Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add dependency injection framework #58

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/bench-release.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Bump version
name: Benchmark latest release

on:
push:
Expand Down Expand Up @@ -46,5 +46,5 @@ jobs:
git config --global user.email "[email protected]"
git config --global user.name "GitHub Action"
git add .benchmarks/
git commit -m "bench: bench: add benchmark current release"
git commit -m "bench: current release"
git push origin master
40 changes: 34 additions & 6 deletions .github/workflows/pr-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ on:
required: true

jobs:
build_test_bench:
test:
runs-on: ubuntu-latest
strategy:
matrix:
Expand Down Expand Up @@ -56,15 +56,43 @@ jobs:
git config --global user.email "[email protected]"
git config --global user.name "GitHub Action"
./scripts/test

- name: Benchmark regression test
run: |
./scripts/bench-compare

- name: Upload coverage to Codecov
uses: codecov/[email protected]
with:
file: ./coverage.xml
name: kstreams
fail_ci_if_error: true
token: ${{secrets.CODECOV_TOKEN}}
bench:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup python
uses: actions/setup-python@v5
with:
python-version: '3.13'
architecture: x64
- name: Set Cache
uses: actions/cache@v4
id: cache # name for referring later
with:
path: .venv/
# The cache key depends on poetry.lock
key: ${{ runner.os }}-cache-${{ hashFiles('poetry.lock') }}
restore-keys: |
${{ runner.os }}-cache-
${{ runner.os }}-
- name: Install Dependencies
# if: steps.cache.outputs.cache-hit != 'true'
run: |
python -m pip install -U pip poetry
poetry --version
poetry config --local virtualenvs.in-project true
poetry install
- name: Benchmark regression test
run: |
./scripts/bench-current
./scripts/bench-compare

1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ if __name__ == "__main__":
- [ ] Store (kafka streams pattern)
- [ ] Stream Join
- [ ] Windowing
- [ ] PEP 593

## Development

Expand Down
6 changes: 6 additions & 0 deletions kstreams/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from aiokafka.structs import RecordMetadata, TopicPartition

from ._di.parameters import FromHeader, Header
from .backends.kafka import Kafka
from .clients import Consumer, Producer
from .create import StreamEngine, create_engine
from .prometheus.monitor import PrometheusMonitor, PrometheusMonitorType
Expand Down Expand Up @@ -31,4 +33,8 @@
"TestStreamClient",
"TopicPartition",
"TopicPartitionOffset",
"Kafka",
"StreamDependencyManager",
"FromHeader",
"Header",
]
68 changes: 68 additions & 0 deletions kstreams/_di/binders/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import inspect
from typing import Any, AsyncIterator, Awaitable, Protocol, TypeVar, Union

from di.api.dependencies import CacheKey
from di.dependent import Dependent, Marker

from kstreams.types import ConsumerRecord


class ExtractorTrait(Protocol):
"""Implement to extract data from incoming `ConsumerRecord`.

Consumers will always work with a consumer Record.
Implementing this would let you extract information from the `ConsumerRecord`.
"""

def __hash__(self) -> int:
"""Required by di in order to cache the deps"""
...

def __eq__(self, __o: object) -> bool:
"""Required by di in order to cache the deps"""
...

async def extract(
self, consumer_record: ConsumerRecord
) -> Union[Awaitable[Any], AsyncIterator[Any]]:
"""This is where the magic should happen.

For example, you could "extract" here a json from the `ConsumerRecord.value`
"""
...


T = TypeVar("T", covariant=True)


class MarkerTrait(Protocol[T]):
def register_parameter(self, param: inspect.Parameter) -> T: ...


class Binder(Dependent[Any]):
def __init__(
self,
*,
extractor: ExtractorTrait,
) -> None:
super().__init__(call=extractor.extract, scope="consumer_record")
self.extractor = extractor

@property
def cache_key(self) -> CacheKey:
return self.extractor


class BinderMarker(Marker):
"""Bind together the different dependencies.

NETX: Add asyncapi marker here, like `MarkerTrait[AsyncApiTrait]`.
Recommendation to wait until 3.0:
- [#618](https://github.com/asyncapi/spec/issues/618)
"""

def __init__(self, *, extractor_marker: MarkerTrait[ExtractorTrait]) -> None:
self.extractor_marker = extractor_marker

def register_parameter(self, param: inspect.Parameter) -> Binder:
return Binder(extractor=self.extractor_marker.register_parameter(param))
44 changes: 44 additions & 0 deletions kstreams/_di/binders/header.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import inspect
from typing import Any, NamedTuple, Optional

from kstreams.exceptions import HeaderNotFound
from kstreams.types import ConsumerRecord


class HeaderExtractor(NamedTuple):
name: str

def __hash__(self) -> int:
return hash((self.__class__, self.name))

def __eq__(self, __o: object) -> bool:
return isinstance(__o, HeaderExtractor) and __o.name == self.name

async def extract(self, consumer_record: ConsumerRecord) -> Any:
headers = dict(consumer_record.headers)
try:
header = headers[self.name]
except KeyError as e:
message = (
f"No header `{self.name}` found.\n"
"Check if your broker is sending the header.\n"
"Try adding a default value to your parameter like `None`.\n"
"Or set `convert_underscores = False`."
)
raise HeaderNotFound(message) from e
else:
return header


class HeaderMarker(NamedTuple):
alias: Optional[str]
convert_underscores: bool

def register_parameter(self, param: inspect.Parameter) -> HeaderExtractor:
if self.alias is not None:
name = self.alias
elif self.convert_underscores:
name = param.name.replace("_", "-")
else:
name = param.name
return HeaderExtractor(name=name)
113 changes: 113 additions & 0 deletions kstreams/_di/dependencies/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from typing import Any, Callable, Optional

from di import Container, bind_by_type
from di.dependent import Dependent
from di.executors import AsyncExecutor

from kstreams._di.dependencies.hooks import bind_by_generic
from kstreams.streams import Stream
from kstreams.types import ConsumerRecord, Send

LayerFn = Callable[..., Any]


class StreamDependencyManager:
"""Core of dependency injection on kstreams.

This is an internal class of kstreams that manages the dependency injection,
as a user you should not use this class directly.

Attributes:
container: dependency store.
stream: the stream wrapping the user function. Optional to improve testability.
When instanciating this class you must provide a stream, otherwise users
won't be able to use the `stream` parameter in their functions.
send: send object. Optional to improve testability, same as stream.

Usage:

stream and send are ommited for simplicity

```python
def user_func(cr: ConsumerRecord):
...

sdm = StreamDependencyManager()
sdm.solve(user_func)
sdm.execute(consumer_record)
```
"""

container: Container

def __init__(
self,
container: Optional[Container] = None,
stream: Optional[Stream] = None,
send: Optional[Send] = None,
):
self.container = container or Container()
self.async_executor = AsyncExecutor()
self.stream = stream
self.send = send

def solve_user_fn(self, fn: LayerFn) -> None:
"""Build the dependency graph for the given function.

Objects must be injected before this function is called.

Attributes:
fn: user defined function, using allowed kstreams params
"""
self._register_consumer_record()

if isinstance(self.stream, Stream):
self._register_stream(self.stream)

if self.send is not None:
self._register_send(self.send)

self.solved_user_fn = self.container.solve(
Dependent(fn, scope="consumer_record"),
scopes=["consumer_record", "stream", "application"],
)

async def execute(self, consumer_record: ConsumerRecord) -> Any:
"""Execute the dependencies graph with external values.

Attributes:
consumer_record: A kafka record containing `values`, `headers`, etc.
"""
async with self.container.enter_scope("consumer_record") as state:
return await self.solved_user_fn.execute_async(
executor=self.async_executor,
state=state,
values={ConsumerRecord: consumer_record},
)

def _register_stream(self, stream: Stream):
"""Register the stream with the container."""
hook = bind_by_type(
Dependent(lambda: stream, scope="consumer_record", wire=False), Stream
)
self.container.bind(hook)

def _register_consumer_record(self):
"""Register consumer record with the container.

We bind_by_generic because we want to bind the `ConsumerRecord` type which
is generic.

The value must be injected at runtime.
"""
hook = bind_by_generic(
Dependent(ConsumerRecord, scope="consumer_record", wire=False),
ConsumerRecord,
)
self.container.bind(hook)

def _register_send(self, send: Send):
hook = bind_by_type(
Dependent(lambda: send, scope="consumer_record", wire=False), Send
)
self.container.bind(hook)
31 changes: 31 additions & 0 deletions kstreams/_di/dependencies/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import inspect
from typing import Any, get_origin

from di._container import BindHook
from di._utils.inspect import get_type
from di.api.dependencies import DependentBase


def bind_by_generic(
provider: DependentBase[Any],
dependency: type,
) -> BindHook:
"""Hook to substitute the matched dependency based on its generic."""

def hook(
param: inspect.Parameter | None, dependent: DependentBase[Any]
) -> DependentBase[Any] | None:
if dependent.call == dependency:
return provider
if param is None:
return None

type_annotation_option = get_type(param)
if type_annotation_option is None:
return None
type_annotation = type_annotation_option.value
if get_origin(type_annotation) is dependency:
return provider
return None

return hook
37 changes: 37 additions & 0 deletions kstreams/_di/parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Optional, TypeVar

from kstreams._di.binders.api import BinderMarker
from kstreams._di.binders.header import HeaderMarker
from kstreams.typing import Annotated


def Header(
*, alias: Optional[str] = None, convert_underscores: bool = True
) -> BinderMarker:
"""Construct another type from the headers of a kafka record.

Args:
alias: Use a different header name
convert_underscores: If True, convert underscores to dashes.

Usage:

```python
from kstream import Header, Annotated

def user_fn(event_type: Annotated[str, Header(alias="EventType")]):
...
```
"""
header_marker = HeaderMarker(alias=alias, convert_underscores=convert_underscores)
binder = BinderMarker(extractor_marker=header_marker)
return binder


T = TypeVar("T")

FromHeader = Annotated[T, Header()]
FromHeader.__doc__ = """General purpose convenient header type.

Use `Annotated` to provide custom params.
"""
Loading
Loading