Skip to content

Commit 24cf099

Browse files
committed
Optimize partial trace definition
We rarely need to both concatenate and truncate. And if we need, we can truncate the initial state directly
1 parent 65ecc35 commit 24cf099

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

pytensor/link/jax/dispatch/scan.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -207,18 +207,28 @@ def get_partial_traces(traces):
207207
):
208208
if init_state is not None:
209209
# MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
210-
trace = jnp.atleast_1d(trace)
211-
init_state = jnp.expand_dims(
212-
init_state, range(trace.ndim - init_state.ndim)
213-
)
214-
full_trace = jnp.concatenate([init_state, trace], axis=0)
215210
buffer_size = buffer.shape[0]
211+
if trace.shape[0] > buffer_size:
212+
# Trace is longer than buffer, keep just the last `buffer.shape[0]` entries
213+
partial_trace = trace[-buffer_size:]
214+
else:
215+
# Trace is shorter than buffer, this happens when we keep the initial_state
216+
if init_state.ndim < buffer.ndim:
217+
init_state = init_state[None]
218+
if (
219+
n_init_needed := buffer_size - trace.shape[0]
220+
) < init_state.shape[0]:
221+
# We may not need to keep all the initial states
222+
init_state = init_state[-n_init_needed:]
223+
partial_trace = jnp.concatenate([init_state, trace], axis=0)
216224
else:
217225
# NIT-SOT: Buffer is just the number of entries that should be returned
218-
full_trace = jnp.atleast_1d(trace)
219226
buffer_size = buffer
227+
partial_trace = (
228+
trace[-buffer_size:] if trace.shape[0] > buffer else trace
229+
)
220230

221-
partial_trace = full_trace[-buffer_size:]
231+
assert partial_trace.shape[0] == buffer_size
222232
partial_traces.append(partial_trace)
223233

224234
return partial_traces

0 commit comments

Comments
 (0)