Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support all_entries in pyiceberg #1608

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mkdocs/docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,8 @@ readable_metrics: [
[6.0989]]
```

To show all the table's current manifest entries for both data and delete files, use `table.inspect.all_entries()`.

### References

To show a table's known snapshot references:
Expand Down
168 changes: 109 additions & 59 deletions pyiceberg/table/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Set, Tuple

from pyiceberg.conversions import from_bytes
from pyiceberg.manifest import DataFile, DataFileContent, ManifestContent, PartitionFieldSummary
from pyiceberg.manifest import DataFile, DataFileContent, ManifestContent, ManifestFile, PartitionFieldSummary
from pyiceberg.partitioning import PartitionSpec
from pyiceberg.table.snapshots import Snapshot, ancestors_of
from pyiceberg.types import PrimitiveType
Expand Down Expand Up @@ -94,7 +94,7 @@ def snapshots(self) -> "pa.Table":
schema=snapshots_schema,
)

def entries(self, snapshot_id: Optional[int] = None) -> "pa.Table":
def _get_entries_schema(self) -> "pa.Schema":
import pyarrow as pa

from pyiceberg.io.pyarrow import schema_to_pyarrow
Expand Down Expand Up @@ -137,6 +137,7 @@ def _readable_metrics_struct(bound_type: PrimitiveType) -> pa.StructType:
pa.field("content", pa.int8(), nullable=False),
pa.field("file_path", pa.string(), nullable=False),
pa.field("file_format", pa.string(), nullable=False),
pa.field("spec_id", pa.int32(), nullable=False),
pa.field("partition", pa_record_struct, nullable=False),
pa.field("record_count", pa.int64(), nullable=False),
pa.field("file_size_in_bytes", pa.int64(), nullable=False),
Expand All @@ -157,74 +158,96 @@ def _readable_metrics_struct(bound_type: PrimitiveType) -> pa.StructType:
pa.field("readable_metrics", pa.struct(readable_metrics_struct), nullable=True),
]
)
return entries_schema

def _get_entries(self, schema: "pa.Schema", manifest: ManifestFile, discard_deleted: bool = True) -> "pa.Table":
import pyarrow as pa

entries_schema = self._get_entries_schema()
entries = []
snapshot = self._get_snapshot(snapshot_id)
for manifest in snapshot.manifests(self.tbl.io):
for entry in manifest.fetch_manifest_entry(io=self.tbl.io):
column_sizes = entry.data_file.column_sizes or {}
value_counts = entry.data_file.value_counts or {}
null_value_counts = entry.data_file.null_value_counts or {}
nan_value_counts = entry.data_file.nan_value_counts or {}
lower_bounds = entry.data_file.lower_bounds or {}
upper_bounds = entry.data_file.upper_bounds or {}
readable_metrics = {
schema.find_column_name(field.field_id): {
"column_size": column_sizes.get(field.field_id),
"value_count": value_counts.get(field.field_id),
"null_value_count": null_value_counts.get(field.field_id),
"nan_value_count": nan_value_counts.get(field.field_id),
# Makes them readable
"lower_bound": from_bytes(field.field_type, lower_bound)
if (lower_bound := lower_bounds.get(field.field_id))
else None,
"upper_bound": from_bytes(field.field_type, upper_bound)
if (upper_bound := upper_bounds.get(field.field_id))
else None,
}
for field in self.tbl.metadata.schema().fields
for entry in manifest.fetch_manifest_entry(io=self.tbl.io, discard_deleted=discard_deleted):
column_sizes = entry.data_file.column_sizes or {}
value_counts = entry.data_file.value_counts or {}
null_value_counts = entry.data_file.null_value_counts or {}
nan_value_counts = entry.data_file.nan_value_counts or {}
lower_bounds = entry.data_file.lower_bounds or {}
upper_bounds = entry.data_file.upper_bounds or {}
readable_metrics = {
schema.find_column_name(field.field_id): {
"column_size": column_sizes.get(field.field_id),
"value_count": value_counts.get(field.field_id),
"null_value_count": null_value_counts.get(field.field_id),
"nan_value_count": nan_value_counts.get(field.field_id),
# Makes them readable
"lower_bound": from_bytes(field.field_type, lower_bound)
if (lower_bound := lower_bounds.get(field.field_id))
else None,
"upper_bound": from_bytes(field.field_type, upper_bound)
if (upper_bound := upper_bounds.get(field.field_id))
else None,
}
for field in self.tbl.metadata.schema().fields
}

partition = entry.data_file.partition
partition_record_dict = {
field.name: partition[pos]
for pos, field in enumerate(self.tbl.metadata.specs()[manifest.partition_spec_id].fields)
}
partition = entry.data_file.partition
partition_record_dict = {
field.name: partition[pos]
for pos, field in enumerate(self.tbl.metadata.specs()[manifest.partition_spec_id].fields)
}

entries.append(
{
"status": entry.status.value,
"snapshot_id": entry.snapshot_id,
"sequence_number": entry.sequence_number,
"file_sequence_number": entry.file_sequence_number,
"data_file": {
"content": entry.data_file.content,
"file_path": entry.data_file.file_path,
"file_format": entry.data_file.file_format,
"partition": partition_record_dict,
"record_count": entry.data_file.record_count,
"file_size_in_bytes": entry.data_file.file_size_in_bytes,
"column_sizes": dict(entry.data_file.column_sizes),
"value_counts": dict(entry.data_file.value_counts),
"null_value_counts": dict(entry.data_file.null_value_counts),
"nan_value_counts": dict(entry.data_file.nan_value_counts),
"lower_bounds": entry.data_file.lower_bounds,
"upper_bounds": entry.data_file.upper_bounds,
"key_metadata": entry.data_file.key_metadata,
"split_offsets": entry.data_file.split_offsets,
"equality_ids": entry.data_file.equality_ids,
"sort_order_id": entry.data_file.sort_order_id,
"spec_id": entry.data_file.spec_id,
},
"readable_metrics": readable_metrics,
}
)
entries.append(
{
"status": entry.status.value,
"snapshot_id": entry.snapshot_id,
"sequence_number": entry.sequence_number,
"file_sequence_number": entry.file_sequence_number,
"data_file": {
"content": entry.data_file.content,
"file_path": entry.data_file.file_path,
"file_format": entry.data_file.file_format,
"partition": partition_record_dict,
"record_count": entry.data_file.record_count,
"file_size_in_bytes": entry.data_file.file_size_in_bytes,
"column_sizes": dict(entry.data_file.column_sizes) or None,
"value_counts": dict(entry.data_file.value_counts) if entry.data_file.value_counts is not None else None,
"null_value_counts": dict(entry.data_file.null_value_counts)
if entry.data_file.null_value_counts is not None
else None,
"nan_value_counts": dict(entry.data_file.nan_value_counts)
if entry.data_file.nan_value_counts is not None
else None,
"lower_bounds": entry.data_file.lower_bounds,
"upper_bounds": entry.data_file.upper_bounds,
"key_metadata": entry.data_file.key_metadata,
"split_offsets": entry.data_file.split_offsets,
"equality_ids": entry.data_file.equality_ids,
"sort_order_id": entry.data_file.sort_order_id,
"spec_id": entry.data_file.spec_id,
},
"readable_metrics": readable_metrics,
}
)

return pa.Table.from_pylist(
entries,
schema=entries_schema,
)

def entries(self, snapshot_id: Optional[int] = None) -> "pa.Table":
import pyarrow as pa

entries = []
snapshot = self._get_snapshot(snapshot_id)

if snapshot.schema_id is None:
raise ValueError(f"Cannot find schema_id for snapshot {snapshot.snapshot_id}")

schema = self.tbl.schemas().get(snapshot.schema_id)
for manifest in snapshot.manifests(self.tbl.io):
manifest_entries = self._get_entries(schema=schema, manifest=manifest, discard_deleted=True)
entries.append(manifest_entries)
return pa.concat_tables(entries)

def refs(self) -> "pa.Table":
import pyarrow as pa

Expand Down Expand Up @@ -657,3 +680,30 @@ def all_manifests(self) -> "pa.Table":
lambda args: self._generate_manifests_table(*args), [(snapshot, True) for snapshot in snapshots]
)
return pa.concat_tables(manifests_by_snapshots)

def all_entries(self) -> "pa.Table":
import pyarrow as pa

snapshots = self.tbl.snapshots()
if not snapshots:
return pa.Table.from_pylist([], self._get_entries_schema())

schemas = self.tbl.schemas()
snapshot_schemas: Dict[int, "pa.Schema"] = {}
for snapshot in snapshots:
if snapshot.schema_id is None:
raise ValueError(f"Cannot find schema_id for snapshot: {snapshot.snapshot_id}")
else:
snapshot_schemas[snapshot.snapshot_id] = schemas.get(snapshot.schema_id)

executor = ExecutorFactory.get_or_create()
all_manifests: Iterator[List[ManifestFile]] = executor.map(lambda snapshot: snapshot.manifests(self.tbl.io), snapshots)
unique_flattened_manifests = list(
{manifest.manifest_path: manifest for manifest_list in all_manifests for manifest in manifest_list}.values()
)

entries: Iterator["pa.Table"] = executor.map(
lambda manifest: self._get_entries(snapshot_schemas[manifest.added_snapshot_id], manifest, discard_deleted=True),
unique_flattened_manifests,
)
return pa.concat_tables(entries)
102 changes: 102 additions & 0 deletions tests/integration/test_inspect_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,3 +938,105 @@ def test_inspect_all_manifests(spark: SparkSession, session_catalog: Catalog, fo
lhs = spark.table(f"{identifier}.all_manifests").toPandas()
rhs = df.to_pandas()
assert_frame_equal(lhs, rhs, check_dtype=False)


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_inspect_all_entries(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None:
identifier = "default.table_metadata_all_entries"
try:
session_catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass

spark.sql(
f"""
CREATE TABLE {identifier} (
id int,
data string
)
PARTITIONED BY (data)
TBLPROPERTIES ('write.update.mode'='merge-on-read',
'write.delete.mode'='merge-on-read')
"""
)
tbl = session_catalog.load_table(identifier)

spark.sql(f"INSERT INTO {identifier} VALUES (1, 'a')")
spark.sql(f"INSERT INTO {identifier} VALUES (2, 'b')")

spark.sql(f"UPDATE {identifier} SET data = 'c' WHERE id = 1")

spark.sql(f"DELETE FROM {identifier} WHERE id = 2")

spark.sql(f"INSERT OVERWRITE {identifier} VALUES (1, 'a')")

def check_pyiceberg_df_equals_spark_df(df: pa.Table, spark_df: DataFrame) -> None:
assert df.column_names == [
"status",
"snapshot_id",
"sequence_number",
"file_sequence_number",
"data_file",
"readable_metrics",
]

# Check first 4 columns are of the correct type
for int_column in ["status", "snapshot_id", "sequence_number", "file_sequence_number"]:
for value in df[int_column]:
assert isinstance(value.as_py(), int)

# The rest of the code checks the data_file and readable_metrics columns
# Convert both dataframes to pandas and sort them the same way for comparison
lhs = df.to_pandas()
rhs = spark_df.toPandas()
for df_to_check in [lhs, rhs]:
df_to_check["content"] = df_to_check["data_file"].apply(lambda x: x.get("content"))
df_to_check["file_path"] = df_to_check["data_file"].apply(lambda x: x.get("file_path"))
df_to_check.sort_values(["status", "snapshot_id", "sequence_number", "content", "file_path"], inplace=True)
df_to_check.drop(columns=["file_path", "content"], inplace=True)

for column in df.column_names:
for left, right in zip(lhs[column].to_list(), rhs[column].to_list()):
if column == "data_file":
for df_column in left.keys():
if df_column == "partition":
# Spark leaves out the partition if the table is unpartitioned
continue

df_lhs = left[df_column]
df_rhs = right[df_column]
if isinstance(df_rhs, dict):
# Arrow turns dicts into lists of tuple
df_lhs = dict(df_lhs)

assert df_lhs == df_rhs, f"Difference in data_file column {df_column}: {df_lhs} != {df_rhs}"
elif column == "readable_metrics":
assert list(left.keys()) == ["id", "data"]

assert left.keys() == right.keys()

for rm_column in left.keys():
rm_lhs = left[rm_column]
rm_rhs = right[rm_column]

assert rm_lhs["column_size"] == rm_rhs["column_size"]
assert rm_lhs["value_count"] == rm_rhs["value_count"]
assert rm_lhs["null_value_count"] == rm_rhs["null_value_count"]
assert rm_lhs["nan_value_count"] == rm_rhs["nan_value_count"]

if rm_column == "timestamptz":
# PySpark does not correctly set the timstamptz
rm_rhs["lower_bound"] = rm_rhs["lower_bound"].replace(tzinfo=pytz.utc)
rm_rhs["upper_bound"] = rm_rhs["upper_bound"].replace(tzinfo=pytz.utc)

assert rm_lhs["lower_bound"] == rm_rhs["lower_bound"]
assert rm_lhs["upper_bound"] == rm_rhs["upper_bound"]
else:
assert left == right, f"Difference in column {column}: {left} != {right}"

tbl.refresh()

df = tbl.inspect.all_entries()
spark_df = spark.table(f"{identifier}.all_entries")
check_pyiceberg_df_equals_spark_df(df, spark_df)