Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train llama 3.1 with GRIT #60

Open
ThisisXXZ opened this issue Nov 9, 2024 · 5 comments
Open

Train llama 3.1 with GRIT #60

ThisisXXZ opened this issue Nov 9, 2024 · 5 comments

Comments

@ThisisXXZ
Copy link

I'm now trying to train llama3.1 with GRIT pipeline.

At first I directly change --model_name_or_path and run the training code (the training script I used is as follows)

#!/bin/bash
#SBATCH --time=6:00:00
#SBATCH --job-name=grit_train
#SBATCH --gres=gpu:h100-96:2
#SBATCH --mem=60G
#SBATCH --output=/home/e/e1347696/unified_encoder_decoder/logs/grit_train_out.log
#SBATCH --error=/home/e/e1347696/unified_encoder_decoder/logs/grit_train_err.log

source ~/.bashrc
conda activate grit_eval

export CUDA_HOME='/usr/local/cuda-12.1'
# CUDA_VISIBLE_DEVICES=$(python train/gritlm/mig_uuid_setup.py)
export CUDA_VISIBLE_DEVICES=0,1

cd /home/e/e1347696/unified_encoder_decoder

# nvidia-smi 

deepspeed \
    --num_gpus=2 \
    --module train.gritlm.training.run \
    --output_dir results/GritLM-7B-training \
    --model_name_or_path model/Llama-3.1-8B \
    --train_data data/grit_training_data \
    --max_example_num_per_dataset 1000 \
    --learning_rate 2e-5 \
    --lr_scheduler_type linear \
    --warmup_ratio 0.03 \
    --max_steps 1253 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 256 \
    --per_device_generative_bs 32 \
    --dataloader_drop_last \
    --normalized \
    --temperature 0.02 \
    --train_group_size 2 \
    --negatives_cross_device \
    --query_max_len 256 \
    --passage_max_len 1024 \
    --mode unified \
    --logging_steps 1 \
    --bf16 \
    --pooling_method mean \
    --use_unique_indices \
    --loss_gen_type mixed \
    --attn bbcc \
    --attn_implementation sdpa \
    --no_gen_gas \
    --gradient_checkpointing \
    --save_steps 1000 \
    --split_emb \
    --deepspeed scripts/configs/config_8gpusds_m7.json

But there is an error TypeError: LlamaModel.forward() got an unexpected keyword argument 'is_causal'. I looked into it and found several issues regarding this #34, #32 and #19.
Just to confirm, if I want to train llama 3.1 model with GRIT, can I just

  • reuse the provided modeling file directly by putting modeling_gritlm7b.py into llama3.1 model folder
    or do I need to
  • change the modeling file for llama3.1 so that it could accept is_causal arg and thus influence attention behavior?

Thank you so much!

@Muennighoff
Copy link
Collaborator

change the modeling file for llama3.1 so that it could accept is_causal arg and thus influence attention behavior?

@ThisisXXZ
Copy link
Author

change the modeling file for llama3.1 so that it could accept is_causal arg and thus influence attention behavior?

I thought is_causal is an argument controlling whether we are using bidirectional attention in the model or not, since the original modeling file does not accept such argument, do we need to implement this for it?

I’m still learning, so please kindly correct me if I’m mistaken. Thank you so much!

@Muennighoff
Copy link
Collaborator

yes you're right; i meant to say that that is the option you have to go with; sorry i should have removed the ?

@ThisisXXZ
Copy link
Author

yes you're right; i meant to say that that is the option you have to go with; sorry i should have removed the ?

Thank you so much! Nah you don't need to remove "the" it's just me don't have much confidence in that :)

I've checked the code for modeling_mistral_gritlm.py and had few other questions

  • Do I need to modify other settings despite the implementation of is_causal arg?
  • I noticed that flash attention is used in modeling_mistral_gritlm.py but not in llama3 modeling file. So if I'm going to implement my own version of is_causal arg, can I just simply add the bidirectional component in Attention class?

Really appreciate your prompted guidance, even on weekends! You are indeed one of the most helpful and fastest responding authors I have contacted.

@Muennighoff
Copy link
Collaborator

  1. only is_causal
  2. yeah any attention mechanism should be fine as long as you implement the masking for it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants