GPU Backend | CI Status |
---|---|
AMDGPU | |
CUDA |
Fused kernels (with ChainRules.jl integration):
See benchmarks/main.jl
for comparison scripts between naїve & fused versions.
Implementation of FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.
E, L, H, B = 64, 4096, 4, 4
causal = false
q = ROCArray(rand(Float32, E, L, H, B))
k = ROCArray(rand(Float32, E, L, H, B))
v = ROCArray(rand(Float32, E, L, H, B))
o = NNop.flash_attention(q, k, v; causal)
∇ = Zygote.gradient(q, k, v) do q, k, v
sum(NNop.flash_attention(q, k, v; causal))
end
- Forward & backward passes.
- Arbitrary sequence length.
- FP32, FP16, BFP16 support.
- Variable sequence length.
- Causal masking.
Implementation of Online normalizer calculation for softmax.
x = ROCArray(rand(Float32, 8192, 1024))
y = NNop.online_softmax(x)
x = ROCArray(rand(Float32, 1024, 1024))
w = ROCArray(rand(Float32, 1024))
y = NNop.rms_norm(x, w)
∇ = Zygote.gradient(x, w) do x, w
sum(NNop.rms_norm(x, w))
end
x = ROCArray(rand(Float32, 1024, 1024))
w = ROCArray(rand(Float32, 1024))
w = ROCArray(rand(Float32, 1024))
y = NNop.layer_norm(x, w)
∇ = Zygote.gradient(x, w, b) do x, w, b
sum(NNop.layer_norm(x, w, b))
end
E, L, B = 16, 1024, 1
QH, KH = 16, 16
emb = NNop.LlamaRotaryEmbedding(E)
position_ids = reshape(collect(0f0:Float32(L) - 1f0), :, 1)
position_ids = repeat(position_ids; inner=(1, B))
cos, sin = emb(position_ids)
cos = Adapt.adapt(kab, cos)
sin = Adapt.adapt(kab, sin)
q = Adapt.adapt(kab, ones(Float32, (E, L, QH, B)))
k = Adapt.adapt(kab, ones(Float32, (E, L, KH, B)))
q, k = NNop.llama_rope(q, k; cos, sin)