Skip to content

Commit

Permalink
Trainer._load_from_checkpoint test multiple adapters
Browse files Browse the repository at this point in the history
  • Loading branch information
Clara Luise Pohland committed May 2, 2024
1 parent 16bd0ef commit 15429ad
Showing 1 changed file with 59 additions and 0 deletions.
59 changes: 59 additions & 0 deletions tests/trainer/test_trainer.py
Expand Up @@ -964,6 +964,65 @@ def test_bnb_compile(self):
with self.assertRaises(ValueError):
_ = Trainer(tiny_model, args, train_dataset=train_dataset) # noqa

@require_peft
def test_multiple_peft_adapters(self):
from peft import LoraConfig, get_peft_model

# Simply tests if initializing a Trainer with a PEFT + compiled model works out of the box
# QLoRA + torch compile is not really supported yet, but we should at least support the model
# loading and let torch throw the

MODEL_ID = "hf-internal-testing/tiny-random-LlamaForCausalLM"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tiny_model = AutoModelForCausalLM.from_pretrained(MODEL_ID)

peft_config = LoraConfig(
r=4,
lora_alpha=16,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
tiny_model = get_peft_model(tiny_model, peft_config, "adapter1")
tiny_model.add_adapter("adapter2", peft_config)

train_dataset = LineByLineTextDataset(
tokenizer=tokenizer,
file_path=PATH_SAMPLE_TEXT,
block_size=tokenizer.max_len_single_sentence,
)
for example in train_dataset.examples:
example["labels"] = example["input_ids"]

tokenizer.pad_token = tokenizer.eos_token

with tempfile.TemporaryDirectory() as tmpdir:
args = TrainingArguments(
tmpdir,
per_device_train_batch_size=1,
learning_rate=1e-9,
save_steps=5,
logging_steps= 5,
max_steps=10,
use_cpu=True,
)
trainer = Trainer(tiny_model, args, tokenizer=tokenizer, train_dataset=train_dataset)

trainer.train()
parameters = dict(tiny_model.named_parameters())
state = dataclasses.asdict(trainer.state)

# Reinitialize trainer
trainer = Trainer(tiny_model, args, tokenizer=tokenizer, train_dataset=train_dataset)

checkpoint = os.path.join(tmpdir, "checkpoint-5")

trainer.train(resume_from_checkpoint=checkpoint)
parameters1 = dict(tiny_model.named_parameters())
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(parameters, parameters1)
self.check_trainer_state_are_the_same(state, state1)

@require_bitsandbytes
def test_rmsprop_bnb(self):
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
Expand Down

0 comments on commit 15429ad

Please sign in to comment.