Skip to content

Commit

Permalink
Use n_quantizers and save codes as uint16 (#12)
Browse files Browse the repository at this point in the history
* restrict full quantizer usage during inference

* save/load uint16 codes

* add test

* bump version
  • Loading branch information
eeishaan authored Jun 19, 2023
1 parent 7cb5f5b commit 3e877db
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 4 deletions.
2 changes: 1 addition & 1 deletion dac/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.0.2"
__version__ = "0.0.3"
__model_version__ = "0.0.1"
import audiotools

Expand Down
3 changes: 3 additions & 0 deletions dac/nn/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ def forward(self, z, n_quantizers: int = None):
n_quantizers = n_quantizers.to(z.device)

for i, quantizer in enumerate(self.quantizers):
if self.training is False and i >= n_quantizers:
break

z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
residual
)
Expand Down
4 changes: 3 additions & 1 deletion dac/utils/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def process(
"""
if isinstance(generator, torch.nn.DataParallel):
generator = generator.module
audio_signal = AudioSignal(artifacts["codes"], generator.sample_rate)
audio_signal = AudioSignal(
artifacts["codes"].astype(np.int64), generator.sample_rate
)
metadata = artifacts["metadata"]

# Decode chunks
Expand Down
2 changes: 1 addition & 1 deletion dac/utils/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def process(
codebook_indices = torch.cat(codebook_indices, dim=0)

return {
"codes": codebook_indices.numpy(),
"codes": codebook_indices.numpy().astype(np.uint16),
"metadata": {
"original_db": input_db,
"overlap_hop_duration": overlap_hop_duration,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name="descript-audio-codec",
version="0.0.2",
version="0.0.3",
classifiers=[
"Intended Audience :: Developers",
"Natural Language :: English",
Expand Down
20 changes: 20 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,24 @@ def test_reconstruction():
run("decode")


def test_compression():
# Test encoding
input_dir = Path(__file__).parent / "assets" / "input"
output_dir = input_dir.parent / "encoded_output_quantizers"
args = {"input": str(input_dir), "output": str(output_dir), "n_quantizers": 3}
with argbind.scope(args):
run("encode")

# Open .dac file
dac_file = output_dir / "sample_0.dac"
artifacts = np.load(dac_file, allow_pickle=True)[()]
codes = artifacts["codes"]

# Ensure that the number of quantizers is correct
assert codes.shape[1] == 3

# Ensure that dtype of compression is uint16
assert codes.dtype == np.uint16


# CUDA_VISIBLE_DEVICES=0 python -m pytest tests/test_cli.py -s

0 comments on commit 3e877db

Please sign in to comment.