Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
PGS62 committed Jan 30, 2024
1 parent d99121f commit 0bb5191
Show file tree
Hide file tree
Showing 2 changed files with 256 additions and 254 deletions.
144 changes: 73 additions & 71 deletions src/corkendall.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#######################################
#
#
# Kendall correlation
#
#
#######################################

# RoM = "Real or Missing"
Expand All @@ -15,25 +15,25 @@ https://julialang.org/blog/2023/07/PSA-dont-use-threadid/#another_option_use_a_p

#= TODO 22 Feb 2023
1) Should the default value of skipmissing be :none or :pairwise? :none forces the user to
address the question of how missings should be handled, but at least for REPL use, it's
address the question of how missings should be handled, but at least for REPL use, it's
rather inconvenient.
2) Should the docstring mention that the function is multi-threaded? Currently no function
in StatsBase is multi-threaded... By default, Julia starts up single-threaded...
3) How to get compatibility of corkendall with StatsBase.pairwise? The problem is that
3) How to get compatibility of corkendall with StatsBase.pairwise? The problem is that
pairwise passes vectors to f that don't contain missing but for which missing isa eltype
and corkendall then wants a skipmissing argument.
Option a)
Amend StatsBase._pairwise! to replace line:
dest[i, j] = f(ynm, ynm)
with:
with:
dest[i, j] = f(disallowmissing(ynm), disallowmissing(ynm))
Option b)
Some other way to arrange that arguments passed to f from within StatsBase._pairwise!
do not have Missing as an allowed element type. Using function handlepairwise would do
do not have Missing as an allowed element type. Using function handle_pairwise! would do
that efficiently.
Option c)
Expand All @@ -48,7 +48,7 @@ https://julialang.org/blog/2023/07/PSA-dont-use-threadid/#another_option_use_a_p
end
Option d)
Have a dedicated method of _pairwise! to handle f === corkendall. This has a big
Have a dedicated method of _pairwise! to handle f === corkendall. This has a big
performance advantage, and is maybe along the lines suggested by nalimilan here:
https://github.com/JuliaStats/StatsBase.jl/pull/647#issuecomment-775264454
=#
Expand All @@ -62,22 +62,29 @@ vectors or matrices, with elements that are either real numbers or missing value
# Keyword argument
- `skipmissing::Symbol=:none`: if `:none`, missing values in either `x` or `y`
cause the function to raise an error. Use `:pairwise` to skip entries with a missing
value in either of the two vectors used to calculate (an element of) the return. Use
cause the function to raise an error. Use `:pairwise` to skip entries with a missing
value in either of the two vectors used to calculate (an element of) the return. Use
`:listwise` to skip entries where a missing value appears anywhere in a given row of `x`
or `y`; note that this might drop a high proportion of entries.
"""
function corkendall(x::RoMVector{T}, y::RoMVector{U}; skipmissing::Symbol=:none) where {T,U}

Base.require_one_based_indexing(x, y)

length(x) == length(y) || throw(DimensionMismatch("x and y have inconsistent dimensions"))
(x isa Vector{Missing} || y isa Vector{Missing}) && return NaN

missing_allowed = missing isa eltype(x) || missing isa eltype(y)
validate_skipmissing(skipmissing, missing_allowed)

# Degenerate case - U and/or T not defined.
if x isa Vector{Missing} || y isa Vector{Missing}
return NaN
end

x = copy(x)

if missing isa eltype(x) || missing isa eltype(y)
x, y = handlelistwise(x, y, skipmissing)
x, y = handlepairwise(x, y, similar(x, T), similar(y, U))
if missing_allowed && skipmissing != :none #pairwise and listwise the same for vector-vector case
x, y = handle_pairwise!(x, y, similar(x, T), similar(y, U))
end

permx = sortperm(x)
Expand All @@ -89,29 +96,36 @@ end
function corkendall(x::RoMMatrix{T}, y::RoMMatrix{U}=x; skipmissing::Symbol=:none) where {T,U}

Base.require_one_based_indexing(x, y)
symmetric = x === y
if size(x, 1) != size(y, 1)
throw(DimensionMismatch("x and y have inconsistent dimensions"))
end

# Swap x and y for more efficient threaded loop.
if size(x, 2) < size(y, 2)
return collect(transpose(corkendall(y, x; skipmissing)))
end
size(x, 1) == size(y, 1) || throw(DimensionMismatch("x and y have inconsistent dimensions"))

x, y = handlelistwise(x, y, skipmissing)
m, nr = size(x)
nc = size(y, 2)
symmetric = x === y

missing_allowed = missing isa eltype(x) || missing isa eltype(y)
validate_skipmissing(skipmissing, missing_allowed)

# Handle degenerate case early to simplify subsequent code (U and/or T not defined).
# Degenerate case - U and/or T not defined.
if x isa Matrix{Missing} || y isa Matrix{Missing}
nr, nc = size(x, 2), size(y, 2)
if symmetric
return ifelse.((1:nr) .== (1:nc)', 1.0, NaN)
else
return fill(NaN, nr, nc)
end
end

# Swap x and y for more efficient threaded loop.
if size(x, 2) < size(y, 2)
return collect(transpose(corkendall(y, x; skipmissing)))
end

if missing_allowed && skipmissing == :listwise
x, y = handle_listwise!(x, y)
end

m, nr = size(x)
nc = size(y, 2)

C = ones(Float64, nr, nc)
# Avoid unnecessary allocation when nthreads is large but output matrix is small.
n_duplicates = min(Threads.nthreads(), symmetric ? nr - 1 : nr)
Expand All @@ -121,7 +135,7 @@ function corkendall(x::RoMMatrix{T}, y::RoMMatrix{U}=x; skipmissing::Symbol=:non
a = Threads.Atomic{Int}(1)
end

#= Create scratch vectors so that threaded code can be non-allocating, a requirement for
#= Create scratch vectors so that threaded code can be non-allocating, a requirement for
good multi-threaded performance. One vector per thread to avoid cross-talk between
threads.
=#
Expand Down Expand Up @@ -169,8 +183,9 @@ function corkendall(x::RoMMatrix{T}, y::RoMMatrix{U}=x; skipmissing::Symbol=:non
return C
end

# corkendall returns a vector in this case, inconsistent with with Statistics.cor and
# StatsBase.corspearman, but consistent with StatsBase.corkendall.
#= corkendall returns a vector in this case, inconsistent with with Statistics.cor and
StatsBase.corspearman, but consistent with StatsBase.corkendall.
=#
function corkendall(x::RoMMatrix, y::RoMVector; skipmissing::Symbol=:none)
return vec(corkendall(x, reshape(y, (length(y), 1)); skipmissing))
end
Expand All @@ -186,7 +201,7 @@ end
# JSTOR, www.jstor.org/stable/2282833.
"""
corkendall_sorted!(sortedx::RoMVector{T}, y::RoMVector{U},
permx::AbstractVector{<:Integer}, scratchpermutedy::RoMVector,
permx::AbstractVector{<:Integer}, scratchpermutedy::RoMVector,
scratchx::AbstractVector{T}, scratchy::AbstractVector{U}) where {T,U}
Kendall correlation between two vectors but this function omits the initial sorting of
Expand All @@ -198,8 +213,8 @@ subsequent arguments being:
- `scratchpermutedy::RoMVector`: a vector of the same element type and length as `y`; used
to permute `y` without allocation.
- `scratchx, scratchy`: vectors of the same length as `x` and `y` whose element types match
the types of the non-missing elements of `x` and `y` respectively; used (in the calls to
`handlepairwise` and `merge_sort!`) to avoid allocations.
the types of the non-missing elements of `x` and `y` respectively; used (in the calls to
`handle_pairwise!` and `merge_sort!`) to avoid allocations.
"""
function corkendall_sorted!(sortedx::RoMVector{T}, y::RoMVector{U},
permx::AbstractVector{<:Integer}, scratchpermutedy::RoMVector{U},
Expand All @@ -212,7 +227,7 @@ function corkendall_sorted!(sortedx::RoMVector{T}, y::RoMVector{U},
end

if missing isa eltype(sortedx) || missing isa eltype(scratchpermutedy)
sortedx, permutedy = handlepairwise(sortedx, scratchpermutedy, scratchx, scratchy)
sortedx, permutedy = handle_pairwise!(sortedx, scratchpermutedy, scratchx, scratchy)
else
permutedy = scratchpermutedy
end
Expand Down Expand Up @@ -345,7 +360,7 @@ midpoint(lo::Integer, hi::Integer) = midpoint(promote(lo, hi)...)
insertion_sort!(v::AbstractVector, lo::Integer, hi::Integer)
Mutates `v` by sorting elements `x[lo:hi]` using the insertion sort algorithm.
This method is a copy-paste-edit of sort! in base/sort.jl, amended to return the bubblesort
This method is a copy-paste-edit of sort! in base/sort.jl, amended to return the bubblesort
distance.
"""
function insertion_sort!(v::AbstractVector, lo::Integer, hi::Integer)
Expand All @@ -370,15 +385,14 @@ function insertion_sort!(v::AbstractVector, lo::Integer, hi::Integer)
return nswaps
end


"""
handlepairwise(x::RoMVector{T}, y::RoMVector{U},
handle_pairwise!(x::RoMVector{T}, y::RoMVector{U},
scratchx::AbstractVector{T}, scratchy::AbstractVector{U}) where {T,U}
Return a pair `(a,b)`, filtered copies of `(x,y)`, in which elements `x[i]` and
`y[i]` are filtered out if `ismissing(x[i])||ismissing(y[i])`.
"""
function handlepairwise(x::RoMVector{T}, y::RoMVector{U},
function handle_pairwise!(x::RoMVector{T}, y::RoMVector{U},
scratchx::AbstractVector{T}, scratchy::AbstractVector{U}) where {T,U}

j = 0
Expand All @@ -395,45 +409,12 @@ function handlepairwise(x::RoMVector{T}, y::RoMVector{U},
end

"""
handlelistwise(x::AbstractArray,y::AbstractArray,skipmissing::Symbol)
If `skipmissing` is `:listwise` and `x` and `y` are both matrices then do listwise filtering
of `x` and `y`. Otherwise merely validate `skipmissing` argument.
"""
function handlelistwise(x::AbstractArray, y::AbstractArray, skipmissing::Symbol)
if skipmissing == :listwise
if x isa Matrix && y isa Matrix
return handlelistwise(x, y)
end
elseif skipmissing == :pairwise
elseif skipmissing == :none
if missing isa eltype(x) || missing isa eltype(y)
throw(ArgumentError("When missing is an allowed element type \
then keyword argument skipmissing must be either\
`:pairwise` or `:listwise`, but got `:$skipmissing`"))
end
else
if missing isa eltype(x) || missing isa eltype(y)
throw(ArgumentError("keyword argument skipmissing must be either \
`:pairwise` or `:listwise`, but got `:$skipmissing`"))
else
throw(ArgumentError("keyword argument skipmissing must be either \
`:pairwise`, `:listwise` or `:none` but got \
`:$skipmissing`"))
end
end
return x, y
end

"""
handlelistwise(x::RoMMatrix{T}, y::RoMMatrix{U}) where {T,U}
handle_listwise!(x::RoMMatrix{T}, y::RoMMatrix{U}) where {T,U}
Return a pair `(a,b)`, filtered copies of `(x,y)`, in which the rows `x[i,:]` and
`y[i,:]` are both filtered out if `any(ismissing,x[i,:])||any(ismissing,y[i,:])`.
"""
handlelistwise(x::AbstractMatrix{<:Real}, y::AbstractMatrix{<:Real}) = x, y

function handlelistwise(x::RoMMatrix{T}, y::RoMMatrix{U}) where {T,U}
function handle_listwise!(x::RoMMatrix{T}, y::RoMMatrix{U}) where {T,U}

nrx = size(x, 1)
nry = size(y, 1)
Expand Down Expand Up @@ -472,3 +453,24 @@ function handlelistwise(x::RoMMatrix{T}, y::RoMMatrix{U}) where {T,U}
return view(a, 1:k, :), view(b, 1:k, :)
end

function validate_skipmissing(skipmissing::Symbol, missing_allowed::Bool)
if skipmissing == :listwise
elseif skipmissing == :pairwise
elseif skipmissing == :none
if missing_allowed
throw(ArgumentError("When missing is an allowed element type \
then keyword argument skipmissing must be either\
`:pairwise` or `:listwise`, but got `:$skipmissing`"))
end
else
if missing_allowed
throw(ArgumentError("keyword argument skipmissing must be either \
`:pairwise` or `:listwise`, but got `:$skipmissing`"))
else
throw(ArgumentError("keyword argument skipmissing must be either \
`:pairwise`, `:listwise` or `:none` but got \
`:$skipmissing`"))
end
end
end

Loading

0 comments on commit 0bb5191

Please sign in to comment.