diff --git a/hamilflow/models/harmonic_oscillator.py b/hamilflow/models/harmonic_oscillator/__init__.py similarity index 97% rename from hamilflow/models/harmonic_oscillator.py rename to hamilflow/models/harmonic_oscillator/__init__.py index 072d869..941c70e 100644 --- a/hamilflow/models/harmonic_oscillator.py +++ b/hamilflow/models/harmonic_oscillator/__init__.py @@ -7,6 +7,11 @@ from numpy.typing import ArrayLike from pydantic import BaseModel, computed_field, field_validator +from hamilflow.models.harmonic_oscillator.initial_conditions import ( + HarmonicOscillatorIC, + parse_ic_for_sho, +) + class HarmonicOscillatorSystem(BaseModel): """The params for the harmonic oscillator @@ -54,19 +59,6 @@ def check_zeta_non_negative(cls, v: float) -> float: return v -class HarmonicOscillatorIC(BaseModel): - """The initial condition for a harmonic oscillator - - :cvar x0: the initial displacement - :cvar v0: the initial velocity - :cvar phi: initial phase - """ - - x0: float = 1.0 - v0: float = 0.0 - phi: float = 0.0 - - class HarmonicOscillatorBase(ABC): r"""Base class to generate time series data for a [harmonic oscillator](https://en.wikipedia.org/wiki/Harmonic_oscillator). @@ -155,6 +147,7 @@ def __init__( system: Dict[str, float], initial_condition: Optional[Dict[str, float]] = {}, ): + initial_condition = parse_ic_for_sho(system["omega"], **initial_condition) super().__init__(system, initial_condition) if self.system.type != "simple": raise ValueError( diff --git a/hamilflow/models/harmonic_oscillator/initial_conditions.py b/hamilflow/models/harmonic_oscillator/initial_conditions.py new file mode 100644 index 0000000..3407973 --- /dev/null +++ b/hamilflow/models/harmonic_oscillator/initial_conditions.py @@ -0,0 +1,39 @@ +import math +from typing import Any, Dict, cast + +from pydantic import BaseModel, Field + + +class HarmonicOscillatorIC(BaseModel): + """The initial condition for a harmonic oscillator + + :cvar x0: initial displacement + :cvar v0: initial velocity + :cvar phi: initial phase + """ + + x0: float = Field(default=1.0) + v0: float = Field(default=0.0) + phi: float = Field(default=0.0) + + +def parse_ic_for_sho(omega: float, **kwargs: Any) -> Dict[str, float]: + "Support alternative initial conditions" + match keys := {*kwargs.keys()}: + case set() if keys <= {"x0", "v0", "phi"}: + ret = {str(k): float(v) for k, v in kwargs.items()} + case set() if keys == {"x0", "t0"}: + ret = dict(x0=float(kwargs["x0"]), v0=0.0, phi=-float(omega * kwargs["t0"])) + case set() if keys == {"E", "t0"}: + ene = cast(float, kwargs["E"]) + ret = dict( + x0=math.sqrt(2 * ene) / omega, v0=0.0, phi=-float(omega * kwargs["t0"]) + ) + case _: + raise ValueError( + f"Unsupported variable names as an initial condition: {keys}" + ) + if phi := ret.get("phi"): + ret["phi"] = phi % (2 * math.pi) + + return ret diff --git a/poetry.lock b/poetry.lock index eba39c3..9933071 100644 --- a/poetry.lock +++ b/poetry.lock @@ -201,6 +201,17 @@ files = [ [package.dependencies] pycparser = "*" +[[package]] +name = "cfgv" +version = "3.4.0" +description = "Validate configuration and produce human readable error messages." +optional = false +python-versions = ">=3.8" +files = [ + {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"}, + {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"}, +] + [[package]] name = "charset-normalizer" version = "3.3.2" @@ -575,6 +586,17 @@ files = [ {file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"}, ] +[[package]] +name = "distlib" +version = "0.3.8" +description = "Distribution utilities" +optional = false +python-versions = "*" +files = [ + {file = "distlib-0.3.8-py2.py3-none-any.whl", hash = "sha256:034db59a0b96f8ca18035f36290806a9a6e6bd9d1ff91e45a7f172eb17e51784"}, + {file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"}, +] + [[package]] name = "exceptiongroup" version = "1.2.0" @@ -617,6 +639,22 @@ files = [ [package.extras] devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benchmark", "pytest-cache", "validictory"] +[[package]] +name = "filelock" +version = "3.15.1" +description = "A platform independent file lock." +optional = false +python-versions = ">=3.8" +files = [ + {file = "filelock-3.15.1-py3-none-any.whl", hash = "sha256:71b3102950e91dfc1bb4209b64be4dc8854f40e5f534428d8684f953ac847fac"}, + {file = "filelock-3.15.1.tar.gz", hash = "sha256:58a2549afdf9e02e10720eaa4d4470f56386d7a6f72edd7d0596337af8ed7ad8"}, +] + +[package.extras] +docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] +typing = ["typing-extensions (>=4.8)"] + [[package]] name = "fonttools" version = "4.50.0" @@ -713,6 +751,20 @@ files = [ [package.dependencies] colorama = ">=0.4" +[[package]] +name = "identify" +version = "2.5.36" +description = "File identification library for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "identify-2.5.36-py2.py3-none-any.whl", hash = "sha256:37d93f380f4de590500d9dba7db359d0d3da95ffe7f9de1753faa159e71e7dfa"}, + {file = "identify-2.5.36.tar.gz", hash = "sha256:e5e00f54165f9047fbebeb4a560f9acfb8af4c88232be60a488e9b68d122745d"}, +] + +[package.extras] +license = ["ukkonen"] + [[package]] name = "idna" version = "3.6" @@ -1594,6 +1646,17 @@ files = [ {file = "nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe"}, ] +[[package]] +name = "nodeenv" +version = "1.9.1" +description = "Node.js virtual environment builder" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9"}, + {file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"}, +] + [[package]] name = "numpy" version = "1.26.4" @@ -1912,6 +1975,24 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "pre-commit" +version = "3.7.1" +description = "A framework for managing and maintaining multi-language pre-commit hooks." +optional = false +python-versions = ">=3.9" +files = [ + {file = "pre_commit-3.7.1-py2.py3-none-any.whl", hash = "sha256:fae36fd1d7ad7d6a5a1c0b0d5adb2ed1a3bda5a21bf6c3e5372073d7a11cd4c5"}, + {file = "pre_commit-3.7.1.tar.gz", hash = "sha256:8ca3ad567bc78a4972a3f1a477e94a79d4597e8140a6e0b651c5e33899c3654a"}, +] + +[package.dependencies] +cfgv = ">=2.0.0" +identify = ">=1.0.0" +nodeenv = ">=0.11.1" +pyyaml = ">=5.1" +virtualenv = ">=20.10.0" + [[package]] name = "prompt-toolkit" version = "3.0.36" @@ -2910,6 +2991,26 @@ h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "virtualenv" +version = "20.26.2" +description = "Virtual Python Environment builder" +optional = false +python-versions = ">=3.7" +files = [ + {file = "virtualenv-20.26.2-py3-none-any.whl", hash = "sha256:a624db5e94f01ad993d476b9ee5346fdf7b9de43ccaee0e0197012dc838a0e9b"}, + {file = "virtualenv-20.26.2.tar.gz", hash = "sha256:82bf0f4eebbb78d36ddaee0283d43fe5736b53880b8a8cdcd37390a07ac3741c"}, +] + +[package.dependencies] +distlib = ">=0.3.7,<1" +filelock = ">=3.12.2,<4" +platformdirs = ">=3.9.1,<5" + +[package.extras] +docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] +test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] + [[package]] name = "watchdog" version = "4.0.0" @@ -3005,4 +3106,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "bdfc8bd2c281e1e3e6de09bf53f4242ceae30e56228796e7a51100f3f66b1ce8" +content-hash = "448b2995036f304d1f84d62991ea713a7b23be4903cf21369a763d3ae48eb907" diff --git a/pyproject.toml b/pyproject.toml index 21d4985..a9c300f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ plotly = "^5.19.0" [tool.poetry.group.dev.dependencies] commitizen = "*" +pre-commit = "*" [tool.commitizen] diff --git a/tests/test_models/test_harmonic_oscillator.py b/tests/test_models/test_harmonic_oscillators/test_harmonic_oscillator.py similarity index 100% rename from tests/test_models/test_harmonic_oscillator.py rename to tests/test_models/test_harmonic_oscillators/test_harmonic_oscillator.py diff --git a/tests/test_models/test_harmonic_oscillators/test_initial_conditions.py b/tests/test_models/test_harmonic_oscillators/test_initial_conditions.py new file mode 100644 index 0000000..e62b078 --- /dev/null +++ b/tests/test_models/test_harmonic_oscillators/test_initial_conditions.py @@ -0,0 +1,35 @@ +import math + +import pytest + +from hamilflow.models.harmonic_oscillator.initial_conditions import parse_ic_for_sho + + +@pytest.fixture() +def omega() -> float: + return 2 * math.pi + + +class TestParseICForSHO: + @pytest.mark.parametrize( + ("input", "expected"), + [ + ( + dict(x0=1.0, v0=1.0, phi=7.0), + dict(x0=1.0, v0=1.0, phi=7 % (2 * math.pi)), + ), + (dict(x0=1.0, t0=1.0), dict(x0=1.0, v0=0.0, phi=0.0)), + ( + dict(E=1.0, t0=1.0), + dict(x0=math.sqrt(2.0) / (2 * math.pi), v0=0.0, phi=0.0), + ), + ], + ) + def test_output( + self, omega: float, input: dict[str, float], expected: dict[str, float] + ) -> None: + assert parse_ic_for_sho(omega, **input) == expected + + def test_raise(self, omega: float) -> None: + with pytest.raises(ValueError): + parse_ic_for_sho(omega, **dict(x0=1.0, E=2.0))