Skip to content

New ADVI fails for variables whose value transforms changes the variables' shape #646

@Dekermanjian

Description

@Dekermanjian

Quick repro of the issue:

import pymc as pm

from pymc_extras.inference.advi.autoguide import AutoDiagonalNormal
from pymc_extras.inference.advi.training import SVIModule, SVITrainer

class SGDOptimizer:
    def __init__(self, learning_rate: float = 1e-5):
        self.learning_rate = learning_rate

    def init(self, params: dict[str, np.ndarray]) -> None:
        return None

    def update(
        self,
        grads: dict[str, np.ndarray],
        state: None,
        params: dict[str, np.ndarray],
    ) -> tuple[dict[str, np.ndarray], None]:
        updated_params = {k: v - self.learning_rate * grads[k] for k, v in params.items()}
        return updated_params, state

class NormalModel(SVIModule):
    def configure_guide(self, model):
        return AutoDiagonalNormal(model)

    def configure_optimizer(self, params: dict[str, np.ndarray]) -> tuple[Any, dict[str, Any]]:
        optimizer = SGDOptimizer(learning_rate=1e-5)
        opt_state = optimizer.init(params)
        return optimizer, opt_state

    def apply_gradients(
        self,
        params: dict[str, np.ndarray],
        grads: dict[str, np.ndarray],
        optimizer: Any,
        optimizer_state: dict[str, Any],
    ):
        updated_params, updated_opt_state = optimizer.update(grads, optimizer_state, params)
        return updated_params, updated_opt_state

with pm.Model() as m:
    p = pm.Dirichlet("p", np.ones(3))
    obs = pm.Categorical('obs', p=p, observed=[0, 1, 2])

    idata = pm.sample() # samples successfully

svi_trainer = SVITrainer(
    module=NormalModel(), stick_the_landing=True
)
svi_state = svi_trainer.fit(n_steps=10_000, model=m, draws_per_step=1) # Fails

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions