Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer._load_from_checkpoint - support loading multiple Peft adapters #30505

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 23 additions & 1 deletion src/transformers/trainer.py
Expand Up @@ -2432,6 +2432,20 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
# this checks the FSDP state dict when `FULL_STATE_DICT` is used
or os.path.isfile(os.path.join(resume_from_checkpoint, f"{FSDP_MODEL_NAME}.bin"))
)
# if multiple adapters exist, they get saved in sub directories
adapter_subdirs = (
[
folder_name
for folder_name in os.listdir(resume_from_checkpoint)
if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
and (
os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_WEIGHTS_NAME))
or os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_SAFE_WEIGHTS_NAME))
)
]
if os.path.isdir(resume_from_checkpoint)
else []
)

if is_fsdp_ckpt and not self.is_fsdp_enabled:
raise ValueError(f"Checkpoint found at {resume_from_checkpoint} is only supported when using PyTorch FSDP")
Expand All @@ -2449,6 +2463,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
]
)
or is_fsdp_ckpt
or adapter_subdirs
):
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")

Expand Down Expand Up @@ -2522,7 +2537,14 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
if os.path.exists(resume_from_checkpoint):
model.load_adapter(resume_from_checkpoint, model.active_adapter, is_trainable=True)
if adapter_subdirs:
active_adapter = model.active_adapter
for subdir_name in adapter_subdirs:
peft_id = os.path.join(resume_from_checkpoint, subdir_name)
model.load_adapter(peft_id, subdir_name, is_trainable=(subdir_name == active_adapter))
model.set_adapter(active_adapter)
else:
model.load_adapter(resume_from_checkpoint, model.active_adapter, is_trainable=True)
else:
logger.warning(
"The intermediate checkpoints of PEFT may not be saved correctly, "
Expand Down
57 changes: 57 additions & 0 deletions tests/trainer/test_trainer.py
Expand Up @@ -964,6 +964,63 @@ def test_bnb_compile(self):
with self.assertRaises(ValueError):
_ = Trainer(tiny_model, args, train_dataset=train_dataset) # noqa

@require_peft
def test_multiple_peft_adapters(self):
from peft import LoraConfig, get_peft_model

# Tests if resuming from checkpoint works if the model has multiple adapters

MODEL_ID = "hf-internal-testing/tiny-random-LlamaForCausalLM"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tiny_model = AutoModelForCausalLM.from_pretrained(MODEL_ID)

peft_config = LoraConfig(
r=4,
lora_alpha=16,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
tiny_model = get_peft_model(tiny_model, peft_config, "adapter1")
tiny_model.add_adapter("adapter2", peft_config)

train_dataset = LineByLineTextDataset(
tokenizer=tokenizer,
file_path=PATH_SAMPLE_TEXT,
block_size=tokenizer.max_len_single_sentence,
)
for example in train_dataset.examples:
example["labels"] = example["input_ids"]

tokenizer.pad_token = tokenizer.eos_token

with tempfile.TemporaryDirectory() as tmpdir:
args = TrainingArguments(
tmpdir,
per_device_train_batch_size=1,
learning_rate=1e-9,
save_steps=5,
logging_steps=5,
max_steps=10,
use_cpu=True,
)
trainer = Trainer(tiny_model, args, tokenizer=tokenizer, train_dataset=train_dataset)

trainer.train()
parameters = dict(tiny_model.named_parameters())
state = dataclasses.asdict(trainer.state)

# Reinitialize trainer
trainer = Trainer(tiny_model, args, tokenizer=tokenizer, train_dataset=train_dataset)

checkpoint = os.path.join(tmpdir, "checkpoint-5")

trainer.train(resume_from_checkpoint=checkpoint)
parameters1 = dict(tiny_model.named_parameters())
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(parameters, parameters1)
self.check_trainer_state_are_the_same(state, state1)

@require_bitsandbytes
def test_rmsprop_bnb(self):
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
Expand Down