Skip to content

Commit 56f63eb

Browse files
authored
Fix promote_batch_shape logic to take batch shapes of all parameters (#1973)
* fix promote batch shape logic * lint
1 parent fa3f731 commit 56f63eb

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

numpyro/distributions/batch_util.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -504,12 +504,16 @@ def promote_batch_shape(d: Distribution):
504504

505505
@promote_batch_shape.register
506506
def _default_promote_batch_shape(d: Distribution):
507-
attr_name = list(d.arg_constraints.keys())[0]
508-
attr_event_dim = d.arg_constraints[attr_name].event_dim
509-
attr = getattr(d, attr_name)
510-
resolved_batch_shape = attr.shape[
511-
: max(0, attr.ndim - d.event_dim - attr_event_dim)
512-
]
507+
attr_batch_shapes = [d.batch_shape]
508+
for attr_name, constraint in d.arg_constraints.items():
509+
try:
510+
attr_event_dim = constraint.event_dim
511+
except NotImplementedError:
512+
continue
513+
attr = getattr(d, attr_name)
514+
attr_batch_ndim = max(0, jnp.ndim(attr) - attr_event_dim)
515+
attr_batch_shapes.append(jnp.shape(attr)[:attr_batch_ndim])
516+
resolved_batch_shape = jnp.broadcast_shapes(*attr_batch_shapes)
513517
new_self = copy.deepcopy(d)
514518
new_self._batch_shape = resolved_batch_shape
515519
return new_self

test/contrib/test_control_flow.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,19 @@ def transition(x_prev, y_curr):
272272

273273
xhat = results.params["x_auto_loc"]
274274
assert_allclose(xhat, tr["x"]["value"], rtol=0.1, atol=0.2)
275+
276+
277+
def test_scan_mvn():
278+
def model():
279+
def transition(c, a):
280+
with numpyro.plate("foo", 5):
281+
c2 = numpyro.sample(
282+
"val", dist.MultivariateNormal(c + a, scale_tril=jnp.eye(2))
283+
)
284+
return c2, c2
285+
286+
scan(transition, jnp.zeros((5, 2)), jnp.ones((4, 5, 2)))
287+
288+
with numpyro.handlers.seed(rng_seed=0), numpyro.handlers.trace() as tr:
289+
model()
290+
assert tr["val"]["fn"].batch_shape == (4, 5)

0 commit comments

Comments
 (0)