@@ -207,18 +207,28 @@ def get_partial_traces(traces):
207207 ):
208208 if init_state is not None :
209209 # MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
210- trace = jnp .atleast_1d (trace )
211- init_state = jnp .expand_dims (
212- init_state , range (trace .ndim - init_state .ndim )
213- )
214- full_trace = jnp .concatenate ([init_state , trace ], axis = 0 )
215210 buffer_size = buffer .shape [0 ]
211+ if trace .shape [0 ] > buffer_size :
212+ # Trace is longer than buffer, keep just the last `buffer.shape[0]` entries
213+ partial_trace = trace [- buffer_size :]
214+ else :
215+ # Trace is shorter than buffer, this happens when we keep the initial_state
216+ if init_state .ndim < buffer .ndim :
217+ init_state = init_state [None ]
218+ if (
219+ n_init_needed := buffer_size - trace .shape [0 ]
220+ ) < init_state .shape [0 ]:
221+ # We may not need to keep all the initial states
222+ init_state = init_state [- n_init_needed :]
223+ partial_trace = jnp .concatenate ([init_state , trace ], axis = 0 )
216224 else :
217225 # NIT-SOT: Buffer is just the number of entries that should be returned
218- full_trace = jnp .atleast_1d (trace )
219226 buffer_size = buffer
227+ partial_trace = (
228+ trace [- buffer_size :] if trace .shape [0 ] > buffer else trace
229+ )
220230
221- partial_trace = full_trace [ - buffer_size :]
231+ assert partial_trace . shape [ 0 ] == buffer_size
222232 partial_traces .append (partial_trace )
223233
224234 return partial_traces
0 commit comments