@@ -338,6 +338,38 @@ def body_fn(carry, xs):
338338 return betas
339339
340340
341+ def initialize_new_session (n_samples , is_new_session ):
342+ """Initialize new session indicator."""
343+ # Revise if the data is one single session or multiple sessions.
344+ # If new_sess is not provided, assume one session
345+ if is_new_session is None :
346+ # default: all False, but first time bin must be True
347+ is_new_session = jax .lax .dynamic_update_index_in_dim (
348+ jnp .zeros (n_samples , dtype = bool ), True , 0 , axis = 0
349+ )
350+ else :
351+ # use the user-provided tree, but force the first time bin to be True
352+ is_new_session = jax .lax .dynamic_update_index_in_dim (
353+ jnp .asarray (is_new_session , dtype = bool ), True , 0 , axis = 0
354+ )
355+
356+ return is_new_session
357+
358+
359+ def compute_rate_per_state (X , glm_params , inverse_link_function ):
360+ """Compute the GLM mean per state."""
361+ coef , intercept = glm_params
362+
363+ # Predicted y
364+ if coef .ndim > 2 :
365+ predicted_rate_given_state = inverse_link_function (
366+ jnp .einsum ("ik, kjw->ijw" , X , coef ) + intercept
367+ )
368+ else :
369+ predicted_rate_given_state = inverse_link_function (X @ coef + intercept )
370+ return predicted_rate_given_state
371+
372+
341373@partial (jax .jit , static_argnames = ["inverse_link_function" , "likelihood_func" ])
342374def forward_backward (
343375 X : Array ,
@@ -412,30 +444,12 @@ def forward_backward(
412444 ----------
413445 .. [1] Bishop, C. M. (2006). *Pattern recognition and machine learning*. Springer.
414446 """
415- coef , intercept = glm_params
416447 # Initialize variables
417448 n_time_bins = X .shape [0 ]
418-
419- # Revise if the data is one single session or multiple sessions.
420- # If new_sess is not provided, assume one session
421- if is_new_session is None :
422- # default: all False, but first time bin must be True
423- is_new_session = jax .lax .dynamic_update_index_in_dim (
424- jnp .zeros (y .shape [0 ], dtype = bool ), True , 0 , axis = 0
425- )
426- else :
427- # use the user-provided tree, but force the first time bin to be True
428- is_new_session = jax .lax .dynamic_update_index_in_dim (
429- jnp .asarray (is_new_session , dtype = bool ), True , 0 , axis = 0
430- )
431-
432- # Predicted y
433- if coef .ndim > 2 :
434- predicted_rate_given_state = inverse_link_function (
435- jnp .einsum ("ik, kjw->ijw" , X , coef ) + intercept
436- )
437- else :
438- predicted_rate_given_state = inverse_link_function (X @ coef + intercept )
449+ is_new_session = initialize_new_session (y .shape [0 ], is_new_session )
450+ predicted_rate_given_state = compute_rate_per_state (
451+ X , glm_params , inverse_link_function
452+ )
439453
440454 # Compute likelihood given the fixed weights
441455 # Data likelihood p(y|z) from emissions model
@@ -636,3 +650,138 @@ def run_m_step(
636650 optimized_projection_weights , state = solver_run (glm_params , X , y , posteriors )
637651
638652 return optimized_projection_weights , new_initial_prob , new_transition_prob , state
653+
654+
655+ def max_sum (
656+ X : Array ,
657+ y : Array ,
658+ initial_prob : Array ,
659+ transition_prob : Array ,
660+ glm_params : Tuple [Array , Array ],
661+ inverse_link_function : Callable ,
662+ log_likelihood_func : Callable [[Array , Array ], Array ],
663+ is_new_session : Array | None = None ,
664+ return_index : bool = False ,
665+ ):
666+ """
667+ Find maximum a posteriori (MAP) state path via the max-sum algorithm.
668+
669+ This function implements the max-sum algorithm for a GLM-HMM, also known as Viterbi algorithm.
670+
671+ Parameters
672+ ----------
673+ X :
674+ Design matrix, pytree with leaves of shape ``(n_time_bins, n_features)``.
675+
676+ y :
677+ Observations, pytree with leaves of shape ``(n_time_bins,)``.
678+
679+ initial_prob :
680+ Initial latent state probability, pytree with leaves of shape ``(n_states, 1)``.
681+
682+ transition_prob :
683+ Latent state transition matrix, pytree with leaves of shape ``(n_states, n_states)``.
684+ ``transition_prob[i, j]`` is the probability of transitioning from state ``i`` to state ``j``.
685+
686+ glm_params :
687+ Length two tuple with the GLM coefficients of shape ``(n_features, n_states)``
688+ and intercept of shape ``(n_states,)``.
689+
690+ inverse_link_function :
691+ Function mapping linear predictors to the mean of the observation distribution
692+ (e.g., exp for Poisson, sigmoid for Bernoulli).
693+
694+ is_new_session :
695+ Boolean array marking the start of a new session.
696+ If unspecified or empty, treats the full set of trials as a single session.
697+
698+ return_index:
699+ If False, return 1-hot encoded map states, if True, return map state indices.
700+
701+ Returns
702+ -------
703+ map_path:
704+ The MAP state path.
705+
706+ """
707+ is_new_session = initialize_new_session (y .shape [0 ], is_new_session )
708+ predicted_rate_given_state = compute_rate_per_state (
709+ X , glm_params , inverse_link_function
710+ )
711+ log_emission = log_likelihood_func (y , predicted_rate_given_state )
712+
713+ log_transition = jnp .log (transition_prob )
714+ log_init = jnp .log (initial_prob )
715+ n_states = initial_prob .shape [0 ]
716+
717+ def forward_max_sum (omega_prev , xs ):
718+ log_em , is_new_sess = xs
719+
720+ def reset_chain (omega_prev , log_em ):
721+ # New session: reset to initial distribution
722+ omega = log_init + log_em
723+ max_prob_state = jnp .full (n_states , - 1 ) # Boundary marker
724+ return omega , max_prob_state
725+
726+ def continue_chain (omega_prev , log_em ):
727+ # Continue existing session: Viterbi step
728+ step = log_em [None , :] + log_transition + omega_prev [:, None ]
729+ max_prob_state = jnp .argmax (step , axis = 0 )
730+ omega = step [max_prob_state , jnp .arange (n_states )]
731+ return omega , max_prob_state
732+
733+ omega , max_prob_state = jax .lax .cond (
734+ is_new_sess ,
735+ reset_chain ,
736+ continue_chain ,
737+ omega_prev ,
738+ log_em ,
739+ )
740+
741+ return omega , (omega , max_prob_state )
742+
743+ init_omega = log_init + log_emission [0 ]
744+ _ , (omegas , max_prob_states ) = jax .lax .scan (
745+ forward_max_sum , init_omega , (log_emission [1 :], is_new_session [1 :])
746+ )
747+
748+ # Backward pass
749+ best_final_state = jnp .argmax (omegas [- 1 ])
750+ # Prepend initial omega and exclude last one, which is already considered.
751+ omegas = jnp .concatenate ([init_omega [None , :], omegas [:- 1 ]], axis = 0 )
752+
753+ def backward_max_sum (current_state_idx , xs ):
754+ max_prob_st , omega_t = xs
755+
756+ def session_boundary (state_idx , max_prob , omega ):
757+ # Hit a session start, pick best state at this boundary
758+ return jnp .argmax (omega )
759+
760+ def continue_backward (state_idx , max_prob , omega ):
761+ # Normal backtracking
762+ return max_prob [state_idx ]
763+
764+ is_boundary = max_prob_st [current_state_idx ] == - 1
765+
766+ prev_state_idx = jax .lax .cond (
767+ is_boundary ,
768+ session_boundary ,
769+ continue_backward ,
770+ current_state_idx ,
771+ max_prob_st ,
772+ omega_t ,
773+ )
774+
775+ return prev_state_idx , prev_state_idx
776+
777+ _ , map_path = jax .lax .scan (
778+ backward_max_sum , best_final_state , (max_prob_states , omegas ), reverse = True
779+ )
780+
781+ # Append the final state
782+ map_path = jnp .concatenate ([map_path , jnp .array ([best_final_state ])])
783+
784+ if not return_index :
785+ map_path = jax .nn .one_hot (map_path , n_states , dtype = jnp .int32 )
786+
787+ return map_path
0 commit comments