Skip to content

Commit

Permalink
fix(topk activation): add keepdim=True to enable broadcasting; make d… (
Browse files Browse the repository at this point in the history
#73)

* fix(topk activation): add keepdim=True to enable broadcasting; make dtype consistent without hardcode

* fix(topk activation): add keepdim=True to enable broadcasting; make dtype consistent without hardcode

* fix(topk activation): add keepdim=True to enable broadcasting; make dtype consistent without hardcode
  • Loading branch information
Hzfinfdu authored Jan 17, 2025
1 parent 24dc841 commit 4faf9f6
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 28 deletions.
10 changes: 5 additions & 5 deletions src/lm_saes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,15 @@ def save_hyperparameters(self, sae_path: Path | str, remove_loading_info: bool =


class SAEConfig(BaseSAEConfig):
sae_type: Literal["sae", "crosscoder", "mixcoder"] = 'sae'
sae_type: Literal["sae", "crosscoder", "mixcoder"] = "sae"


class CrossCoderConfig(BaseSAEConfig):
sae_type: Literal["sae", "crosscoder", "mixcoder"] = 'crosscoder'
sae_type: Literal["sae", "crosscoder", "mixcoder"] = "crosscoder"


class MixCoderConfig(BaseSAEConfig):
sae_type: Literal["sae", "crosscoder", "mixcoder"] = 'mixcoder'
sae_type: Literal["sae", "crosscoder", "mixcoder"] = "mixcoder"
d_single_modal: int
d_shared: int
n_modalities: int = 2
Expand Down
17 changes: 5 additions & 12 deletions src/lm_saes/crosscoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,18 @@ class CrossCoder(SparseAutoEncoder):
def __init__(self, cfg: BaseSAEConfig):
super(CrossCoder, self).__init__(cfg)

def _decoder_norm(
self,
decoder: torch.nn.Linear,
keepdim: bool = False,
local_only=True,
aggregate="none"
):
def _decoder_norm(self, decoder: torch.nn.Linear, keepdim: bool = False, local_only=True, aggregate="none"):
decoder_norm = super()._decoder_norm(
decoder=decoder,
keepdim=keepdim,
)
if not local_only:
decoder_norm = all_reduce_tensor(
decoder_norm,
decoder_norm,
aggregate=aggregate,
)
return decoder_norm

@overload
def encode(
self,
Expand Down Expand Up @@ -110,7 +104,7 @@ def encode(

hidden_pre = all_reduce_tensor(hidden_pre, aggregate="sum")
hidden_pre = self.hook_hidden_pre(hidden_pre)

if self.cfg.sparsity_include_decoder_norm:
true_feature_acts = hidden_pre * self._decoder_norm(
decoder=self.decoder,
Expand All @@ -127,7 +121,7 @@ def encode(
if return_hidden_pre:
return feature_acts, hidden_pre
return feature_acts

@overload
def compute_loss(
self,
Expand Down Expand Up @@ -229,4 +223,3 @@ def initialize_with_same_weight_across_layers(self):
self.encoder.bias.data = get_tensor_from_specific_rank(self.encoder.bias.data.clone(), src=0)
self.decoder.weight.data = get_tensor_from_specific_rank(self.decoder.weight.data.clone(), src=0)
self.decoder.bias.data = get_tensor_from_specific_rank(self.decoder.bias.data.clone(), src=0)

24 changes: 15 additions & 9 deletions src/lm_saes/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,31 +92,37 @@ def _decoder_norm(self, decoder: torch.nn.Linear, keepdim: bool = False):
return decoder_norm

def activation_function_factory(self, cfg: BaseSAEConfig) -> Callable[[torch.Tensor], torch.Tensor]: # type: ignore
assert cfg.act_fn.lower() in ["relu", "topk", "jumprelu", "batchtopk"], f"Not implemented activation function {cfg.act_fn}"
assert cfg.act_fn.lower() in [
"relu",
"topk",
"jumprelu",
"batchtopk",
], f"Not implemented activation function {cfg.act_fn}"
if cfg.act_fn.lower() == "relu":
return lambda x: x.gt(0).float()
return lambda x: x.gt(0).to(x.dtype)
elif cfg.act_fn.lower() == "jumprelu":
return lambda x: x.gt(cfg.jump_relu_threshold).float()
return lambda x: x.gt(cfg.jump_relu_threshold).to(x.dtype)
elif cfg.act_fn.lower() == "topk":

def topk_activation(x: torch.Tensor):
x = torch.clamp(x, min=0.0)
k = x.shape[-1] - self.current_k + 1
k_th_value, _ = torch.kthvalue(x, k=k, dim=-1)
return x.ge(k_th_value).float()
k_th_value, _ = torch.kthvalue(x, k=k, dim=-1, keepdim=True)
return x.ge(k_th_value).to(x.dtype)

return topk_activation

elif cfg.act_fn.lower() == "batchtopk":

def topk_activation(x: torch.Tensor):
assert x.dim() == 2
batch_size = x.size(0)

x = torch.clamp(x, min=0.0)
k = x.numel() - self.current_k * batch_size + 1
k_th_value, _ = torch.kthvalue(x.flatten(), k=k, dim=-1)
return x.ge(k_th_value).float()
return x.ge(k_th_value).to(x.dtype)

return topk_activation

def compute_norm_factor(self, x: torch.Tensor, hook_point: str) -> torch.Tensor: # type: ignore
Expand Down
21 changes: 19 additions & 2 deletions tests/unit/test_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,28 @@ def set_encoder_norm(norm: float):

def test_sae_activate_fn(sae_config: SAEConfig, sae: SparseAutoEncoder):
sae.current_k = 2
print(
sae.activation_function(
torch.tensor(
[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]],
device=sae_config.device,
dtype=sae_config.dtype,
)
)
)
assert torch.allclose(
sae.activation_function(
torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]], device=sae_config.device, dtype=sae_config.dtype)
torch.tensor(
[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [5.0, 6.0, 1.0, 2.0, 3.0, 4.0]],
device=sae_config.device,
dtype=sae_config.dtype,
)
).to(sae_config.device, sae_config.dtype),
torch.tensor([[0.0, 0.0, 0.0, 0.0, 1.0, 1.0]], device=sae_config.device, dtype=sae_config.dtype),
torch.tensor(
[[0.0, 0.0, 0.0, 0.0, 1.0, 1.0], [1.0, 1.0, 0.0, 0.0, 0.0, 0.0]],
device=sae_config.device,
dtype=sae_config.dtype,
),
atol=1e-4,
rtol=1e-5,
)
Expand Down

0 comments on commit 4faf9f6

Please sign in to comment.