Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring expressivity/predict into ExpressiveTranslator. #292

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ Please check out above [section](#seamlessexpressive-models) on how to acquire `
### W2v-BERT 2.0 speech encoder
| Model Name | #params | checkpoint |
| ----------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| W2v-BERT 2.0 | 600M | [🤗 Model card](https://huggingface.co/facebook/conformer-shaw) - [checkpoint](https://huggingface.co/facebook/conformer-shaw/resolve/main/conformer_shaw.pt)
| W2v-BERT 2.0 | 600M | [🤗 Model card](https://huggingface.co/facebook/w2v-bert-2.0) - [checkpoint](https://huggingface.co/facebook/w2v-bert-2.0/resolve/main/conformer_shaw.pt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we fix this in another PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure


Here's how you should do a foward pass through the speech encoder:

Expand Down
2 changes: 1 addition & 1 deletion demo/expressive/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
load_gcmvn_stats,
load_unity_unit_tokenizer,
)
from seamless_communication.cli.expressivity.predict.pretssel_generator import PretsselGenerator
from seamless_communication.inference.pretssel_generator import PretsselGenerator

from typing import Tuple
from utils import LANGUAGE_CODE_TO_NAME
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@
from torch import Tensor
from tqdm import tqdm

from seamless_communication.cli.expressivity.predict.pretssel_generator import (
PretsselGenerator,
)

from seamless_communication.cli.m4t.evaluate.evaluate import (
adjust_output_for_corrupted_inputs,
count_lines,
Expand All @@ -36,6 +34,9 @@
add_inference_arguments,
set_generation_opts,
)
from seamless_communication.inference.pretssel_generator import (
PretsselGenerator,
)
from seamless_communication.inference import BatchedSpeechOutput, Translator
from seamless_communication.models.unity import (
load_gcmvn_stats,
Expand Down
101 changes: 11 additions & 90 deletions src/seamless_communication/cli/expressivity/predict/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,14 @@
import torchaudio
from pathlib import Path

from fairseq2.data import SequenceData
from fairseq2.data.audio import WaveformToFbankConverter

from seamless_communication.cli.expressivity.predict.pretssel_generator import (
PretsselGenerator,
)
from seamless_communication.cli.m4t.predict import (
add_inference_arguments,
set_generation_opts,
)
from seamless_communication.inference import Translator
from seamless_communication.models.unity import (
load_gcmvn_stats,
load_unity_unit_tokenizer,
)
from seamless_communication.inference import ExpressiveTranslator
from seamless_communication.store import add_gated_assets


AUDIO_SAMPLE_RATE = 16000


logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
Expand All @@ -39,13 +26,6 @@
logger = logging.getLogger(__name__)


def remove_prosody_tokens_from_text(text: str) -> str:
# filter out prosody tokens, there is only emphasis '*', and pause '='
text = text.replace("*", "").replace("=", "")
text = " ".join(text.split())
return text


def main() -> None:
parser = argparse.ArgumentParser(description="Running SeamlessExpressive inference.")
parser.add_argument("input", type=str, help="Audio WAV file path.")
Expand Down Expand Up @@ -82,59 +62,11 @@ def main() -> None:

logger.info(f"Running inference on {device=} with {dtype=}.")

unit_tokenizer = load_unity_unit_tokenizer(args.model_name)

translator = Translator(
expressive_translator = ExpressiveTranslator(
args.model_name,
vocoder_name_or_card=None,
device=device,
dtype=dtype,
)

pretssel_generator = PretsselGenerator(
args.vocoder_name,
vocab_info=unit_tokenizer.vocab_info,
device=device,
dtype=dtype,
)

fbank_extractor = WaveformToFbankConverter(
num_mel_bins=80,
waveform_scale=2**15,
channel_last=True,
standardize=False,
device=device,
dtype=dtype,
)

_gcmvn_mean, _gcmvn_std = load_gcmvn_stats(args.vocoder_name)
gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype)
gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype)

wav, sample_rate = torchaudio.load(args.input)
wav = torchaudio.functional.resample(wav, orig_freq=sample_rate, new_freq=16_000)
wav = wav.transpose(0, 1)

data = fbank_extractor(
{
"waveform": wav,
"sample_rate": 16000,
}
)
fbank = data["fbank"]
gcmvn_fbank = fbank.subtract(gcmvn_mean).divide(gcmvn_std)
std, mean = torch.std_mean(fbank, dim=0)
fbank = fbank.subtract(mean).divide(std)

src = SequenceData(
seqs=fbank.unsqueeze(0),
seq_lens=torch.LongTensor([fbank.shape[0]]),
is_ragged=False,
)
src_gcmvn = SequenceData(
seqs=gcmvn_fbank.unsqueeze(0),
seq_lens=torch.LongTensor([gcmvn_fbank.shape[0]]),
is_ragged=False,
device,
dtype
)

text_generation_opts, unit_generation_opts = set_generation_opts(args)
Expand All @@ -145,22 +77,13 @@ def main() -> None:
f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
)

text_output, unit_output = translator.predict(
src,
"s2st",
speech_output, text_output = expressive_translator.predict(
args.input,
args.tgt_lang,
text_generation_opts=text_generation_opts,
unit_generation_opts=unit_generation_opts,
unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
duration_factor=args.duration_factor,
prosody_encoder_input=src_gcmvn,
)

assert unit_output is not None
speech_output = pretssel_generator.predict(
unit_output.units,
tgt_lang=args.tgt_lang,
prosody_encoder_input=src_gcmvn,
text_generation_opts,
unit_generation_opts,
args.unit_generation_ngram_filtering,
args.duration_factor,
)

logger.info(f"Saving expressive translated audio in {args.tgt_lang}")
Expand All @@ -170,9 +93,7 @@ def main() -> None:
sample_rate=speech_output.sample_rate,
)

text_out = remove_prosody_tokens_from_text(str(text_output[0]))

logger.info(f"Translated text in {args.tgt_lang}: {text_out}")
logger.info(f"Translated text in {args.tgt_lang}: {text_output[0]}")


if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions src/seamless_communication/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
SequenceGeneratorOptions as SequenceGeneratorOptions,
)
from seamless_communication.inference.generator import UnitYGenerator as UnitYGenerator
from seamless_communication.inference.expressive_translator import (
ExpressiveTranslator as ExpressiveTranslator,
)
from seamless_communication.inference.translator import (
BatchedSpeechOutput as BatchedSpeechOutput,
)
Expand Down
153 changes: 153 additions & 0 deletions src/seamless_communication/inference/expressive_translator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# MIT_LICENSE file in the root directory of this source tree.

import torch
import torchaudio

from torch.nn import Module
from typing import List, Optional, Tuple, Union

from fairseq2.assets.card import AssetCard
from fairseq2.data import SequenceData, StringLike
from fairseq2.data.audio import WaveformToFbankConverter
from fairseq2.typing import DataType, Device

from seamless_communication.inference import BatchedSpeechOutput, Translator
from seamless_communication.inference.generator import SequenceGeneratorOptions
from seamless_communication.inference.pretssel_generator import (
PretsselGenerator,
)
from seamless_communication.models.unity import (
load_gcmvn_stats,
load_unity_unit_tokenizer,
)

AUDIO_SAMPLE_RATE = 16000


class ExpressiveTranslator(Module):
def __init__(
self,
model_name_or_card: Union[str, AssetCard],
vocoder_name_or_card: Union[str, AssetCard, None],
device: Device,
dtype: DataType,
):
super().__init__()

unit_tokenizer = load_unity_unit_tokenizer(model_name_or_card)

self.translator = Translator(
model_name_or_card,
vocoder_name_or_card=None,
device=device,
dtype=dtype,
)

self.pretssel_generator = PretsselGenerator(
vocoder_name_or_card,
vocab_info=unit_tokenizer.vocab_info,
device=device,
dtype=dtype,
)

self.fbank_extractor = WaveformToFbankConverter(
num_mel_bins=80,
waveform_scale=2**15,
channel_last=True,
standardize=False,
device=device,
dtype=dtype,
)

_gcmvn_mean, _gcmvn_std = load_gcmvn_stats(vocoder_name_or_card)
self.gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype)
self.gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype)

@staticmethod
def remove_prosody_tokens_from_text(text_output: List[str]) -> List[str]:
modified_text_output = []
for text in text_output:
# filter out prosody tokens, there is only emphasis '*', and pause '='
text = text.replace("*", "").replace("=", "")
text = " ".join(text.split())
modified_text_output.append(text)
return modified_text_output

@torch.inference_mode()
def predict(
self,
audio_path: str,
tgt_lang: str,
text_generation_opts: Optional[SequenceGeneratorOptions] = None,
unit_generation_opts: Optional[SequenceGeneratorOptions] = None,
unit_generation_ngram_filtering: bool = False,
duration_factor: float = 1.0,
) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
"""
The main method used to perform inference on all tasks.

:param audio_path:
Path to audio waveform.
:param tgt_lang:
Target language to decode into.
:param text_generation_opts:
Text generation hyperparameters for incremental decoding.
:param unit_generation_opts:
Unit generation hyperparameters for incremental decoding.
:param unit_generation_ngram_filtering:
If True, removes consecutive repeated ngrams
from the decoded unit output.

:returns:
- Batched list of Translated text.
- Translated BatchedSpeechOutput.
"""
# TODO: Replace with fairseq2.data once re-sampling is implemented.
wav, sample_rate = torchaudio.load(audio_path)
wav = torchaudio.functional.resample(wav, orig_freq=sample_rate, new_freq=16_000)
wav = wav.transpose(0, 1)

data = self.fbank_extractor(
{
"waveform": wav,
"sample_rate": AUDIO_SAMPLE_RATE,
}
)
fbank = data["fbank"]
gcmvn_fbank = fbank.subtract(self.gcmvn_mean).divide(self.gcmvn_std)
std, mean = torch.std_mean(fbank, dim=0)
fbank = fbank.subtract(mean).divide(std)

src = SequenceData(
seqs=fbank.unsqueeze(0),
seq_lens=torch.LongTensor([fbank.shape[0]]),
is_ragged=False,
)
src_gcmvn = SequenceData(
seqs=gcmvn_fbank.unsqueeze(0),
seq_lens=torch.LongTensor([gcmvn_fbank.shape[0]]),
is_ragged=False,
)

text_output, unit_output = self.translator.predict(
src,
"s2st",
tgt_lang,
text_generation_opts=text_generation_opts,
unit_generation_opts=unit_generation_opts,
unit_generation_ngram_filtering=unit_generation_ngram_filtering,
duration_factor=duration_factor,
prosody_encoder_input=src_gcmvn,
)
text_output = self.remove_prosody_tokens_from_text(text_output)

assert unit_output is not None
speech_output = self.pretssel_generator.predict(
unit_output.units,
tgt_lang=tgt_lang,
prosody_encoder_input=src_gcmvn,
)
return text_output, speech_output
4 changes: 2 additions & 2 deletions src/seamless_communication/inference/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import List, Optional, Tuple, Union, cast

import torch
import torch.nn as nn
from torch.nn import Module
from fairseq2.assets import asset_store
from fairseq2.assets.card import AssetCard
from fairseq2.data import Collater, SequenceData, StringLike
Expand Down Expand Up @@ -75,7 +75,7 @@ class BatchedSpeechOutput:
"""Sample rate of the audio waveforms."""


class Translator(nn.Module):
class Translator(Module):
def __init__(
self,
model_name_or_card: Union[str, AssetCard],
Expand Down