Skip to content

Commit 403d530

Browse files
authored
Auto feature extractor (huggingface#11097)
* AutoFeatureExtractor * Init and first tests * Tests * Damn you gitignore * Quality * Defensive test for when not all backends are here * Use pattern for Speech2Text models
1 parent 520198f commit 403d530

18 files changed

+309
-34
lines changed

.gitignore

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@ __pycache__/
99
*.so
1010

1111
# tests and logs
12-
tests/fixtures/*
13-
!tests/fixtures/sample_text_no_unicode.txt
12+
tests/fixtures/cached_*_text.txt
1413
logs/
1514
lightning_logs/
1615
lang_code_data/

docs/source/model_doc/auto.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,13 @@ AutoTokenizer
4444
:members:
4545

4646

47+
AutoFeatureExtractor
48+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
49+
50+
.. autoclass:: transformers.AutoFeatureExtractor
51+
:members:
52+
53+
4754
AutoModel
4855
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4956

src/transformers/__init__.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
_BaseLazyModule,
4646
is_flax_available,
4747
is_sentencepiece_available,
48+
is_speech_available,
4849
is_tf_available,
4950
is_tokenizers_available,
5051
is_torch_available,
@@ -102,6 +103,7 @@
102103
"is_py3nvml_available",
103104
"is_sentencepiece_available",
104105
"is_sklearn_available",
106+
"is_speech_available",
105107
"is_tf_available",
106108
"is_tokenizers_available",
107109
"is_torch_available",
@@ -133,9 +135,11 @@
133135
"models.auto": [
134136
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP",
135137
"CONFIG_MAPPING",
138+
"FEATURE_EXTRACTOR_MAPPING",
136139
"MODEL_NAMES_MAPPING",
137140
"TOKENIZER_MAPPING",
138141
"AutoConfig",
142+
"AutoFeatureExtractor",
139143
"AutoTokenizer",
140144
],
141145
"models.bart": ["BartConfig", "BartTokenizer"],
@@ -202,7 +206,6 @@
202206
"models.speech_to_text": [
203207
"SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
204208
"Speech2TextConfig",
205-
"Speech2TextFeatureExtractor",
206209
],
207210
"models.squeezebert": ["SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SqueezeBertConfig", "SqueezeBertTokenizer"],
208211
"models.t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config"],
@@ -288,7 +291,6 @@
288291
_import_structure["models.pegasus"].append("PegasusTokenizer")
289292
_import_structure["models.reformer"].append("ReformerTokenizer")
290293
_import_structure["models.speech_to_text"].append("Speech2TextTokenizer")
291-
_import_structure["models.speech_to_text"].append("Speech2TextProcessor")
292294
_import_structure["models.t5"].append("T5Tokenizer")
293295
_import_structure["models.xlm_prophetnet"].append("XLMProphetNetTokenizer")
294296
_import_structure["models.xlm_roberta"].append("XLMRobertaTokenizer")
@@ -339,13 +341,28 @@
339341

340342
if is_sentencepiece_available():
341343
_import_structure["convert_slow_tokenizer"] = ["SLOW_TO_FAST_CONVERTERS", "convert_slow_tokenizer"]
344+
342345
else:
343346
from .utils import dummy_tokenizers_objects
344347

345348
_import_structure["utils.dummy_tokenizers_objects"] = [
346349
name for name in dir(dummy_tokenizers_objects) if not name.startswith("_")
347350
]
348351

352+
# Speech-specific objects
353+
if is_speech_available():
354+
_import_structure["models.speech_to_text"].append("Speech2TextFeatureExtractor")
355+
356+
if is_sentencepiece_available():
357+
_import_structure["models.speech_to_text"].append("Speech2TextProcessor")
358+
359+
else:
360+
from .utils import dummy_speech_objects
361+
362+
_import_structure["utils.dummy_speech_objects"] = [
363+
name for name in dir(dummy_speech_objects) if not name.startswith("_")
364+
]
365+
349366
# Vision-specific objects
350367
if is_vision_available():
351368
_import_structure["image_utils"] = ["ImageFeatureExtractionMixin"]
@@ -1394,6 +1411,7 @@
13941411
is_py3nvml_available,
13951412
is_sentencepiece_available,
13961413
is_sklearn_available,
1414+
is_speech_available,
13971415
is_tf_available,
13981416
is_tokenizers_available,
13991417
is_torch_available,
@@ -1429,9 +1447,11 @@
14291447
from .models.auto import (
14301448
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP,
14311449
CONFIG_MAPPING,
1450+
FEATURE_EXTRACTOR_MAPPING,
14321451
MODEL_NAMES_MAPPING,
14331452
TOKENIZER_MAPPING,
14341453
AutoConfig,
1454+
AutoFeatureExtractor,
14351455
AutoTokenizer,
14361456
)
14371457
from .models.bart import BartConfig, BartTokenizer
@@ -1494,11 +1514,7 @@
14941514
from .models.reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
14951515
from .models.retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig, RetriBertTokenizer
14961516
from .models.roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, RobertaTokenizer
1497-
from .models.speech_to_text import (
1498-
SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP,
1499-
Speech2TextConfig,
1500-
Speech2TextFeatureExtractor,
1501-
)
1517+
from .models.speech_to_text import SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, Speech2TextConfig
15021518
from .models.squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig, SqueezeBertTokenizer
15031519
from .models.t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
15041520
from .models.tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig, TapasTokenizer
@@ -1585,7 +1601,7 @@
15851601
from .models.mt5 import MT5Tokenizer
15861602
from .models.pegasus import PegasusTokenizer
15871603
from .models.reformer import ReformerTokenizer
1588-
from .models.speech_to_text import Speech2TextProcessor, Speech2TextTokenizer
1604+
from .models.speech_to_text import Speech2TextTokenizer
15891605
from .models.t5 import T5Tokenizer
15901606
from .models.xlm_prophetnet import XLMProphetNetTokenizer
15911607
from .models.xlm_roberta import XLMRobertaTokenizer
@@ -1627,9 +1643,19 @@
16271643

16281644
if is_sentencepiece_available():
16291645
from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, convert_slow_tokenizer
1646+
16301647
else:
16311648
from .utils.dummy_tokenizers_objects import *
16321649

1650+
if is_speech_available():
1651+
from .models.speech_to_text import Speech2TextFeatureExtractor
1652+
1653+
if is_sentencepiece_available():
1654+
from .models.speech_to_text import Speech2TextProcessor
1655+
1656+
else:
1657+
from .utils.dummy_speech_objects import *
1658+
16331659
if is_vision_available():
16341660
from .image_utils import ImageFeatureExtractionMixin
16351661
from .models.vit import ViTFeatureExtractor

src/transformers/dependency_versions_table.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
"sphinx-copybutton": "sphinx-copybutton",
4444
"sphinx-markdown-tables": "sphinx-markdown-tables",
4545
"sphinx-rtd-theme": "sphinx-rtd-theme==0.4.3",
46+
"sphinxext-opengraph": "sphinxext-opengraph==0.4.1",
4647
"sphinx": "sphinx==3.2.1",
4748
"starlette": "starlette",
4849
"tensorflow-cpu": "tensorflow-cpu>=2.3",

src/transformers/feature_extraction_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,13 @@ def get_feature_extractor_dict(
325325
local_files_only = kwargs.pop("local_files_only", False)
326326
revision = kwargs.pop("revision", None)
327327

328+
from_pipeline = kwargs.pop("_from_pipeline", None)
329+
from_auto_class = kwargs.pop("_from_auto", False)
330+
331+
user_agent = {"file_type": "feature extractor", "from_auto_class": from_auto_class}
332+
if from_pipeline is not None:
333+
user_agent["using_pipeline"] = from_pipeline
334+
328335
if is_offline_mode() and not local_files_only:
329336
logger.info("Offline mode: forcing local_files_only=True")
330337
local_files_only = True
@@ -349,6 +356,7 @@ def get_feature_extractor_dict(
349356
resume_download=resume_download,
350357
local_files_only=local_files_only,
351358
use_auth_token=use_auth_token,
359+
user_agent=user_agent,
352360
)
353361
# Load feature_extractor dict
354362
with open(resolved_feature_extractor_file, "r", encoding="utf-8") as reader:
@@ -426,6 +434,7 @@ def to_dict(self) -> Dict[str, Any]:
426434
:obj:`Dict[str, Any]`: Dictionary of all the attributes that make up this feature extractor instance.
427435
"""
428436
output = copy.deepcopy(self.__dict__)
437+
output["feature_extractor_type"] = self.__class__.__name__
429438

430439
return output
431440

src/transformers/file_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,11 @@ def is_torchaudio_available():
397397
return _torchaudio_available
398398

399399

400+
def is_speech_available():
401+
# For now this depends on torchaudio but the exact dependency might evolve in the future.
402+
return _torchaudio_available
403+
404+
400405
def torch_only_method(fn):
401406
def wrapper(*args, **kwargs):
402407
if not _torch_available:
@@ -513,6 +518,13 @@ def wrapper(*args, **kwargs):
513518
"""
514519

515520

521+
# docstyle-ignore
522+
SPEECH_IMPORT_ERROR = """
523+
{0} requires the torchaudio library but it was not found in your environment. You can install it with pip:
524+
`pip install torchaudio`
525+
"""
526+
527+
516528
# docstyle-ignore
517529
VISION_IMPORT_ERROR = """
518530
{0} requires the PIL library but it was not found in your environment. You can install it with pip:
@@ -586,6 +598,12 @@ def requires_scatter(obj):
586598
raise ImportError(SCATTER_IMPORT_ERROR.format(name))
587599

588600

601+
def requires_speech(obj):
602+
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
603+
if not is_speech_available():
604+
raise ImportError(SPEECH_IMPORT_ERROR.format(name))
605+
606+
589607
def requires_vision(obj):
590608
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
591609
if not is_vision_available():

src/transformers/models/auto/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
_import_structure = {
2525
"configuration_auto": ["ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CONFIG_MAPPING", "MODEL_NAMES_MAPPING", "AutoConfig"],
26+
"feature_extraction_auto": ["FEATURE_EXTRACTOR_MAPPING", "AutoFeatureExtractor"],
2627
"tokenization_auto": ["TOKENIZER_MAPPING", "AutoTokenizer"],
2728
}
2829

@@ -104,6 +105,7 @@
104105

105106
if TYPE_CHECKING:
106107
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, AutoConfig
108+
from .feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
107109
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
108110

109111
if is_torch_available():

0 commit comments

Comments
 (0)