Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Adapters-initialized tensors sometimes conflict with model dtypes, no way to set automatically #766

Open
killershrimp opened this issue Dec 10, 2024 · 0 comments · May be fixed by #767
Labels
bug Something isn't working

Comments

@killershrimp
Copy link

killershrimp commented Dec 10, 2024

Environment info

  • adapters version: 1.0.1
  • Platform: Linux
  • Python version: 3.11.10
  • PyTorch version (GPU?): 2.5.1 (1x RTX A4000)

Information

Model I am using (Bert, XLNet ...): Llama 3.2 1B (English)

Adapter setup I am using (if any): LoRA/ReFT

The task I am working on is: finetuning on boolq

To reproduce

(note: same error results if you swap LoRA stuff with ReFT)

import adapters
import datasets
import torch
import transformers
import trl

base_prompt = """
### Context:
{}

### Question:
{}

### Response:
{}
"""

device="cuda"
model_max_length = 2048
model_name_or_path = "meta-llama/Llama-3.2-1B"

def model_init(trial):
    model = transformers.AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16, device_map=device)
    adapters.init(model)
    lora_config = adapters.LoRAConfig()
    model.add_adapter("lora", lora_config)
    model.train_adapter("lora")
    return model

tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name_or_path, padding_side="left", device_map=device
)
tokenizer.pad_token = tokenizer.eos_token
dataset = datasets.load_dataset("google/boolq")

training_args = transformers.TrainingArguments(
    output_dir="test_trainer",
    per_device_train_batch_size=1,
)

def format_prompts(e):
    return [
        base_prompt.format(p, q, a)
        for p, q, a in zip(e["passage"], e["question"], e["answer"])
    ]

def tokenize_prompt(e):
    return tokenizer(format_prompts(e), truncation=True, max_length=model_max_length)

small_train_dataset = dataset["train"].select(range(1000))
small_train_dataset = small_train_dataset.map(
    tokenize_prompt,
    batched=True,
    remove_columns=["passage", "question", "answer"])
tokenizer.pad_token = tokenizer.eos_token

response_template = "### Response:\n"
collator = trl.DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

trainer = adapters.AdapterTrainer(
    # model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    # eval_dataset=small_eval_dataset,
    model_init=model_init,
    data_collator=collator,
)

trainer.train()
Stacktrace (adapters-initialized matrices have incompatible tensor dtype with model)
Traceback (most recent call last):
  File "/mnt/data/alex/lora_on_reft/adapters_issue.py", line 70, in <module>
    trainer.train()
  File "/mnt/data/alex/miniconda3/envs/adapters_baseline/lib/python3.11/site-packages/transformers/trainer.py", line 2052, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/mnt/data/alex/miniconda3/envs/adapters_baseline/lib/python3.11/site-packages/transformers/trainer.py", line 2388, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/data/alex/miniconda3/envs/adapters_baseline/lib/python3.11/site-packages/transformers/trainer.py", line 3485, in training_step
    loss = self.compute_loss(model, inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/data/alex/miniconda3/envs/adapters_baseline/lib/python3.11/site-packages/transformers/trainer.py", line 3532, in compute_loss
    outputs = model(**inputs)
              ^^^^^^^^^^^^^^^
  File "/mnt/data/alex/miniconda3/envs/adapters_baseline/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/data/alex/miniconda3/envs/adapters_baseline/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/data/alex/miniconda3/envs/adapters_baseline/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1189, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/mnt/data/alex/miniconda3/envs/adapters_baseline/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/data/alex/miniconda3/envs/adapters_baseline/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/data/alex/miniconda3/envs/adapters_baseline/lib/python3.11/site-packages/adapters/context.py", line 116, in wrapper_func
    results = f(self, *args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/data/alex/miniconda3/envs/adapters_baseline/lib/python3.11/site-packages/adapters/model_mixin.py", line 1470, in forward
    return super().forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/data/alex/miniconda3/envs/adapters_baseline/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1000, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/mnt/data/alex/miniconda3/envs/adapters_baseline/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/data/alex/miniconda3/envs/adapters_baseline/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
    return inner()
           ^^^^^^^
  File "/mnt/data/alex/miniconda3/envs/adapters_baseline/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in inner
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/data/alex/miniconda3/envs/adapters_baseline/lib/python3.11/site-packages/adapters/models/llama/modeling_llama.py", line 437, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
                                                          ^^^^^^^^^^^^^^^
  File "/mnt/data/alex/miniconda3/envs/adapters_baseline/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/data/alex/miniconda3/envs/adapters_baseline/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/data/alex/miniconda3/envs/adapters_baseline/lib/python3.11/site-packages/adapters/models/llama/modeling_llama.py", line 322, in forward
    query_states = self.q_proj(hidden_states)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/data/alex/miniconda3/envs/adapters_baseline/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/data/alex/miniconda3/envs/adapters_baseline/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/data/alex/miniconda3/envs/adapters_baseline/lib/python3.11/site-packages/adapters/methods/lora.py", line 521, in forward
    state = self.compose(adapter_setup, state)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/data/alex/miniconda3/envs/adapters_baseline/lib/python3.11/site-packages/adapters/methods/adapter_layer_base.py", line 520, in compose
    state = composition_func(adapter_setup, state, lvl=0)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/data/alex/miniconda3/envs/adapters_baseline/lib/python3.11/site-packages/adapters/methods/adapter_layer_base.py", line 354, in compose_stack
    state = self.compose_single(adapter_stack_layer, state, lvl=lvl + 1)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/data/alex/miniconda3/envs/adapters_baseline/lib/python3.11/site-packages/adapters/methods/lora.py", line 503, in compose_single
    hidden_states, gate = lora(state.hidden_states, state.layer_input)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/data/alex/miniconda3/envs/adapters_baseline/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/data/alex/miniconda3/envs/adapters_baseline/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/data/alex/miniconda3/envs/adapters_baseline/lib/python3.11/site-packages/adapters/methods/lora.py", line 95, in forward
    hidden_states = self.lora_dropout(hidden_states) @ torch.t(self.lora_A) @ torch.t(self.lora_B)
                    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::BFloat16 != float

Expected behavior

Adapters automatically adjusts dtype to model's dtype, or option to manually configure in the AdapterConfig, which doesn't currently seem to exist

@killershrimp killershrimp added the bug Something isn't working label Dec 10, 2024
@killershrimp killershrimp changed the title [BUG] [BUG] (draft) Dec 10, 2024
@killershrimp killershrimp changed the title [BUG] (draft) [BUG] Adapters-initialized tensors sometimes conflict with model dtypes, no way to set automatically Dec 10, 2024
@killershrimp killershrimp linked a pull request Dec 10, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant