Skip to content

Commit 50c33aa

Browse files
authored
feat: Support Bucket and Truncate transforms on write (apache#1345)
* introduce bucket transform * include pyiceberg-core * introduce bucket transform * include pyiceberg-core * resolve poetry conflict * support truncate transforms * Remove stale comment * fix poetry hash * avoid codespell error for truncate transform * adopt nits
1 parent 0a3a886 commit 50c33aa

File tree

5 files changed

+259
-20
lines changed

5 files changed

+259
-20
lines changed

poetry.lock

Lines changed: 17 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyiceberg/transforms.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@
8585
if TYPE_CHECKING:
8686
import pyarrow as pa
8787

88+
ArrayLike = TypeVar("ArrayLike", pa.Array, pa.ChunkedArray)
89+
8890
S = TypeVar("S")
8991
T = TypeVar("T")
9092

@@ -193,6 +195,27 @@ def supports_pyarrow_transform(self) -> bool:
193195
@abstractmethod
194196
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": ...
195197

198+
def _pyiceberg_transform_wrapper(
199+
self, transform_func: Callable[["ArrayLike", Any], "ArrayLike"], *args: Any
200+
) -> Callable[["ArrayLike"], "ArrayLike"]:
201+
try:
202+
import pyarrow as pa
203+
except ModuleNotFoundError as e:
204+
raise ModuleNotFoundError("For bucket/truncate transforms, PyArrow needs to be installed") from e
205+
206+
def _transform(array: "ArrayLike") -> "ArrayLike":
207+
if isinstance(array, pa.Array):
208+
return transform_func(array, *args)
209+
elif isinstance(array, pa.ChunkedArray):
210+
result_chunks = []
211+
for arr in array.iterchunks():
212+
result_chunks.append(transform_func(arr, *args))
213+
return pa.chunked_array(result_chunks)
214+
else:
215+
raise ValueError(f"PyArrow array can only be of type pa.Array or pa.ChunkedArray, but found {type(array)}")
216+
217+
return _transform
218+
196219

197220
class BucketTransform(Transform[S, int]):
198221
"""Base Transform class to transform a value into a bucket partition value.
@@ -309,7 +332,13 @@ def __repr__(self) -> str:
309332
return f"BucketTransform(num_buckets={self._num_buckets})"
310333

311334
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
312-
raise NotImplementedError()
335+
from pyiceberg_core import transform as pyiceberg_core_transform
336+
337+
return self._pyiceberg_transform_wrapper(pyiceberg_core_transform.bucket, self._num_buckets)
338+
339+
@property
340+
def supports_pyarrow_transform(self) -> bool:
341+
return True
313342

314343

315344
class TimeResolution(IntEnum):
@@ -827,7 +856,13 @@ def __repr__(self) -> str:
827856
return f"TruncateTransform(width={self._width})"
828857

829858
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
830-
raise NotImplementedError()
859+
from pyiceberg_core import transform as pyiceberg_core_transform
860+
861+
return self._pyiceberg_transform_wrapper(pyiceberg_core_transform.truncate, self._width)
862+
863+
@property
864+
def supports_pyarrow_transform(self) -> bool:
865+
return True
831866

832867

833868
@singledispatch

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ psycopg2-binary = { version = ">=2.9.6", optional = true }
7979
sqlalchemy = { version = "^2.0.18", optional = true }
8080
getdaft = { version = ">=0.2.12", optional = true }
8181
cachetools = "^5.5.0"
82+
pyiceberg-core = { version = "^0.4.0", optional = true }
8283

8384
[tool.poetry.group.dev.dependencies]
8485
pytest = "7.4.4"
@@ -842,6 +843,10 @@ ignore_missing_imports = true
842843
module = "daft.*"
843844
ignore_missing_imports = true
844845

846+
[[tool.mypy.overrides]]
847+
module = "pyiceberg_core.*"
848+
ignore_missing_imports = true
849+
845850
[[tool.mypy.overrides]]
846851
module = "pyparsing.*"
847852
ignore_missing_imports = true
@@ -1206,6 +1211,7 @@ sql-postgres = ["sqlalchemy", "psycopg2-binary"]
12061211
sql-sqlite = ["sqlalchemy"]
12071212
gcsfs = ["gcsfs"]
12081213
rest-sigv4 = ["boto3"]
1214+
pyiceberg-core = ["pyiceberg-core"]
12091215

12101216
[tool.pytest.ini_options]
12111217
markers = [

tests/integration/test_writes/test_partitioned_writes.py

Lines changed: 157 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,12 @@ def test_dynamic_partition_overwrite_unpartitioned_evolve_to_identity_transform(
412412
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, part_col: str, format_version: int
413413
) -> None:
414414
identifier = f"default.unpartitioned_table_v{format_version}_evolve_into_identity_transformed_partition_field_{part_col}"
415+
416+
try:
417+
session_catalog.drop_table(identifier=identifier)
418+
except NoSuchTableError:
419+
pass
420+
415421
tbl = session_catalog.create_table(
416422
identifier=identifier,
417423
schema=TABLE_SCHEMA,
@@ -756,6 +762,55 @@ def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog) -> Non
756762
tbl.append("not a df")
757763

758764

765+
@pytest.mark.integration
766+
@pytest.mark.parametrize(
767+
"spec",
768+
[
769+
(PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=TruncateTransform(2), name="int_trunc"))),
770+
(PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=TruncateTransform(2), name="long_trunc"))),
771+
(PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=TruncateTransform(2), name="string_trunc"))),
772+
],
773+
)
774+
@pytest.mark.parametrize("format_version", [1, 2])
775+
def test_truncate_transform(
776+
spec: PartitionSpec,
777+
spark: SparkSession,
778+
session_catalog: Catalog,
779+
arrow_table_with_null: pa.Table,
780+
format_version: int,
781+
) -> None:
782+
identifier = "default.truncate_transform"
783+
784+
try:
785+
session_catalog.drop_table(identifier=identifier)
786+
except NoSuchTableError:
787+
pass
788+
789+
tbl = _create_table(
790+
session_catalog=session_catalog,
791+
identifier=identifier,
792+
properties={"format-version": str(format_version)},
793+
data=[arrow_table_with_null],
794+
partition_spec=spec,
795+
)
796+
797+
assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}"
798+
df = spark.table(identifier)
799+
assert df.count() == 3, f"Expected 3 total rows for {identifier}"
800+
for col in arrow_table_with_null.column_names:
801+
assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}"
802+
assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null"
803+
804+
assert tbl.inspect.partitions().num_rows == 3
805+
files_df = spark.sql(
806+
f"""
807+
SELECT *
808+
FROM {identifier}.files
809+
"""
810+
)
811+
assert files_df.count() == 3
812+
813+
759814
@pytest.mark.integration
760815
@pytest.mark.parametrize(
761816
"spec",
@@ -767,18 +822,52 @@ def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog) -> Non
767822
PartitionField(source_id=1, field_id=1002, transform=IdentityTransform(), name="bool"),
768823
)
769824
),
770-
# none of non-identity is supported
771-
(PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket"))),
772-
(PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=BucketTransform(2), name="long_bucket"))),
773-
(PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=BucketTransform(2), name="date_bucket"))),
774-
(PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=BucketTransform(2), name="timestamp_bucket"))),
775-
(PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=BucketTransform(2), name="timestamptz_bucket"))),
776-
(PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=BucketTransform(2), name="string_bucket"))),
777-
(PartitionSpec(PartitionField(source_id=12, field_id=1001, transform=BucketTransform(2), name="fixed_bucket"))),
778-
(PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=BucketTransform(2), name="binary_bucket"))),
779-
(PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=TruncateTransform(2), name="int_trunc"))),
780-
(PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=TruncateTransform(2), name="long_trunc"))),
781-
(PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=TruncateTransform(2), name="string_trunc"))),
825+
],
826+
)
827+
@pytest.mark.parametrize("format_version", [1, 2])
828+
def test_identity_and_bucket_transform_spec(
829+
spec: PartitionSpec,
830+
spark: SparkSession,
831+
session_catalog: Catalog,
832+
arrow_table_with_null: pa.Table,
833+
format_version: int,
834+
) -> None:
835+
identifier = "default.identity_and_bucket_transform"
836+
837+
try:
838+
session_catalog.drop_table(identifier=identifier)
839+
except NoSuchTableError:
840+
pass
841+
842+
tbl = _create_table(
843+
session_catalog=session_catalog,
844+
identifier=identifier,
845+
properties={"format-version": str(format_version)},
846+
data=[arrow_table_with_null],
847+
partition_spec=spec,
848+
)
849+
850+
assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}"
851+
df = spark.table(identifier)
852+
assert df.count() == 3, f"Expected 3 total rows for {identifier}"
853+
for col in arrow_table_with_null.column_names:
854+
assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}"
855+
assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null"
856+
857+
assert tbl.inspect.partitions().num_rows == 3
858+
files_df = spark.sql(
859+
f"""
860+
SELECT *
861+
FROM {identifier}.files
862+
"""
863+
)
864+
assert files_df.count() == 3
865+
866+
867+
@pytest.mark.integration
868+
@pytest.mark.parametrize(
869+
"spec",
870+
[
782871
(PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=TruncateTransform(2), name="binary_trunc"))),
783872
],
784873
)
@@ -801,11 +890,66 @@ def test_unsupported_transform(
801890

802891
with pytest.raises(
803892
ValueError,
804-
match="Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: *",
893+
match="FeatureUnsupported => Unsupported data type for truncate transform: LargeBinary",
805894
):
806895
tbl.append(arrow_table_with_null)
807896

808897

898+
@pytest.mark.integration
899+
@pytest.mark.parametrize(
900+
"spec, expected_rows",
901+
[
902+
(PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket")), 3),
903+
(PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=BucketTransform(2), name="long_bucket")), 3),
904+
(PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=BucketTransform(2), name="date_bucket")), 3),
905+
(PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=BucketTransform(2), name="timestamp_bucket")), 3),
906+
(PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=BucketTransform(2), name="timestamptz_bucket")), 3),
907+
(PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=BucketTransform(2), name="string_bucket")), 3),
908+
(PartitionSpec(PartitionField(source_id=12, field_id=1001, transform=BucketTransform(2), name="fixed_bucket")), 2),
909+
(PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=BucketTransform(2), name="binary_bucket")), 2),
910+
],
911+
)
912+
@pytest.mark.parametrize("format_version", [1, 2])
913+
def test_bucket_transform(
914+
spark: SparkSession,
915+
session_catalog: Catalog,
916+
arrow_table_with_null: pa.Table,
917+
spec: PartitionSpec,
918+
expected_rows: int,
919+
format_version: int,
920+
) -> None:
921+
identifier = "default.bucket_transform"
922+
923+
try:
924+
session_catalog.drop_table(identifier=identifier)
925+
except NoSuchTableError:
926+
pass
927+
928+
tbl = _create_table(
929+
session_catalog=session_catalog,
930+
identifier=identifier,
931+
properties={"format-version": str(format_version)},
932+
data=[arrow_table_with_null],
933+
partition_spec=spec,
934+
)
935+
936+
assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}"
937+
df = spark.table(identifier)
938+
assert df.count() == 3, f"Expected 3 total rows for {identifier}"
939+
for col in arrow_table_with_null.column_names:
940+
assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}"
941+
assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null"
942+
943+
assert tbl.inspect.partitions().num_rows == expected_rows
944+
files_df = spark.sql(
945+
f"""
946+
SELECT *
947+
FROM {identifier}.files
948+
"""
949+
)
950+
assert files_df.count() == expected_rows
951+
952+
809953
@pytest.mark.integration
810954
@pytest.mark.parametrize(
811955
"transform,expected_rows",

tests/test_transforms.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@
1818
# pylint: disable=eval-used,protected-access,redefined-outer-name
1919
from datetime import date
2020
from decimal import Decimal
21-
from typing import TYPE_CHECKING, Any, Callable, Optional
21+
from typing import Any, Callable, Optional, Union
2222
from uuid import UUID
2323

2424
import mmh3 as mmh3
25+
import pyarrow as pa
2526
import pytest
2627
from pydantic import (
2728
BeforeValidator,
@@ -116,9 +117,6 @@
116117
timestamptz_to_micros,
117118
)
118119

119-
if TYPE_CHECKING:
120-
import pyarrow as pa
121-
122120

123121
@pytest.mark.parametrize(
124122
"test_input,test_type,expected",
@@ -1563,3 +1561,43 @@ def test_ymd_pyarrow_transforms(
15631561
else:
15641562
with pytest.raises(ValueError):
15651563
transform.pyarrow_transform(DateType())(arrow_table_date_timestamps[source_col])
1564+
1565+
1566+
@pytest.mark.parametrize(
1567+
"source_type, input_arr, expected, num_buckets",
1568+
[
1569+
(IntegerType(), pa.array([1, 2]), pa.array([6, 2], type=pa.int32()), 10),
1570+
(
1571+
IntegerType(),
1572+
pa.chunked_array([pa.array([1, 2]), pa.array([3, 4])]),
1573+
pa.chunked_array([pa.array([6, 2], type=pa.int32()), pa.array([5, 0], type=pa.int32())]),
1574+
10,
1575+
),
1576+
(IntegerType(), pa.array([1, 2]), pa.array([6, 2], type=pa.int32()), 10),
1577+
],
1578+
)
1579+
def test_bucket_pyarrow_transforms(
1580+
source_type: PrimitiveType,
1581+
input_arr: Union[pa.Array, pa.ChunkedArray],
1582+
expected: Union[pa.Array, pa.ChunkedArray],
1583+
num_buckets: int,
1584+
) -> None:
1585+
transform: Transform[Any, Any] = BucketTransform(num_buckets=num_buckets)
1586+
assert expected == transform.pyarrow_transform(source_type)(input_arr)
1587+
1588+
1589+
@pytest.mark.parametrize(
1590+
"source_type, input_arr, expected, width",
1591+
[
1592+
(StringType(), pa.array(["developer", "iceberg"]), pa.array(["dev", "ice"]), 3),
1593+
(IntegerType(), pa.array([1, -1]), pa.array([0, -10]), 10),
1594+
],
1595+
)
1596+
def test_truncate_pyarrow_transforms(
1597+
source_type: PrimitiveType,
1598+
input_arr: Union[pa.Array, pa.ChunkedArray],
1599+
expected: Union[pa.Array, pa.ChunkedArray],
1600+
width: int,
1601+
) -> None:
1602+
transform: Transform[Any, Any] = TruncateTransform(width=width)
1603+
assert expected == transform.pyarrow_transform(source_type)(input_arr)

0 commit comments

Comments
 (0)