Skip to content

Commit

Permalink
Use ChunkSplitters for improved performance in symmetric case
Browse files Browse the repository at this point in the history
  • Loading branch information
PGS62 committed Feb 14, 2024
1 parent 5d6447d commit 3e3ec4f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 22 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ version = "3.0.2"

[deps]
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
ChunkSplitters = "ae650224-84b6-46f8-82ea-d812ca08434e"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
CSV = "0.10"
ChunkSplitters = "2"
DataFrames = "1"
Statistics = "1"
Tables = "1"
Expand Down
51 changes: 29 additions & 22 deletions src/corkendall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Kendall correlation
#
#######################################
using ChunkSplitters: chunks

"""
corkendall(x, y=x; skipmissing::Symbol=:none)
Expand Down Expand Up @@ -65,29 +66,35 @@ function _corkendall(x::AbstractMatrix{T}, y::AbstractMatrix{U},
intarray = Int[]
nmtx = nonmissingtype(eltype(x))[]
nmty = nonmissingtype(eltype(y))[]
alljs = (symmetric ? 2 : 1):nr

#ChunkSplitters.chunks with split=:scatter provides better load balancing in symmetric case
Threads.@threads for thischunk in chunks(alljs; n=Threads.nthreads() * 4, split=:scatter)

for k in thischunk
j = alljs[k]

sortedxcolj = task_local_vector(:sortedxcolj, x)
scratch_py = task_local_vector(:scratch_py, y)
ycoli = task_local_vector(:ycoli, y)
permx = task_local_vector(:permx, intarray)
# Ensuring missing is not an element type of scratch_sy, scratch_fx, scratch_fy
# gives improved performance.
scratch_sy = task_local_vector(:scratch_sy, nmty)
scratch_fx = task_local_vector(:scratch_fx, nmtx)
scratch_fy = task_local_vector(:scratch_fy, nmty)

sortperm!(permx, view(x, :, j))
@inbounds for k in eachindex(sortedxcolj)
sortedxcolj[k] = x[permx[k], j]
end

Threads.@threads for j = (symmetric ? 2 : 1):nr

sortedxcolj = task_local_vector(:sortedxcolj, x)
scratch_py = task_local_vector(:scratch_py, y)
ycoli = task_local_vector(:ycoli, y)
permx = task_local_vector(:permx, intarray)
# Ensure missing is not an element type of scratch_sy, scratch_fx, scratch_fy for
# improved performance.
scratch_sy = task_local_vector(:scratch_sy, nmty)
scratch_fx = task_local_vector(:scratch_fx, nmtx)
scratch_fy = task_local_vector(:scratch_fy, nmty)

sortperm!(permx, view(x, :, j))
@inbounds for k in eachindex(sortedxcolj)
sortedxcolj[k] = x[permx[k], j]
end

for i = 1:(symmetric ? j - 1 : nc)
ycoli .= view(y, :, i)
C[j, i] = corkendall_kernel!(sortedxcolj, ycoli, permx, skipmissing;
scratch_py, scratch_sy, scratch_fx, scratch_fy)
symmetric && (C[i, j] = C[j, i])
for i = 1:(symmetric ? j - 1 : nc)
ycoli .= view(y, :, i)
C[j, i] = corkendall_kernel!(sortedxcolj, ycoli, permx, skipmissing;
scratch_py, scratch_sy, scratch_fx, scratch_fy)
symmetric && (C[i, j] = C[j, i])
end
end
end
return C
Expand Down

0 comments on commit 3e3ec4f

Please sign in to comment.