-
Notifications
You must be signed in to change notification settings - Fork 81
New ADVI fails for variables whose value transforms changes the variables' shape #646
Copy link
Copy link
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Quick repro of the issue:
import pymc as pm
from pymc_extras.inference.advi.autoguide import AutoDiagonalNormal
from pymc_extras.inference.advi.training import SVIModule, SVITrainer
class SGDOptimizer:
def __init__(self, learning_rate: float = 1e-5):
self.learning_rate = learning_rate
def init(self, params: dict[str, np.ndarray]) -> None:
return None
def update(
self,
grads: dict[str, np.ndarray],
state: None,
params: dict[str, np.ndarray],
) -> tuple[dict[str, np.ndarray], None]:
updated_params = {k: v - self.learning_rate * grads[k] for k, v in params.items()}
return updated_params, state
class NormalModel(SVIModule):
def configure_guide(self, model):
return AutoDiagonalNormal(model)
def configure_optimizer(self, params: dict[str, np.ndarray]) -> tuple[Any, dict[str, Any]]:
optimizer = SGDOptimizer(learning_rate=1e-5)
opt_state = optimizer.init(params)
return optimizer, opt_state
def apply_gradients(
self,
params: dict[str, np.ndarray],
grads: dict[str, np.ndarray],
optimizer: Any,
optimizer_state: dict[str, Any],
):
updated_params, updated_opt_state = optimizer.update(grads, optimizer_state, params)
return updated_params, updated_opt_state
with pm.Model() as m:
p = pm.Dirichlet("p", np.ones(3))
obs = pm.Categorical('obs', p=p, observed=[0, 1, 2])
idata = pm.sample() # samples successfully
svi_trainer = SVITrainer(
module=NormalModel(), stick_the_landing=True
)
svi_state = svi_trainer.fit(n_steps=10_000, model=m, draws_per_step=1) # Fails
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working