@@ -212,7 +212,7 @@ def get_masks(self, chain_start_ixs: jnp.ndarray) -> jnp.ndarray:
212
212
213
213
return masks_arr
214
214
215
- def add_chains (self , chains ):
215
+ def add_chains (self , chains , num_slices = None ):
216
216
"""Add new chains and calculate an estimate of the inverse evidence, its
217
217
variance, and the variance of the variance.
218
218
@@ -228,6 +228,10 @@ def add_chains(self, chains):
228
228
chains (Chains): An instance of the chains class containing the chains to
229
229
be used in the calculation.
230
230
231
+ num_slices (int): Number of slices into which the samples are divided row-wise
232
+ when using flow models to avoid memory issues. If None, the samples are
233
+ considered all-together. Defaults to None.
234
+
231
235
Raises:
232
236
233
237
ValueError: Raised if the input number of chains to not match the
@@ -247,10 +251,36 @@ def add_chains(self, chains):
247
251
Y = chains .ln_posterior
248
252
nchains = self .nchains
249
253
254
+ if not num_slices is None :
255
+ if num_slices > X .shape [0 ]:
256
+ raise ValueError (
257
+ "Can't split chains into more blocks than there are samples."
258
+ )
259
+
250
260
if self .batch_calculation :
251
- lnpred = self .model .predict (x = X )
252
- lnargs = lnpred - Y
253
- lnargs = lnargs .at [jnp .isinf (lnargs )].set (jnp .nan )
261
+ if num_slices :
262
+ # Number of rows in each slice
263
+ slice_size = X .shape [0 ] // num_slices
264
+ lnpred_list = []
265
+
266
+ # Calculate lnpred in row-wise slices
267
+ for i in range (num_slices ):
268
+ start_row = i * slice_size
269
+ end_row = (i + 1 ) * slice_size if i < num_slices - 1 else X .shape [0 ]
270
+ X_slice = X [start_row :end_row ]
271
+
272
+ # Predict for each row slice and append result
273
+ lnpred_slice = self .model .predict (x = X_slice )
274
+ lnpred_list .append (lnpred_slice )
275
+
276
+ # Concatenate all row slice predictions
277
+ lnpred = jnp .concatenate (lnpred_list , axis = 0 )
278
+ lnargs = lnpred - Y
279
+ lnargs = lnargs .at [jnp .isinf (lnargs )].set (jnp .nan )
280
+ else :
281
+ lnpred = self .model .predict (x = X )
282
+ lnargs = lnpred - Y
283
+ lnargs = lnargs .at [jnp .isinf (lnargs )].set (jnp .nan )
254
284
255
285
else :
256
286
lnpred = np .zeros_like (Y )
0 commit comments