@@ -35,9 +35,19 @@ class InflationCondition:
35
35
into the Rust API or C ABI.
36
36
"""
37
37
38
- def __init__ (self , compiled_artifact : CompilationArtifact ):
38
+ def __init__ (self , compiled_artifact : CompilationArtifact , validate_basis : bool = True ):
39
+ """Passes compiled model to lib_inflx_rs.
40
+
41
+ ### Args:
42
+ - `compiled_artifact` (`CompilationArtifact`): output of `Compiler` (see its docs)
43
+ - `validate_basis` (`bool`, optional): if `True`, lib_inflx_rs will check that the
44
+ field-space basis defined in the `CompilationArtifact` is orthonormal
45
+ at some number of random field-space points for with random parameter values. It will
46
+ throw an exception if this is not the case. You may disable this if inflatox picks
47
+ random points outside the domain of your model.
48
+ """
39
49
self .artifact = compiled_artifact
40
- self .dylib = open_inflx_dylib (compiled_artifact .shared_object_path )
50
+ self .dylib = open_inflx_dylib (compiled_artifact .shared_object_path , validate_basis )
41
51
42
52
def calc_V (self , x : np .ndarray , args : np .ndarray ) -> float :
43
53
"""calculates the scalar potential at field-space coordinates `x` with
@@ -84,9 +94,7 @@ def calc_V_array(
84
94
coordinates
85
95
"""
86
96
n_fields = self .artifact .n_fields
87
- start_stop = np .array (
88
- [[float (start ), float (stop )] for (start , stop ) in zip (start , stop )]
89
- )
97
+ start_stop = np .array ([[float (start ), float (stop )] for (start , stop ) in zip (start , stop )])
90
98
N = N if N is not None else (8000 for _ in range (n_fields ))
91
99
x = np .zeros (N )
92
100
self .dylib .potential_array (x , args , start_stop )
@@ -145,9 +153,7 @@ def calc_H_array(
145
153
[[float (x0_start ), float (x0_stop )], [float (x1_start ), float (x1_stop )]]
146
154
)
147
155
N = N if N is not None else (8000 for _ in range (n_fields ))
148
- return self .dylib .hesse_array (
149
- np .array (n_fields , dtype = np .int64 ), args , start_stop
150
- )
156
+ return self .dylib .hesse_array (np .array (n_fields , dtype = np .int64 ), args , start_stop )
151
157
152
158
153
159
class GeneralisedAL (InflationCondition ):
@@ -623,9 +629,7 @@ def consistency_rapidturn_ot(
623
629
threads = threads if threads is not None else 1
624
630
625
631
# evaluate and return
626
- consistency_rapidturn_only_on_trajectory (
627
- self .dylib , args , x , out , progress , threads
628
- )
632
+ consistency_rapidturn_only_on_trajectory (self .dylib , args , x , out , progress , threads )
629
633
return out
630
634
631
635
def epsilon_v_ot (
0 commit comments