Skip to content

Commit

Permalink
Refactor domain initialization and data loading
Browse files Browse the repository at this point in the history
  • Loading branch information
mpvanderschelling committed Jun 25, 2024
1 parent 6d99381 commit 74fd315
Show file tree
Hide file tree
Showing 7 changed files with 257 additions and 118 deletions.
47 changes: 36 additions & 11 deletions src/f3dasm/_src/design/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dataclasses import dataclass, field
from pathlib import Path
from typing import (Any, Dict, Iterable, Iterator, List, Literal, Optional,
Sequence, Type)
Protocol, Sequence, Type)

# Third-party core
import numpy as np
Expand All @@ -36,6 +36,13 @@
# =============================================================================


class _Data(Protocol):
def to_dataframe(self) -> pd.DataFrame:
...

# =============================================================================


@dataclass
class Domain:
"""Main class for defining the domain of the design of experiments.
Expand Down Expand Up @@ -238,6 +245,26 @@ def from_dataframe(cls, df_input: pd.DataFrame,

return cls(space=input_space, output_space=output_space)

@classmethod
def from_data(cls: Type[Domain],
input_data: _Data, output_data: _Data) -> Domain:
"""Initializes a Domain from input and output data.
Parameters
----------
input_data : _Data
Input data.
output_data : _Data
Output data.
Returns
-------
Domain
Domain object
"""
return cls.from_dataframe(
input_data.to_dataframe(), output_data.to_dataframe())

# Export
# =============================================================================

Expand Down Expand Up @@ -645,9 +672,7 @@ def make_nd_continuous_domain(bounds: np.ndarray | List[List[float]],
return Domain(space)


def _domain_factory(domain: Domain | DictConfig | None,
input_data: pd.DataFrame,
output_data: pd.DataFrame) -> Domain:
def _domain_factory(domain: Domain | DictConfig | str | Path) -> Domain:
if isinstance(domain, Domain):
return domain

Expand All @@ -657,14 +682,14 @@ def _domain_factory(domain: Domain | DictConfig | None,
elif isinstance(domain, DictConfig):
return Domain.from_yaml(domain)

elif (input_data.empty and output_data.empty and domain is None):
return Domain()
# elif (input_data.empty and output_data.empty and domain is None):
# return Domain()

elif domain is None:
return Domain.from_dataframe(
input_data, output_data)
# elif domain is None:
# return Domain.from_dataframe(
# input_data, output_data)

else:
raise TypeError(
f"Domain must be of type Domain, DictConfig "
f"or None, not {type(domain)}")
f"Domain must be of type Domain, DictConfig, str or Path, "
f"not {type(domain)}")
10 changes: 6 additions & 4 deletions src/f3dasm/_src/experimentdata/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ def from_file(cls, filename: Path | str) -> _Data:
return cls(df, columns=_Columns(_columns))

@classmethod
def from_numpy(cls: Type[_Data], array: np.ndarray) -> _Data:
def from_numpy(cls: Type[_Data],
array: np.ndarray, keys: Iterable[str]) -> _Data:
"""Loads the data from a numpy array.
Parameters
Expand Down Expand Up @@ -458,7 +459,8 @@ def _convert_dict_to_data(dictionary: Dict[str, Any]) -> _Data:
return _Data(data=df, columns=_Columns(_columns))


def _data_factory(data: DataTypes) -> _Data:
def _data_factory(data: DataTypes,
keys: Optional[Iterable[str]] = None) -> _Data:
if data is None:
return _Data()

Expand All @@ -469,10 +471,10 @@ def _data_factory(data: DataTypes) -> _Data:
return _Data.from_dataframe(data)

elif isinstance(data, (Path, str)):
return _Data.from_file(data)
return _Data.from_file(Path(data))

elif isinstance(data, np.ndarray):
return _Data.from_numpy(data)
return _Data.from_numpy(data, keys=keys)

else:
raise TypeError(
Expand Down
Loading

0 comments on commit 74fd315

Please sign in to comment.