Skip to content
66 changes: 61 additions & 5 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,14 @@
)
from pyiceberg.partitioning import PartitionField, PartitionFieldValue, PartitionKey, PartitionSpec, partition_record_value
from pyiceberg.schema import (
Accessor,
PartnerAccessor,
PreOrderSchemaVisitor,
Schema,
SchemaVisitorPerPrimitiveType,
SchemaWithPartnerVisitor,
_check_schema_compatible,
build_position_accessors,
pre_order_visit,
promote,
prune_columns,
Expand All @@ -138,7 +140,7 @@
)
from pyiceberg.table.metadata import TableMetadata
from pyiceberg.table.name_mapping import NameMapping, apply_name_mapping
from pyiceberg.transforms import TruncateTransform
from pyiceberg.transforms import IdentityTransform, TruncateTransform
from pyiceberg.typedef import EMPTY_DICT, Properties, Record
from pyiceberg.types import (
BinaryType,
Expand Down Expand Up @@ -1216,6 +1218,45 @@ def _field_id(self, field: pa.Field) -> int:
return -1


def _get_column_projection_values(
file: DataFile, projected_schema: Schema, partition_spec: Optional[PartitionSpec], file_project_field_ids: Set[int]
) -> Tuple[bool, Dict[str, Any]]:
"""Apply Column Projection rules to File Schema."""
project_schema_diff = projected_schema.field_ids.difference(file_project_field_ids)
should_project_columns = len(project_schema_diff) > 0
projected_missing_fields: Dict[str, Any] = {}

if not should_project_columns:
return False, {}

partition_schema: StructType
accessors: Dict[int, Accessor]

if partition_spec is not None:
partition_schema = partition_spec.partition_type(projected_schema)
accessors = build_position_accessors(partition_schema)
else:
return False, {}

for field_id in project_schema_diff:
for partition_field in partition_spec.fields_by_source_id(field_id):
if isinstance(partition_field.transform, IdentityTransform):
accessor = accessors.get(partition_field.field_id)

if accessor is None:
continue

# The partition field may not exist in the partition record of the data file.
# This can happen when new partition fields are introduced after the file was written.
try:
if partition_value := accessor.get(file.partition):
projected_missing_fields[partition_field.name] = partition_value
except IndexError:
continue

return True, projected_missing_fields


def _task_to_record_batches(
fs: FileSystem,
task: FileScanTask,
Expand All @@ -1226,6 +1267,7 @@ def _task_to_record_batches(
case_sensitive: bool,
name_mapping: Optional[NameMapping] = None,
use_large_types: bool = True,
partition_spec: Optional[PartitionSpec] = None,
) -> Iterator[pa.RecordBatch]:
_, _, path = _parse_location(task.file.file_path)
arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
Expand All @@ -1237,16 +1279,20 @@ def _task_to_record_batches(
# When V3 support is introduced, we will update `downcast_ns_timestamp_to_us` flag based on
# the table format version.
file_schema = pyarrow_to_schema(physical_schema, name_mapping, downcast_ns_timestamp_to_us=True)

pyarrow_filter = None
if bound_row_filter is not AlwaysTrue():
translated_row_filter = translate_column_names(bound_row_filter, file_schema, case_sensitive=case_sensitive)
bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive)
pyarrow_filter = expression_to_pyarrow(bound_file_filter)

file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False)
# Apply column projection rules
# https://iceberg.apache.org/spec/#column-projection
should_project_columns, projected_missing_fields = _get_column_projection_values(
task.file, projected_schema, partition_spec, file_schema.field_ids
)

if file_schema is None:
raise ValueError(f"Missing Iceberg schema in Metadata for file: {path}")
file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False)

fragment_scanner = ds.Scanner.from_fragment(
fragment=fragment,
Expand Down Expand Up @@ -1286,14 +1332,23 @@ def _task_to_record_batches(
continue
output_batches = arrow_table.to_batches()
for output_batch in output_batches:
yield _to_requested_schema(
result_batch = _to_requested_schema(
projected_schema,
file_project_schema,
output_batch,
downcast_ns_timestamp_to_us=True,
use_large_types=use_large_types,
)

# Inject projected column values if available
if should_project_columns:
for name, value in projected_missing_fields.items():
index = result_batch.schema.get_field_index(name)
if index != -1:
result_batch = result_batch.set_column(index, name, [value])

yield result_batch


def _task_to_table(
fs: FileSystem,
Expand Down Expand Up @@ -1517,6 +1572,7 @@ def _record_batches_from_scan_tasks_and_deletes(
self._case_sensitive,
self._table_metadata.name_mapping(),
self._use_large_types,
self._table_metadata.spec(),
)
for batch in batches:
if self._limit is not None:
Expand Down
132 changes: 132 additions & 0 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,18 @@
_read_deletes,
_to_requested_schema,
bin_pack_arrow_table,
compute_statistics_plan,
data_file_statistics_from_parquet_metadata,
expression_to_pyarrow,
parquet_path_to_id_mapping,
schema_to_pyarrow,
)
from pyiceberg.manifest import DataFile, DataFileContent, FileFormat
from pyiceberg.partitioning import PartitionField, PartitionSpec
from pyiceberg.schema import Schema, make_compatible_name, visit
from pyiceberg.table import FileScanTask, TableProperties
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.table.name_mapping import create_mapping_from_schema
from pyiceberg.transforms import IdentityTransform
from pyiceberg.typedef import UTF8, Properties, Record
from pyiceberg.types import (
Expand All @@ -99,6 +103,7 @@
TimestamptzType,
TimeType,
)
from tests.catalog.test_base import InMemoryCatalog
from tests.conftest import UNIFIED_AWS_SESSION_PROPERTIES


Expand Down Expand Up @@ -1122,6 +1127,133 @@ def test_projection_concat_files(schema_int: Schema, file_int: str) -> None:
assert repr(result_table.schema) == "id: int32"


def test_identity_transform_column_projection(tmp_path: str, catalog: InMemoryCatalog) -> None:
# Test by adding a non-partitioned data file to a partitioned table, verifying partition value projection from manifest metadata.
# TODO: Update to use a data file created by writing data to an unpartitioned table once add_files supports field IDs.
# (context: https://github.com/apache/iceberg-python/pull/1443#discussion_r1901374875)

schema = Schema(
NestedField(1, "other_field", StringType(), required=False), NestedField(2, "partition_id", IntegerType(), required=False)
)

partition_spec = PartitionSpec(
PartitionField(2, 1000, IdentityTransform(), "partition_id"),
)

table = catalog.create_table(
"default.test_projection_partition",
schema=schema,
partition_spec=partition_spec,
properties={TableProperties.DEFAULT_NAME_MAPPING: create_mapping_from_schema(schema).model_dump_json()},
)

file_data = pa.array(["foo"], type=pa.string())
file_loc = f"{tmp_path}/test.parquet"
pq.write_table(pa.table([file_data], names=["other_field"]), file_loc)

statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=pq.read_metadata(file_loc),
stats_columns=compute_statistics_plan(table.schema(), table.metadata.properties),
parquet_column_mapping=parquet_path_to_id_mapping(table.schema()),
)

unpartitioned_file = DataFile(
content=DataFileContent.DATA,
file_path=file_loc,
file_format=FileFormat.PARQUET,
# projected value
partition=Record(partition_id=1),
file_size_in_bytes=os.path.getsize(file_loc),
sort_order_id=None,
spec_id=table.metadata.default_spec_id,
equality_ids=None,
key_metadata=None,
**statistics.to_serialized_dict(),
)

with table.transaction() as transaction:
with transaction.update_snapshot().overwrite() as update:
update.append_data_file(unpartitioned_file)

assert (
str(table.scan().to_arrow())
== """pyarrow.Table
other_field: large_string
partition_id: int64
----
other_field: [["foo"]]
partition_id: [[1]]"""
)


def test_identity_transform_columns_projection(tmp_path: str, catalog: InMemoryCatalog) -> None:
# Test by adding a non-partitioned data file to a multi-partitioned table, verifying partition value projection from manifest metadata.
# TODO: Update to use a data file created by writing data to an unpartitioned table once add_files supports field IDs.
# (context: https://github.com/apache/iceberg-python/pull/1443#discussion_r1901374875)
schema = Schema(
NestedField(1, "field_1", StringType(), required=False),
NestedField(2, "field_2", IntegerType(), required=False),
NestedField(3, "field_3", IntegerType(), required=False),
)

partition_spec = PartitionSpec(
PartitionField(2, 1000, IdentityTransform(), "field_2"),
PartitionField(3, 1001, IdentityTransform(), "field_3"),
)

table = catalog.create_table(
"default.test_projection_partitions",
schema=schema,
partition_spec=partition_spec,
properties={TableProperties.DEFAULT_NAME_MAPPING: create_mapping_from_schema(schema).model_dump_json()},
)

file_data = pa.array(["foo"], type=pa.string())
file_loc = f"{tmp_path}/test.parquet"
pq.write_table(pa.table([file_data], names=["field_1"]), file_loc)

statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=pq.read_metadata(file_loc),
stats_columns=compute_statistics_plan(table.schema(), table.metadata.properties),
parquet_column_mapping=parquet_path_to_id_mapping(table.schema()),
)

unpartitioned_file = DataFile(
content=DataFileContent.DATA,
file_path=file_loc,
file_format=FileFormat.PARQUET,
# projected value
partition=Record(field_2=2, field_3=3),
file_size_in_bytes=os.path.getsize(file_loc),
sort_order_id=None,
spec_id=table.metadata.default_spec_id,
equality_ids=None,
key_metadata=None,
**statistics.to_serialized_dict(),
)

with table.transaction() as transaction:
with transaction.update_snapshot().overwrite() as update:
update.append_data_file(unpartitioned_file)

assert (
str(table.scan().to_arrow())
== """pyarrow.Table
field_1: large_string
field_2: int64
field_3: int64
----
field_1: [["foo"]]
field_2: [[2]]
field_3: [[3]]"""
)


@pytest.fixture
def catalog() -> InMemoryCatalog:
return InMemoryCatalog("test.in_memory.catalog", **{"test.key": "test.value"})


def test_projection_filter(schema_int: Schema, file_int: str) -> None:
result_table = project(schema_int, [file_int], GreaterThan("id", 4))
assert len(result_table.columns[0]) == 0
Expand Down