Skip to content

Commit 89325d5

Browse files
Add sample batching to avoid memory issues.
1 parent e924067 commit 89325d5

File tree

1 file changed

+34
-4
lines changed

1 file changed

+34
-4
lines changed

harmonic/evidence.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def get_masks(self, chain_start_ixs: jnp.ndarray) -> jnp.ndarray:
212212

213213
return masks_arr
214214

215-
def add_chains(self, chains):
215+
def add_chains(self, chains, num_slices=None):
216216
"""Add new chains and calculate an estimate of the inverse evidence, its
217217
variance, and the variance of the variance.
218218
@@ -228,6 +228,10 @@ def add_chains(self, chains):
228228
chains (Chains): An instance of the chains class containing the chains to
229229
be used in the calculation.
230230
231+
num_slices (int): Number of slices into which the samples are divided row-wise
232+
when using flow models to avoid memory issues. If None, the samples are
233+
considered all-together. Defaults to None.
234+
231235
Raises:
232236
233237
ValueError: Raised if the input number of chains to not match the
@@ -247,10 +251,36 @@ def add_chains(self, chains):
247251
Y = chains.ln_posterior
248252
nchains = self.nchains
249253

254+
if not num_slices is None:
255+
if num_slices > X.shape[0]:
256+
raise ValueError(
257+
"Can't split chains into more blocks than there are samples."
258+
)
259+
250260
if self.batch_calculation:
251-
lnpred = self.model.predict(x=X)
252-
lnargs = lnpred - Y
253-
lnargs = lnargs.at[jnp.isinf(lnargs)].set(jnp.nan)
261+
if num_slices:
262+
# Number of rows in each slice
263+
slice_size = X.shape[0] // num_slices
264+
lnpred_list = []
265+
266+
# Calculate lnpred in row-wise slices
267+
for i in range(num_slices):
268+
start_row = i * slice_size
269+
end_row = (i + 1) * slice_size if i < num_slices - 1 else X.shape[0]
270+
X_slice = X[start_row:end_row]
271+
272+
# Predict for each row slice and append result
273+
lnpred_slice = self.model.predict(x=X_slice)
274+
lnpred_list.append(lnpred_slice)
275+
276+
# Concatenate all row slice predictions
277+
lnpred = jnp.concatenate(lnpred_list, axis=0)
278+
lnargs = lnpred - Y
279+
lnargs = lnargs.at[jnp.isinf(lnargs)].set(jnp.nan)
280+
else:
281+
lnpred = self.model.predict(x=X)
282+
lnargs = lnpred - Y
283+
lnargs = lnargs.at[jnp.isinf(lnargs)].set(jnp.nan)
254284

255285
else:
256286
lnpred = np.zeros_like(Y)

0 commit comments

Comments
 (0)