Skip to content

Commit 12fd5ad

Browse files
authored
Try #269:
2 parents 045fab2 + 4e49edf commit 12fd5ad

File tree

8 files changed

+46
-10
lines changed

8 files changed

+46
-10
lines changed

examples/matmul.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,9 @@ function matmul!(a, b, c)
2323
println("Matrix size mismatch!")
2424
return nothing
2525
end
26-
if isa(a, Array)
27-
kernel! = matmul_kernel!(CPU(),4)
28-
else
29-
kernel! = matmul_kernel!(CUDADevice(),256)
30-
end
26+
device = KernelAbstractions.get_device(a)
27+
n = device isa GPU ? 256 : 4
28+
kernel! = matmul_kernel!(device, n)
3129
kernel!(a, b, c, ndrange=size(c))
3230
end
3331

examples/naive_transpose.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,9 @@ function naive_transpose!(a, b)
1616
println("Matrix size mismatch!")
1717
return nothing
1818
end
19-
if isa(a, Array)
20-
kernel! = naive_transpose_kernel!(CPU(),4)
21-
else
22-
kernel! = naive_transpose_kernel!(CUDADevice(),256)
23-
end
19+
device = KernelAbstractions.get_device(a)
20+
n = device isa GPU ? 256 : 4
21+
kernel! = naive_transpose_kernel!(device, n)
2422
kernel!(a, b, ndrange=size(a))
2523
end
2624

lib/CUDAKernels/src/CUDAKernels.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ import KernelAbstractions
1010

1111
export CUDADevice
1212

13+
KernelAbstractions.get_device(::Type{<:CUDA.CuArray}) = CUDADevice()
14+
KernelAbstractions.get_device(::Type{<:CUDA.CUSPARSE.AbstractCuSparseArray}) = CUDADevice()
15+
16+
1317
const FREE_STREAMS = CUDA.CuStream[]
1418
const STREAMS = CUDA.CuStream[]
1519
const STREAM_GC_THRESHOLD = Ref{Int}(16)

lib/CUDAKernels/test/runtests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ using Test
88
include(joinpath(dirname(pathof(KernelAbstractions)), "..", "test", "testsuite.jl"))
99
include(joinpath(dirname(pathof(KernelGradients)), "..", "test", "testsuite.jl"))
1010

11+
@testset "get_device" begin
12+
@test @inferred(KernelAbstractions.get_device(CUDA.CuArray{Float32,3})) == CUDADevice()
13+
@test @inferred(KernelAbstractions.get_device(CUDA.CUSPARSE.CuSparseMatrixCSC{Float32})) == CUDADevice()
14+
end
15+
1116
if parse(Bool, get(ENV, "CI", "false"))
1217
default = "CPU"
1318
else

lib/ROCKernels/src/ROCKernels.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ import KernelAbstractions
1111

1212
export ROCDevice
1313

14+
KernelAbstractions.get_device(::Type{<:AMDGPU.ROCArray}) = ROCDevice()
15+
16+
1417
const FREE_QUEUES = HSAQueue[]
1518
const QUEUES = HSAQueue[]
1619
const QUEUE_GC_THRESHOLD = Ref{Int}(16)

lib/ROCKernels/test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ using Test
88
include(joinpath(dirname(pathof(KernelAbstractions)), "..", "test", "testsuite.jl"))
99
include(joinpath(dirname(pathof(KernelGradients)), "..", "test", "testsuite.jl"))
1010

11+
@test "get_device" begin
12+
@test @inferred(KernelAbstractions.get_device(AMDGPU.ROCArray{Float32,3})) == ROCDevice()
13+
end
14+
1115
if parse(Bool, get(ENV, "CI", "false"))
1216
default = "CPU"
1317
else

src/KernelAbstractions.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,24 @@ abstract type GPU <: Device end
337337

338338
struct CPU <: Device end
339339

340+
341+
"""
342+
KernelAbstractions.get_device(A::AbstractArray)::KernelAbstractions.Device
343+
KernelAbstractions.get_device(TA::Type{<:AbstractArray})::KernelAbstractions.Device
344+
345+
Get a `KernelAbstractions.Device` instance suitable for array `A` resp. array
346+
type `TA`.
347+
"""
348+
function get_device end
349+
350+
get_device(A::AbstractArray) = get_device(typeof(A))
351+
352+
get_device(::Type{<:AbstractArray}) = CPU()
353+
354+
# Would require dependency on GPUArrays:
355+
# get_device(TA::Type{<:GPUArrays.AbstractGPUArray}) = throw(ArgumentError("NoKernelAbstractions.Device type defined for arrays of type $(TA.name.name)"))
356+
357+
340358
include("nditeration.jl")
341359
using .NDIteration
342360
import .NDIteration: get

test/test.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ end
6464
A[I] = i
6565
end
6666

67+
@testset "get_device" begin
68+
A = rand(5)
69+
@test @inferred(KernelAbstractions.get_device(typeof(A))) == CPU()
70+
@test @inferred(KernelAbstractions.get_device(A)) == KernelAbstractions.get_device(typeof(A))
71+
end
72+
6773
@testset "indextest" begin
6874
# TODO: add test for _group and _local_cartesian
6975
A = ArrayT{Int}(undef, 16, 16)

0 commit comments

Comments
 (0)