|
18 | 18 |
|
19 | 19 | #include "detail/graph_partition_utils.cuh"
|
20 | 20 | #include "prims/edge_bucket.cuh"
|
| 21 | +#include "prims/fill_edge_property.cuh" |
21 | 22 | #include "prims/per_v_pair_dst_nbr_intersection.cuh"
|
22 | 23 | #include "prims/transform_e.cuh"
|
23 | 24 |
|
@@ -130,11 +131,37 @@ edge_property_t<edge_t, edge_t> edge_triangle_count_impl(
|
130 | 131 | bool do_expensive_check)
|
131 | 132 | {
|
132 | 133 | using weight_t = float;
|
| 134 | + |
| 135 | + CUGRAPH_EXPECTS( |
| 136 | + !graph_view.is_multigraph(), |
| 137 | + "Invalid input argument: edge triangle count currently does not support multi-graphs."); |
| 138 | + |
| 139 | + // Exclude self-loops |
| 140 | + |
| 141 | + std::optional<cugraph::edge_property_t<edge_t, bool>> self_loop_edge_mask{std::nullopt}; |
| 142 | + auto cur_graph_view = graph_view; |
| 143 | + if (cur_graph_view.count_self_loops(handle) > edge_t{0}) { |
| 144 | + self_loop_edge_mask = cugraph::edge_property_t<edge_t, bool>(handle, cur_graph_view); |
| 145 | + if (cur_graph_view.has_edge_mask()) { cur_graph_view.clear_edge_mask(); } |
| 146 | + cugraph::fill_edge_property(handle, cur_graph_view, self_loop_edge_mask->mutable_view(), false); |
| 147 | + |
| 148 | + transform_e(handle, |
| 149 | + graph_view, |
| 150 | + edge_src_dummy_property_t{}.view(), |
| 151 | + edge_dst_dummy_property_t{}.view(), |
| 152 | + edge_dummy_property_t{}.view(), |
| 153 | + cuda::proclaim_return_type<bool>( |
| 154 | + [] __device__(auto src, auto dst, auto, auto, auto) { return src != dst; }), |
| 155 | + self_loop_edge_mask->mutable_view()); |
| 156 | + |
| 157 | + cur_graph_view.attach_edge_mask(self_loop_edge_mask->view()); |
| 158 | + } |
| 159 | + |
133 | 160 | rmm::device_uvector<vertex_t> edgelist_srcs(0, handle.get_stream());
|
134 | 161 | rmm::device_uvector<vertex_t> edgelist_dsts(0, handle.get_stream());
|
135 | 162 | std::tie(edgelist_srcs, edgelist_dsts, std::ignore, std::ignore, std::ignore) =
|
136 | 163 | decompress_to_edgelist<vertex_t, edge_t, weight_t, int32_t>(
|
137 |
| - handle, graph_view, std::nullopt, std::nullopt, std::nullopt, std::nullopt); |
| 164 | + handle, cur_graph_view, std::nullopt, std::nullopt, std::nullopt, std::nullopt); |
138 | 165 |
|
139 | 166 | auto edge_first = thrust::make_zip_iterator(edgelist_srcs.begin(), edgelist_dsts.begin());
|
140 | 167 |
|
@@ -162,7 +189,7 @@ edge_property_t<edge_t, edge_t> edge_triangle_count_impl(
|
162 | 189 | // Perform 'nbr_intersection' in chunks to reduce peak memory.
|
163 | 190 | auto [intersection_offsets, intersection_indices] =
|
164 | 191 | per_v_pair_dst_nbr_intersection(handle,
|
165 |
| - graph_view, |
| 192 | + cur_graph_view, |
166 | 193 | edge_first + prev_chunk_size,
|
167 | 194 | edge_first + prev_chunk_size + chunk_size,
|
168 | 195 | do_expensive_check);
|
@@ -272,7 +299,7 @@ edge_property_t<edge_t, edge_t> edge_triangle_count_impl(
|
272 | 299 | std::nullopt,
|
273 | 300 | std::nullopt,
|
274 | 301 | std::nullopt,
|
275 |
| - graph_view.vertex_partition_range_lasts()); |
| 302 | + cur_graph_view.vertex_partition_range_lasts()); |
276 | 303 |
|
277 | 304 | thrust::for_each(
|
278 | 305 | handle.get_thrust_policy(),
|
@@ -335,16 +362,19 @@ edge_property_t<edge_t, edge_t> edge_triangle_count_impl(
|
335 | 362 | prev_chunk_size += chunk_size;
|
336 | 363 | }
|
337 | 364 |
|
338 |
| - cugraph::edge_property_t<edge_t, edge_t> counts(handle, graph_view); |
| 365 | + cugraph::edge_property_t<edge_t, edge_t> counts(handle, cur_graph_view); |
| 366 | + { |
| 367 | + auto unmasked_graph_view = cur_graph_view; |
| 368 | + if (unmasked_graph_view.has_edge_mask()) { unmasked_graph_view.clear_edge_mask(); } |
| 369 | + cugraph::fill_edge_property(handle, unmasked_graph_view, counts.mutable_view(), edge_t{0}); |
| 370 | + } |
339 | 371 |
|
340 | 372 | cugraph::edge_bucket_t<vertex_t, void, true, multi_gpu, true> valid_edges(handle);
|
341 | 373 | valid_edges.insert(edgelist_srcs.begin(), edgelist_srcs.end(), edgelist_dsts.begin());
|
342 | 374 |
|
343 |
| - auto cur_graph_view = graph_view; |
344 |
| - |
345 | 375 | cugraph::transform_e(
|
346 | 376 | handle,
|
347 |
| - graph_view, |
| 377 | + cur_graph_view, |
348 | 378 | valid_edges,
|
349 | 379 | cugraph::edge_src_dummy_property_t{}.view(),
|
350 | 380 | cugraph::edge_dst_dummy_property_t{}.view(),
|
|
0 commit comments