Skip to content

Commit

Permalink
Utility for sharded filtering by key (#531)
Browse files Browse the repository at this point in the history
  • Loading branch information
dvadym authored Jan 14, 2025
1 parent 88a8acb commit b0ef34c
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 0 deletions.
45 changes: 45 additions & 0 deletions pipeline_dp/pipeline_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
their implementation is framework-agnostic because they use only other primitive
operations declared in PipelineBackend interface."""

import random
from typing import Type, Dict, Any, Callable, TypeVar, Iterable, List

from pipeline_dp import pipeline_backend
Expand Down Expand Up @@ -113,3 +114,47 @@ def min_max_per_key(backend: pipeline_backend.PipelineBackend, col,
)
# col: (key, (min, max))
return col


def filter_by_key_with_sharding(backend: pipeline_backend.PipelineBackend, col,
keys_to_keep, sharding_factor: int,
stage_name: str):
"""Filters out elements with keys which are not in `keys_to_keep`.
It's like `backend.filter_by_key`, but for scalable filtering split each key
into `sharding_factor` subkeys.
Args:
backend: backend to use to perform the computation.
col: collection with elements (key, data).
keys_to_keep: collection of keys to keep, both local (currently `list` and
`set`) and distributed collections are supported.
sharding_factor: number of subkeys to split each key into.
stage_name: name of the stage.
Returns:
A filtered collection containing only data belonging to keys_to_keep.
"""
if sharding_factor > 1:
col = backend.map_tuple(
col,
lambda k, v: ((k, random.randint(0, sharding_factor - 1)), v),
f"Sharding each key into {sharding_factor} subkeys",
)
keys_to_keep = backend.flat_map(
keys_to_keep,
lambda p: tuple((p, i) for i in range(sharding_factor)),
f"Shard partitions into {sharding_factor} keys",
)
# to_multi_transformable_collection is no-op for not LocalMode. For
# local mode it is transform iterable to list, which is neded because
# filter_by_key requires list.
keys_to_keep = backend.to_multi_transformable_collection(keys_to_keep)

col_filtered = backend.filter_by_key(col, keys_to_keep, stage_name)

if sharding_factor > 1:
col_filtered = backend.map_tuple(col_filtered, lambda k, v: (k[0], v),
"Remove sharding factor")

return col_filtered
9 changes: 9 additions & 0 deletions tests/pipeline_functions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,15 @@ def test_min_max_per_key(self, col, expected_min_max):

self.assertEqual(expected_min_max, list(result))

@parameterized.parameters(1, 5)
def test_local_filter_by_key(self, sharding_factor):
col = [(7, 1), (2, 1), (3, 9), (4, 1), (9, 10), (7, 4), (7, 5)]
keys_to_keep = [7, 9]
result = composite_funcs.filter_by_key_with_sharding(
self.backend, col, keys_to_keep, sharding_factor, "filter_by_key")
self.assertEqual(sorted(list(result)), [(7, 1), (7, 4), (7, 5),
(9, 10)])


@unittest.skipIf(sys.platform == 'win32' or sys.platform == 'darwin',
"Problems with serialisation on Windows and macOS")
Expand Down

0 comments on commit b0ef34c

Please sign in to comment.