Skip to content

Add trainer integration test for llava to ensure accelerate autocasting works correctly #30489

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

frasermince
Copy link
Contributor

What does this PR do?

This PR adds a new integration test to ensure the accelerate autocasting is working correctly. This came out of a discussion found here and that PR should probably be merged first (or this one merged into that one).

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@@ -0,0 +1,120 @@
import unittest
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file should be the only new change. That will be clearer once the other is merged. Usually I would set it up to merge this one into the first but this feels a bit difference since the changes are on a fork.

@frasermince frasermince force-pushed the frasermince/trainer-integration-test branch from b723f1f to 908ff93 Compare April 25, 2024 17:32

output = model(**inputs)
expected_slice = torch.tensor(
[[-3.5664, -3.5625, -0.4309], [-5.8242, -5.6914, -1.3242], [-5.4805, -5.9375, 1.1465]],
Copy link
Contributor Author

@frasermince frasermince Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw this pattern in the llava next tests. These came from the use of this model before training. Not quite sure if this is correct so please let me know if there is something else we want to test. Perhaps instead we want the trained model before applying the downcasting change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also note these do not yet pass allclose. I wanted to go ahead and open this PR to generate discussion around what the right thing to test is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@muellerzr How are these allclose values found? Is it literally what the model is outputting now to determine if logits are changing in the future? Is it based on some original implementation?

"llava-hf/bakLlava-v1-hf", quantization_config=bits_and_bytes_config
)
adapter_name = "lora_default"
peft_config = LoraConfig(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am somewhat unclear on where we test the slow tests but I assume there is some limit on memory so I tried to give a reasonable LORA for this test. If you think there is a simpler or more idiomatic way to do this test please let me know.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine for what we're doing!

@frasermince frasermince force-pushed the frasermince/trainer-integration-test branch from 908ff93 to 1cd13e0 Compare April 25, 2024 17:35
@slow
@require_bitsandbytes
def test_model_trainer_integration_test(self):
def image_prompt_generator():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not entirely sure this is the simplest or most idiomatic way to create this test dataset so please let me know if there is a better way.

@frasermince frasermince force-pushed the frasermince/trainer-integration-test branch 4 times, most recently from 7ce15e3 to 9440dd6 Compare April 25, 2024 18:13
@amyeroberts
Copy link
Collaborator

cc @muellerzr for first review

@frasermince frasermince force-pushed the frasermince/trainer-integration-test branch 2 times, most recently from e9e3feb to bc31529 Compare May 2, 2024 22:07
@frasermince
Copy link
Contributor Author

Updated this now that the previous PR is merged! I am very concerned about OOMs being an issue here however. I think there's some open questions around:

  1. How we ensure models are compatible with the trainer and accelerate
  2. How we test training a model in CI given how memory intensive this can be

@frasermince frasermince force-pushed the frasermince/trainer-integration-test branch from 8f35687 to 3c7e7e1 Compare May 5, 2024 19:10
@huggingface huggingface deleted a comment from github-actions bot May 30, 2024
@amyeroberts
Copy link
Collaborator

Gentle ping @muellerzr, or possibly @SunMarc?

@huggingface huggingface deleted a comment from github-actions bot Jun 24, 2024
@huggingface huggingface deleted a comment from github-actions bot Jul 19, 2024
@amyeroberts
Copy link
Collaborator

Another ping @muellerzr @SunMarc

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this is a very good test. Since autocasting is done "automagically" via accelerate, this tests it exactly how you should!

@muellerzr muellerzr requested a review from amyeroberts July 31, 2024 20:12
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great - thanks for adding this!

Once the conflicts are resolved I think we're good to go

cc @zucchini-nlp for reference

@@ -0,0 +1,116 @@
import gc
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing copyright header


def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My (somewhat sparse) knowledge of empty_cache is that it's not meant to be used manually and can cause unintended / surprising behaviour: #31372 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While normally I’d agree, if it’s in the tests it should be fine. That was in reference to it being in the actual Trainer code

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting finding, thanks for fixing and adding tests!

Out of curiosity, iiuc training llava and llava-next with HF Trainer and fp16/bf16 flags failed with dtype errors, before the fix was done. I am wondering about other llava-based models, this should mean that vipllava, llava-next-video and video-llava should fail with the same error because all follow similar architecture. But for llava-next-video I had a script with fp16 running w/o errors, would love to know your opinion on this

@frasermince
Copy link
Contributor Author

Interesting finding, thanks for fixing and adding tests!

Out of curiosity, iiuc training llava and llava-next with HF Trainer and fp16/bf16 flags failed with dtype errors, before the fix was done. I am wondering about other llava-based models, this should mean that vipllava, llava-next-video and video-llava should fail with the same error because all follow similar architecture. But for llava-next-video I had a script with fp16 running w/o errors, would love to know your opinion on this

Interesting, it's possible that this error has been subsequently fixed in these other envs. It's been a couple of months since I have looked at this but I could definitely check if there are any other nuances that could cause those models to work with a half precision flag and llava and llava-next to not. I would have to do a a bit more research on this.

@frasermince frasermince force-pushed the frasermince/trainer-integration-test branch from 3c7e7e1 to 56bb12e Compare August 8, 2024 16:21
@frasermince frasermince force-pushed the frasermince/trainer-integration-test branch from 56bb12e to cdad616 Compare August 8, 2024 16:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants