-
Notifications
You must be signed in to change notification settings - Fork 25.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Image + text + audio uniform processors #30511
base: main
Are you sure you want to change the base?
Changes from 51 commits
42ecf48
ccb2147
f999e0c
8fb3a6b
a90c766
49cb6cc
3ac1c7e
7a819fd
9cc38b7
ff6a950
7db64a0
41674d9
618a687
f39cdc1
9a6f97d
380f82f
75f15d3
3df5faa
5ad0694
69e5a2d
68c2f40
e1e4084
bfa81e5
4b557b0
b7fc377
270bb9e
94a1b75
6603bf0
004c961
c2e49f5
79958b5
a36f524
3238dd3
3afde22
eb99e29
c6afd63
3b824b5
d8c2a6e
78433a1
0cd3d66
d9c51c1
43ba3bd
dcfa8db
7048fee
25ba7ba
36ba3cc
05fea5d
e05b14b
46a1fe8
13a1909
7bf3615
a285284
1fe9eb5
931f68c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,8 +17,11 @@ | |
""" | ||
|
||
|
||
from ...processing_utils import ProcessorMixin | ||
from ...tokenization_utils_base import BatchEncoding | ||
from typing import List, Union | ||
|
||
from ...image_utils import ImageInput | ||
from ...processing_utils import ProcessingKwargs, ProcessorMixin | ||
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput | ||
|
||
|
||
class AlignProcessor(ProcessorMixin): | ||
|
@@ -33,6 +36,7 @@ class AlignProcessor(ProcessorMixin): | |
The image processor is a required input. | ||
tokenizer ([`BertTokenizer`, `BertTokenizerFast`]): | ||
The tokenizer is a required input. | ||
""" | ||
|
||
attributes = ["image_processor", "tokenizer"] | ||
|
@@ -41,12 +45,59 @@ class AlignProcessor(ProcessorMixin): | |
|
||
def __init__(self, image_processor, tokenizer): | ||
super().__init__(image_processor, tokenizer) | ||
|
||
def __call__(self, text=None, images=None, padding="max_length", max_length=64, return_tensors=None, **kwargs): | ||
self.processing_kwargs: ProcessingKwargs = { | ||
"common_kwargs": {"return_tensors": None}, | ||
"text_kwargs": { | ||
"text_pair": None, | ||
"text_target": None, | ||
"text_pair_target": None, | ||
"add_special_tokens": True, | ||
"padding": "max_length", | ||
"truncation": True, | ||
"max_length": 64, | ||
"stride": 0, | ||
"is_split_into_words": False, | ||
"pad_to_multiple_of": None, | ||
"return_token_type_ids": None, | ||
"return_attention_mask": None, | ||
"return_overflowing_tokens": False, | ||
"return_special_tokens_mask": False, | ||
"return_offsets_mapping": False, | ||
"return_length": False, | ||
"verbose": True, | ||
}, | ||
"images_kwargs": { | ||
"do_crop_margin": None, | ||
"do_resize": None, | ||
"size": None, | ||
"resample": None, | ||
"do_thumbnail": None, | ||
"do_align_long_axis": None, | ||
"do_pad": None, | ||
"do_rescale": None, | ||
"rescale_factor": None, | ||
"do_normalize": None, | ||
"image_mean": None, | ||
"image_std": None, | ||
"data_format": "channels_first", | ||
"input_data_format": None, | ||
}, | ||
"audio_kwargs": {}, | ||
"videos_kwargs": {}, | ||
} | ||
|
||
def __call__( | ||
self, | ||
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, | ||
images: ImageInput = None, | ||
audio=None, | ||
videos=None, | ||
**kwargs, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My main question is how are type hints and signature scan done for this? Is there a way to say that the kwargs include text kwargs and image kwargs + benefit from having the doc as well? |
||
) -> BatchEncoding: | ||
""" | ||
Main method to prepare text(s) and image(s) to be fed as input to the model. This method forwards the `text` | ||
and `kwargs` arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode | ||
the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to | ||
arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode | ||
the text. To prepare the image(s), this method forwards the `images` arguments to | ||
EfficientNetImageProcessor's [`~EfficientNetImageProcessor.__call__`] if `images` is not `None`. Please refer | ||
to the doctsring of the above two methods for more information. | ||
|
@@ -58,19 +109,12 @@ def __call__(self, text=None, images=None, padding="max_length", max_length=64, | |
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): | ||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch | ||
tensor. Both channels-first and channels-last formats are supported. | ||
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `max_length`): | ||
Activates and controls padding for tokenization of input text. Choose between [`True` or `'longest'`, | ||
`'max_length'`, `False` or `'do_not_pad'`] | ||
max_length (`int`, *optional*, defaults to `max_length`): | ||
Maximum padding value to use to pad the input text during tokenization. | ||
return_tensors (`str` or [`~utils.TensorType`], *optional*): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's a good point, in general anything that is |
||
If set, will return tensors of a particular framework. Acceptable values are: | ||
- `'tf'`: Return TensorFlow `tf.constant` objects. | ||
- `'pt'`: Return PyTorch `torch.Tensor` objects. | ||
- `'np'`: Return NumPy `np.ndarray` objects. | ||
- `'jax'`: Return JAX `jnp.ndarray` objects. | ||
- `'tf'`: Return TensorFlow `tf.constant` objects. | ||
- `'pt'`: Return PyTorch `torch.Tensor` objects. | ||
- `'np'`: Return NumPy `np.ndarray` objects. | ||
- `'jax'`: Return JAX `jnp.ndarray` objects. | ||
Returns: | ||
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields: | ||
|
@@ -82,15 +126,28 @@ def __call__(self, text=None, images=None, padding="max_length", max_length=64, | |
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. | ||
""" | ||
if text is None and images is None: | ||
raise ValueError("You have to specify either text or images. Both cannot be none.") | ||
raise ValueError("You must specify either text or images.") | ||
|
||
if text is not None: | ||
encoding = self.tokenizer( | ||
text, padding=padding, max_length=max_length, return_tensors=return_tensors, **kwargs | ||
) | ||
text_kwargs = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🤩 |
||
**self.processing_kwargs["text_kwargs"], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So this is basically creating a dictionary with the default, and is updates with the kwargs.
|
||
**self.processing_kwargs["common_kwargs"], | ||
**kwargs, | ||
} | ||
encoding = self.tokenizer(text, **text_kwargs) | ||
|
||
if images is not None: | ||
image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs) | ||
images_kwargs = { | ||
**self.processing_kwargs["images_kwargs"], | ||
**self.processing_kwargs["common_kwargs"], | ||
**kwargs, | ||
} | ||
image_features = self.image_processor(images, **images_kwargs) | ||
|
||
# BC for explicit return_tensors | ||
common_kwargs = {**self.processing_kwargs["common_kwargs"], **kwargs} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we prepare common_kwargs here and not use them? Is it for consistency in the processor patterns? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will correct this part in |
||
if "return_tensors" in common_kwargs: | ||
return_tensors = common_kwargs.pop("return_tensors", None) | ||
|
||
if text is not None and images is not None: | ||
encoding["pixel_values"] = image_features.pixel_values | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only thing I'd say about the current design is the default behaviour of a processor isn't configurable. As users who use multimodal models just want to use the processor, and processors might share the default processing classes they bundle together e.g. CLIP, we might want to be able to override this in the init, such that I can save and load processors with custom behaviour. For example:
The correct way, might be to say that the e.g. tokenizer behaviour should be configured instead. At the moment, I think if I change the tokenizer, and build the processor, it will use the processor's max_length default, rather than the tokenizer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point - I tried the second example with
model_max_length
and it seems to work fine: