Skip to content

Commit b7b1aff

Browse files
[patch] fix bug in solver constraint handling (#16)
* small bug fixes * lambo tweaks * revert coord selection changes * clean up commented code * fix normalization term
1 parent c318f9e commit b7b1aff

File tree

3 files changed

+21
-7
lines changed

3 files changed

+21
-7
lines changed

cortex/acquisition/_graph_nei.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,15 @@ def tree_output_to_dict(
126126
)
127127

128128
if constraints is not None:
129-
for constraint in constraints:
130-
constraint_values = tree_output.fetch_task_outputs(constraint)["logits"]
131-
constraint_values = constraint_values.softmax(dim=-1)[..., 1]
129+
for c_list in constraints.values():
130+
for constraint in c_list:
131+
if constraint in result:
132+
continue
132133

133-
result[constraint] = constraint_values
134+
constraint_values = tree_output.fetch_task_outputs(constraint)["logits"]
135+
constraint_values = constraint_values.softmax(dim=-1)[..., 1]
136+
137+
result[constraint] = constraint_values
134138

135139
return result
136140

@@ -163,6 +167,7 @@ def get_graph_nei_runtime_kwargs(
163167
"f_ref": f_ref,
164168
"f_baseline": f_baseline,
165169
}
170+
print(f"[INFO][LaMBO-2] Baseline value: {f_baseline.mean(0).max().item():.4f}")
166171
return res
167172

168173

cortex/model/leaf/_classifier_leaf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def check_probs(probs: torch.Tensor, dim: int = -1) -> bool:
2222
if torch.any(probs < 0) or torch.any(probs > 1):
2323
raise ValueError("Probabilities must be between 0 and 1")
2424

25-
if not torch.allclose(probs.sum(dim=dim), torch.ones(probs.shape[:-1])):
25+
if not torch.allclose(probs.sum(dim=dim), torch.ones(probs.shape[:-1], device=probs.device)):
2626
raise ValueError("Probabilities must sum to 1")
2727

2828
return True

cortex/optim/generative/_lambo.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
import pprint
23
import warnings
34
from typing import Callable, Optional
45

@@ -139,7 +140,7 @@ def step(self) -> None:
139140
activations, trunk_outputs = self._get_latent_variables(generation_inputs)
140141

141142
delta = torch.nn.Parameter(torch.zeros_like(activations))
142-
optimizer = torch.optim.Adam([delta], lr=self.guidance_step_size)
143+
optimizer = torch.optim.Adam([delta], lr=self.guidance_step_size, betas=(0.09, 0.0999))
143144
metrics = {"step": self._step_count}
144145

145146
# get initial solution before guidance
@@ -207,13 +208,15 @@ def step(self) -> None:
207208
grad_norm = feature_grad.norm(dim=(-2, -1), keepdim=True)
208209
metrics.update(
209210
{
211+
"act_obj_val": tgt_obj_vals.mean().item(),
210212
"masked_design_loss": design_loss.item(),
211213
"masked_design_loss_grad_norm": grad_norm.mean().item(),
212214
"masked_token_loss": kl_div.item(),
213215
"masked_obj_loss": obj_loss.item(),
214216
"token_entropy": entropy.item(),
215217
}
216218
)
219+
pprint.pp(metrics)
217220

218221
self._step_count += 1
219222

@@ -257,7 +260,13 @@ def coord_score(tok_embeddings):
257260
null_embedding,
258261
is_excluded=~pos_is_feasible,
259262
)
260-
position_probs = (position_scores * self.feature_attr_temp).softmax(-1)
263+
denom = torch.where(position_scores > float("-inf"), position_scores, 0.0).abs().sum(-1, keepdim=True)
264+
position_scores = position_scores / (denom + 1e-6)
265+
266+
position_probs = (position_scores / self.feature_attr_temp).softmax(-1)
267+
hand_tuned_entropy = torch.distributions.Categorical(probs=position_probs).entropy().median()
268+
print(f"[INFO][LaMBO-2]: Hand-tuned entropy = {hand_tuned_entropy}")
269+
261270
edit_idxs = torch.multinomial(position_probs, self.num_mutations_per_step, replacement=False)
262271
edit_idxs = edit_idxs.sort(dim=-1).values
263272

0 commit comments

Comments
 (0)