Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ordabayevy committed Dec 17, 2023
1 parent 24ea976 commit 49dbab1
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
6 changes: 4 additions & 2 deletions pyro/poutine/enum_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def _tmc_mixture_sample(msg: Message) -> torch.Tensor:
# find batch dims that aren't plate dims
batch_shape = [1] * len(dist.batch_shape)
for f in msg["cond_indep_stack"]:
if f.dim is not None:
if f.vectorized:
assert f.dim is not None
batch_shape[f.dim] = f.size if f.size > 0 else dist.batch_shape[f.dim]
batch_shape_tuple = tuple(batch_shape)

Expand Down Expand Up @@ -71,7 +72,8 @@ def _tmc_diagonal_sample(msg: Message) -> torch.Tensor:
# find batch dims that aren't plate dims
batch_shape = [1] * len(dist.batch_shape)
for f in msg["cond_indep_stack"]:
if f.dim is not None:
if f.vectorized:
assert f.dim is not None
batch_shape[f.dim] = f.size if f.size > 0 else dist.batch_shape[f.dim]
batch_shape_tuple = tuple(batch_shape)

Expand Down
2 changes: 1 addition & 1 deletion pyro/poutine/indep_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pyro.util import ignore_jit_warnings


@dataclass
@dataclass(eq=False)
class CondIndepStackFrame:
name: str
dim: Optional[int]
Expand Down
3 changes: 2 additions & 1 deletion pyro/poutine/trace_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,8 @@ def symbolize_dims(self, plate_to_symbol: Optional[Dict[str, str]] = None) -> No
# allocate even symbols for plate dims
dim_to_symbol: Dict[int, str] = {}
for frame in site["cond_indep_stack"]:
if frame.dim is not None:
if frame.vectorized:
assert frame.dim is not None
if frame.name in plate_to_symbol:
symbol = plate_to_symbol[frame.name]
else:
Expand Down

0 comments on commit 49dbab1

Please sign in to comment.