Skip to content

Commit

Permalink
feat: add dependency injection framework
Browse files Browse the repository at this point in the history
  • Loading branch information
woile committed Nov 19, 2024
1 parent c222b81 commit 27669df
Show file tree
Hide file tree
Showing 24 changed files with 969 additions and 202 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ if __name__ == "__main__":
- [ ] Store (kafka streams pattern)
- [ ] Stream Join
- [ ] Windowing
- [ ] PEP 593

## Development

Expand Down
6 changes: 6 additions & 0 deletions kstreams/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from aiokafka.structs import RecordMetadata, TopicPartition

from ._di.parameters import FromHeader, Header
from .backends.kafka import Kafka
from .clients import Consumer, Producer
from .create import StreamEngine, create_engine
from .prometheus.monitor import PrometheusMonitor, PrometheusMonitorType
Expand Down Expand Up @@ -31,4 +33,8 @@
"TestStreamClient",
"TopicPartition",
"TopicPartitionOffset",
"Kafka",
"StreamDependencyManager",
"FromHeader",
"Header",
]
68 changes: 68 additions & 0 deletions kstreams/_di/binders/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import inspect
from typing import Any, AsyncIterator, Awaitable, Protocol, TypeVar, Union

from di.api.dependencies import CacheKey
from di.dependent import Dependent, Marker

from kstreams.types import ConsumerRecord


class ExtractorTrait(Protocol):
"""Implement to extract data from incoming `ConsumerRecord`.
Consumers will always work with a consumer Record.
Implementing this would let you extract information from the `ConsumerRecord`.
"""

def __hash__(self) -> int:
"""Required by di in order to cache the deps"""
...

Check warning on line 19 in kstreams/_di/binders/api.py

View check run for this annotation

Codecov / codecov/patch

kstreams/_di/binders/api.py#L19

Added line #L19 was not covered by tests

def __eq__(self, __o: object) -> bool:
"""Required by di in order to cache the deps"""
...

Check warning on line 23 in kstreams/_di/binders/api.py

View check run for this annotation

Codecov / codecov/patch

kstreams/_di/binders/api.py#L23

Added line #L23 was not covered by tests

async def extract(
self, consumer_record: ConsumerRecord
) -> Union[Awaitable[Any], AsyncIterator[Any]]:
"""This is where the magic should happen.
For example, you could "extract" here a json from the `ConsumerRecord.value`
"""
...

Check warning on line 32 in kstreams/_di/binders/api.py

View check run for this annotation

Codecov / codecov/patch

kstreams/_di/binders/api.py#L32

Added line #L32 was not covered by tests


T = TypeVar("T", covariant=True)


class MarkerTrait(Protocol[T]):
def register_parameter(self, param: inspect.Parameter) -> T: ...


class Binder(Dependent[Any]):
def __init__(
self,
*,
extractor: ExtractorTrait,
) -> None:
super().__init__(call=extractor.extract, scope="consumer_record")
self.extractor = extractor

@property
def cache_key(self) -> CacheKey:
return self.extractor


class BinderMarker(Marker):
"""Bind together the different dependencies.
NETX: Add asyncapi marker here, like `MarkerTrait[AsyncApiTrait]`.
Recommendation to wait until 3.0:
- [#618](https://github.com/asyncapi/spec/issues/618)
"""

def __init__(self, *, extractor_marker: MarkerTrait[ExtractorTrait]) -> None:
self.extractor_marker = extractor_marker

def register_parameter(self, param: inspect.Parameter) -> Binder:
return Binder(extractor=self.extractor_marker.register_parameter(param))
44 changes: 44 additions & 0 deletions kstreams/_di/binders/header.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import inspect
from typing import Any, NamedTuple, Optional

from kstreams.exceptions import HeaderNotFound
from kstreams.types import ConsumerRecord


class HeaderExtractor(NamedTuple):
name: str

def __hash__(self) -> int:
return hash((self.__class__, self.name))

def __eq__(self, __o: object) -> bool:
return isinstance(__o, HeaderExtractor) and __o.name == self.name

Check warning on line 15 in kstreams/_di/binders/header.py

View check run for this annotation

Codecov / codecov/patch

kstreams/_di/binders/header.py#L15

Added line #L15 was not covered by tests

async def extract(self, consumer_record: ConsumerRecord) -> Any:
headers = dict(consumer_record.headers)
try:
header = headers[self.name]
except KeyError as e:
message = (
f"No header `{self.name}` found.\n"
"Check if your broker is sending the header.\n"
"Try adding a default value to your parameter like `None`.\n"
"Or set `convert_underscores = False`."
)
raise HeaderNotFound(message) from e
else:
return header


class HeaderMarker(NamedTuple):
alias: Optional[str]
convert_underscores: bool

def register_parameter(self, param: inspect.Parameter) -> HeaderExtractor:
if self.alias is not None:
name = self.alias
elif self.convert_underscores:
name = param.name.replace("_", "-")
else:
name = param.name
return HeaderExtractor(name=name)
114 changes: 114 additions & 0 deletions kstreams/_di/dependencies/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from typing import Any, Callable, Optional

from di import Container, bind_by_type
from di.dependent import Dependent
from di.executors import AsyncExecutor

from kstreams._di.dependencies.hooks import bind_by_generic
from kstreams.streams import Stream
from kstreams.types import ConsumerRecord, Send

LayerFn = Callable[..., Any]


class StreamDependencyManager:
"""Core of dependency injection on kstreams.
This is an internal class of kstreams that manages the dependency injection,
as a user you should not use this class directly.
Attributes:
container: dependency store.
stream: the stream wrapping the user function. Optional to improve testability.
When instanciating this class you must provide a stream, otherwise users
won't be able to use the `stream` parameter in their functions.
send: send object. Optional to improve testability, same as stream.
Usage:
stream and send are ommited for simplicity
```python
def user_func(cr: ConsumerRecord):
...
sdm = StreamDependencyManager()
sdm.solve(user_func)
sdm.execute(consumer_record)
```
"""

container: Container

def __init__(
self,
container: Optional[Container] = None,
stream: Optional[Stream] = None,
send: Optional[Send] = None,
):
self.container = container or Container()
self.async_executor = AsyncExecutor()
self.stream = stream
self.send = send

def solve_user_fn(self, fn: LayerFn) -> None:
"""Build the dependency graph for the given function.
Objects must be injected before this function is called.
Attributes:
fn: user defined function, using allowed kstreams params
"""
self._register_consumer_record()

if isinstance(self.stream, Stream):
self._register_stream(self.stream)

if self.send is not None:
self._register_send(self.send)

self.solved_user_fn = self.container.solve(
Dependent(fn, scope="consumer_record"),
scopes=["consumer_record", "stream", "application"],
)

async def execute(self, consumer_record: ConsumerRecord) -> Any:
"""Execute the dependencies graph with external values.
Attributes:
consumer_record: A kafka record containing `values`, `headers`, etc.
"""
async with self.container.enter_scope("consumer_record") as state:
return await self.container.execute_async(
self.solved_user_fn,
values={ConsumerRecord: consumer_record},
executor=self.async_executor,
state=state,
)

def _register_stream(self, stream: Stream):
"""Register the stream with the container."""
hook = bind_by_type(
Dependent(lambda: stream, scope="consumer_record", wire=False), Stream
)
self.container.bind(hook)

def _register_consumer_record(self):
"""Register consumer record with the container.
We bind_by_generic because we want to bind the `ConsumerRecord` type which
is generic.
The value must be injected at runtime.
"""
hook = bind_by_generic(
Dependent(ConsumerRecord, scope="consumer_record", wire=False),
ConsumerRecord,
)
self.container.bind(hook)

def _register_send(self, send: Send):
hook = bind_by_type(
Dependent(lambda: send, scope="consumer_record", wire=False), Send
)
self.container.bind(hook)
31 changes: 31 additions & 0 deletions kstreams/_di/dependencies/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import inspect
from typing import Any, get_origin

from di._container import BindHook
from di._utils.inspect import get_type
from di.api.dependencies import DependentBase


def bind_by_generic(
provider: DependentBase[Any],
dependency: type,
) -> BindHook:
"""Hook to substitute the matched dependency based on its generic."""

def hook(
param: inspect.Parameter | None, dependent: DependentBase[Any]
) -> DependentBase[Any] | None:
if dependent.call == dependency:
return provider
if param is None:
return None

type_annotation_option = get_type(param)
if type_annotation_option is None:
return None

Check warning on line 25 in kstreams/_di/dependencies/hooks.py

View check run for this annotation

Codecov / codecov/patch

kstreams/_di/dependencies/hooks.py#L25

Added line #L25 was not covered by tests
type_annotation = type_annotation_option.value
if get_origin(type_annotation) is dependency:
return provider
return None

return hook
37 changes: 37 additions & 0 deletions kstreams/_di/parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Optional, TypeVar

from kstreams._di.binders.api import BinderMarker
from kstreams._di.binders.header import HeaderMarker
from kstreams.typing import Annotated


def Header(
*, alias: Optional[str] = None, convert_underscores: bool = True
) -> BinderMarker:
"""Construct another type from the headers of a kafka record.
Args:
alias: Use a different header name
convert_underscores: If True, convert underscores to dashes.
Usage:
```python
from kstream import Header, Annotated
def user_fn(event_type: Annotated[str, Header(alias="EventType")]):
...
```
"""
header_marker = HeaderMarker(alias=alias, convert_underscores=convert_underscores)
binder = BinderMarker(extractor_marker=header_marker)
return binder


T = TypeVar("T")

FromHeader = Annotated[T, Header()]
FromHeader.__doc__ = """General purpose convenient header type.
Use `Annotated` to provide custom params.
"""
12 changes: 9 additions & 3 deletions kstreams/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@

from aiokafka.structs import RecordMetadata

from kstreams.middleware.di_middleware import DependencyInjectionHandler
from kstreams.structs import TopicPartitionOffset

from .backends.kafka import Kafka
from .clients import Consumer, Producer
from .exceptions import DuplicateStreamException, EngineNotStartedException
from .middleware import Middleware
from .middleware.udf_middleware import UdfHandler
from .prometheus.monitor import PrometheusMonitor
from .rebalance_listener import MetricsRebalanceListener, RebalanceListener
from .serializers import Deserializer, Serializer
Expand Down Expand Up @@ -389,15 +389,21 @@ def add_stream(
stream.rebalance_listener.stream = stream
stream.rebalance_listener.engine = self

stream.udf_handler = UdfHandler(
# stream.udf_handler = UdfHandler(
# next_call=stream.func,
# send=self.send,
# stream=stream,
# )

stream.udf_handler = DependencyInjectionHandler(
next_call=stream.func,
send=self.send,
stream=stream,
)

# NOTE: When `no typing` support is deprecated this check can
# be removed
if stream.udf_handler.type != UDFType.NO_TYPING:
if stream.udf_handler.get_type() != UDFType.NO_TYPING:
stream.func = self._build_stream_middleware_stack(stream=stream)

def _build_stream_middleware_stack(self, *, stream: Stream) -> NextMiddlewareCall:
Expand Down
3 changes: 3 additions & 0 deletions kstreams/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ def __str__(self) -> str:


class BackendNotSet(StreamException): ...


class HeaderNotFound(StreamException): ...
Loading

0 comments on commit 27669df

Please sign in to comment.