Skip to content

Commit

Permalink
feat: add dependency injection framework
Browse files Browse the repository at this point in the history
  • Loading branch information
woile committed Sep 17, 2022
1 parent c203741 commit f6fbed5
Show file tree
Hide file tree
Showing 14 changed files with 458 additions and 5 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ if __name__ == "__main__":
- [ ] Store (kafka streams pattern)
- [ ] Stream Join
- [ ] Windowing
- [ ] PEP 593

## Development

Expand Down
10 changes: 8 additions & 2 deletions kstreams/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from aiokafka.structs import ConsumerRecord

from .backends.kafka import Kafka
from .clients import Consumer, ConsumerType, Producer, ProducerType
from .create import StreamEngine, create_engine
from .dependencies.core import StreamDependencyManager
from .parameters import FromHeader, Header
from .prometheus.monitor import PrometheusMonitor, PrometheusMonitorType
from .streams import Stream, stream
from .types import ConsumerRecord

__all__ = [
"Consumer",
Expand All @@ -17,4 +19,8 @@
"Stream",
"stream",
"ConsumerRecord",
"Kafka",
"StreamDependencyManager",
"FromHeader",
"Header",
]
69 changes: 69 additions & 0 deletions kstreams/binders/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import inspect
from typing import Any, AsyncIterator, Awaitable, Protocol, TypeVar, Union

from di.api.dependencies import CacheKey
from di.dependant import Dependant, Marker

from kstreams.types import ConsumerRecord


class ExtractorTrait(Protocol):
"""Implement to extract data from incoming `ConsumerRecord`.
Consumers will always work with a consumer Record.
Implementing this would let you extract information from the `ConsumerRecord`.
"""

def __hash__(self) -> int:
"""Required by di in order to cache the deps"""
...

def __eq__(self, __o: object) -> bool:
"""Required by di in order to cache the deps"""
...

async def extract(
self, consumer_record: ConsumerRecord
) -> Union[Awaitable[Any], AsyncIterator[Any]]:
"""This is where the magic should happen.
For example, you could "extract" here a json from the `ConsumerRecord.value`
"""
...


T = TypeVar("T", covariant=True)


class MarkerTrait(Protocol[T]):
def register_parameter(self, param: inspect.Parameter) -> T:
...


class Binder(Dependant[Any]):
def __init__(
self,
*,
extractor: ExtractorTrait,
) -> None:
super().__init__(call=extractor.extract, scope="consumer_record")
self.extractor = extractor

@property
def cache_key(self) -> CacheKey:
return self.extractor


class BinderMarker(Marker):
"""Bind together the different dependencies.
NETX: Add asyncapi marker here, like `MarkerTrait[AsyncApiTrait]`.
Recommendation to wait until 3.0:
- [#618](https://github.com/asyncapi/spec/issues/618)
"""

def __init__(self, *, extractor_marker: MarkerTrait[ExtractorTrait]) -> None:
self.extractor_marker = extractor_marker

def register_parameter(self, param: inspect.Parameter) -> Binder:
return Binder(extractor=self.extractor_marker.register_parameter(param))
52 changes: 52 additions & 0 deletions kstreams/binders/header.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import inspect
from typing import Any, NamedTuple, Optional

from kstreams.exceptions import HeaderNotFound
from kstreams.types import ConsumerRecord


class HeaderExtractor(NamedTuple):
name: str

def __hash__(self) -> int:
return hash((self.__class__, self.name))

def __eq__(self, __o: object) -> bool:
return isinstance(__o, HeaderExtractor) and __o.name == self.name

async def extract(self, consumer_record: ConsumerRecord) -> Any:
if isinstance(consumer_record.headers, dict):
headers = tuple(consumer_record.headers.items())
# NEXT: Check also if it is a sequence, if not it means
# someone modified the headers in the serializer in a way
# that we cannot extract it, raise a readable error
else:
headers = consumer_record.headers
found_headers = [value for header, value in headers if header == self.name]

try:
header = found_headers.pop()
except IndexError as e:
message = (
f"No header `{self.name}` found.\n"
"Check if your broker is sending the header.\n"
"Try adding a default value to your parameter like `None`.\n"
"Or set `convert_underscores = False`."
)
raise HeaderNotFound(message) from e
else:
return header


class HeaderMarker(NamedTuple):
alias: Optional[str]
convert_underscores: bool

def register_parameter(self, param: inspect.Parameter) -> HeaderExtractor:
if self.alias is not None:
name = self.alias
elif self.convert_underscores:
name = param.name.replace("_", "-")
else:
name = param.name
return HeaderExtractor(name=name)
85 changes: 85 additions & 0 deletions kstreams/dependencies/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import Any, Callable, Optional

from di.container import Container, bind_by_type
from di.dependant import Dependant
from di.executors import AsyncExecutor

from kstreams.types import ConsumerRecord


class StreamDependencyManager:
"""Core of dependency injection on kstreams.
Attributes:
container: dependency store.
Usage
```python
consumer_record = ConsumerRecord(...)
def user_func(event_type: FromHeader[str]):
...
sdm = StreamDependencyManager()
sdm.solve(user_func)
sdm.execute(consumer_record)
```
"""

container: Container

def __init__(self, container: Optional[Container] = None):
self.container = container or Container()

def _register_framework_deps(self):
"""Register with the container types that belong to kstream.
These types can be injectable things available on startup.
But also they can be things created on a per connection basis.
They must be available at the time of executing the users' function.
For example:
- App
- ConsumerRecord
- Consumer
- KafkaBackend
Here we are just "registering", that's why we use `bind_by_type`.
And eventually, when "executing", we should inject all the values we
want to be "available".
"""
self.container.bind(
bind_by_type(
Dependant(ConsumerRecord, scope="consumer_record", wire=False),
ConsumerRecord,
)
)
# NEXT: Add Consumer as a dependency, so it can be injected.

def build(self, user_fn: Callable[..., Any]) -> None:
"""Build the dependency graph for the given function.
Attributes:
user_fn: user defined function, using allowed kstreams params
"""
self._register_framework_deps()

self.solved_user_fn = self.container.solve(
Dependant(user_fn, scope="consumer_record"),
scopes=["consumer_record"],
)

async def execute(self, consumer_record: ConsumerRecord) -> Any:
"""Execute the dependencies graph with external values.
Attributes:
consumer_record: A kafka record containing `values`, `headers`, etc.
"""
with self.container.enter_scope("consumer_record") as state:
return await self.container.execute_async(
self.solved_user_fn,
values={ConsumerRecord: consumer_record},
executor=AsyncExecutor(),
state=state,
)
4 changes: 4 additions & 0 deletions kstreams/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,7 @@ def __str__(self) -> str:

class BackendNotSet(StreamException):
...


class HeaderNotFound(StreamException):
...
33 changes: 33 additions & 0 deletions kstreams/parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Optional, TypeVar

from kstreams.binders.api import BinderMarker
from kstreams.binders.header import HeaderMarker
from kstreams.typing import Annotated


def Header(
*, alias: Optional[str] = None, convert_underscores: bool = True
) -> BinderMarker:
"""Construct another type from the headers of a kafka record.
Usage:
```python
from kstream import Header, Annotated
def user_fn(event_type: Annotated[str, Header(alias="EventType")]):
...
```
"""
header_marker = HeaderMarker(alias=alias, convert_underscores=convert_underscores)
binder = BinderMarker(extractor_marker=header_marker)
return binder


T = TypeVar("T")

FromHeader = Annotated[T, Header(alias=None)]
FromHeader.__doc__ = """General purpose convenient header type.
Use `Annotated` to provide custom params.
"""
2 changes: 1 addition & 1 deletion kstreams/serializers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, Optional, Protocol

from kstreams import ConsumerRecord
from kstreams.types import ConsumerRecord

from .types import Headers

Expand Down
6 changes: 6 additions & 0 deletions kstreams/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from typing import Dict, Sequence, Tuple

from aiokafka.structs import ConsumerRecord as AIOConsumerRecord

Headers = Dict[str, str]
EncodedHeaders = Sequence[Tuple[str, bytes]]


class ConsumerRecord(AIOConsumerRecord):
...
7 changes: 7 additions & 0 deletions kstreams/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Remove this file when python3.8 support is dropped."""
import sys

if sys.version_info < (3, 9):
from typing_extensions import Annotated as Annotated # noqa: F401
else:
from typing import Annotated as Annotated # noqa: F401
Loading

0 comments on commit f6fbed5

Please sign in to comment.