-
Notifications
You must be signed in to change notification settings - Fork 13
Description
Is your feature request related to a problem? Please describe.
The Forward-backward algorithm can be parallelized with an associative scan, since it's just a bunch of matmuls. However, this is true within a session, multiple sessions require a "message resetting" that breaks this symmetry. We should decide if it is worth spending time to parallelize the computation with this in mind.
Describe the solution you'd like
Possible solution, implement multiple versions of the algorithm:
1 - associative_scan version of the forward and backward passes
2 - keep the current version that uses a regular scan plus resetting.
At fit time, check if:
a. There is a single session -> call the associative_scan algorithm.
b. There are multiple session of the same duration: -> vmap the associative_scan and call that algorithm (vmapping over sessions)
c. There are multiple session with different durations -> call the current scan-based algorithm.