diff --git a/requirements/base.txt b/requirements/base.txt index b9f2d5b73..b5bdbda54 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -5,5 +5,5 @@ # # pip-compile-multi # -numpy==1.25.2 +numpy==1.26.1 # via -r base.in diff --git a/requirements/ci.txt b/requirements/ci.txt index 9e194f2ea..5cde028b0 100644 --- a/requirements/ci.txt +++ b/requirements/ci.txt @@ -11,46 +11,46 @@ certifi==2023.7.22 # via requests chardet==5.2.0 # via tox -charset-normalizer==3.2.0 +charset-normalizer==3.3.0 # via requests colorama==0.4.6 # via tox distlib==0.3.7 # via virtualenv -filelock==3.12.2 +filelock==3.12.4 # via # tox # virtualenv gitdb==4.0.10 # via gitpython -gitpython==3.1.32 +gitpython==3.1.38 # via -r ci.in idna==3.4 # via requests -packaging==23.1 +packaging==23.2 # via # -r ci.in # pyproject-api # tox -platformdirs==3.10.0 +platformdirs==3.11.0 # via # tox # virtualenv -pluggy==1.2.0 +pluggy==1.3.0 # via tox -pyproject-api==1.5.3 +pyproject-api==1.6.1 # via tox requests==2.31.0 # via -r ci.in -smmap==5.0.0 +smmap==5.0.1 # via gitdb tomli==2.0.1 # via # pyproject-api # tox -tox==4.7.0 +tox==4.11.3 # via -r ci.in -urllib3==2.0.4 +urllib3==2.0.7 # via requests -virtualenv==20.24.2 +virtualenv==20.24.5 # via tox diff --git a/requirements/dev.txt b/requirements/dev.txt index 773a7eb1c..e753dba9b 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -12,19 +12,19 @@ -r static.txt -r test.txt -r wheels.txt -anyio==3.7.1 +anyio==4.0.0 # via jupyter-server -argon2-cffi==21.3.0 +argon2-cffi==23.1.0 # via jupyter-server argon2-cffi-bindings==21.2.0 # via argon2-cffi -arrow==1.2.3 +arrow==1.3.0 # via isoduration async-lru==2.0.4 # via jupyterlab -cffi==1.15.1 +cffi==1.16.0 # via argon2-cffi-bindings -click==8.1.6 +click==8.1.7 # via # pip-compile-multi # pip-tools @@ -36,16 +36,16 @@ json5==0.9.14 # via jupyterlab-server jsonpointer==2.4 # via jsonschema -jsonschema[format-nongpl]==4.19.0 +jsonschema[format-nongpl]==4.19.1 # via # jupyter-events # jupyterlab-server # nbformat -jupyter-events==0.7.0 +jupyter-events==0.8.0 # via jupyter-server jupyter-lsp==2.2.0 # via jupyterlab -jupyter-server==2.7.0 +jupyter-server==2.8.0 # via # jupyter-lsp # jupyterlab @@ -53,9 +53,9 @@ jupyter-server==2.7.0 # notebook-shim jupyter-server-terminals==0.4.4 # via jupyter-server -jupyterlab==4.0.4 +jupyterlab==4.0.7 # via -r dev.in -jupyterlab-server==2.24.0 +jupyterlab-server==2.25.0 # via jupyterlab notebook-shim==0.2.3 # via jupyterlab @@ -89,13 +89,15 @@ terminado==0.17.1 # jupyter-server-terminals toposort==1.10 # via pip-compile-multi +types-python-dateutil==2.8.19.14 + # via arrow uri-template==1.3.0 # via jsonschema webcolors==1.13 # via jsonschema -websocket-client==1.6.1 +websocket-client==1.6.4 # via jupyter-server -wheel==0.41.1 +wheel==0.41.2 # via pip-tools # The following packages are considered to be unsafe in a requirements file: diff --git a/requirements/docs.txt b/requirements/docs.txt index 04be43353..3caa9a364 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -11,9 +11,9 @@ accessible-pygments==0.0.4 # via pydata-sphinx-theme alabaster==0.7.13 # via sphinx -annotated-types==0.5.0 +annotated-types==0.6.0 # via pydantic -asttokens==2.2.1 +asttokens==2.4.0 # via stack-data attrs==23.1.0 # via @@ -21,7 +21,7 @@ attrs==23.1.0 # referencing autodoc-pydantic==2.0.1 # via -r docs.in -babel==2.12.1 +babel==2.13.0 # via # pydata-sphinx-theme # sphinx @@ -31,15 +31,11 @@ beautifulsoup4==4.12.2 # via # nbconvert # pydata-sphinx-theme -bleach==6.0.0 +bleach==6.1.0 # via nbconvert -certifi==2023.7.22 - # via requests -charset-normalizer==3.2.0 - # via requests comm==0.1.4 # via ipykernel -debugpy==1.6.7.post1 +debugpy==1.8.0 # via ipykernel decorator==5.1.1 # via ipython @@ -51,26 +47,21 @@ docutils==0.20.1 # nbsphinx # pydata-sphinx-theme # sphinx -executing==1.2.0 +exceptiongroup==1.1.3 + # via ipython +executing==2.0.0 # via stack-data -fastjsonschema==2.18.0 +fastjsonschema==2.18.1 # via nbformat -idna==3.4 - # via requests imagesize==1.4.1 # via sphinx -importlib-metadata==6.8.0 - # via - # jupyter-client - # nbconvert - # sphinx -ipykernel==6.25.1 +ipykernel==6.25.2 # via -r docs.in -ipython==8.14.0 +ipython==8.16.1 # via # -r docs.in # ipykernel -jedi==0.19.0 +jedi==0.19.1 # via ipython jinja2==3.1.2 # via @@ -78,15 +69,15 @@ jinja2==3.1.2 # nbconvert # nbsphinx # sphinx -jsonschema==4.19.0 +jsonschema==4.19.1 # via nbformat jsonschema-specifications==2023.7.1 # via jsonschema -jupyter-client==8.3.0 +jupyter-client==8.4.0 # via # ipykernel # nbclient -jupyter-core==5.3.1 +jupyter-core==5.4.0 # via # ipykernel # jupyter-client @@ -111,22 +102,22 @@ mdit-py-plugins==0.4.0 # via myst-parser mdurl==0.1.2 # via markdown-it-py -mistune==3.0.1 +mistune==3.0.2 # via nbconvert myst-parser==2.0.0 # via -r docs.in nbclient==0.8.0 # via nbconvert -nbconvert==7.7.3 +nbconvert==7.9.2 # via nbsphinx nbformat==5.9.2 # via # nbclient # nbconvert # nbsphinx -nbsphinx==0.9.2 +nbsphinx==0.9.3 # via -r docs.in -nest-asyncio==1.5.7 +nest-asyncio==1.5.8 # via ipykernel pandocfilters==1.5.0 # via nbconvert @@ -136,25 +127,23 @@ pexpect==4.8.0 # via ipython pickleshare==0.7.5 # via ipython -platformdirs==3.10.0 - # via jupyter-core prompt-toolkit==3.0.39 # via ipython -psutil==5.9.5 +psutil==5.9.6 # via ipykernel ptyprocess==0.7.0 # via pexpect pure-eval==0.2.2 # via stack-data -pydantic==2.1.1 +pydantic==2.4.2 # via # autodoc-pydantic # pydantic-settings -pydantic-core==2.4.0 +pydantic-core==2.10.1 # via pydantic -pydantic-settings==2.0.2 +pydantic-settings==2.0.3 # via autodoc-pydantic -pydata-sphinx-theme==0.13.3 +pydata-sphinx-theme==0.14.1 # via -r docs.in pygments==2.16.1 # via @@ -173,17 +162,15 @@ referencing==0.30.2 # via # jsonschema # jsonschema-specifications -requests==2.31.0 - # via sphinx -rpds-py==0.9.2 +rpds-py==0.10.6 # via # jsonschema # referencing snowballstemmer==2.2.0 # via sphinx -soupsieve==2.4.1 +soupsieve==2.5 # via beautifulsoup4 -sphinx==7.1.2 +sphinx==7.2.6 # via # -r docs.in # autodoc-pydantic @@ -204,27 +191,27 @@ sphinx-copybutton==0.5.2 # via -r docs.in sphinx-design==0.5.0 # via -r docs.in -sphinxcontrib-applehelp==1.0.6 +sphinxcontrib-applehelp==1.0.7 # via sphinx -sphinxcontrib-devhelp==1.0.4 +sphinxcontrib-devhelp==1.0.5 # via sphinx -sphinxcontrib-htmlhelp==2.0.3 +sphinxcontrib-htmlhelp==2.0.4 # via sphinx sphinxcontrib-jsmath==1.0.1 # via sphinx -sphinxcontrib-qthelp==1.0.5 +sphinxcontrib-qthelp==1.0.6 # via sphinx -sphinxcontrib-serializinghtml==1.1.7 +sphinxcontrib-serializinghtml==1.1.9 # via sphinx -stack-data==0.6.2 +stack-data==0.6.3 # via ipython tinycss2==1.2.1 # via nbconvert -tornado==6.3.2 +tornado==6.3.3 # via # ipykernel # jupyter-client -traitlets==5.9.0 +traitlets==5.11.2 # via # comm # ipykernel @@ -236,15 +223,12 @@ traitlets==5.9.0 # nbconvert # nbformat # nbsphinx -typing-extensions==4.7.1 +typing-extensions==4.8.0 # via - # ipython # pydantic # pydantic-core # pydata-sphinx-theme -urllib3==2.0.4 - # via requests -wcwidth==0.2.6 +wcwidth==0.2.8 # via prompt-toolkit webencodings==0.5.1 # via diff --git a/requirements/extra.in b/requirements/extra.in index 79256f24c..f868e73d8 100644 --- a/requirements/extra.in +++ b/requirements/extra.in @@ -1,6 +1,7 @@ confluent-kafka colorama scipp +scippneutron plopp h5py ess-streaming-data-types diff --git a/requirements/extra.txt b/requirements/extra.txt index 9ef0e47ff..abf022e3e 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -1,54 +1,88 @@ -# SHA1:37c1f0ef266520fca7370542ebdef12c4f90ba80 +# SHA1:ce2fd237e95e3b0ea4fe4ec48793a278cea1c392 # # This file is autogenerated by pip-compile-multi # To update, run: # # pip-compile-multi # +certifi==2023.7.22 + # via requests +charset-normalizer==3.3.0 + # via requests colorama==0.4.6 # via -r extra.in confluent-kafka==2.2.0 # via -r extra.in confuse==2.0.1 # via scipp -contourpy==1.1.0 +contourpy==1.1.1 # via matplotlib -cycler==0.11.0 +cycler==0.12.1 # via matplotlib -ess-streaming-data-types==0.22.1 +ess-streaming-data-types==0.23.1 # via -r extra.in flatbuffers==23.5.26 # via ess-streaming-data-types -fonttools==4.42.0 +fonttools==4.43.1 # via matplotlib graphlib-backport==1.0.3 # via scipp -h5py==3.9.0 - # via -r extra.in -kiwisolver==1.4.4 +h5py==3.10.0 + # via + # -r extra.in + # scippneutron + # scippnexus +idna==3.4 + # via requests +kiwisolver==1.4.5 # via matplotlib -matplotlib==3.7.2 +matplotlib==3.8.0 # via plopp -numpy==1.25.2 +numpy==1.26.1 # via # contourpy # ess-streaming-data-types # h5py # matplotlib # scipp -packaging==23.1 - # via matplotlib -pillow==10.0.0 + # scippneutron + # scipy +packaging==23.2 + # via + # matplotlib + # pooch +pillow==10.1.0 # via matplotlib -plopp==23.5.1 +platformdirs==3.11.0 + # via pooch +plopp==23.10.1 # via -r extra.in -pyparsing==3.0.9 +pooch==1.7.0 + # via scippneutron +pyparsing==3.1.1 # via matplotlib python-dateutil==2.8.2 - # via matplotlib + # via + # matplotlib + # scippnexus pyyaml==6.0.1 # via confuse +requests==2.31.0 + # via pooch scipp==23.8.0 + # via + # -r extra.in + # scippneutron + # scippnexus +scippneutron==23.9.0 # via -r extra.in +scippnexus==23.8.0 + # via scippneutron +scipy==1.11.3 + # via + # scippneutron + # scippnexus six==1.16.0 # via python-dateutil +urllib3==2.0.7 + # via requests diff --git a/requirements/mypy.txt b/requirements/mypy.txt index 0018786e6..5df22e915 100644 --- a/requirements/mypy.txt +++ b/requirements/mypy.txt @@ -6,9 +6,9 @@ # pip-compile-multi # -r test.txt -mypy==1.5.0 +mypy==1.6.1 # via -r mypy.in mypy-extensions==1.0.0 # via mypy -typing-extensions==4.7.1 +typing-extensions==4.8.0 # via mypy diff --git a/requirements/static.txt b/requirements/static.txt index 30caa4e57..39ca601d5 100644 --- a/requirements/static.txt +++ b/requirements/static.txt @@ -5,23 +5,23 @@ # # pip-compile-multi # -cfgv==3.3.1 +cfgv==3.4.0 # via pre-commit distlib==0.3.7 # via virtualenv -filelock==3.12.2 +filelock==3.12.4 # via virtualenv -identify==2.5.26 +identify==2.5.30 # via pre-commit nodeenv==1.8.0 # via pre-commit -platformdirs==3.10.0 +platformdirs==3.11.0 # via virtualenv -pre-commit==3.3.3 +pre-commit==3.5.0 # via -r static.in pyyaml==6.0.1 # via pre-commit -virtualenv==20.24.2 +virtualenv==20.24.5 # via pre-commit # The following packages are considered to be unsafe in a requirements file: diff --git a/requirements/test.in b/requirements/test.in index bbf3a7c0c..baafff8fc 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -1,5 +1,6 @@ -r base.in -r extra.in +scipy pytest pytest-cov pytest-xdist diff --git a/requirements/test.txt b/requirements/test.txt index 52b8cfa5f..de6998725 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,4 +1,4 @@ -# SHA1:9982cf5a9dd835ca57251e3dac244292199f22e1 +# SHA1:dc2bc3cc1deab915f302e8b4464480c77d7fc324 # # This file is autogenerated by pip-compile-multi # To update, run: @@ -7,17 +7,17 @@ # -r base.txt -r extra.txt -coverage[toml]==7.2.7 +coverage[toml]==7.3.2 # via pytest-cov -exceptiongroup==1.1.2 +exceptiongroup==1.1.3 # via pytest execnet==2.0.2 # via pytest-xdist iniconfig==2.0.0 # via pytest -pluggy==1.2.0 +pluggy==1.3.0 # via pytest -pytest==7.4.0 +pytest==7.4.2 # via # -r test.in # pytest-cov diff --git a/requirements/wheels.txt b/requirements/wheels.txt index ea0610f0f..4935673c6 100644 --- a/requirements/wheels.txt +++ b/requirements/wheels.txt @@ -5,9 +5,9 @@ # # pip-compile-multi # -build==0.10.0 +build==1.0.3 # via -r wheels.in -packaging==23.1 +packaging==23.2 # via build pyproject-hooks==1.0.0 # via build diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..a1514956a --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +# These fixtures cannot be found by pytest, +# if they are not defined in `conftest.py` under `tests` directory. +import pytest + + +def pytest_addoption(parser: pytest.Parser): + parser.addoption("--full-benchmark", action="store_true", default=False) + parser.addoption("--kafka-test", action="store_true", default=False) + + +@pytest.fixture +def kafka_test(request: pytest.FixtureRequest) -> bool: + """ + Requires --kafka-test flag. + """ + if not request.config.getoption('--kafka-test'): + pytest.skip( + "Skipping kafka required tests. " + "Use ``--kafka-test`` option to run this test." + ) + return True + + +@pytest.fixture +def full_benchmark(request: pytest.FixtureRequest) -> bool: + """ + Requires --full-benchmark flag. + """ + if not request.config.getoption('--full-benchmark'): + pytest.skip( + "Skipping full benchmark. " + "Use ``--full-benchmark`` option to run this test." + ) + + return True diff --git a/tests/prototypes/__init__.py b/tests/prototypes/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/prototypes/parameters.py b/tests/prototypes/parameters.py new file mode 100644 index 000000000..11136cd56 --- /dev/null +++ b/tests/prototypes/parameters.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from dataclasses import dataclass +from typing import NewType + +from beamlime.constructors import ProviderGroup + +# Random Events Generation +RandomSeed = NewType("RandomSeed", int) +EventRate = NewType("EventRate", int) # [events/s] +NumPixels = NewType("NumPixels", int) # [pixels/detector] +FrameRate = NewType("FrameRate", int) # [Hz] +NumFrames = NewType("NumFrames", int) # [dimensionless] + +# Workflow +ChunkSize = NewType("ChunkSize", int) +HistogramBinSize = NewType("HistogramBinSize", int) + + +default_param_providers = ProviderGroup() +default_params = { + RandomSeed: 123, + FrameRate: 14, + NumFrames: 140, + ChunkSize: 28, + HistogramBinSize: 50, +} + +default_param_providers[RandomSeed] = lambda: default_params[RandomSeed] +default_param_providers[FrameRate] = lambda: default_params[FrameRate] +default_param_providers[NumFrames] = lambda: default_params[NumFrames] +default_param_providers[ChunkSize] = lambda: default_params[ChunkSize] +default_param_providers[HistogramBinSize] = lambda: default_params[HistogramBinSize] + + +@dataclass +class BenchmarkParameters: + num_pixels: NumPixels + event_rate: EventRate + num_frames: NumFrames + frame_rate: FrameRate + + +default_param_providers[BenchmarkParameters] = BenchmarkParameters diff --git a/tests/prototypes/prototype_kafka.py b/tests/prototypes/prototype_kafka.py new file mode 100644 index 000000000..6315cb303 --- /dev/null +++ b/tests/prototypes/prototype_kafka.py @@ -0,0 +1,395 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from __future__ import annotations + +import asyncio +from abc import ABC, abstractmethod +from queue import Empty +from typing import Callable, List, NewType, Optional + +import scipp as sc +from confluent_kafka import Consumer, Message, Producer, TopicPartition +from confluent_kafka.admin import AdminClient, PartitionMetadata, TopicMetadata + +from beamlime.constructors import Factory, ProviderGroup +from beamlime.core.schedulers import retry + +from .parameters import ChunkSize, FrameRate, NumFrames +from .prototype_mini import BaseApp, BasePrototype, DataStreamListener +from .random_data_providers import RandomEvents +from .workflows import Events + +KafkaBrokerAddress = NewType("KafkaBrokerAddress", str) +KafkaTopic = NewType("KafkaTopic", str) +KafkaBootstrapServer = NewType("KafkaBootstrapServer", str) +ConsumerContextManager = Callable[[], Consumer] + + +def provide_kafka_bootstrap_server( + broker_address: Optional[KafkaBrokerAddress] = None, +) -> KafkaBootstrapServer: + boostrap_server_addr = broker_address or 'localhost:9092' + return KafkaBootstrapServer(boostrap_server_addr) + + +def provide_kafka_admin(broker_address: KafkaBootstrapServer) -> AdminClient: + return AdminClient({'bootstrap.servers': broker_address}) + + +def provide_kafka_producer(broker_address: KafkaBootstrapServer) -> Producer: + return Producer({'bootstrap.servers': broker_address}) + + +def provide_kafka_consumer_ctxt_manager( + broker_address: KafkaBootstrapServer, kafka_topic_partition: TopicPartition +) -> ConsumerContextManager: + from contextlib import contextmanager + + @contextmanager + def consumer_manager(): + cs = Consumer( + { + 'bootstrap.servers': broker_address, + 'group.id': "BEAMLIME", + 'auto.offset.reset': 'earliest', + 'enable.auto.commit': False, + } + ) + cs.assign([kafka_topic_partition]) + yield cs + cs.close() + + return consumer_manager + + +def kafka_topic_exists(topic: KafkaTopic, admin_cli: AdminClient) -> bool: + return topic in admin_cli.list_topics().topics + + +TopicCreated = NewType("TopicCreated", bool) + + +def create_topic(admin: AdminClient, topic: KafkaTopic) -> TopicCreated: + from confluent_kafka.admin import NewTopic + + if not admin.list_topics().topics.get(topic): + admin.create_topics([NewTopic(topic)]) + + @retry(RuntimeError, max_trials=5, interval=0.1) + def wait_for_topic_to_be_created(): + # Wait for a topic to be created by a broker. + if not admin.list_topics().topics.get(topic): + raise RuntimeError( + "Kafka topic could not be created " "within timeout of 0.5 second." + ) + + wait_for_topic_to_be_created() + return TopicCreated(True) + + +def retrieve_topic_partition( + admin: AdminClient, topic: KafkaTopic, topic_created: TopicCreated +) -> TopicPartition: + topic_meta: TopicMetadata + + if not (topic_meta := admin.list_topics().topics.get(topic)) or not topic_created: + raise ValueError(f"There is no topic named {topic} in the broker") + elif len(topic_meta.partitions) != 1: + raise NotImplementedError("There should be exactly 1 partition for testing.") + + part_meta: PartitionMetadata + _, part_meta = topic_meta.partitions.popitem() + return TopicPartition(topic, part_meta.id, offset=0) + + +def provide_random_kafka_topic(admin_cli: AdminClient) -> KafkaTopic: + import uuid + + def generate_ru_tp(prefix: str = "BEAMLIMETEST") -> KafkaTopic: + return KafkaTopic('.'.join((prefix, uuid.uuid4().hex))) + + random_topic: KafkaTopic + while kafka_topic_exists((random_topic := generate_ru_tp()), admin_cli): + ... + + return random_topic + + +KafkaTopicDeleted = NewType("KafkaTopicDeleted", bool) + + +class TemporaryTopicNotDeleted(Exception): + ... + + +def delete_topic(topic: KafkaTopic, admin_cli: AdminClient) -> KafkaTopicDeleted: + from concurrent.futures import Future + + futures: dict[str, Future] = admin_cli.delete_topics([topic]) + + @retry(TemporaryTopicNotDeleted, max_trials=10, interval=0.1) + def wait_for_kafka_topic_deleted(): + if futures[topic].running(): + raise TemporaryTopicNotDeleted + + try: + wait_for_kafka_topic_deleted() + return KafkaTopicDeleted(True) + except TemporaryTopicNotDeleted: + return KafkaTopicDeleted(False) + + +RandomEventBuffers = NewType('RandomEventBuffers', list[bytes]) + + +def provide_random_event_buffers(random_events: RandomEvents) -> RandomEventBuffers: + from streaming_data_types.eventdata_ev44 import serialise_ev44 + + return RandomEventBuffers( + [ + serialise_ev44( + source_name='LIME', + message_id=i_event, + reference_time=event.coords['event_time_zero'].values[:1], + reference_time_index=[i_event], + time_of_flight=event.coords['event_time_offset'].values, + pixel_id=event.coords['pixel_id'].values, + ) + for i_event, event in enumerate(random_events) + ] + ) + + +class KafkaStreamSimulatorBase(BaseApp, ABC): + kafka_topic: KafkaTopic + admin: AdminClient + producer: Producer + + def produce_data(self, raw_data: bytes, max_size: int = 100_000) -> None: + slice_steps = range((len(raw_data) + max_size - 1) // max_size) + slices = [ + raw_data[i_slice * max_size : (i_slice + 1) * max_size] + for i_slice in slice_steps + ] + for sliced in slices: + try: + self.producer.produce(self.kafka_topic, sliced) + except Exception: + self.producer.flush() + self.producer.produce(self.kafka_topic, sliced) + + @abstractmethod + def stream(self) -> None: + ... + + async def run(self) -> None: + self.stream() + await asyncio.sleep(0.5) + self.producer.flush() + self.info("Data streaming to kafka finished...") + + def __del__(self) -> None: + if kafka_topic_exists(self.kafka_topic, self.admin): + delete_topic(self.kafka_topic, self.admin) + + +class KafkaStreamSimulatorScippOnly(KafkaStreamSimulatorBase): + random_events: RandomEvents + + def stream(self) -> None: + import json + + from scipp.serialization import serialize + + for random_event in self.random_events: + header, serialized_list = serialize(random_event) + self.producer.produce(self.kafka_topic, f'header:{json.dumps(header)}') + + for data_buffer in serialized_list: + self.produce_data(data_buffer) + self.producer.produce(self.kafka_topic, 'finished') + self.debug("Result: %s", self.producer.poll(1)) + + +class KafkaStreamSimulator(KafkaStreamSimulatorBase): + random_events: RandomEventBuffers + + def stream(self) -> None: + for i_frame, random_event in enumerate(self.random_events): + self.producer.produce(self.kafka_topic, "starts") + self.produce_data(random_event) + self.producer.produce(self.kafka_topic, 'finished') + self.debug( + "Produced %sth message divided into %s chunks.", + i_frame + 1, + self.producer.poll(1), + ) + + +class KafkaListenerBase(BaseApp, ABC): + raw_data_pipe: List[Events] + chunk_size: ChunkSize + kafka_topic: KafkaTopic + consumer_cxt: ConsumerContextManager + num_frames: NumFrames + frame_rate: FrameRate + + def merge_bytes(self, datalist: list[bytes]) -> bytes: + from functools import reduce + + return reduce(lambda x, y: x + y, datalist) + + @retry(Empty, max_trials=10, interval=0.1) + def _poll(self, consumer: Consumer) -> Message: + if (msg := consumer.poll(timeout=0)) is None: + raise Empty + else: + return msg + + def poll(self, consumer: Consumer) -> Message | None: + try: + return self._poll(consumer) + except Empty: + return None + + @abstractmethod + def poll_one_data(self, consumer: Consumer) -> sc.DataArray | None: + ... + + async def send_data_chunk(self, data_chunk: Events) -> None: + self.debug( + "Sending %s th, %s pieces of data.", self.data_counts, len(data_chunk) + ) + self.raw_data_pipe.append(Events(data_chunk)) + await self.commit_process() + + def start_stop_watch(self) -> None: + self.stop_watch.start() + + async def run(self) -> None: + with self.consumer_cxt() as consumer: + self.start_stop_watch() + data_chunk: Events = Events([]) + while (event := self.poll_one_data(consumer)) is not None: + data_chunk.append(event) + if len(data_chunk) >= self.chunk_size: + await self.send_data_chunk(data_chunk) + data_chunk = Events([]) + + if data_chunk: + await self.send_data_chunk(data_chunk) + + self.info("Data streaming finished...") + + +class KafkaListenerScippOnly(KafkaListenerBase): + def poll_one_data(self, consumer: Consumer) -> sc.DataArray | None: + import json + + from scipp.serialization import deserialize + + header: dict = {} + data_list: list[bytes] = [] + header_prefix = b'header' + finished_prefix = b'finished' + + while msg := self.poll(consumer): + raw_msg: bytes = msg.value() + if raw_msg.startswith(header_prefix): + header = json.loads(raw_msg.removeprefix(header_prefix)) + elif raw_msg.startswith(finished_prefix): + da = deserialize(header, [self.merge_bytes(data_list)]) + if not isinstance(da, sc.DataArray): + raise TypeError('Expected sc.DataArray, but got ', type(da)) + return da + else: + data_list.append(raw_msg) + + return None + + +class KafkaListener(KafkaListenerBase): + def deserialize(self, data_list: list[bytes]) -> sc.DataArray: + import numpy as np + from streaming_data_types.eventdata_ev44 import deserialise_ev44 + + data = deserialise_ev44(self.merge_bytes(data_list)) + event_zeros = np.full(len(data.pixel_id), data.reference_time[0]) + + return sc.DataArray( + data=sc.ones(dims=['event'], shape=(len(data.pixel_id),), unit='counts'), + coords={ + 'event_time_offset': sc.Variable( + dims=['event'], values=data.time_of_flight, unit='ns', dtype=float + ), + 'event_time_zero': sc.Variable( + dims=['event'], values=event_zeros, unit='ns' + ), + 'pixel_id': sc.Variable( + dims=['event'], values=data.pixel_id, dtype='int' + ), + }, + ) + + def poll_one_data(self, consumer: Consumer) -> sc.DataArray | None: + data_list: list[bytes] = [] + header_prefix = b'starts' + finished_prefix = b'finished' + + while msg := self.poll(consumer): + raw_msg: bytes = msg.value() + if raw_msg.startswith(header_prefix): + ... + elif raw_msg.startswith(finished_prefix): + return self.deserialize(data_list) + else: + data_list.append(raw_msg) + + return None + + +class KafkaPrototype(BasePrototype): + kafka_simulator: KafkaStreamSimulator + + def collect_sub_daemons(self) -> list[BaseApp]: + return [self.kafka_simulator] + super().collect_sub_daemons() + + +def collect_kafka_providers() -> ProviderGroup: + from beamlime.constructors.providers import SingletonProvider + + kafka_providers = ProviderGroup( + SingletonProvider(provide_random_kafka_topic), + SingletonProvider(provide_kafka_bootstrap_server), + SingletonProvider(provide_kafka_admin), + ) + kafka_providers[Producer] = provide_kafka_producer + kafka_providers[ConsumerContextManager] = provide_kafka_consumer_ctxt_manager + kafka_providers[TopicCreated] = create_topic + kafka_providers[TopicPartition] = retrieve_topic_partition + kafka_providers[KafkaStreamSimulator] = KafkaStreamSimulator + kafka_providers[KafkaPrototype] = KafkaPrototype + kafka_providers[KafkaTopicDeleted] = delete_topic + kafka_providers[RandomEventBuffers] = provide_random_event_buffers + + return kafka_providers + + +def kafka_prototype_factory() -> Factory: + from .prototype_mini import Prototype, prototype_base_providers + + kafka_providers = collect_kafka_providers() + base_providers = prototype_base_providers() + base_providers[Prototype] = KafkaPrototype + base_providers[DataStreamListener] = KafkaListener + + return Factory(base_providers, kafka_providers) + + +if __name__ == "__main__": + from .prototype_mini import prototype_arg_parser, run_standalone_prototype + + kafka_factory = kafka_prototype_factory() + parser = prototype_arg_parser() + + run_standalone_prototype(kafka_factory, parser.parse_args()) diff --git a/tests/prototypes/prototype_mini.py b/tests/prototypes/prototype_mini.py new file mode 100644 index 000000000..9b05e294b --- /dev/null +++ b/tests/prototypes/prototype_mini.py @@ -0,0 +1,510 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from __future__ import annotations + +import argparse +import asyncio +from abc import ABC, abstractmethod +from contextlib import contextmanager +from typing import Any, Generator, Generic, List, NewType, Optional, TypeVar + +from beamlime.constructors import Factory, ProviderGroup +from beamlime.logging import BeamlimeLogger +from beamlime.logging.mixins import LogMixin + +from .parameters import ChunkSize, EventRate, NumFrames, NumPixels +from .random_data_providers import RandomEvents +from .workflows import ( + Events, + FirstPulseTime, + Histogrammed, + MergedData, + PixelGrouped, + ReducedData, + Workflow, +) + +TargetCounts = NewType("TargetCounts", int) + + +def calculate_target_counts( + num_frames: NumFrames, chunk_size: ChunkSize +) -> TargetCounts: + import math + + return TargetCounts(math.ceil(num_frames / chunk_size)) + + +class StopWatch(LogMixin): + logger: BeamlimeLogger + + def __init__(self) -> None: + self.lapse: dict[str, list[float]] = dict() + self._start_timestamp: Optional[float] = None + self._stop_timestamp: Optional[float] = None + + @property + def duration(self) -> float: + if self._start_timestamp is None: + raise TypeError( + "Start-timestamp is not available. ``start`` was never called." + ) + elif self._stop_timestamp is None: + raise TypeError( + "Stop-timestamp is not available. ``stop`` was never called." + ) + return self._stop_timestamp - self._start_timestamp + + def start(self) -> None: + import time + + if self._start_timestamp is None: + self._start_timestamp = time.time() + else: + raise RuntimeError( + "Start-timestamp is already recorded. " + "``start`` cannot be called twice." + ) + + def stop(self) -> None: + import time + + if self._start_timestamp is None: + raise RuntimeError("``start`` must be called before ``stop``.") + elif self._stop_timestamp is None: + self._stop_timestamp = time.time() + else: + raise RuntimeError( + "Stop-timestamp is already recorded. " + "``stop`` cannot be called twice." + ) + + def lap(self, app_name: str) -> None: + import time + + app_lapse = self.lapse.setdefault(app_name, []) + app_lapse.append(time.time()) + + @property + def lap_counts(self) -> dict[str, int]: + return {app_name: len(app_lapse) for app_name, app_lapse in self.lapse.items()} + + def log_benchmark_result(self): + self.info("Lap counts: %s", self.lap_counts) + self.info("Benchmark result: %s [s]", self.duration) + + +class BaseApp(LogMixin, ABC): + logger: BeamlimeLogger + stop_watch: StopWatch + target_counts: TargetCounts + + @property + def app_name(self) -> str: + return self.__class__.__name__ + + @property + def data_counts(self) -> int: + return len(self.stop_watch.lapse.get(self.app_name, [])) + + @property + def target_count_reached(self) -> bool: + return self.target_counts <= self.data_counts + + async def commit_process(self): + self.stop_watch.lap(self.app_name) + await asyncio.sleep(0) + + def data_pipe_monitor( + self, + pipe: List[Any], + timeout: float = 5, + interval: float = 1 / 14, + preferred_size: int = 1, + target_size: int = 1, + ): + from beamlime.core.schedulers import async_retry + + @async_retry( + TimeoutError, max_trials=int(timeout / interval), interval=interval + ) + async def wait_for_preferred_size() -> None: + if len(pipe) < preferred_size: + raise TimeoutError + + async def is_pipe_filled() -> bool: + try: + await wait_for_preferred_size() + except TimeoutError: + ... + return len(pipe) >= target_size + + return is_pipe_filled + + @abstractmethod + async def run(self): + ... + + +DataStreamListener = NewType("DataStreamListener", BaseApp) + + +class DataStreamSimulator(BaseApp): + raw_data_pipe: List[Events] + random_events: RandomEvents + chunk_size: ChunkSize + + def slice_chunk(self) -> Events: + chunk, self.random_events = ( + Events(self.random_events[: self.chunk_size]), + RandomEvents(self.random_events[self.chunk_size :]), + ) + return chunk + + async def run(self) -> None: + self.stop_watch.start() + + for i_chunk in range(self.target_counts): + chunk = self.slice_chunk() + self.raw_data_pipe.append(chunk) + self.debug("Sent %s th chunk, with %s pieces.", i_chunk + 1, len(chunk)) + await self.commit_process() + + self.info("Data streaming finished...") + + +InputType = TypeVar("InputType") +OutputType = TypeVar("OutputType") + + +class DataReductionApp(BaseApp, Generic[InputType, OutputType]): + workflow: Workflow + input_pipe: List[InputType] + output_pipe: List[OutputType] + + def __init__(self) -> None: + self.input_type = self._retrieve_type_arg('input_pipe') + self.output_type = self._retrieve_type_arg('output_pipe') + self.first_pulse_time: FirstPulseTime + super().__init__() + + @classmethod + def _retrieve_type_arg(cls, attr_name: str) -> type: + """ + Retrieve type arguments of an attribute with generic type. + It is only for retrieving input/output pipe type. + + >>> class C(DataReductionApp): + ... attr0: list[int] + ... + >>> C._retrieve_type_arg('attr0') + + """ + from typing import get_args, get_type_hints + + if not (attr_type := get_type_hints(cls).get(attr_name)): + raise ValueError( + f"Class {cls} does not have an attribute " + f"{attr_name} or it is missing type annotation." + ) + elif not (type_args := get_args(attr_type)): + raise TypeError(f"Attribute {attr_name} does not have any type arguments.") + else: + return type_args[0] + + def format_received(self, data: InputType) -> str: + return str(data) + + async def process_every_data(self, data: InputType) -> None: + self.debug("Received, %s", self.format_received(data)) + with self.workflow.constant_provider(self.input_type, data): + self.output_pipe.append(self.workflow[self.output_type]) + + await self.commit_process() + + def process_first_data(self, data: InputType) -> None: + ... + + def wrap_up(self, *args, **kwargs) -> Any: + self.info("No more data coming in. Finishing ...") + + async def run(self) -> None: + data_monitor = self.data_pipe_monitor(self.input_pipe, target_size=1) + if not self.target_count_reached and await data_monitor(): + data = self.input_pipe.pop(0) + self.process_first_data(data) + await self.process_every_data(data) + + while not self.target_count_reached and await data_monitor(): + data = self.input_pipe.pop(0) + await self.process_every_data(data) + + self.wrap_up() + + +class DataMerge(DataReductionApp[InputType, OutputType]): + input_pipe: List[Events] + output_pipe: List[MergedData] + + def format_received(self, data: Any) -> str: + return f"{len(data)} pieces of {self.input_type.__name__}" + + def process_first_data(self, data: Events) -> None: + sample_event = data[0] + first_pulse_time = sample_event.coords['event_time_zero'][0] + self.workflow.providers[FirstPulseTime] = lambda: first_pulse_time + + +class DataBinning(DataReductionApp[InputType, OutputType]): + input_pipe: List[MergedData] + output_pipe: List[PixelGrouped] + + +class DataReduction(DataReductionApp[InputType, OutputType]): + input_pipe: List[PixelGrouped] + output_pipe: List[ReducedData] + + +class DataHistogramming(DataReductionApp[InputType, OutputType]): + input_pipe: List[ReducedData] + output_pipe: List[Histogrammed] + + +class VisualizationDaemon(DataReductionApp[InputType, OutputType]): + input_pipe: List[Histogrammed] + output_pipe: Optional[List[None]] = None + + def show(self): + if not hasattr(self, "fig"): + raise AttributeError("Please wait until the first figure is created.") + return self.fig + + def process_first_data(self, data: Histogrammed) -> None: + import plopp as pp + + self.first_data = data + self.debug("First data as a seed of histogram: %s", self.first_data) + self.stream_node = pp.Node(self.first_data) + self.fig = pp.figure1d(self.stream_node) + + async def process_every_data(self, data: Histogrammed) -> None: + if data is not self.first_data: + self.first_data.values += data.values + self.stream_node.notify_children("update") + self.debug("Updated plot.") + await self.commit_process() + + def wrap_up(self, *args, **kwargs) -> Any: + self.stop_watch.stop() + self.stop_watch.log_benchmark_result() + return super().wrap_up(*args, **kwargs) + + +@contextmanager +def asyncio_event_loop() -> Generator[asyncio.AbstractEventLoop, Any, Any]: + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + yield loop + + loop.close() + asyncio.set_event_loop(asyncio.new_event_loop()) + + +class BasePrototype(BaseApp, ABC): + data_stream_listener: DataStreamListener + data_merger: DataMerge[Events, MergedData] + data_binner: DataBinning[MergedData, PixelGrouped] + data_reducer: DataReduction[PixelGrouped, ReducedData] + data_histogrammer: DataHistogramming[ReducedData, Histogrammed] + visualizer: VisualizationDaemon + + def collect_sub_daemons(self) -> list[BaseApp]: + return [ + self.data_stream_listener, + self.data_merger, + self.data_binner, + self.data_reducer, + self.data_histogrammer, + self.visualizer, + ] + + def run(self): + """ + Collect all coroutines of daemons and schedule them into the event loop. + + Notes + ----- + **Debugging log while running async daemons under various circumstances.** + + - ``asyncio.get_event_loop`` vs ``asyncio.new_event_loop`` + 1. ``asyncio.get_event_loop`` + ``get_event_loop`` will always return the current event loop. + If there is no event loop set in the thread, it will create a new one + and set it as a current event loop of the thread, and return the loop. + Many of ``asyncio`` free functions internally use ``get_event_loop``, + i.e. ``asyncio.create_task``. + + **Things to be considered while using ``asyncio.get_event_loop``. + - ``asyncio.create_task`` does not guarantee + whether the current loop is/will be alive until the task is done. + You may use ``run_until_complete`` to make sure the loop is not closed + until the task is finished. + When you need to throw multiple async calls to the loop, + use ``asyncio.gather`` to merge all the tasks like in this method. + - ``close`` or ``stop`` might accidentally destroy/interrupt + other tasks running in the same event loop. + i.e. You can accidentally destroy the main event loop of a jupyter kernel. + - [1]``RuntimeError`` if there has been an event loop set in the + thread object before but it is now removed. + + 2. ``asyncio.new_event_loop`` + ``asyncio.new_event_loop`` will always return the new event loop, + but it is not set it as a current loop of the thread automatically. + + However, sometimes it is automatically handled within the thread, + and it caused errors which was hard to be debugged under ``pytest`` session. + For example, + - The new event loop was not closed properly as it is destroyed. + - The new event loop was never started until it is destroyed. + ``Traceback`` of ``pytest`` did not show + where exactly the error is from in those cases. + It was resolved by using ``get_event_loop``, + or manually closing the event loop at the end of the test. + + **When to use ``asyncio.new_event_loop``.** + - ``asyncio.get_event_loop`` raises ``RuntimeError``[1] + - Multi-threads + + Please note that the loop object might need to be ``close``ed manually. + """ + self.debug('Start running ...') + with asyncio_event_loop() as loop: + daemon_coroutines = [daemon.run() for daemon in self.collect_sub_daemons()] + tasks = [loop.create_task(coro) for coro in daemon_coroutines] + + if not loop.is_running(): + loop.run_until_complete(asyncio.gather(*tasks)) + + +Prototype = NewType("Prototype", BasePrototype) + + +def prototype_app_providers() -> ProviderGroup: + from beamlime.constructors.providers import SingletonProvider + + app_providers = ProviderGroup( + SingletonProvider(StopWatch), + SingletonProvider(VisualizationDaemon), + SingletonProvider(calculate_target_counts), + ) + app_providers[DataMerge[Events, MergedData]] = DataMerge + app_providers[DataBinning[MergedData, PixelGrouped]] = DataBinning + app_providers[DataReduction[PixelGrouped, ReducedData]] = DataReduction + app_providers[DataHistogramming[ReducedData, Histogrammed]] = DataHistogramming + for pipe_type in (Events, PixelGrouped, MergedData, ReducedData, Histogrammed): + app_providers[List[pipe_type]] = SingletonProvider(list) + + return app_providers + + +def prototype_base_providers() -> ProviderGroup: + from beamlime.constructors.providers import merge + from beamlime.logging.providers import log_providers + + from .parameters import default_param_providers + from .random_data_providers import random_data_providers + from .workflows import workflow_providers + + return merge( + default_param_providers, + random_data_providers, + prototype_app_providers(), + log_providers, + workflow_providers, + ) + + +@contextmanager +def multiple_constant_providers( + factory: Factory, constants: Optional[dict[type, Any]] = None +): + if constants: + tp, val = constants.popitem() + with factory.constant_provider(tp, val): + with multiple_constant_providers(factory, constants): + yield + else: + yield + + +@contextmanager +def multiple_temporary_providers( + factory: Factory, providers: Optional[dict[type, Any]] = None +): + if providers: + tp, prov = providers.popitem() + with factory.temporary_provider(tp, prov): + with multiple_temporary_providers(factory, providers): + yield + else: + yield + + +def mini_prototype_factory() -> Factory: + providers = prototype_base_providers() + providers[Prototype] = BasePrototype + providers[DataStreamListener] = DataStreamSimulator + return Factory(providers) + + +def run_prototype( + prototype_factory: Factory, + parameters: Optional[dict[type, Any]] = None, + providers: Optional[dict[type, Any]] = None, +): + with multiple_constant_providers(prototype_factory, parameters): + with multiple_temporary_providers(prototype_factory, providers): + prototype = prototype_factory[Prototype] + prototype.run() + + +def prototype_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + parser.add_argument_group('Event Generator Configuration') + parser.add_argument( + '--event-rate', default=10**4, help=f": {EventRate}", type=int + ) + parser.add_argument( + '--num-pixels', default=10**4, help=f": {NumPixels}", type=int + ) + parser.add_argument('--num-frames', default=140, help=f": {NumFrames}", type=int) + + return parser + + +def run_standalone_prototype( + prototype_factory: Factory, arg_name_space: argparse.Namespace +): + import logging + + prototype_factory[BeamlimeLogger].setLevel(logging.DEBUG) + run_prototype( + prototype_factory=prototype_factory, + parameters={ + EventRate: arg_name_space.event_rate, + NumPixels: arg_name_space.num_pixels, + NumFrames: arg_name_space.num_frames, + }, + ) + + +if __name__ == "__main__": + factory = mini_prototype_factory() + arg_parser = prototype_arg_parser() + + run_standalone_prototype(factory, arg_name_space=arg_parser.parse_args()) diff --git a/tests/prototypes/prototype_test.py b/tests/prototypes/prototype_test.py new file mode 100644 index 000000000..80b18ad6d --- /dev/null +++ b/tests/prototypes/prototype_test.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from typing import Generator + +import pytest + +from beamlime.constructors import Factory + +from .prototype_mini import ( + BaseApp, + StopWatch, + TargetCounts, + VisualizationDaemon, + run_prototype, +) + + +@pytest.fixture +def mini_factory() -> Generator[Factory, None, None]: + from .prototype_mini import mini_prototype_factory + + yield mini_prototype_factory() + + +@pytest.fixture +def kafka_factory(kafka_test: bool) -> Generator[Factory, None, None]: + from .prototype_kafka import KafkaTopicDeleted, kafka_prototype_factory + + assert kafka_test + kafka = kafka_prototype_factory() + yield kafka + assert kafka[KafkaTopicDeleted] + + +def prototype_test_helper(prototype_factory: Factory, reference_app_tp: type[BaseApp]): + from .parameters import ChunkSize, EventRate, NumFrames, NumPixels + + # No laps recorded. + assert len(prototype_factory[StopWatch].lap_counts) == 0 + + num_frames = 140 + chunk_size = 28 + run_prototype( + prototype_factory=prototype_factory, + parameters={ + EventRate: 10**4, + NumPixels: 10**4, + NumFrames: num_frames, + ChunkSize: chunk_size, + }, + ) + stop_watch = prototype_factory[StopWatch] + reference_app_name = prototype_factory[reference_app_tp].app_name + assert prototype_factory[TargetCounts] == int(num_frames / chunk_size) + assert stop_watch.lap_counts[reference_app_name] == prototype_factory[TargetCounts] + + +def test_mini_prototype(mini_factory: Factory): + prototype_test_helper(mini_factory, VisualizationDaemon) + + +def test_kafka_prototype(kafka_factory: Factory): + prototype_test_helper(kafka_factory, VisualizationDaemon) diff --git a/tests/prototypes/random_data_providers.py b/tests/prototypes/random_data_providers.py new file mode 100644 index 000000000..e10be2bd2 --- /dev/null +++ b/tests/prototypes/random_data_providers.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from __future__ import annotations + +from typing import Any, Generator, List, NewType, Optional + +import scipp as sc +from numpy.random import Generator as RNG + +from beamlime.constructors import ProviderGroup +from beamlime.constructors.providers import SingletonProvider + +from .parameters import EventRate, FrameRate, NumFrames, NumPixels, RandomSeed + +# Derived Configuration +EventFrameRate = NewType("EventFrameRate", int) # [events/frame] +ReferenceTimeZero = NewType("ReferenceTimeZero", int) # [ns] + +# Generated Data +DetectorCounts = NewType("DetectorCounts", sc.Variable) +TimeCoords = NewType("TimeCoords", dict[str, sc.Variable]) +RandomPixelId = NewType("RandomPixelId", sc.Variable) +RandomEvent = NewType("RandomEvent", sc.DataArray) +RandomEvents = NewType("RandomEvents", List[sc.DataArray]) + + +def provide_rng(random_seed: RandomSeed) -> RNG: + from numpy.random import default_rng + + return default_rng(random_seed) + + +def calculate_event_per_frame( + frame_rate: FrameRate, event_rate: EventRate +) -> EventFrameRate: + return EventFrameRate(int(event_rate / frame_rate)) + + +def provide_time_coords( + rng: RNG, + ef_rate: EventFrameRate, + ref_time: Optional[ReferenceTimeZero] = None, +) -> TimeCoords: + ref_time = ref_time or ReferenceTimeZero(13620492**11) + et_zero = sc.array(dims=["event"], values=[ref_time] * ef_rate, unit='ns') + et_offset = sc.array( + dims=["event"], values=rng.random((ef_rate,)) * 800 + 200, unit='ns' + ) + + return TimeCoords({"event_time_zero": et_zero, "event_time_offset": et_offset}) + + +def provide_time_coords_generator( + rng: RNG, ef_rate: EventFrameRate, num_frames: NumFrames, frame_rate: FrameRate +) -> Generator[TimeCoords, Any, Any]: + for i_frame in range(num_frames): + ref_time = ReferenceTimeZero( + int(13620492e11 + (i_frame / frame_rate) * 10**9) + ) + yield provide_time_coords(rng, ef_rate, ref_time) + + +def provide_dummy_counts(ef_rate: EventFrameRate) -> DetectorCounts: + return DetectorCounts(sc.ones(sizes={"event": ef_rate}, unit='counts')) + + +def provide_random_pixel_id_generator( + rng: RNG, ef_rate: EventFrameRate, num_pixels: NumPixels, num_frames: NumFrames +) -> Generator[RandomPixelId, Any, Any]: + for _ in range(num_frames): + yield RandomPixelId( + sc.array( + dims=["event"], + values=rng.integers(low=0, high=num_pixels, size=ef_rate), + unit=sc.units.dimensionless, + ) + ) + + +def provide_random_event_generator( + pixel_id_generator: Generator[RandomPixelId, Any, Any], + time_coords_generator: Generator[TimeCoords, Any, Any], + data: DetectorCounts, +) -> Generator[RandomEvent, Any, Any]: + for pixel_id, time_coords in zip(pixel_id_generator, time_coords_generator): + yield RandomEvent( + sc.DataArray(data=data, coords={"pixel_id": pixel_id, **time_coords}) + ) + + +def provide_random_events( + random_event_generator: Generator[RandomEvent, Any, Any] +) -> RandomEvents: + """ + Whole set of random events should be created at once in advance + since randomly generating data can consume non-trivial amount of time + and it should not interfere other async applications. + """ + + return RandomEvents([random_event for random_event in random_event_generator]) + + +random_data_providers = ProviderGroup( + provide_rng, + provide_random_pixel_id_generator, + provide_time_coords_generator, + provide_random_event_generator, + provide_random_events, + SingletonProvider(calculate_event_per_frame), + SingletonProvider(provide_time_coords), + SingletonProvider(provide_dummy_counts), +) + + +def dump_random_dummy_events() -> RandomEvents: + num_pixels = NumPixels(10_000) + event_frame_rate = EventFrameRate(10_000) + num_frames = NumFrames(10) + frame_rate = FrameRate(14) + rng = provide_rng(RandomSeed(123)) + time_coords_generator = provide_time_coords_generator( + rng, event_frame_rate, num_frames, frame_rate + ) + data = provide_dummy_counts(event_frame_rate) + pixel_id_generator = provide_random_pixel_id_generator( + rng, event_frame_rate, num_pixels, num_frames + ) + random_event_generator = provide_random_event_generator( + pixel_id_generator, time_coords_generator, data + ) + return provide_random_events(random_event_generator) diff --git a/tests/prototypes/workflows.py b/tests/prototypes/workflows.py new file mode 100644 index 000000000..f5397d81f --- /dev/null +++ b/tests/prototypes/workflows.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from __future__ import annotations + +from typing import List, NewType + +import scipp as sc + +from beamlime.constructors import Factory, ProviderGroup, SingletonProvider + +from .parameters import FrameRate, HistogramBinSize, NumPixels + +# Coordinates +PixelID = NewType("PixelID", sc.Variable) +EventTimeOffset = NewType("EventTimeOffset", sc.Variable) +Length = NewType("Length", sc.Variable) +WaveLength = NewType("WaveLength", sc.Variable) + +# Constants +FirstPulseTime = NewType("FirstPulseTime", sc.Variable) +FrameUnwrappingGraph = NewType("FrameUnwrappingGraph", dict) +LtotalGraph = NewType("LtotalGraph", dict) +WavelengthGraph = NewType("WavelengthGraph", dict) + +# Generated/Calculated +Events = NewType("Events", List[sc.DataArray]) +MergedData = NewType("MergedData", sc.DataArray) +PixelIDEdges = NewType("PixelIDEdges", sc.Variable) +PixelGrouped = NewType("PixelGrouped", sc.DataArray) +LTotalCalculated = NewType("Transformed", sc.DataArray) +FrameUnwrapped = NewType("FrameUnwrapped", sc.DataArray) +ReducedData = NewType("ReducedData", sc.DataArray) +Histogrammed = NewType("Histogrammed", sc.DataArray) + + +def provide_Ltotal_graph() -> LtotalGraph: + c_a = sc.scalar(0.00001, unit='m') + c_b = sc.scalar(0.1, unit='m') + + return LtotalGraph( + { + 'L1': lambda pixel_id: (pixel_id * c_a) + c_b, + 'L2': lambda pixel_id: (pixel_id * c_a) + c_b, + 'Ltotal': lambda L1, L2: L1 + L2, + } + ) + + +def provide_wavelength_graph() -> WavelengthGraph: + c_c = sc.scalar(1, unit='1e-3m^2/s') + + return WavelengthGraph( + {'wavelength': lambda tof, Ltotal: (c_c * tof / Ltotal).to(unit='angstrom')} + ) + + +def merge_data_list(da_list: Events) -> MergedData: + return MergedData(sc.concat(da_list, dim='event')) + + +def provide_pixel_id_bin_edges(num_pixels: NumPixels) -> PixelIDEdges: + return PixelIDEdges(sc.arange(dim='pixel_id', start=0, stop=num_pixels)) + + +def bin_pixel_id(da: MergedData, pixel_bin_coord: PixelIDEdges) -> PixelGrouped: + return PixelGrouped(da.group(pixel_bin_coord)) + + +def calculate_ltotal( + binned: PixelGrouped, + graph: LtotalGraph, +) -> LTotalCalculated: + da = binned.transform_coords(['Ltotal'], graph=graph) + if not isinstance(da, sc.DataArray): + raise TypeError + + return LTotalCalculated(da) + + +def unwrap_frames( + da: LTotalCalculated, frame_rate: FrameRate, first_pulse_time: FirstPulseTime +) -> FrameUnwrapped: + from scippneutron.tof import unwrap_frames + + return FrameUnwrapped( + unwrap_frames( + da, + pulse_period=sc.scalar(1 / frame_rate, unit='ns'), # No pulse skipping + lambda_min=sc.scalar(5.0, unit='angstrom'), + frame_offset=first_pulse_time.to(unit='ms'), + first_pulse_time=first_pulse_time, + ) + ) + + +def calculate_wavelength( + unwrapped: FrameUnwrapped, graph: WavelengthGraph +) -> ReducedData: + da = unwrapped.transform_coords(['wavelength'], graph=graph) + if not isinstance(da, sc.DataArray): + raise TypeError + + return ReducedData(da) + + +def histogram_result( + bin_size: HistogramBinSize, reduced_data: ReducedData +) -> Histogrammed: + return reduced_data.hist(wavelength=bin_size) + + +Workflow = NewType("Workflow", Factory) + + +def provide_workflow( + num_pixels: NumPixels, histogram_binsize: HistogramBinSize, frame_rate: FrameRate +) -> Workflow: + providers = ProviderGroup( + SingletonProvider(provide_wavelength_graph), + SingletonProvider(provide_Ltotal_graph), + merge_data_list, + bin_pixel_id, + calculate_ltotal, + calculate_wavelength, + unwrap_frames, + histogram_result, + provide_pixel_id_bin_edges, + ) + + providers[NumPixels] = lambda: num_pixels + providers[HistogramBinSize] = lambda: histogram_binsize + providers[FrameRate] = lambda: frame_rate + return Workflow(Factory(providers)) + + +workflow_providers = ProviderGroup(SingletonProvider(provide_workflow))