Skip to content

Commit

Permalink
feat(comment): #61 reuse the class HarmonicOscillatorSystem to suppor…
Browse files Browse the repository at this point in the history
…t complex cases (#62)

* feat(comment): #61 #58 (comment)

* feat: #59 improve exception messages

* chore(typo): #61
  • Loading branch information
cmp0xff committed Jul 23, 2024
1 parent ac1a678 commit 1347914
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 26 deletions.
10 changes: 5 additions & 5 deletions hamilflow/models/discrete/d0/free_particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ class FreeParticleIC(BaseModel):

@model_validator(mode="after")
def check_dimensions_match(self) -> Self:
assert (
len(self.x0) == len(cast(Sequence, self.v0))
if isinstance(self.x0, Sequence)
else not isinstance(self.v0, Sequence)
)
if (x0_seq := isinstance(self.x0, Sequence)) != isinstance(self.v0, Sequence):
raise TypeError("x0 and v0 need both to be scalars or Sequences")
elif x0_seq and len(cast(Sequence, self.x0)) != len(cast(Sequence, self.v0)):
raise ValueError("Sequences x0 and v0 need to have the same length")

return self


Expand Down
64 changes: 46 additions & 18 deletions hamilflow/models/harmonic_oscillator.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,31 @@
from abc import ABC, abstractmethod
from functools import cached_property
from typing import Literal
from typing import Literal, Mapping, Sequence

import numpy as np
import pandas as pd
from numpy.typing import ArrayLike
from pydantic import BaseModel, computed_field, field_validator
from pydantic import BaseModel, Field, computed_field, field_validator, model_validator

try:
from typing import Self
except ImportError:
from typing_extensions import Self


class HarmonicOscillatorSystem(BaseModel):
"""The params for the harmonic oscillator
:cvar omega: angular frequency of the harmonic oscillator
:cvar zeta: damping ratio
:cvar real: use real solution (only supported for the undamped case)
"""

omega: float
zeta: float = 0.0

real: bool = Field(default=True)

@computed_field # type: ignore[misc]
@cached_property
def period(self) -> float:
Expand Down Expand Up @@ -53,6 +61,13 @@ def check_zeta_non_negative(cls, v: float) -> float:

return v

@model_validator(mode="after")
def check_real_zeta(self) -> Self:
if not self.real and self.zeta != 0.0:
raise NotImplementedError("real = False only implemented for zeta = 0.0")

return self


class HarmonicOscillatorIC(BaseModel):
"""The initial condition for a harmonic oscillator
Expand All @@ -77,23 +92,23 @@ class HarmonicOscillatorBase(ABC):

def __init__(
self,
system: dict[str, float],
initial_condition: dict[str, float] | None = None,
system: Mapping[str, float | int | bool],
initial_condition: Mapping[str, float | int] | None = None,
) -> None:
initial_condition = initial_condition or {}
self.system = HarmonicOscillatorSystem.model_validate(system)
self.initial_condition = HarmonicOscillatorIC.model_validate(initial_condition)

@cached_property
def definition(self) -> dict[str, float]:
def definition(self) -> dict[str, dict[str, float | int | bool]]:
"""model params and initial conditions defined as a dictionary."""
return {
"system": self.system.model_dump(),
"initial_condition": self.initial_condition.model_dump(),
}

@abstractmethod
def _x(self, t: ArrayLike) -> ArrayLike:
def _x(self, t: float | int | Sequence[float | int]) -> ArrayLike:
r"""Solution to simple harmonic oscillators."""
...

Expand Down Expand Up @@ -129,13 +144,17 @@ class SimpleHarmonicOscillator(HarmonicOscillatorBase):
The mass behaves like a simple harmonic oscillator.
In general, the solution to a simple harmonic oscillator is
In general, the solution to a real simple harmonic oscillator is
$$
x(t) = A \cos(\omega t + \phi),
$$
where $\omega$ is the angular frequency, $\phi$ is the initial phase, and $A$ is the amplitude.
The complex solution is
$$
x(t) = A \exp(-\mathbb{i} (\omega t + \phi)).
$$
To use this generator,
Expand All @@ -153,23 +172,32 @@ class SimpleHarmonicOscillator(HarmonicOscillatorBase):

def __init__(
self,
system: dict[str, float],
initial_condition: dict[str, float] | None = None,
system: Mapping[str, float | int | bool],
initial_condition: Mapping[str, float | int] | None = None,
) -> None:
super().__init__(system, initial_condition)
if self.system.type != "simple":
raise ValueError(
f"System is not a Simple Harmonic Oscillator: {self.system}"
)

def _x(self, t: ArrayLike) -> ArrayLike:
def _f(self, phase: float | int | Sequence[float | int]) -> np.ndarray:
np_phase = np.array(phase, copy=False)
return np.cos(np_phase) if self.system.real else np.exp(-1j * np_phase)

def _x(self, t: float | int | Sequence[float | int]) -> np.ndarray:
r"""Solution to simple harmonic oscillators:
$$
x(t) = x_0 \cos(\omega t + \phi).
x(t) = x_0 \cos(\omega t + \phi)
$$
if real, or
$$
x(t) = x_0 \exp(-\mathbb{i} (\omega t + \phi))
$$
if not real.
"""
return self.initial_condition.x0 * np.cos(
return self.initial_condition.x0 * self._f(
self.system.omega * t + self.initial_condition.phi
)

Expand Down Expand Up @@ -225,8 +253,8 @@ class DampedHarmonicOscillator(HarmonicOscillatorBase):

def __init__(
self,
system: dict[str, float],
initial_condition: dict[str, float] | None = None,
system: Mapping[str, float | int],
initial_condition: Mapping[str, float | int] | None = None,
) -> None:
super().__init__(system, initial_condition)
if self.system.type == "simple":
Expand All @@ -235,7 +263,7 @@ def __init__(
f"This is a simple harmonic oscillator, use `SimpleHarmonicOscillator`."
)

def _x_under_damped(self, t: float | np.ndarray) -> float | np.ndarray:
def _x_under_damped(self, t: float | int | Sequence[float | int]) -> ArrayLike:
r"""Solution to under damped harmonic oscillators:
$$
Expand All @@ -260,7 +288,7 @@ def _x_under_damped(self, t: float | np.ndarray) -> float | np.ndarray:
* np.sin(omega_damp * t)
) * np.exp(-self.system.zeta * self.system.omega * t)

def _x_critical_damped(self, t: float | np.ndarray) -> float | np.ndarray:
def _x_critical_damped(self, t: float | int | Sequence[float | int]) -> ArrayLike:
r"""Solution to critical damped harmonic oscillators:
$$
Expand All @@ -278,7 +306,7 @@ def _x_critical_damped(self, t: float | np.ndarray) -> float | np.ndarray:
-self.system.zeta * self.system.omega * t
)

def _x_over_damped(self, t: float | np.ndarray) -> float | np.ndarray:
def _x_over_damped(self, t: float | int | Sequence[float | int]) -> ArrayLike:
r"""Solution to over harmonic oscillators:
$$
Expand All @@ -304,7 +332,7 @@ def _x_over_damped(self, t: float | np.ndarray) -> float | np.ndarray:
* np.sinh(gamma_damp * t)
) * np.exp(-self.system.zeta * self.system.omega * t)

def _x(self, t: float | np.ndarray) -> float | np.ndarray:
def _x(self, t: float | int | Sequence[float | int]) -> ArrayLike:
r"""Solution to damped harmonic oscillators."""
if self.system.type == "under_damped":
x = self._x_under_damped(t)
Expand Down
10 changes: 7 additions & 3 deletions tests/test_models/discrete/d0/test_free_particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@ def test_constructor(
) -> None:
assert FreeParticleIC(x0=x0, v0=v0)

@pytest.mark.parametrize(("x0", "v0"), [(1, (2,)), ((1,), (2, 3))])
def test_raise(self, x0: int | Sequence[int], v0: int | Sequence[int]) -> None:
with pytest.raises(ValidationError):
@pytest.mark.parametrize(
("x0", "v0", "expected"), [(1, (2,), TypeError), ((1,), (2, 3), ValueError)]
)
def test_raise(
self, x0: int | Sequence[int], v0: Sequence[int], expected: type[Exception]
) -> None:
with pytest.raises(expected):
FreeParticleIC(x0=x0, v0=v0)


Expand Down
13 changes: 13 additions & 0 deletions tests/test_models/test_harmonic_oscillator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pandas as pd
import pytest
from pydantic import ValidationError

from hamilflow.models.harmonic_oscillator import (
DampedHarmonicOscillator,
Expand Down Expand Up @@ -136,3 +137,15 @@ def test_criticaldamped_harmonic_oscillator(omega, zeta, expected):
df = ho(n_periods=1, n_samples_per_period=10)

pd.testing.assert_frame_equal(df, pd.DataFrame(expected))


class TestHarmonicOscillatorSystem:
@pytest.mark.parametrize("omega", [-1, 1])
def test_complex(self, omega: int) -> None:
HarmonicOscillatorSystem(omega=omega, real=False)

@pytest.mark.parametrize("omega", [-1, 1])
@pytest.mark.parametrize("zeta", [0.5, 1.0, 1.5])
def test_raise_complex(self, omega: int, zeta: float) -> None:
with pytest.raises(NotImplementedError):
HarmonicOscillatorSystem(omega=omega, zeta=zeta, real=False)

0 comments on commit 1347914

Please sign in to comment.