Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fairseq, pytorch] use pretrained model #72

Open
nguyenvulong opened this issue Jan 5, 2024 · 0 comments
Open

[fairseq, pytorch] use pretrained model #72

nguyenvulong opened this issue Jan 5, 2024 · 0 comments
Labels
documentation Improvements or additions to documentation

Comments

@nguyenvulong
Copy link
Owner

nguyenvulong commented Jan 5, 2024

There are different ways to load the model from fairseq, sadly there's no better documentation but reading the code
I collected some of the tutorials from the internet. Some worked some didn't. Give it a try.

Also there is a new fairseq2.
Help me upvote this question so that someone can help us understand what's happening: https://stackoverflow.com/questions/77757228/what-are-the-differences-between-fairseq-and-fairseq2

FAIRSEQ 1

ONE

# load the checkpoint
wav2vec2_checkpoint_path = '/path/to/checkpoint.pt'
checkpoint = torch.load(wav2vec2_checkpoint_path)

# get the config of the model
wav2vec2_encoder = fairseq.models.wav2vec.Wav2Vec2Model.build_model(checkpoint['cfg']['model'])

# load the weights
wav2vec2_encoder.load_state_dict(checkpoint['model'])

# test 
audio = torch.randn(1,10000)
features = wav2vec2_encoder(audio, features_only=True, mask=False)['x']

TWO

import torch
import fairseq

cp_path = '/path/to/wav2vec.pt'
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
model = model[0]
model.eval()

wav_input_16khz = torch.randn(1,10000)
z = model.feature_extractor(wav_input_16khz)
c = model.feature_aggregator(z)

FAIRSEQ2

Example

class SSLModel(nn.Module):
    def __init__(self,device):
        super(SSLModel, self).__init__()
        self.dtype = torch.float32
        self.device=device
        self.out_dim = 1024
        self.model = load_conformer_shaw_model("conformer_shaw", device=device, dtype=self.dtype)
        self.model.eval()

        return

    def extract_feat(self, seqs):
        # with torch.inference_mode():
        with torch.no_grad():
            seqs, padding_mask = self.model.encoder_frontend(seqs, None)
            seqs, padding_mask = self.model.encoder(seqs, None)

        return seqs

HUGGINGFACE: Updating ...

Check this: https://huggingface.co/docs/transformers/main/model_doc/wav2vec2

Usecases

Speech classification xlsr wav2vec2
https://colab.research.google.com/github/m3hrdadfi/soxan/blob/main/notebooks/Emotion_recognition_in_Greek_speech_using_Wav2Vec2.ipynb#scrollTo=bqF4rNMzI1M5

Others

To modify the task, do something like this

        model_override_rules = {}
        model_override_rules['task'] = {'_name': 'audio_finetuning'}
        cp_path = os.path.join(BASE_DIR,'pretrained/w2v-large.pt')
        model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path], arg_overrides=model_override_rules)
@nguyenvulong nguyenvulong added the documentation Improvements or additions to documentation label Jan 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

1 participant