Skip to content

Commit

Permalink
Fixing edge cases and adding tests for them
Browse files Browse the repository at this point in the history
Also re-arranged tests so that if ever we port this code to StatsBase then we take over only the tests in test/corkendall, Testing against corkendall_naive probably too complicated...
  • Loading branch information
PGS62 committed Feb 7, 2024
1 parent 9d13d6d commit d52ed2d
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 49 deletions.
43 changes: 29 additions & 14 deletions src/corkendall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ const RoMMatrix{T<:Real} = AbstractMatrix{<:Union{T,Missing}}
corkendall(x, y=x; skipmissing::Symbol=:none)
Compute Kendall's rank correlation coefficient, τ. `x` and `y` must be either vectors or
matrices, with elements that are either real numbers or `missing`. When either x or y is a
matrix the function uses multiple threads if available.
matrices, with elements that are either real numbers or `missing`.
When either `x` or `y` is a matrix the function uses multiple threads if available.
# Keyword argument
- `skipmissing::Symbol=:none`: If `:none` (the default), then `missing` entries in `x` or
- `skipmissing::Symbol=:none`: If `:none` (the default), `missing` entries in `x` or
`y` give rise to `missing` entries in the return. If `:pairwise`, when either of the
`i`th entries of the vectors required to calculate an element of the return is `missing`,
both entries are skipped. If `:listwise`, when any entry in the `i`th row of `x` or the
Expand All @@ -28,23 +29,25 @@ function corkendall(x::RoMMatrix{T}, y::RoMMatrix{U}=x;
skipmissing::Symbol=:none) where {T,U}

corkendall_validateargs(x, y, skipmissing, true)
symmetric = x===y

missing_allowed = missing isa eltype(x) || missing isa eltype(y)
nr, nc = size(x, 2), size(y, 2)

if missing_allowed && skipmissing == :listwise
x, y = handle_listwise(x, y)
end
(m, nr), nc = size(x), size(y, 2)

#Degenerate case - T and/or U not defined.
if x isa Matrix{Missing} || y isa Matrix{Missing}
offdiag = skipmissing == :none ? missing : NaN
if x === y
offdiag = (m >= 2 && skipmissing == :none) ? missing : NaN
if symmetric
return ifelse.((1:nr) .== (1:nc)', 1.0, offdiag)
else
return fill(offdiag, nr, nc)
end
end

if missing_allowed && skipmissing == :listwise
x, y = handle_listwise(x, y)
end

if skipmissing == :none && missing_allowed
C = ones(Union{Missing,Float64}, nr, nc)
else
Expand Down Expand Up @@ -130,21 +133,24 @@ function corkendall(x::RoMVector{T}, y::RoMVector{U}; skipmissing::Symbol=:none)

corkendall_validateargs(x, y, skipmissing, false)

length(x) >= 2 || return NaN

missing_allowed = missing isa eltype(x) || missing isa eltype(y)

if missing_allowed && skipmissing == :none
if any(ismissing, x) || any(ismissing, y)
return missing
end
elseif x isa Vector{Missing} || y isa Vector{Missing}
# Degenerate case - U and/or T not defined.
#Degenerate case - T and/or U not defined.
return NaN
end

x = copy(x)

if missing_allowed && skipmissing == :pairwise
x, y = handle_pairwise(x, y)
length(x) >= 2 || return NaN
end

permx = sortperm(x)
Expand Down Expand Up @@ -417,11 +423,20 @@ function handle_listwise(x::RoMMatrix{T}, y::RoMMatrix{U}) where {T,U}

axes(x, 1) == axes(y, 1) || throw(DimensionMismatch("x and y have inconsistent dimensions"))

a = similar(x, T)
k = 0

symmetric = x === y

#Degenerate case - T and/or U not defined.
if x isa Matrix{Missing} || y isa Matrix{Missing}
if symmetric
return view(x, [], :), view(x, [], :)
else
return view(x, [], :), view(y, [], :)
end
end

a = similar(x, T)

k = 0
if symmetric
@inbounds for i in axes(x, 1)
if all(j -> !ismissing(x[i, j]), axes(x, 2))
Expand Down
10 changes: 5 additions & 5 deletions src/corkendall_fromfile.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import CSV
using CSV: write, File, getnames
import DataFrames
import Tables
import Statistics
Expand Down Expand Up @@ -59,7 +59,7 @@ function corkendall_fromfile(file1::String, file2::String, outputfile::String,
DataFrames.insertcols!(datatowrite, 1, Symbol("") => String.(names1))
end

filename = CSV.write(outputfile, datatowrite, header=writeheaders)
filename = write(outputfile, datatowrite, header=writeheaders)

if whattoreturn == "filename"
return filename
Expand Down Expand Up @@ -93,15 +93,15 @@ function csvread(filename::String, ignorefirstrow::Bool, ignorefirstcol::Bool;
types = Union{Missing,Float64}
strict = true

filedata = CSV.File(filename; header, drop, missingstring, types, strict)
filedata = File(filename; header, drop, missingstring, types, strict)
data = Tables.matrix(filedata)

#Convert to Array{Float64} if there are in fact no missings
if isnothing(findfirst(ismissing, data))
data = identity.(data)
end

names = CSV.getnames(filedata)
names = getnames(filedata)

return data, names
end
Expand Down Expand Up @@ -163,7 +163,7 @@ function comparecorrelationfiles(file1::String, file2::String)

absdiffs = abs.(data1 .- data2)

return (maximum(absdiffs), median(absdiffs))
return (maximum(absdiffs), Statistics.median(absdiffs))

end

Expand Down
2 changes: 1 addition & 1 deletion test/compare_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,4 +198,4 @@ function myisapprox(x::Union{T,Missing}, y::Union{U,Missing},
end
end

myisequal(x, y) = myisapprox(x, y, 0.0)
myisequal(x, y) = myisapprox(x, y, 0.0)
51 changes: 32 additions & 19 deletions test/corkendall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@ using KendallTau
using Test
using Random

include("corkendall_naive.jl")
include("compare_implementations.jl")

@testset "corkendall_auxiliary_fns" begin

#Auxiliary functions for corkendall
Expand All @@ -25,6 +22,12 @@ include("compare_implementations.jl")
@test KendallTau.handle_pairwise(u, v) == (Int64[], Int64[])
@test KendallTau.handle_listwise(mx, mx) == ([1 2; 5 6], [1 2; 5 6])

#Test handling of symmetric inputs
res1, res2 = KendallTau.handle_listwise(mx, mx)
@test res1 === res2
res1, res2 = KendallTau.handle_listwise(mx, copy(mx))
@test !(res1 === res2)

v = collect(100:-1:1)
KendallTau.insertion_sort!(v, 1, 100)
@test v == 1:100
Expand All @@ -35,6 +38,7 @@ include("compare_implementations.jl")

end


@testset "corkendall" begin

x = Float64[1 0; 2 1; 3 0; 4 1; 5 10]
Expand Down Expand Up @@ -102,17 +106,39 @@ end
@test isnan(f([1, 2, 3, 4, 5], xm, skipmissing=:pairwise))
@test isnan(f(xm, [1, 2, 3, 4, 5], skipmissing=:pairwise))
@test isequal(f(xmm, skipmissing=:pairwise), [1.0 NaN; NaN 1.0])
@test isequal(f(xmm, skipmissing=:none), [1.0 missing; missing 1.0])
@test isequal(f(xmm, xmm, skipmissing=:none), [1.0 missing; missing 1.0])
@test isequal(f(xmm, copy(xmm), skipmissing=:none), [missing missing; missing missing])
@test isequal(f(xmm, xmm, skipmissing=:listwise), [1.0 NaN; NaN 1.0])
@test isequal(f(xmm, copy(xmm), skipmissing=:listwise), [NaN NaN; NaN NaN])

@test isequal(f(xmm, copy(xmm), skipmissing=:pairwise), [NaN NaN; NaN NaN])

@test ismissing(f([1, 2, 3, 4, 5], xm, skipmissing=:none))
@test ismissing(f([1, 2, 3, 4, 5], xm, skipmissing=:none))
@test isequal(f(xmm, skipmissing=:none), [1.0 missing; missing 1.0])
@test isequal(f(xmm, copy(xmm), skipmissing=:none), [missing missing; missing missing])
@test_throws ArgumentError f([1,2,3,4],[4,3,2,1], skipmissing = :listwise)
@test isequal(f(hcat(Y, xm), skipmissing=:none), vcat(hcat(f(Y, skipmissing=:none), [missing, missing, missing]), [missing missing missing 1.0]))
@test_throws ArgumentError f([1, 2, 3, 4], [4, 3, 2, 1], skipmissing=:listwise)

#interaction of NaNs and missing inputs with skipmissing argument
nan_and_missing = hcat(fill(NaN,10,1),fill(missing,10,1))
@test isequal(f(nan_and_missing,skipmissing=:none),[1.0 missing;missing 1.0])
@test isequal(f(nan_and_missing,copy(nan_and_missing),skipmissing=:none),[NaN missing;missing missing])
@test isequal(f(nan_and_missing,skipmissing=:pairwise),[1.0 NaN;NaN 1.0])
@test isequal(f(nan_and_missing,copy(nan_and_missing),skipmissing=:pairwise),[NaN NaN;NaN NaN])
@test isequal(f(nan_and_missing,skipmissing=:listwise),[1.0 NaN;NaN 1.0])
@test isequal(f(nan_and_missing,copy(nan_and_missing),skipmissing=:listwise),[NaN NaN;NaN NaN])

@test_throws ArgumentError f(x; skipmissing=:foo)
@test_throws ArgumentError f(Xm; skipmissing=:foo)

#when inputs have fewer than 2 rows return should be NaN even when inputs are missing
@test isnan(f(Float64[],Float64[]))
@test isnan(f([1],[1]))
@test isnan(f([missing],[missing]))
@test isequal(f([missing],[missing missing]),[NaN NaN])

c11 = f(x1, x1)
c12 = f(x1, x2)
c22 = f(x2, x2)
Expand Down Expand Up @@ -196,21 +222,8 @@ end
n_reps = Threads.nthreads()
@test f(repeat(hcat(a, b), outer=[1, n_reps])) == repeat(f(hcat(a, b)), outer=[n_reps, n_reps])

#= Test functions against corkendall_naive, a "reference implementation" that has the
advantage of simplicity.
=#
if f !== corkendall_naive
@test compare_implementations(f, corkendall_naive, abstol=0.0, maxcols=10, maxrows=10, numtests=200) == true
@test compare_implementations(f, corkendall_naive, abstol=0.0, maxcols=10, maxrows=100, numtests=200) == true
@test compare_implementations(f, corkendall_naive, abstol=1e14, maxcols=2, maxrows=20000, numtests=5) == true
end

end

smallx = randn(MersenneTwister(123), 1000, 3)
indicators = rand(MersenneTwister(456), 1000, 3) .< 0.05
smallx = ifelse.(indicators, missing, smallx)
@test corkendall_naive(smallx, skipmissing=:pairwise) == KendallTau.corkendall(smallx, skipmissing=:pairwise)
@test corkendall_naive(smallx, skipmissing=:listwise) == KendallTau.corkendall(smallx, skipmissing=:listwise)

end
end

31 changes: 22 additions & 9 deletions test/corkendall_naive.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
using KendallTau: corkendall_validateargs, handle_listwise, handle_pairwise
using KendallTau: corkendall_validateargs, handle_listwise, handle_pairwise, RoMVector, RoMMatrix

# RoM = "Real or Missing"
const RoMVector{T<:Real} = AbstractVector{<:Union{T,Missing}}
const RoMMatrix{T<:Real} = AbstractMatrix{<:Union{T,Missing}}
#const RoMVector{T<:Real} = AbstractVector{<:Union{T,Missing}}
#const RoMMatrix{T<:Real} = AbstractMatrix{<:Union{T,Missing}}

"""
corkendall_naive(x, y=x; skipmissing::Symbol=:none)
Expand All @@ -29,10 +29,11 @@ function corkendall_naive(x::RoMMatrix{T}, y::RoMMatrix{U}=x;

missing_allowed = missing isa eltype(x) || missing isa eltype(y)

# Degenerate case - U and/or T not defined.
(m, nr), nc = size(x), size(y, 2)

#Degenerate case - T and/or U not defined.
if x isa Matrix{Missing} || y isa Matrix{Missing}
offdiag = missing_allowed && skipmissing == :none ? missing : NaN
nr, nc = size(x, 2), size(y, 2)
offdiag = (m >= 2 && skipmissing == :none) ? missing : NaN
if symmetric
return ifelse.((1:nr) .== (1:nc)', 1.0, offdiag)
else
Expand Down Expand Up @@ -66,11 +67,17 @@ function corkendall_naive(x::RoMVector{T}, y::RoMVector{U}; skipmissing::Symbol=

corkendall_validateargs(x, y, skipmissing, false)

length(x)>=2 || return(NaN)

missing_allowed = missing isa eltype(x) || missing isa eltype(y)

if x isa Vector{Missing} || y isa Vector{Missing}
# Degenerate case - U and/or T not defined.
return skipmissing == :none ? missing : NaN
if missing_allowed && skipmissing == :none
if any(ismissing, x) || any(ismissing, y)
return missing
end
elseif x isa Vector{Missing} || y isa Vector{Missing}
#Degenerate case - T and/or U not defined.
return NaN
end

x = copy(x)
Expand Down Expand Up @@ -101,6 +108,12 @@ function corkendall_naive_kernel!(x, y, skipmissing::Symbol)
if missing isa eltype(x) || missing isa eltype(y)
x, y = handle_pairwise(x, y)
end
elseif skipmissing == :none
if missing isa eltype(x) || missing isa eltype(y)
if any(ismissing, x) || any(ismissing, y)
return (missing)
end
end
end

n = length(x)
Expand Down
7 changes: 6 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
using KendallTau
using Test

include("corkendall_naive.jl")
include("compare_implementations.jl")


include("corkendall.jl")
include("corkendall_fromfile.jl")
include("corkendall_fromfile.jl")
include("versus_naive.jl")
24 changes: 24 additions & 0 deletions test/versus_naive.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using KendallTau
using Test
using Random

#=
Note that corkendall and corkendall_naive share some subroutines, notably handle_pairwise
and handle_listwise. If those were bugged then this test would likely give a false positive.
=#

@testset "versus_corkendall_naive" begin

@test compare_implementations(corkendall, corkendall_naive, abstol=0.0, maxcols=10, maxrows=10, numtests=200) == true
@test compare_implementations(corkendall, corkendall_naive, abstol=0.0, maxcols=10, maxrows=100, numtests=200) == true
@test compare_implementations(corkendall, corkendall_naive, abstol=1e14, maxcols=2, maxrows=20000, numtests=5) == true

smallx = randn(MersenneTwister(123), 1000, 3)
indicators = rand(MersenneTwister(456), 1000, 3) .< 0.05
smallx = ifelse.(indicators, missing, smallx)
@test corkendall_naive(smallx, skipmissing=:pairwise) == KendallTau.corkendall(smallx, skipmissing=:pairwise)
@test corkendall_naive(smallx, skipmissing=:listwise) == KendallTau.corkendall(smallx, skipmissing=:listwise)
@test isequal(corkendall_naive(smallx, skipmissing=:none) , KendallTau.corkendall(smallx, skipmissing=:none))


end

0 comments on commit d52ed2d

Please sign in to comment.