Skip to content

Commit

Permalink
Rename spliceout to splice_out
Browse files Browse the repository at this point in the history
  • Loading branch information
iver56 committed Jun 29, 2022
1 parent 643f320 commit 62e2f64
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 23 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ classification. It was successfully applied in the paper
### Added

* Add new transform: `Identity`
* Add API for processing of targets alongside inputs. Some transforms experimentally
* Add API for processing targets alongside inputs. Some transforms experimentally
support this feature already.

### Changed
Expand Down
2 changes: 1 addition & 1 deletion scripts/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
Identity,
)
from torch_audiomentations.augmentations.shuffle_channels import ShuffleChannels
from torch_audiomentations.augmentations.spliceout import SpliceOut
from torch_audiomentations.augmentations.splice_out import SpliceOut
from torch_audiomentations.core.transforms_interface import ModeNotSupportedException
from torch_audiomentations.utils.object_dict import ObjectDict

Expand Down
42 changes: 21 additions & 21 deletions tests/test_spliceout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import torch
import pytest

from torch_audiomentations.augmentations.spliceout import SpliceOut
from torch_audiomentations.augmentations.splice_out import SpliceOut
from torch_audiomentations import Compose


class TestSpliceout(unittest.TestCase):
def test_spliceout(self):
def test_splice_out(self):

audio_samples = torch.rand(size=(8, 1, 32000), dtype=torch.float32)
augment = Compose(
Expand All @@ -17,13 +17,13 @@ def test_spliceout(self):
],
output_type="dict",
)
spliceout_samples = augment(
splice_out_samples = augment(
samples=audio_samples, sample_rate=16000
).samples.numpy()

assert spliceout_samples.dtype == np.float32
assert splice_out_samples.dtype == np.float32

def test_spliceout_odd_hann(self):
def test_splice_out_odd_hann(self):

audio_samples = torch.rand(size=(8, 1, 32000), dtype=torch.float32)
augment = Compose(
Expand All @@ -32,13 +32,13 @@ def test_spliceout_odd_hann(self):
],
output_type="dict",
)
spliceout_samples = augment(
splice_out_samples = augment(
samples=audio_samples, sample_rate=16100
).samples.numpy()

assert spliceout_samples.dtype == np.float32
assert splice_out_samples.dtype == np.float32

def test_spliceout_perbatch(self):
def test_splice_out_per_batch(self):

audio_samples = torch.rand(size=(8, 1, 32000), dtype=torch.float32)
augment = Compose(
Expand All @@ -53,15 +53,15 @@ def test_spliceout_perbatch(self):
],
output_type="dict",
)
spliceout_samples = augment(
splice_out_samples = augment(
samples=audio_samples, sample_rate=16000
).samples.numpy()

assert spliceout_samples.dtype == np.float32
self.assertLess(spliceout_samples.sum(), audio_samples.numpy().sum())
self.assertEqual(spliceout_samples.shape, audio_samples.shape)
assert splice_out_samples.dtype == np.float32
self.assertLess(splice_out_samples.sum(), audio_samples.numpy().sum())
self.assertEqual(splice_out_samples.shape, audio_samples.shape)

def test_spliceout_multichannel(self):
def test_splice_out_multichannel(self):

audio_samples = torch.rand(size=(8, 2, 32000), dtype=torch.float32)
augment = Compose(
Expand All @@ -70,17 +70,17 @@ def test_spliceout_multichannel(self):
],
output_type="dict",
)
spliceout_samples = augment(
splice_out_samples = augment(
samples=audio_samples, sample_rate=16000
).samples.numpy()

assert spliceout_samples.dtype == np.float32
self.assertLess(spliceout_samples.sum(), audio_samples.numpy().sum())
self.assertEqual(spliceout_samples.shape, audio_samples.shape)
assert splice_out_samples.dtype == np.float32
self.assertLess(splice_out_samples.sum(), audio_samples.numpy().sum())
self.assertEqual(splice_out_samples.shape, audio_samples.shape)

@pytest.mark.skip(reason="This test fails and SpliceOut is not released yet")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA")
def test_spliceout_cuda(self):
def test_splice_out_cuda(self):

audio_samples = (
torch.rand(
Expand All @@ -94,9 +94,9 @@ def test_spliceout_cuda(self):
],
output_type="dict",
)
spliceout_samples = (
splice_out_samples = (
augment(samples=audio_samples, sample_rate=16000).samples.cpu().numpy()
)

assert spliceout_samples.dtype == np.float32
self.assertLess(spliceout_samples.sum(), audio_samples.cpu().numpy().sum())
assert splice_out_samples.dtype == np.float32
self.assertLess(splice_out_samples.sum(), audio_samples.cpu().numpy().sum())

0 comments on commit 62e2f64

Please sign in to comment.