Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
No longer have two functions handlepairwiswe and handlepairwise!, but just one: handlepairwise
  • Loading branch information
PGS62 committed Jan 30, 2024
1 parent b05a156 commit 58af14a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 44 deletions.
61 changes: 21 additions & 40 deletions src/corkendall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ https://julialang.org/blog/2023/07/PSA-dont-use-threadid/#another_option_use_a_p
Option b)
Some other way to arrange that arguments passed to f from within StatsBase._pairwise!
do not have Missing as an allowed element type. Using function handlepairwise! would do
do not have Missing as an allowed element type. Using function handlepairwise would do
that efficiently.
Option c)
Expand Down Expand Up @@ -74,12 +74,16 @@ function corkendall(x::RoMVector{T}, y::RoMVector{U}; skipmissing::Symbol=:none)
(x isa Vector{Missing} || y isa Vector{Missing}) && return NaN

x = copy(x)
x, y = handlelistwise(x, y, skipmissing)
x, y = handlepairwise(x, y)

if missing isa eltype(x) || missing isa eltype(y)
x, y = handlelistwise(x, y, skipmissing)
x, y = handlepairwise(x, y, similar(x, T), similar(y, U))
end

permx = sortperm(x)
permute!(x, permx)

return corkendall_sorted!(x, y, permx, similar(y), T[], U[])
return corkendall_sorted!(x, y, permx, similar(y), similar(x, T), similar(y, U))
end

function corkendall(x::RoMMatrix{T}, y::RoMMatrix{U}=x; skipmissing::Symbol=:none) where {T,U}
Expand Down Expand Up @@ -117,9 +121,10 @@ function corkendall(x::RoMMatrix{T}, y::RoMMatrix{U}=x; skipmissing::Symbol=:non
a = Threads.Atomic{Int}(1)
end

# Create scratch vectors so that threaded code can be non-allocating, a requirement for
# good multi-threaded performance. One vector per thread to avoid cross-talk between
# threads.
#= Create scratch vectors so that threaded code can be non-allocating, a requirement for
good multi-threaded performance. One vector per thread to avoid cross-talk between
threads.
=#
duplicate(x, n) = [copy(x) for _ in 1:n]
scratchpermutedys = duplicate(similar(y, m), n_duplicates)
ycolis = duplicate(similar(y, m), n_duplicates)
Expand All @@ -128,8 +133,9 @@ function corkendall(x::RoMMatrix{T}, y::RoMMatrix{U}=x; skipmissing::Symbol=:non
scratchxs = duplicate(Vector{T}(undef, m), n_duplicates)
scratchys = duplicate(Vector{U}(undef, m), n_duplicates)

#= Use the "static scheduler". This is the "quickfix, but not recommended longterm"
way of avoiding concurrency bugs. See https://julialang.org/blog/2023/07/PSA-dont-use-threadid/#fixing_buggy_code_which_uses_this_pattern
#= Use the "static scheduler". This is the "quickfix, but not recommended longterm" way
of avoiding concurrency bugs when using threadid.
https://julialang.org/blog/2023/07/PSA-dont-use-threadid/#fixing_buggy_code_which_uses_this_pattern
TODO Adopt a "better fix" as outlined in that blog.
=#
Threads.@threads :static for j = (symmetric ? 2 : 1):nr
Expand Down Expand Up @@ -191,9 +197,9 @@ subsequent arguments being:
yield `sortedx`.
- `scratchpermutedy::RoMVector`: a vector of the same element type and length as `y`; used
to permute `y` without allocation.
- `scratchx, scratchy`: vectors of the same length as `x` and `y` whose element types match the types
of the non-missing elements of `x` and `y` respectively; used (in the call to
`handlepairwise!`) to avoid allocations.
- `scratchx, scratchy`: vectors of the same length as `x` and `y` whose element types match
the types of the non-missing elements of `x` and `y` respectively; used (in the calls to
`handlepairwise` and `merge_sort!`) to avoid allocations.
"""
function corkendall_sorted!(sortedx::RoMVector{T}, y::RoMVector{U},
permx::AbstractVector{<:Integer}, scratchpermutedy::RoMVector{U},
Expand All @@ -206,7 +212,7 @@ function corkendall_sorted!(sortedx::RoMVector{T}, y::RoMVector{U},
end

if missing isa eltype(sortedx) || missing isa eltype(scratchpermutedy)
sortedx, permutedy = handlepairwise!(sortedx, scratchpermutedy, scratchx, scratchy)
sortedx, permutedy = handlepairwise(sortedx, scratchpermutedy, scratchx, scratchy)
else
permutedy = scratchpermutedy
end
Expand Down Expand Up @@ -364,40 +370,15 @@ function insertion_sort!(v::AbstractVector, lo::Integer, hi::Integer)
return nswaps
end

"""
handlepairwise(x::RoMVector{T}, y::RoMVector{U}) where {T,U}
Return a pair `(a,b)`, filtered copies of `x` and `y`, in which elements `x[i]` and `y[i]`
are filtered out if `ismissing(x[i])||ismissing(y[i])`.
"""
handlepairwise(x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) = x, y

function handlepairwise(x::RoMVector{T}, y::RoMVector{U}) where {T,U}

n = length(x)
a = Vector{T}(undef, n)
b = Vector{U}(undef, n)
j::Int = 0

@inbounds for i in eachindex(x)
if !(ismissing(x[i]) || ismissing(y[i]))
j += 1
a[j] = x[i]
b[j] = y[i]
end
end

return resize!(a, j), resize!(b, j)
end

"""
handlepairwise!(x::RoMVector{T}, y::RoMVector{U},
handlepairwise(x::RoMVector{T}, y::RoMVector{U},
scratchx::AbstractVector{T}, scratchy::AbstractVector{U}) where {T,U}
Return a pair `(a,b)`, filtered copies of `(x,y)`, in which elements `x[i]` and
`y[i]` are filtered out if `ismissing(x[i])||ismissing(y[i])`.
"""
function handlepairwise!(x::RoMVector{T}, y::RoMVector{U},
function handlepairwise(x::RoMVector{T}, y::RoMVector{U},
scratchx::AbstractVector{T}, scratchy::AbstractVector{U}) where {T,U}

j = 0
Expand Down
8 changes: 4 additions & 4 deletions test/rankcorr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,10 @@ mx = [1 2
missing missing
5 6]

@test KendallTau.handlepairwise!(x, y, similar(x), similar(y)) == ([2, 3, 4], [1, 2, 4])
@test KendallTau.handlepairwise!(float.(x), y, similar(float.(x)), similar(y)) == ([2.0, 3.0, 4.0], [1, 2, 4])
@test KendallTau.handlepairwise!(x, float.(y), similar(x), similar(float.(y))) == ([2, 3, 4], [1.0, 2.0, 4.0])
@test KendallTau.handlepairwise!(u, v, similar(u), similar(v)) == (Int64[], Int64[])
@test KendallTau.handlepairwise(x, y, similar(x), similar(y)) == ([2, 3, 4], [1, 2, 4])
@test KendallTau.handlepairwise(float.(x), y, similar(float.(x)), similar(y)) == ([2.0, 3.0, 4.0], [1, 2, 4])
@test KendallTau.handlepairwise(x, float.(y), similar(x), similar(float.(y))) == ([2, 3, 4], [1.0, 2.0, 4.0])
@test KendallTau.handlepairwise(u, v, similar(u), similar(v)) == (Int64[], Int64[])
@test KendallTau.handlelistwise(mx, mx) == ([1 2; 5 6], [1 2; 5 6])

end #testset

0 comments on commit 58af14a

Please sign in to comment.