Skip to content

Commit 14368e7

Browse files
committed
test: more polars enabled tests
Signed-off-by: Ion Koutsouris <[email protected]>
1 parent 4ef9fb3 commit 14368e7

File tree

1 file changed

+193
-0
lines changed

1 file changed

+193
-0
lines changed

python/tests/test_table_read.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,47 @@ def test_read_table_with_edge_timestamps():
5757
assert len(list(dataset.get_fragments(predicate))) == 1
5858

5959

60+
@pytest.mark.polars
61+
def test_read_table_with_edge_timestamps_polars():
62+
os.environ["POLARS_NEW_MULTIFILE"] = "1"
63+
import polars as pl
64+
65+
table_path = "../crates/test/tests/data/table_with_edge_timestamps"
66+
dt = DeltaTable(table_path)
67+
dataset = pl.scan_delta(dt).collect().to_arrow()
68+
assert dataset.to_pydict() == {
69+
"BIG_DATE": [
70+
datetime(9999, 12, 31, 0, 0, 0, tzinfo=timezone.utc),
71+
datetime(9999, 12, 30, 0, 0, 0, tzinfo=timezone.utc),
72+
],
73+
"NORMAL_DATE": [
74+
datetime(2022, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
75+
datetime(2022, 2, 1, 0, 0, 0, tzinfo=timezone.utc),
76+
],
77+
"SOME_VALUE": [1, 2],
78+
}
79+
# Can push down filters to these timestamps.
80+
predicate = ds.field("BIG_DATE") == datetime(
81+
9999, 12, 31, 0, 0, 0, tzinfo=timezone.utc
82+
)
83+
assert len(list(dataset.get_fragments(predicate))) == 1
84+
85+
6086
def test_read_simple_table_to_dict():
6187
table_path = "../crates/test/tests/data/simple_table"
6288
dt = DeltaTable(table_path)
6389
assert dt.to_pyarrow_dataset().to_table().to_pydict() == {"id": [5, 7, 9]}
6490

6591

92+
@pytest.mark.polars
93+
def test_read_simple_table_to_dict_polars():
94+
import polars as pl
95+
96+
table_path = "../crates/test/tests/data/simple_table"
97+
dt = DeltaTable(table_path)
98+
assert pl.scan_delta(dt).collect().to_arrow().to_pydict() == {"id": [5, 7, 9]}
99+
100+
66101
class _SerializableException(BaseException):
67102
pass
68103

@@ -85,6 +120,24 @@ def _recursively_read_simple_table(executor_class: Type[Executor], depth):
85120
future.result()
86121

87122

123+
def _recursively_read_simple_table_polars(executor_class: Type[Executor], depth):
124+
try:
125+
test_read_simple_table_to_dict_polars()
126+
except BaseException as e: # Ideally this would catch `pyo3_runtime.PanicException` but its seems that is not possible.
127+
# Re-raise as something that can be serialized and therefore sent back to parent processes.
128+
raise _SerializableException(f"Seraializatble exception: {e}") from e
129+
130+
if depth == 0:
131+
return
132+
# We use concurrent.futures.Executors instead of `threading.Thread` or `multiprocessing.Process` to that errors
133+
# are re-rasied in the parent process/thread when we call `future.result()`.
134+
with executor_class(max_workers=1) as executor:
135+
future = executor.submit(
136+
_recursively_read_simple_table_polars, executor_class, depth - 1
137+
)
138+
future.result()
139+
140+
88141
@pytest.mark.parametrize(
89142
"executor_class,multiprocessing_start_method,expect_panic",
90143
[
@@ -109,6 +162,42 @@ def test_read_simple_in_threads_and_processes(
109162
_recursively_read_simple_table(executor_class=executor_class, depth=5)
110163

111164

165+
@pytest.mark.polars
166+
@pytest.mark.parametrize(
167+
"executor_class,multiprocessing_start_method,expect_panic",
168+
[
169+
(ThreadPoolExecutor, None, False),
170+
(ProcessPoolExecutor, "forkserver", False),
171+
(ProcessPoolExecutor, "spawn", False),
172+
(ProcessPoolExecutor, "fork", True),
173+
],
174+
)
175+
def test_read_simple_in_threads_and_processes_polars(
176+
executor_class, multiprocessing_start_method, expect_panic
177+
):
178+
if multiprocessing_start_method is not None:
179+
multiprocessing.set_start_method(multiprocessing_start_method, force=True)
180+
if expect_panic:
181+
with pytest.raises(
182+
_SerializableException,
183+
match="The tokio runtime does not support forked processes",
184+
):
185+
_recursively_read_simple_table_polars(
186+
executor_class=executor_class, depth=5
187+
)
188+
else:
189+
_recursively_read_simple_table_polars(executor_class=executor_class, depth=5)
190+
191+
192+
@pytest.mark.polars
193+
def test_read_simple_table_by_version_to_dict_polars():
194+
import polars as pl
195+
196+
table_path = "../crates/test/tests/data/delta-0.2.0"
197+
dt = DeltaTable(table_path, version=2)
198+
assert pl.scan_delta(dt).collect().to_arrow().to_pydict() == {"value": [1, 2, 3]}
199+
200+
112201
def test_read_simple_table_by_version_to_dict():
113202
table_path = "../crates/test/tests/data/delta-0.2.0"
114203
dt = DeltaTable(table_path, version=2)
@@ -218,6 +307,19 @@ def test_read_simple_table_update_incremental():
218307
assert dt.to_pyarrow_dataset().to_table().to_pydict() == {"id": [5, 7, 9]}
219308

220309

310+
@pytest.mark.polars
311+
def test_read_simple_table_update_incremental_polars():
312+
import polars as pl
313+
314+
table_path = "../crates/test/tests/data/simple_table"
315+
dt = DeltaTable(table_path, version=0)
316+
data = pl.scan_delta(dt).collect().to_arrow()
317+
assert data.to_pydict() == {"id": [0, 1, 2, 3, 4]}
318+
dt.update_incremental()
319+
data = pl.scan_delta(dt).collect().to_arrow()
320+
assert data.to_pydict() == {"id": [5, 7, 9]}
321+
322+
221323
def test_read_simple_table_file_sizes_failure(mocker):
222324
table_path = "../crates/test/tests/data/simple_table"
223325
dt = DeltaTable(table_path)
@@ -235,6 +337,22 @@ def test_read_simple_table_file_sizes_failure(mocker):
235337
dt.to_pyarrow_dataset().to_table().to_pydict()
236338

237339

340+
@pytest.mark.polars
341+
def test_read_partitioned_table_to_dict_polars():
342+
os.environ["POLARS_NEW_MULTIFILE"] = "1"
343+
import polars as pl
344+
345+
table_path = "../crates/test/tests/data/delta-0.8.0-partitioned"
346+
dt = DeltaTable(table_path)
347+
expected = {
348+
"value": ["1", "2", "3", "6", "7", "5", "4"],
349+
"year": ["2020", "2020", "2020", "2021", "2021", "2021", "2021"],
350+
"month": ["1", "2", "2", "12", "12", "12", "4"],
351+
"day": ["1", "3", "5", "20", "20", "4", "5"],
352+
}
353+
assert pl.scan_delta(dt).collect().to_arrow().to_pydict() == expected
354+
355+
238356
def test_read_partitioned_table_to_dict():
239357
table_path = "../crates/test/tests/data/delta-0.8.0-partitioned"
240358
dt = DeltaTable(table_path)
@@ -261,6 +379,27 @@ def test_read_partitioned_table_with_partitions_filters_to_dict():
261379
assert dt.to_pyarrow_dataset(partitions).to_table().to_pydict() == expected
262380

263381

382+
@pytest.mark.polars
383+
def test_read_partitioned_table_with_filters_to_dict_polars():
384+
os.environ["POLARS_NEW_MULTIFILE"] = "1"
385+
import polars as pl
386+
387+
table_path = "../crates/test/tests/data/delta-0.8.0-partitioned"
388+
dt = DeltaTable(table_path)
389+
partitions = pl.col("year") == "2021"
390+
expected = {
391+
"value": ["6", "7", "5", "4"],
392+
"year": ["2021", "2021", "2021", "2021"],
393+
"month": ["12", "12", "12", "4"],
394+
"day": ["20", "20", "4", "5"],
395+
}
396+
397+
assert (
398+
pl.scan_delta(dt).filter(partitions).collect().to_arrow().to_pydict()
399+
== expected
400+
)
401+
402+
264403
def test_read_empty_delta_table_after_delete():
265404
table_path = "../crates/test/tests/data/delta-0.8-empty"
266405
dt = DeltaTable(table_path)
@@ -269,6 +408,17 @@ def test_read_empty_delta_table_after_delete():
269408
assert dt.to_pyarrow_dataset().to_table().to_pydict() == expected
270409

271410

411+
@pytest.mark.polars
412+
def test_read_empty_delta_table_after_delete_polars():
413+
import polars as pl
414+
415+
table_path = "../crates/test/tests/data/delta-0.8-empty"
416+
dt = DeltaTable(table_path)
417+
expected = {"column": []}
418+
419+
assert pl.scan_delta(dt).collect().to_arrow().to_pydict() == expected
420+
421+
272422
def test_read_table_with_column_subset():
273423
table_path = "../crates/test/tests/data/delta-0.8.0-partitioned"
274424
dt = DeltaTable(table_path)
@@ -282,6 +432,22 @@ def test_read_table_with_column_subset():
282432
)
283433

284434

435+
@pytest.mark.polars
436+
def test_read_table_with_column_subset_polars():
437+
import polars as pl
438+
439+
table_path = "../crates/test/tests/data/delta-0.8.0-partitioned"
440+
dt = DeltaTable(table_path)
441+
expected = {
442+
"value": ["1", "2", "3", "6", "7", "5", "4"],
443+
"day": ["1", "3", "5", "20", "20", "4", "5"],
444+
}
445+
assert (
446+
pl.scan_delta(dt).select(["value", "day"]).collect().to_arrow().to_pydict()
447+
== expected
448+
)
449+
450+
285451
def test_read_table_as_category():
286452
table_path = "../crates/test/tests/data/delta-0.8.0-partitioned"
287453
dt = DeltaTable(table_path)
@@ -359,6 +525,33 @@ def test_read_special_partition():
359525
assert set(table["x"].to_pylist()) == {"A/A", "B B"}
360526

361527

528+
@pytest.mark.polars
529+
def test_read_special_partition_polars():
530+
os.environ["POLARS_NEW_MULTIFILE"] = "1"
531+
import polars as pl
532+
533+
table_path = "../crates/test/tests/data/delta-0.8.0-special-partition"
534+
535+
dt = DeltaTable(table_path)
536+
537+
file1 = (
538+
r"x=A%2FA/part-00007-b350e235-2832-45df-9918-6cab4f7578f7.c000.snappy.parquet"
539+
)
540+
file2 = (
541+
r"x=B%20B/part-00015-e9abbc6f-85e9-457b-be8e-e9f5b8a22890.c000.snappy.parquet"
542+
)
543+
544+
assert set(dt.files()) == {file1, file2}
545+
546+
assert dt.files([("x", "=", "A/A")]) == [file1]
547+
assert dt.files([("x", "=", "B B")]) == [file2]
548+
assert dt.files([("x", "=", "c")]) == []
549+
550+
table = pl.scan_delta(dt).collect().to_arrow()
551+
552+
assert set(table["x"].to_pylist()) == {"A/A", "B B"}
553+
554+
362555
def test_read_partitioned_table_metadata():
363556
table_path = "../crates/test/tests/data/delta-0.8.0-partitioned"
364557
dt = DeltaTable(table_path)

0 commit comments

Comments
 (0)