Skip to content

Commit 1347914

Browse files
authored
feat(comment): #61 reuse the class HarmonicOscillatorSystem to support complex cases (#62)
* feat(comment): #61 #58 (comment) * feat: #59 improve exception messages * chore(typo): #61
1 parent ac1a678 commit 1347914

File tree

4 files changed

+71
-26
lines changed

4 files changed

+71
-26
lines changed

hamilflow/models/discrete/d0/free_particle.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ class FreeParticleIC(BaseModel):
2323

2424
@model_validator(mode="after")
2525
def check_dimensions_match(self) -> Self:
26-
assert (
27-
len(self.x0) == len(cast(Sequence, self.v0))
28-
if isinstance(self.x0, Sequence)
29-
else not isinstance(self.v0, Sequence)
30-
)
26+
if (x0_seq := isinstance(self.x0, Sequence)) != isinstance(self.v0, Sequence):
27+
raise TypeError("x0 and v0 need both to be scalars or Sequences")
28+
elif x0_seq and len(cast(Sequence, self.x0)) != len(cast(Sequence, self.v0)):
29+
raise ValueError("Sequences x0 and v0 need to have the same length")
30+
3131
return self
3232

3333

hamilflow/models/harmonic_oscillator.py

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,31 @@
11
from abc import ABC, abstractmethod
22
from functools import cached_property
3-
from typing import Literal
3+
from typing import Literal, Mapping, Sequence
44

55
import numpy as np
66
import pandas as pd
77
from numpy.typing import ArrayLike
8-
from pydantic import BaseModel, computed_field, field_validator
8+
from pydantic import BaseModel, Field, computed_field, field_validator, model_validator
9+
10+
try:
11+
from typing import Self
12+
except ImportError:
13+
from typing_extensions import Self
914

1015

1116
class HarmonicOscillatorSystem(BaseModel):
1217
"""The params for the harmonic oscillator
1318
1419
:cvar omega: angular frequency of the harmonic oscillator
1520
:cvar zeta: damping ratio
21+
:cvar real: use real solution (only supported for the undamped case)
1622
"""
1723

1824
omega: float
1925
zeta: float = 0.0
2026

27+
real: bool = Field(default=True)
28+
2129
@computed_field # type: ignore[misc]
2230
@cached_property
2331
def period(self) -> float:
@@ -53,6 +61,13 @@ def check_zeta_non_negative(cls, v: float) -> float:
5361

5462
return v
5563

64+
@model_validator(mode="after")
65+
def check_real_zeta(self) -> Self:
66+
if not self.real and self.zeta != 0.0:
67+
raise NotImplementedError("real = False only implemented for zeta = 0.0")
68+
69+
return self
70+
5671

5772
class HarmonicOscillatorIC(BaseModel):
5873
"""The initial condition for a harmonic oscillator
@@ -77,23 +92,23 @@ class HarmonicOscillatorBase(ABC):
7792

7893
def __init__(
7994
self,
80-
system: dict[str, float],
81-
initial_condition: dict[str, float] | None = None,
95+
system: Mapping[str, float | int | bool],
96+
initial_condition: Mapping[str, float | int] | None = None,
8297
) -> None:
8398
initial_condition = initial_condition or {}
8499
self.system = HarmonicOscillatorSystem.model_validate(system)
85100
self.initial_condition = HarmonicOscillatorIC.model_validate(initial_condition)
86101

87102
@cached_property
88-
def definition(self) -> dict[str, float]:
103+
def definition(self) -> dict[str, dict[str, float | int | bool]]:
89104
"""model params and initial conditions defined as a dictionary."""
90105
return {
91106
"system": self.system.model_dump(),
92107
"initial_condition": self.initial_condition.model_dump(),
93108
}
94109

95110
@abstractmethod
96-
def _x(self, t: ArrayLike) -> ArrayLike:
111+
def _x(self, t: float | int | Sequence[float | int]) -> ArrayLike:
97112
r"""Solution to simple harmonic oscillators."""
98113
...
99114

@@ -129,13 +144,17 @@ class SimpleHarmonicOscillator(HarmonicOscillatorBase):
129144
130145
The mass behaves like a simple harmonic oscillator.
131146
132-
In general, the solution to a simple harmonic oscillator is
147+
In general, the solution to a real simple harmonic oscillator is
133148
134149
$$
135150
x(t) = A \cos(\omega t + \phi),
136151
$$
137152
138153
where $\omega$ is the angular frequency, $\phi$ is the initial phase, and $A$ is the amplitude.
154+
The complex solution is
155+
$$
156+
x(t) = A \exp(-\mathbb{i} (\omega t + \phi)).
157+
$$
139158
140159
141160
To use this generator,
@@ -153,23 +172,32 @@ class SimpleHarmonicOscillator(HarmonicOscillatorBase):
153172

154173
def __init__(
155174
self,
156-
system: dict[str, float],
157-
initial_condition: dict[str, float] | None = None,
175+
system: Mapping[str, float | int | bool],
176+
initial_condition: Mapping[str, float | int] | None = None,
158177
) -> None:
159178
super().__init__(system, initial_condition)
160179
if self.system.type != "simple":
161180
raise ValueError(
162181
f"System is not a Simple Harmonic Oscillator: {self.system}"
163182
)
164183

165-
def _x(self, t: ArrayLike) -> ArrayLike:
184+
def _f(self, phase: float | int | Sequence[float | int]) -> np.ndarray:
185+
np_phase = np.array(phase, copy=False)
186+
return np.cos(np_phase) if self.system.real else np.exp(-1j * np_phase)
187+
188+
def _x(self, t: float | int | Sequence[float | int]) -> np.ndarray:
166189
r"""Solution to simple harmonic oscillators:
167190
168191
$$
169-
x(t) = x_0 \cos(\omega t + \phi).
192+
x(t) = x_0 \cos(\omega t + \phi)
193+
$$
194+
if real, or
195+
$$
196+
x(t) = x_0 \exp(-\mathbb{i} (\omega t + \phi))
170197
$$
198+
if not real.
171199
"""
172-
return self.initial_condition.x0 * np.cos(
200+
return self.initial_condition.x0 * self._f(
173201
self.system.omega * t + self.initial_condition.phi
174202
)
175203

@@ -225,8 +253,8 @@ class DampedHarmonicOscillator(HarmonicOscillatorBase):
225253

226254
def __init__(
227255
self,
228-
system: dict[str, float],
229-
initial_condition: dict[str, float] | None = None,
256+
system: Mapping[str, float | int],
257+
initial_condition: Mapping[str, float | int] | None = None,
230258
) -> None:
231259
super().__init__(system, initial_condition)
232260
if self.system.type == "simple":
@@ -235,7 +263,7 @@ def __init__(
235263
f"This is a simple harmonic oscillator, use `SimpleHarmonicOscillator`."
236264
)
237265

238-
def _x_under_damped(self, t: float | np.ndarray) -> float | np.ndarray:
266+
def _x_under_damped(self, t: float | int | Sequence[float | int]) -> ArrayLike:
239267
r"""Solution to under damped harmonic oscillators:
240268
241269
$$
@@ -260,7 +288,7 @@ def _x_under_damped(self, t: float | np.ndarray) -> float | np.ndarray:
260288
* np.sin(omega_damp * t)
261289
) * np.exp(-self.system.zeta * self.system.omega * t)
262290

263-
def _x_critical_damped(self, t: float | np.ndarray) -> float | np.ndarray:
291+
def _x_critical_damped(self, t: float | int | Sequence[float | int]) -> ArrayLike:
264292
r"""Solution to critical damped harmonic oscillators:
265293
266294
$$
@@ -278,7 +306,7 @@ def _x_critical_damped(self, t: float | np.ndarray) -> float | np.ndarray:
278306
-self.system.zeta * self.system.omega * t
279307
)
280308

281-
def _x_over_damped(self, t: float | np.ndarray) -> float | np.ndarray:
309+
def _x_over_damped(self, t: float | int | Sequence[float | int]) -> ArrayLike:
282310
r"""Solution to over harmonic oscillators:
283311
284312
$$
@@ -304,7 +332,7 @@ def _x_over_damped(self, t: float | np.ndarray) -> float | np.ndarray:
304332
* np.sinh(gamma_damp * t)
305333
) * np.exp(-self.system.zeta * self.system.omega * t)
306334

307-
def _x(self, t: float | np.ndarray) -> float | np.ndarray:
335+
def _x(self, t: float | int | Sequence[float | int]) -> ArrayLike:
308336
r"""Solution to damped harmonic oscillators."""
309337
if self.system.type == "under_damped":
310338
x = self._x_under_damped(t)

tests/test_models/discrete/d0/test_free_particle.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,13 @@ def test_constructor(
1515
) -> None:
1616
assert FreeParticleIC(x0=x0, v0=v0)
1717

18-
@pytest.mark.parametrize(("x0", "v0"), [(1, (2,)), ((1,), (2, 3))])
19-
def test_raise(self, x0: int | Sequence[int], v0: int | Sequence[int]) -> None:
20-
with pytest.raises(ValidationError):
18+
@pytest.mark.parametrize(
19+
("x0", "v0", "expected"), [(1, (2,), TypeError), ((1,), (2, 3), ValueError)]
20+
)
21+
def test_raise(
22+
self, x0: int | Sequence[int], v0: Sequence[int], expected: type[Exception]
23+
) -> None:
24+
with pytest.raises(expected):
2125
FreeParticleIC(x0=x0, v0=v0)
2226

2327

tests/test_models/test_harmonic_oscillator.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pandas as pd
22
import pytest
3+
from pydantic import ValidationError
34

45
from hamilflow.models.harmonic_oscillator import (
56
DampedHarmonicOscillator,
@@ -136,3 +137,15 @@ def test_criticaldamped_harmonic_oscillator(omega, zeta, expected):
136137
df = ho(n_periods=1, n_samples_per_period=10)
137138

138139
pd.testing.assert_frame_equal(df, pd.DataFrame(expected))
140+
141+
142+
class TestHarmonicOscillatorSystem:
143+
@pytest.mark.parametrize("omega", [-1, 1])
144+
def test_complex(self, omega: int) -> None:
145+
HarmonicOscillatorSystem(omega=omega, real=False)
146+
147+
@pytest.mark.parametrize("omega", [-1, 1])
148+
@pytest.mark.parametrize("zeta", [0.5, 1.0, 1.5])
149+
def test_raise_complex(self, omega: int, zeta: float) -> None:
150+
with pytest.raises(NotImplementedError):
151+
HarmonicOscillatorSystem(omega=omega, zeta=zeta, real=False)

0 commit comments

Comments
 (0)