Skip to content

pxl-th/NNop.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

26 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

NNop.jl

GPU Backend CI Status
AMDGPU
CUDA

Fused kernels (with ChainRules.jl integration):

Benchmarking

See benchmarks/main.jl for comparison scripts between naїve & fused versions.

Flash Attention

Implementation of FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.

E, L, H, B = 64, 4096, 4, 4
causal = false

q = ROCArray(rand(Float32, E, L, H, B))
k = ROCArray(rand(Float32, E, L, H, B))
v = ROCArray(rand(Float32, E, L, H, B))

o = NNop.flash_attention(q, k, v; causal)
∇ = Zygote.gradient(q, k, v) do q, k, v
    sum(NNop.flash_attention(q, k, v; causal))
end

Features:

  • Forward & backward passes.
  • Arbitrary sequence length.
  • FP32, FP16, BFP16 support.
  • Variable sequence length.
  • Causal masking.

Softmax

Implementation of Online normalizer calculation for softmax.

x = ROCArray(rand(Float32, 8192, 1024))
y = NNop.online_softmax(x)

RMS Norm

x = ROCArray(rand(Float32, 1024, 1024))
w = ROCArray(rand(Float32, 1024))
y = NNop.rms_norm(x, w)
∇ = Zygote.gradient(x, w) do x, w
    sum(NNop.rms_norm(x, w))
end

Layer Norm

x = ROCArray(rand(Float32, 1024, 1024))
w = ROCArray(rand(Float32, 1024))
w = ROCArray(rand(Float32, 1024))
y = NNop.layer_norm(x, w)
∇ = Zygote.gradient(x, w, b) do x, w, b
    sum(NNop.layer_norm(x, w, b))
end

Llama RoPE

E, L, B = 16, 1024, 1
QH, KH = 16, 16

emb = NNop.LlamaRotaryEmbedding(E)
position_ids = reshape(collect(0f0:Float32(L) - 1f0), :, 1)
position_ids = repeat(position_ids; inner=(1, B))

cos, sin = emb(position_ids)
cos = Adapt.adapt(kab, cos)
sin = Adapt.adapt(kab, sin)

q = Adapt.adapt(kab, ones(Float32, (E, L, QH, B)))
k = Adapt.adapt(kab, ones(Float32, (E, L, KH, B)))
q, k = NNop.llama_rope(q, k; cos, sin)

About

Flash Attention & friends in pure Julia

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •  

Languages