|
1 | 1 | import math
|
| 2 | +import pprint |
2 | 3 | import warnings
|
3 | 4 | from typing import Callable, Optional
|
4 | 5 |
|
@@ -139,7 +140,7 @@ def step(self) -> None:
|
139 | 140 | activations, trunk_outputs = self._get_latent_variables(generation_inputs)
|
140 | 141 |
|
141 | 142 | delta = torch.nn.Parameter(torch.zeros_like(activations))
|
142 |
| - optimizer = torch.optim.Adam([delta], lr=self.guidance_step_size) |
| 143 | + optimizer = torch.optim.Adam([delta], lr=self.guidance_step_size, betas=(0.09, 0.0999)) |
143 | 144 | metrics = {"step": self._step_count}
|
144 | 145 |
|
145 | 146 | # get initial solution before guidance
|
@@ -207,13 +208,15 @@ def step(self) -> None:
|
207 | 208 | grad_norm = feature_grad.norm(dim=(-2, -1), keepdim=True)
|
208 | 209 | metrics.update(
|
209 | 210 | {
|
| 211 | + "act_obj_val": tgt_obj_vals.mean().item(), |
210 | 212 | "masked_design_loss": design_loss.item(),
|
211 | 213 | "masked_design_loss_grad_norm": grad_norm.mean().item(),
|
212 | 214 | "masked_token_loss": kl_div.item(),
|
213 | 215 | "masked_obj_loss": obj_loss.item(),
|
214 | 216 | "token_entropy": entropy.item(),
|
215 | 217 | }
|
216 | 218 | )
|
| 219 | + pprint.pp(metrics) |
217 | 220 |
|
218 | 221 | self._step_count += 1
|
219 | 222 |
|
@@ -257,7 +260,13 @@ def coord_score(tok_embeddings):
|
257 | 260 | null_embedding,
|
258 | 261 | is_excluded=~pos_is_feasible,
|
259 | 262 | )
|
260 |
| - position_probs = (position_scores * self.feature_attr_temp).softmax(-1) |
| 263 | + denom = torch.where(position_scores > float("-inf"), position_scores, 0.0).abs().sum(-1, keepdim=True) |
| 264 | + position_scores = position_scores / (denom + 1e-6) |
| 265 | + |
| 266 | + position_probs = (position_scores / self.feature_attr_temp).softmax(-1) |
| 267 | + hand_tuned_entropy = torch.distributions.Categorical(probs=position_probs).entropy().median() |
| 268 | + print(f"[INFO][LaMBO-2]: Hand-tuned entropy = {hand_tuned_entropy}") |
| 269 | + |
261 | 270 | edit_idxs = torch.multinomial(position_probs, self.num_mutations_per_step, replacement=False)
|
262 | 271 | edit_idxs = edit_idxs.sort(dim=-1).values
|
263 | 272 |
|
|
0 commit comments