Skip to content

Commit 00e159b

Browse files
authored
Merge branch 'main' into training_args_extension
2 parents bb2d303 + 340500b commit 00e159b

File tree

11 files changed

+183
-36
lines changed

11 files changed

+183
-36
lines changed

src/transformers/cache_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,7 @@ def _flatten_dynamic_cache_for_fx(cache, spec):
585585
return torch.utils._pytree.tree_flatten(dictionary)[0]
586586

587587

588-
if is_torch_greater_or_equal("2.2"):
588+
if is_torch_greater_or_equal("2.3"):
589589
torch.utils._pytree.register_pytree_node(
590590
DynamicCache,
591591
_flatten_dynamic_cache,

src/transformers/modeling_rope_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[se
425425
rope_scaling = config.rope_scaling
426426
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
427427
required_keys = {"rope_type", "factor"}
428-
optional_keys = {"attention_factor", "beta_fast", "beta_slow"}
428+
optional_keys = {"attention_factor", "beta_fast", "beta_slow", "original_max_position_embeddings"}
429429
received_keys = set(rope_scaling.keys())
430430
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
431431

src/transformers/models/chameleon/modeling_chameleon.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,13 +1289,10 @@ def forward(
12891289
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
12901290
)
12911291

1292-
if inputs_embeds is None:
1293-
inputs_embeds = self.embed_tokens(input_ids)
1294-
12951292
if pixel_values is not None:
12961293
image_tokens = self.get_image_tokens(pixel_values)
12971294
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
1298-
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_tokens.numel():
1295+
if not is_torchdynamo_compiling() and input_ids[special_image_mask].numel() != image_tokens.numel():
12991296
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum()
13001297
n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
13011298
raise ValueError(
@@ -1304,6 +1301,9 @@ def forward(
13041301
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
13051302
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
13061303

1304+
if inputs_embeds is None:
1305+
inputs_embeds = self.embed_tokens(input_ids)
1306+
13071307
# torch.jit.trace() doesn't support cache objects in the output
13081308
if use_cache and past_key_values is None and not torch.jit.is_tracing():
13091309
past_key_values = DynamicCache()

src/transformers/models/gemma3/configuration_gemma3.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
2020
# See the License for the specific language governing permissions and
2121
# limitations under the License.
22-
from typing import Optional
22+
from typing import Any, Dict, Optional, Union
2323

2424
from ...configuration_utils import PretrainedConfig
2525
from ...modeling_rope_utils import rope_config_validation
@@ -292,8 +292,8 @@ class Gemma3Config(PretrainedConfig):
292292

293293
def __init__(
294294
self,
295-
text_config: Optional[Gemma3TextConfig] = None,
296-
vision_config: Optional[SiglipVisionConfig] = None,
295+
text_config: Optional[Union[Gemma3TextConfig, Dict[str, Any]]] = None,
296+
vision_config: Optional[Union[SiglipVisionConfig, Dict[str, Any]]] = None,
297297
mm_tokens_per_image: int = 256,
298298
boi_token_index: int = 255_999,
299299
eoi_token_index: int = 256_000,
@@ -303,18 +303,15 @@ def __init__(
303303
):
304304
if text_config is None:
305305
text_config = Gemma3TextConfig()
306-
logger.info("text_config is None, using default Gemma3TextConfig vision config.")
306+
logger.info("text_config is None, using default Gemma3TextConfig text config.")
307307
elif isinstance(text_config, dict):
308308
text_config = Gemma3TextConfig(**text_config)
309309

310310
if isinstance(vision_config, dict):
311311
vision_config = SiglipVisionConfig(**vision_config)
312-
else:
312+
elif vision_config is None:
313313
vision_config = SiglipVisionConfig()
314-
logger.info(
315-
"vision_config is None or incompatible with Gemma3VisionConfig intialization. Gemma3 will be limited "
316-
"to text tasks."
317-
)
314+
logger.info("vision_config is None, using default SiglipVisionConfig vision config.")
318315

319316
self.text_config = text_config
320317
self.vision_config = vision_config

src/transformers/models/gemma3/modular_gemma3.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import copy
1717
from collections.abc import Callable
1818
from dataclasses import dataclass
19-
from typing import List, Optional, Tuple, Union
19+
from typing import Any, Dict, List, Optional, Tuple, Union
2020

2121
import torch
2222
import torch.nn as nn
@@ -266,8 +266,8 @@ class Gemma3Config(PretrainedConfig):
266266

267267
def __init__(
268268
self,
269-
text_config: Optional[Gemma3TextConfig] = None,
270-
vision_config: Optional[SiglipVisionConfig] = None,
269+
text_config: Optional[Union[Gemma3TextConfig, Dict[str, Any]]] = None,
270+
vision_config: Optional[Union[SiglipVisionConfig, Dict[str, Any]]] = None,
271271
mm_tokens_per_image: int = 256,
272272
boi_token_index: int = 255_999,
273273
eoi_token_index: int = 256_000,
@@ -277,18 +277,15 @@ def __init__(
277277
):
278278
if text_config is None:
279279
text_config = Gemma3TextConfig()
280-
logger.info("text_config is None, using default Gemma3TextConfig vision config.")
280+
logger.info("text_config is None, using default Gemma3TextConfig text config.")
281281
elif isinstance(text_config, dict):
282282
text_config = Gemma3TextConfig(**text_config)
283283

284284
if isinstance(vision_config, dict):
285285
vision_config = SiglipVisionConfig(**vision_config)
286-
else:
286+
elif vision_config is None:
287287
vision_config = SiglipVisionConfig()
288-
logger.info(
289-
"vision_config is None or incompatible with Gemma3VisionConfig intialization. Gemma3 will be limited "
290-
"to text tasks."
291-
)
288+
logger.info("vision_config is None, using default SiglipVisionConfig vision config.")
292289

293290
self.text_config = text_config
294291
self.vision_config = vision_config

src/transformers/testing_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2367,8 +2367,8 @@ def tee(line, sink, pipe, label=""):
23672367
# XXX: the timeout doesn't seem to make any difference here
23682368
await asyncio.wait(
23692369
[
2370-
_read_stream(p.stdout, lambda l: tee(l, out, sys.stdout, label="stdout:")),
2371-
_read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:")),
2370+
asyncio.create_task(_read_stream(p.stdout, lambda l: tee(l, out, sys.stdout, label="stdout:"))),
2371+
asyncio.create_task(_read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:"))),
23722372
],
23732373
timeout=timeout,
23742374
)

src/transformers/utils/quantization_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1360,7 +1360,7 @@ def to_diff_dict(self) -> Dict[str, Any]:
13601360

13611361
# only serialize values that differ from the default config
13621362
for key, value in config_dict.items():
1363-
if value != default_config_dict[key]:
1363+
if key not in default_config_dict or value != default_config_dict[key]:
13641364
serializable_config_dict[key] = value
13651365

13661366
return serializable_config_dict

tests/generation/test_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@
126126
"ayavision",
127127
"gemma3",
128128
"mistral3",
129+
"chameleon",
129130
]
130131

131132

tests/models/chameleon/test_modeling_chameleon.py

Lines changed: 153 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
"""Testing suite for the PyTorch chameleon model."""
1616

17+
import copy
1718
import unittest
1819

1920
import requests
@@ -30,7 +31,7 @@
3031

3132
from ...generation.test_utils import GenerationTesterMixin
3233
from ...test_configuration_common import ConfigTester
33-
from ...test_modeling_common import ModelTesterMixin, ids_tensor
34+
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
3435
from ...test_pipeline_mixin import PipelineTesterMixin
3536

3637

@@ -52,12 +53,12 @@ def __init__(
5253
self,
5354
parent,
5455
batch_size=13,
55-
seq_length=7,
56+
seq_length=35,
5657
is_training=False,
5758
use_input_mask=True,
5859
use_labels=True,
5960
vocab_size=99,
60-
image_token_id=98,
61+
image_token_id=4,
6162
hidden_size=32,
6263
num_hidden_layers=2,
6364
num_attention_heads=2,
@@ -73,9 +74,9 @@ def __init__(
7374
num_labels=3,
7475
num_choices=4,
7576
pad_token_id=0,
76-
vq_num_embeds=12,
77-
vq_embed_dim=12,
78-
vq_channel_multiplier=[1, 2],
77+
vq_num_embeds=5,
78+
vq_embed_dim=5,
79+
vq_channel_multiplier=[1, 4],
7980
vq_img_token_start_id=10, # has to be less than vocab size when added with vq_num_embeds
8081
scope=None,
8182
):
@@ -138,7 +139,9 @@ def get_config(self):
138139
start = self.vq_img_token_start_id
139140
end = self.vq_img_token_start_id + self.vq_num_embeds
140141
for i in range(start, end):
141-
vocab_map[i] = f"IMGIMGBS{i}" # dummy str for each token, anything starting with IMGIMG
142+
image_token_infix = "".join(chr(ord("A") + int(c)) for c in str(i))
143+
# dummy str for each image token, anything starting with IMGIMG
144+
vocab_map[i] = f"IMGIMG{image_token_infix}Z"
142145

143146
return ChameleonConfig(
144147
vocab_size=self.vocab_size,
@@ -275,7 +278,6 @@ class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
275278
{
276279
"feature-extraction": ChameleonModel,
277280
"text-generation": ChameleonForConditionalGeneration,
278-
"image-text-to-text": ChameleonForConditionalGeneration,
279281
}
280282
if is_torch_available()
281283
else {}
@@ -330,6 +332,149 @@ def test_model_rope_scaling(self, scaling_type):
330332
def test_batching_equivalence(self):
331333
pass
332334

335+
@unittest.skip("Chameleon VQ model cannot be squishes more due to hardcoded layer params in model code")
336+
def test_model_is_small(self):
337+
pass
338+
339+
340+
class ChameleonVision2SeqModelTester(ChameleonModelTester):
341+
def __init__(self, parent, image_size=10, **kwargs):
342+
super().__init__(parent, **kwargs)
343+
self.image_size = image_size
344+
self.image_seq_length = 25
345+
346+
def prepare_config_and_inputs(self):
347+
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
348+
input_ids[input_ids == self.image_token_id] = self.pad_token_id
349+
input_ids[:, : self.image_seq_length] = self.image_token_id
350+
attention_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
351+
pixel_values = floats_tensor([self.batch_size, 3, self.image_size, self.image_size])
352+
353+
config = self.get_config()
354+
355+
return config, input_ids, attention_mask, pixel_values
356+
357+
def prepare_config_and_inputs_for_common(self):
358+
config_and_inputs = self.prepare_config_and_inputs()
359+
config, input_ids, attention_mask, pixel_values = config_and_inputs
360+
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values}
361+
return config, inputs_dict
362+
363+
364+
@require_torch
365+
class ChameleonVision2SeqModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
366+
all_model_classes = (ChameleonModel, ChameleonForConditionalGeneration) if is_torch_available() else ()
367+
pipeline_model_mapping = (
368+
{
369+
"image-text-to-text": ChameleonForConditionalGeneration,
370+
}
371+
if is_torch_available()
372+
else {}
373+
)
374+
test_headmasking = False
375+
test_pruning = False
376+
fx_compatible = False
377+
378+
def setUp(self):
379+
self.model_tester = ChameleonVision2SeqModelTester(self)
380+
self.config_tester = ConfigTester(self, config_class=ChameleonConfig, hidden_size=37)
381+
382+
def test_config(self):
383+
self.config_tester.run_common_tests()
384+
385+
@unittest.skip("Chameleon forces some token ids to be -inf!")
386+
def test_batching_equivalence(self):
387+
pass
388+
389+
@unittest.skip("Chameleon cannot do offload because it uses `self.linear.weight` in forward")
390+
def test_cpu_offload(self):
391+
pass
392+
393+
@unittest.skip("Chameleon cannot do offload because it uses `self.linear.weight` in forward")
394+
def test_disk_offload_bin(self):
395+
pass
396+
397+
@unittest.skip("Chameleon cannot do offload because it uses `self.linear.weight` in forward")
398+
def test_disk_offload_safetensors(self):
399+
pass
400+
401+
@unittest.skip("Chameleon VQ model cannot be squishes more due to hardcoded layer params in model code")
402+
def test_model_is_small(self):
403+
pass
404+
405+
def test_mismatching_num_image_tokens(self):
406+
"""
407+
Tests that VLMs through an error with explicit message saying what is wrong
408+
when number of images don't match number of image tokens in the text.
409+
Also we need to test multi-image cases when one prompr has multiple image tokens.
410+
"""
411+
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
412+
for model_class in self.all_model_classes:
413+
model = model_class(config).to(torch_device)
414+
curr_input_dict = copy.deepcopy(input_dict) # the below tests modify dict in-place
415+
_ = model(**curr_input_dict) # successful forward with no modifications
416+
417+
# remove one image but leave the image token in text
418+
curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-1:, ...]
419+
with self.assertRaises(ValueError):
420+
_ = model(**curr_input_dict)
421+
422+
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
423+
input_ids = curr_input_dict["input_ids"][:1]
424+
pixel_values = curr_input_dict["pixel_values"][:1]
425+
input_ids = torch.cat([input_ids, input_ids], dim=0)
426+
427+
# one image and two image tokens raise an error
428+
with self.assertRaises(ValueError):
429+
_ = model(input_ids=input_ids, pixel_values=pixel_values)
430+
431+
# two images and two image tokens don't raise an error
432+
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
433+
_ = model(input_ids=input_ids, pixel_values=pixel_values)
434+
435+
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
436+
def test_inputs_embeds(self):
437+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
438+
439+
for model_class in self.all_model_classes:
440+
model = model_class(config)
441+
model.to(torch_device)
442+
model.eval()
443+
444+
inputs = self._prepare_for_class(inputs_dict, model_class)
445+
446+
input_ids = inputs["input_ids"]
447+
del inputs["input_ids"]
448+
del inputs["pixel_values"]
449+
450+
wte = model.get_input_embeddings()
451+
inputs["inputs_embeds"] = wte(input_ids)
452+
453+
with torch.no_grad():
454+
model(**inputs)
455+
456+
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
457+
# while some other models require pixel_values to be present
458+
def test_inputs_embeds_matches_input_ids(self):
459+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
460+
461+
for model_class in self.all_model_classes:
462+
model = model_class(config)
463+
model.to(torch_device)
464+
model.eval()
465+
466+
inputs = self._prepare_for_class(inputs_dict, model_class)
467+
input_ids = inputs["input_ids"]
468+
del inputs["input_ids"]
469+
del inputs["pixel_values"]
470+
471+
inputs_embeds = model.get_input_embeddings()(input_ids)
472+
473+
with torch.no_grad():
474+
out_ids = model(input_ids=input_ids, **inputs)[0]
475+
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
476+
torch.testing.assert_close(out_embeds, out_ids)
477+
333478

334479
@require_torch
335480
class ChameleonIntegrationTest(unittest.TestCase):

0 commit comments

Comments
 (0)