diff --git a/pyro/poutine/indep_messenger.py b/pyro/poutine/indep_messenger.py index 3d68136b9a..c224b032fd 100644 --- a/pyro/poutine/indep_messenger.py +++ b/pyro/poutine/indep_messenger.py @@ -2,7 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import numbers -from typing import Iterator, NamedTuple, Optional, Tuple +from dataclasses import dataclass +from typing import Iterator, Optional, Tuple import torch from typing_extensions import Self @@ -12,7 +13,8 @@ from pyro.util import ignore_jit_warnings -class CondIndepStackFrame(NamedTuple): +@dataclass +class CondIndepStackFrame: name: str dim: Optional[int] size: int @@ -97,7 +99,6 @@ def __enter__(self) -> Self: self._vectorized = True if self._vectorized is True: - assert self.dim is not None self.dim = _DIM_ALLOCATOR.allocate(self.name, self.dim) return super().__enter__()