Skip to content

Commit

Permalink
Merge pull request #10 from PGS62/task_local_storage_experiment
Browse files Browse the repository at this point in the history
Use `Base.task_local_storage` and stop using `Threads.threadid`
  • Loading branch information
PGS62 authored Feb 12, 2024
2 parents 3ad06e1 + f999efa commit 832095c
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 41 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
julia = "1"
CSV = "0.10 - 0"
DataFrames = "1"
Statistics = "1"
Tables = "1"
julia = "1"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
1 change: 1 addition & 0 deletions src/KendallTau.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module KendallTau


include("corkendall.jl")
include("corkendall_fromfile.jl")

Expand Down
58 changes: 18 additions & 40 deletions src/corkendall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ function corkendall(x::AbstractMatrix, y::AbstractMatrix=x;

end

function _corkendall(x::AbstractMatrix, y::AbstractMatrix,
C::AbstractMatrix, skipmissing::Symbol)
function _corkendall(x::AbstractMatrix{T}, y::AbstractMatrix{U},
C::AbstractMatrix, skipmissing::Symbol) where {T,U}

symmetric = x === y

Expand All @@ -56,47 +56,25 @@ function _corkendall(x::AbstractMatrix, y::AbstractMatrix,

(m, nr), nc = size(x), size(y, 2)

# Avoid unnecessary allocation when nthreads is large but output matrix is small.
n_duplicates = min(Threads.nthreads(), symmetric ? nr - 1 : nr)
Threads.@threads for j = (symmetric ? 2 : 1):nr

use_atomic = n_duplicates < Threads.nthreads()
if use_atomic
a = Threads.Atomic{Int}(1)
end

#= Create scratch vectors so that threaded code can be non-allocating, a requirement for
good multi-threaded performance. One vector per thread to avoid cross-talk between
threads.
=#
duplicate(x, n) = [copy(x) for _ in 1:n]
scratch_pys = duplicate(similar(y, m), n_duplicates)
ycolis = duplicate(similar(y, m), n_duplicates)
sortedxcoljs = duplicate(similar(x, m), n_duplicates)
permxs = duplicate(zeros(Int, m), n_duplicates)
scratch_fxs = duplicate(similar(x, m), n_duplicates)
scratch_fys = duplicate(similar(y, m), n_duplicates)
scratch_sys = duplicate(similar(y, m), n_duplicates)

#= Use the "static scheduler". This is the "quickfix, but not recommended longterm" way
of avoiding concurrency bugs when using Threads.threadid().
https://julialang.org/blog/2023/07/PSA-dont-use-threadid/#fixing_buggy_code_which_uses_this_pattern
TODO Adopt a "better fix" as outlined in that blog.
=#
Threads.@threads :static for j = (symmetric ? 2 : 1):nr

if use_atomic
id = Threads.atomic_add!(a, 1)[]
else
id = Threads.threadid()
if !haskey(task_local_storage(), :scratch_py)
task_local_storage(:scratch_py, similar(y, m))
task_local_storage(:scratch_sy, similar(y, m))
task_local_storage(:ycoli, similar(y, m))
task_local_storage(:sortedxcolj, similar(x, m))
task_local_storage(:permx, zeros(Int, m))
task_local_storage(:scratch_fx, similar(x, m))
task_local_storage(:scratch_fy, similar(y, m))
end

scratch_py = scratch_pys[id]
scratch_sy = scratch_sys[id]
ycoli = ycolis[id]
sortedxcolj = sortedxcoljs[id]
permx = permxs[id]
scratch_fx = scratch_fxs[id]
scratch_fy = scratch_fys[id]
scratch_py::Vector{U} = task_local_storage(:scratch_py)
scratch_sy::Vector{U} = task_local_storage(:scratch_sy)
ycoli::Vector{U} = task_local_storage(:ycoli)
sortedxcolj::Vector{T} = task_local_storage(:sortedxcolj)
permx::Vector{Int} = task_local_storage(:permx)
scratch_fx::Vector{T} = task_local_storage(:scratch_fx)
scratch_fy::Vector{U} = task_local_storage(:scratch_fy)

sortperm!(permx, view(x, :, j))
@inbounds for k in eachindex(sortedxcolj)
Expand Down

0 comments on commit 832095c

Please sign in to comment.