diff --git a/src/corkendall.jl b/src/corkendall.jl index 783cb52..a460c47 100644 --- a/src/corkendall.jl +++ b/src/corkendall.jl @@ -71,7 +71,7 @@ function corkendall(x::RoMVector{T}, y::RoMVector{U}; skipmissing::Symbol=:none) Base.require_one_based_indexing(x, y) length(x) == length(y) || throw(DimensionMismatch("x and y have inconsistent dimensions")) - (x isa Vector{Missing} || y isa Vector{Missing}) && return NaN + (x isa Vector{Missing} || y isa Vector{Missing}) && return NaN x = copy(x) x, y = handlelistwise(x, y, skipmissing) @@ -121,12 +121,12 @@ function corkendall(x::RoMMatrix{T}, y::RoMMatrix{U}=x; skipmissing::Symbol=:non # good multi-threaded performance. One vector per thread to avoid cross-talk between # threads. duplicate(x, n) = [copy(x) for _ in 1:n] - scratchyvectors = duplicate(similar(y, m), n_duplicates) + scratchpermutedys = duplicate(similar(y, m), n_duplicates) ycolis = duplicate(similar(y, m), n_duplicates) xcoljsorteds = duplicate(similar(x, m), n_duplicates) permxs = duplicate(zeros(Int, m), n_duplicates) - txs = duplicate(Vector{T}(undef, m), n_duplicates) - tys = duplicate(Vector{U}(undef, m), n_duplicates) + scratchxs = duplicate(Vector{T}(undef, m), n_duplicates) + scratchys = duplicate(Vector{U}(undef, m), n_duplicates) #= Use the "static scheduler". This is the "quickfix, but not recommended longterm" way of avoiding concurrency bugs. See https://julialang.org/blog/2023/07/PSA-dont-use-threadid/#fixing_buggy_code_which_uses_this_pattern @@ -142,12 +142,12 @@ function corkendall(x::RoMMatrix{T}, y::RoMMatrix{U}=x; skipmissing::Symbol=:non id = Threads.threadid() end - scratchyvector = scratchyvectors[id] + scratchpermutedy = scratchpermutedys[id] ycoli = ycolis[id] xcoljsorted = xcoljsorteds[id] permx = permxs[id] - tx = txs[id] - ty = tys[id] + scratchx = scratchxs[id] + scratchy = scratchys[id] sortperm!(permx, view(x, :, j)) @inbounds for k in eachindex(xcoljsorted) @@ -156,7 +156,7 @@ function corkendall(x::RoMMatrix{T}, y::RoMMatrix{U}=x; skipmissing::Symbol=:non for i = 1:(symmetric ? j - 1 : nc) ycoli .= view(y, :, i) - C[j, i] = corkendall_sorted!(xcoljsorted, ycoli, permx, scratchyvector, tx, ty) + C[j, i] = corkendall_sorted!(xcoljsorted, ycoli, permx, scratchpermutedy, scratchx, scratchy) symmetric && (C[i, j] = C[j, i]) end end @@ -180,8 +180,8 @@ end # JSTOR, www.jstor.org/stable/2282833. """ corkendall_sorted!(sortedx::RoMVector{T}, y::RoMVector{U}, - permx::AbstractVector{<:Integer}, scratchyvector::RoMVector, - tx::AbstractVector{T}, ty::AbstractVector{U}) where {T,U} + permx::AbstractVector{<:Integer}, scratchpermutedy::RoMVector, + scratchx::AbstractVector{T}, scratchy::AbstractVector{U}) where {T,U} Kendall correlation between two vectors but this function omits the initial sorting of the first argument. So calculating Kendall correlation between `x` and `y` is a two stage @@ -189,27 +189,29 @@ process: a) sort `x` to get `sortedx`; b) call this function on `sortedx` and `y subsequent arguments being: - `permx::AbstractVector{<:Integer}`: the permutation that achieved the sorting of `x` to yield `sortedx`. -- `scratchyvector::RoMVector`: a vector of the same element type and length as `y`; used +- `scratchpermutedy::RoMVector`: a vector of the same element type and length as `y`; used to permute `y` without allocation. -- `tx, ty`: vectors of the same length as `x` and `y` whose element types match the types +- `scratchx, scratchy`: vectors of the same length as `x` and `y` whose element types match the types of the non-missing elements of `x` and `y` respectively; used (in the call to `handlepairwise!`) to avoid allocations. """ function corkendall_sorted!(sortedx::RoMVector{T}, y::RoMVector{U}, - permx::AbstractVector{<:Integer}, scratchyvector::RoMVector{U}, - tx::AbstractVector{T}, ty::AbstractVector{U}) where {T,U} + permx::AbstractVector{<:Integer}, scratchpermutedy::RoMVector{U}, + scratchx::AbstractVector{T}, scratchy::AbstractVector{U}) where {T,U} + + length(sortedx) >= 2 || return NaN @inbounds for i in eachindex(y) - scratchyvector[i] = y[permx[i]] - end - if missing isa eltype(sortedx) || missing isa eltype(scratchyvector) - sortedx, scratchyvector = handlepairwise!(sortedx, scratchyvector, tx, ty) + scratchpermutedy[i] = y[permx[i]] end - length(sortedx) >= 2 || return NaN - shuffledy = scratchyvector + if missing isa eltype(sortedx) || missing isa eltype(scratchpermutedy) + sortedx, permutedy = handlepairwise!(sortedx, scratchpermutedy, scratchx, scratchy) + else + permutedy = scratchpermutedy + end - if any(isnan, sortedx) || any(isnan, shuffledy) + if any(isnan, sortedx) || any(isnan, permutedy) return NaN end n = length(sortedx) @@ -223,23 +225,23 @@ function corkendall_sorted!(sortedx::RoMVector{T}, y::RoMVector{U}, if sortedx[i-1] == sortedx[i] k += 1 elseif k > 0 - # Sort the corresponding chunk of shuffledy, so the rows of hcat(sortedx,shuffledy) - # are sorted first on sortedx, then (where sortedx values are tied) on shuffledy. + # Sort the corresponding chunk of permutedy, so the rows of hcat(sortedx,permutedy) + # are sorted first on sortedx, then (where sortedx values are tied) on permutedy. # Hence double ties can be counted by calling countties. - sort!(view(shuffledy, (i-k-1):(i-1))) + sort!(view(permutedy, (i-k-1):(i-1))) ntiesx += div(widen(k) * (k + 1), 2) # Must use wide integers here - ndoubleties += countties(shuffledy, i - k - 1, i - 1) + ndoubleties += countties(permutedy, i - k - 1, i - 1) k = 0 end end if k > 0 - sort!(view(shuffledy, (n-k):n)) + sort!(view(permutedy, (n-k):n)) ntiesx += div(widen(k) * (k + 1), 2) - ndoubleties += countties(shuffledy, n - k, n) + ndoubleties += countties(permutedy, n - k, n) end - nswaps = merge_sort!(shuffledy, 1, n, ty) - ntiesy = countties(shuffledy, 1, n) + nswaps = merge_sort!(permutedy, 1, n, scratchy) + ntiesy = countties(permutedy, 1, n) # Calls to float below prevent possible overflow errors when # length(sortedx) exceeds 77_936 (32 bit) or 5_107_605_667 (64 bit) @@ -390,25 +392,25 @@ end """ handlepairwise!(x::RoMVector{T}, y::RoMVector{U}, - tx::AbstractVector{T}, ty::AbstractVector{U}) where {T,U} + scratchx::AbstractVector{T}, scratchy::AbstractVector{U}) where {T,U} Return a pair `(a,b)`, filtered copies of `(x,y)`, in which elements `x[i]` and `y[i]` are filtered out if `ismissing(x[i])||ismissing(y[i])`. """ function handlepairwise!(x::RoMVector{T}, y::RoMVector{U}, - tx::AbstractVector{T}, ty::AbstractVector{U}) where {T,U} + scratchx::AbstractVector{T}, scratchy::AbstractVector{U}) where {T,U} j = 0 @inbounds for i in eachindex(x) if !(ismissing(x[i]) || ismissing(y[i])) j += 1 - tx[j] = x[i] - ty[j] = y[i] + scratchx[j] = x[i] + scratchy[j] = y[i] end end - return view(tx, 1:j), view(ty, 1:j) + return view(scratchx, 1:j), view(scratchy, 1:j) end """