Skip to content

Commit 2115023

Browse files
gaogaotiantianHyukjinKwon
authored andcommitted
[SPARK-54384][PYTHON] Modernize the _batched method for BatchedSerializer
### What changes were proposed in this pull request? Use the modern itertools to do `_batch` function in `BatchedSerializer` to make code cleaner and faster. The code is about 170% faster than the original implementation. <details> <summary> Result with the following code ``` Batching batch_original took 0.3086 seconds Batching batch_after took 0.1159 seconds ``` </summary> ```python import itertools import time def batch_original(iterator, batch_size): items = [] count = 0 for item in iterator: items.append(item) count += 1 if count == batch_size: yield items items = [] count = 0 if items: yield items def batch_list(iterator, batch_size): n = len(iterator) for i in range(0, n, batch_size): yield iterator[i : i + batch_size] def batch_after(iterator, batch_size): it = iter(iterator) while batch := list(itertools.islice(it, batch_size)): yield batch def do_test(iterator, batch): result = [] start = time.perf_counter_ns() for b in batch(iterator, 10000): result.append(b) end = time.perf_counter_ns() print(f"Batching {batch.__name__} took {(end - start)/1e9:.4f} seconds") return result if __name__ == "__main__": data = range(10000005) result_original = do_test(data, batch_original) result_after = do_test(data, batch_after) assert result_original == result_after data = list(range(10000005)) result_list = do_test(data, batch_list) result_after = do_test(data, batch_after) assert result_list == result_after ``` </details> Notice that `__getslice__` is **removed** since Python 3.0, so the optimization for known size iterators like lists is not working at all. There's no simple way to know if an iterator supports slice operation now. The most straightforward way is to try it out like `iterator[:1]` - I don't know how frequent we are dealing with lists, if the iterator is often lists, then we can do it. The raw `[:]` operation is 22% faster than this implementation. I like the simplicity without the `try ... except ...` block. ### Why are the changes needed? Most importantly, the code is less verbose. Also it's much faster. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? The script above checks if the result is the same as before. Also we will have CI. ### Was this patch authored or co-authored using generative AI tooling? No Closes #53086 from gaogaotiantian/modernize-batch. Authored-by: Tian Gao <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent 1db267e commit 2115023

File tree

1 file changed

+3
-15
lines changed

1 file changed

+3
-15
lines changed

python/pyspark/serializers.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -203,22 +203,10 @@ def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE):
203203
def _batched(self, iterator):
204204
if self.batchSize == self.UNLIMITED_BATCH_SIZE:
205205
yield list(iterator)
206-
elif hasattr(iterator, "__len__") and hasattr(iterator, "__getslice__"):
207-
n = len(iterator)
208-
for i in range(0, n, self.batchSize):
209-
yield iterator[i : i + self.batchSize]
210206
else:
211-
items = []
212-
count = 0
213-
for item in iterator:
214-
items.append(item)
215-
count += 1
216-
if count == self.batchSize:
217-
yield items
218-
items = []
219-
count = 0
220-
if items:
221-
yield items
207+
it = iter(iterator)
208+
while batch := list(itertools.islice(it, self.batchSize)):
209+
yield batch
222210

223211
def dump_stream(self, iterator, stream):
224212
self.serializer.dump_stream(self._batched(iterator), stream)

0 commit comments

Comments
 (0)