Skip to content

Commit fce604d

Browse files
Add files via upload
0 parents  commit fce604d

10 files changed

+2267
-0
lines changed

DDPM/__init__.py

Whitespace-only changes.

DDPM/attend.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from functools import wraps
2+
from packaging import version
3+
from collections import namedtuple
4+
5+
import torch
6+
from torch import nn, einsum
7+
import torch.nn.functional as F
8+
9+
from einops import rearrange
10+
11+
# constants
12+
13+
AttentionConfig = namedtuple('AttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
14+
15+
# helpers
16+
17+
def exists(val):
18+
return val is not None
19+
20+
def default(val, d):
21+
return val if exists(val) else d
22+
23+
def once(fn):
24+
called = False
25+
@wraps(fn)
26+
def inner(x):
27+
nonlocal called
28+
if called:
29+
return
30+
called = True
31+
return fn(x)
32+
return inner
33+
34+
print_once = once(print)
35+
36+
# main class
37+
38+
class Attend(nn.Module):
39+
def __init__(
40+
self,
41+
dropout = 0.,
42+
flash = False,
43+
scale = None
44+
):
45+
super().__init__()
46+
self.dropout = dropout
47+
self.scale = scale
48+
self.attn_dropout = nn.Dropout(dropout)
49+
50+
self.flash = flash
51+
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
52+
53+
# determine efficient attention configs for cuda and cpu
54+
55+
self.cpu_config = AttentionConfig(True, True, True)
56+
self.cuda_config = None
57+
58+
if not torch.cuda.is_available() or not flash:
59+
return
60+
61+
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
62+
63+
if device_properties.major == 8 and device_properties.minor == 0:
64+
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
65+
self.cuda_config = AttentionConfig(True, False, False)
66+
else:
67+
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
68+
self.cuda_config = AttentionConfig(False, True, True)
69+
70+
def flash_attn(self, q, k, v):
71+
_, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
72+
73+
if exists(self.scale):
74+
default_scale = q.shape[-1]
75+
q = q * (self.scale / default_scale)
76+
77+
q, k, v = map(lambda t: t.contiguous(), (q, k, v))
78+
79+
# Check if there is a compatible device for flash attention
80+
81+
config = self.cuda_config if is_cuda else self.cpu_config
82+
83+
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
84+
85+
with torch.backends.cuda.sdp_kernel(**config._asdict()):
86+
out = F.scaled_dot_product_attention(
87+
q, k, v,
88+
dropout_p = self.dropout if self.training else 0.
89+
)
90+
91+
return out
92+
93+
def forward(self, q, k, v):
94+
"""
95+
einstein notation
96+
b - batch
97+
h - heads
98+
n, i, j - sequence length (base sequence length, source, target)
99+
d - feature dimension
100+
"""
101+
102+
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
103+
104+
if self.flash:
105+
return self.flash_attn(q, k, v)
106+
107+
scale = default(self.scale, q.shape[-1] ** -0.5)
108+
109+
# similarity
110+
111+
sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
112+
113+
# attention
114+
115+
attn = sim.softmax(dim = -1)
116+
attn = self.attn_dropout(attn)
117+
118+
# aggregate values
119+
120+
out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
121+
122+
return out

0 commit comments

Comments
 (0)