Skip to content

Commit 34e2380

Browse files
authored
Add unit tests (#27)
* TST: add some tests for `LRFinder` This is a draft of unit tests for this package. For details of how test cases are written, please check out "tests/README.md". * TST: replace env vars with command line arguments for pytest runner Other requested changes mentioned in PR #27 are also done in this commit. * TST: remove decorator for making metaclass work on Py2k * TST: remove local import statements Local imports in `collect_task_classes()` is not necessary since module `task` has been imported in global. * STY: format code with black
1 parent 66e23d7 commit 34e2380

File tree

8 files changed

+365
-0
lines changed

8 files changed

+365
-0
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ matplotlib
22
numpy
33
torch>=0.4.1
44
tqdm
5+
pytest

tests/README.md

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
## Requirements
2+
- pytest
3+
4+
## Run tests
5+
- normal (use GPU if it's available)
6+
```bash
7+
# in root directory of this package
8+
$ python -mpytest ./tests
9+
```
10+
11+
- forcibly run all tests on CPU
12+
```bash
13+
# in root directory of this package
14+
$ python -mpytest --cpu_only ./tests
15+
```
16+
17+
## How to add new test cases
18+
To make it able to create test cases and re-use settings conveniently, here we package those basic elements for running a training task into objects inheriting `BaseTask` in `task.py`.
19+
20+
A `BaseTask` is formed of these members:
21+
- `batch_size`
22+
- `model`
23+
- `optimizer`
24+
- `criterion` (loss function)
25+
- `device` (`cpu`, `cuda`, etc.)
26+
- `train_loader` (`torch.utils.data.DataLoader` for training set)
27+
- `val_loader` (`torch.utils.data.DataLoader` for validation set)
28+
29+
If you want to create a new task, just write a new class inheriting `BaseTask` and add your configuration in `__init__`.
30+
31+
Note-1: Any task inheriting `BaseTask` in `task.py` will be collected by the function `test_lr_finder.py::collect_task_classes()`.
32+
33+
Note-2: Model and dataset will be instantiated when a task class is **initialized**, so that it is not recommended to collect a lot of task **objects** at once.
34+
35+
36+
### Directly use specific task in a test case
37+
```python
38+
from . import task as mod_task
39+
def test_run():
40+
task = mod_task.FooTask()
41+
...
42+
```
43+
44+
### Use `pytest.mark.parametrize`
45+
- Use specified task in a test case
46+
```python
47+
@pytest.mark.parametrize(
48+
'cls_task, arg', # names of parameters (see also the signature of the following function)
49+
[
50+
(task.FooTask, 'foo'),
51+
(task.BarTask, 'bar'),
52+
], # list of parameters
53+
)
54+
def test_run(cls_task, arg):
55+
...
56+
```
57+
58+
- Use all existing tasks in a test case
59+
```python
60+
@pytest.mark.parametrize(
61+
'cls_task',
62+
collect_task_classes(),
63+
)
64+
def test_run(cls_task):
65+
...
66+
```

tests/__init__.py

Whitespace-only changes.

tests/conftest.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import pytest
2+
3+
4+
class CustomCommandLineOption(object):
5+
"""An object for storing command line options parsed by pytest.
6+
7+
Since `pytest.config` global object is deprecated and removed in version
8+
5.0, this class is made to work as a store of command line options for
9+
those components which are not able to access them via `request.config`.
10+
"""
11+
12+
def __init__(self):
13+
self._content = {}
14+
15+
def __str__(self):
16+
return str(self._content)
17+
18+
def add(self, key, value):
19+
self._content.update({key: value})
20+
21+
def delete(self, key):
22+
del self._content[key]
23+
24+
def __getattr__(self, key):
25+
if key in self._content:
26+
return self._content[key]
27+
else:
28+
return super(CustomCommandLineOption, self).__getattr__(key)
29+
30+
31+
def pytest_addoption(parser):
32+
parser.addoption(
33+
"--cpu_only", action="store_true", help="Forcibly run all tests on CPU."
34+
)
35+
36+
37+
def pytest_configure(config):
38+
# Bind a config object to `pytest` module instance
39+
pytest.custom_cmdopt = CustomCommandLineOption()
40+
41+
pytest.custom_cmdopt.add("cpu_only", config.getoption("--cpu_only"))

tests/dataset.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import numpy as np
2+
import torch
3+
from torch.utils.data import Dataset
4+
5+
6+
class XORDataset(Dataset):
7+
def __init__(self, length, shape=None):
8+
"""
9+
Arguments:
10+
length (int): length of dataset, which equals `len(self)`.
11+
shape (list, tuple, optional): shape of dataset. If it isn't
12+
specified, it will be initialized to `(length, 8)`.
13+
Default: None.
14+
"""
15+
_shape = (length,) + tuple(shape) if shape else (length, 8)
16+
raw = np.random.randint(0, 2, _shape)
17+
self.data = torch.FloatTensor(raw)
18+
self.label = (
19+
torch.tensor(np.bitwise_xor.reduce(raw, axis=1)).unsqueeze(dim=1).float()
20+
)
21+
22+
def __getitem__(self, index):
23+
return self.data[index], self.label[index]
24+
25+
def __len__(self):
26+
return len(self.data)
27+
28+
29+
class ExtraXORDataset(XORDataset):
30+
""" A XOR dataset which is able to return extra values. """
31+
32+
def __init__(self, length, shape=None, extra_dims=1):
33+
"""
34+
Arguments:
35+
length (int): length of dataset, which equals `len(self)`.
36+
shape (list, tuple, optional): shape of dataset. If it isn't
37+
specified, it will be initialized to `(length, 8)`.
38+
Default: None.
39+
extra_dims (int, optional): dimension of extra values.
40+
Default: 1.
41+
"""
42+
super(ExtraXORDataset, self).__init__(length, shape=shape)
43+
if extra_dims:
44+
_extra_shape = (length, extra_dims)
45+
self.extras = torch.randint(0, 2, _extra_shape)
46+
else:
47+
self.extras = None
48+
49+
def __getitem__(self, index):
50+
if self.extras is not None:
51+
retval = [self.data[index], self.label[index]]
52+
retval.extend([v for v in self.extras[index]])
53+
return retval
54+
else:
55+
return self.data[index], self.label[index]
56+
57+
def __len__(self):
58+
return len(self.data)

tests/model.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import torch.optim as optim
5+
6+
7+
class LinearMLP(nn.Module):
8+
def __init__(self, layer_dim):
9+
super(LinearMLP, self).__init__()
10+
io_pairs = zip(layer_dim[:-1], layer_dim[1:])
11+
layers = [nn.Linear(idim, odim) for idim, odim in io_pairs]
12+
self.net = nn.Sequential(*layers)
13+
14+
def forward(self, x):
15+
return self.net(x)

tests/task.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import torch.optim as optim
5+
from torch.utils.data import DataLoader, Subset
6+
import pytest
7+
8+
from .model import LinearMLP
9+
from .dataset import XORDataset, ExtraXORDataset
10+
11+
12+
def use_cuda():
13+
if pytest.custom_cmdopt.cpu_only:
14+
return False
15+
else:
16+
return torch.cuda.is_available()
17+
18+
19+
class TaskTemplate(type):
20+
def __call__(cls, *args, **kwargs):
21+
obj = type.__call__(cls, *args, **kwargs)
22+
if hasattr(obj, "__post_init__"):
23+
obj.__post_init__()
24+
return obj
25+
26+
27+
class BaseTask(metaclass=TaskTemplate):
28+
def __init__(self):
29+
self.batch_size = -1
30+
self.model = None
31+
self.optimizer = None
32+
self.criterion = None
33+
self.device = None
34+
self.train_loader = None
35+
self.val_loader = None
36+
37+
def __post_init__(self):
38+
# Check whether cuda is available or not, and we will cast `self.device`
39+
# to `torch.device` here to make sure operations related to moving tensor
40+
# would work fine later.
41+
if not use_cuda():
42+
self.device = None
43+
if self.device is None:
44+
return
45+
46+
if isinstance(self.device, str):
47+
self.device = torch.device(self.device)
48+
elif not isinstance(self.device, torch.device):
49+
raise TypeError("Invalid type of device.")
50+
51+
52+
class XORTask(BaseTask):
53+
def __init__(self, validate=False):
54+
super(XORTask, self).__init__()
55+
bs, steps = 8, 64
56+
dataset = XORDataset(bs * steps)
57+
if validate:
58+
self.train_loader = DataLoader(Subset(dataset, range(steps - bs)))
59+
self.val_loader = DataLoader(Subset(dataset, range(steps - bs, steps)))
60+
else:
61+
self.train_loader = DataLoader(dataset)
62+
self.val_loader = None
63+
64+
self.batch_size = bs
65+
self.model = LinearMLP([8, 4, 1])
66+
self.optimizer = optim.SGD(self.model.parameters(), lr=1e-3)
67+
self.criterion = nn.MSELoss()
68+
self.device = torch.device("cuda")
69+
70+
71+
class ExtraXORTask(BaseTask):
72+
def __init__(self, validate=False):
73+
super(ExtraXORTask, self).__init__()
74+
bs, steps = 8, 64
75+
dataset = ExtraXORDataset(bs * steps, extra_dims=2)
76+
if validate:
77+
self.train_loader = DataLoader(Subset(dataset, range(steps - bs)))
78+
self.val_loader = DataLoader(Subset(dataset, range(steps - bs, steps)))
79+
else:
80+
self.train_loader = DataLoader(dataset)
81+
self.val_loader = None
82+
83+
self.model = LinearMLP([8, 4, 1])
84+
self.optimizer = optim.SGD(self.model.parameters(), lr=1e-3)
85+
self.criterion = nn.MSELoss()
86+
self.device = torch.device("cuda")
87+
88+
89+
class DiscriminativeLearningRateTask(BaseTask):
90+
def __init__(self, validate=False):
91+
super(DiscriminativeLearningRateTask, self).__init__()
92+
bs, steps = 8, 64
93+
dataset = XORDataset(bs * steps)
94+
if validate:
95+
self.train_loader = DataLoader(Subset(dataset, range(steps - bs)))
96+
self.val_loader = DataLoader(Subset(dataset, range(steps - bs, steps)))
97+
else:
98+
self.train_loader = DataLoader(dataset)
99+
self.val_loader = None
100+
101+
dataset = XORDataset(128)
102+
self.model = LinearMLP([8, 4, 1])
103+
self.optimizer = optim.SGD(
104+
[
105+
{"params": self.model.net[0].parameters(), "lr": 0.01},
106+
{"params": self.model.net[1].parameters(), "lr": 0.001},
107+
],
108+
lr=1e-3,
109+
momentum=0.5,
110+
)
111+
self.criterion = nn.MSELoss()
112+
self.device = torch.device("cuda")

tests/test_lr_finder.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import pytest
2+
from torch_lr_finder import LRFinder
3+
4+
from . import task as mod_task
5+
6+
7+
def collect_task_classes():
8+
names = [v for v in dir(mod_task) if v.endswith("Task") and v != "BaseTask"]
9+
attrs = [getattr(mod_task, v) for v in names]
10+
classes = [v for v in attrs if issubclass(v, mod_task.BaseTask)]
11+
return classes
12+
13+
14+
def prepare_lr_finder(task, **kwargs):
15+
model = task.model
16+
optimizer = task.optimizer
17+
criterion = task.criterion
18+
config = {
19+
"device": kwargs.get("device", None),
20+
"memory_cache": kwargs.get("memory_cache", True),
21+
"cache_dir": kwargs.get("cache_dir", None),
22+
}
23+
lr_finder = LRFinder(model, optimizer, criterion, **config)
24+
return lr_finder
25+
26+
27+
def get_optim_lr(optimizer):
28+
return [grp["lr"] for grp in optimizer.param_groups]
29+
30+
31+
class TestRangeTest:
32+
@pytest.mark.parametrize("cls_task", collect_task_classes())
33+
def test_run(self, cls_task):
34+
task = cls_task()
35+
init_lrs = get_optim_lr(task.optimizer)
36+
37+
lr_finder = prepare_lr_finder(task)
38+
lr_finder.range_test(task.train_loader)
39+
40+
# check whether lr is actually changed
41+
assert max(lr_finder.history["lr"]) >= init_lrs[0]
42+
43+
@pytest.mark.parametrize("cls_task", collect_task_classes())
44+
def test_run_with_val_loader(self, cls_task):
45+
task = cls_task(validate=True)
46+
init_lrs = get_optim_lr(task.optimizer)
47+
48+
lr_finder = prepare_lr_finder(task)
49+
lr_finder.range_test(task.train_loader, val_loader=task.val_loader)
50+
51+
# check whether lr is actually changed
52+
assert max(lr_finder.history["lr"]) >= init_lrs[0]
53+
54+
55+
class TestReset:
56+
@pytest.mark.parametrize(
57+
"cls_task",
58+
[
59+
mod_task.XORTask,
60+
mod_task.DiscriminativeLearningRateTask,
61+
],
62+
)
63+
def test_reset(self, cls_task):
64+
task = cls_task()
65+
init_lrs = get_optim_lr(task.optimizer)
66+
67+
lr_finder = prepare_lr_finder(task)
68+
lr_finder.range_test(task.train_loader, val_loader=task.val_loader)
69+
lr_finder.reset()
70+
71+
restored_lrs = get_optim_lr(task.optimizer)
72+
assert init_lrs == restored_lrs

0 commit comments

Comments
 (0)