Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] More information about the training process: hyperparameters and token number #396

Open
TheMrguiller opened this issue Dec 3, 2024 · 6 comments

Comments

@TheMrguiller
Copy link

Hi,

I am new to mechanistic interpretability and am working on a toy project with GPT-2 using some examples. However, it seems I might be missing something, particularly regarding the number of tokens and the hyperparameters. Let me explain with a graph:
imagen
The hyperparameters that i am using is are the following:

cfg = LanguageModelSAERunnerConfig(
    
    model_name=args.model_name,  # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
    hook_name=f'blocks.{args.hook_layer}.hook_resid_mid',  # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
    hook_layer=args.hook_layer,  # Only one layer in the model.
    d_in=768,  # the width of the mlp output.
    dataset_path=args.dataset_path,  # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.
    is_dataset_tokenized=False,
    streaming=True,  # we could pre-download the token dataset if it was small.
    # SAE Parameters
    mse_loss_normalization=None,  # We won't normalize the mse loss,
    expansion_factor=64,  # the width of the SAE. Larger will result in better stats but slower training.
    b_dec_init_method="geometric_median",  # The geometric median can be used to initialize the decoder weights.
    apply_b_dec_to_input=False,  # We won't apply the decoder weights to the input.
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    normalize_activations="expected_average_only_in",
    # Training Parameters
    lr=5e-5,  # lower the better, we'll go fairly high to speed up the tutorial.
    adam_beta1=0.9,  # adam params (default, but once upon a time we experimented with these.)
    adam_beta2=0.999,
    lr_scheduler_name="constant",  # constant learning rate with warmup. Could be better schedules out there.
    lr_warm_up_steps=total_training_steps // 20,  # this can help avoid too many dead features initially.
    lr_decay_steps=total_training_steps // 5,  # this will help us avoid overfitting.
    l1_coefficient=1.7e-3,  # will control how sparse the feature activations are
    l1_warm_up_steps=total_training_steps // 20,  # this can help avoid too many dead features initially.
    lp_norm=1,  # the L1 penalty (and not a Lp for p < 1)
    lr_end=5e-6,  # we'll lower the learning rate at the end of training.
    train_batch_size_tokens=4096,
    context_size=1024,  # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.
    # Activation Store Parameters
    n_batches_in_buffer=128,  # controls how many activations we store / shuffle.
    training_tokens=total_training_tokens,  # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
    store_batch_size_prompts=16,
    # Resampling protocol
    use_ghost_grads=False,  # we don't use ghost grads anymore.
    feature_sampling_window=1000,  # this controls our reporting of feature sparsity stats
    dead_feature_window=1000,  # would effect resampling or ghost grads if we were using it.
    dead_feature_threshold=1e-4,  # would effect resampling or ghost grads if we were using it.
    # WANDB
    log_to_wandb=True,  # always use wandb unless you are just testing code.
    wandb_project=args.wandb_project,
    run_name=args.experiment_name,
    wandb_log_frequency=args.wandb_log_frequency,
    eval_every_n_wandb_logs=args.eval_every_n_wandb_logs,
    # Misc
    device=device,
    seed=args.seed,
    n_checkpoints=args.n_checkpoints,
    checkpoint_path=args.checkpoint_path,
    dtype="float32",
)

It seems I might be misunderstanding some aspects. Is the "max context" the maximum possible length of a single sentence, or something else? Should the "max tokens" refer to the total number of tokens in the entire dataset? How exactly does this affect the training process?

I am observing the same behavior across all my training experiments, regardless of the layer, which has left me quite confused.

Could someone guide me on this? I’ve been reading blogs and papers, but they often assume prior knowledge of these concepts without explaining them in detail.

Best regards,

@chanind
Copy link
Collaborator

chanind commented Dec 3, 2024

Hi,

What's the issue in the graph you'd like to point out? If you want to plot against tokens on the x-axis, you should be able to change that in wandb.

The context_size is the length of sequences passed to the GPT2. Batches of activations look like B x C x D where B is the batch dim, C is the context length and D is the hidden dimension of the model (assuming you're training on the residual stream. Generally, using the same context_size that the model was trained with is a good idea, but that's not always possible for some of the newer models with absolutely massive context sizes. Using a longer context size will mean you use more memory for a given batch size.

For instance, I believe GPT2 was trained with a context size of 1024 tokens. This means each sequence it was trained on had length 1024. If you use a shorter context size for your SAE, then you should make sure you don't use your SAE on a sequence longer than the sequence the SAE was trained on since you'll likely get bad results.

For instance, if you trained your SAE on context length 2 (just for demonstration purposes), then you could run the SAE on the text "Hi", but if you tried to run the SAE on the final token of "Hi there my name is", you'd likely get bad results since "is" is token position 4, but the SAE was trained with context size 2.

The "training_tokens" parameter is the total number of tokens the SAE will train on. Usually, the datasets used for training SAEs are the same as what you would train a normal LLM on, and thus have absolutely crazy numbers of tokens. Realistically, you probably won't need to train on more than 1 - 10 Billion tokens for a real SAE, and you can probably get away with 500M or less.

Not sure if this is what you're asking, but happy to answer if you have further questions!

@TheMrguiller
Copy link
Author

Thank you for your insights. I will try to explain myself and try to use these opportunity to learn more. Here goes my questions:

  • My problems seems that my SAE gets to a plateau only after seen approximately 600k tokens, what is really strange for me. It is true that the dataset is not that diverse as i am trying to do some study in AI Safety related corpus.
  • In the case of the max context size do you recommend on having sentences just of the max size or have diverse sizes?
  • Another question related to the dataset size, if the size of 169M tokens is too low should i add some other random text like the fineweb to add more corpus to it?

Again thank you so much i am a noobie in this.

@chanind
Copy link
Collaborator

chanind commented Dec 3, 2024

169M might be OK for GPT2, you can experiment. The things you should be looking at are explained variance and L0 (L0 means the mean number of active latents to reconstruct an input activation). What is the L0 of this SAE? It looks like the L1 coefficient you set is really low for SAELens, I'd try something closer to 1.0. If the L1 coefficient is too weak, then the SAE can get perfect reconstruction with really high L0 which might be what's happening here.

@TheMrguiller
Copy link
Author

You really know your plays. It is true that my values are quite high, the explained variance is nearly 1 and the L0 is 13k taking into account i have 50k latents. Why is it high for SAELens? I have been reading articles, on bigger models, and it seems that they use very small L1 coefficient.

@chanind
Copy link
Collaborator

chanind commented Dec 3, 2024

I'm not sure why the L1 coefficient in SAELens tends to be much larger than from other sources. I suspect SAELens is normalizing the L1 loss differently (either SAELens is taking the mean across the batch, or across the batch and context, while other sources are not or the inverse or something).

Regardless, if your L0 is higher than the number of dimensions of the SAE input (768 for GPT2), then the SAE can just learn the identiy matrix and get perfect reconstruction without doing anything interesting. You should probably be aiming for a L0 between 5-150 or so I would think. There's a trade-off there, the lower the L0 the more sparse the SAE, but the reconstruction will be worse.

@TheMrguiller
Copy link
Author

Thank you so much, i will try to follow your advice and I will try to keep you updated

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants