-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add dependency injection framework
- Loading branch information
Showing
14 changed files
with
458 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import inspect | ||
from typing import Any, AsyncIterator, Awaitable, Protocol, TypeVar, Union | ||
|
||
from di.api.dependencies import CacheKey | ||
from di.dependant import Dependant, Marker | ||
|
||
from kstreams.types import ConsumerRecord | ||
|
||
|
||
class ExtractorTrait(Protocol): | ||
"""Implement to extract data from incoming `ConsumerRecord`. | ||
Consumers will always work with a consumer Record. | ||
Implementing this would let you extract information from the `ConsumerRecord`. | ||
""" | ||
|
||
def __hash__(self) -> int: | ||
"""Required by di in order to cache the deps""" | ||
... | ||
|
||
def __eq__(self, __o: object) -> bool: | ||
"""Required by di in order to cache the deps""" | ||
... | ||
|
||
async def extract( | ||
self, consumer_record: ConsumerRecord | ||
) -> Union[Awaitable[Any], AsyncIterator[Any]]: | ||
"""This is where the magic should happen. | ||
For example, you could "extract" here a json from the `ConsumerRecord.value` | ||
""" | ||
... | ||
|
||
|
||
T = TypeVar("T", covariant=True) | ||
|
||
|
||
class MarkerTrait(Protocol[T]): | ||
def register_parameter(self, param: inspect.Parameter) -> T: | ||
... | ||
|
||
|
||
class Binder(Dependant[Any]): | ||
def __init__( | ||
self, | ||
*, | ||
extractor: ExtractorTrait, | ||
) -> None: | ||
super().__init__(call=extractor.extract, scope="consumer_record") | ||
self.extractor = extractor | ||
|
||
@property | ||
def cache_key(self) -> CacheKey: | ||
return self.extractor | ||
|
||
|
||
class BinderMarker(Marker): | ||
"""Bind together the different dependencies. | ||
NETX: Add asyncapi marker here, like `MarkerTrait[AsyncApiTrait]`. | ||
Recommendation to wait until 3.0: | ||
- [#618](https://github.com/asyncapi/spec/issues/618) | ||
""" | ||
|
||
def __init__(self, *, extractor_marker: MarkerTrait[ExtractorTrait]) -> None: | ||
self.extractor_marker = extractor_marker | ||
|
||
def register_parameter(self, param: inspect.Parameter) -> Binder: | ||
return Binder(extractor=self.extractor_marker.register_parameter(param)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import inspect | ||
from typing import Any, NamedTuple, Optional | ||
|
||
from kstreams.exceptions import HeaderNotFound | ||
from kstreams.types import ConsumerRecord | ||
|
||
|
||
class HeaderExtractor(NamedTuple): | ||
name: str | ||
|
||
def __hash__(self) -> int: | ||
return hash((self.__class__, self.name)) | ||
|
||
def __eq__(self, __o: object) -> bool: | ||
return isinstance(__o, HeaderExtractor) and __o.name == self.name | ||
|
||
async def extract(self, consumer_record: ConsumerRecord) -> Any: | ||
if isinstance(consumer_record.headers, dict): | ||
headers = tuple(consumer_record.headers.items()) | ||
# NEXT: Check also if it is a sequence, if not it means | ||
# someone modified the headers in the serializer in a way | ||
# that we cannot extract it, raise a readable error | ||
else: | ||
headers = consumer_record.headers | ||
found_headers = [value for header, value in headers if header == self.name] | ||
|
||
try: | ||
header = found_headers.pop() | ||
except IndexError as e: | ||
message = ( | ||
f"No header `{self.name}` found.\n" | ||
"Check if your broker is sending the header.\n" | ||
"Try adding a default value to your parameter like `None`.\n" | ||
"Or set `convert_underscores = False`." | ||
) | ||
raise HeaderNotFound(message) from e | ||
else: | ||
return header | ||
|
||
|
||
class HeaderMarker(NamedTuple): | ||
alias: Optional[str] | ||
convert_underscores: bool | ||
|
||
def register_parameter(self, param: inspect.Parameter) -> HeaderExtractor: | ||
if self.alias is not None: | ||
name = self.alias | ||
elif self.convert_underscores: | ||
name = param.name.replace("_", "-") | ||
else: | ||
name = param.name | ||
return HeaderExtractor(name=name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from typing import Any, Callable, Optional | ||
|
||
from di.container import Container, bind_by_type | ||
from di.dependant import Dependant | ||
from di.executors import AsyncExecutor | ||
|
||
from kstreams.types import ConsumerRecord | ||
|
||
|
||
class StreamDependencyManager: | ||
"""Core of dependency injection on kstreams. | ||
Attributes: | ||
container: dependency store. | ||
Usage | ||
```python | ||
consumer_record = ConsumerRecord(...) | ||
def user_func(event_type: FromHeader[str]): | ||
... | ||
sdm = StreamDependencyManager() | ||
sdm.solve(user_func) | ||
sdm.execute(consumer_record) | ||
``` | ||
""" | ||
|
||
container: Container | ||
|
||
def __init__(self, container: Optional[Container] = None): | ||
self.container = container or Container() | ||
|
||
def _register_framework_deps(self): | ||
"""Register with the container types that belong to kstream. | ||
These types can be injectable things available on startup. | ||
But also they can be things created on a per connection basis. | ||
They must be available at the time of executing the users' function. | ||
For example: | ||
- App | ||
- ConsumerRecord | ||
- Consumer | ||
- KafkaBackend | ||
Here we are just "registering", that's why we use `bind_by_type`. | ||
And eventually, when "executing", we should inject all the values we | ||
want to be "available". | ||
""" | ||
self.container.bind( | ||
bind_by_type( | ||
Dependant(ConsumerRecord, scope="consumer_record", wire=False), | ||
ConsumerRecord, | ||
) | ||
) | ||
# NEXT: Add Consumer as a dependency, so it can be injected. | ||
|
||
def build(self, user_fn: Callable[..., Any]) -> None: | ||
"""Build the dependency graph for the given function. | ||
Attributes: | ||
user_fn: user defined function, using allowed kstreams params | ||
""" | ||
self._register_framework_deps() | ||
|
||
self.solved_user_fn = self.container.solve( | ||
Dependant(user_fn, scope="consumer_record"), | ||
scopes=["consumer_record"], | ||
) | ||
|
||
async def execute(self, consumer_record: ConsumerRecord) -> Any: | ||
"""Execute the dependencies graph with external values. | ||
Attributes: | ||
consumer_record: A kafka record containing `values`, `headers`, etc. | ||
""" | ||
with self.container.enter_scope("consumer_record") as state: | ||
return await self.container.execute_async( | ||
self.solved_user_fn, | ||
values={ConsumerRecord: consumer_record}, | ||
executor=AsyncExecutor(), | ||
state=state, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,3 +26,7 @@ def __str__(self) -> str: | |
|
||
class BackendNotSet(StreamException): | ||
... | ||
|
||
|
||
class HeaderNotFound(StreamException): | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from typing import Optional, TypeVar | ||
|
||
from kstreams.binders.api import BinderMarker | ||
from kstreams.binders.header import HeaderMarker | ||
from kstreams.typing import Annotated | ||
|
||
|
||
def Header( | ||
*, alias: Optional[str] = None, convert_underscores: bool = True | ||
) -> BinderMarker: | ||
"""Construct another type from the headers of a kafka record. | ||
Usage: | ||
```python | ||
from kstream import Header, Annotated | ||
def user_fn(event_type: Annotated[str, Header(alias="EventType")]): | ||
... | ||
``` | ||
""" | ||
header_marker = HeaderMarker(alias=alias, convert_underscores=convert_underscores) | ||
binder = BinderMarker(extractor_marker=header_marker) | ||
return binder | ||
|
||
|
||
T = TypeVar("T") | ||
|
||
FromHeader = Annotated[T, Header(alias=None)] | ||
FromHeader.__doc__ = """General purpose convenient header type. | ||
Use `Annotated` to provide custom params. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,10 @@ | ||
from typing import Dict, Sequence, Tuple | ||
|
||
from aiokafka.structs import ConsumerRecord as AIOConsumerRecord | ||
|
||
Headers = Dict[str, str] | ||
EncodedHeaders = Sequence[Tuple[str, bytes]] | ||
|
||
|
||
class ConsumerRecord(AIOConsumerRecord): | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
"""Remove this file when python3.8 support is dropped.""" | ||
import sys | ||
|
||
if sys.version_info < (3, 9): | ||
from typing_extensions import Annotated as Annotated # noqa: F401 | ||
else: | ||
from typing import Annotated as Annotated # noqa: F401 |
Oops, something went wrong.