Skip to content

Fix VLM train_on_response_only #60

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

coding-famer
Copy link
Contributor

Make train_on_response_only compatible with VLM. Fix unslothai/unsloth#1396

In VLM, tokenization occurs within the data_collator rather than in SFTTrainer._prepare_dataset. To ensure consistency in the train_on_response_only interface, I modified the trainer’s data_collator to handle label construction.

@patel-zeel
Copy link

patel-zeel commented Mar 3, 2025

Hi @coding-famer! This PR has most of the changes and will be completed with the following:

  • Since the vision tokenizer has image_processor as well, _find_common_token_ids should take tokenizer.tokenizer instead of tokenizer.
  • batch["labels"] = _train_on_responses_only(batch) -> batch["labels"] = _train_on_responses_only(batch)["labels"], not to create nested dictionary with key "labels".

This colab will run end-to-end once the above changes are done.

@coding-famer
Copy link
Contributor Author

Hi @coding-famer! This PR has most of the changes and will be completed with the following:

  • Since the vision tokenizer has image_processor as well, _find_common_token_ids should take tokenizer.tokenizer instead of tokenizer.
  • batch["labels"] = _train_on_responses_only(batch) -> batch["labels"] = _train_on_responses_only(batch)["labels"], not to create nested dictionary with key "labels".

This colab will run end-to-end once the above changes are done.

Hi @patel-zeel, thanks for your suggestion. Updated to fix 2. For 1, I believe the vision tokenizer will have the same behavior as tokenizer when image is None.

@patel-zeel
Copy link

patel-zeel commented Mar 5, 2025

@coding-famer Could you please help me understand why we would pass image=None for this case? In normal cases, the following line would need a check for VLM and need to access the text tokenizer explicitly with tokenizer.tokenizer.

tokenizer = trainer.processing_class if hasattr(trainer, "processing_class") else trainer.tokenizer

There might be bugs in the way I am using this feature, so I'd appreciate it if you could point to the required changes in my colab as well.

@coding-famer
Copy link
Contributor Author

@coding-famer Could you please help me understand why we would pass image=None for this case? In normal cases, the following line would need a check for VLM and need to access the text tokenizer explicitly with tokenizer.tokenizer.

tokenizer = trainer.processing_class if hasattr(trainer, "processing_class") else trainer.tokenizer

There might be bugs in the way I am using this feature, so I'd appreciate it if you could point to the required changes in my colab as well.

@patel-zeel Sorry I mean here processor(ids) will have same output with processor.tokenizer(ids), so this modification is unnecessary.
And I think the colab notebook works correctly? Please LMK what error will occur if we don't set tokenizer = tokenizer.tokenizer

@patel-zeel
Copy link

patel-zeel commented Mar 5, 2025

@coding-famer Oh, I got your point and also found the issue. Actually, the first argument to processor is images, and thus, my colab is running into an error. We need to make sure that text part goes to text argument, but, using keyword argument may break for some tokenizers?

Edit: I verified that all LLama, Pixtral, Qwen and llava processors accept images as their first argument (Perhaps they kept it that way on purpose so that no one mistakenly passes only text to the processor and wonders why model is so poor!).

@coding-famer
Copy link
Contributor Author

@coding-famer Oh, I got your point and also found the issue. Actually, the first argument to processor is images, and thus, my colab is running into an error. We need to make sure that text part goes to text argument, but, using keyword argument may break for some tokenizers?

Oh yes you are right. I incorrectly set tokenizer=processor.tokenizer when I initialized the trainer(though this seems a solution for this issue but let's make it handle general cases). So what's the best way you think to solve it? Set tokenizer here for vlm?

tokenizer = trainer.processing_class if hasattr(trainer, "processing_class") else trainer.tokenizer

@patel-zeel
Copy link

patel-zeel commented Mar 5, 2025

@coding-famer IMO, we can continue with your logic of using has_tokenized to detect VLM and create an if-else condition to use processor(text=...) explicitly when a VLM is detected. We can also add what we have discovered so far in the code comments to save a few hours for future devs!

If we want to make VLM detection stricter, I used the following logic in my wrong PR:

IS_VISION_MODEL = False
# This approach should support all kind of models irrespecitve 
# of depth of vision layer in the parent module.
for module_name, _ in trainer.model.named_modules():
    if any(name in module_name for name in ["visual", "vision_tower", "vision_model"]):
        IS_VISION_MODEL = True
        break

@coding-famer
Copy link
Contributor Author

@patel-zeel After thinking about this, I hope this pr can handle the general case: where inputs haven't been tokenized when init the trainer, instead of just handling vlm. So I think my has_tokenized flag is good.

@patel-zeel
Copy link

Yes, I agree, @coding-famer. After solving the merge conflicts with @danielhanchen's new edits, this PR will be ready for review.

@patel-zeel
Copy link

@coding-famer A small edit I thought about: Changing the tokenizer in place might create an issue in case a full tokenizer is needed later. How about something like the following to make it more explicit and readable?

# Get the text tokenizer from the combined tokenizer
if not has_tokenized:
    text_tokenizer = tokenizer.tokenizer
else:
    text_tokenizer = tokenizer
pass

# Get most common tokens since tokenizers can tokenize stuff differently!
Q_must, Q_left, Q_right = _find_common_token_ids(instruction_part, text_tokenizer)
A_must, A_left, A_right = _find_common_token_ids(response_part,    text_tokenizer)

@patel-zeel
Copy link

@coding-famer A few more thoughts:

  1. I'm not sure what the following code does, but shall we keep it outside of the if else condition to apply it for VLMs, too?
# Check if all labels randomnly got masked to nothing - maybe wrong chat template?
from .training_utils import fix_zero_training_loss
fix_zero_training_loss(None, tokenizer, trainer.train_dataset)
  1. I was thinking about what would happen if someone calls train_on_response_only multiple times (say executing the same cell of Jupyter notebook multiple times). Will it wrap one more time? I was thinking of something like this to avoid it:
class CustomDataCollator:
    def __init__(self, collator, modifier_fn):
        self.collator = collator
        self.modifier_fn = modifier_fn

    def __call__(self, examples):
        batch = self.collator(examples)
        batch["labels"] = self.modifier_fn(batch)["labels"]
        return batch

from unsloth_zoo.vision_utils import UnslothVisionDataCollator
import warnings

if not hasattr(trainer, "data_collator"):
    warnings.warn(
        "Vision model detected, but data_collator is not set. Setting it to UnslothVisionDataCollator."
    )
    trainer.data_collator = UnslothVisionDataCollator(model=trainer.model, processor=tokenizer)
if not isinstance(trainer.data_collator, CustomDataCollator):
    trainer.data_collator = CustomDataCollator(trainer.data_collator, _train_on_responses_only)

pass

@coding-famer
Copy link
Contributor Author

@patel-zeel Good points! For 1 I'm not sure, too. Let me have a look at the logic of that function. I'll add 2.

not isinstance(trainer.data_collator, DataCollatorForSeq2Seq):
trainer.data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer)
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
if hasattr(trainer, "data_collator"):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to ensure the trainer has a data_collator?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If people follow Unsloth Notebooks, it's unlikely to be missed, but I thought it's not hurting existing functionality. Also, it attaches the collator with a warning. In case an advanced user does not want this behavior, they can modify it after applying train_on_response_only.

This was mainly inspired by the recent DataCollatorForSeq2Seq change a few lines above.

Copy link
Contributor Author

@coding-famer coding-famer Mar 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about leaving this for @danielhanchen to review?
I mean I didn't add your

if not hasattr(trainer, "data_collator"):
    warnings.warn(
        "Vision model detected, but data_collator is not set. Setting it to UnslothVisionDataCollator."
    )
    trainer.data_collator = UnslothVisionDataCollator(model=trainer.model, processor=tokenizer)

part. And the recent DataCollatorForSeq2Seq change doesn't ensure the trainer has data_collator, too.
It's good to add this but the old logic doesn't do this (or maybe it is better to do it outside the train_on_response_only cus this is not quite related).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree. It's not directly related and not having it is not hurting anything.

@patel-zeel
Copy link

@coding-famer This looks mostly good with one caveat: If someone changes instruction_part and response_part and reapplies the function, it'll silently execute and will have no effect because we will simply check if trainer.data_collator already has a collator and will return the same object. One possible solution is:

if hasattr(trainer, "data_collator"):
    # If `UnslothResponseOnlyCollator` is found, extract internal collator to avoid double wrapping
    if hasattr(trainer.data_collator, "collator"):
        trainer.data_collator = trainer.data_collator.collator
    trainer.data_collator = UnslothResponseOnlyCollator(trainer.data_collator, _train_on_responses_only)

@patel-zeel
Copy link

patel-zeel commented Mar 10, 2025

Amazing @coding-famer!

Hi @danielhanchen, this PR is ready for review. This colab is for testing this feature. This is one of the issues given in the Unsloth puzzles colab.
image

@mikeknapp
Copy link

@coding-famer @patel-zeel I believe we can get text_only and image examples to work correctly if we do something like this:

class UnslothResponseOnlyCollator:
  def __init__(self, collator, modifier_fn):
      self.collator = collator
      self.modifier_fn = modifier_fn
  def __call__(self, examples):
      image_examples = []
      text_only_examples = []
      for example in examples:
          msgs = example.get("messages", [])
          if any(msg.get("type") == "image" for msg in msgs):
              image_examples.append(example)
          else:
              text_only_examples.append(example)
      batch = {}
      if len(image_examples) > 0:
          image_batch = self.collator(image_examples)
          image_batch["labels"] = self.modifier_fn(image_batch)["labels"]
          batch = image_batch
      if len(text_only_examples) > 0:
          text_batch = self.collator(text_only_examples)
          text_batch["labels"] = self.modifier_fn(text_batch)["labels"]
          batch = {**batch, **text_batch}
      return batch

This probably just needs some massaging to deal with the case where people have a 2 column dataset, with the second column being an image.

@coding-famer
Copy link
Contributor Author

@coding-famer @patel-zeel I believe we can get text_only and image examples to work correctly if we do something like this:

class UnslothResponseOnlyCollator:
  def __init__(self, collator, modifier_fn):
      self.collator = collator
      self.modifier_fn = modifier_fn
  def __call__(self, examples):
      image_examples = []
      text_only_examples = []
      for example in examples:
          msgs = example.get("messages", [])
          if any(msg.get("type") == "image" for msg in msgs):
              image_examples.append(example)
          else:
              text_only_examples.append(example)
      batch = {}
      if len(image_examples) > 0:
          image_batch = self.collator(image_examples)
          image_batch["labels"] = self.modifier_fn(image_batch)["labels"]
          batch = image_batch
      if len(text_only_examples) > 0:
          text_batch = self.collator(text_only_examples)
          text_batch["labels"] = self.modifier_fn(text_batch)["labels"]
          batch = {**batch, **text_batch}
      return batch

This probably just needs some massaging to deal with the case where people have a 2 column dataset, with the second column being an image.

For me, I think the mixture of text only and text image case should be done in a new datacollator, and this issue is trying to handle the train_on_response_only case(imaging someone wants to train with mixture of text and text-image but doesn't want to train on response only). You can open a new PR and I'm willing to contribute to it, too.

@danielhanchen
Copy link
Contributor

Oh my wait - apologies I totally missed this

@danielhanchen
Copy link
Contributor

I actually I think accidentally in parallel did this during the release for Geamm 3: https://github.com/unslothai/unsloth-zoo/blob/main/unsloth_zoo/vision_utils.py#L259

@danielhanchen
Copy link
Contributor

I can accept the PR and provide the bounty for parts which are not in the current code base - the extra checks and names and etc

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Train on completions only by fixing the collator inquiry
4 participants