diff --git a/README.md b/README.md index cfbb6dc1..4f80f042 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,7 @@ if __name__ == "__main__": - [ ] Store (kafka streams pattern) - [ ] Stream Join - [ ] Windowing +- [ ] PEP 593 ## Development diff --git a/kstreams/__init__.py b/kstreams/__init__.py index 54e82471..c8b045eb 100644 --- a/kstreams/__init__.py +++ b/kstreams/__init__.py @@ -1,5 +1,7 @@ from aiokafka.structs import RecordMetadata, TopicPartition +from ._di.parameters import FromHeader, Header +from .backends.kafka import Kafka from .clients import Consumer, Producer from .create import StreamEngine, create_engine from .prometheus.monitor import PrometheusMonitor, PrometheusMonitorType @@ -31,4 +33,8 @@ "TestStreamClient", "TopicPartition", "TopicPartitionOffset", + "Kafka", + "StreamDependencyManager", + "FromHeader", + "Header", ] diff --git a/kstreams/_di/binders/api.py b/kstreams/_di/binders/api.py new file mode 100644 index 00000000..02b959e8 --- /dev/null +++ b/kstreams/_di/binders/api.py @@ -0,0 +1,68 @@ +import inspect +from typing import Any, AsyncIterator, Awaitable, Protocol, TypeVar, Union + +from di.api.dependencies import CacheKey +from di.dependent import Dependent, Marker + +from kstreams.types import ConsumerRecord + + +class ExtractorTrait(Protocol): + """Implement to extract data from incoming `ConsumerRecord`. + + Consumers will always work with a consumer Record. + Implementing this would let you extract information from the `ConsumerRecord`. + """ + + def __hash__(self) -> int: + """Required by di in order to cache the deps""" + ... + + def __eq__(self, __o: object) -> bool: + """Required by di in order to cache the deps""" + ... + + async def extract( + self, consumer_record: ConsumerRecord + ) -> Union[Awaitable[Any], AsyncIterator[Any]]: + """This is where the magic should happen. + + For example, you could "extract" here a json from the `ConsumerRecord.value` + """ + ... + + +T = TypeVar("T", covariant=True) + + +class MarkerTrait(Protocol[T]): + def register_parameter(self, param: inspect.Parameter) -> T: ... + + +class Binder(Dependent[Any]): + def __init__( + self, + *, + extractor: ExtractorTrait, + ) -> None: + super().__init__(call=extractor.extract, scope="consumer_record") + self.extractor = extractor + + @property + def cache_key(self) -> CacheKey: + return self.extractor + + +class BinderMarker(Marker): + """Bind together the different dependencies. + + NETX: Add asyncapi marker here, like `MarkerTrait[AsyncApiTrait]`. + Recommendation to wait until 3.0: + - [#618](https://github.com/asyncapi/spec/issues/618) + """ + + def __init__(self, *, extractor_marker: MarkerTrait[ExtractorTrait]) -> None: + self.extractor_marker = extractor_marker + + def register_parameter(self, param: inspect.Parameter) -> Binder: + return Binder(extractor=self.extractor_marker.register_parameter(param)) diff --git a/kstreams/_di/binders/header.py b/kstreams/_di/binders/header.py new file mode 100644 index 00000000..c0f46de6 --- /dev/null +++ b/kstreams/_di/binders/header.py @@ -0,0 +1,44 @@ +import inspect +from typing import Any, NamedTuple, Optional + +from kstreams.exceptions import HeaderNotFound +from kstreams.types import ConsumerRecord + + +class HeaderExtractor(NamedTuple): + name: str + + def __hash__(self) -> int: + return hash((self.__class__, self.name)) + + def __eq__(self, __o: object) -> bool: + return isinstance(__o, HeaderExtractor) and __o.name == self.name + + async def extract(self, consumer_record: ConsumerRecord) -> Any: + headers = dict(consumer_record.headers) + try: + header = headers[self.name] + except KeyError as e: + message = ( + f"No header `{self.name}` found.\n" + "Check if your broker is sending the header.\n" + "Try adding a default value to your parameter like `None`.\n" + "Or set `convert_underscores = False`." + ) + raise HeaderNotFound(message) from e + else: + return header + + +class HeaderMarker(NamedTuple): + alias: Optional[str] + convert_underscores: bool + + def register_parameter(self, param: inspect.Parameter) -> HeaderExtractor: + if self.alias is not None: + name = self.alias + elif self.convert_underscores: + name = param.name.replace("_", "-") + else: + name = param.name + return HeaderExtractor(name=name) diff --git a/kstreams/_di/dependencies/core.py b/kstreams/_di/dependencies/core.py new file mode 100644 index 00000000..b0d92a46 --- /dev/null +++ b/kstreams/_di/dependencies/core.py @@ -0,0 +1,95 @@ +from typing import Any, Callable, Optional + +from di import Container, bind_by_type +from di.dependent import Dependent +from di.executors import AsyncExecutor + +from kstreams._di.dependencies.hooks import bind_by_generic +from kstreams.streams import Stream +from kstreams.types import ConsumerRecord, Send + +LayerFn = Callable[..., Any] + + +class StreamDependencyManager: + """Core of dependency injection on kstreams. + + Attributes: + container: dependency store. + + Usage + + ```python + consumer_record = ConsumerRecord(...) + def user_func(event_type: FromHeader[str]): + ... + + sdm = StreamDependencyManager() + sdm.solve(user_func) + sdm.execute(consumer_record) + ``` + """ + + container: Container + + def __init__(self, container: Optional[Container] = None): + self.container = container or Container() + self.async_executor = AsyncExecutor() + + def _register_framework_deps(self): + """Register with the container types that belong to kstream. + + These types can be injectable things available on startup. + But also they can be things created on a per connection basis. + They must be available at the time of executing the users' function. + + For example: + + - App + - ConsumerRecord + - Consumer + - KafkaBackend + + Here we are just "registering", that's why we use `bind_by_type`. + And eventually, when "executing", we should inject all the values we + want to be "available". + """ + hook = bind_by_generic( + Dependent(ConsumerRecord, scope="consumer_record", wire=False), + ConsumerRecord, + ) + self.container.bind(hook) + # NEXT: Add Consumer as a dependency, so it can be injected. + + def solve_user_fn(self, fn: LayerFn) -> None: + """Build the dependency graph for the given function. + + Attributes: + fn: user defined function, using allowed kstreams params + """ + self._register_framework_deps() + + self.solved_user_fn = self.container.solve( + Dependent(fn, scope="consumer_record"), + scopes=["consumer_record"], + ) + + def register_stream(self, stream: Stream): + hook = bind_by_type(Dependent(stream, scope="consumer_record"), Stream) + self.container.bind(hook) + + def register_send(self, send: Send): ... + + async def execute(self, consumer_record: ConsumerRecord) -> Any: + """Execute the dependencies graph with external values. + + Attributes: + consumer_record: A kafka record containing `values`, `headers`, etc. + """ + async with self.container.enter_scope("consumer_record") as state: + return await self.container.execute_async( + self.solved_user_fn, + values={ConsumerRecord: consumer_record}, + executor=self.async_executor, + state=state, + ) diff --git a/kstreams/_di/dependencies/hooks.py b/kstreams/_di/dependencies/hooks.py new file mode 100644 index 00000000..11e0799e --- /dev/null +++ b/kstreams/_di/dependencies/hooks.py @@ -0,0 +1,31 @@ +import inspect +from typing import Any, get_origin + +from di._container import BindHook +from di._utils.inspect import get_type +from di.api.dependencies import DependentBase + + +def bind_by_generic( + provider: DependentBase[Any], + dependency: type, +) -> BindHook: + """Hook to substitute the matched dependency based on its generic.""" + + def hook( + param: inspect.Parameter | None, dependent: DependentBase[Any] + ) -> DependentBase[Any] | None: + if dependent.call == dependency: + return provider + if param is None: + return None + + type_annotation_option = get_type(param) + if type_annotation_option is None: + return None + type_annotation = type_annotation_option.value + if get_origin(type_annotation) is dependency: + return provider + return None + + return hook diff --git a/kstreams/_di/parameters.py b/kstreams/_di/parameters.py new file mode 100644 index 00000000..b16a1b1e --- /dev/null +++ b/kstreams/_di/parameters.py @@ -0,0 +1,37 @@ +from typing import Optional, TypeVar + +from kstreams._di.binders.api import BinderMarker +from kstreams._di.binders.header import HeaderMarker +from kstreams.typing import Annotated + + +def Header( + *, alias: Optional[str] = None, convert_underscores: bool = True +) -> BinderMarker: + """Construct another type from the headers of a kafka record. + + Args: + alias: Use a different header name + convert_underscores: If True, convert underscores to dashes. + + Usage: + + ```python + from kstream import Header, Annotated + + def user_fn(event_type: Annotated[str, Header(alias="EventType")]): + ... + ``` + """ + header_marker = HeaderMarker(alias=alias, convert_underscores=convert_underscores) + binder = BinderMarker(extractor_marker=header_marker) + return binder + + +T = TypeVar("T") + +FromHeader = Annotated[T, Header()] +FromHeader.__doc__ = """General purpose convenient header type. + +Use `Annotated` to provide custom params. +""" diff --git a/kstreams/engine.py b/kstreams/engine.py index b69095f6..056f0a4d 100644 --- a/kstreams/engine.py +++ b/kstreams/engine.py @@ -5,6 +5,7 @@ from aiokafka.structs import RecordMetadata +from kstreams.middleware.di_middleware import DependencyInjectionHandler from kstreams.structs import TopicPartitionOffset from .backends.kafka import Kafka @@ -389,7 +390,13 @@ def add_stream( stream.rebalance_listener.stream = stream stream.rebalance_listener.engine = self - stream.udf_handler = UdfHandler( + # stream.udf_handler = UdfHandler( + # next_call=stream.func, + # send=self.send, + # stream=stream, + # ) + + stream.udf_handler = DependencyInjectionHandler( next_call=stream.func, send=self.send, stream=stream, @@ -397,8 +404,8 @@ def add_stream( # NOTE: When `no typing` support is deprecated this check can # be removed - if stream.udf_handler.type != UDFType.NO_TYPING: - stream.func = self._build_stream_middleware_stack(stream=stream) + # if stream.udf_handler.type != UDFType.NO_TYPING: + stream.func = self._build_stream_middleware_stack(stream=stream) def _build_stream_middleware_stack(self, *, stream: Stream) -> NextMiddlewareCall: assert stream.udf_handler, "UdfHandler can not be None" diff --git a/kstreams/exceptions.py b/kstreams/exceptions.py index 249f2db7..dd5a65d9 100644 --- a/kstreams/exceptions.py +++ b/kstreams/exceptions.py @@ -25,3 +25,6 @@ def __str__(self) -> str: class BackendNotSet(StreamException): ... + + +class HeaderNotFound(StreamException): ... diff --git a/kstreams/middleware/di_middleware.py b/kstreams/middleware/di_middleware.py new file mode 100644 index 00000000..a97b8664 --- /dev/null +++ b/kstreams/middleware/di_middleware.py @@ -0,0 +1,19 @@ +import typing + +from kstreams import types +from kstreams._di.dependencies.core import StreamDependencyManager + +from .middleware import BaseMiddleware + + +class DependencyInjectionHandler(BaseMiddleware): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.dependecy_manager = StreamDependencyManager() + self.dependecy_manager.register_stream(self.stream) + self.dependecy_manager.register_send(self.send) + self.dependecy_manager.solve_user_fn(fn=self.next_call) + self.type = None + + async def __call__(self, cr: types.ConsumerRecord) -> typing.Any: + return await self.dependecy_manager.execute(cr) diff --git a/kstreams/streams.py b/kstreams/streams.py index b247c6d7..5b703914 100644 --- a/kstreams/streams.py +++ b/kstreams/streams.py @@ -15,7 +15,7 @@ from .backends.kafka import Kafka from .clients import Consumer -from .middleware import Middleware, udf_middleware +from .middleware import Middleware from .rebalance_listener import RebalanceListener from .serializers import Deserializer from .streams_utils import StreamErrorPolicy, UDFType @@ -172,7 +172,7 @@ def __init__( self.seeked_initial_offsets = False self.rebalance_listener = rebalance_listener self.middlewares = middlewares or [] - self.udf_handler: typing.Optional[udf_middleware.UdfHandler] = None + self.udf_handler = None self.topics = [topics] if isinstance(topics, str) else topics self.subscribe_by_pattern = subscribe_by_pattern self.error_policy = error_policy @@ -356,6 +356,10 @@ async def start(self) -> None: await func else: # Typing cases + + # If it's an async generator, then DON'T await the function + # because we want to start ONLY and let the user retrieve the + # values while iterating the stream if not inspect.isasyncgenfunction(self.udf_handler.next_call): # Is not an async_generator, then create `await` the func await self.func_wrapper_with_typing() diff --git a/kstreams/types.py b/kstreams/types.py index 3562f3b6..af4ca5ee 100644 --- a/kstreams/types.py +++ b/kstreams/types.py @@ -9,7 +9,6 @@ Headers = typing.Dict[str, str] EncodedHeaders = typing.Sequence[typing.Tuple[str, bytes]] StreamFunc = typing.Callable - EngineHooks = typing.Sequence[typing.Callable[[], typing.Any]] diff --git a/kstreams/typing.py b/kstreams/typing.py new file mode 100644 index 00000000..1707d56b --- /dev/null +++ b/kstreams/typing.py @@ -0,0 +1,8 @@ +"""Remove this file when python3.8 support is dropped.""" + +import sys + +if sys.version_info < (3, 9): + from typing_extensions import Annotated as Annotated # noqa: F401 +else: + from typing import Annotated as Annotated # noqa: F401 diff --git a/poetry.lock b/poetry.lock index 6270fef6..109c6012 100644 --- a/poetry.lock +++ b/poetry.lock @@ -446,6 +446,24 @@ files = [ {file = "decli-0.6.2.tar.gz", hash = "sha256:36f71eb55fd0093895efb4f416ec32b7f6e00147dda448e3365cf73ceab42d6f"}, ] +[[package]] +name = "di" +version = "0.79.2" +description = "Dependency injection toolkit" +optional = false +python-versions = ">=3.8,<4" +files = [ + {file = "di-0.79.2-py3-none-any.whl", hash = "sha256:4b2ac7c46d4d9e941ca47d37c2029ba739c1f8a0e19e5288731224870f00d6e6"}, + {file = "di-0.79.2.tar.gz", hash = "sha256:0c65b9ccb984252dadbdcdb39743eeddef0c1f167f791c59fcd70e97bb0d3af8"}, +] + +[package.dependencies] +graphlib2 = ">=0.4.1,<0.5.0" +typing-extensions = {version = ">=3", markers = "python_version < \"3.9\""} + +[package.extras] +anyio = ["anyio (>=3.5.0)"] + [[package]] name = "exceptiongroup" version = "1.2.2" @@ -460,6 +478,20 @@ files = [ [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "faker" +version = "14.2.1" +description = "Faker is a Python package that generates fake data for you." +optional = false +python-versions = ">=3.6" +files = [ + {file = "Faker-14.2.1-py3-none-any.whl", hash = "sha256:2e28aaea60456857d4ce95dd12aed767769537ad23d13d51a545cd40a654e9d9"}, + {file = "Faker-14.2.1.tar.gz", hash = "sha256:daad7badb4fd916bd047b28c8459ef4689e4fe6acf61f6dfebee8cc602e4d009"}, +] + +[package.dependencies] +python-dateutil = ">=2.4" + [[package]] name = "fastapi" version = "0.115.2" @@ -508,6 +540,40 @@ python-dateutil = ">=2.8.1" [package.extras] dev = ["flake8", "markdown", "twine", "wheel"] +[[package]] +name = "graphlib2" +version = "0.4.7" +description = "Rust port of the Python stdlib graphlib modules" +optional = false +python-versions = ">=3.7" +files = [ + {file = "graphlib2-0.4.7-cp37-abi3-macosx_10_7_x86_64.whl", hash = "sha256:483710733215783cdc76452ccde1247af8f697685c9c1dfd9bb9ff4f52d990ee"}, + {file = "graphlib2-0.4.7-cp37-abi3-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:3619c7d3c5aca95e6cbbfc283aa6bf42ffa5b59d7f39c8d0ad615bce65dc406f"}, + {file = "graphlib2-0.4.7-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5b19f1b91d0f22ca3d1cfb2965478db98cf5916a5c6cea5fdc7caf4bf1bfbc33"}, + {file = "graphlib2-0.4.7-cp37-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:624020f6808ee21ffbb2e455f8dd4196bbb37032a35aa3327f0f5b65fb6a35d1"}, + {file = "graphlib2-0.4.7-cp37-abi3-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:6efc6a197a619a97f1b105aea14b202101241c1db9014bd100ad19cf29288cbf"}, + {file = "graphlib2-0.4.7-cp37-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d7cc38b68775cb2cdfc487bbaca2f7991da0d76d42a68f412c2ca61461e6e026"}, + {file = "graphlib2-0.4.7-cp37-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b06bed98d42f4e10adfe2a8332efdca06b5bac6e7c86dd1d22a4dea4de9b275a"}, + {file = "graphlib2-0.4.7-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c9ec3a5645bdf020d8bd9196b2665e26090d60e523fd498df29628f2c5fbecc"}, + {file = "graphlib2-0.4.7-cp37-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:824df87f767471febfd785a05a2cc77c0c973e0112a548df827763ca0aa8c126"}, + {file = "graphlib2-0.4.7-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:2de5e32ca5c0b06d442d2be4b378cc0bc335c5fcbc14a7d531a621eb8294d019"}, + {file = "graphlib2-0.4.7-cp37-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:13a23fcf07c7bef8a5ad0e04ab826d3a2a2bcb493197005300c68b4ea7b8f581"}, + {file = "graphlib2-0.4.7-cp37-abi3-musllinux_1_2_i686.whl", hash = "sha256:15a8a6daa28c1fb5c518d387879f3bbe313264fbbc2fab5635b718bc71a24913"}, + {file = "graphlib2-0.4.7-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:0cb6c4449834077972c3cea4602f86513b4b75fcf2d40b12e4fe4bf1aa5c8da2"}, + {file = "graphlib2-0.4.7-cp37-abi3-win32.whl", hash = "sha256:31b40cea537845d80b69403ae306d7c6a68716b76f5171f68daed1804aadefec"}, + {file = "graphlib2-0.4.7-cp37-abi3-win_amd64.whl", hash = "sha256:d40935a9da81a046ebcaa0216ad593ef504ae8a5425a59bdbd254c0462adedc8"}, + {file = "graphlib2-0.4.7-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:9cef08a50632e75a9e11355e68fa1f8c9371d0734642855f8b5c4ead1b058e6f"}, + {file = "graphlib2-0.4.7-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aeecb604d70317c20ca6bc3556f7f5c40146ad1f0ded837e978b2fe6edf3e567"}, + {file = "graphlib2-0.4.7-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cb4ae9df7ed895c6557619049c9f73e1c2e6d1fbed568010fd5d4af94e2f0692"}, + {file = "graphlib2-0.4.7-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:3ee3a99fc39df948fef340b01254709cc603263f8b176f72ed26f1eea44070a4"}, + {file = "graphlib2-0.4.7-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5873480df8991273bd1585122df232acd0f946c401c254bd9f0d661c72589dcf"}, + {file = "graphlib2-0.4.7-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:297c817229501255cd3a744c62c8f91e5139ee79bc550488f5bc765ffa33f7c5"}, + {file = "graphlib2-0.4.7-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:853ef22df8e9f695706e0b8556cda9342d4d617f7d7bd02803e824bcc0c30b20"}, + {file = "graphlib2-0.4.7-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee62ff1042fde980adf668e30393eca79aee8f1fa1274ab3b98d69091c70c5e8"}, + {file = "graphlib2-0.4.7-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b16e21e70938132d4160c2591fed59f79b5f8b702e4860c8933111b5fedb55c2"}, + {file = "graphlib2-0.4.7.tar.gz", hash = "sha256:a951c18cb4c2c2834eec898b4c75d3f930d6f08beb37496f0e0ce56eb3f571f5"}, +] + [[package]] name = "griffe" version = "1.4.0" @@ -1782,4 +1848,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "c710ce8ba36d7aaaa7d9e2a499a85c6a48786ed4cc880f56f28b7fda8b4b7c57" +content-hash = "648b74eb2f37b8a895a99518c3b1747446de67dc24371dd02b883cd8a697564e" diff --git a/pyproject.toml b/pyproject.toml index 38912d62..373b2712 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ prometheus-client = "<1.0" future = "^1.0.0" PyYAML = ">=5.4,<7.0.0" pydantic = ">=2.0.0,<3.0.0" +di = "^0.79.2" [tool.poetry.group.dev.dependencies] pytest = "^8.3.3" @@ -47,6 +48,7 @@ mkdocs-material = "^9.5.39" starlette-prometheus = "^0.10.0" codecov = "^2.1.12" mkdocstrings = { version = "^0.26.1", extras = ["python"] } +Faker = "^14.2.0" [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/scripts/test b/scripts/test index a20ebb12..5cff7269 100755 --- a/scripts/test +++ b/scripts/test @@ -5,7 +5,7 @@ if [ -d '.venv' ] ; then export PREFIX=".venv/bin/" fi -${PREFIX}pytest -x --cov-report term-missing --cov-report=xml:coverage.xml --cov=kstreams ${1-"./tests"} $2 +${PREFIX}pytest --cov-report term-missing --cov-report=xml:coverage.xml --cov=kstreams ${1-"./tests"} $2 ${PREFIX}ruff check kstreams tests ${PREFIX}ruff format --check kstreams tests examples -${PREFIX}mypy kstreams/ +${PREFIX}mypy kstreams/ tests diff --git a/tests/_di/test_dependency_manager.py b/tests/_di/test_dependency_manager.py new file mode 100644 index 00000000..bfdceeb1 --- /dev/null +++ b/tests/_di/test_dependency_manager.py @@ -0,0 +1,41 @@ +from typing import Any, AsyncGenerator, TypeVar + +from kstreams._di.dependencies.core import StreamDependencyManager +from kstreams.types import ConsumerRecord + + +async def test_cr_is_injected(rand_consumer_record): + + async def user_fn(event_type: ConsumerRecord) -> str: + event_type.value = "hello" + return event_type.value + + stream_manager = StreamDependencyManager() + stream_manager.solve_user_fn(user_fn) + content = await stream_manager.execute(rand_consumer_record) + assert content == "hello" + + +async def test_cr_generics_is_injected(rand_consumer_record): + + async def user_fn(event_type: ConsumerRecord[Any, Any]) -> str: + event_type.value = "hello" + return event_type.value + + stream_manager = StreamDependencyManager() + stream_manager.solve_user_fn(user_fn) + content = await stream_manager.execute(rand_consumer_record) + assert content == "hello" + + +async def test_cr_with_generator(rand_consumer_record): + + async def user_fn(event_type: ConsumerRecord) -> AsyncGenerator[str, None]: + event_type.value = "hello" + yield event_type.value + + stream_manager = StreamDependencyManager() + stream_manager.solve_user_fn(user_fn) + content = await stream_manager.execute(rand_consumer_record) + + assert content == "hello" \ No newline at end of file diff --git a/tests/_di/test_hooks.py b/tests/_di/test_hooks.py new file mode 100644 index 00000000..dfcf26ae --- /dev/null +++ b/tests/_di/test_hooks.py @@ -0,0 +1,66 @@ +import typing + +import pytest +from di import Container +from di.dependent import Dependent +from di.executors import SyncExecutor + +from kstreams._di.dependencies.hooks import bind_by_generic + +KT = typing.TypeVar("KT") +VT = typing.TypeVar("VT") + +class Record(typing.Generic[KT, VT]): + def __init__(self, key: KT, value: VT): + self.key = key + self.value = value + +def func_hinted(record: Record[str, int]) -> Record[str, int]: + return record + +def func_base(record: Record) -> Record: + return record + +@pytest.mark.parametrize("func", [ + func_hinted, + func_base, +]) +def test_bind_generic_ok(func: typing.Callable): + + dep = Dependent(func) + container = Container() + container.bind( + bind_by_generic( + Dependent(lambda: Record("foo", 1), wire=False), + Record, + ) + ) + solved = container.solve(dep, scopes=[None]) + with container.enter_scope(None) as state: + instance = solved.execute_sync(executor=SyncExecutor(), state=state) + assert isinstance(instance, Record) + + +def func_str(record: str) -> str: + return record + +@pytest.mark.parametrize("func", [ + func_str, +]) +def test_bind_generic_unrelated(func: typing.Callable): + + dep = Dependent(func) + container = Container() + container.bind( + bind_by_generic( + Dependent(lambda: Record("foo", 1), wire=False), + Record, + ) + ) + solved = container.solve(dep, scopes=[None]) + with container.enter_scope(None) as state: + instance = solved.execute_sync(executor=SyncExecutor(), state=state) + print(type(instance)) + print(instance) + assert not isinstance(instance, Record) + assert isinstance(instance, str) \ No newline at end of file diff --git a/tests/_di/test_param_headers.py b/tests/_di/test_param_headers.py new file mode 100644 index 00000000..4440d83d --- /dev/null +++ b/tests/_di/test_param_headers.py @@ -0,0 +1,69 @@ +import pytest + +from kstreams import FromHeader, Header +from kstreams._di.dependencies.core import StreamDependencyManager +from kstreams.exceptions import HeaderNotFound +from kstreams.typing import Annotated + + +async def test_from_headers_ok(rand_consumer_record): + rand_consumer_record.headers = (("event-type", "hello"),) + + async def user_fn(event_type: FromHeader[str]) -> str: + return event_type + + stream_manager = StreamDependencyManager() + stream_manager.solve_user_fn(user_fn) + header_content = await stream_manager.execute(rand_consumer_record) + assert header_content == "hello" + + +async def test_from_header_not_found(rand_consumer_record): + rand_consumer_record.headers = (("event_type", "hello"),) + + def user_fn(a_header: FromHeader[str]) -> str: + return a_header + + stream_manager = StreamDependencyManager() + stream_manager.solve_user_fn(user_fn) + with pytest.raises(HeaderNotFound): + await stream_manager.execute(rand_consumer_record) + + +@pytest.mark.xfail(reason="not implemenetd yet") +async def test_from_headers_numbers(rand_consumer_record): + rand_consumer_record.headers = (("event-type", "1"),) + + async def user_fn(event_type: FromHeader[int]) -> int: + return event_type + + stream_manager = StreamDependencyManager() + stream_manager.solve_user_fn(user_fn) + header_content = await stream_manager.execute(rand_consumer_record) + assert header_content == 1 + + +async def test_headers_alias(rand_consumer_record): + rand_consumer_record.headers = (("EventType", "hello"),) + + async def user_fn(event_type: Annotated[int, Header(alias="EventType")]) -> int: + return event_type + + stream_manager = StreamDependencyManager() + stream_manager.solve_user_fn(user_fn) + header_content = await stream_manager.execute(rand_consumer_record) + assert header_content == "hello" + + +async def test_headers_convert_underscores(rand_consumer_record): + rand_consumer_record.headers = (("event_type", "hello"),) + + async def user_fn( + event_type: Annotated[int, Header(convert_underscores=False)], + ) -> int: + return event_type + + stream_manager = StreamDependencyManager() + stream_manager.solve_user_fn(user_fn) + header_content = await stream_manager.execute(rand_consumer_record) + assert header_content == "hello" diff --git a/tests/conftest.py b/tests/conftest.py index 985194e4..3f706792 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,14 +1,21 @@ +import logging from collections import namedtuple from dataclasses import field from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple import pytest import pytest_asyncio +from faker import Faker from pytest_httpserver import HTTPServer from kstreams import clients, create_engine +from kstreams.types import ConsumerRecord from kstreams.utils import create_ssl_context_from_mem +# Silence faker DEBUG logs +logger = logging.getLogger("faker") +logger.setLevel(logging.INFO) + class RecordMetadata(NamedTuple): offset: int = 1 @@ -195,3 +202,42 @@ class ConsumerRecord(NamedTuple): ) return consumer_record + + +@pytest.fixture +def fake(): + return Faker() + + +@pytest.fixture() +def rand_consumer_record(fake: Faker): + """A random consumer record""" + + def generate( + topic: Optional[str] = None, + headers: Optional[Sequence[Tuple[str, bytes]]] = None, + partition: Optional[int] = None, + offset: Optional[int] = None, + timestamp: Optional[int] = None, + timestamp_type: Optional[int] = None, + key: Optional[Any] = None, + value: Optional[Any] = None, + checksum: Optional[int] = None, + serialized_key_size: Optional[int] = None, + serialized_value_size: Optional[int] = None, + ) -> ConsumerRecord: + return ConsumerRecord( + topic=topic or fake.slug(), + headers=headers or tuple(), + partition=partition or fake.pyint(max_value=10), + offset=offset or fake.pyint(max_value=99999999), + timestamp=timestamp or fake.unix_time(), + timestamp_type=timestamp_type or 1, + key=key or fake.pystr(), + value=value or fake.pystr().encode(), + checksum=checksum or fake.pystr(), + serialized_key_size=serialized_key_size or fake.pyint(max_value=10), + serialized_value_size=serialized_value_size or fake.pyint(max_value=10), + ) + + return generate