diff --git a/pyro/distributions/zero_inflated.py b/pyro/distributions/zero_inflated.py index 8ee654b345..7f017f8429 100644 --- a/pyro/distributions/zero_inflated.py +++ b/pyro/distributions/zero_inflated.py @@ -60,11 +60,11 @@ def support(self): @lazy_property def gate(self): - return logits_to_probs(self.gate_logits) + return logits_to_probs(self.gate_logits, is_binary=True) @lazy_property def gate_logits(self): - return probs_to_logits(self.gate) + return probs_to_logits(self.gate, is_binary=True) def log_prob(self, value): if self._validate_args: