import asyncio
import functools
import logging
from uuid import uuid4
from confluent_kafka import Consumer
from pydantic import BaseModel
from heizer._source import get_logger
from heizer._source.admin import create_new_topic, get_admin_client
from heizer._source.message import HeizerMessage
from heizer._source.topic import HeizerTopic
from heizer.config import HeizerConfig
from heizer.types import (
Any,
Awaitable,
Callable,
Concatenate,
Coroutine,
List,
Optional,
ParamSpec,
Stopper,
Type,
TypeVar,
Union,
)
R = TypeVar("R")
F = TypeVar("F", bound=Callable[..., Any])
P = ParamSpec("P")
T = TypeVar("T")
logger = get_logger(__name__)
[docs]class consumer(object):
"""A decorator to create a consumer"""
__id__: str
name: Optional[str]
topics: List[HeizerTopic]
config: HeizerConfig = HeizerConfig()
call_once: bool = False
stopper: Optional[Stopper] = None
deserializer: Optional[Type[BaseModel]] = None
is_async: bool = False
poll_timeout: int = 1
init_topics: bool = True
def __init__(
self,
*,
topics: List[HeizerTopic],
config: HeizerConfig = HeizerConfig(),
call_once: bool = False,
stopper: Optional[Stopper] = None,
deserializer: Optional[Type[BaseModel]] = None,
is_async: bool = False,
name: Optional[str] = None,
poll_timeout: Optional[int] = None,
init_topics: bool = True,
):
self.topics = topics
self.config = config
self.call_once = call_once
self.stopper = stopper
self.deserializer = deserializer
self.__id__ = str(uuid4())
self.name = name or self.__id__
self.is_async = is_async
self.poll_timeout = poll_timeout or 1
self.init_topics = init_topics
async def __run__(
self,
func: Callable[Concatenate[HeizerMessage, P], Union[T, Awaitable[T]]],
c: Consumer,
is_async: bool,
*args: P.args,
**kwargs: P.kwargs,
) -> Union[Optional[T]]: # ignore type
"""Run the consumer"""
if self.init_topics:
logger.debug("Initializing topics")
admin_client = get_admin_client(self.config)
create_new_topic(admin_client, self.topics)
while True:
msg = c.poll(self.poll_timeout)
if msg is None:
continue
if msg.error():
logger.error(f"Consumer error: {msg.error()}")
continue
logger.debug("Received message")
hmessage = HeizerMessage(msg)
if self.deserializer is not None:
logger.debug("Parsing message")
try:
hmessage.formatted_value = self.deserializer.parse_raw(hmessage.value)
except Exception as e:
logger.exception("Failed to deserialize message", exc_info=e)
result = None
logger.debug("Executing function")
try:
if is_async:
result = await func(hmessage, *args, **kwargs) # type: ignore
else:
result = func(hmessage, *args, **kwargs)
except Exception as e:
logger.exception(
f"Failed to execute function {func.__name__}",
exc_info=e,
)
finally:
# TODO: add failed message to a retry queue
logger.debug("Committing message")
c.commit()
if self.stopper is not None:
logger.debug("Executing stopper function")
try:
should_stop = self.stopper(hmessage)
except Exception as e:
logger.warning(
f"Failed to execute stopper function {self.stopper.__name__}.",
exc_info=e,
)
should_stop = False
if should_stop:
return result
if self.call_once is True:
logger.debug("Call Once is on, returning result")
return result
def __call__(
self, func: Callable[Concatenate[HeizerMessage, P], T]
) -> Callable[P, Union[Coroutine[Any, Any, Optional[T]], T, None]]:
logger.debug("Creating consumer")
c = Consumer(self.config.value)
logger.debug("Subscribing to topics")
c.subscribe([topic.name for topic in self.topics])
@functools.wraps(func)
async def async_decorator(*args: P.args, **kwargs: P.kwargs) -> Optional[T]:
"""Async decorator"""
logging.debug("Running async decorator")
return await self.__run__(func, c, self.is_async, *args, **kwargs)
@functools.wraps(func)
def decorator(*args: P.args, **kwargs: P.kwargs) -> Optional[T]:
"""Sync decorator"""
logging.debug("Running sync decorator")
return asyncio.get_event_loop().run_until_complete(self.__run__(func, c, self.is_async, *args, **kwargs))
return async_decorator if self.is_async else decorator