You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
My own task or dataset (give details below)
Reproduction (minimal, reproducible, runnable)
I am following the tutorial here: https://huggingface.co/docs/optimum-neuron/en/training_tutorials/sft_lora_finetune_llm
I have been using the training script found here: https://github.com/huggingface/optimum-neuron/blob/main/docs/source/training_tutorials/sft_lora_finetune_llm.py
I used a trn1.2xlarge instance with 2 neuron cores to train a Llama 3.1 8B using LoRA using tensor parallelism with a degree of 2. However, training loss is very high compared to the same model with same parameters being trained on a single RTX 3090. The training losses look like this:
I ran these experiments using databricks/databricks-dolly-15k and timdettmers/openassistant-guanaco
I also changed the tokenize function under _prepare_non_packed_dataloader in trl/trainer/sft_trainer so that it pads every sample to max_length so it behaves the same as optimum-neuron.
My training script for the trn1.2xlarge instance (for dolly dataset, for openassistant dataset I change the formatting function so it just returns examples["text"] directly:
train.py
fromdataclassesimportdataclass, fieldfromdatasetsimportload_from_disk, load_dataset, DatasetfrompeftimportLoraConfigfromtransformersimport (
AutoModelForCausalLM,
AutoTokenizer,
set_seed,
)
fromoptimum.neuronimportNeuronHfArgumentParserasHfArgumentParserfromoptimum.neuronimportNeuronSFTConfig, NeuronSFTTrainer, NeuronTrainingArgumentsfromoptimum.neuron.distributedimportlazy_load_for_parallelismimporttorchfromhuggingface_hubimportloginimportosos.environ["WANDB_PROJECT"] ="my_project"os.environ["WANDB_LOG_MODEL"] ="false"os.environ["WANDB_WATCH"] ="all"defformat_dolly(examples):
output_text= []
foriinrange(len(examples["instruction"])):
instruction=f"### Instruction\n{examples['instruction'][i]}"context=f"### Context\n{examples['context'][i]}"iflen(examples["context"][i]) >0elseNoneresponse=f"### Answer\n{examples['response'][i]}"prompt="\n\n".join([iforiin [instruction, context, response] ifiisnotNone])
output_text.append(prompt)
returnoutput_textdeftraining_function(script_args, training_args):
dataset=load_dataset("databricks/databricks-dolly-15k")
tokenizer=AutoTokenizer.from_pretrained(script_args.model_id)
tokenizer.pad_token=tokenizer.eos_tokenconfig=LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
target_modules=["q_proj", "gate_proj", "v_proj", "o_proj", "k_proj", "up_proj", "down_proj"],
bias="none",
task_type="CAUSAL_LM",
)
args=training_args.to_dict()
sft_config=NeuronSFTConfig(
#max_seq_length=1024,#packing=False,**args,
)
withlazy_load_for_parallelism(tensor_parallel_size=training_args.tensor_parallel_size):
model=AutoModelForCausalLM.from_pretrained(script_args.model_id)
trainer=NeuronSFTTrainer(
args=sft_config,
model=model,
peft_config=config,
tokenizer=tokenizer,
train_dataset=dataset,
formatting_func=format_dolly
)
# Start training#print(trainer.evaluate())trainer.train()
trainer.save_model() # Saves the tokenizer too for easy upload@dataclassclassScriptArguments:
model_id: str=field(
default="meta-llama/Llama-3.1-8B",
metadata={"help": "The model that you want to train from the Hugging Face hub."},
)
defmain():
parser=HfArgumentParser([ScriptArguments, NeuronTrainingArguments])
script_args, training_args=parser.parse_args_into_dataclasses()
# set seedset_seed(training_args.seed)
# run training functiontraining_function(script_args, training_args)
if__name__=="__main__":
main()
Disabling embedding parallelization on the Trainium instance lowers the training loss but it is still consistently higher than the loss on the RTX 3090. Also, with embedding parallelization enabled the model is saved incorrectly. Trained model with embedding parallelization has additional layers base_model.model.lm_head.weight and base_model.model.model.embed_tokens.weight . Additionally only half ofbase_model.model.model.embed_tokens.weight is saved (shape is (64128, 4096) instead of (128256, 4096)) but perhaps this should be another issue.
Expected behavior
I expect the training loss to be much closer to the loss I get when I train the model on an RTX 3090 instead of 2 trainium neuron cores.
The text was updated successfully, but these errors were encountered:
System Info
Who can help?
@michaelbenayoun
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction (minimal, reproducible, runnable)
I am following the tutorial here: https://huggingface.co/docs/optimum-neuron/en/training_tutorials/sft_lora_finetune_llm
I have been using the training script found here: https://github.com/huggingface/optimum-neuron/blob/main/docs/source/training_tutorials/sft_lora_finetune_llm.py
I used a trn1.2xlarge instance with 2 neuron cores to train a Llama 3.1 8B using LoRA using tensor parallelism with a degree of 2. However, training loss is very high compared to the same model with same parameters being trained on a single RTX 3090. The training losses look like this:
I ran these experiments using databricks/databricks-dolly-15k and timdettmers/openassistant-guanaco
I also changed the
tokenize
function under_prepare_non_packed_dataloader
intrl/trainer/sft_trainer
so that it pads every sample to max_length so it behaves the same as optimum-neuron.My training script for the trn1.2xlarge instance (for dolly dataset, for openassistant dataset I change the formatting function so it just returns
examples["text"]
directly:train.py
My bash script for graph compilation:
compile.sh
and my bash script for training:
train.sh
The script I use to train on RTX 3090:
train.py
Disabling embedding parallelization on the Trainium instance lowers the training loss but it is still consistently higher than the loss on the RTX 3090. Also, with embedding parallelization enabled the model is saved incorrectly. Trained model with embedding parallelization has additional layers
base_model.model.lm_head.weight
andbase_model.model.model.embed_tokens.weight
. Additionally only half ofbase_model.model.model.embed_tokens.weight
is saved (shape is (64128, 4096) instead of (128256, 4096)) but perhaps this should be another issue.Expected behavior
I expect the training loss to be much closer to the loss I get when I train the model on an RTX 3090 instead of 2 trainium neuron cores.
The text was updated successfully, but these errors were encountered: