-
Notifications
You must be signed in to change notification settings - Fork 131
Fix VLM train_on_response_only #60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix VLM train_on_response_only #60
Conversation
Hi @coding-famer! This PR has most of the changes and will be completed with the following:
This colab will run end-to-end once the above changes are done. |
Hi @patel-zeel, thanks for your suggestion. Updated to fix 2. For 1, I believe the vision tokenizer will have the same behavior as tokenizer when image is None. |
@coding-famer Could you please help me understand why we would pass image=None for this case? In normal cases, the following line would need a check for VLM and need to access the text tokenizer explicitly with unsloth-zoo/unsloth_zoo/dataset_utils.py Line 184 in 0389d45
There might be bugs in the way I am using this feature, so I'd appreciate it if you could point to the required changes in my colab as well. |
@patel-zeel Sorry I mean here |
@coding-famer Oh, I got your point and also found the issue. Actually, the first argument to processor is Edit: I verified that all |
Oh yes you are right. I incorrectly set unsloth-zoo/unsloth_zoo/dataset_utils.py Line 184 in 0389d45
|
@coding-famer IMO, we can continue with your logic of using If we want to make VLM detection stricter, I used the following logic in my wrong PR: IS_VISION_MODEL = False
# This approach should support all kind of models irrespecitve
# of depth of vision layer in the parent module.
for module_name, _ in trainer.model.named_modules():
if any(name in module_name for name in ["visual", "vision_tower", "vision_model"]):
IS_VISION_MODEL = True
break |
@patel-zeel After thinking about this, I hope this pr can handle the general case: where inputs haven't been tokenized when init the trainer, instead of just handling vlm. So I think my |
Yes, I agree, @coding-famer. After solving the merge conflicts with @danielhanchen's new edits, this PR will be ready for review. |
@coding-famer A small edit I thought about: Changing the # Get the text tokenizer from the combined tokenizer
if not has_tokenized:
text_tokenizer = tokenizer.tokenizer
else:
text_tokenizer = tokenizer
pass
# Get most common tokens since tokenizers can tokenize stuff differently!
Q_must, Q_left, Q_right = _find_common_token_ids(instruction_part, text_tokenizer)
A_must, A_left, A_right = _find_common_token_ids(response_part, text_tokenizer) |
@coding-famer A few more thoughts:
# Check if all labels randomnly got masked to nothing - maybe wrong chat template?
from .training_utils import fix_zero_training_loss
fix_zero_training_loss(None, tokenizer, trainer.train_dataset)
class CustomDataCollator:
def __init__(self, collator, modifier_fn):
self.collator = collator
self.modifier_fn = modifier_fn
def __call__(self, examples):
batch = self.collator(examples)
batch["labels"] = self.modifier_fn(batch)["labels"]
return batch
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
import warnings
if not hasattr(trainer, "data_collator"):
warnings.warn(
"Vision model detected, but data_collator is not set. Setting it to UnslothVisionDataCollator."
)
trainer.data_collator = UnslothVisionDataCollator(model=trainer.model, processor=tokenizer)
if not isinstance(trainer.data_collator, CustomDataCollator):
trainer.data_collator = CustomDataCollator(trainer.data_collator, _train_on_responses_only)
pass |
@patel-zeel Good points! For 1 I'm not sure, too. Let me have a look at the logic of that function. I'll add 2. |
unsloth_zoo/dataset_utils.py
Outdated
not isinstance(trainer.data_collator, DataCollatorForSeq2Seq): | ||
trainer.data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer) | ||
from unsloth_zoo.vision_utils import UnslothVisionDataCollator | ||
if hasattr(trainer, "data_collator"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to ensure the trainer has a data_collator
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If people follow Unsloth Notebooks, it's unlikely to be missed, but I thought it's not hurting existing functionality. Also, it attaches the collator with a warning. In case an advanced user does not want this behavior, they can modify it after applying train_on_response_only
.
This was mainly inspired by the recent DataCollatorForSeq2Seq
change a few lines above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about leaving this for @danielhanchen to review?
I mean I didn't add your
if not hasattr(trainer, "data_collator"):
warnings.warn(
"Vision model detected, but data_collator is not set. Setting it to UnslothVisionDataCollator."
)
trainer.data_collator = UnslothVisionDataCollator(model=trainer.model, processor=tokenizer)
part. And the recent DataCollatorForSeq2Seq
change doesn't ensure the trainer
has data_collator
, too.
It's good to add this but the old logic doesn't do this (or maybe it is better to do it outside the train_on_response_only
cus this is not quite related).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I agree. It's not directly related and not having it is not hurting anything.
@coding-famer This looks mostly good with one caveat: If someone changes if hasattr(trainer, "data_collator"):
# If `UnslothResponseOnlyCollator` is found, extract internal collator to avoid double wrapping
if hasattr(trainer.data_collator, "collator"):
trainer.data_collator = trainer.data_collator.collator
trainer.data_collator = UnslothResponseOnlyCollator(trainer.data_collator, _train_on_responses_only) |
Amazing @coding-famer! Hi @danielhanchen, this PR is ready for review. This colab is for testing this feature. This is one of the issues given in the Unsloth puzzles colab. |
@coding-famer @patel-zeel I believe we can get text_only and image examples to work correctly if we do something like this: class UnslothResponseOnlyCollator:
def __init__(self, collator, modifier_fn):
self.collator = collator
self.modifier_fn = modifier_fn
def __call__(self, examples):
image_examples = []
text_only_examples = []
for example in examples:
msgs = example.get("messages", [])
if any(msg.get("type") == "image" for msg in msgs):
image_examples.append(example)
else:
text_only_examples.append(example)
batch = {}
if len(image_examples) > 0:
image_batch = self.collator(image_examples)
image_batch["labels"] = self.modifier_fn(image_batch)["labels"]
batch = image_batch
if len(text_only_examples) > 0:
text_batch = self.collator(text_only_examples)
text_batch["labels"] = self.modifier_fn(text_batch)["labels"]
batch = {**batch, **text_batch}
return batch This probably just needs some massaging to deal with the case where people have a 2 column dataset, with the second column being an image. |
For me, I think the mixture of text only and text image case should be done in a new datacollator, and this issue is trying to handle the train_on_response_only case(imaging someone wants to train with mixture of text and text-image but doesn't want to train on response only). You can open a new PR and I'm willing to contribute to it, too. |
Oh my wait - apologies I totally missed this |
I actually I think accidentally in parallel did this during the release for Geamm 3: https://github.com/unslothai/unsloth-zoo/blob/main/unsloth_zoo/vision_utils.py#L259 |
I can accept the PR and provide the bounty for parts which are not in the current code base - the extra checks and names and etc |
Make
train_on_response_only
compatible with VLM. Fix unslothai/unsloth#1396In VLM, tokenization occurs within the
data_collator
rather than inSFTTrainer._prepare_dataset
. To ensure consistency in thetrain_on_response_only
interface, I modified the trainer’sdata_collator
to handle label construction.