Skip to content

Commit

Permalink
fix(misc): fix calculate_activation_norm method
Browse files Browse the repository at this point in the history
  • Loading branch information
Frankstein73 authored and dest1n1s committed Jan 14, 2025
1 parent fb7be43 commit 1c98127
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/lm_saes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class BaseSAEConfig(BaseModelConfig):
act_fn: str = "relu"
jump_relu_threshold: float = 0.0
apply_decoder_bias_to_pre_encoder: bool = True
norm_activation: str = "token-wise"
norm_activation: str = "dataset-wise"
sparsity_include_decoder_norm: bool = True
top_k: int = 50
sae_pretrained_name_or_path: Optional[str] = None
Expand Down
4 changes: 3 additions & 1 deletion src/lm_saes/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ def initialize_sae_from_config(
assert (
activation_stream is not None
), "Activation iterator must be provided for dataset-wise normalization"
activation_norm = calculate_activation_norm(activation_stream)
activation_norm = calculate_activation_norm(
activation_stream, [cfg.hook_point_in, cfg.hook_point_out]
)
sae.set_dataset_average_activation_norm(activation_norm)

if self.cfg.init_search:
Expand Down
17 changes: 12 additions & 5 deletions src/lm_saes/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import warnings
from typing import Iterable

import torch
Expand Down Expand Up @@ -121,17 +122,23 @@ def assert_tensor_consistency(tensor):


def calculate_activation_norm(
activation_stream: Iterable[dict[str, torch.Tensor]], batch_num: int = 8
activation_stream: Iterable[dict[str, torch.Tensor]], hook_points: list[str], batch_num: int = 8
) -> dict[str, float]:
activation_norm = {}
stream_iter = iter(activation_stream)
hook_points = list(set(hook_points))
assert len(hook_points) > 0, "No hook points provided"
while batch_num > 0:
batch = next(stream_iter)
for key, value in batch.items():
try:
batch = next(stream_iter)
except StopIteration:
warnings.warn(f"Activation stream ended prematurely. {batch_num} batches not processed.")
break
for key in hook_points:
if key not in activation_norm:
activation_norm[key] = value.norm(p=2, dim=1)
activation_norm[key] = batch[key].norm(p=2, dim=1)
else:
activation_norm[key] = torch.cat((activation_norm[key], value.norm(p=2, dim=1)), dim=0)
activation_norm[key] = torch.cat((activation_norm[key], batch[key].norm(p=2, dim=1)), dim=0)
batch_num -= 1
for key in activation_norm:
activation_norm[key] = activation_norm[key].mean().item()
Expand Down
12 changes: 10 additions & 2 deletions tests/unit/test_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def initializer_config() -> InitializerConfig:

def test_initialize_sae_from_config(sae_config: SAEConfig, initializer_config: InitializerConfig):
initializer = Initializer(initializer_config)
sae_config.norm_activation = "token-wise"
sae = initializer.initialize_sae_from_config(sae_config)
sae_config.norm_activation = "dataset-wise"
sae = initializer.initialize_sae_from_config(sae_config, activation_norm={"in": 3.0, "out": 2.0})
Expand All @@ -54,12 +55,19 @@ def test_initialize_sae_from_config(sae_config: SAEConfig, initializer_config: I
def test_initialize_search(
mocker: MockerFixture, sae_config: SAEConfig, initializer_config: InitializerConfig, generator: torch.Generator
):
def stream_generator():
# Create 10 batches of activations
for _ in range(20):
yield {
"in": torch.ones(4, sae_config.d_model), # norm will be sqrt(16)
"out": torch.ones(4, sae_config.d_model) * 2, # norm will be sqrt(16) * 2
}

sae_config.hook_point_out = sae_config.hook_point_in
initializer_config.init_search = True
initializer_config.l1_coefficient = 0.0008
activation_stream_iter = mocker.Mock()
batch = torch.randn((16, sae_config.d_model), generator=generator)
activation_stream_iter.__iter__ = lambda x: iter([{"in": batch, "out": batch}])
activation_stream_iter = stream_generator()
initializer = Initializer(initializer_config)
sae = initializer.initialize_sae_from_config(sae_config, activation_stream=activation_stream_iter)
assert torch.allclose(sae.decoder_norm(), sae.decoder_norm().mean(), atol=1e-4, rtol=1e-5)
Expand Down
14 changes: 7 additions & 7 deletions tests/unit/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def stream_generator():

def test_basic_functionality(self, mock_activation_stream):
"""Test basic functionality with default batch_num."""
result = calculate_activation_norm(mock_activation_stream)
result = calculate_activation_norm(mock_activation_stream, ["layer1", "layer2"])

assert isinstance(result, dict)
assert "layer1" in result
Expand All @@ -38,7 +38,7 @@ def test_basic_functionality(self, mock_activation_stream):

def test_custom_batch_num(self, mock_activation_stream):
"""Test with custom batch_num parameter."""
result = calculate_activation_norm(mock_activation_stream, batch_num=3)
result = calculate_activation_norm(mock_activation_stream, batch_num=3, hook_points=["layer1", "layer2"])

# Should still give same results as we're averaging
assert pytest.approx(result["layer1"], rel=1e-4) == 4.0
Expand All @@ -47,16 +47,16 @@ def test_custom_batch_num(self, mock_activation_stream):
def test_empty_stream(self):
"""Test behavior with empty activation stream."""
empty_stream = iter([])
with pytest.raises(StopIteration):
calculate_activation_norm(empty_stream)
with pytest.warns(UserWarning):
calculate_activation_norm(empty_stream, hook_points=[""])

def test_single_batch(self):
"""Test with a single batch of activations."""

def single_batch_stream():
yield {"single": torch.ones(2, 4)} # norm will be 2.0

result = calculate_activation_norm(single_batch_stream(), batch_num=1)
result = calculate_activation_norm(single_batch_stream(), hook_points=["single", "single"], batch_num=1)
assert pytest.approx(result["single"], rel=1e-4) == 2.0

def test_zero_tensors(self):
Expand All @@ -66,7 +66,7 @@ def zero_stream():
for _ in range(10):
yield {"zeros": torch.zeros(2, 4)}

result = calculate_activation_norm(zero_stream())
result = calculate_activation_norm(zero_stream(), hook_points=["zeros"])
assert result["zeros"] == 0.0

def test_mixed_values(self):
Expand All @@ -76,7 +76,7 @@ def mixed_stream():
for i in range(10):
yield {"mixed": torch.tensor([[1.0, -2.0], [3.0, -4.0], [3.0, -2.0], [9.0, -4.0]]) * (i + 1)}

result = calculate_activation_norm(mixed_stream(), batch_num=10)
result = calculate_activation_norm(mixed_stream(), hook_points=["mixed"], batch_num=10)
assert (
pytest.approx(result["mixed"], rel=1e-4) == ((math.sqrt(5) + 5 + math.sqrt(13) + math.sqrt(97)) / 4) * 5.5
)

0 comments on commit 1c98127

Please sign in to comment.