Skip to content

Commit

Permalink
Merge pull request #72 from OpenMOSS/cc_upd
Browse files Browse the repository at this point in the history
re-implement crosscoders
  • Loading branch information
dest1n1s authored Jan 16, 2025
2 parents a5bc151 + 62c34c5 commit 24dc841
Show file tree
Hide file tree
Showing 4 changed files with 245 additions and 387 deletions.
10 changes: 8 additions & 2 deletions src/lm_saes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,14 @@ class BaseSAEConfig(BaseModelConfig):
So this class should not be used directly but only as a base config class for other SAE variants like SAEConfig, MixCoderConfig, CrossCoderConfig, etc.
"""

sae_type: Literal["sae", "crosscoder", "mixcoder"]
hook_point_in: str
hook_point_out: str = Field(default_factory=lambda validated_model: validated_model["hook_point_in"])
d_model: int
expansion_factor: int
use_decoder_bias: bool = True
use_glu_encoder: bool = False
act_fn: str = "relu"
act_fn: Literal["relu", "jumprelu", "topk", "batchtopk"] = "relu"
jump_relu_threshold: float = 0.0
apply_decoder_bias_to_pre_encoder: bool = True
norm_activation: str = "dataset-wise"
Expand Down Expand Up @@ -94,10 +95,15 @@ def save_hyperparameters(self, sae_path: Path | str, remove_loading_info: bool =


class SAEConfig(BaseSAEConfig):
pass
sae_type: Literal["sae", "crosscoder", "mixcoder"] = 'sae'


class CrossCoderConfig(BaseSAEConfig):
sae_type: Literal["sae", "crosscoder", "mixcoder"] = 'crosscoder'


class MixCoderConfig(BaseSAEConfig):
sae_type: Literal["sae", "crosscoder", "mixcoder"] = 'mixcoder'
d_single_modal: int
d_shared: int
n_modalities: int = 2
Expand Down
Loading

0 comments on commit 24dc841

Please sign in to comment.