Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Request: Support custom hf_model in config / loaders #436

Open
hijohnnylin opened this issue Feb 25, 2025 · 0 comments
Open

Feature Request: Support custom hf_model in config / loaders #436

hijohnnylin opened this issue Feb 25, 2025 · 0 comments

Comments

@hijohnnylin
Copy link
Collaborator

Problem:

  1. Some SAEs are trained on models distilled from other models, and are not in the TransformerLens model list. For example, the deepseek distilled llama 8B model has a "base" model of Llama 3.1 8b, but we have to specify a huggingface path for the distilled model.
  2. These models must be loaded in a special way in Transformerlens, something like:
        hf_model = None
        hf_tokenizer = None
        if custom_hf_model_id is not None:
            logger.info("Loading custom HF model: %s", custom_hf_model_id)
            hf_model = AutoModelForCausalLM.from_pretrained(
                custom_hf_model_id,
            )
            hf_tokenizer = AutoTokenizer.from_pretrained(custom_hf_model_id)

        model = HookedTransformer.from_pretrained_no_processing(
            transformerlens_model_id,
            device=args.device,
            dtype=STR_TO_DTYPE[config.MODEL_DTYPE],
            n_devices=device_count,
            hf_model=hf_model,
            tokenizer=hf_tokenizer,
            **config.MODEL_KWARGS,
        )
  1. However, the user of the library does not know this when loading the SAE via SAELens. So they will just load the SAE with the wrong base model and probably see bad results.

Proposed fix:

  • An optional SAELens config property that specifies hf_model which indicates what hf_model the SAE was trained on. If this is specified, we either force the custom hf_model when loading with a pretrained loader, or throw error if it doesn't match (and give some useful example code like above).
  • Possibly an example notebook

An example loader that we would like to support is below - note that they specify the base llama model, but have nowhere to specify the custom hf_model or enforce it.

def get_llama_scope_r1_distill_config(
repo_id: str,
folder_name: str,
options: SAEConfigLoadOptions, # noqa: ARG001
) -> Dict[str, Any]:
# Future Llama Scope series SAE by OpenMoss group use this config.
# repo_id: [
# fnlp/Llama-Scope-R1-Distill
# ]
# folder_name: [
# 800M-Slimpajama-0-OpenR1-Math-220k/L{layer}R,
# 400M-Slimpajama-400M-OpenR1-Math-220k/L{layer}R,
# 0-Slimpajama-800M-OpenR1-Math-220k/L{layer}R,
# ]
config_path = folder_name + "/config.json"
config_path = hf_hub_download(repo_id, config_path)
with open(config_path) as f:
huggingface_cfg_dict = json.load(f)
# Model specific parameters
model_name, d_in = "meta-llama/Llama-3.1-8B", huggingface_cfg_dict["d_model"]
return {
"architecture": "jumprelu",
"d_in": d_in,
"d_sae": d_in * huggingface_cfg_dict["expansion_factor"],
"dtype": "float32",
"model_name": model_name,
"hook_name": huggingface_cfg_dict["hook_point_in"],
"hook_layer": int(huggingface_cfg_dict["hook_point_in"].split(".")[1]),
"hook_head_index": None,
"activation_fn_str": "relu",
"finetuning_scaling_factor": False,
"sae_lens_training_version": None,
"prepend_bos": True,
"dataset_path": "cerebras/SlimPajama-627B",
"context_size": 1024,
"dataset_trust_remote_code": True,
"apply_b_dec_to_input": False,
"normalize_activations": "expected_average_only_in",
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant