-
Notifications
You must be signed in to change notification settings - Fork 29.6k
Add trainer integration test for llava to ensure accelerate autocasting works correctly #30489
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add trainer integration test for llava to ensure accelerate autocasting works correctly #30489
Conversation
@@ -0,0 +1,120 @@ | |||
import unittest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file should be the only new change. That will be clearer once the other is merged. Usually I would set it up to merge this one into the first but this feels a bit difference since the changes are on a fork.
b723f1f
to
908ff93
Compare
|
||
output = model(**inputs) | ||
expected_slice = torch.tensor( | ||
[[-3.5664, -3.5625, -0.4309], [-5.8242, -5.6914, -1.3242], [-5.4805, -5.9375, 1.1465]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw this pattern in the llava next tests. These came from the use of this model before training. Not quite sure if this is correct so please let me know if there is something else we want to test. Perhaps instead we want the trained model before applying the downcasting change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also note these do not yet pass allclose
. I wanted to go ahead and open this PR to generate discussion around what the right thing to test is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@muellerzr How are these allclose values found? Is it literally what the model is outputting now to determine if logits are changing in the future? Is it based on some original implementation?
"llava-hf/bakLlava-v1-hf", quantization_config=bits_and_bytes_config | ||
) | ||
adapter_name = "lora_default" | ||
peft_config = LoraConfig( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am somewhat unclear on where we test the slow tests but I assume there is some limit on memory so I tried to give a reasonable LORA for this test. If you think there is a simpler or more idiomatic way to do this test please let me know.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is fine for what we're doing!
908ff93
to
1cd13e0
Compare
@slow | ||
@require_bitsandbytes | ||
def test_model_trainer_integration_test(self): | ||
def image_prompt_generator(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not entirely sure this is the simplest or most idiomatic way to create this test dataset so please let me know if there is a better way.
7ce15e3
to
9440dd6
Compare
cc @muellerzr for first review |
e9e3feb
to
bc31529
Compare
Updated this now that the previous PR is merged! I am very concerned about OOMs being an issue here however. I think there's some open questions around:
|
8f35687
to
3c7e7e1
Compare
Gentle ping @muellerzr, or possibly @SunMarc? |
Another ping @muellerzr @SunMarc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall this is a very good test. Since autocasting is done "automagically" via accelerate, this tests it exactly how you should!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great - thanks for adding this!
Once the conflicts are resolved I think we're good to go
cc @zucchini-nlp for reference
@@ -0,0 +1,116 @@ | |||
import gc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing copyright header
|
||
def tearDown(self): | ||
gc.collect() | ||
torch.cuda.empty_cache() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My (somewhat sparse) knowledge of empty_cache is that it's not meant to be used manually and can cause unintended / surprising behaviour: #31372 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While normally I’d agree, if it’s in the tests it should be fine. That was in reference to it being in the actual Trainer code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting finding, thanks for fixing and adding tests!
Out of curiosity, iiuc training llava and llava-next with HF Trainer and fp16/bf16 flags failed with dtype errors, before the fix was done. I am wondering about other llava-based models, this should mean that vipllava, llava-next-video and video-llava should fail with the same error because all follow similar architecture. But for llava-next-video I had a script with fp16 running w/o errors, would love to know your opinion on this
Interesting, it's possible that this error has been subsequently fixed in these other envs. It's been a couple of months since I have looked at this but I could definitely check if there are any other nuances that could cause those models to work with a half precision flag and llava and llava-next to not. I would have to do a a bit more research on this. |
3c7e7e1
to
56bb12e
Compare
56bb12e
to
cdad616
Compare
What does this PR do?
This PR adds a new integration test to ensure the accelerate autocasting is working correctly. This came out of a discussion found here and that PR should probably be merged first (or this one merged into that one).
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
Not an issue but another PR that should probably be merged first: Fix llava half precision and autocast issues #29721 (comment)
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.