Skip to content

Commit

Permalink
fix(cached_acts): re-implement changes mistakenly removed in 71ff9f9
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfinfdu committed Feb 19, 2025
1 parent 59ce8ac commit 44bb55b
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/lm_saes/activation/processors/cached_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,15 @@ def process(self, data: None = None, **kwargs) -> Iterable[dict[str, Any]]:
for k, v in activations.items():
if k in self.hook_points:
activations[k] = v.to(self.dtype)

while activations["tokens"].ndim >= 3:
def flatten(x: torch.Tensor | list[list[Any]]) -> torch.Tensor | list[Any]:
if isinstance(x, torch.Tensor):
return x.flatten(start_dim=0, end_dim=1)
else:
return [a for b in x for a in b]
activations = {k: flatten(v) for k, v in activations.items()}

yield activations # Use pin_memory to load data on cpu, then transfer them to cuda in the main process, as advised in https://discuss.pytorch.org/t/dataloader-multiprocessing-with-dataset-returning-a-cuda-tensor/151022/2.
# I wrote this utils function as I notice it is used multiple times in this repo. Do we need to apply it elsewhere?

Expand Down

0 comments on commit 44bb55b

Please sign in to comment.