Skip to content

Commit

Permalink
wrote Equal_sum_subsets iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
PGS62 committed Apr 3, 2024
1 parent 7193fa3 commit a66f137
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 4 deletions.
68 changes: 66 additions & 2 deletions src/pairwise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ function _pairwise!(::Val{:none}, f, dest::AbstractMatrix{V}, x, y,
#cov(x) is faster than cov(x, x)
(f == cov) && (f = ((x, y) -> x === y ? cov(x) : cov(x, y)))

Threads.@threads for subset in equal_sum_subsets(nc, Threads.nthreads())
Threads.@threads for subset in Equal_sum_subsets(nc, Threads.nthreads())
for j in subset
for i = (symmetric ? j : 1):nr
# For performance, diagonal is special-cased
Expand Down Expand Up @@ -96,7 +96,7 @@ function _pairwise!(::Val{:pairwise}, f, dest::AbstractMatrix{V}, x, y, symmetri
nmtx = promoted_nmtype(x)[]
nmty = promoted_nmtype(y)[]

Threads.@threads for subset in equal_sum_subsets(nc, Threads.nthreads())
Threads.@threads for subset in Equal_sum_subsets(nc, Threads.nthreads())
scratch_fx = task_local_vector(:scratch_fx, nmtx, m)
scratch_fy = task_local_vector(:scratch_fy, nmty, m)
for j in subset
Expand Down Expand Up @@ -393,6 +393,8 @@ function handle_pairwise(x::AbstractVector, y::AbstractVector;
return view(scratch_fx, lb:(j-1)), view(scratch_fy, lb:(j-1))
end

#TODO remove tsts for equal_sum_subsets, add tests for Equal_sum_subsets, rename to Equal_sum_vectors?

#=Condition a) makes equal_sum_subsets useful for load-balancing the multi-threaded section
of _pairwise! in the non-symmetric case, and condition b) for the symmetric case.=#
"""
Expand Down Expand Up @@ -440,4 +442,66 @@ function task_local_vector(key::Symbol, similarto::AbstractArray{V},
length::Int)::Vector{V} where {V}
haskey(task_local_storage(), key) || task_local_storage(key, similar(similarto, length))
return task_local_storage(key)
end

#Alternative approach - use an iterator.

"""
Equal_sum_subsets
An iterator enabling the partition of the integers 1 to n into `num_subsets` vectors such
that a) each subset has (approximately) the same number of elements; and b) the sum of the
elements in each subset is nearly equal. If `n` is a multiple of `2 * num_subsets` both
conditions hold exactly.
## Example
```julia-repl
julia> for s in KendallTau.Equal_sum_subsets(30,5)
println((s, sum(s)))
end
([30, 21, 20, 11, 10, 1], 93)
([29, 22, 19, 12, 9, 2], 93)
([28, 23, 18, 13, 8, 3], 93)
([27, 24, 17, 14, 7, 4], 93)
([26, 25, 16, 15, 6, 5], 93)
```
"""
struct Equal_sum_subsets
n::Int64
num_subsets::Int64
end

Base.length(x::Equal_sum_subsets) = min(x.n, x.num_subsets)

Base.firstindex(x::Equal_sum_subsets) = 1

function Base.iterate(ess::Equal_sum_subsets, state::Int64=1)
state > length(ess) && return (nothing)
return (getindex(ess, state), state + 1)
end

function demo(n, num_subsets)
for s in Equal_sum_subsets(n, num_subsets)
println(s, sum(s))
end
end

function Base.getindex(ess::Equal_sum_subsets, i::Int64=1)
i > length(ess) && return (nothing)
n, nss = ess.n, ess.num_subsets
s = 2i - 1
b = 2nss - s
result = zeros(Int64, div(n, nss) + ((i <= rem(n, nss)) ? 1 : 0))
result[1] = n - i + 1
k = 1
while true
x = result[k] - (mod(k, 2) == 0 ? s : b)
if x > 0
k += 1
result[k] = x
else
break
end
end
return result
end
4 changes: 2 additions & 2 deletions src/rankcorr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ function _pairwise!(::Val{:pairwise}, f::typeof(corspearman),
fl64 = Float64[]
nmtx = promoted_nmtype(x)[]
nmty = promoted_nmtype(y)[]
Threads.@threads for subset in equal_sum_subsets(nr, Threads.nthreads())
Threads.@threads for subset in Equal_sum_subsets(nr, Threads.nthreads())

for i in subset

Expand Down Expand Up @@ -482,7 +482,7 @@ function corkendall_loop!(skipmissing::Symbol, f::typeof(corkendall), dest::Abst

symmetric = x === y

Threads.@threads for subset in equal_sum_subsets(nr, Threads.nthreads())
Threads.@threads for subset in Equal_sum_subsets(nr, Threads.nthreads())

for i in subset

Expand Down

0 comments on commit a66f137

Please sign in to comment.