Skip to content

Commit

Permalink
make sure last_da is only used for the first backwards segment, when …
Browse files Browse the repository at this point in the history
…the number of checkpoint segments is greater than 1
  • Loading branch information
lucidrains committed Nov 14, 2023
1 parent 2018f5d commit e116bde
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
19 changes: 13 additions & 6 deletions colt5_attention/triton_coor_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def coor_descent_kernel_backward(
ds_ptr,
db_ptr,
k_ptr,
last_da_ptr,
input_row_stride,
b_row_stride,
mask_row_stride,
Expand Down Expand Up @@ -207,6 +208,12 @@ def coor_descent_kernel_backward(
k = tl.load(k_ptr)
logk = tl.log(k)

# load last da

last_da_ptr = last_da_ptr + row_idx

last_da = tl.load(last_da_ptr)

# load initial ds

ds_row_start_ptr = ds_ptr + row_idx * ds_row_stride
Expand All @@ -226,10 +233,6 @@ def coor_descent_kernel_backward(
dk_ptr = dk_ptr + row_idx
dk = tl.load(dk_ptr)

# temp variables

last_da = tl.sum(ds, axis = 0)

# backwards

for ind in range(n_iters):
Expand Down Expand Up @@ -289,7 +292,7 @@ def coor_descent_kernel_backward(
ds += dsb
db = dsb

last_da = 0.
last_da *= 0.

# store dk

Expand Down Expand Up @@ -430,6 +433,7 @@ def backward(
ds = grad_probs * y / last_eps
db = ds.clone()
dk = torch.zeros_like(k)
last_da = ds.sum(dim = -1)

mask_int = mask.int()

Expand All @@ -440,7 +444,9 @@ def backward(
reversed(epsilons)
)

for init_a, init_b, segment_iters, eps_init, in items:
for ind, (init_a, init_b, segment_iters, eps_init) in enumerate(items):
is_first = ind == 0

coor_descent_kernel_backward[(n_rows,)](
dk,
x,
Expand All @@ -450,6 +456,7 @@ def backward(
ds,
db,
k,
last_da if is_first else torch.zeros_like(last_da),
x.stride(0),
init_b.stride(0),
mask_int.stride(0),
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'CoLT5-attention',
packages = find_packages(),
version = '0.10.16',
version = '0.10.17',
license='MIT',
description = 'Conditionally Routed Attention',
long_description_content_type = 'text/markdown',
Expand Down

0 comments on commit e116bde

Please sign in to comment.