-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdist.py
34 lines (22 loc) · 815 Bytes
/
dist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import functools
import torch
import torch.distributions
import utils
LOGGER = utils.log.getLogger(__name__)
__defined_kl = False
EPS = 1e-5
def clamp_probs(probs):
probs = probs.clamp(EPS, 1. - EPS) # Will no longer sum to 1
return probs / probs.sum(-1, keepdim=True) # to simplex
def grid(h, w, pad=0, device='cpu', dtype=torch.float32, norm=False):
hr = torch.arange(h + 2 * pad, device=device) - pad
wr = torch.arange(w + 2 * pad, device=device) - pad
if norm:
hr = hr / (h + 2 * pad - 1)
wr = wr / (w + 2 * pad - 1)
ig, jg = torch.meshgrid(hr, wr)
g = torch.stack([jg, ig]).to(dtype)[None]
return g
@functools.lru_cache(2)
def cached_grid(h, w, pad=0, device='cpu', dtype=torch.float32, norm=False):
return grid(h, w, pad, device, dtype, norm)