-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtest_sample.py
75 lines (56 loc) · 1.98 KB
/
test_sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
r"""Tests for the azula.sample module."""
import pytest
import torch
import torch.nn as nn
from functools import partial
from torch import Tensor
from typing import Any, Sequence
from azula.denoise import PreconditionedDenoiser
from azula.nn.embedding import SineEncoding
from azula.noise import VPSchedule
from azula.sample import DDIMSampler, DDPMSampler, EulerSampler, HeunSampler, LMSSampler, PCSampler
class Dummy(nn.Module):
def __init__(self, features: int = 5, with_label: bool = False):
super().__init__()
self.with_label = with_label
self.l1 = nn.Linear(features, 64)
self.l2 = nn.Linear(64, features)
self.relu = nn.ReLU()
self.time_encoding = SineEncoding(64)
def forward(self, x_t: Tensor, t: Tensor, label: Any = None):
y = self.l1(x_t)
y = y + self.time_encoding(t)
y = self.relu(y)
y = self.l2(y)
if self.with_label:
assert label is not None
else:
assert label is None
return y
@pytest.mark.parametrize("with_label", [False, True])
@pytest.mark.parametrize("batch", [(), (64,)])
def test_samplers(with_label: bool, batch: Sequence[int]):
denoiser = PreconditionedDenoiser(
backbone=Dummy(5, with_label),
schedule=VPSchedule(),
)
Ss = [
partial(DDPMSampler, steps=64),
partial(DDIMSampler, steps=64, eta=0.0),
partial(DDIMSampler, steps=64, eta=1.0),
partial(EulerSampler, steps=64),
partial(HeunSampler, steps=64),
partial(LMSSampler, steps=64, order=3),
partial(PCSampler, steps=64, corrections=1),
]
for S in Ss:
sampler = S(denoiser)
x1 = sampler.init((*batch, 5))
assert x1.shape == (*batch, 5)
assert torch.all(torch.isfinite(x1))
if with_label:
x0 = sampler(x1, label="cat")
else:
x0 = sampler(x1)
assert x0.shape == (*batch, 5)
assert torch.all(torch.isfinite(x0))