diff --git a/hamilflow/models/discrete/d0/free_particle.py b/hamilflow/models/discrete/d0/free_particle.py index 8512bd0..8435358 100644 --- a/hamilflow/models/discrete/d0/free_particle.py +++ b/hamilflow/models/discrete/d0/free_particle.py @@ -23,11 +23,11 @@ class FreeParticleIC(BaseModel): @model_validator(mode="after") def check_dimensions_match(self) -> Self: - assert ( - len(self.x0) == len(cast(Sequence, self.v0)) - if isinstance(self.x0, Sequence) - else not isinstance(self.v0, Sequence) - ) + if (x0_seq := isinstance(self.x0, Sequence)) != isinstance(self.v0, Sequence): + raise TypeError("x0 and v0 need both to be scalars or Sequences") + elif x0_seq and len(cast(Sequence, self.x0)) != len(cast(Sequence, self.v0)): + raise ValueError("Sequences x0 and v0 need to have the same length") + return self diff --git a/hamilflow/models/harmonic_oscillator.py b/hamilflow/models/harmonic_oscillator.py index 5b56293..955b8a7 100644 --- a/hamilflow/models/harmonic_oscillator.py +++ b/hamilflow/models/harmonic_oscillator.py @@ -1,11 +1,16 @@ from abc import ABC, abstractmethod from functools import cached_property -from typing import Literal +from typing import Literal, Mapping, Sequence import numpy as np import pandas as pd from numpy.typing import ArrayLike -from pydantic import BaseModel, computed_field, field_validator +from pydantic import BaseModel, Field, computed_field, field_validator, model_validator + +try: + from typing import Self +except ImportError: + from typing_extensions import Self class HarmonicOscillatorSystem(BaseModel): @@ -13,11 +18,14 @@ class HarmonicOscillatorSystem(BaseModel): :cvar omega: angular frequency of the harmonic oscillator :cvar zeta: damping ratio + :cvar real: use real solution (only supported for the undamped case) """ omega: float zeta: float = 0.0 + real: bool = Field(default=True) + @computed_field # type: ignore[misc] @cached_property def period(self) -> float: @@ -53,6 +61,13 @@ def check_zeta_non_negative(cls, v: float) -> float: return v + @model_validator(mode="after") + def check_real_zeta(self) -> Self: + if not self.real and self.zeta != 0.0: + raise NotImplementedError("real = False only implemented for zeta = 0.0") + + return self + class HarmonicOscillatorIC(BaseModel): """The initial condition for a harmonic oscillator @@ -77,15 +92,15 @@ class HarmonicOscillatorBase(ABC): def __init__( self, - system: dict[str, float], - initial_condition: dict[str, float] | None = None, + system: Mapping[str, float | int | bool], + initial_condition: Mapping[str, float | int] | None = None, ) -> None: initial_condition = initial_condition or {} self.system = HarmonicOscillatorSystem.model_validate(system) self.initial_condition = HarmonicOscillatorIC.model_validate(initial_condition) @cached_property - def definition(self) -> dict[str, float]: + def definition(self) -> dict[str, dict[str, float | int | bool]]: """model params and initial conditions defined as a dictionary.""" return { "system": self.system.model_dump(), @@ -93,7 +108,7 @@ def definition(self) -> dict[str, float]: } @abstractmethod - def _x(self, t: ArrayLike) -> ArrayLike: + def _x(self, t: float | int | Sequence[float | int]) -> ArrayLike: r"""Solution to simple harmonic oscillators.""" ... @@ -129,13 +144,17 @@ class SimpleHarmonicOscillator(HarmonicOscillatorBase): The mass behaves like a simple harmonic oscillator. - In general, the solution to a simple harmonic oscillator is + In general, the solution to a real simple harmonic oscillator is $$ x(t) = A \cos(\omega t + \phi), $$ where $\omega$ is the angular frequency, $\phi$ is the initial phase, and $A$ is the amplitude. + The complex solution is + $$ + x(t) = A \exp(-\mathbb{i} (\omega t + \phi)). + $$ To use this generator, @@ -153,8 +172,8 @@ class SimpleHarmonicOscillator(HarmonicOscillatorBase): def __init__( self, - system: dict[str, float], - initial_condition: dict[str, float] | None = None, + system: Mapping[str, float | int | bool], + initial_condition: Mapping[str, float | int] | None = None, ) -> None: super().__init__(system, initial_condition) if self.system.type != "simple": @@ -162,14 +181,23 @@ def __init__( f"System is not a Simple Harmonic Oscillator: {self.system}" ) - def _x(self, t: ArrayLike) -> ArrayLike: + def _f(self, phase: float | int | Sequence[float | int]) -> np.ndarray: + np_phase = np.array(phase, copy=False) + return np.cos(np_phase) if self.system.real else np.exp(-1j * np_phase) + + def _x(self, t: float | int | Sequence[float | int]) -> np.ndarray: r"""Solution to simple harmonic oscillators: $$ - x(t) = x_0 \cos(\omega t + \phi). + x(t) = x_0 \cos(\omega t + \phi) + $$ + if real, or + $$ + x(t) = x_0 \exp(-\mathbb{i} (\omega t + \phi)) $$ + if not real. """ - return self.initial_condition.x0 * np.cos( + return self.initial_condition.x0 * self._f( self.system.omega * t + self.initial_condition.phi ) @@ -225,8 +253,8 @@ class DampedHarmonicOscillator(HarmonicOscillatorBase): def __init__( self, - system: dict[str, float], - initial_condition: dict[str, float] | None = None, + system: Mapping[str, float | int], + initial_condition: Mapping[str, float | int] | None = None, ) -> None: super().__init__(system, initial_condition) if self.system.type == "simple": @@ -235,7 +263,7 @@ def __init__( f"This is a simple harmonic oscillator, use `SimpleHarmonicOscillator`." ) - def _x_under_damped(self, t: float | np.ndarray) -> float | np.ndarray: + def _x_under_damped(self, t: float | int | Sequence[float | int]) -> ArrayLike: r"""Solution to under damped harmonic oscillators: $$ @@ -260,7 +288,7 @@ def _x_under_damped(self, t: float | np.ndarray) -> float | np.ndarray: * np.sin(omega_damp * t) ) * np.exp(-self.system.zeta * self.system.omega * t) - def _x_critical_damped(self, t: float | np.ndarray) -> float | np.ndarray: + def _x_critical_damped(self, t: float | int | Sequence[float | int]) -> ArrayLike: r"""Solution to critical damped harmonic oscillators: $$ @@ -278,7 +306,7 @@ def _x_critical_damped(self, t: float | np.ndarray) -> float | np.ndarray: -self.system.zeta * self.system.omega * t ) - def _x_over_damped(self, t: float | np.ndarray) -> float | np.ndarray: + def _x_over_damped(self, t: float | int | Sequence[float | int]) -> ArrayLike: r"""Solution to over harmonic oscillators: $$ @@ -304,7 +332,7 @@ def _x_over_damped(self, t: float | np.ndarray) -> float | np.ndarray: * np.sinh(gamma_damp * t) ) * np.exp(-self.system.zeta * self.system.omega * t) - def _x(self, t: float | np.ndarray) -> float | np.ndarray: + def _x(self, t: float | int | Sequence[float | int]) -> ArrayLike: r"""Solution to damped harmonic oscillators.""" if self.system.type == "under_damped": x = self._x_under_damped(t) diff --git a/tests/test_models/discrete/d0/test_free_particle.py b/tests/test_models/discrete/d0/test_free_particle.py index b2a6731..574bee0 100644 --- a/tests/test_models/discrete/d0/test_free_particle.py +++ b/tests/test_models/discrete/d0/test_free_particle.py @@ -15,9 +15,13 @@ def test_constructor( ) -> None: assert FreeParticleIC(x0=x0, v0=v0) - @pytest.mark.parametrize(("x0", "v0"), [(1, (2,)), ((1,), (2, 3))]) - def test_raise(self, x0: int | Sequence[int], v0: int | Sequence[int]) -> None: - with pytest.raises(ValidationError): + @pytest.mark.parametrize( + ("x0", "v0", "expected"), [(1, (2,), TypeError), ((1,), (2, 3), ValueError)] + ) + def test_raise( + self, x0: int | Sequence[int], v0: Sequence[int], expected: type[Exception] + ) -> None: + with pytest.raises(expected): FreeParticleIC(x0=x0, v0=v0) diff --git a/tests/test_models/test_harmonic_oscillator.py b/tests/test_models/test_harmonic_oscillator.py index 24dd291..94bfb0a 100644 --- a/tests/test_models/test_harmonic_oscillator.py +++ b/tests/test_models/test_harmonic_oscillator.py @@ -1,5 +1,6 @@ import pandas as pd import pytest +from pydantic import ValidationError from hamilflow.models.harmonic_oscillator import ( DampedHarmonicOscillator, @@ -136,3 +137,15 @@ def test_criticaldamped_harmonic_oscillator(omega, zeta, expected): df = ho(n_periods=1, n_samples_per_period=10) pd.testing.assert_frame_equal(df, pd.DataFrame(expected)) + + +class TestHarmonicOscillatorSystem: + @pytest.mark.parametrize("omega", [-1, 1]) + def test_complex(self, omega: int) -> None: + HarmonicOscillatorSystem(omega=omega, real=False) + + @pytest.mark.parametrize("omega", [-1, 1]) + @pytest.mark.parametrize("zeta", [0.5, 1.0, 1.5]) + def test_raise_complex(self, omega: int, zeta: float) -> None: + with pytest.raises(NotImplementedError): + HarmonicOscillatorSystem(omega=omega, zeta=zeta, real=False)