Source code for heizer._source.consumer

import asyncio
import atexit
import functools
import logging
import os
import signal
from dataclasses import dataclass
from uuid import uuid4

import confluent_kafka as ck

from heizer._source import get_logger
from heizer._source.admin import create_new_topics
from heizer._source.enums import ConsumerStatusEnum
from heizer._source.message import Message
from heizer._source.producer import Producer
from heizer._source.status_manager import write_consumer_status
from heizer._source.topic import Topic
from heizer.config import ConsumerConfig
from heizer.types import (
    Any,
    Awaitable,
    Callable,
    Concatenate,
    Coroutine,
    KafkaConfig,
    List,
    Optional,
    ParamSpec,
    Stopper,
    TypeVar,
    Union,
)

R = TypeVar("R")
F = TypeVar("F", bound=Callable[..., Any])
P = ParamSpec("P")
T = TypeVar("T")

logger = get_logger(__name__)


def _make_consumer_config_dict(config: ConsumerConfig) -> KafkaConfig:
    config_dict = {
        "bootstrap.servers": config.bootstrap_servers,
        "group.id": config.group_id,
        "auto.offset.reset": config.auto_offset_reset,
        "enable.auto.commit": config.enable_auto_commit,
    }

    if config.other_configs:
        config_dict.update(config.other_configs)

    return config_dict


[docs]@dataclass class ConsumerSignal: is_running: bool = True
[docs] def stop(self) -> None: self.is_running = False
[docs]class consumer(object): """ A decorator to create a consumer """ id: str name: Optional[str] topics: List[Topic] config: KafkaConfig call_once: bool = False stopper: Optional[Stopper] = None deserializer: Optional[Callable] = None consumer_signal: ConsumerSignal is_async: bool = False poll_timeout: int = 1 init_topics: bool = False enable_retry: bool = False retry_topic: Optional[Topic] = None retry_times: int = 1 retry_times_header_key: str = "!!-heizer_func_retried_times-!!" # private attr _consumer_instance: ck.Consumer _producer_instance: Optional[Producer] = None # for retry def __init__( self, *, topics: List[Topic], config: Union[KafkaConfig, ConsumerConfig], call_once: bool = False, stopper: Optional[Stopper] = None, deserializer: Optional[Callable] = None, is_async: bool = False, id: Optional[str] = None, name: Optional[str] = None, poll_timeout: Optional[int] = None, init_topics: bool = False, consumer_signal: Optional[ConsumerSignal] = None, enable_retry: bool = False, retry_topic: Optional[Topic] = None, retry_times: int = 1, ): self.topics = topics self.config = _make_consumer_config_dict(config) if isinstance(config, ConsumerConfig) else config self.call_once = call_once self.stopper = stopper self.deserializer = deserializer self.id = id or str(uuid4()) self.name = name or self.id self.is_async = is_async self.poll_timeout = poll_timeout if poll_timeout is not None else 1 self.init_topics = init_topics self.consumer_signal = consumer_signal or ConsumerSignal() self.enable_retry = enable_retry self.retry_topic = retry_topic self.retry_times = retry_times def __repr__(self) -> str: return self.name or self.id @property def ck_consumer(self) -> ck.Consumer: return self._consumer_instance @property def ck_producer(self) -> Optional[ck.Producer]: return self._producer_instance
[docs] def commit(self, asynchronous: bool = False, *arg, **kwargs): self.ck_consumer.commit(asynchronous=asynchronous, *arg, **kwargs)
async def _long_run(self, f: Callable, is_async: bool, *args, **kwargs): target_topics: List[Topic] = [topic for topic in self.topics] if self.enable_retry and self.retry_topic: target_topics.append(self.retry_topic) logger.info(f"[{self}] Subscribing to topics {[topic.name for topic in target_topics]}") self.ck_consumer.subscribe([topic.name for topic in target_topics]) result = None while self.consumer_signal.is_running: write_consumer_status( consumer_id=self.id, consumer_name=self.name, pid=os.getpid(), status=ConsumerStatusEnum.RUNNING ) result = None msg = self.ck_consumer.poll(self.poll_timeout) if msg is None: continue if msg.error(): logger.error(f"[{self}] Consumer error: {msg.error()}") continue logger.debug(f"[{self}] Received message") h_message = Message(msg) if self.deserializer is not None: logger.debug(f"[{self}] Parsing message") try: h_message.formatted_value = self.deserializer(h_message.value) except Exception as e: logger.exception(f"[{self}] Failed to deserialize message", exc_info=e) # If heizer retry times header key exists, remove it when passing to user func has_retired: Optional[int] = None if h_message.headers and self.retry_times_header_key in h_message.headers: has_retired = int(h_message.headers.pop(self.retry_times_header_key)) logger.debug(f"[{self}] Executing function {f.__name__}") try: if is_async: result = await f(h_message, self, *args, **kwargs) # type: ignore else: result = f(h_message, self, *args, **kwargs) except Exception as e: logger.exception( f"[{self}] Failed to execute function {f.__name__}", exc_info=e, ) if self.enable_retry: logger.info(f"[{self}] Start producing retry message for function {f.__name__}") await self.produce_retry_message(message=h_message, has_retired=has_retired, func_name=f.__name__) finally: logger.debug(f"[{self}] Committing message") self.commit() if self.stopper is not None: logger.debug(f"[{self}] Executing stopper function") try: should_stop = self.stopper(h_message, self) except Exception as e: logger.warning( f"[{self}] Failed to execute stopper function {self.stopper.__name__}.", exc_info=e, ) should_stop = False if should_stop: self.consumer_signal.stop() break if self.call_once is True: logger.debug(f"[{self}] Call Once is on, returning result") self.consumer_signal.stop() break return result async def _run( # type: ignore self, func: Callable[Concatenate[Message, P], Union[T, Awaitable[T]]], is_async: bool, *args: P.args, **kwargs: P.kwargs, ) -> Union[Optional[T]]: """Run the consumer""" if self.init_topics: logger.info(f"[{self}] Initializing topics") create_new_topics({"bootstrap.servers": self.config["bootstrap.servers"]}, self.topics) logger.info(f"[{self}] Creating consumer") self._consumer_instance = ck.Consumer(self.config) if self.enable_retry: self.retry_topic = self.retry_topic or Topic(name=f"{func.__name__}_heizer_retry") self._producer_instance = Producer(config={"bootstrap.servers": self.config["bootstrap.servers"]}) logger.info(f"[{self}] Creating retry topic {self.retry_topic.name}") create_new_topics({"bootstrap.servers": self.config["bootstrap.servers"]}, [self.retry_topic]) atexit.register(self._atexit) signal.signal(signal.SIGTERM, self._exit) signal.signal(signal.SIGINT, self._exit) result = None try: result = await self._long_run(func, is_async, *args, **kwargs) except KeyboardInterrupt: logger.info( f"[{self}] Stopped because of keyboard interruption", ) except Exception as e: logger.exception(f"[{self}] stopped", exc_info=e) finally: self.consumer_signal.stop() self.ck_consumer.close() write_consumer_status( consumer_id=self.id, consumer_name=self.name, pid=os.getpid(), status=ConsumerStatusEnum.CLOSED ) logger.info( f"[{self}] Closed", ) return result def __call__( self, func: Callable[Concatenate[Message, P], T] ) -> Callable[P, Union[Coroutine[Any, Any, Optional[T]], T, None]]: @functools.wraps(func) async def async_decorator(*args: P.args, **kwargs: P.kwargs) -> Optional[T]: """Async decorator""" logging.debug(f"[{self}] Running async decorator") return await self._run(func, self.is_async, *args, **kwargs) @functools.wraps(func) def decorator(*args: P.args, **kwargs: P.kwargs) -> Optional[T]: """Sync decorator""" logging.debug(f"[{self}] Running sync decorator") return asyncio.run(self._run(func, self.is_async, *args, **kwargs)) return async_decorator if self.is_async else decorator def _exit(self, sig, frame, *args, **kwargs) -> None: self.consumer_signal.stop() self.ck_consumer.close() def _atexit(self) -> None: self.ck_consumer.close()
[docs] async def produce_retry_message(self, message: Message, has_retired: Optional[int], func_name: str) -> None: """ Produces a retry message for a given function. :param message: The original message that needs to be retried. :type message: Message :param has_retired: The number of times the function has been retried before. Defaults to None. :type has_retired: Optional[int] :param func_name: The name of the function. :type func_name: str :return: None """ if not self._producer_instance: logger.error( f"[{self}] Confluent producer instance not found," f" failed to produce retry message for function {func_name}" ) return if not self.retry_topic: logger.error(f"[{self}] Retry topic not found, failed to produce retry message for function {func_name}") return if has_retired == self.retry_times: logger.info(f"[{self}] Function {func_name} reached retry limit ({self.retry_times}), will give up") return elif has_retired is None: has_retired = 1 else: has_retired += 1 new_headers = message.headers or {} new_headers[self.retry_times_header_key] = str(has_retired) await self._producer_instance.async_produce( topic=self.retry_topic, key=message.key, value=message.value or "", headers=new_headers, auto_flush=True ) logger.info(f"[{self}] Produced retry message for function {func_name}, retry time(s): {has_retired}")