Skip to content

Commit

Permalink
refactoring: improved names of variables
Browse files Browse the repository at this point in the history
  • Loading branch information
PGS62 committed Jan 30, 2024
1 parent 4a8df7c commit b05a156
Showing 1 changed file with 36 additions and 34 deletions.
70 changes: 36 additions & 34 deletions src/corkendall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ function corkendall(x::RoMVector{T}, y::RoMVector{U}; skipmissing::Symbol=:none)

Base.require_one_based_indexing(x, y)
length(x) == length(y) || throw(DimensionMismatch("x and y have inconsistent dimensions"))
(x isa Vector{Missing} || y isa Vector{Missing}) && return NaN
(x isa Vector{Missing} || y isa Vector{Missing}) && return NaN

x = copy(x)
x, y = handlelistwise(x, y, skipmissing)
Expand Down Expand Up @@ -121,12 +121,12 @@ function corkendall(x::RoMMatrix{T}, y::RoMMatrix{U}=x; skipmissing::Symbol=:non
# good multi-threaded performance. One vector per thread to avoid cross-talk between
# threads.
duplicate(x, n) = [copy(x) for _ in 1:n]
scratchyvectors = duplicate(similar(y, m), n_duplicates)
scratchpermutedys = duplicate(similar(y, m), n_duplicates)
ycolis = duplicate(similar(y, m), n_duplicates)
xcoljsorteds = duplicate(similar(x, m), n_duplicates)
permxs = duplicate(zeros(Int, m), n_duplicates)
txs = duplicate(Vector{T}(undef, m), n_duplicates)
tys = duplicate(Vector{U}(undef, m), n_duplicates)
scratchxs = duplicate(Vector{T}(undef, m), n_duplicates)
scratchys = duplicate(Vector{U}(undef, m), n_duplicates)

#= Use the "static scheduler". This is the "quickfix, but not recommended longterm"
way of avoiding concurrency bugs. See https://julialang.org/blog/2023/07/PSA-dont-use-threadid/#fixing_buggy_code_which_uses_this_pattern
Expand All @@ -142,12 +142,12 @@ function corkendall(x::RoMMatrix{T}, y::RoMMatrix{U}=x; skipmissing::Symbol=:non
id = Threads.threadid()
end

scratchyvector = scratchyvectors[id]
scratchpermutedy = scratchpermutedys[id]
ycoli = ycolis[id]
xcoljsorted = xcoljsorteds[id]
permx = permxs[id]
tx = txs[id]
ty = tys[id]
scratchx = scratchxs[id]
scratchy = scratchys[id]

sortperm!(permx, view(x, :, j))
@inbounds for k in eachindex(xcoljsorted)
Expand All @@ -156,7 +156,7 @@ function corkendall(x::RoMMatrix{T}, y::RoMMatrix{U}=x; skipmissing::Symbol=:non

for i = 1:(symmetric ? j - 1 : nc)
ycoli .= view(y, :, i)
C[j, i] = corkendall_sorted!(xcoljsorted, ycoli, permx, scratchyvector, tx, ty)
C[j, i] = corkendall_sorted!(xcoljsorted, ycoli, permx, scratchpermutedy, scratchx, scratchy)
symmetric && (C[i, j] = C[j, i])
end
end
Expand All @@ -180,36 +180,38 @@ end
# JSTOR, www.jstor.org/stable/2282833.
"""
corkendall_sorted!(sortedx::RoMVector{T}, y::RoMVector{U},
permx::AbstractVector{<:Integer}, scratchyvector::RoMVector,
tx::AbstractVector{T}, ty::AbstractVector{U}) where {T,U}
permx::AbstractVector{<:Integer}, scratchpermutedy::RoMVector,
scratchx::AbstractVector{T}, scratchy::AbstractVector{U}) where {T,U}
Kendall correlation between two vectors but this function omits the initial sorting of
the first argument. So calculating Kendall correlation between `x` and `y` is a two stage
process: a) sort `x` to get `sortedx`; b) call this function on `sortedx` and `y`, with
subsequent arguments being:
- `permx::AbstractVector{<:Integer}`: the permutation that achieved the sorting of `x` to
yield `sortedx`.
- `scratchyvector::RoMVector`: a vector of the same element type and length as `y`; used
- `scratchpermutedy::RoMVector`: a vector of the same element type and length as `y`; used
to permute `y` without allocation.
- `tx, ty`: vectors of the same length as `x` and `y` whose element types match the types
- `scratchx, scratchy`: vectors of the same length as `x` and `y` whose element types match the types
of the non-missing elements of `x` and `y` respectively; used (in the call to
`handlepairwise!`) to avoid allocations.
"""
function corkendall_sorted!(sortedx::RoMVector{T}, y::RoMVector{U},
permx::AbstractVector{<:Integer}, scratchyvector::RoMVector{U},
tx::AbstractVector{T}, ty::AbstractVector{U}) where {T,U}
permx::AbstractVector{<:Integer}, scratchpermutedy::RoMVector{U},
scratchx::AbstractVector{T}, scratchy::AbstractVector{U}) where {T,U}

length(sortedx) >= 2 || return NaN

@inbounds for i in eachindex(y)
scratchyvector[i] = y[permx[i]]
end
if missing isa eltype(sortedx) || missing isa eltype(scratchyvector)
sortedx, scratchyvector = handlepairwise!(sortedx, scratchyvector, tx, ty)
scratchpermutedy[i] = y[permx[i]]
end
length(sortedx) >= 2 || return NaN

shuffledy = scratchyvector
if missing isa eltype(sortedx) || missing isa eltype(scratchpermutedy)
sortedx, permutedy = handlepairwise!(sortedx, scratchpermutedy, scratchx, scratchy)
else
permutedy = scratchpermutedy
end

if any(isnan, sortedx) || any(isnan, shuffledy)
if any(isnan, sortedx) || any(isnan, permutedy)
return NaN
end
n = length(sortedx)
Expand All @@ -223,23 +225,23 @@ function corkendall_sorted!(sortedx::RoMVector{T}, y::RoMVector{U},
if sortedx[i-1] == sortedx[i]
k += 1
elseif k > 0
# Sort the corresponding chunk of shuffledy, so the rows of hcat(sortedx,shuffledy)
# are sorted first on sortedx, then (where sortedx values are tied) on shuffledy.
# Sort the corresponding chunk of permutedy, so the rows of hcat(sortedx,permutedy)
# are sorted first on sortedx, then (where sortedx values are tied) on permutedy.
# Hence double ties can be counted by calling countties.
sort!(view(shuffledy, (i-k-1):(i-1)))
sort!(view(permutedy, (i-k-1):(i-1)))
ntiesx += div(widen(k) * (k + 1), 2) # Must use wide integers here
ndoubleties += countties(shuffledy, i - k - 1, i - 1)
ndoubleties += countties(permutedy, i - k - 1, i - 1)
k = 0
end
end
if k > 0
sort!(view(shuffledy, (n-k):n))
sort!(view(permutedy, (n-k):n))
ntiesx += div(widen(k) * (k + 1), 2)
ndoubleties += countties(shuffledy, n - k, n)
ndoubleties += countties(permutedy, n - k, n)
end

nswaps = merge_sort!(shuffledy, 1, n, ty)
ntiesy = countties(shuffledy, 1, n)
nswaps = merge_sort!(permutedy, 1, n, scratchy)
ntiesy = countties(permutedy, 1, n)

# Calls to float below prevent possible overflow errors when
# length(sortedx) exceeds 77_936 (32 bit) or 5_107_605_667 (64 bit)
Expand Down Expand Up @@ -390,25 +392,25 @@ end

"""
handlepairwise!(x::RoMVector{T}, y::RoMVector{U},
tx::AbstractVector{T}, ty::AbstractVector{U}) where {T,U}
scratchx::AbstractVector{T}, scratchy::AbstractVector{U}) where {T,U}
Return a pair `(a,b)`, filtered copies of `(x,y)`, in which elements `x[i]` and
`y[i]` are filtered out if `ismissing(x[i])||ismissing(y[i])`.
"""
function handlepairwise!(x::RoMVector{T}, y::RoMVector{U},
tx::AbstractVector{T}, ty::AbstractVector{U}) where {T,U}
scratchx::AbstractVector{T}, scratchy::AbstractVector{U}) where {T,U}

j = 0

@inbounds for i in eachindex(x)
if !(ismissing(x[i]) || ismissing(y[i]))
j += 1
tx[j] = x[i]
ty[j] = y[i]
scratchx[j] = x[i]
scratchy[j] = y[i]
end
end

return view(tx, 1:j), view(ty, 1:j)
return view(scratchx, 1:j), view(scratchy, 1:j)
end

"""
Expand Down

0 comments on commit b05a156

Please sign in to comment.