Skip to content

Commit

Permalink
more test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jgreener64 committed Jan 29, 2025
1 parent a6e394e commit 1022af7
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/kernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ function pairwise_force_gpu!(buffers, sys::System{D, AT, T},
nbs = @view neighbors.list[1:neighbors.n]
end
if length(neighbors) > 0
backend = get_backend(coords)
backend = get_backend(sys.coords)
n_threads_gpu = gpu_threads_pairwise(length(nbs))
kernel! = pairwise_force_kernel_nl!(backend, n_threads_gpu)
kernel!(buffers.fs_mat, sys.coords, sys.velocities, sys.atoms, sys.boundary, pairwise_inters,
Expand Down
18 changes: 8 additions & 10 deletions src/setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -887,15 +887,14 @@ function System(coord_file::AbstractString,
end
coords = wrap_coords.(coords, (boundary_used,))

if AT <: AbstractGPUArray
if Symbol(AT) == :CuArray
neighbor_finder = GPUNeighborFinder(
eligible=AT(eligible),
dist_cutoff=T(dist_neighbors),
special=AT(special),
n_steps_reorder=10,
initialized=false,
)
elseif use_cell_list
elseif use_cell_list && !(AT <: AbstractGPUArray)
neighbor_finder = CellListMapNeighborFinder(
eligible=eligible,
special=special,
Expand All @@ -906,8 +905,8 @@ function System(coord_file::AbstractString,
)
else
neighbor_finder = DistanceNeighborFinder(
eligible=eligible,
special=special,
eligible=AT(eligible),
special=AT(special),
n_steps=10,
dist_cutoff=T(dist_neighbors),
)
Expand Down Expand Up @@ -1280,15 +1279,14 @@ function System(T::Type,
end
specific_inter_lists = tuple(specific_inter_array...)

if AT <: AbstractGPUArray
if Symbol(AT) == :CuArray
neighbor_finder = GPUNeighborFinder(
eligible=AT(eligible),
dist_cutoff=T(dist_neighbors),
special=AT(special),
n_steps_reorder=10,
initialized=false,
)
elseif use_cell_list
elseif use_cell_list && !(AT <: AbstractGPUArray)
neighbor_finder = CellListMapNeighborFinder(
eligible=eligible,
special=special,
Expand All @@ -1299,8 +1297,8 @@ function System(T::Type,
)
else
neighbor_finder = DistanceNeighborFinder(
eligible=eligible,
special=special,
eligible=AT(eligible),
special=AT(special),
n_steps=10,
dist_cutoff=T(dist_neighbors),
)
Expand Down
6 changes: 3 additions & 3 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,14 @@
# Mark all pairs as ineligible for pairwise interactions and check that the
# potential energy from the specific interactions does not change on scaling
no_nbs = falses(length(sys), length(sys))
if AT <: AbstractGPUArray
if AT <: CuArray
sys.neighbor_finder = GPUNeighborFinder(
eligible=AT(no_nbs),
dist_cutoff=1.0u"nm",
)
else
else
sys.neighbor_finder = DistanceNeighborFinder(
eligible=no_nbs,
eligible=(AT <: Array ? no_nbs : AT(no_nbs)),
dist_cutoff=1.0u"nm",
)
end
Expand Down

0 comments on commit 1022af7

Please sign in to comment.