Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I can only use its encoder to extract audio features, right? How should I use it? Could you provide an example #67

Open
wvinzh opened this issue Dec 22, 2023 · 1 comment

Comments

@wvinzh
Copy link

wvinzh commented Dec 22, 2023

I can only use its encoder to extract audio features, right? How should I use it? Could you provide an example

@sanchit-gandhi
Copy link
Collaborator

Yes - you can do so with the following:

import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from datasets import load_dataset

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# load model + processor
model_id = "distil-whisper/distil-large-v2"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
encoder = model.get_encoder()

processor = AutoProcessor.from_pretrained(model_id)

# load dataset
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]["array"]

# preprocess inputs
input_features = processor(sample, return_tensors="pt").input_features
input_features = input_features.to(device, dtype=torch_dtype)

# forward pass to get encoder hidden states
with torch.no_grad():
    encoder_hidden_states = encoder(input_features).last_hidden_state

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants