Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Bottleneck adapters do not work with the ViT model when original_ln_after = False #764

Closed
julian-fong opened this issue Dec 2, 2024 · 3 comments
Labels
question Further information is requested

Comments

@julian-fong
Copy link
Contributor

It seems like the ViT model does not train well with the bottleneck configs when the parameter original_ln_after is set to False

To reproduce

from datasets import load_dataset
import torch
num_classes = 100
train_dataset = load_dataset("uoft-cs/cifar100", split = "train").select(range(10000))
eval_dataset = load_dataset("uoft-cs/cifar100", split = "test").select(range(1000))

train_dataset.set_format("torch")
eval_dataset.set_format("torch")

model_name_or_path = 'google/vit-base-patch16-224-in21k'

from transformers import ViTImageProcessor
processor = ViTImageProcessor.from_pretrained(model_name_or_path)

def preprocess_image(example):
  image = processor(example["img"], return_tensors='pt')
  image["label"] = example["fine_label"]
  return image


train_dataset = train_dataset.map(preprocess_image)
eval_dataset = eval_dataset.map(preprocess_image)
#remove uneccessary columns
train_dataset = train_dataset.remove_columns(['img', 'fine_label', 'coarse_label'])
eval_dataset = eval_dataset.remove_columns(['img', 'fine_label', 'coarse_label'])


from typing import Any
from dataclasses import dataclass

@dataclass
class DataCollator:
  processor : Any
  def __call__(self, inputs):

    pixel_values = [input["pixel_values"].squeeze() for input in inputs]
    labels = [input["label"] for input in inputs]

    pixel_values = torch.stack(pixel_values)
    labels = torch.stack(labels)
    return {
        'pixel_values': pixel_values,
        'labels': labels,
    }

data_collator = DataCollator(processor = processor)


from adapters import ViTAdapterModel

model = ViTAdapterModel.from_pretrained(model_name_or_path)

from adapters import BnConfig
config = BnConfig(mh_adapter=False, output_adapter=True, reduction_factor=96, non_linearity="relu", original_ln_after=False)
model.add_adapter("bottleneck_adapter", config=config)
model.add_image_classification_head("bottleneck_adapter", num_labels=num_classes)
model.train_adapter("bottleneck_adapter")

import numpy as np
import evaluate
accuracy = evaluate.load("accuracy")

def compute_metrics(p):
  return accuracy.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

from adapters import AdapterTrainer
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir='./training_results',
    eval_strategy='epoch',
    learning_rate=10e-3,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=5,
    weight_decay=10e-4,
    report_to = "none",
    remove_unused_columns=False,
)

trainer = AdapterTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=processor,
    compute_metrics = compute_metrics
)

trainer.train()

image

@julian-fong julian-fong added the bug Something isn't working label Dec 2, 2024
@calpt
Copy link
Member

calpt commented Jan 2, 2025

I did some experimentation based on the script you provided with varying configs:
image

In general, at least one of original_ln_before or original_ln_after should be set to True to make sure the original residual connection from pre-training is preserved.
When original_ln_after=False, training only seems to converge if residual_before_ln=False, so these two should be used in combination in the example provided.

Since training does work with certain combinations of config values, I don't belive there's a general issue in the implementation here, just that we need to make sure to select the right combination of values. (Maybe we could add these notes as tips to a suitable place in the notebooks/ docs)

(edit: replaced results image)

@calpt calpt added question Further information is requested and removed bug Something isn't working labels Jan 2, 2025
@calpt calpt closed this as completed Jan 2, 2025
@julian-fong
Copy link
Contributor Author

I can provide some updates inside #775 since we are already planning to fix the config. Do you think this is a suitable place to put these notes?

@calpt
Copy link
Member

calpt commented Jan 4, 2025

sounds good. things specific to AdapterPlus are good in the place you linked. The things on layer norm would be general to bottleneck adapters, so we could add them in this section of the docs in one of those blue "Note" boxes to make it better discoverable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants