-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add dependency injection framework
- Loading branch information
Showing
24 changed files
with
969 additions
and
202 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import inspect | ||
from typing import Any, AsyncIterator, Awaitable, Protocol, TypeVar, Union | ||
|
||
from di.api.dependencies import CacheKey | ||
from di.dependent import Dependent, Marker | ||
|
||
from kstreams.types import ConsumerRecord | ||
|
||
|
||
class ExtractorTrait(Protocol): | ||
"""Implement to extract data from incoming `ConsumerRecord`. | ||
Consumers will always work with a consumer Record. | ||
Implementing this would let you extract information from the `ConsumerRecord`. | ||
""" | ||
|
||
def __hash__(self) -> int: | ||
"""Required by di in order to cache the deps""" | ||
... | ||
|
||
def __eq__(self, __o: object) -> bool: | ||
"""Required by di in order to cache the deps""" | ||
... | ||
|
||
async def extract( | ||
self, consumer_record: ConsumerRecord | ||
) -> Union[Awaitable[Any], AsyncIterator[Any]]: | ||
"""This is where the magic should happen. | ||
For example, you could "extract" here a json from the `ConsumerRecord.value` | ||
""" | ||
... | ||
|
||
|
||
T = TypeVar("T", covariant=True) | ||
|
||
|
||
class MarkerTrait(Protocol[T]): | ||
def register_parameter(self, param: inspect.Parameter) -> T: ... | ||
|
||
|
||
class Binder(Dependent[Any]): | ||
def __init__( | ||
self, | ||
*, | ||
extractor: ExtractorTrait, | ||
) -> None: | ||
super().__init__(call=extractor.extract, scope="consumer_record") | ||
self.extractor = extractor | ||
|
||
@property | ||
def cache_key(self) -> CacheKey: | ||
return self.extractor | ||
|
||
|
||
class BinderMarker(Marker): | ||
"""Bind together the different dependencies. | ||
NETX: Add asyncapi marker here, like `MarkerTrait[AsyncApiTrait]`. | ||
Recommendation to wait until 3.0: | ||
- [#618](https://github.com/asyncapi/spec/issues/618) | ||
""" | ||
|
||
def __init__(self, *, extractor_marker: MarkerTrait[ExtractorTrait]) -> None: | ||
self.extractor_marker = extractor_marker | ||
|
||
def register_parameter(self, param: inspect.Parameter) -> Binder: | ||
return Binder(extractor=self.extractor_marker.register_parameter(param)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import inspect | ||
from typing import Any, NamedTuple, Optional | ||
|
||
from kstreams.exceptions import HeaderNotFound | ||
from kstreams.types import ConsumerRecord | ||
|
||
|
||
class HeaderExtractor(NamedTuple): | ||
name: str | ||
|
||
def __hash__(self) -> int: | ||
return hash((self.__class__, self.name)) | ||
|
||
def __eq__(self, __o: object) -> bool: | ||
return isinstance(__o, HeaderExtractor) and __o.name == self.name | ||
|
||
async def extract(self, consumer_record: ConsumerRecord) -> Any: | ||
headers = dict(consumer_record.headers) | ||
try: | ||
header = headers[self.name] | ||
except KeyError as e: | ||
message = ( | ||
f"No header `{self.name}` found.\n" | ||
"Check if your broker is sending the header.\n" | ||
"Try adding a default value to your parameter like `None`.\n" | ||
"Or set `convert_underscores = False`." | ||
) | ||
raise HeaderNotFound(message) from e | ||
else: | ||
return header | ||
|
||
|
||
class HeaderMarker(NamedTuple): | ||
alias: Optional[str] | ||
convert_underscores: bool | ||
|
||
def register_parameter(self, param: inspect.Parameter) -> HeaderExtractor: | ||
if self.alias is not None: | ||
name = self.alias | ||
elif self.convert_underscores: | ||
name = param.name.replace("_", "-") | ||
else: | ||
name = param.name | ||
return HeaderExtractor(name=name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
from typing import Any, Callable, Optional | ||
|
||
from di import Container, bind_by_type | ||
from di.dependent import Dependent | ||
from di.executors import AsyncExecutor | ||
|
||
from kstreams._di.dependencies.hooks import bind_by_generic | ||
from kstreams.streams import Stream | ||
from kstreams.types import ConsumerRecord, Send | ||
|
||
LayerFn = Callable[..., Any] | ||
|
||
|
||
class StreamDependencyManager: | ||
"""Core of dependency injection on kstreams. | ||
This is an internal class of kstreams that manages the dependency injection, | ||
as a user you should not use this class directly. | ||
Attributes: | ||
container: dependency store. | ||
stream: the stream wrapping the user function. Optional to improve testability. | ||
When instanciating this class you must provide a stream, otherwise users | ||
won't be able to use the `stream` parameter in their functions. | ||
send: send object. Optional to improve testability, same as stream. | ||
Usage: | ||
stream and send are ommited for simplicity | ||
```python | ||
def user_func(cr: ConsumerRecord): | ||
... | ||
sdm = StreamDependencyManager() | ||
sdm.solve(user_func) | ||
sdm.execute(consumer_record) | ||
``` | ||
""" | ||
|
||
container: Container | ||
|
||
def __init__( | ||
self, | ||
container: Optional[Container] = None, | ||
stream: Optional[Stream] = None, | ||
send: Optional[Send] = None, | ||
): | ||
self.container = container or Container() | ||
self.async_executor = AsyncExecutor() | ||
self.stream = stream | ||
self.send = send | ||
|
||
def solve_user_fn(self, fn: LayerFn) -> None: | ||
"""Build the dependency graph for the given function. | ||
Objects must be injected before this function is called. | ||
Attributes: | ||
fn: user defined function, using allowed kstreams params | ||
""" | ||
self._register_consumer_record() | ||
|
||
if isinstance(self.stream, Stream): | ||
self._register_stream(self.stream) | ||
|
||
if self.send is not None: | ||
self._register_send(self.send) | ||
|
||
self.solved_user_fn = self.container.solve( | ||
Dependent(fn, scope="consumer_record"), | ||
scopes=["consumer_record", "stream", "application"], | ||
) | ||
|
||
async def execute(self, consumer_record: ConsumerRecord) -> Any: | ||
"""Execute the dependencies graph with external values. | ||
Attributes: | ||
consumer_record: A kafka record containing `values`, `headers`, etc. | ||
""" | ||
async with self.container.enter_scope("consumer_record") as state: | ||
return await self.container.execute_async( | ||
self.solved_user_fn, | ||
values={ConsumerRecord: consumer_record}, | ||
executor=self.async_executor, | ||
state=state, | ||
) | ||
|
||
def _register_stream(self, stream: Stream): | ||
"""Register the stream with the container.""" | ||
hook = bind_by_type( | ||
Dependent(lambda: stream, scope="consumer_record", wire=False), Stream | ||
) | ||
self.container.bind(hook) | ||
|
||
def _register_consumer_record(self): | ||
"""Register consumer record with the container. | ||
We bind_by_generic because we want to bind the `ConsumerRecord` type which | ||
is generic. | ||
The value must be injected at runtime. | ||
""" | ||
hook = bind_by_generic( | ||
Dependent(ConsumerRecord, scope="consumer_record", wire=False), | ||
ConsumerRecord, | ||
) | ||
self.container.bind(hook) | ||
|
||
def _register_send(self, send: Send): | ||
hook = bind_by_type( | ||
Dependent(lambda: send, scope="consumer_record", wire=False), Send | ||
) | ||
self.container.bind(hook) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import inspect | ||
from typing import Any, get_origin | ||
|
||
from di._container import BindHook | ||
from di._utils.inspect import get_type | ||
from di.api.dependencies import DependentBase | ||
|
||
|
||
def bind_by_generic( | ||
provider: DependentBase[Any], | ||
dependency: type, | ||
) -> BindHook: | ||
"""Hook to substitute the matched dependency based on its generic.""" | ||
|
||
def hook( | ||
param: inspect.Parameter | None, dependent: DependentBase[Any] | ||
) -> DependentBase[Any] | None: | ||
if dependent.call == dependency: | ||
return provider | ||
if param is None: | ||
return None | ||
|
||
type_annotation_option = get_type(param) | ||
if type_annotation_option is None: | ||
return None | ||
type_annotation = type_annotation_option.value | ||
if get_origin(type_annotation) is dependency: | ||
return provider | ||
return None | ||
|
||
return hook |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from typing import Optional, TypeVar | ||
|
||
from kstreams._di.binders.api import BinderMarker | ||
from kstreams._di.binders.header import HeaderMarker | ||
from kstreams.typing import Annotated | ||
|
||
|
||
def Header( | ||
*, alias: Optional[str] = None, convert_underscores: bool = True | ||
) -> BinderMarker: | ||
"""Construct another type from the headers of a kafka record. | ||
Args: | ||
alias: Use a different header name | ||
convert_underscores: If True, convert underscores to dashes. | ||
Usage: | ||
```python | ||
from kstream import Header, Annotated | ||
def user_fn(event_type: Annotated[str, Header(alias="EventType")]): | ||
... | ||
``` | ||
""" | ||
header_marker = HeaderMarker(alias=alias, convert_underscores=convert_underscores) | ||
binder = BinderMarker(extractor_marker=header_marker) | ||
return binder | ||
|
||
|
||
T = TypeVar("T") | ||
|
||
FromHeader = Annotated[T, Header()] | ||
FromHeader.__doc__ = """General purpose convenient header type. | ||
Use `Annotated` to provide custom params. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,3 +25,6 @@ def __str__(self) -> str: | |
|
||
|
||
class BackendNotSet(StreamException): ... | ||
|
||
|
||
class HeaderNotFound(StreamException): ... |
Oops, something went wrong.