Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making set_intersection work with zip-like iterators #6351

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,24 @@
namespace hpx::parallel::detail {
/// \cond NOINTERNAL

template <typename T>
struct decay_tuple
{
using type = T;
};

template <typename T>
struct decay_tuple<std::tuple<T>>
{
using type = std::tuple<std::remove_const_t<T>>;
};

template <typename T>
struct decay_tuple<std::tuple<T&>>
{
using type = std::tuple<std::remove_const_t<T>>;
};

///////////////////////////////////////////////////////////////////////////
template <typename FwdIter>
struct set_operations_buffer
Expand All @@ -43,17 +61,24 @@ namespace hpx::parallel::detail {
public:
rewritable_ref() = default;

explicit constexpr rewritable_ref(T const& item) noexcept
explicit constexpr rewritable_ref(T& item) noexcept
: item_(&item)
{
}

rewritable_ref& operator=(T const& item)
rewritable_ref& operator=(T& item)
{
item_ = &item;
return *this;
}

template <typename U>
rewritable_ref& operator=(U const& item)
{
*item_ = item;
return *this;
}

// different versions of clang-format produce different results
// clang-format off
operator T const&() const
Expand All @@ -64,7 +89,7 @@ namespace hpx::parallel::detail {
// clang-format on

private:
T const* item_ = nullptr;
T* item_ = nullptr;
};

using value_type = typename std::iterator_traits<FwdIter>::value_type;
Expand Down Expand Up @@ -149,8 +174,13 @@ namespace hpx::parallel::detail {
bool const first_partition = start1 == 0;
bool const last_partition = end1 == static_cast<std::size_t>(len1);

auto start_value = HPX_INVOKE(proj1, first1[start1]);
auto end_value = HPX_INVOKE(proj1, first1[end1]);
using result_type =
std::invoke_result_t<Proj1, hpx::traits::iter_value_t<Iter1>>;
using element_type =
typename decay_tuple<std::decay_t<result_type>>::type;

element_type start_value = HPX_INVOKE(proj1, first1[start1]);
element_type end_value = HPX_INVOKE(proj1, first1[end1]);

// all but the last chunk require special handling
if (!last_partition)
Expand All @@ -166,7 +196,8 @@ namespace hpx::parallel::detail {
// last element of the current chunk
if (end1 != 0)
{
auto end_value1 = HPX_INVOKE(proj1, first1[end1 - 1]);
element_type end_value1 =
HPX_INVOKE(proj1, first1[end1 - 1]);

while (!HPX_INVOKE(f, end_value1, end_value) && --end1 != 0)
{
Expand All @@ -180,7 +211,8 @@ namespace hpx::parallel::detail {
// first element of the current chunk
if (start1 != 0)
{
auto start_value1 = HPX_INVOKE(proj1, first1[start1 - 1]);
element_type start_value1 =
HPX_INVOKE(proj1, first1[start1 - 1]);

while (
!HPX_INVOKE(f, start_value1, start_value) && --start1 != 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,6 @@ namespace hpx::parallel {
HPX_MOVE(first1), HPX_MOVE(first2), HPX_MOVE(dest)});
}

using buffer_type = typename set_operations_buffer<Iter3>::type;
using func_type = std::decay_t<F>;

// calculate approximate destination index
Expand All @@ -286,8 +285,8 @@ namespace hpx::parallel {

// perform required set operation for one chunk
auto f2 = [proj1, proj2](Iter1 part_first1, Sent1 part_last1,
Iter2 part_first2, Sent2 part_last2,
buffer_type* d, func_type const& f) {
Iter2 part_first2, Sent2 part_last2, auto* d,
func_type const& f) {
return sequential_set_intersection(part_first1, part_last1,
part_first2, part_last2, d, f, proj1, proj2);
};
Expand Down