Skip to content

Commit ee91f54

Browse files
Merge pull request #428 from flatironinstitute/development
Development
2 parents ee432dc + 55c9562 commit ee91f54

File tree

7 files changed

+391
-46
lines changed

7 files changed

+391
-46
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,7 @@ docs/data/
153153

154154
# rst generated files
155155
docs/stubs
156+
157+
# NPZ/NPY files
158+
*.npz
159+
*.npy

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ dev = [
6161
"dandi", # Required by doctest for fetch module
6262
"seaborn", # Required by doctest for _documentation_utils module
6363
"myst-nb", # Test myst_nb utils for glue
64+
"hmmlearn", # Test algorithmic implementations of HMM
6465
]
6566
docs = [
6667
"numpydoc",

scripts/check_parameter_naming.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,6 @@
7777
{"flat_dict", "flat_map_dict"},
7878
{"fit_params", "init_params"},
7979
{"args", "arg"},
80-
{"tol", "atol"},
81-
{"tol", "rtol"},
82-
{"fit_params", "flat_params"},
83-
{"solver_kwargs", "solver_init_kwargs"},
84-
{"unaccepted_name", "accepted_name"},
85-
{"fn", "fun"},
8680
{"initialize_init_proba", "initialize_transition_proba"},
8781
*(
8882
{a, b}
@@ -95,6 +89,14 @@
9589
r=2,
9690
)
9791
),
92+
{"likelihood_func", "log_likelihood_func"},
93+
{"negative_log_likelihood_func", "log_likelihood_func"},
94+
{"tol", "atol"},
95+
{"tol", "rtol"},
96+
{"fit_params", "flat_params"},
97+
{"solver_kwargs", "solver_init_kwargs"},
98+
{"unaccepted_name", "accepted_name"},
99+
{"fn", "fun"},
98100
]
99101

100102

src/nemos/basis/_transformer_basis.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def __getattr__(self, name: str):
335335
if name in self._wrapped_methods:
336336
return self._wrapped_methods[name]
337337

338-
if not hasattr(self._basis, name) or name == "to_transformer":
338+
if not hasattr(self._basis, name):
339339
raise AttributeError(f"'TransformerBasis' object has no attribute '{name}'")
340340

341341
# Get the original attribute from the basis
@@ -463,7 +463,6 @@ def __dir__(self) -> list[str]:
463463
"""Extend the list of properties of methods with the ones from the underlying Basis."""
464464
unique_attrs = set(list(super().__dir__()) + list(self.basis.__dir__()))
465465
# discard without raising errors if not present
466-
unique_attrs.discard("to_transformer")
467466
return list(unique_attrs)
468467

469468
def __add__(self, other: TransformerBasis | Basis) -> TransformerBasis:

src/nemos/glm_hmm/expectation_maximization.py

Lines changed: 171 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,38 @@ def body_fn(carry, xs):
338338
return betas
339339

340340

341+
def initialize_new_session(n_samples, is_new_session):
342+
"""Initialize new session indicator."""
343+
# Revise if the data is one single session or multiple sessions.
344+
# If new_sess is not provided, assume one session
345+
if is_new_session is None:
346+
# default: all False, but first time bin must be True
347+
is_new_session = jax.lax.dynamic_update_index_in_dim(
348+
jnp.zeros(n_samples, dtype=bool), True, 0, axis=0
349+
)
350+
else:
351+
# use the user-provided tree, but force the first time bin to be True
352+
is_new_session = jax.lax.dynamic_update_index_in_dim(
353+
jnp.asarray(is_new_session, dtype=bool), True, 0, axis=0
354+
)
355+
356+
return is_new_session
357+
358+
359+
def compute_rate_per_state(X, glm_params, inverse_link_function):
360+
"""Compute the GLM mean per state."""
361+
coef, intercept = glm_params
362+
363+
# Predicted y
364+
if coef.ndim > 2:
365+
predicted_rate_given_state = inverse_link_function(
366+
jnp.einsum("ik, kjw->ijw", X, coef) + intercept
367+
)
368+
else:
369+
predicted_rate_given_state = inverse_link_function(X @ coef + intercept)
370+
return predicted_rate_given_state
371+
372+
341373
@partial(jax.jit, static_argnames=["inverse_link_function", "likelihood_func"])
342374
def forward_backward(
343375
X: Array,
@@ -412,30 +444,12 @@ def forward_backward(
412444
----------
413445
.. [1] Bishop, C. M. (2006). *Pattern recognition and machine learning*. Springer.
414446
"""
415-
coef, intercept = glm_params
416447
# Initialize variables
417448
n_time_bins = X.shape[0]
418-
419-
# Revise if the data is one single session or multiple sessions.
420-
# If new_sess is not provided, assume one session
421-
if is_new_session is None:
422-
# default: all False, but first time bin must be True
423-
is_new_session = jax.lax.dynamic_update_index_in_dim(
424-
jnp.zeros(y.shape[0], dtype=bool), True, 0, axis=0
425-
)
426-
else:
427-
# use the user-provided tree, but force the first time bin to be True
428-
is_new_session = jax.lax.dynamic_update_index_in_dim(
429-
jnp.asarray(is_new_session, dtype=bool), True, 0, axis=0
430-
)
431-
432-
# Predicted y
433-
if coef.ndim > 2:
434-
predicted_rate_given_state = inverse_link_function(
435-
jnp.einsum("ik, kjw->ijw", X, coef) + intercept
436-
)
437-
else:
438-
predicted_rate_given_state = inverse_link_function(X @ coef + intercept)
449+
is_new_session = initialize_new_session(y.shape[0], is_new_session)
450+
predicted_rate_given_state = compute_rate_per_state(
451+
X, glm_params, inverse_link_function
452+
)
439453

440454
# Compute likelihood given the fixed weights
441455
# Data likelihood p(y|z) from emissions model
@@ -636,3 +650,138 @@ def run_m_step(
636650
optimized_projection_weights, state = solver_run(glm_params, X, y, posteriors)
637651

638652
return optimized_projection_weights, new_initial_prob, new_transition_prob, state
653+
654+
655+
def max_sum(
656+
X: Array,
657+
y: Array,
658+
initial_prob: Array,
659+
transition_prob: Array,
660+
glm_params: Tuple[Array, Array],
661+
inverse_link_function: Callable,
662+
log_likelihood_func: Callable[[Array, Array], Array],
663+
is_new_session: Array | None = None,
664+
return_index: bool = False,
665+
):
666+
"""
667+
Find maximum a posteriori (MAP) state path via the max-sum algorithm.
668+
669+
This function implements the max-sum algorithm for a GLM-HMM, also known as Viterbi algorithm.
670+
671+
Parameters
672+
----------
673+
X :
674+
Design matrix, pytree with leaves of shape ``(n_time_bins, n_features)``.
675+
676+
y :
677+
Observations, pytree with leaves of shape ``(n_time_bins,)``.
678+
679+
initial_prob :
680+
Initial latent state probability, pytree with leaves of shape ``(n_states, 1)``.
681+
682+
transition_prob :
683+
Latent state transition matrix, pytree with leaves of shape ``(n_states, n_states)``.
684+
``transition_prob[i, j]`` is the probability of transitioning from state ``i`` to state ``j``.
685+
686+
glm_params :
687+
Length two tuple with the GLM coefficients of shape ``(n_features, n_states)``
688+
and intercept of shape ``(n_states,)``.
689+
690+
inverse_link_function :
691+
Function mapping linear predictors to the mean of the observation distribution
692+
(e.g., exp for Poisson, sigmoid for Bernoulli).
693+
694+
is_new_session :
695+
Boolean array marking the start of a new session.
696+
If unspecified or empty, treats the full set of trials as a single session.
697+
698+
return_index:
699+
If False, return 1-hot encoded map states, if True, return map state indices.
700+
701+
Returns
702+
-------
703+
map_path:
704+
The MAP state path.
705+
706+
"""
707+
is_new_session = initialize_new_session(y.shape[0], is_new_session)
708+
predicted_rate_given_state = compute_rate_per_state(
709+
X, glm_params, inverse_link_function
710+
)
711+
log_emission = log_likelihood_func(y, predicted_rate_given_state)
712+
713+
log_transition = jnp.log(transition_prob)
714+
log_init = jnp.log(initial_prob)
715+
n_states = initial_prob.shape[0]
716+
717+
def forward_max_sum(omega_prev, xs):
718+
log_em, is_new_sess = xs
719+
720+
def reset_chain(omega_prev, log_em):
721+
# New session: reset to initial distribution
722+
omega = log_init + log_em
723+
max_prob_state = jnp.full(n_states, -1) # Boundary marker
724+
return omega, max_prob_state
725+
726+
def continue_chain(omega_prev, log_em):
727+
# Continue existing session: Viterbi step
728+
step = log_em[None, :] + log_transition + omega_prev[:, None]
729+
max_prob_state = jnp.argmax(step, axis=0)
730+
omega = step[max_prob_state, jnp.arange(n_states)]
731+
return omega, max_prob_state
732+
733+
omega, max_prob_state = jax.lax.cond(
734+
is_new_sess,
735+
reset_chain,
736+
continue_chain,
737+
omega_prev,
738+
log_em,
739+
)
740+
741+
return omega, (omega, max_prob_state)
742+
743+
init_omega = log_init + log_emission[0]
744+
_, (omegas, max_prob_states) = jax.lax.scan(
745+
forward_max_sum, init_omega, (log_emission[1:], is_new_session[1:])
746+
)
747+
748+
# Backward pass
749+
best_final_state = jnp.argmax(omegas[-1])
750+
# Prepend initial omega and exclude last one, which is already considered.
751+
omegas = jnp.concatenate([init_omega[None, :], omegas[:-1]], axis=0)
752+
753+
def backward_max_sum(current_state_idx, xs):
754+
max_prob_st, omega_t = xs
755+
756+
def session_boundary(state_idx, max_prob, omega):
757+
# Hit a session start, pick best state at this boundary
758+
return jnp.argmax(omega)
759+
760+
def continue_backward(state_idx, max_prob, omega):
761+
# Normal backtracking
762+
return max_prob[state_idx]
763+
764+
is_boundary = max_prob_st[current_state_idx] == -1
765+
766+
prev_state_idx = jax.lax.cond(
767+
is_boundary,
768+
session_boundary,
769+
continue_backward,
770+
current_state_idx,
771+
max_prob_st,
772+
omega_t,
773+
)
774+
775+
return prev_state_idx, prev_state_idx
776+
777+
_, map_path = jax.lax.scan(
778+
backward_max_sum, best_final_state, (max_prob_states, omegas), reverse=True
779+
)
780+
781+
# Append the final state
782+
map_path = jnp.concatenate([map_path, jnp.array([best_final_state])])
783+
784+
if not return_index:
785+
map_path = jax.nn.one_hot(map_path, n_states, dtype=jnp.int32)
786+
787+
return map_path

0 commit comments

Comments
 (0)