Skip to content

Commit b6c85f7

Browse files
committed
added validation of basis orthonormality by default
1 parent 4773008 commit b6c85f7

File tree

7 files changed

+268
-91
lines changed

7 files changed

+268
-91
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ libloading = "0.8"
3939
indicatif = { version = "0.17", features = ["rayon", "improved_unicode"] }
4040
console = { version = "0.15" }
4141
nalgebra = "0.33.1"
42+
rand = "0.9.0"
4243

4344
[profile.release]
4445
opt-level = 3

python/inflatox/consistency_conditions.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,19 @@ class InflationCondition:
3535
into the Rust API or C ABI.
3636
"""
3737

38-
def __init__(self, compiled_artifact: CompilationArtifact):
38+
def __init__(self, compiled_artifact: CompilationArtifact, validate_basis: bool = True):
39+
"""Passes compiled model to lib_inflx_rs.
40+
41+
### Args:
42+
- `compiled_artifact` (`CompilationArtifact`): output of `Compiler` (see its docs)
43+
- `validate_basis` (`bool`, optional): if `True`, lib_inflx_rs will check that the
44+
field-space basis defined in the `CompilationArtifact` is orthonormal
45+
at some number of random field-space points for with random parameter values. It will
46+
throw an exception if this is not the case. You may disable this if inflatox picks
47+
random points outside the domain of your model.
48+
"""
3949
self.artifact = compiled_artifact
40-
self.dylib = open_inflx_dylib(compiled_artifact.shared_object_path)
50+
self.dylib = open_inflx_dylib(compiled_artifact.shared_object_path, validate_basis)
4151

4252
def calc_V(self, x: np.ndarray, args: np.ndarray) -> float:
4353
"""calculates the scalar potential at field-space coordinates `x` with
@@ -84,9 +94,7 @@ def calc_V_array(
8494
coordinates
8595
"""
8696
n_fields = self.artifact.n_fields
87-
start_stop = np.array(
88-
[[float(start), float(stop)] for (start, stop) in zip(start, stop)]
89-
)
97+
start_stop = np.array([[float(start), float(stop)] for (start, stop) in zip(start, stop)])
9098
N = N if N is not None else (8000 for _ in range(n_fields))
9199
x = np.zeros(N)
92100
self.dylib.potential_array(x, args, start_stop)
@@ -145,9 +153,7 @@ def calc_H_array(
145153
[[float(x0_start), float(x0_stop)], [float(x1_start), float(x1_stop)]]
146154
)
147155
N = N if N is not None else (8000 for _ in range(n_fields))
148-
return self.dylib.hesse_array(
149-
np.array(n_fields, dtype=np.int64), args, start_stop
150-
)
156+
return self.dylib.hesse_array(np.array(n_fields, dtype=np.int64), args, start_stop)
151157

152158

153159
class GeneralisedAL(InflationCondition):
@@ -623,9 +629,7 @@ def consistency_rapidturn_ot(
623629
threads = threads if threads is not None else 1
624630

625631
# evaluate and return
626-
consistency_rapidturn_only_on_trajectory(
627-
self.dylib, args, x, out, progress, threads
628-
)
632+
consistency_rapidturn_only_on_trajectory(self.dylib, args, x, out, progress, threads)
629633
return out
630634

631635
def epsilon_v_ot(

0 commit comments

Comments
 (0)