Skip to content

Commit

Permalink
Refactor to reduce allocations
Browse files Browse the repository at this point in the history
Needed a function barrier to avoid type instability
  • Loading branch information
PGS62 committed Feb 6, 2024
1 parent 9cbed12 commit adffb75
Showing 1 changed file with 52 additions and 31 deletions.
83 changes: 52 additions & 31 deletions src/corkendall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,28 @@
const RoMVector{T<:Real} = AbstractVector{<:Union{T,Missing}}
const RoMMatrix{T<:Real} = AbstractMatrix{<:Union{T,Missing}}

function corkendall_degenerate(x, y, skipmissing::Symbol=:none)

corkendall_validateargs(x, y, skipmissing)
offdiag = skipmissing == :none ? missing : NaN
nr, nc = size(x, 2), size(y, 2)
if x === y
return ifelse.((1:nr) .== (1:nc)', 1.0, offdiag)
else
return fill(offdiag, nr, nc)
end
end

function corkendall_validateargs(x, y, skipmissing)
Base.require_one_based_indexing(x, y)
size(x, 1) == size(y, 1) || throw(DimensionMismatch("x and y have
inconsistent dimensions"))
skipmissing == :none || skipmissing == :pairwise || skipmissing == :listwise ||
throw(ArgumentError("skipmissing must be one of :none, :pairwise or :listwise, \
but got :$skipmissing"))
end


"""
corkendall(x, y=x; skipmissing::Symbol=:none)
Expand All @@ -24,49 +46,50 @@ available.
`i`th row of `y` is `missing` then the entire `i`th rows are skipped; note that
this might skip a high proportion of entries. Only allowed when `x` or `y` is a matrix.
"""

function corkendall(x::RoMMatrix{T}, y::RoMMatrix{U}=x;
skipmissing::Symbol=:none) where {T,U}

Base.require_one_based_indexing(x, y)

size(x, 1) == size(y, 1) || throw(DimensionMismatch("x and y have
inconsistent dimensions"))
missing_allowed = missing isa eltype(x) || missing isa eltype(y)
nr, nc = size(x, 2), size(y, 2)

symmetric = x === y

missing_allowed = missing isa eltype(x) || missing isa eltype(y)
skipmissing in [:none, :pairwise, :listwise] || throw(ArgumentError("skipmissing must \
be one of :none, :pairwise or :listwise, but got :$skipmissing"))

# Degenerate case - U and/or T not defined.
if x isa Matrix{Missing} || y isa Matrix{Missing}
offdiag = missing_allowed && skipmissing == :none ? missing : NaN
nr, nc = size(x, 2), size(y, 2)
if missing_allowed && skipmissing == :listwise
x, y = handle_listwise(x, y)
if symmetric
return ifelse.((1:nr) .== (1:nc)', 1.0, offdiag)
else
return fill(offdiag, nr, nc)
y = x
end
end

if x isa Matrix{Missing} || y isa Matrix{Missing}
return corkendall_degenerate(x, y, skipmissing)
end

if skipmissing == :none && missing_allowed
C = ones(Union{Missing,Float64}, nr, nc)
else
C = ones(Float64, nr, nc)
end

return (_corkendall(x, y; C, skipmissing))

end

function _corkendall(x::RoMMatrix{T}, y::RoMMatrix{U}=x; skipmissing::Symbol=:none, C) where {T,U}

corkendall_validateargs(x, y, skipmissing)

symmetric = x === y

# Swap x and y for more efficient threaded loop.
if size(x, 2) < size(y, 2)
return collect(transpose(corkendall(y, x; skipmissing)))
end

if missing_allowed && skipmissing == :listwise
x, y = handle_listwise(x, y)
end

m, nr = size(x)
nc = size(y, 2)

if missing_allowed && skipmissing == :none
C = ones(Union{Missing,Float64}, nr, nc)
else
C = ones(Float64, nr, nc)
end

# Avoid unnecessary allocation when nthreads is large but output matrix is small.
n_duplicates = min(Threads.nthreads(), symmetric ? nr - 1 : nr)

Expand Down Expand Up @@ -126,14 +149,9 @@ end

function corkendall(x::RoMVector{T}, y::RoMVector{U}; skipmissing::Symbol=:none) where {T,U}

Base.require_one_based_indexing(x, y)

length(x) == length(y) || throw(DimensionMismatch("x and y have \
inconsistent dimensions"))
corkendall_validateargs(x, y, skipmissing)

missing_allowed = missing isa eltype(x) || missing isa eltype(y)
skipmissing in [:none, :pairwise] || throw(ArgumentError("skipmissing must be one of \
:none or :pairwise, but got :$skipmissing"))

if missing_allowed && skipmissing == :none
if any(ismissing, x) || any(ismissing, y)
Expand Down Expand Up @@ -394,6 +412,9 @@ function handle_pairwise(x::RoMVector{T}, y::RoMVector{U};
return view(scratch_fx, 1:j), view(scratch_fy, 1:j)
end

handle_listwise(x::AbstractMatrix{<:Real}, y::AbstractMatrix{<:Real}) = x, y


"""
handle_listwise(x::RoMMatrix{T}, y::RoMMatrix{U}) where {T,U}
Expand Down

0 comments on commit adffb75

Please sign in to comment.