Skip to content

Commit

Permalink
fix type instabilities, more use of @inbounds
Browse files Browse the repository at this point in the history
  • Loading branch information
PGS62 committed Jan 23, 2021
1 parent 26c1aa4 commit 0cc0b18
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions src/rankcorr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,22 @@ function corkendall!(x::RealVector, y::RealVector, permx=sortperm(x))
npairs = div(n * (n - 1), 2)
ntiesx, ntiesy, ndoubleties, k, nswaps = 0, 0, 0, 0, 0

for i 2:n
@inbounds for i 2:n
if x[i - 1] == x[i]
k += 1
elseif k > 0
# Sort the corresponding chunk of y, so the rows of hcat(x,y) are
# sorted first on x, then (where x values are tied) on y. Hence
# double ties can be counted by calling countties.
mergesort!(y, i - k - 1, i - 1)
ntiesx += k * (k + 1) / 2
ntiesx += div(k * (k + 1) , 2)
ndoubleties += countties(y, i - k - 1, i - 1)
k = 0
end
end
if k > 0
mergesort!(y, n - k, n)
ntiesx += k * (k + 1) / 2
ntiesx += div(k * (k + 1) , 2)
ndoubleties += countties(y, n - k, n)
end

Expand All @@ -55,17 +55,18 @@ Assumes `x` is sorted. Returns the number of ties within `x[lo:hi]`.
"""
function countties(x::AbstractVector, lo::Integer, hi::Integer)
thistiecount, result = 0, 0
for i (lo + 1):hi
(lo < 1 || hi > length(x)) && error("Bounds error")
@inbounds for i (lo + 1):hi
if x[i] == x[i - 1]
thistiecount += 1
elseif thistiecount > 0
result += (thistiecount * (thistiecount + 1)) / 2
result += div(thistiecount * (thistiecount + 1) , 2)
thistiecount = 0
end
end

if thistiecount > 0
result += (thistiecount * (thistiecount + 1)) / 2
result += div(thistiecount * (thistiecount + 1), 2)
end
result
end
Expand All @@ -85,7 +86,7 @@ corkendall(x::RealVector, Y::RealMatrix) = (n = size(Y, 2); permx = sortperm(x);
function corkendall(X::RealMatrix)
n = size(X, 2)
C = ones(float(eltype(X)), n, n)# avoids dependency on LinearAlgebra
for j 2:n
@inbounds for j 2:n
permx = sortperm(X[:,j])
for i 1:j - 1
C[j,i] = corkendall!(X[:,j], X[:,i], permx)
Expand All @@ -99,7 +100,7 @@ function corkendall(X::RealMatrix, Y::RealMatrix)
nr = size(X, 2)
nc = size(Y, 2)
C = zeros(float(eltype(X)), nr, nc)
for j 1:nr
@inbounds for j 1:nr
permx = sortperm(X[:,j])
for i 1:nc
C[j,i] = corkendall!(X[:,j], Y[:,i], permx)
Expand All @@ -110,10 +111,11 @@ end

# Auxilliary functions for Kendall's rank correlation

# Same value for this constant as in base/sort.jl. Method speedtestmergesort seems
# to show that a value of 64 is fractionaly (2% ?) faster but safer to follow base/sort.jl, which has it at 20.
# Method speedtest_mergesort appears to to show that a value of 64 is optimal,
# but note that the equivalent constant in base/sort.jl is 20.
const SMALL_THRESHOLD = 64

#Copy was from https://github.com/JuliaLang/julia/commit/28330a2fef4d9d149ba0fd3ffa06347b50067647 dated 20 Sep 2020
"""
mergesort!(v::AbstractVector, lo::Integer, hi::Integer, t=similar(v, 0))
Expand Down Expand Up @@ -160,19 +162,18 @@ function mergesort!(v::AbstractVector, lo::Integer, hi::Integer, t=similar(v, 0)
return nswaps
end

# This implementation of `midpoint` is performance-optimized but safe
# only if `lo <= hi`.
# This function is copied from base/sort.jl
# This implementation of `midpoint` is performance-optimized but safe only if `lo <= hi`.
# This function is also copied from base/sort.jl
midpoint(lo::T, hi::T) where T <: Integer = lo + ((hi - lo) >>> 0x01)
midpoint(lo::Integer, hi::Integer) = midpoint(promote(lo, hi)...)


#Copy was from https://github.com/JuliaLang/julia/commit/28330a2fef4d9d149ba0fd3ffa06347b50067647 dated 20 Sep 2020
"""
insertionsort!(v::AbstractVector, lo::Integer, hi::Integer)
Mutates `v` by sorting elements `x[lo:hi]` using the insertionsort algorithm.
This method is a copy-paste-edit of sort! in base/sort.jl (the method specialised on InsertionSortAlg),
but amended to return the bubblesort distance.
amended to return the bubblesort distance.
"""
function insertionsort!(v::AbstractVector, lo::Integer, hi::Integer)
if lo == hi return 0 end
Expand All @@ -192,4 +193,4 @@ function insertionsort!(v::AbstractVector, lo::Integer, hi::Integer)
v[j] = x
end
return nswaps
end
end

0 comments on commit 0cc0b18

Please sign in to comment.