Skip to content

Commit

Permalink
feat: add dependency injection framework
Browse files Browse the repository at this point in the history
  • Loading branch information
woile committed Oct 31, 2024
1 parent 5e85da3 commit 2990d73
Show file tree
Hide file tree
Showing 20 changed files with 621 additions and 9 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ if __name__ == "__main__":
- [ ] Store (kafka streams pattern)
- [ ] Stream Join
- [ ] Windowing
- [ ] PEP 593

## Development

Expand Down
6 changes: 6 additions & 0 deletions kstreams/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from aiokafka.structs import RecordMetadata, TopicPartition

from ._di.parameters import FromHeader, Header
from .backends.kafka import Kafka
from .clients import Consumer, Producer
from .create import StreamEngine, create_engine
from .prometheus.monitor import PrometheusMonitor, PrometheusMonitorType
Expand Down Expand Up @@ -31,4 +33,8 @@
"TestStreamClient",
"TopicPartition",
"TopicPartitionOffset",
"Kafka",
"StreamDependencyManager",
"FromHeader",
"Header",
]
68 changes: 68 additions & 0 deletions kstreams/_di/binders/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import inspect
from typing import Any, AsyncIterator, Awaitable, Protocol, TypeVar, Union

from di.api.dependencies import CacheKey
from di.dependent import Dependent, Marker

from kstreams.types import ConsumerRecord


class ExtractorTrait(Protocol):
"""Implement to extract data from incoming `ConsumerRecord`.
Consumers will always work with a consumer Record.
Implementing this would let you extract information from the `ConsumerRecord`.
"""

def __hash__(self) -> int:
"""Required by di in order to cache the deps"""
...

def __eq__(self, __o: object) -> bool:
"""Required by di in order to cache the deps"""
...

async def extract(
self, consumer_record: ConsumerRecord
) -> Union[Awaitable[Any], AsyncIterator[Any]]:
"""This is where the magic should happen.
For example, you could "extract" here a json from the `ConsumerRecord.value`
"""
...


T = TypeVar("T", covariant=True)


class MarkerTrait(Protocol[T]):
def register_parameter(self, param: inspect.Parameter) -> T: ...


class Binder(Dependent[Any]):
def __init__(
self,
*,
extractor: ExtractorTrait,
) -> None:
super().__init__(call=extractor.extract, scope="consumer_record")
self.extractor = extractor

@property
def cache_key(self) -> CacheKey:
return self.extractor


class BinderMarker(Marker):
"""Bind together the different dependencies.
NETX: Add asyncapi marker here, like `MarkerTrait[AsyncApiTrait]`.
Recommendation to wait until 3.0:
- [#618](https://github.com/asyncapi/spec/issues/618)
"""

def __init__(self, *, extractor_marker: MarkerTrait[ExtractorTrait]) -> None:
self.extractor_marker = extractor_marker

def register_parameter(self, param: inspect.Parameter) -> Binder:
return Binder(extractor=self.extractor_marker.register_parameter(param))
44 changes: 44 additions & 0 deletions kstreams/_di/binders/header.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import inspect
from typing import Any, NamedTuple, Optional

from kstreams.exceptions import HeaderNotFound
from kstreams.types import ConsumerRecord


class HeaderExtractor(NamedTuple):
name: str

def __hash__(self) -> int:
return hash((self.__class__, self.name))

def __eq__(self, __o: object) -> bool:
return isinstance(__o, HeaderExtractor) and __o.name == self.name

async def extract(self, consumer_record: ConsumerRecord) -> Any:
headers = dict(consumer_record.headers)
try:
header = headers[self.name]
except KeyError as e:
message = (
f"No header `{self.name}` found.\n"
"Check if your broker is sending the header.\n"
"Try adding a default value to your parameter like `None`.\n"
"Or set `convert_underscores = False`."
)
raise HeaderNotFound(message) from e
else:
return header


class HeaderMarker(NamedTuple):
alias: Optional[str]
convert_underscores: bool

def register_parameter(self, param: inspect.Parameter) -> HeaderExtractor:
if self.alias is not None:
name = self.alias
elif self.convert_underscores:
name = param.name.replace("_", "-")
else:
name = param.name
return HeaderExtractor(name=name)
95 changes: 95 additions & 0 deletions kstreams/_di/dependencies/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import Any, Callable, Optional

from di import Container, bind_by_type
from di.dependent import Dependent
from di.executors import AsyncExecutor

from kstreams._di.dependencies.hooks import bind_by_generic
from kstreams.streams import Stream
from kstreams.types import ConsumerRecord, Send

LayerFn = Callable[..., Any]


class StreamDependencyManager:
"""Core of dependency injection on kstreams.
Attributes:
container: dependency store.
Usage
```python
consumer_record = ConsumerRecord(...)
def user_func(event_type: FromHeader[str]):
...
sdm = StreamDependencyManager()
sdm.solve(user_func)
sdm.execute(consumer_record)
```
"""

container: Container

def __init__(self, container: Optional[Container] = None):
self.container = container or Container()
self.async_executor = AsyncExecutor()

def _register_framework_deps(self):
"""Register with the container types that belong to kstream.
These types can be injectable things available on startup.
But also they can be things created on a per connection basis.
They must be available at the time of executing the users' function.
For example:
- App
- ConsumerRecord
- Consumer
- KafkaBackend
Here we are just "registering", that's why we use `bind_by_type`.
And eventually, when "executing", we should inject all the values we
want to be "available".
"""
hook = bind_by_generic(
Dependent(ConsumerRecord, scope="consumer_record", wire=False),
ConsumerRecord,
)
self.container.bind(hook)
# NEXT: Add Consumer as a dependency, so it can be injected.

def solve_user_fn(self, fn: LayerFn) -> None:
"""Build the dependency graph for the given function.
Attributes:
fn: user defined function, using allowed kstreams params
"""
self._register_framework_deps()

self.solved_user_fn = self.container.solve(
Dependent(fn, scope="consumer_record"),
scopes=["consumer_record"],
)

def register_stream(self, stream: Stream):
hook = bind_by_type(Dependent(stream, scope="consumer_record"), Stream)
self.container.bind(hook)

def register_send(self, send: Send): ...

async def execute(self, consumer_record: ConsumerRecord) -> Any:
"""Execute the dependencies graph with external values.
Attributes:
consumer_record: A kafka record containing `values`, `headers`, etc.
"""
async with self.container.enter_scope("consumer_record") as state:
return await self.container.execute_async(
self.solved_user_fn,
values={ConsumerRecord: consumer_record},
executor=self.async_executor,
state=state,
)
31 changes: 31 additions & 0 deletions kstreams/_di/dependencies/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import inspect
from typing import Any, get_origin

from di._container import BindHook
from di._utils.inspect import get_type
from di.api.dependencies import DependentBase


def bind_by_generic(
provider: DependentBase[Any],
dependency: type,
) -> BindHook:
"""Hook to substitute the matched dependency based on its generic."""

def hook(
param: inspect.Parameter | None, dependent: DependentBase[Any]
) -> DependentBase[Any] | None:
if dependent.call == dependency:
return provider
if param is None:
return None

type_annotation_option = get_type(param)
if type_annotation_option is None:
return None
type_annotation = type_annotation_option.value
if get_origin(type_annotation) is dependency:
return provider
return None

return hook
37 changes: 37 additions & 0 deletions kstreams/_di/parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Optional, TypeVar

from kstreams._di.binders.api import BinderMarker
from kstreams._di.binders.header import HeaderMarker
from kstreams.typing import Annotated


def Header(
*, alias: Optional[str] = None, convert_underscores: bool = True
) -> BinderMarker:
"""Construct another type from the headers of a kafka record.
Args:
alias: Use a different header name
convert_underscores: If True, convert underscores to dashes.
Usage:
```python
from kstream import Header, Annotated
def user_fn(event_type: Annotated[str, Header(alias="EventType")]):
...
```
"""
header_marker = HeaderMarker(alias=alias, convert_underscores=convert_underscores)
binder = BinderMarker(extractor_marker=header_marker)
return binder


T = TypeVar("T")

FromHeader = Annotated[T, Header()]
FromHeader.__doc__ = """General purpose convenient header type.
Use `Annotated` to provide custom params.
"""
13 changes: 10 additions & 3 deletions kstreams/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from aiokafka.structs import RecordMetadata

from kstreams.middleware.di_middleware import DependencyInjectionHandler
from kstreams.structs import TopicPartitionOffset

from .backends.kafka import Kafka
Expand Down Expand Up @@ -389,16 +390,22 @@ def add_stream(
stream.rebalance_listener.stream = stream
stream.rebalance_listener.engine = self

stream.udf_handler = UdfHandler(
# stream.udf_handler = UdfHandler(
# next_call=stream.func,
# send=self.send,
# stream=stream,
# )

stream.udf_handler = DependencyInjectionHandler(
next_call=stream.func,
send=self.send,
stream=stream,
)

# NOTE: When `no typing` support is deprecated this check can
# be removed
if stream.udf_handler.type != UDFType.NO_TYPING:
stream.func = self._build_stream_middleware_stack(stream=stream)
# if stream.udf_handler.type != UDFType.NO_TYPING:
stream.func = self._build_stream_middleware_stack(stream=stream)

def _build_stream_middleware_stack(self, *, stream: Stream) -> NextMiddlewareCall:
assert stream.udf_handler, "UdfHandler can not be None"
Expand Down
3 changes: 3 additions & 0 deletions kstreams/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ def __str__(self) -> str:


class BackendNotSet(StreamException): ...


class HeaderNotFound(StreamException): ...
19 changes: 19 additions & 0 deletions kstreams/middleware/di_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import typing

from kstreams import types
from kstreams._di.dependencies.core import StreamDependencyManager

from .middleware import BaseMiddleware


class DependencyInjectionHandler(BaseMiddleware):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.dependecy_manager = StreamDependencyManager()
self.dependecy_manager.register_stream(self.stream)
self.dependecy_manager.register_send(self.send)
self.dependecy_manager.solve_user_fn(fn=self.next_call)
self.type = None

async def __call__(self, cr: types.ConsumerRecord) -> typing.Any:
return await self.dependecy_manager.execute(cr)
8 changes: 6 additions & 2 deletions kstreams/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from .backends.kafka import Kafka
from .clients import Consumer
from .middleware import Middleware, udf_middleware
from .middleware import Middleware
from .rebalance_listener import RebalanceListener
from .serializers import Deserializer
from .streams_utils import StreamErrorPolicy, UDFType
Expand Down Expand Up @@ -172,7 +172,7 @@ def __init__(
self.seeked_initial_offsets = False
self.rebalance_listener = rebalance_listener
self.middlewares = middlewares or []
self.udf_handler: typing.Optional[udf_middleware.UdfHandler] = None
self.udf_handler = None
self.topics = [topics] if isinstance(topics, str) else topics
self.subscribe_by_pattern = subscribe_by_pattern
self.error_policy = error_policy
Expand Down Expand Up @@ -356,6 +356,10 @@ async def start(self) -> None:
await func
else:
# Typing cases

# If it's an async generator, then DON'T await the function
# because we want to start ONLY and let the user retrieve the
# values while iterating the stream
if not inspect.isasyncgenfunction(self.udf_handler.next_call):
# Is not an async_generator, then create `await` the func
await self.func_wrapper_with_typing()
Expand Down
Loading

0 comments on commit 2990d73

Please sign in to comment.