From a66f137c05be8b747f70d53e6fd6a942815640c5 Mon Sep 17 00:00:00 2001 From: Philip Swannell <18028484+PGS62@users.noreply.github.com> Date: Wed, 3 Apr 2024 19:13:09 +0100 Subject: [PATCH] wrote Equal_sum_subsets iterator --- src/pairwise.jl | 68 +++++++++++++++++++++++++++++++++++++++++++++++-- src/rankcorr.jl | 4 +-- 2 files changed, 68 insertions(+), 4 deletions(-) diff --git a/src/pairwise.jl b/src/pairwise.jl index 11dbbcd..0013058 100644 --- a/src/pairwise.jl +++ b/src/pairwise.jl @@ -17,7 +17,7 @@ function _pairwise!(::Val{:none}, f, dest::AbstractMatrix{V}, x, y, #cov(x) is faster than cov(x, x) (f == cov) && (f = ((x, y) -> x === y ? cov(x) : cov(x, y))) - Threads.@threads for subset in equal_sum_subsets(nc, Threads.nthreads()) + Threads.@threads for subset in Equal_sum_subsets(nc, Threads.nthreads()) for j in subset for i = (symmetric ? j : 1):nr # For performance, diagonal is special-cased @@ -96,7 +96,7 @@ function _pairwise!(::Val{:pairwise}, f, dest::AbstractMatrix{V}, x, y, symmetri nmtx = promoted_nmtype(x)[] nmty = promoted_nmtype(y)[] - Threads.@threads for subset in equal_sum_subsets(nc, Threads.nthreads()) + Threads.@threads for subset in Equal_sum_subsets(nc, Threads.nthreads()) scratch_fx = task_local_vector(:scratch_fx, nmtx, m) scratch_fy = task_local_vector(:scratch_fy, nmty, m) for j in subset @@ -393,6 +393,8 @@ function handle_pairwise(x::AbstractVector, y::AbstractVector; return view(scratch_fx, lb:(j-1)), view(scratch_fy, lb:(j-1)) end +#TODO remove tsts for equal_sum_subsets, add tests for Equal_sum_subsets, rename to Equal_sum_vectors? + #=Condition a) makes equal_sum_subsets useful for load-balancing the multi-threaded section of _pairwise! in the non-symmetric case, and condition b) for the symmetric case.=# """ @@ -440,4 +442,66 @@ function task_local_vector(key::Symbol, similarto::AbstractArray{V}, length::Int)::Vector{V} where {V} haskey(task_local_storage(), key) || task_local_storage(key, similar(similarto, length)) return task_local_storage(key) +end + +#Alternative approach - use an iterator. + +""" + Equal_sum_subsets + +An iterator enabling the partition of the integers 1 to n into `num_subsets` vectors such +that a) each subset has (approximately) the same number of elements; and b) the sum of the +elements in each subset is nearly equal. If `n` is a multiple of `2 * num_subsets` both +conditions hold exactly. + +## Example +```julia-repl +julia> for s in KendallTau.Equal_sum_subsets(30,5) +println((s, sum(s))) +end +([30, 21, 20, 11, 10, 1], 93) +([29, 22, 19, 12, 9, 2], 93) +([28, 23, 18, 13, 8, 3], 93) +([27, 24, 17, 14, 7, 4], 93) +([26, 25, 16, 15, 6, 5], 93) +``` +""" +struct Equal_sum_subsets + n::Int64 + num_subsets::Int64 +end + +Base.length(x::Equal_sum_subsets) = min(x.n, x.num_subsets) + +Base.firstindex(x::Equal_sum_subsets) = 1 + +function Base.iterate(ess::Equal_sum_subsets, state::Int64=1) + state > length(ess) && return (nothing) + return (getindex(ess, state), state + 1) +end + +function demo(n, num_subsets) + for s in Equal_sum_subsets(n, num_subsets) + println(s, sum(s)) + end +end + +function Base.getindex(ess::Equal_sum_subsets, i::Int64=1) + i > length(ess) && return (nothing) + n, nss = ess.n, ess.num_subsets + s = 2i - 1 + b = 2nss - s + result = zeros(Int64, div(n, nss) + ((i <= rem(n, nss)) ? 1 : 0)) + result[1] = n - i + 1 + k = 1 + while true + x = result[k] - (mod(k, 2) == 0 ? s : b) + if x > 0 + k += 1 + result[k] = x + else + break + end + end + return result end \ No newline at end of file diff --git a/src/rankcorr.jl b/src/rankcorr.jl index 2b06d98..42535ad 100644 --- a/src/rankcorr.jl +++ b/src/rankcorr.jl @@ -112,7 +112,7 @@ function _pairwise!(::Val{:pairwise}, f::typeof(corspearman), fl64 = Float64[] nmtx = promoted_nmtype(x)[] nmty = promoted_nmtype(y)[] - Threads.@threads for subset in equal_sum_subsets(nr, Threads.nthreads()) + Threads.@threads for subset in Equal_sum_subsets(nr, Threads.nthreads()) for i in subset @@ -482,7 +482,7 @@ function corkendall_loop!(skipmissing::Symbol, f::typeof(corkendall), dest::Abst symmetric = x === y - Threads.@threads for subset in equal_sum_subsets(nr, Threads.nthreads()) + Threads.@threads for subset in Equal_sum_subsets(nr, Threads.nthreads()) for i in subset