Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement column projection #1443

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
66 changes: 61 additions & 5 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,14 @@
)
from pyiceberg.partitioning import PartitionField, PartitionFieldValue, PartitionKey, PartitionSpec, partition_record_value
from pyiceberg.schema import (
Accessor,
PartnerAccessor,
PreOrderSchemaVisitor,
Schema,
SchemaVisitorPerPrimitiveType,
SchemaWithPartnerVisitor,
_check_schema_compatible,
build_position_accessors,
pre_order_visit,
promote,
prune_columns,
Expand All @@ -138,7 +140,7 @@
)
from pyiceberg.table.metadata import TableMetadata
from pyiceberg.table.name_mapping import NameMapping, apply_name_mapping
from pyiceberg.transforms import TruncateTransform
from pyiceberg.transforms import IdentityTransform, TruncateTransform
from pyiceberg.typedef import EMPTY_DICT, Properties, Record
from pyiceberg.types import (
BinaryType,
Expand Down Expand Up @@ -1216,6 +1218,45 @@ def _field_id(self, field: pa.Field) -> int:
return -1


def _get_column_projection_values(
file: DataFile, projected_schema: Schema, partition_spec: Optional[PartitionSpec], file_project_field_ids: Set[int]
) -> Tuple[bool, Dict[str, Any]]:
"""Apply Column Projection rules to File Schema."""
project_schema_diff = projected_schema.field_ids.difference(file_project_field_ids)
should_project_columns = len(project_schema_diff) > 0
projected_missing_fields: Dict[str, Any] = {}

if not should_project_columns:
return False, {}

partition_schema: StructType
accessors: Dict[int, Accessor]

if partition_spec is not None:
partition_schema = partition_spec.partition_type(projected_schema)
accessors = build_position_accessors(partition_schema)
else:
return False, {}

for field_id in project_schema_diff:
for partition_field in partition_spec.fields_by_source_id(field_id):
if isinstance(partition_field.transform, IdentityTransform):
accessor = accessors.get(partition_field.field_id)

if accessor is None:
continue

# The partition field may not exist in the partition record of the data file.
# This can happen when new partition fields are introduced after the file was written.
try:
if partition_value := accessor.get(file.partition):
projected_missing_fields[partition_field.name] = partition_value
except IndexError:
continue

return True, projected_missing_fields


def _task_to_record_batches(
fs: FileSystem,
task: FileScanTask,
Expand All @@ -1226,6 +1267,7 @@ def _task_to_record_batches(
case_sensitive: bool,
name_mapping: Optional[NameMapping] = None,
use_large_types: bool = True,
partition_spec: Optional[PartitionSpec] = None,
) -> Iterator[pa.RecordBatch]:
_, _, path = _parse_location(task.file.file_path)
arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
Expand All @@ -1237,16 +1279,20 @@ def _task_to_record_batches(
# When V3 support is introduced, we will update `downcast_ns_timestamp_to_us` flag based on
# the table format version.
file_schema = pyarrow_to_schema(physical_schema, name_mapping, downcast_ns_timestamp_to_us=True)

pyarrow_filter = None
if bound_row_filter is not AlwaysTrue():
translated_row_filter = translate_column_names(bound_row_filter, file_schema, case_sensitive=case_sensitive)
bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive)
pyarrow_filter = expression_to_pyarrow(bound_file_filter)

file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False)
# Apply column projection rules
# https://iceberg.apache.org/spec/#column-projection
should_project_columns, projected_missing_fields = _get_column_projection_values(
task.file, projected_schema, partition_spec, file_schema.field_ids
)

if file_schema is None:
raise ValueError(f"Missing Iceberg schema in Metadata for file: {path}")
file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False)

fragment_scanner = ds.Scanner.from_fragment(
fragment=fragment,
Expand Down Expand Up @@ -1286,14 +1332,23 @@ def _task_to_record_batches(
continue
output_batches = arrow_table.to_batches()
for output_batch in output_batches:
yield _to_requested_schema(
result_batch = _to_requested_schema(
projected_schema,
file_project_schema,
output_batch,
downcast_ns_timestamp_to_us=True,
use_large_types=use_large_types,
)

# Inject projected column values if available
if should_project_columns:
for name, value in projected_missing_fields.items():
index = result_batch.schema.get_field_index(name)
if index != -1:
result_batch = result_batch.set_column(index, name, [value])

yield result_batch


def _task_to_table(
fs: FileSystem,
Expand Down Expand Up @@ -1517,6 +1572,7 @@ def _record_batches_from_scan_tasks_and_deletes(
self._case_sensitive,
self._table_metadata.name_mapping(),
self._use_large_types,
self._table_metadata.spec(),
)
for batch in batches:
if self._limit is not None:
Expand Down
132 changes: 132 additions & 0 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,18 @@
_read_deletes,
_to_requested_schema,
bin_pack_arrow_table,
compute_statistics_plan,
data_file_statistics_from_parquet_metadata,
expression_to_pyarrow,
parquet_path_to_id_mapping,
schema_to_pyarrow,
)
from pyiceberg.manifest import DataFile, DataFileContent, FileFormat
from pyiceberg.partitioning import PartitionField, PartitionSpec
from pyiceberg.schema import Schema, make_compatible_name, visit
from pyiceberg.table import FileScanTask, TableProperties
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.table.name_mapping import create_mapping_from_schema
from pyiceberg.transforms import IdentityTransform
from pyiceberg.typedef import UTF8, Properties, Record
from pyiceberg.types import (
Expand All @@ -99,6 +103,7 @@
TimestamptzType,
TimeType,
)
from tests.catalog.test_base import InMemoryCatalog
from tests.conftest import UNIFIED_AWS_SESSION_PROPERTIES


Expand Down Expand Up @@ -1122,6 +1127,133 @@ def test_projection_concat_files(schema_int: Schema, file_int: str) -> None:
assert repr(result_table.schema) == "id: int32"


def test_identity_transform_column_projection(tmp_path: str, catalog: InMemoryCatalog) -> None:
# Test by adding a non-partitioned data file to a partitioned table, verifying partition value projection from manifest metadata.
# TODO: Update to use a data file created by writing data to an unpartitioned table once add_files supports field IDs.
# (context: https://github.com/apache/iceberg-python/pull/1443#discussion_r1901374875)

schema = Schema(
NestedField(1, "other_field", StringType(), required=False), NestedField(2, "partition_id", IntegerType(), required=False)
)

partition_spec = PartitionSpec(
PartitionField(2, 1000, IdentityTransform(), "partition_id"),
)

table = catalog.create_table(
"default.test_projection_partition",
schema=schema,
partition_spec=partition_spec,
properties={TableProperties.DEFAULT_NAME_MAPPING: create_mapping_from_schema(schema).model_dump_json()},
)

file_data = pa.array(["foo"], type=pa.string())
file_loc = f"{tmp_path}/test.parquet"
pq.write_table(pa.table([file_data], names=["other_field"]), file_loc)

statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=pq.read_metadata(file_loc),
stats_columns=compute_statistics_plan(table.schema(), table.metadata.properties),
parquet_column_mapping=parquet_path_to_id_mapping(table.schema()),
)

unpartitioned_file = DataFile(
content=DataFileContent.DATA,
file_path=file_loc,
file_format=FileFormat.PARQUET,
# projected value
partition=Record(partition_id=1),
gabeiglio marked this conversation as resolved.
Show resolved Hide resolved
file_size_in_bytes=os.path.getsize(file_loc),
sort_order_id=None,
spec_id=table.metadata.default_spec_id,
equality_ids=None,
key_metadata=None,
**statistics.to_serialized_dict(),
)

with table.transaction() as transaction:
with transaction.update_snapshot().overwrite() as update:
update.append_data_file(unpartitioned_file)

assert (
str(table.scan().to_arrow())
== """pyarrow.Table
other_field: large_string
partition_id: int64
----
other_field: [["foo"]]
partition_id: [[1]]"""
)


def test_identity_transform_columns_projection(tmp_path: str, catalog: InMemoryCatalog) -> None:
gabeiglio marked this conversation as resolved.
Show resolved Hide resolved
# Test by adding a non-partitioned data file to a multi-partitioned table, verifying partition value projection from manifest metadata.
# TODO: Update to use a data file created by writing data to an unpartitioned table once add_files supports field IDs.
# (context: https://github.com/apache/iceberg-python/pull/1443#discussion_r1901374875)
schema = Schema(
NestedField(1, "field_1", StringType(), required=False),
NestedField(2, "field_2", IntegerType(), required=False),
NestedField(3, "field_3", IntegerType(), required=False),
)

partition_spec = PartitionSpec(
PartitionField(2, 1000, IdentityTransform(), "field_2"),
PartitionField(3, 1001, IdentityTransform(), "field_3"),
)

table = catalog.create_table(
"default.test_projection_partitions",
schema=schema,
partition_spec=partition_spec,
properties={TableProperties.DEFAULT_NAME_MAPPING: create_mapping_from_schema(schema).model_dump_json()},
)

file_data = pa.array(["foo"], type=pa.string())
file_loc = f"{tmp_path}/test.parquet"
pq.write_table(pa.table([file_data], names=["field_1"]), file_loc)

statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=pq.read_metadata(file_loc),
stats_columns=compute_statistics_plan(table.schema(), table.metadata.properties),
parquet_column_mapping=parquet_path_to_id_mapping(table.schema()),
)

unpartitioned_file = DataFile(
content=DataFileContent.DATA,
file_path=file_loc,
file_format=FileFormat.PARQUET,
# projected value
partition=Record(field_2=2, field_3=3),
file_size_in_bytes=os.path.getsize(file_loc),
sort_order_id=None,
spec_id=table.metadata.default_spec_id,
equality_ids=None,
key_metadata=None,
**statistics.to_serialized_dict(),
)

with table.transaction() as transaction:
with transaction.update_snapshot().overwrite() as update:
update.append_data_file(unpartitioned_file)

assert (
str(table.scan().to_arrow())
== """pyarrow.Table
field_1: large_string
field_2: int64
field_3: int64
----
field_1: [["foo"]]
field_2: [[2]]
field_3: [[3]]"""
)


@pytest.fixture
def catalog() -> InMemoryCatalog:
return InMemoryCatalog("test.in_memory.catalog", **{"test.key": "test.value"})


def test_projection_filter(schema_int: Schema, file_int: str) -> None:
result_table = project(schema_int, [file_int], GreaterThan("id", 4))
assert len(result_table.columns[0]) == 0
Expand Down