Skip to content

Commit

Permalink
test(sae): fix test SAE fixture; specify test weight for exact-checki…
Browse files Browse the repository at this point in the history
…ng forward computation
  • Loading branch information
dest1n1s authored and Frankstein73 committed Jan 20, 2025
1 parent 1cbcd25 commit 4eb40ba
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 43 deletions.
40 changes: 34 additions & 6 deletions src/lm_saes/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _decoder_norm(self, decoder: torch.nn.Linear, keepdim: bool = False):
decoder_norm = decoder_norm.redistribute(placements=[Replicate()], async_op=True).to_local()
return decoder_norm

def activation_function_factory(self, cfg: BaseSAEConfig) -> Callable[[torch.Tensor], torch.Tensor]: # type: ignore
def activation_function_factory(self, cfg: BaseSAEConfig) -> Callable[[torch.Tensor], torch.Tensor]:
assert cfg.act_fn.lower() in [
"relu",
"topk",
Expand Down Expand Up @@ -127,7 +127,9 @@ def topk_activation(x: torch.Tensor):

return topk_activation

def compute_norm_factor(self, x: torch.Tensor, hook_point: str) -> torch.Tensor: # type: ignore
raise ValueError(f"Not implemented activation function {cfg.act_fn}")

def compute_norm_factor(self, x: torch.Tensor, hook_point: str) -> torch.Tensor:
"""Compute the normalization factor for the activation vectors.
This should be called during forward pass.
There are four modes for norm_activation:
Expand Down Expand Up @@ -157,6 +159,7 @@ def compute_norm_factor(self, x: torch.Tensor, hook_point: str) -> torch.Tensor:
)
if self.cfg.norm_activation == "inference":
return torch.tensor(1.0, device=x.device, dtype=x.dtype)
raise ValueError(f"Not implemented norm_activation {self.cfg.norm_activation}")

@torch.no_grad()
def _set_decoder_to_fixed_norm(self, decoder: torch.nn.Linear, value: float, force_exact: bool):
Expand Down Expand Up @@ -371,26 +374,51 @@ def encode(
Float[torch.Tensor, "batch seq_len d_sae"],
],
],
]: # should be overridden by subclasses
]:
"""Encode input tensor through the sparse autoencoder.
Args:
x: Input tensor of shape (batch, d_model) or (batch, seq_len, d_model)
return_hidden_pre: If True, also return the pre-activation hidden states
Returns:
If return_hidden_pre is False:
Feature activations tensor of shape (batch, d_sae) or (batch, seq_len, d_sae)
If return_hidden_pre is True:
Tuple of (feature_acts, hidden_pre) where both have shape (batch, d_sae) or (batch, seq_len, d_sae)
"""
# Apply input normalization based on config
input_norm_factor = self.compute_norm_factor(x, hook_point=self.cfg.hook_point_in)
x = x * input_norm_factor

# Optionally subtract decoder bias before encoding
if self.cfg.use_decoder_bias and self.cfg.apply_decoder_bias_to_pre_encoder:
# We need to convert decoder bias to a tensor before subtracting
bias = self.decoder.bias.to_local() if isinstance(self.decoder.bias, DTensor) else self.decoder.bias
x = x - bias

# Pass through encoder
hidden_pre = self.encoder(x)
# Apply GLU if configured
if self.cfg.use_glu_encoder:
hidden_pre_glu = torch.sigmoid(self.encoder_glu(x))
hidden_pre = hidden_pre * hidden_pre_glu

hidden_pre = self.hook_hidden_pre(hidden_pre)

# Scale feature activations by decoder norm if configured
if self.cfg.sparsity_include_decoder_norm:
true_feature_acts = hidden_pre * self._decoder_norm(decoder=self.decoder)
sparsity_scores = hidden_pre * self._decoder_norm(decoder=self.decoder)
else:
true_feature_acts = hidden_pre
activation_mask = self.activation_function(true_feature_acts)
sparsity_scores = hidden_pre

# Apply activation function. The activation function here differs from a common activation function,
# since it computes a scaling of the input tensor, which is, suppose the common activation function
# is $f(x)$, then here it computes $f(x) / x$. For simple ReLU case, it computes a mask of 1s and 0s.
activation_mask = self.activation_function(sparsity_scores)
feature_acts = hidden_pre * activation_mask
feature_acts = self.hook_feature_acts(feature_acts)

if return_hidden_pre:
return feature_acts, hidden_pre
return feature_acts
Expand Down
64 changes: 27 additions & 37 deletions tests/unit/test_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,30 @@ def generator(sae_config: SAEConfig) -> torch.Generator:
@pytest.fixture
def sae(sae_config: SAEConfig, generator: torch.Generator) -> SparseAutoEncoder:
sae = SparseAutoEncoder(sae_config)
sae.encoder.weight.data = torch.randn(
sae_config.d_sae, sae_config.d_model, generator=generator, device=sae_config.device, dtype=sae_config.dtype
sae.encoder.weight.data = torch.tensor(
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]],
requires_grad=True,
dtype=sae_config.dtype,
device=sae_config.device,
)
sae.decoder.weight.data = torch.randn(
sae_config.d_model, sae_config.d_sae, generator=generator, device=sae_config.device, dtype=sae_config.dtype
sae.encoder.bias.data = torch.tensor(
[3.0, 2.0, 3.0, 4.0],
requires_grad=True,
dtype=sae_config.dtype,
device=sae_config.device,
)
sae.decoder.weight.data = torch.tensor(
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
requires_grad=True,
dtype=sae_config.dtype,
device=sae_config.device,
)
sae.decoder.bias.data = torch.tensor(
[1.0, 2.0],
requires_grad=True,
dtype=sae_config.dtype,
device=sae_config.device,
)
if sae_config.use_decoder_bias:
sae.decoder.bias.data = torch.randn(
sae_config.d_model, generator=generator, device=sae_config.device, dtype=sae_config.dtype
)
if sae_config.use_glu_encoder:
sae.encoder_glu.weight.data = torch.randn(
sae_config.d_sae, sae_config.d_model, generator=generator, device=sae_config.device, dtype=sae_config.dtype
)
sae.encoder_glu.bias.data = torch.randn(
sae_config.d_sae, generator=generator, device=sae_config.device, dtype=sae_config.dtype
)
return sae


Expand Down Expand Up @@ -196,27 +203,8 @@ def test_get_full_state_dict(sae_config: SAEConfig, sae: SparseAutoEncoder):

def test_standardize_parameters_of_dataset_norm(sae_config: SAEConfig, sae: SparseAutoEncoder):
sae_config.norm_activation = "dataset-wise"
sae.encoder.bias.data = torch.tensor(
[[1.0, 2.0]],
requires_grad=True,
dtype=sae_config.dtype,
device=sae_config.device,
)
encoder_bias_data = sae.encoder.bias.data.clone()
sae.decoder.weight.data = torch.tensor(
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
requires_grad=True,
dtype=sae_config.dtype,
device=sae_config.device,
)
decoder_weight_data = sae.decoder.weight.data.clone()
if sae_config.use_decoder_bias:
sae.decoder.bias.data = torch.tensor(
[[1.0, 2.0, 3.0, 4.0]],
requires_grad=True,
dtype=sae_config.dtype,
device=sae_config.device,
)
decoder_bias_data = sae.decoder.bias.data.clone()
sae.standardize_parameters_of_dataset_norm({"in": 3.0, "out": 2.0})
assert sae.cfg.norm_activation == "inference"
Expand All @@ -237,6 +225,8 @@ def test_standardize_parameters_of_dataset_norm(sae_config: SAEConfig, sae: Spar


def test_forward(sae_config: SAEConfig, sae: SparseAutoEncoder):
sae.set_dataset_average_activation_norm({"in": 3.0, "out": 2.0})
output = sae.forward(torch.tensor([[1.0, 2.0]], device=sae_config.device, dtype=sae_config.dtype))
assert output.shape == (1, 2)
sae.set_dataset_average_activation_norm(
{"in": 2.0 * math.sqrt(sae_config.d_model), "out": 1.0 * math.sqrt(sae_config.d_model)}
)
output = sae.forward(torch.tensor([[4.0, 4.0]], device=sae_config.device, dtype=sae_config.dtype))
assert torch.allclose(output, torch.tensor([[69.0, 146.0]], device=sae_config.device, dtype=sae_config.dtype))

0 comments on commit 4eb40ba

Please sign in to comment.