From 44bb55b49ff8a1896519ce18bbf00140ca6ba3af Mon Sep 17 00:00:00 2001 From: Hzfinfdu Date: Wed, 19 Feb 2025 20:25:49 +0800 Subject: [PATCH] fix(cached_acts): re-implement changes mistakenly removed in 71ff9f9 --- src/lm_saes/activation/processors/cached_activation.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/lm_saes/activation/processors/cached_activation.py b/src/lm_saes/activation/processors/cached_activation.py index e1dba22..df32694 100644 --- a/src/lm_saes/activation/processors/cached_activation.py +++ b/src/lm_saes/activation/processors/cached_activation.py @@ -240,6 +240,15 @@ def process(self, data: None = None, **kwargs) -> Iterable[dict[str, Any]]: for k, v in activations.items(): if k in self.hook_points: activations[k] = v.to(self.dtype) + + while activations["tokens"].ndim >= 3: + def flatten(x: torch.Tensor | list[list[Any]]) -> torch.Tensor | list[Any]: + if isinstance(x, torch.Tensor): + return x.flatten(start_dim=0, end_dim=1) + else: + return [a for b in x for a in b] + activations = {k: flatten(v) for k, v in activations.items()} + yield activations # Use pin_memory to load data on cpu, then transfer them to cuda in the main process, as advised in https://discuss.pytorch.org/t/dataloader-multiprocessing-with-dataset-returning-a-cuda-tensor/151022/2. # I wrote this utils function as I notice it is used multiple times in this repo. Do we need to apply it elsewhere?