Skip to content

Commit

Permalink
24kHz model release (#13)
Browse files Browse the repository at this point in the history
* add support for multiple models

* expose model_type arg to cli + docs update

* version updates

* update tests

* expose download args via cli

* doc correction

* minor change

* update readme

* minor change
  • Loading branch information
eeishaan authored Jun 21, 2023
1 parent 3e877db commit 06e8049
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 26 deletions.
13 changes: 8 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@ pip install git+https://github.com/descriptinc/descript-audio-codec

### Weights
Weights are released as part of this repo under MIT license.
They are automatically downloaded when you first run `encode` or `decode` command. They can be cached locally with
We release weights for models that can natively support 24kHz and 44.1kHz sampling rates.
Weights are automatically downloaded when you first run `encode` or `decode` command. You can cache them using one of the following commands
```bash
python3 -m dac download # downloads the default 44kHz variant
python3 -m dac download --model_type 44khz # downloads the 44kHz variant
python3 -m dac download --model_type 24khz # downloads the 24kHz variant
```
python3 -m dac download
```
We provide a Dockerfile that installs all required dependencies for encoding and decoding. The build process caches model weights inside the image. This allows the image to be used without an internet connection. [Please refer to instructions below.](#docker-image)
We provide a Dockerfile that installs all required dependencies for encoding and decoding. The build process caches the default model weights inside the image. This allows the image to be used without an internet connection. [Please refer to instructions below.](#docker-image)


### Compress audio
Expand Down Expand Up @@ -74,7 +77,7 @@ from audiotools import AudioSignal
model = DAC()

# Load compatible pre-trained model
model = load_model(dac.__model_version__)
model = load_model(tag="latest", model_type="44khz")
model.eval()
model.to('cuda')

Expand Down
7 changes: 5 additions & 2 deletions dac/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
__version__ = "0.0.3"
__model_version__ = "0.0.1"
__version__ = "0.0.4"

# preserved here for legacy reasons
__model_version__ = "latest"

import audiotools

audiotools.ml.BaseModel.INTERN += ["dac.**"]
Expand Down
65 changes: 55 additions & 10 deletions dac/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,68 @@
from pathlib import Path

import argbind
from audiotools import ml

import dac


DAC = dac.model.DAC
Accelerator = ml.Accelerator

__MODEL_LATEST_TAGS__ = {
"44khz": "0.0.1",
"24khz": "0.0.4",
}

__MODEL_URLS__ = {
(
"44khz",
"0.0.1",
): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth",
(
"24khz",
"0.0.4",
): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth",
}


def ensure_default_model(tag: str = dac.__model_version__):
@argbind.bind(group="download", positional=True, without_prefix=True)
def ensure_default_model(tag: str = "latest", model_type: str = "44khz"):
"""
Function that downloads the weights file from URL if a local cache is not
found.
Function that downloads the weights file from URL if a local cache is not found.
Args:
tag (str): The tag of the model to download.
Parameters
----------
tag : str
The tag of the model to download. Defaults to "latest".
model_type : str
The type of model to download. Must be one of "44khz" or "24khz". Defaults to "44khz".
Returns
-------
Path
Directory path required to load model via audiotools.
"""
download_link = f"https://github.com/descriptinc/descript-audio-codec/releases/download/{tag}/weights.pth"
local_path = Path.home() / ".cache" / "descript" / tag / "dac" / f"weights.pth"
model_type = model_type.lower()
tag = tag.lower()

assert model_type in [
"44khz",
"24khz",
], "model_type must be one of '44khz' or '24khz'"

if tag == "latest":
tag = __MODEL_LATEST_TAGS__[model_type]

download_link = __MODEL_URLS__.get((model_type, tag), None)

if download_link is None:
raise ValueError(
f"Could not find model with tag {tag} and model type {model_type}"
)

local_path = (
Path.home() / ".cache" / "descript" / model_type / tag / "dac" / f"weights.pth"
)
if not local_path.exists():
local_path.parent.mkdir(parents=True, exist_ok=True)

Expand All @@ -38,11 +82,12 @@ def ensure_default_model(tag: str = dac.__model_version__):


def load_model(
tag: str,
tag: str = "latest",
load_path: str = "",
model_type: str = "44khz",
):
if not load_path:
load_path = ensure_default_model(tag)
load_path = ensure_default_model(tag, model_type)
kwargs = {
"folder": load_path,
"map_location": "cpu",
Expand Down
26 changes: 24 additions & 2 deletions dac/utils/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from audiotools import AudioSignal
from tqdm import tqdm

import dac
from dac.utils import load_model

warnings.filterwarnings("ignore", category=UserWarning)
Expand Down Expand Up @@ -99,13 +98,36 @@ def decode(
input: str,
output: str = "",
weights_path: str = "",
model_tag: str = dac.__model_version__,
model_tag: str = "latest",
preserve_sample_rate: bool = False,
device: str = "cuda",
model_type: str = "44khz",
):
"""Decode audio from codes.
Parameters
----------
input : str
Path to input directory or file
output : str, optional
Path to output directory, by default "".
If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
weights_path : str, optional
Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
model_tag and model_type.
model_tag : str, optional
Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
preserve_sample_rate : bool, optional
If True, return audio will have the same sample rate as the original
device : str, optional
Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU.
model_type : str, optional
The type of model to download. Must be one of "44khz" or "24khz". Defaults to "44khz". Ignored if `weights_path` is specified.
"""
generator = load_model(
tag=model_tag,
load_path=weights_path,
model_type=model_type,
)
generator.to(device)
generator.eval()
Expand Down
25 changes: 23 additions & 2 deletions dac/utils/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from audiotools.core import util
from tqdm import tqdm

import dac
from dac.utils import load_model

warnings.filterwarnings("ignore", category=UserWarning)
Expand Down Expand Up @@ -124,13 +123,35 @@ def encode(
input: str,
output: str = "",
weights_path: str = "",
model_tag: str = dac.__model_version__,
model_tag: str = "latest",
n_quantizers: int = None,
device: str = "cuda",
model_type: str = "44khz",
):
"""Encode audio files in input path to .dac format.
Parameters
----------
input : str
Path to input audio file or directory
output : str, optional
Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
weights_path : str, optional
Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
model_tag and model_type.
model_tag : str, optional
Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
n_quantizers : int, optional
Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate.
device : str, optional
Device to use, by default "cuda"
model_type : str, optional
The type of model to download. Must be one of "44khz" or "24khz". Defaults to "44khz". Ignored if `weights_path` is specified.
"""
generator = load_model(
tag=model_tag,
load_path=weights_path,
model_type=model_type,
)
generator.to(device)
generator.eval()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name="descript-audio-codec",
version="0.0.3",
version="0.0.4",
classifiers=[
"Intended Audience :: Developers",
"Natural Language :: English",
Expand Down
18 changes: 14 additions & 4 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import argbind
import numpy as np
import pytest
import torch
from audiotools import AudioSignal

from dac.__main__ import run
Expand All @@ -28,20 +30,23 @@ def teardown_module(module):
subprocess.check_output(["rm", "-rf", f"{repo_root}/tests/assets"])


def test_reconstruction():
@pytest.mark.parametrize("model_type", ["44khz", "24khz"])
def test_reconstruction(model_type):
# Test encoding
input_dir = Path(__file__).parent / "assets" / "input"
output_dir = input_dir.parent / "encoded_output"
output_dir = input_dir.parent / model_type / "encoded_output"
args = {
"input": str(input_dir),
"output": str(output_dir),
"device": "cuda" if torch.cuda.is_available() else "cpu",
"model_type": model_type,
}
with argbind.scope(args):
run("encode")

# Test decoding
input_dir = output_dir
output_dir = input_dir.parent / "decoded_output"
output_dir = input_dir.parent / model_type / "decoded_output"
args = {
"input": str(input_dir),
"output": str(output_dir),
Expand All @@ -54,7 +59,12 @@ def test_compression():
# Test encoding
input_dir = Path(__file__).parent / "assets" / "input"
output_dir = input_dir.parent / "encoded_output_quantizers"
args = {"input": str(input_dir), "output": str(output_dir), "n_quantizers": 3}
args = {
"input": str(input_dir),
"output": str(output_dir),
"n_quantizers": 3,
"device": "cuda" if torch.cuda.is_available() else "cpu",
}
with argbind.scope(args):
run("encode")

Expand Down

0 comments on commit 06e8049

Please sign in to comment.