-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add dependency injection framework
- Loading branch information
Showing
20 changed files
with
621 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import inspect | ||
from typing import Any, AsyncIterator, Awaitable, Protocol, TypeVar, Union | ||
|
||
from di.api.dependencies import CacheKey | ||
from di.dependent import Dependent, Marker | ||
|
||
from kstreams.types import ConsumerRecord | ||
|
||
|
||
class ExtractorTrait(Protocol): | ||
"""Implement to extract data from incoming `ConsumerRecord`. | ||
Consumers will always work with a consumer Record. | ||
Implementing this would let you extract information from the `ConsumerRecord`. | ||
""" | ||
|
||
def __hash__(self) -> int: | ||
"""Required by di in order to cache the deps""" | ||
... | ||
|
||
def __eq__(self, __o: object) -> bool: | ||
"""Required by di in order to cache the deps""" | ||
... | ||
|
||
async def extract( | ||
self, consumer_record: ConsumerRecord | ||
) -> Union[Awaitable[Any], AsyncIterator[Any]]: | ||
"""This is where the magic should happen. | ||
For example, you could "extract" here a json from the `ConsumerRecord.value` | ||
""" | ||
... | ||
|
||
|
||
T = TypeVar("T", covariant=True) | ||
|
||
|
||
class MarkerTrait(Protocol[T]): | ||
def register_parameter(self, param: inspect.Parameter) -> T: ... | ||
|
||
|
||
class Binder(Dependent[Any]): | ||
def __init__( | ||
self, | ||
*, | ||
extractor: ExtractorTrait, | ||
) -> None: | ||
super().__init__(call=extractor.extract, scope="consumer_record") | ||
self.extractor = extractor | ||
|
||
@property | ||
def cache_key(self) -> CacheKey: | ||
return self.extractor | ||
|
||
|
||
class BinderMarker(Marker): | ||
"""Bind together the different dependencies. | ||
NETX: Add asyncapi marker here, like `MarkerTrait[AsyncApiTrait]`. | ||
Recommendation to wait until 3.0: | ||
- [#618](https://github.com/asyncapi/spec/issues/618) | ||
""" | ||
|
||
def __init__(self, *, extractor_marker: MarkerTrait[ExtractorTrait]) -> None: | ||
self.extractor_marker = extractor_marker | ||
|
||
def register_parameter(self, param: inspect.Parameter) -> Binder: | ||
return Binder(extractor=self.extractor_marker.register_parameter(param)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import inspect | ||
from typing import Any, NamedTuple, Optional | ||
|
||
from kstreams.exceptions import HeaderNotFound | ||
from kstreams.types import ConsumerRecord | ||
|
||
|
||
class HeaderExtractor(NamedTuple): | ||
name: str | ||
|
||
def __hash__(self) -> int: | ||
return hash((self.__class__, self.name)) | ||
|
||
def __eq__(self, __o: object) -> bool: | ||
return isinstance(__o, HeaderExtractor) and __o.name == self.name | ||
|
||
async def extract(self, consumer_record: ConsumerRecord) -> Any: | ||
headers = dict(consumer_record.headers) | ||
try: | ||
header = headers[self.name] | ||
except KeyError as e: | ||
message = ( | ||
f"No header `{self.name}` found.\n" | ||
"Check if your broker is sending the header.\n" | ||
"Try adding a default value to your parameter like `None`.\n" | ||
"Or set `convert_underscores = False`." | ||
) | ||
raise HeaderNotFound(message) from e | ||
else: | ||
return header | ||
|
||
|
||
class HeaderMarker(NamedTuple): | ||
alias: Optional[str] | ||
convert_underscores: bool | ||
|
||
def register_parameter(self, param: inspect.Parameter) -> HeaderExtractor: | ||
if self.alias is not None: | ||
name = self.alias | ||
elif self.convert_underscores: | ||
name = param.name.replace("_", "-") | ||
else: | ||
name = param.name | ||
return HeaderExtractor(name=name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from typing import Any, Callable, Optional | ||
|
||
from di import Container, bind_by_type | ||
from di.dependent import Dependent | ||
from di.executors import AsyncExecutor | ||
|
||
from kstreams._di.dependencies.hooks import bind_by_generic | ||
from kstreams.streams import Stream | ||
from kstreams.types import ConsumerRecord, Send | ||
|
||
LayerFn = Callable[..., Any] | ||
|
||
|
||
class StreamDependencyManager: | ||
"""Core of dependency injection on kstreams. | ||
Attributes: | ||
container: dependency store. | ||
Usage | ||
```python | ||
consumer_record = ConsumerRecord(...) | ||
def user_func(event_type: FromHeader[str]): | ||
... | ||
sdm = StreamDependencyManager() | ||
sdm.solve(user_func) | ||
sdm.execute(consumer_record) | ||
``` | ||
""" | ||
|
||
container: Container | ||
|
||
def __init__(self, container: Optional[Container] = None): | ||
self.container = container or Container() | ||
self.async_executor = AsyncExecutor() | ||
|
||
def _register_framework_deps(self): | ||
"""Register with the container types that belong to kstream. | ||
These types can be injectable things available on startup. | ||
But also they can be things created on a per connection basis. | ||
They must be available at the time of executing the users' function. | ||
For example: | ||
- App | ||
- ConsumerRecord | ||
- Consumer | ||
- KafkaBackend | ||
Here we are just "registering", that's why we use `bind_by_type`. | ||
And eventually, when "executing", we should inject all the values we | ||
want to be "available". | ||
""" | ||
hook = bind_by_generic( | ||
Dependent(ConsumerRecord, scope="consumer_record", wire=False), | ||
ConsumerRecord, | ||
) | ||
self.container.bind(hook) | ||
# NEXT: Add Consumer as a dependency, so it can be injected. | ||
|
||
def solve_user_fn(self, fn: LayerFn) -> None: | ||
"""Build the dependency graph for the given function. | ||
Attributes: | ||
fn: user defined function, using allowed kstreams params | ||
""" | ||
self._register_framework_deps() | ||
|
||
self.solved_user_fn = self.container.solve( | ||
Dependent(fn, scope="consumer_record"), | ||
scopes=["consumer_record"], | ||
) | ||
|
||
def register_stream(self, stream: Stream): | ||
hook = bind_by_type(Dependent(stream, scope="consumer_record"), Stream) | ||
self.container.bind(hook) | ||
|
||
def register_send(self, send: Send): ... | ||
|
||
async def execute(self, consumer_record: ConsumerRecord) -> Any: | ||
"""Execute the dependencies graph with external values. | ||
Attributes: | ||
consumer_record: A kafka record containing `values`, `headers`, etc. | ||
""" | ||
async with self.container.enter_scope("consumer_record") as state: | ||
return await self.container.execute_async( | ||
self.solved_user_fn, | ||
values={ConsumerRecord: consumer_record}, | ||
executor=self.async_executor, | ||
state=state, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import inspect | ||
from typing import Any, get_origin | ||
|
||
from di._container import BindHook | ||
from di._utils.inspect import get_type | ||
from di.api.dependencies import DependentBase | ||
|
||
|
||
def bind_by_generic( | ||
provider: DependentBase[Any], | ||
dependency: type, | ||
) -> BindHook: | ||
"""Hook to substitute the matched dependency based on its generic.""" | ||
|
||
def hook( | ||
param: inspect.Parameter | None, dependent: DependentBase[Any] | ||
) -> DependentBase[Any] | None: | ||
if dependent.call == dependency: | ||
return provider | ||
if param is None: | ||
return None | ||
|
||
type_annotation_option = get_type(param) | ||
if type_annotation_option is None: | ||
return None | ||
type_annotation = type_annotation_option.value | ||
if get_origin(type_annotation) is dependency: | ||
return provider | ||
return None | ||
|
||
return hook |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from typing import Optional, TypeVar | ||
|
||
from kstreams._di.binders.api import BinderMarker | ||
from kstreams._di.binders.header import HeaderMarker | ||
from kstreams.typing import Annotated | ||
|
||
|
||
def Header( | ||
*, alias: Optional[str] = None, convert_underscores: bool = True | ||
) -> BinderMarker: | ||
"""Construct another type from the headers of a kafka record. | ||
Args: | ||
alias: Use a different header name | ||
convert_underscores: If True, convert underscores to dashes. | ||
Usage: | ||
```python | ||
from kstream import Header, Annotated | ||
def user_fn(event_type: Annotated[str, Header(alias="EventType")]): | ||
... | ||
``` | ||
""" | ||
header_marker = HeaderMarker(alias=alias, convert_underscores=convert_underscores) | ||
binder = BinderMarker(extractor_marker=header_marker) | ||
return binder | ||
|
||
|
||
T = TypeVar("T") | ||
|
||
FromHeader = Annotated[T, Header()] | ||
FromHeader.__doc__ = """General purpose convenient header type. | ||
Use `Annotated` to provide custom params. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,3 +25,6 @@ def __str__(self) -> str: | |
|
||
|
||
class BackendNotSet(StreamException): ... | ||
|
||
|
||
class HeaderNotFound(StreamException): ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import typing | ||
|
||
from kstreams import types | ||
from kstreams._di.dependencies.core import StreamDependencyManager | ||
|
||
from .middleware import BaseMiddleware | ||
|
||
|
||
class DependencyInjectionHandler(BaseMiddleware): | ||
def __init__(self, *args, **kwargs) -> None: | ||
super().__init__(*args, **kwargs) | ||
self.dependecy_manager = StreamDependencyManager() | ||
self.dependecy_manager.register_stream(self.stream) | ||
self.dependecy_manager.register_send(self.send) | ||
self.dependecy_manager.solve_user_fn(fn=self.next_call) | ||
self.type = None | ||
|
||
async def __call__(self, cr: types.ConsumerRecord) -> typing.Any: | ||
return await self.dependecy_manager.execute(cr) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.