Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save adapter config and remapped adapter weights for loading into PEFT #933

Merged
merged 18 commits into from May 21, 2024

Conversation

ebsmothers
Copy link
Contributor

@ebsmothers ebsmothers commented May 3, 2024

PEFT integration

This is a PR for integration with PEFT. With this integration, you can take a fine-tuned checkpoint from torchtune and load it into a PEFT model using from_pretrained for continued fine-tuning or inference.

First, finetune a model in torchtune. Using the tune CLI:

tune run lora_finetune_single_device --config llama2/7B_lora_single_device \
checkpointer.output_dir=/my/output/dir

Then in Python:

from transformers import AutoModelForCausalLM
from peft import PeftModel

# hub ID of the base model from the above fine-tune
model_id = "meta-llama/Llama-2-7b-hf" 

# output_dir from tune command
checkpoint_dir = "/my/output/dir" 

model = AutoModelForCausalLM.from_pretrained(model_id)
peft_model = PeftModel.from_pretrained(model, checkpoint_dir)

And that's it! You can now use peft_model as you would any other PEFT model class.

We automatically output the necessary files/formats for PEFT integration whenever using torchtune's HF checkpointer, so make sure to use that in your fine-tuning config if you want to load your torchtune checkpoints into PEFT (example config usage).

Implementation

We save a file adapter_config.json, along with adapter_model.bin to match the format expected by PEFT. We also remap the LoRA weights to match the HF format (due to differences in RoPE implementations).

The save logic differs depending on checkpointer and model. In summary:

  • For Meta checkpointer and tune checkpointer we do not make any changes, and continue to save adapter weights in the tune format. This is to be consistent with the principle of same input format, same output format (since PEFT output format only matches HF format).
  • For HF checkpointer we still output tune format adapter weights (to allow resumption from intermediate checkpoints), but we also output HF-mapped adapter weights except for in the case of phi-3 models. See this comment for the rationale.

Testing:

Unit tests

Added unit test in test_checkpointer.py to verify that adapter config and PEFT-compatible weights are saved as expected

pytest tests/torchtune/utils/test_checkpointer.py
...
======= 6 passed in 1.19s ==========

Recipe tests

pytest -m integration_test tests/recipes
...
==== 18 passed, 1 deselected, 3 warnings in 167.99s (0:02:47) =======

Manual E2E test

First create the file test_peft_integration.py as in this gist.

(1) ✅ Permute of LoRA weights works as expected (i.e. _permute_lora_matrix(B) * A = _permute(B*A), which I think is what we want).
(2) ✅ Uploaded adapter weights can be loaded into a transformers model via from_pretrained
(3) ✅ Model forwards match within a reasonable tolerance across PEFT-loaded and torchtune-loaded checkpoints

For (3):

Test case 1: default config (Q and V only)

tune run lora_finetune_single_device --config llama2/7B_lora_single_device gradient_accumulation_steps=1 \
max_steps_per_epoch=500 dtype=fp32 checkpointer.output_dir=/data/users/ebs/test_peft_integration

to save a fine-tuned LoRA checkpoint with adapter config and adapter weights in PEFT format. Then to compare forward pass when loading our fine-tuned checkpoint into PEFT vs into torchtune:

python3 test_peft_integration.py --checkpoint-dir=/data/users/ebs/test_peft_integration
...
Maximum difference: 9.298324584960938e-05

Test case 2: all layers, custom LoRA rank and alpha

tune run lora_finetune_single_device --config llama2/7B_lora_single_device \
model.lora_attn_modules=['q_proj','k_proj','v_proj','output_proj'] model.apply_lora_to_mlp=True \
model.apply_lora_to_output=True gradient_accumulation_steps=1 max_steps_per_epoch=100 model.lora_rank=16 \
model.lora_alpha=64 dtype=fp32 checkpointer.output_dir=/data/users/ebs/test_peft_integration_full

Then

python3 test_peft_integration.py --checkpoint-dir=/data/users/ebs/test_peft_integration_full
...
Maximum difference: 0.000152587890625

Copy link

pytorch-bot bot commented May 3, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/933

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 9eb9b68 with merge base 29ae975 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 3, 2024
Copy link
Contributor

@kartikayk kartikayk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally looks good - just needs some clean up and comments to make this easy to understand

Comment on lines +262 to +263
self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp
self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related to this PR, but maybe at some point we should consider replacing the apply_lora_to_* flags with just adding mlp and output to the lora_modules?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agreed, I think this is likely where we'll head eventually. One thing is that we will probably want to make LoRA in MLP more configurable (i.e. use w1, w2, w3 (or hopefully more descriptive names) instead of mlp). Otherwise the relationship between e.g. q_proj (nn.Linear) and mlp (FeedForward) being in the same config is a bit confusing. Anyways this shouldn't be a huge effort to change

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that a single list is more intuitive, since, AFAICT, this is just consolidated into a single list under the hood.

or hopefully more descriptive names

Changing names later on can invalidate the saved checkpoints, so would require some versioning for backwards compatibility.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess versioning or some sort of a convertor/mapping? It would be great to figure this change out soon, but this point about checkpoint invalidation is a good one and something we should have a general solution for. I suspect this will come up many times

self._apply_lora_to_mlp,
self._apply_lora_to_output,
),
"peft_type": "LORA",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about this, but if the base model used for training was loaded from HF in the HF format (i.e. a transformers PretrainedModel), it should have a name_or_path attribute. This could be stored and if it exists, we could add it to the config here as base_model_name_or_path. This is not a required attribute for the adapter_config.json but would be nice to have for a few situations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good point. I was trying to avoid this initially since it may necessitate some changes to our load_checkpoint method, as right now we really only retrieve and remap model weights. If it's more of a nice-to-have, I may punt on it for this particular PR to keep things more isolated to save_checkpoint. Lmk if this makes sense. Also cc @kartikayk if you have any general thoughts on loading state/metadata through load_checkpointer and passing through our recipe. I imagine this is something we may want to start supporting more for various integrations anyways.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you expand a bit more on why we would need base_model_name_or_path? Is this to make sure there are no bugs related to selecting the right base model for further training in HF land? If so, I wonder if this is something which is a "must have" rather than a "good to have"? or let me know if I misunderstand?

If it's a must have, then is this something we can read from one of the json files or do we need to pass this information along through the recipe?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't strictly need base_model_name_or_path, but not having it means that the burden is on the user to figure out which base model this adapter belongs to. Of course, this can be solved with good documentation, but having it automatically in the adapter_config.json would be quite convenient.

Other points to consider:

  • When shared on HF Hub, this metadata can be used for other things (I'm not an expert on this though)
  • If base_model_name_or_path is present, users can load the adapter + base model in a single line of code (e.g. AutoModelForCausalLM.from_pretrained(<path-to-adapter>)).

@@ -198,3 +198,78 @@ def _permute(t, n_heads):
converted_state_dict[new_key] = value

return converted_state_dict


_TO_PEFT_KEYS = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe some quick comments on what these dicts refer to?

Comment on lines +247 to +256
for k, v in _TO_PEFT_KEYS.items():
full_mapping.update(
{
vv.replace(".weight", f".{k}.weight"): kk.replace(
".weight", f".{v}.weight"
)
for kk, vv in _FROM_HF.items()
if vv is not None
}
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block can use some comments explaining what's going on here


head_dim = dim // num_heads

def _permute_lora_matrix(t, n_heads):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So these are permuted as well - nice find!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only B matrices though 😃

Copy link

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, this should hopefully be useful to many users. I don't have any critical comments, just a couple of smaller ones.

It might also be a good idea to add a test for each supported architecture, just to be sure that the re-mappings of the keys are the same for all of them.

self._apply_lora_to_mlp,
self._apply_lora_to_output,
),
"peft_type": "LORA",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about this, but if the base model used for training was loaded from HF in the HF format (i.e. a transformers PretrainedModel), it should have a name_or_path attribute. This could be stored and if it exists, we could add it to the config here as base_model_name_or_path. This is not a required attribute for the adapter_config.json but would be nice to have for a few situations.



if __name__ == "__main__":
# test_permute()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this test still required? As is, it only prints something at the end, no asserts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah actually this whole test is probably going to be scrapped. As you point out in another comment, it is pretty expensive to run. Really I think I am gonna adopt a version of your other suggestion and will just add a unit test to confirm that key conversions etc are done correctly.

with torch.no_grad():
peft_out = peft_model(inputs)
tt_out = tt_model(inputs)
print(f"Maximum difference: {torch.max(torch.abs(peft_out.logits - tt_out))}")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be changed to an assert?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, will probably wind up removing this anyways. But in the forthcoming unit test I will use asserts


# Initialize Llama2 and load merged checkpoint
# (just testing that forward lines up)
tt_model = llama2_7b()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could first create the outputs of the PEFT model, then delete, and then load the tune model to save memory. But probably not necessary as you probably have much beefier CI runners than us :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is prob the better way. I was taking advantage of the extra memory for debugging.. defining attributes on each model class for intermediate values then comparing each step along the way. But this is gonna get scrapped anyways

"lora_b": "lora_B",
}

_TO_PEFT_TARGET_MODULES = {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if a single mapping can be maintained for all supported architectures. I haven't actually tried if it works, but just checked the key names for the supported models and Phi3 seems to use gate_up_proj (https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/tree/main?show_file_info=model-00001-of-00002.safetensors). So I wonder if one mapping per architecture is required (with this being the default mapping).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I've actually only tested for Llama2 so far, I think you're right that we'll need a separate mapping at least for Phi-3. We do have something here for the full checkpoint mapping already, will just need to adapt it for PEFT purposes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: there are other challenges with loading fine-tuned phi-3 checkpoints into PEFT from torchtune related to fused vs non-fused QKV. Namely, if someone fine-tunes in torchtune only on e.g. Q and K, they will not really be able to continue fine-tuning in PEFT in the way they would expect. In that case we can of course zero out the weights of the V chunk of the PEFT QKV LoRA matrix to get something that is in spirit correct, but (a) the user would probably expect only Q and K to remain trainable, which would not be the case, and (b) the learned LoRA weights from the torchtune finetune based on Q and K only may put any subsequent PEFT fine-tune using V as well in a suboptimal initial parameter space.

We could enforce up front that phi-3 LoRA is all-or-nothing on Q, K, and V for PEFT integration but I feel that's a bit messy. So for the time being I am opting to just raise a warning on checkpoint save that phi-3 adapter weights cannot be loaded into PEFT, and save just the usual torchtune adapter weights in that case.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, yes I think giving a warning is the best solution in this situation.

The only issue I have with the warning is that it is only given during checkpointing. I would be afraid that a user starts an expensive training run only to find out the next day that the checkpoint was not saved as expected. Would it be possible to give the warning already at model initialization time?

Comment on lines +262 to +263
self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp
self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that a single list is more intuitive, since, AFAICT, this is just consolidated into a single list under the hood.

or hopefully more descriptive names

Changing names later on can invalidate the saved checkpoints, so would require some versioning for backwards compatibility.

@ebsmothers ebsmothers marked this pull request as ready for review May 18, 2024 00:22
@ebsmothers ebsmothers changed the title [WIP] Save adapter config and remapped adapter weights for loading into PEFT Save adapter config and remapped adapter weights for loading into PEFT May 18, 2024
@@ -477,6 +477,19 @@ def save_checkpoint(
# to be sent to the checkpointer and ultimately written to file
if self._is_rank_zero:

# if training is in-progress, checkpoint the optimizer state and recipe state
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dumb q: Why move this up?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mainly to align with the ordering in the single-device recipe, but realistically it doesn't matter too much. Maybe I'll just leave as-is to not mix in extra complexity with the current set of changes. Also I realized somehow my changes to save PEFT config did not get pushed to this recipe, will update that now.

Copy link
Contributor

@kartikayk kartikayk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally looks good - thanks for persevering through all of the unknowns! A bunch of future-facing questions which would be good to think about.

Comment on lines +262 to +263
self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp
self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess versioning or some sort of a convertor/mapping? It would be great to figure this change out soon, but this point about checkpoint invalidation is a good one and something we should have a general solution for. I suspect this will come up many times

self._apply_lora_to_mlp,
self._apply_lora_to_output,
),
"peft_type": "LORA",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you expand a bit more on why we would need base_model_name_or_path? Is this to make sure there are no bugs related to selecting the right base model for further training in HF land? If so, I wonder if this is something which is a "must have" rather than a "good to have"? or let me know if I misunderstand?

If it's a must have, then is this something we can read from one of the json files or do we need to pass this information along through the recipe?

def tune_to_peft_adapter_config(
adapter_config: Dict[str, Any],
):
expected_keys = ["target_modules", "r", "lora_alpha"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be a constant like _TO_PEFT_TARGET_MODULES?

return adapter_config


def tune_to_peft_adapter_weights(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan I'm curious what your thoughts are on this function. It seems like this (along with other similar conversion functions) are fairly brittle and susceptible to breakages resulting from changes in PEFT/Transformers. A couple of questions:

  • How brittle is this in practice? Do we expect changes in these keys or permutation logic often?
  • Are the unit tests enough to capture this? Do we need to add similar tests on the PEFT side as well?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • How brittle is this in practice? Do we expect changes in these keys or permutation logic often?

No, there shouldn't be any frequent changes in this regard, as that would result in incompatibilities of old HF checkpoints as well. Generally, when something changes in the modeling code, we try to preserve the format of the checkpoint and re-map while loading the state_dict. I won't say it never happened in the past but I think it would generally be considered a bug and we'd fix it if notified.

  • Are the unit tests enough to capture this? Do we need to add similar tests on the PEFT side as well?

This probably wouldn't hurt. I could imagine that if you push a converted checkpoint to the HF Hub (ideally a small model), we can add a test to check if we can load it successfully.

@@ -482,12 +484,57 @@ def save_checkpoint(
f"{os.path.getsize(output_path) / 1000**3:.2f} GB "
f"saved to {output_path}"
)
# Phi-3-mini uses fused QKV in PEFT, this will not work as expected
# if only a proper subset of Q, K, V have been fine-tuned
if self._model_type == ModelType.PHI3_MINI:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely getting very wieldy. As we discussed, probably there is opportunity to refactor checkpointers with a focus on adapters

@ebsmothers ebsmothers merged commit dc9c697 into pytorch:main May 21, 2024
29 checks passed
@ebsmothers ebsmothers deleted the peft-integration branch May 21, 2024 20:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants