Skip to content

Commit

Permalink
Support for deterministic dependent samples in PyroSample.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben Zickel committed Jun 23, 2024
1 parent 64e71ee commit e2788b9
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 8 deletions.
21 changes: 16 additions & 5 deletions pyro/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ class PyroSample:
assert isinstance(my_module, PyroModule)
my_module.x = PyroSample(Normal(0, 1)) # independent
my_module.y = PyroSample(lambda self: Normal(self.x, 1)) # dependent
my_module.z = PyroSample(lambda self: self.y ** 2) # deterministic dependent
or EXPERIMENTALLY as a decorator on lazy initialization methods::
Expand All @@ -175,16 +176,22 @@ def x(self):
def y(self):
return Normal(self.x, 1) # dependent
@PyroSample
def z(self):
return self.y ** 2 # deterministic dependent
def forward(self):
return self.y # accessed like a @property
return self.z # accessed like a @property
:param prior: distribution object or function that inputs the
:class:`PyroModule` instance ``self`` and returns a distribution
object.
object or a deterministic value.
"""

prior: Union[
"TorchDistributionMixin", Callable[["PyroModule"], "TorchDistributionMixin"]
"TorchDistributionMixin",
Callable[["PyroModule"], "TorchDistributionMixin"],
Callable[["PyroModule"], torch.Tensor],
]

def __post_init__(self) -> None:
Expand Down Expand Up @@ -605,13 +612,17 @@ def __getattr__(self, name: str) -> Any:
if value is None:
if not hasattr(prior, "sample"): # if not a distribution
prior = prior(self)
value = pyro.sample(fullname, prior)
value = (
pyro.deterministic(fullname, prior)
if isinstance(prior, torch.Tensor)
else pyro.sample(fullname, prior)
)
context.set(fullname, value)
return value
else: # Cannot determine supermodule and hence cannot compute fullname.
if not hasattr(prior, "sample"): # if not a distribution
prior = prior(self)
return prior()
return prior if isinstance(prior, torch.Tensor) else prior()

result = super().__getattr__(name)

Expand Down
24 changes: 21 additions & 3 deletions tests/nn/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,10 @@ def __init__(self, size):
)
self.s = PyroSample(dist.Normal(0, 1))
self.t = PyroSample(lambda self: dist.Normal(self.s, self.z))
self.u = PyroSample(lambda self: self.t**2)

def forward(self):
return self.x + self.y + self.t
return self.x + self.y + self.u


class DecoratorModel(PyroModule):
Expand Down Expand Up @@ -521,8 +522,12 @@ def s(self):
def t(self):
return dist.Normal(self.s, self.z).to_event(1)

@PyroSample
def u(self):
return self.t**2

def forward(self):
return self.x + self.y + self.t
return self.x + self.y + self.u


@pytest.mark.parametrize("Model", [AttributeModel, DecoratorModel])
Expand All @@ -531,19 +536,32 @@ def test_decorator(Model, size):
model = Model(size)
for i in range(2):
trace = poutine.trace(model).get_trace()
assert set(trace.nodes.keys()) == {"_INPUT", "x", "y", "z", "s", "t", "_RETURN"}
assert set(trace.nodes.keys()) == {
"_INPUT",
"x",
"y",
"z",
"s",
"t",
"u",
"_RETURN",
}

assert trace.nodes["x"]["type"] == "param"
assert trace.nodes["y"]["type"] == "param"
assert trace.nodes["z"]["type"] == "param"
assert trace.nodes["s"]["type"] == "sample"
assert trace.nodes["t"]["type"] == "sample"
assert trace.nodes["u"]["type"] == "sample"

assert trace.nodes["x"]["value"].shape == (size,)
assert trace.nodes["y"]["value"].shape == (size,)
assert trace.nodes["z"]["value"].shape == (size,)
assert trace.nodes["s"]["value"].shape == ()
assert trace.nodes["t"]["value"].shape == (size,)
assert trace.nodes["u"]["value"].shape == (size,)

assert trace.nodes["u"]["infer"] == {"_deterministic": True}


def test_mixin_factory():
Expand Down

0 comments on commit e2788b9

Please sign in to comment.