Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

training loss while fine-tuning llama 3.1 with lora is very high compared to rtx 3090 #721

Open
4 tasks done
anilozlu opened this issue Oct 18, 2024 · 1 comment
Open
4 tasks done
Labels
bug Something isn't working

Comments

@anilozlu
Copy link

anilozlu commented Oct 18, 2024

System Info

using Huggingface AMI from AWS marketplace with Ubuntu 22.04
optimum-neuron 0.0.25
transformers 4.45.2
peft 0.13.0
trl 0.11.4
accelerate 0.29.2
torch 2.1.2

Who can help?

@michaelbenayoun

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction (minimal, reproducible, runnable)

I am following the tutorial here: https://huggingface.co/docs/optimum-neuron/en/training_tutorials/sft_lora_finetune_llm
I have been using the training script found here: https://github.com/huggingface/optimum-neuron/blob/main/docs/source/training_tutorials/sft_lora_finetune_llm.py
I used a trn1.2xlarge instance with 2 neuron cores to train a Llama 3.1 8B using LoRA using tensor parallelism with a degree of 2. However, training loss is very high compared to the same model with same parameters being trained on a single RTX 3090. The training losses look like this:
combined
I ran these experiments using databricks/databricks-dolly-15k and timdettmers/openassistant-guanaco
I also changed the tokenize function under _prepare_non_packed_dataloader in trl/trainer/sft_trainer so that it pads every sample to max_length so it behaves the same as optimum-neuron.
My training script for the trn1.2xlarge instance (for dolly dataset, for openassistant dataset I change the formatting function so it just returns examples["text"] directly:

train.py
from dataclasses import dataclass, field

from datasets import load_from_disk, load_dataset, Dataset
from peft import LoraConfig
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    set_seed,
)

from optimum.neuron import NeuronHfArgumentParser as HfArgumentParser
from optimum.neuron import NeuronSFTConfig, NeuronSFTTrainer, NeuronTrainingArguments
from optimum.neuron.distributed import lazy_load_for_parallelism

import torch
from huggingface_hub import login

import os

os.environ["WANDB_PROJECT"] = "my_project"
os.environ["WANDB_LOG_MODEL"] = "false"
os.environ["WANDB_WATCH"] = "all"


def format_dolly(examples):
    output_text = []
    for i in range(len(examples["instruction"])):
        instruction = f"### Instruction\n{examples['instruction'][i]}"
        context = f"### Context\n{examples['context'][i]}" if len(examples["context"][i]) > 0 else None
        response = f"### Answer\n{examples['response'][i]}"
        prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None])
        output_text.append(prompt)
    return output_text



def training_function(script_args, training_args):
    dataset = load_dataset("databricks/databricks-dolly-15k")

    tokenizer = AutoTokenizer.from_pretrained(script_args.model_id)
    tokenizer.pad_token = tokenizer.eos_token

    config = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        target_modules=["q_proj", "gate_proj", "v_proj", "o_proj", "k_proj", "up_proj", "down_proj"],
        bias="none",
        task_type="CAUSAL_LM",
    )

    args = training_args.to_dict()

    sft_config = NeuronSFTConfig(
        #max_seq_length=1024,
        #packing=False,
        **args,
    )

    with lazy_load_for_parallelism(tensor_parallel_size=training_args.tensor_parallel_size):
        model = AutoModelForCausalLM.from_pretrained(script_args.model_id)

    trainer = NeuronSFTTrainer(
        args=sft_config,
        model=model,
        peft_config=config,
        tokenizer=tokenizer,
        train_dataset=dataset,
        formatting_func=format_dolly
    )
    

    # Start training
    #print(trainer.evaluate())
    trainer.train()

    trainer.save_model()  # Saves the tokenizer too for easy upload


@dataclass
class ScriptArguments:
    model_id: str = field(
        default="meta-llama/Llama-3.1-8B",
        metadata={"help": "The model that you want to train from the Hugging Face hub."},
    )


def main():
    parser = HfArgumentParser([ScriptArguments, NeuronTrainingArguments])
    script_args, training_args = parser.parse_args_into_dataclasses()

    # set seed
    set_seed(training_args.seed)

    # run training function
    training_function(script_args, training_args)


if __name__ == "__main__":
    main()

My bash script for graph compilation:

compile.sh
#!/bin/bash
set -ex

MODEL_NAME="meta-llama/Llama-3.1-8B"
huggingface-cli download $MODEL_NAME --exclude "original/*" --token TOKEN

export NEURON_FUSE_SOFTMAX=1
export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3
export MALLOC_ARENA_MAX=64
export NEURON_CC_FLAGS="--model-type=transformer --distribution-strategy=llm-training --enable-saturate-infinity --cache_dir=/home/ubuntu/cache_dir_neuron/"

PROCESSES_PER_NODE=2

NUM_EPOCHS=1
TP_DEGREE=2
PP_DEGREE=1
BS=1
GRADIENT_ACCUMULATION_STEPS=8
LOGGING_STEPS=10

OUTPUT_DIR="trn1.2xlarge_databricks-dolly-15k"

MAX_STEPS=25

XLA_USE_BF16=1 neuron_parallel_compile torchrun --nproc_per_node $PROCESSES_PER_NODE train.py \
  --model_id $MODEL_NAME \
  --num_train_epochs $NUM_EPOCHS \
  --do_train \
  --learning_rate 5e-5 \
  --warmup_ratio 0.03 \
  --max_steps $MAX_STEPS \
  --per_device_train_batch_size $BS \
  --per_device_eval_batch_size $BS \
  --gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \
  --gradient_checkpointing true \
  --bf16 \
  --zero_1 false \
  --tensor_parallel_size $TP_DEGREE \
  --pipeline_parallel_size $PP_DEGREE \
  --logging_steps $LOGGING_STEPS \
  --save_total_limit 1 \
  --output_dir $OUTPUT_DIR \
  --lr_scheduler_type "constant" \
  --overwrite_output_dir \
  --report_to "none"

rm -rf $OUTPUT_DIR

and my bash script for training:

train.sh
#!/bin/bash
set -ex

MODEL_NAME="meta-llama/Llama-3.1-8B"
HF_TOKEN="TOKEN"

export NEURON_FUSE_SOFTMAX=1
export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3
export MALLOC_ARENA_MAX=64
export NEURON_CC_FLAGS="--model-type=transformer --distribution-strategy=llm-training --enable-saturate-infinity --cache_dir=/home/ubuntu/cache_dir_neuron/"

PROCESSES_PER_NODE=2

NUM_EPOCHS=1
TP_DEGREE=2
PP_DEGREE=1
BS=1
GRADIENT_ACCUMULATION_STEPS=8
LOGGING_STEPS=10

OUTPUT_DIR="trn1.2xlarge_databricks-dolly-15k"

MAX_STEPS=200

XLA_USE_BF16=1 torchrun --nproc_per_node $PROCESSES_PER_NODE train.py \
  --model_id $MODEL_NAME \
  --num_train_epochs $NUM_EPOCHS \
  --do_train \
  --learning_rate 5e-5 \
  --warmup_ratio 0.03 \
  --max_steps $MAX_STEPS \
  --per_device_train_batch_size $BS \
  --per_device_eval_batch_size $BS \
  --gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \
  --gradient_checkpointing true \
  --bf16 \
  --zero_1 false \
  --tensor_parallel_size $TP_DEGREE \
  --pipeline_parallel_size $PP_DEGREE \
  --logging_steps $LOGGING_STEPS \
  --save_total_limit 1 \
  --output_dir $OUTPUT_DIR \
  --lr_scheduler_type "constant" \
  --overwrite_output_dir \
  --report_to "wandb" \
  --run_name $OUTPUT_DIR \

The script I use to train on RTX 3090:

train.py
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"

os.environ["WANDB_PROJECT"] = "my_project"
os.environ["WANDB_LOG_MODEL"] = "false"
os.environ["WANDB_WATCH"] = "all"

from datasets import load_dataset
from peft import LoraConfig, TaskType, AutoPeftModelForCausalLM
from trl import SFTTrainer, SFTConfig
from transformers import AutoTokenizer, TrainingArguments, AutoModelForCausalLM
from accelerate import init_empty_weights
import torch

model_path = "meta-llama/Llama-3.1-8B"
dataset_path = "databricks/databricks-dolly-15k"

output_dir = model_path.split("/")[-1] + "-" + dataset_path.split("/")[-1]
run_name = "3090" + "_" + model_path.split("/")[-1] + "_" + dataset_path.split("/")[-1]

tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

def format_dolly(example):
    instruction = f"### Instruction\n{example['instruction']}"
    context = f"### Context\n{example['context']}" if len(example["context"]) > 0 else None
    response = f"### Answer\n{example['response']}"
    prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None])
    return {"text": prompt}

dataset = load_dataset(dataset_path, split="train")
dataset = dataset.map(format_dolly)

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=["q_proj", "gate_proj", "v_proj", "o_proj", "k_proj", "up_proj", "down_proj"],
    bias="none",
    task_type="CAUSAL_LM",
)

sft_config = SFTConfig(
    do_train=True,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    save_total_limit=1,
    bf16=True,
    max_seq_length=1024,
    output_dir=run_name,
    dataset_text_field="text",
    learning_rate=5e-05,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    gradient_checkpointing=True,
    logging_steps=10,
    report_to="wandb",
    run_name=run_name,
    num_train_epochs=1,
    max_steps=200
)

with init_empty_weights():
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", torch_dtype=torch.bfloat16)

trainer = SFTTrainer(
    model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    args=sft_config,
    peft_config=lora_config,
    packing=False
)

trainer.train()

Disabling embedding parallelization on the Trainium instance lowers the training loss but it is still consistently higher than the loss on the RTX 3090. Also, with embedding parallelization enabled the model is saved incorrectly. Trained model with embedding parallelization has additional layers base_model.model.lm_head.weight and base_model.model.model.embed_tokens.weight . Additionally only half ofbase_model.model.model.embed_tokens.weight is saved (shape is (64128, 4096) instead of (128256, 4096)) but perhaps this should be another issue.

Expected behavior

I expect the training loss to be much closer to the loss I get when I train the model on an RTX 3090 instead of 2 trainium neuron cores.

@anilozlu anilozlu added the bug Something isn't working label Oct 18, 2024
@anilozlu
Copy link
Author

Sorry if this is a double ping but I think I made a typo with your handle the first time, @michaelbenayoun can you help with this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant