Skip to content

Commit 748a3e4

Browse files
committed
fix bug in EKL + SChooses
1 parent 0dee1a8 commit 748a3e4

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed

demo/test.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ def normpdf(x, mu, sigma):
2424
def test_[x: X](t):
2525
return x
2626

27+
@memo_test(mod, expect='ce')
28+
def chooses_multiple():
29+
bob: chooses(x in X, wpp=1)
30+
bob: chooses(x in X, wpp=1)
31+
return 1
32+
2733
@memo_test(mod)
2834
def observes_call[x: X]():
2935
a: thinks[ b: chooses(x in X, wpp=1) ]
@@ -271,4 +277,12 @@ def kl_fail_known[r: R]():
271277
def kl_fail_dom():
272278
alice: chooses(p in Z, wpp=normpdf(p, 0, 1))
273279
alice: chooses(q in R, wpp=normpdf(q, 0, 2))
274-
return KL[alice.p | alice.q]
280+
return KL[alice.p | alice.q]
281+
282+
@memo_test(mod)
283+
def kl_victor[x: X]():
284+
bob: thinks[
285+
alice: given(x in X, wpp=1),
286+
env: chooses(x in X, wpp=1)
287+
]
288+
return bob[KL[alice.x | env.x]]

memo/core.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -806,7 +806,7 @@ def _(e: EKL, ctxt: Context) -> Value:
806806
ctxt.emit(f"{q_p} = marg({ctxt.frame.ll}, {idxs_q})")
807807

808808
ctxt.emit(
809-
f"{out} = marg({p_p} * jnp.nan_to_num(jnp.log({p_p}) - jnp.log(jnp.swapaxes({q_p}, {p_c.idx}, {q_c.idx}))), [{p_c.idx}])"
809+
f"{out} = marg({p_p} * jnp.nan_to_num(jnp.log({p_p}) - jnp.log(jnp.swapaxes({q_p}, {-1 - p_c.idx}, {-1 - q_c.idx}))), [{p_c.idx}])"
810810
)
811811
return Value(
812812
tag=out,
@@ -952,6 +952,14 @@ def eval_stmt(s: Stmt, ctxt: Context) -> None:
952952

953953
idx_list = []
954954
for id, dom in choices:
955+
if (Name("self"), id) in ctxt.frame.choices:
956+
raise MemoError(
957+
"Repeated choice",
958+
hint=f"{who} has already chosen {id} earlier in this model! Pick a new name?",
959+
user=True,
960+
ctxt=ctxt,
961+
loc=s.loc
962+
)
955963
idx = ctxt.next_idx
956964
ctxt.next_idx += 1
957965
idx_list.append(idx)

memo/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.1.1'
1+
__version__ = '1.1.2'

0 commit comments

Comments
 (0)