From 0cc0b18fca95afa7df954783170c6a23bf63c4a4 Mon Sep 17 00:00:00 2001 From: Philip Swannell Date: Sat, 23 Jan 2021 14:54:11 +0000 Subject: [PATCH] fix type instabilities, more use of @inbounds --- src/rankcorr.jl | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/src/rankcorr.jl b/src/rankcorr.jl index e87d701..ea094c1 100644 --- a/src/rankcorr.jl +++ b/src/rankcorr.jl @@ -22,7 +22,7 @@ function corkendall!(x::RealVector, y::RealVector, permx=sortperm(x)) npairs = div(n * (n - 1), 2) ntiesx, ntiesy, ndoubleties, k, nswaps = 0, 0, 0, 0, 0 - for i ∈ 2:n + @inbounds for i ∈ 2:n if x[i - 1] == x[i] k += 1 elseif k > 0 @@ -30,14 +30,14 @@ function corkendall!(x::RealVector, y::RealVector, permx=sortperm(x)) # sorted first on x, then (where x values are tied) on y. Hence # double ties can be counted by calling countties. mergesort!(y, i - k - 1, i - 1) - ntiesx += k * (k + 1) / 2 + ntiesx += div(k * (k + 1) , 2) ndoubleties += countties(y, i - k - 1, i - 1) k = 0 end end if k > 0 mergesort!(y, n - k, n) - ntiesx += k * (k + 1) / 2 + ntiesx += div(k * (k + 1) , 2) ndoubleties += countties(y, n - k, n) end @@ -55,17 +55,18 @@ Assumes `x` is sorted. Returns the number of ties within `x[lo:hi]`. """ function countties(x::AbstractVector, lo::Integer, hi::Integer) thistiecount, result = 0, 0 - for i ∈ (lo + 1):hi + (lo < 1 || hi > length(x)) && error("Bounds error") + @inbounds for i ∈ (lo + 1):hi if x[i] == x[i - 1] thistiecount += 1 elseif thistiecount > 0 - result += (thistiecount * (thistiecount + 1)) / 2 + result += div(thistiecount * (thistiecount + 1) , 2) thistiecount = 0 end end if thistiecount > 0 - result += (thistiecount * (thistiecount + 1)) / 2 + result += div(thistiecount * (thistiecount + 1), 2) end result end @@ -85,7 +86,7 @@ corkendall(x::RealVector, Y::RealMatrix) = (n = size(Y, 2); permx = sortperm(x); function corkendall(X::RealMatrix) n = size(X, 2) C = ones(float(eltype(X)), n, n)# avoids dependency on LinearAlgebra - for j ∈ 2:n + @inbounds for j ∈ 2:n permx = sortperm(X[:,j]) for i ∈ 1:j - 1 C[j,i] = corkendall!(X[:,j], X[:,i], permx) @@ -99,7 +100,7 @@ function corkendall(X::RealMatrix, Y::RealMatrix) nr = size(X, 2) nc = size(Y, 2) C = zeros(float(eltype(X)), nr, nc) - for j ∈ 1:nr + @inbounds for j ∈ 1:nr permx = sortperm(X[:,j]) for i ∈ 1:nc C[j,i] = corkendall!(X[:,j], Y[:,i], permx) @@ -110,10 +111,11 @@ end # Auxilliary functions for Kendall's rank correlation -# Same value for this constant as in base/sort.jl. Method speedtestmergesort seems -# to show that a value of 64 is fractionaly (2% ?) faster but safer to follow base/sort.jl, which has it at 20. +# Method speedtest_mergesort appears to to show that a value of 64 is optimal, +# but note that the equivalent constant in base/sort.jl is 20. const SMALL_THRESHOLD = 64 +#Copy was from https://github.com/JuliaLang/julia/commit/28330a2fef4d9d149ba0fd3ffa06347b50067647 dated 20 Sep 2020 """ mergesort!(v::AbstractVector, lo::Integer, hi::Integer, t=similar(v, 0)) @@ -160,19 +162,18 @@ function mergesort!(v::AbstractVector, lo::Integer, hi::Integer, t=similar(v, 0) return nswaps end -# This implementation of `midpoint` is performance-optimized but safe -# only if `lo <= hi`. -# This function is copied from base/sort.jl +# This implementation of `midpoint` is performance-optimized but safe only if `lo <= hi`. +# This function is also copied from base/sort.jl midpoint(lo::T, hi::T) where T <: Integer = lo + ((hi - lo) >>> 0x01) midpoint(lo::Integer, hi::Integer) = midpoint(promote(lo, hi)...) - +#Copy was from https://github.com/JuliaLang/julia/commit/28330a2fef4d9d149ba0fd3ffa06347b50067647 dated 20 Sep 2020 """ insertionsort!(v::AbstractVector, lo::Integer, hi::Integer) Mutates `v` by sorting elements `x[lo:hi]` using the insertionsort algorithm. This method is a copy-paste-edit of sort! in base/sort.jl (the method specialised on InsertionSortAlg), -but amended to return the bubblesort distance. +amended to return the bubblesort distance. """ function insertionsort!(v::AbstractVector, lo::Integer, hi::Integer) if lo == hi return 0 end @@ -192,4 +193,4 @@ function insertionsort!(v::AbstractVector, lo::Integer, hi::Integer) v[j] = x end return nswaps -end +end \ No newline at end of file