Skip to content

Commit

Permalink
feat: add dependency injection framework
Browse files Browse the repository at this point in the history
  • Loading branch information
woile committed Oct 15, 2024
1 parent 8623734 commit 68e98c9
Show file tree
Hide file tree
Showing 21 changed files with 594 additions and 33 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ if __name__ == "__main__":
- [ ] Store (kafka streams pattern)
- [ ] Stream Join
- [ ] Windowing
- [ ] PEP 593

## Development

Expand Down
11 changes: 9 additions & 2 deletions kstreams/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from aiokafka.structs import ConsumerRecord, RecordMetadata, TopicPartition
from aiokafka.structs import TopicPartition

from ._di.dependencies.core import StreamDependencyManager
from ._di.parameters import FromHeader, Header
from .backends.kafka import Kafka
from .clients import Consumer, Producer
from .create import StreamEngine, create_engine
from .prometheus.monitor import PrometheusMonitor, PrometheusMonitorType
Expand All @@ -11,7 +14,7 @@
from .streams import Stream, stream
from .structs import TopicPartitionOffset
from .test_utils import TestStreamClient
from .types import Send
from .types import ConsumerRecord, RecordMetadata, Send

__all__ = [
"Consumer",
Expand All @@ -31,4 +34,8 @@
"TestStreamClient",
"TopicPartition",
"TopicPartitionOffset",
"Kafka",
"StreamDependencyManager",
"FromHeader",
"Header",
]
68 changes: 68 additions & 0 deletions kstreams/_di/binders/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import inspect
from typing import Any, AsyncIterator, Awaitable, Protocol, TypeVar, Union

from di.api.dependencies import CacheKey
from di.dependent import Dependent, Marker

from kstreams.types import ConsumerRecord


class ExtractorTrait(Protocol):
"""Implement to extract data from incoming `ConsumerRecord`.
Consumers will always work with a consumer Record.
Implementing this would let you extract information from the `ConsumerRecord`.
"""

def __hash__(self) -> int:
"""Required by di in order to cache the deps"""
...

Check warning on line 19 in kstreams/_di/binders/api.py

View check run for this annotation

Codecov / codecov/patch

kstreams/_di/binders/api.py#L19

Added line #L19 was not covered by tests

def __eq__(self, __o: object) -> bool:
"""Required by di in order to cache the deps"""
...

Check warning on line 23 in kstreams/_di/binders/api.py

View check run for this annotation

Codecov / codecov/patch

kstreams/_di/binders/api.py#L23

Added line #L23 was not covered by tests

async def extract(
self, consumer_record: ConsumerRecord
) -> Union[Awaitable[Any], AsyncIterator[Any]]:
"""This is where the magic should happen.
For example, you could "extract" here a json from the `ConsumerRecord.value`
"""
...

Check warning on line 32 in kstreams/_di/binders/api.py

View check run for this annotation

Codecov / codecov/patch

kstreams/_di/binders/api.py#L32

Added line #L32 was not covered by tests


T = TypeVar("T", covariant=True)


class MarkerTrait(Protocol[T]):
def register_parameter(self, param: inspect.Parameter) -> T: ...


class Binder(Dependent[Any]):
def __init__(
self,
*,
extractor: ExtractorTrait,
) -> None:
super().__init__(call=extractor.extract, scope="consumer_record")
self.extractor = extractor

@property
def cache_key(self) -> CacheKey:
return self.extractor


class BinderMarker(Marker):
"""Bind together the different dependencies.
NETX: Add asyncapi marker here, like `MarkerTrait[AsyncApiTrait]`.
Recommendation to wait until 3.0:
- [#618](https://github.com/asyncapi/spec/issues/618)
"""

def __init__(self, *, extractor_marker: MarkerTrait[ExtractorTrait]) -> None:
self.extractor_marker = extractor_marker

def register_parameter(self, param: inspect.Parameter) -> Binder:
return Binder(extractor=self.extractor_marker.register_parameter(param))
52 changes: 52 additions & 0 deletions kstreams/_di/binders/header.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import inspect
from typing import Any, NamedTuple, Optional

from kstreams.exceptions import HeaderNotFound
from kstreams.types import ConsumerRecord


class HeaderExtractor(NamedTuple):
name: str

def __hash__(self) -> int:
return hash((self.__class__, self.name))

def __eq__(self, __o: object) -> bool:
return isinstance(__o, HeaderExtractor) and __o.name == self.name

Check warning on line 15 in kstreams/_di/binders/header.py

View check run for this annotation

Codecov / codecov/patch

kstreams/_di/binders/header.py#L15

Added line #L15 was not covered by tests

async def extract(self, consumer_record: ConsumerRecord) -> Any:
if isinstance(consumer_record.headers, dict):
headers = tuple(consumer_record.headers.items())

Check warning on line 19 in kstreams/_di/binders/header.py

View check run for this annotation

Codecov / codecov/patch

kstreams/_di/binders/header.py#L19

Added line #L19 was not covered by tests
# NEXT: Check also if it is a sequence, if not it means
# someone modified the headers in the serializer in a way
# that we cannot extract it, raise a readable error
else:
headers = consumer_record.headers
found_headers = [value for header, value in headers if header == self.name]

try:
header = found_headers.pop()
except IndexError as e:
message = (
f"No header `{self.name}` found.\n"
"Check if your broker is sending the header.\n"
"Try adding a default value to your parameter like `None`.\n"
"Or set `convert_underscores = False`."
)
raise HeaderNotFound(message) from e
else:
return header


class HeaderMarker(NamedTuple):
alias: Optional[str]
convert_underscores: bool

def register_parameter(self, param: inspect.Parameter) -> HeaderExtractor:
if self.alias is not None:
name = self.alias
elif self.convert_underscores:
name = param.name.replace("_", "-")
else:
name = param.name
return HeaderExtractor(name=name)
86 changes: 86 additions & 0 deletions kstreams/_di/dependencies/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import Any, Callable, Optional

from di import Container, bind_by_type
from di.dependent import Dependent as Dependant
from di.executors import AsyncExecutor

from kstreams.types import ConsumerRecord


class StreamDependencyManager:
"""Core of dependency injection on kstreams.
Attributes:
container: dependency store.
Usage
```python
consumer_record = ConsumerRecord(...)
def user_func(event_type: FromHeader[str]):
...
sdm = StreamDependencyManager()
sdm.solve(user_func)
sdm.execute(consumer_record)
```
"""

container: Container

def __init__(self, container: Optional[Container] = None):
self.container = container or Container()
self.async_executor = AsyncExecutor()

def _register_framework_deps(self):
"""Register with the container types that belong to kstream.
These types can be injectable things available on startup.
But also they can be things created on a per connection basis.
They must be available at the time of executing the users' function.
For example:
- App
- ConsumerRecord
- Consumer
- KafkaBackend
Here we are just "registering", that's why we use `bind_by_type`.
And eventually, when "executing", we should inject all the values we
want to be "available".
"""
self.container.bind(
bind_by_type(
Dependant(ConsumerRecord, scope="consumer_record", wire=False),
ConsumerRecord,
)
)
# NEXT: Add Consumer as a dependency, so it can be injected.

def build(self, user_fn: Callable[..., Any]) -> None:
"""Build the dependency graph for the given function.
Attributes:
user_fn: user defined function, using allowed kstreams params
"""
self._register_framework_deps()

self.solved_user_fn = self.container.solve(
Dependant(user_fn, scope="consumer_record"),
scopes=["consumer_record"],
)

async def execute(self, consumer_record: ConsumerRecord) -> Any:
"""Execute the dependencies graph with external values.
Attributes:
consumer_record: A kafka record containing `values`, `headers`, etc.
"""
with self.container.enter_scope("consumer_record") as state:
return await self.container.execute_async(
self.solved_user_fn,
values={ConsumerRecord: consumer_record},
executor=self.async_executor,
state=state,
)
33 changes: 33 additions & 0 deletions kstreams/_di/parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Optional, TypeVar

from kstreams._di.binders.api import BinderMarker
from kstreams._di.binders.header import HeaderMarker
from kstreams.typing import Annotated


def Header(
*, alias: Optional[str] = None, convert_underscores: bool = True
) -> BinderMarker:
"""Construct another type from the headers of a kafka record.
Usage:
```python
from kstream import Header, Annotated
def user_fn(event_type: Annotated[str, Header(alias="EventType")]):
...
```
"""
header_marker = HeaderMarker(alias=alias, convert_underscores=convert_underscores)
binder = BinderMarker(extractor_marker=header_marker)
return binder


T = TypeVar("T")

FromHeader = Annotated[T, Header(alias=None)]
FromHeader.__doc__ = """General purpose convenient header type.
Use `Annotated` to provide custom params.
"""
3 changes: 3 additions & 0 deletions kstreams/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ def __str__(self) -> str:


class BackendNotSet(StreamException): ...


class HeaderNotFound(StreamException): ...
10 changes: 6 additions & 4 deletions kstreams/middleware/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from aiokafka import errors

from kstreams import ConsumerRecord, types
from kstreams import types
from kstreams.streams_utils import StreamErrorPolicy

if typing.TYPE_CHECKING:
Expand All @@ -25,7 +25,9 @@ def __init__(
**kwargs: typing.Any,
) -> None: ... # pragma: no cover

async def __call__(self, cr: ConsumerRecord) -> typing.Any: ... # pragma: no cover
async def __call__(
self, cr: types.ConsumerRecord
) -> typing.Any: ... # pragma: no cover


class Middleware:
Expand Down Expand Up @@ -56,7 +58,7 @@ def __init__(
self.send = send
self.stream = stream

async def __call__(self, cr: ConsumerRecord) -> typing.Any:
async def __call__(self, cr: types.ConsumerRecord) -> typing.Any:
raise NotImplementedError


Expand All @@ -76,7 +78,7 @@ def __init__(
self.engine = engine
self.error_policy = error_policy

async def __call__(self, cr: ConsumerRecord) -> typing.Any:
async def __call__(self, cr: types.ConsumerRecord) -> typing.Any:
try:
return await self.next_call(cr)
except errors.ConsumerStoppedError as exc:
Expand Down
3 changes: 2 additions & 1 deletion kstreams/middleware/udf_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import sys
import typing

from kstreams import ConsumerRecord, types
from kstreams import types
from kstreams.streams import Stream
from kstreams.streams_utils import UDFType, setup_type
from kstreams.types import ConsumerRecord

from .middleware import BaseMiddleware

Expand Down
2 changes: 1 addition & 1 deletion kstreams/serializers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, Optional, Protocol

from kstreams import ConsumerRecord
from kstreams.types import ConsumerRecord

from .types import Headers

Expand Down
4 changes: 2 additions & 2 deletions kstreams/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from aiokafka import errors

from kstreams import ConsumerRecord, TopicPartition
from kstreams import TopicPartition
from kstreams.exceptions import BackendNotSet
from kstreams.middleware.middleware import ExceptionMiddleware
from kstreams.structs import TopicPartitionOffset
Expand All @@ -19,7 +19,7 @@
from .rebalance_listener import RebalanceListener
from .serializers import Deserializer
from .streams_utils import StreamErrorPolicy, UDFType
from .types import Deprecated, StreamFunc
from .types import ConsumerRecord, Deprecated, StreamFunc

if typing.TYPE_CHECKING:
from kstreams import StreamEngine
Expand Down
6 changes: 3 additions & 3 deletions kstreams/test_utils/test_clients.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from datetime import datetime
from typing import Any, Coroutine, Dict, List, Optional, Sequence, Set

from kstreams import ConsumerRecord, RebalanceListener, TopicPartition
from kstreams import RebalanceListener, TopicPartition
from kstreams.clients import Consumer, Producer
from kstreams.serializers import Serializer
from kstreams.types import Headers
from kstreams.types import ConsumerRecord, Headers

from .structs import RecordMetadata
from .topics import TopicManager
Expand Down Expand Up @@ -204,7 +204,7 @@ async def getmany(
*partitions: List[TopicPartition],
timeout_ms: int = 0,
max_records: int = 1,
) -> Dict[TopicPartition, List[ConsumerRecord]]:
) -> Dict[TopicPartition, List[ConsumerRecord | None]]:
"""
Basic getmany implementation.
`partitions` and `timeout_ms` could be added to the logic
Expand Down
4 changes: 2 additions & 2 deletions kstreams/test_utils/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from types import TracebackType
from typing import Any, Dict, List, Optional, Type

from kstreams import Consumer, ConsumerRecord, Producer
from kstreams import Consumer, Producer
from kstreams.engine import StreamEngine
from kstreams.prometheus.monitor import PrometheusMonitor
from kstreams.serializers import Serializer
from kstreams.streams import Stream
from kstreams.types import Headers
from kstreams.types import ConsumerRecord, Headers

from .structs import RecordMetadata
from .test_clients import TestConsumer, TestProducer
Expand Down
Loading

0 comments on commit 68e98c9

Please sign in to comment.