From fc5c5c02ce272c82d709a3c8b8e1b5d73cabb133 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Rozet?= Date: Mon, 18 Dec 2023 12:59:19 -0600 Subject: [PATCH] Shorten comment --- pyro/contrib/zuko.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/contrib/zuko.py b/pyro/contrib/zuko.py index 7b17f0bd61..19cea0de38 100644 --- a/pyro/contrib/zuko.py +++ b/pyro/contrib/zuko.py @@ -27,7 +27,7 @@ def batch_shape(self) -> Size: return self.dist.batch_shape def __call__(self, shape: Size = ()) -> Tensor: - if hasattr(self.dist, "rsample_and_log_prob"): # special method for fast sampling + scoring + if hasattr(self.dist, "rsample_and_log_prob"): # fast sampling + scoring x, self.cache[x] = self.dist.rsample_and_log_prob(shape) elif self.has_rsample: x = self.dist.rsample(shape)