diff --git a/.github/workflows/pr_to_pr.yml b/.github/workflows/pr_to_pr.yml index 224f7909..1861700f 100644 --- a/.github/workflows/pr_to_pr.yml +++ b/.github/workflows/pr_to_pr.yml @@ -1,9 +1,12 @@ -name: Pull request to pr/** branches +name: Pull request and push to pr/** branches on: pull_request: branches: - "pr/**" + push: + branches: + - "pr/**" jobs: check-coding-style: diff --git a/VERSION b/VERSION index 3c80e4f0..e1df5de7 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.4.3 \ No newline at end of file +1.4.4 \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index cffb629f..7b6218b3 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -24,8 +24,8 @@ project = 'f3dasm' author = 'Martin van der Schelling' copyright = '2022, Martin van der Schelling' -version = '1.4.3' -release = '1.4.3' +version = '1.4.4' +release = '1.4.4' # -- General configuration ---------------------------------------------------- diff --git a/docs/source/rst_doc_files/classes/design/experimentsample.rst b/docs/source/rst_doc_files/classes/design/experimentsample.rst index bd928b30..00be3b85 100644 --- a/docs/source/rst_doc_files/classes/design/experimentsample.rst +++ b/docs/source/rst_doc_files/classes/design/experimentsample.rst @@ -54,6 +54,19 @@ An KeyError will be raised if the key is not found. >>> experiment_sample.get('param_1') 0.0249 +Manually iterating over ExperimentData +---------------------------------------- + +The :class:`~f3dasm.design.ExperimentData` object can be manually iterated over to get :class:`~f3dasm.design.ExperimentSample` objects for each experiment: + +.. code-block:: python + + >>> for experiment_sample in experiment_data: + ... print(experiment_sample) + ExperimentSample(0 : {'x0': 0.8184054141827567, 'x1': 0.937852542255321, 'x2': 0.7376563782762678} - {}) + ExperimentSample(1 : {'x0': 0.7203461491873061, 'x1': 0.7320604457665572, 'x2': 0.2524387342272223} - {}) + ExperimentSample(2 : {'x0': 0.35449352388104904, 'x1': 0.11413412225748525, 'x2': 0.1467895592274866} - {}) + Storing output parameters to the experiment sample -------------------------------------------------- diff --git a/src/f3dasm/__init__.py b/src/f3dasm/__init__.py index b4e60693..58852bc9 100644 --- a/src/f3dasm/__init__.py +++ b/src/f3dasm/__init__.py @@ -37,7 +37,7 @@ # ============================================================================= -__version__ = '1.4.3' +__version__ = '1.4.4' # Log welcome message and the version of f3dasm diff --git a/src/f3dasm/_src/datageneration/datagenerator.py b/src/f3dasm/_src/datageneration/datagenerator.py index 81a4742e..93d421f5 100644 --- a/src/f3dasm/_src/datageneration/datagenerator.py +++ b/src/f3dasm/_src/datageneration/datagenerator.py @@ -7,6 +7,7 @@ # Standard import sys +from abc import abstractmethod from functools import partial from typing import Any, Callable @@ -43,30 +44,80 @@ class DataGenerator: """Base class for a data generator""" def pre_process(self, experiment_sample: ExperimentSample, **kwargs) -> None: - """Function that handles the pre-processing""" + """Interface function that handles the pre-processing of the data generator + + Notes + ----- + If not implemented the function will be skipped + + The experiment_sample is cached inside the data generator. This + allows the user to access the experiment_sample in the pre_process, execute + and post_process functions as a class variable called self.experiment_sample. + """ ... + @abstractmethod def execute(self, **kwargs) -> None: - """Function that calls the FEM simulator the pre-processing""" - raise NotImplementedError("No execute function implemented!") + """Interface function that handles the execution of the data generator + + Raises + ------ + NotImplementedError + If the function is not implemented by the user + + Notes + ----- + The experiment_sample is cached inside the data generator. This + allows the user to access the experiment_sample in the pre_process, execute + and post_process functions as a class variable called self.experiment_sample. + """ + + ... def post_process(self, experiment_sample: ExperimentSample, **kwargs) -> None: - """Function that handles the post-processing""" + """Interface function that handles the post-processing of the data generator + + Notes + ----- + If not implemented the function will be skipped + + The experiment_sample is cached inside the data generator. This + allows the user to access the experiment_sample in the pre_process, execute + and post_process functions as a class variable called self.experiment_sample. + """ ... @time_and_log - def run(self, experiment_sample: ExperimentSample, **kwargs) -> ExperimentSample: - """Run the data generator + def _run(self, experiment_sample: ExperimentSample, **kwargs) -> ExperimentSample: + """ + Run the data generator + This function chains the following methods together + + * pre_process(); to combine the experiment_sample and the parameters + of the data generator to an input file that can be used to run the data generator + + * execute(); to run the data generator and generate the response of the experiment + + * post_process(); to process the response of the experiment and store it back + in the experiment_sample + + The function also caches the experiment_sample in the data generator. This + allows the user to access the experiment_sample in the pre_process, execute + and post_process functions as a class variable called self.experiment_sample. Parameters ---------- ExperimentSample : ExperimentSample The design to run the data generator on + kwargs : dict + The keyword arguments to pass to the pre_process, execute and post_process + Returns ------- ExperimentSample - Processed design + Processed design with the response of the data generator saved in the + experiment_sample """ # Cache the design self.experiment_sample: ExperimentSample = experiment_sample @@ -88,7 +139,25 @@ def _post_simulation(self) -> None: ... def add_pre_process(self, func: Callable, **kwargs): + """Add a pre-processing function to the data generator + + Parameters + ---------- + func : Callable + The function to add to the pre-processing + kwargs : dict + The keyword arguments to pass to the pre-processing function + """ self.pre_process = partial(func, **kwargs) def add_post_process(self, func: Callable, **kwargs): + """Add a post-processing function to the data generator + + Parameters + ---------- + func : Callable + The function to add to the post-processing + kwargs : dict + The keyword arguments to pass to the post-processing function + """ self.post_process = partial(func, **kwargs) diff --git a/src/f3dasm/_src/datageneration/functions/function.py b/src/f3dasm/_src/datageneration/functions/function.py index c5e376bd..2036e939 100644 --- a/src/f3dasm/_src/datageneration/functions/function.py +++ b/src/f3dasm/_src/datageneration/functions/function.py @@ -99,7 +99,7 @@ def execute(self, experiment_sample: ExperimentSample) -> ExperimentSample: experiment_sample["y"] = self(x).ravel().astype(np.float32) return experiment_sample - def run(self, experiment_sample: ExperimentSample, **kwargs) -> ExperimentSample: + def _run(self, experiment_sample: ExperimentSample, **kwargs) -> ExperimentSample: return self.execute(experiment_sample) def _retrieve_original_input(self, x: np.ndarray): diff --git a/src/f3dasm/_src/design/domain.py b/src/f3dasm/_src/design/domain.py index e4ea5e43..0fe9e747 100644 --- a/src/f3dasm/_src/design/domain.py +++ b/src/f3dasm/_src/design/domain.py @@ -8,10 +8,11 @@ from __future__ import annotations # Standard +import math import pickle from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, Iterator, List, Sequence, Type +from typing import Any, Dict, Iterable, Iterator, List, Sequence, Type # Third-party core import numpy as np @@ -252,7 +253,14 @@ def add_int(self, name: str, low: int, high: int, step: int = 1): >>> domain.add_int('param1', 0, 10, 2) >>> domain.space {'param1': DiscreteParameter(lower_bound=0, upper_bound=10, step=2)} + + Note + ---- + If the lower and upper bound are equal, then a constant parameter + will be added to the domain! """ + if low == high: + self.add_constant(name, low) self._add(name, DiscreteParameter(low, high, step)) def add_float(self, name: str, low: float, high: float, log: bool = False): @@ -275,8 +283,16 @@ def add_float(self, name: str, low: float, high: float, log: bool = False): >>> domain.add_float('param1', 0., 10., log=True) >>> domain.space {'param1': ContinuousParameter(lower_bound=0., upper_bound=10., log=True)} + + Note + ---- + If the lower and upper bound are equal, then a constant parameter + will be added to the domain! """ - self._add(name, ContinuousParameter(low, high, log)) + if math.isclose(low, high): + self.add_constant(name, low) + else: + self._add(name, ContinuousParameter(low, high, log)) def add_category(self, name: str, categories: Sequence[CategoricalType]): """Add a new categorical input parameter to the domain. @@ -573,6 +589,38 @@ def _filter(self, type: Type[Parameter]) -> Domain: if isinstance(parameter, type)} ) + def select(self, names: str | Iterable[str]) -> Domain: + """Select a subset of parameters from the domain. + + Parameters + ---------- + + names : str or Iterable[str] + The names of the parameters to select. + + Returns + ------- + Domain + A new domain with the selected parameters. + + Example + ------- + >>> domain = Domain() + >>> domain.space = { + ... 'param1': ContinuousParameter(lower_bound=0., upper_bound=1.), + ... 'param2': DiscreteParameter(lower_bound=0, upper_bound=8), + ... 'param3': CategoricalParameter(categories=['cat1', 'cat2']) + ... } + >>> domain.select(['param1', 'param3']) + Domain({'param1': ContinuousParameter(lower_bound=0, upper_bound=1), + 'param3': CategoricalParameter(categories=['cat1', 'cat2'])}) + """ + + if isinstance(names, str): + names = [names] + + return Domain(space={key: self.space[key] for key in names}) + # Miscellaneous # ============================================================================= diff --git a/src/f3dasm/_src/experimentdata/_data.py b/src/f3dasm/_src/experimentdata/_data.py index 8faa63e6..abc85881 100644 --- a/src/f3dasm/_src/experimentdata/_data.py +++ b/src/f3dasm/_src/experimentdata/_data.py @@ -291,6 +291,23 @@ def n_best_samples(self, nosamples: int, column_name: List[str] | str) -> pd.Dat """ return self.data.nsmallest(n=nosamples, columns=column_name) + def select_columns(self, columns: Iterable[str] | str) -> _Data: + """Filter the data on the selected columns. + + Parameters + ---------- + columns : Iterable[str] | str + The columns to select. + + Returns + ------- + _Data + The data only with the selected columns + """ + # This is necessary otherwise self.data[columns] will be a Series + if isinstance(columns, str): + columns = [columns] + return _Data(self.data[columns]) # Append and remove data # ============================================================================= diff --git a/src/f3dasm/_src/experimentdata/experimentdata.py b/src/f3dasm/_src/experimentdata/experimentdata.py index ade061cc..4e5a5f39 100644 --- a/src/f3dasm/_src/experimentdata/experimentdata.py +++ b/src/f3dasm/_src/experimentdata/experimentdata.py @@ -104,7 +104,7 @@ def __init__(self, domain: Optional[Domain] = None, input_data: Optional[DataTyp if self.input_data.is_empty(): self.input_data = _Data.from_domain(self.domain) - self.jobs = jobs_factory(jobs, self.input_data, job_value) + self.jobs = jobs_factory(jobs, self.input_data, self.output_data, job_value) # Check if the columns of input_data are in the domain if not self.input_data.has_columnnames(self.domain.names): @@ -119,10 +119,16 @@ def __len__(self): return len(self.input_data) def __iter__(self) -> Iterator[Tuple[Dict[str, Any]]]: - return self.input_data.__iter__() + self.current_index = 0 + return self - def __next__(self): - return self.input_data.__next__() + def __next__(self) -> ExperimentSample: + if self.current_index >= len(self): + raise StopIteration + else: + index = self.index[self.current_index] + self.current_index += 1 + return self.get_experiment_sample(index) def __add__(self, other: ExperimentData | ExperimentSample) -> ExperimentData: """The + operator combines two ExperimentData objects""" @@ -145,11 +151,6 @@ def __eq__(self, __o: ExperimentData) -> bool: self.jobs == __o.jobs, self.domain == __o.domain]) - def __getitem__(self, index: int | slice | Iterable[int]) -> _Data: - """The [] operator returns a single datapoint or a subset of datapoints""" - return ExperimentData(input_data=self.input_data[index], output_data=self.output_data[index], - jobs=self.jobs[index], domain=self.domain, filename=self.filename, path=self.path) - def _repr_html_(self) -> str: return self.input_data.combine_data_to_multiindex(self.output_data, self.jobs.to_dataframe())._repr_html_() @@ -171,6 +172,20 @@ def wrapper_func(self, *args, **kwargs) -> None: return wrapper_func + # Properties + # ============================================================================= + + @property + def index(self) -> pd.Index: + """Returns an iterable of the job number of the experiments + + Returns + ------- + pd.Index + The job number of all the experiments in pandas Index format + """ + return self.input_data.indices + # Alternative Constructors # ============================================================================= @@ -280,7 +295,7 @@ def _from_file_attempt(cls: Type[ExperimentData], filename: Path) -> ExperimentD except FileNotFoundError: raise FileNotFoundError(f"Cannot find the files from {filename}.") - # Export + # Selecting subsets # ============================================================================= def select(self, indices: int | slice | Iterable[int]) -> ExperimentData: @@ -296,15 +311,85 @@ def select(self, indices: int | slice | Iterable[int]) -> ExperimentData: ExperimentData The selected ExperimentData object with only the selected indices. """ - return self[indices] - def store(self, filename: str = None): + return ExperimentData(input_data=self.input_data[indices], output_data=self.output_data[indices], + jobs=self.jobs[indices], domain=self.domain, filename=self.filename, path=self.path) + + def get_input_data(self, parameter_names: Optional[str | Iterable[str]] = None) -> ExperimentData: + """Retrieve a subset of the input data from the ExperimentData object + + Parameters + ---------- + parameter_names : str | Iterable[str], optional + The name(s) of the input parameters that you want to retrieve, + if None all input parameters are retrieved, by default None + + Returns + ------- + ExperimentData + The selected ExperimentData object with only the selected input data. + + Notes + ----- + If parameter_names is None, all input data is retrieved. + The returned ExperimentData object has the domain of the original ExperimentData object, + but only with the selected input parameters. + """ + if parameter_names is None: + return ExperimentData(input_data=self.input_data, jobs=self.jobs, + domain=self.domain, filename=self.filename, path=self.path) + else: + return ExperimentData(input_data=self.input_data.select_columns(parameter_names), jobs=self.jobs, + domain=self.domain.select(parameter_names), filename=self.filename, path=self.path) + + def get_output_data(self, parameter_names: Optional[str | Iterable[str]] = None) -> ExperimentData: + """Retrieve a subset of the output data from the ExperimentData object + + Parameters + ---------- + parameter_names : str | Iterable[str], optional + The name(s) of the output parameters that you want to retrieve, + if None all output parameters are retrieved, by default None + + Returns + ------- + ExperimentData + The selected ExperimentData object with only the selected output data. + + Notes + ----- + If parameter_names is None, all output data is retrieved. + The returned ExperimentData object has no domain object and no input data! + """ + if parameter_names is None: + return ExperimentData(output_data=self.output_data, jobs=self.jobs, + filename=self.filename, path=self.path) + else: + return ExperimentData(output_data=self.output_data.select_columns(parameter_names), jobs=self.jobs, + filename=self.filename, path=self.path) + + # Export + # ============================================================================= + + def store(self, filename: Optional[str] = None): """Store the ExperimentData to disk, with checking for a lock Parameters ---------- filename : str, optional filename of the files to store, without suffix + + Notes + ----- + If no filename is given, the filename of the ExperimentData object is used. + + The ExperimentData object is stored at the location provided by the `.path` attribute + that is set upon creation of the object. + The ExperimentData object is stored in four files. The name is used as a prefix for the four files: + - the input data (_input.csv) + - the output data (_output.csv) + - the jobs (_jobs.pkl) + - the domain (_domain.pkl) """ if filename is None: filename = self.filename @@ -362,7 +447,7 @@ def get_n_best_output(self, n_samples: int) -> ExperimentData: New experimentData object with a selection of the n best samples. """ df = self.output_data.n_best_samples(n_samples, self.output_data.names) - return self[df.index] + return self.select(df.index) # Append or remove data # ============================================================================= @@ -483,10 +568,10 @@ def _reset_index(self) -> None: self.output_data.reset_index() self.jobs.reset_index() - # ExperimentSample +# ExperimentSample # ============================================================================= - def _get_experiment_sample(self, index: int) -> ExperimentSample: + def get_experiment_sample(self, index: int) -> ExperimentSample: """ Gets the experiment_sample at the given index. @@ -540,7 +625,7 @@ def _access_open_job_data(self) -> ExperimentSample: """ job_index = self.jobs.get_open_job() self.jobs.mark(job_index, status=Status.IN_PROGRESS) - experiment_sample = self._get_experiment_sample(job_index) + experiment_sample = self.get_experiment_sample(job_index) return experiment_sample @_access_file @@ -702,7 +787,7 @@ def _run_sequential(self, data_generator: DataGenerator, kwargs: dict): logger.debug( f"Running experiment_sample {experiment_sample._jobnumber} with kwargs {kwargs}") - _experiment_sample = data_generator.run(experiment_sample, **kwargs) # no *args! + _experiment_sample = data_generator._run(experiment_sample, **kwargs) # no *args! self._set_experiment_sample(_experiment_sample) except Exception as e: error_msg = f"Error in experiment_sample {experiment_sample._jobnumber}: {e}" @@ -737,7 +822,7 @@ def _run_multiprocessing(self, data_generator: DataGenerator, kwargs: dict): def f(options: Dict[str, Any]) -> Any: logger.debug(f"Running experiment_sample {options['experiment_sample'].job_number}") - return data_generator.run(**options) + return data_generator._run(**options) with mp.Pool() as pool: # maybe implement pool.starmap_async ? @@ -775,7 +860,7 @@ def _run_cluster(self, data_generator: DataGenerator, kwargs: dict): break try: - _experiment_sample = data_generator.run(experiment_sample, **kwargs) + _experiment_sample = data_generator._run(experiment_sample, **kwargs) self._write_experiment_sample(_experiment_sample) except Exception as e: error_msg = f"Error in experiment_sample {experiment_sample._jobnumber}: {e}" @@ -902,7 +987,7 @@ def _iterate_scipy(self, optimizer: Optimizer, data_generator: DataGenerator, # Repeat last iteration to fill up total iteration if len(self) < n_data_before_iterate + iterations: - last_design = self._get_experiment_sample(len(self)-1) + last_design = self.get_experiment_sample(len(self)-1) for repetition in range(iterations - (len(self) - n_data_before_iterate)): self._add_experiments(last_design) @@ -986,7 +1071,8 @@ def domain_factory(domain: Domain | None, input_data: _Data) -> Domain: raise TypeError(f"Domain must be of type Domain or None, not {type(domain)}") -def jobs_factory(jobs: Path | str | _JobQueue | None, input_data: _Data, job_value: Status) -> _JobQueue: +def jobs_factory(jobs: Path | str | _JobQueue | None, input_data: _Data, + output_data: _Data, job_value: Status) -> _JobQueue: """Creates a _JobQueue object from particular inpute Parameters @@ -994,7 +1080,9 @@ def jobs_factory(jobs: Path | str | _JobQueue | None, input_data: _Data, job_val jobs : Path | str | None input data for the jobs input_data : _Data - _Data object to extract indices from, if necessary + _Data object of input data to extract indices from, if necessary + output_data : _Data + _Data object of output data to extract indices from, if necessary job_value : Status initial value of all the jobs @@ -1009,4 +1097,7 @@ def jobs_factory(jobs: Path | str | _JobQueue | None, input_data: _Data, job_val if isinstance(jobs, (Path, str)): return _JobQueue.from_file(Path(jobs)) + if input_data.is_empty(): + return _JobQueue.from_data(output_data, value=job_value) + return _JobQueue.from_data(input_data, value=job_value) diff --git a/src/f3dasm/_src/experimentdata/experimentsample.py b/src/f3dasm/_src/experimentdata/experimentsample.py index 5d4f979b..0d8433b4 100644 --- a/src/f3dasm/_src/experimentdata/experimentsample.py +++ b/src/f3dasm/_src/experimentdata/experimentsample.py @@ -349,27 +349,26 @@ def to_dict(self) -> Dict[str, Any]: """ return {**self.input_data, **self.output_data_loaded, 'job_number': self.job_number} - def store(self, object: Any, name: str, to_disk: bool = False, + def store(self, name: str, object: Any, to_disk: bool = False, store_method: Optional[Type[_Store]] = None) -> None: """Store an object to disk. Parameters ---------- - object : Any - The object to store. name : str The name of the file to store the object in. + object : Any + The object to store. to_disk : bool, optional Whether to store the object to disk, by default False store_method : Store, optional The method to use to store the object, by default None - Raises - ------ - - TypeError - If the object type is not supported and no store_method is provided. + Notes + ----- + If to_disk is True and no store_method is provided, the default store method will be used. + The default store method is saving the object as a pickle file (.pkl). """ if to_disk: self._store_to_disk(object=object, name=name, store_method=store_method) diff --git a/src/f3dasm/_src/logger.py b/src/f3dasm/_src/logger.py index 92515338..ed30ee35 100644 --- a/src/f3dasm/_src/logger.py +++ b/src/f3dasm/_src/logger.py @@ -39,8 +39,8 @@ handler = logging.StreamHandler() handler.setFormatter(formatter) -# Set the level for the "f3dasm" logger -logger.setLevel(logging.INFO) +# Set the default level for the "f3dasm" logger +logger.setLevel(logging.WARNING) # Add the custom handler to the "f3dasm" logger logger.addHandler(handler) diff --git a/src/f3dasm/_src/optimization/adapters/scipy_implementations.py b/src/f3dasm/_src/optimization/adapters/scipy_implementations.py index 269fc7e8..fca44036 100644 --- a/src/f3dasm/_src/optimization/adapters/scipy_implementations.py +++ b/src/f3dasm/_src/optimization/adapters/scipy_implementations.py @@ -44,7 +44,7 @@ def run_algorithm(self, iterations: int, data_generator: DataGenerator): """ def fun(x): - sample: ExperimentSample = data_generator.run( + sample: ExperimentSample = data_generator._run( ExperimentSample.from_numpy(x)) _, y = sample.to_numpy() return float(y) diff --git a/tests/experimentdata/test_experimentdata.py b/tests/experimentdata/test_experimentdata.py index 19f53301..79422f34 100644 --- a/tests/experimentdata/test_experimentdata.py +++ b/tests/experimentdata/test_experimentdata.py @@ -10,7 +10,7 @@ import pytest import xarray as xr -from f3dasm import ExperimentData +from f3dasm import ExperimentData, ExperimentSample from f3dasm._src.experimentdata.experimentdata import DataTypes from f3dasm.design import (ContinuousParameter, Domain, Status, _Data, _JobQueue, make_nd_continuous_domain) @@ -53,16 +53,6 @@ def test_experiment_data_len_equals_output_data(experimentdata: ExperimentData): assert len(experimentdata) == len(experimentdata.output_data) -@pytest.mark.parametrize("slice_type", [3, [0, 1, 3], slice(0, 3)]) -def test_experiment_data_getitem_(slice_type: int | Iterable[int], experimentdata: ExperimentData): - input_data = experimentdata.input_data[slice_type] - output_data = experimentdata.output_data[slice_type] - jobs = experimentdata.jobs[slice_type] - constructed_experimentdata = ExperimentData( - input_data=input_data, output_data=output_data, jobs=jobs, domain=experimentdata.domain) - assert constructed_experimentdata == experimentdata[slice_type] - - @pytest.mark.parametrize("slice_type", [3, [0, 1, 3], slice(0, 3)]) def test_experiment_data_select(slice_type: int | Iterable[int], experimentdata: ExperimentData): input_data = experimentdata.input_data[slice_type] @@ -555,8 +545,8 @@ def mock_pd_read_pickle(*args, **kwargs): assert experiment_data.jobs == experimentdata_expected_no_output.jobs assert experiment_data == experimentdata_expected_no_output - - + + @pytest.mark.parametrize("input_data", [None]) @pytest.mark.parametrize("output_data", [None]) @pytest.mark.parametrize("domain", [make_nd_continuous_domain(bounds=np.array([[0., 1.], [0., 1.], [0., 1.]]), @@ -666,6 +656,44 @@ def test_evaluate_mode(mode: str, experimentdata_continuous: ExperimentData, tmp experimentdata_continuous.evaluate("ackley", mode=mode, kwargs={ "scale_bounds": np.array([[0., 1.], [0., 1.], [0., 1.]]), 'seed': SEED}) +def test_get_input_data(experimentdata_expected_no_output: ExperimentData): + input_data = experimentdata_expected_no_output.get_input_data() + df, _ = input_data.to_pandas() + pd.testing.assert_frame_equal(df, pd_input()) + assert experimentdata_expected_no_output.input_data == input_data.input_data + + +@pytest.mark.parametrize("selection", ["x0", ["x0"], ["x0", "x2"]]) +def test_get_input_data_selection(experimentdata_expected_no_output: ExperimentData, selection: Iterable[str] | str): + input_data = experimentdata_expected_no_output.get_input_data(selection) + df, _ = input_data.to_pandas() + if isinstance(selection, str): + selection = [selection] + selected_pd = pd_input()[selection] + pd.testing.assert_frame_equal(df, selected_pd) + +def test_get_output_data(experimentdata_expected: ExperimentData): + output_data = experimentdata_expected.get_output_data() + _, df = output_data.to_pandas() + pd.testing.assert_frame_equal(df, pd_output()) + assert experimentdata_expected.output_data == output_data.output_data + +@pytest.mark.parametrize("selection", ["y", ["y"]]) +def test_get_output_data_selection(experimentdata_expected: ExperimentData, selection: Iterable[str] | str): + output_data = experimentdata_expected.get_output_data(selection) + _, df = output_data.to_pandas() + if isinstance(selection, str): + selection = [selection] + selected_pd = pd_output()[selection] + pd.testing.assert_frame_equal(df, selected_pd) + +def test_iter_behaviour(experimentdata_continuous: ExperimentData): + for i in experimentdata_continuous: + assert isinstance(i, ExperimentSample) + + selected_experimentdata = experimentdata_continuous.select([0, 2, 4]) + for i in selected_experimentdata: + assert isinstance(i, ExperimentSample) if __name__ == "__main__": # pragma: no cover pytest.main()