From 59fffe30204185f8f3981f2dd51047f540eaa6ef Mon Sep 17 00:00:00 2001 From: Kevin Liu Date: Sun, 5 Jan 2025 18:32:23 -0500 Subject: [PATCH] [infra] replace `pycln` with `ruff` (#1485) * pre-commit autoupdate * run ruff linter and formatter * remove pycln * ignore some rules * make lint * poetry add ruff --dev * remove ruff from dev dep * git checkout apache/main poetry.lock * add back --exit-non-zero-on-fix --- .pre-commit-config.yaml | 15 +- pyiceberg/cli/output.py | 12 +- pyiceberg/expressions/visitors.py | 16 +- pyiceberg/io/pyarrow.py | 44 +- pyiceberg/manifest.py | 56 ++- pyiceberg/schema.py | 14 +- pyiceberg/table/__init__.py | 32 +- pyiceberg/table/inspect.py | 480 ++++++++++--------- ruff.toml | 2 +- tests/avro/test_resolver.py | 50 +- tests/avro/test_writer.py | 40 +- tests/catalog/test_rest.py | 48 +- tests/catalog/test_sql.py | 34 +- tests/conftest.py | 296 ++++++------ tests/expressions/test_evaluator.py | 30 +- tests/expressions/test_visitors.py | 480 +++++++++---------- tests/integration/test_add_files.py | 104 ++-- tests/integration/test_deletes.py | 16 +- tests/integration/test_reads.py | 28 +- tests/integration/test_rest_schema.py | 20 +- tests/integration/test_writes/test_writes.py | 180 ++++--- tests/io/test_pyarrow.py | 122 +++-- tests/io/test_pyarrow_visitor.py | 352 +++++++------- tests/table/test_init.py | 114 ++--- tests/table/test_name_mapping.py | 244 +++++----- tests/test_schema.py | 24 +- tests/utils/test_manifest.py | 6 +- 27 files changed, 1535 insertions(+), 1324 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bdd1f362b5..e3dc04bde3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,26 +28,19 @@ repos: - id: check-yaml - id: check-ast - repo: https://github.com/astral-sh/ruff-pre-commit - # Ruff version (Used for linting) - rev: v0.7.4 + rev: v0.8.6 hooks: - id: ruff - args: [ --fix, --exit-non-zero-on-fix, --preview ] + args: [ --fix, --exit-non-zero-on-fix ] - id: ruff-format - args: [ --preview ] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.8.0 + rev: v1.14.1 hooks: - id: mypy args: [--install-types, --non-interactive, --config=pyproject.toml] - - repo: https://github.com/hadialqattan/pycln - rev: v2.4.0 - hooks: - - id: pycln - args: [--config=pyproject.toml] - repo: https://github.com/igorshubovych/markdownlint-cli - rev: v0.42.0 + rev: v0.43.0 hooks: - id: markdownlint args: ["--fix"] diff --git a/pyiceberg/cli/output.py b/pyiceberg/cli/output.py index a4183c32bd..0eb85841bf 100644 --- a/pyiceberg/cli/output.py +++ b/pyiceberg/cli/output.py @@ -242,8 +242,10 @@ def version(self, version: str) -> None: self._out({"version": version}) def describe_refs(self, refs: List[Tuple[str, SnapshotRefType, Dict[str, str]]]) -> None: - self._out([ - {"name": name, "type": type, detail_key: detail_val} - for name, type, detail in refs - for detail_key, detail_val in detail.items() - ]) + self._out( + [ + {"name": name, "type": type, detail_key: detail_val} + for name, type, detail in refs + for detail_key, detail_val in detail.items() + ] + ) diff --git a/pyiceberg/expressions/visitors.py b/pyiceberg/expressions/visitors.py index 26698921b5..768878b068 100644 --- a/pyiceberg/expressions/visitors.py +++ b/pyiceberg/expressions/visitors.py @@ -1228,7 +1228,7 @@ def visit_less_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: # NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more. return ROWS_MIGHT_MATCH - if lower_bound >= literal.value: + if lower_bound >= literal.value: # type: ignore[operator] return ROWS_CANNOT_MATCH return ROWS_MIGHT_MATCH @@ -1249,7 +1249,7 @@ def visit_less_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) -> b # NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more. return ROWS_MIGHT_MATCH - if lower_bound > literal.value: + if lower_bound > literal.value: # type: ignore[operator] return ROWS_CANNOT_MATCH return ROWS_MIGHT_MATCH @@ -1266,7 +1266,7 @@ def visit_greater_than(self, term: BoundTerm[L], literal: Literal[L]) -> bool: if upper_bound_bytes := self.upper_bounds.get(field_id): upper_bound = from_bytes(field.field_type, upper_bound_bytes) - if upper_bound <= literal.value: + if upper_bound <= literal.value: # type: ignore[operator] if self._is_nan(upper_bound): # NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more. return ROWS_MIGHT_MATCH @@ -1287,7 +1287,7 @@ def visit_greater_than_or_equal(self, term: BoundTerm[L], literal: Literal[L]) - if upper_bound_bytes := self.upper_bounds.get(field_id): upper_bound = from_bytes(field.field_type, upper_bound_bytes) - if upper_bound < literal.value: + if upper_bound < literal.value: # type: ignore[operator] if self._is_nan(upper_bound): # NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more. return ROWS_MIGHT_MATCH @@ -1312,7 +1312,7 @@ def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: # NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more. return ROWS_MIGHT_MATCH - if lower_bound > literal.value: + if lower_bound > literal.value: # type: ignore[operator] return ROWS_CANNOT_MATCH if upper_bound_bytes := self.upper_bounds.get(field_id): @@ -1321,7 +1321,7 @@ def visit_equal(self, term: BoundTerm[L], literal: Literal[L]) -> bool: # NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more. return ROWS_MIGHT_MATCH - if upper_bound < literal.value: + if upper_bound < literal.value: # type: ignore[operator] return ROWS_CANNOT_MATCH return ROWS_MIGHT_MATCH @@ -1349,7 +1349,7 @@ def visit_in(self, term: BoundTerm[L], literals: Set[L]) -> bool: # NaN indicates unreliable bounds. See the InclusiveMetricsEvaluator docs for more. return ROWS_MIGHT_MATCH - literals = {lit for lit in literals if lower_bound <= lit} + literals = {lit for lit in literals if lower_bound <= lit} # type: ignore[operator] if len(literals) == 0: return ROWS_CANNOT_MATCH @@ -1359,7 +1359,7 @@ def visit_in(self, term: BoundTerm[L], literals: Set[L]) -> bool: if self._is_nan(upper_bound): return ROWS_MIGHT_MATCH - literals = {lit for lit in literals if upper_bound >= lit} + literals = {lit for lit in literals if upper_bound >= lit} # type: ignore[operator] if len(literals) == 0: return ROWS_CANNOT_MATCH diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index e8c9f64d63..dc41a7d6a1 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -2449,27 +2449,31 @@ def _dataframe_to_data_files( yield from write_file( io=io, table_metadata=table_metadata, - tasks=iter([ - WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=task_schema) - for batches in bin_pack_arrow_table(df, target_file_size) - ]), + tasks=iter( + [ + WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=task_schema) + for batches in bin_pack_arrow_table(df, target_file_size) + ] + ), ) else: partitions = _determine_partitions(spec=table_metadata.spec(), schema=table_metadata.schema(), arrow_table=df) yield from write_file( io=io, table_metadata=table_metadata, - tasks=iter([ - WriteTask( - write_uuid=write_uuid, - task_id=next(counter), - record_batches=batches, - partition_key=partition.partition_key, - schema=task_schema, - ) - for partition in partitions - for batches in bin_pack_arrow_table(partition.arrow_table_partition, target_file_size) - ]), + tasks=iter( + [ + WriteTask( + write_uuid=write_uuid, + task_id=next(counter), + record_batches=batches, + partition_key=partition.partition_key, + schema=task_schema, + ) + for partition in partitions + for batches in bin_pack_arrow_table(partition.arrow_table_partition, target_file_size) + ] + ), ) @@ -2534,10 +2538,12 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T partition_columns: List[Tuple[PartitionField, NestedField]] = [ (partition_field, schema.find_field(partition_field.source_id)) for partition_field in spec.fields ] - partition_values_table = pa.table({ - str(partition.field_id): partition.transform.pyarrow_transform(field.field_type)(arrow_table[field.name]) - for partition, field in partition_columns - }) + partition_values_table = pa.table( + { + str(partition.field_id): partition.transform.pyarrow_transform(field.field_type)(arrow_table[field.name]) + for partition, field in partition_columns + } + ) # Sort by partitions sort_indices = pa.compute.sort_indices( diff --git a/pyiceberg/manifest.py b/pyiceberg/manifest.py index a56da5fc05..5a32a6330c 100644 --- a/pyiceberg/manifest.py +++ b/pyiceberg/manifest.py @@ -292,28 +292,32 @@ def __repr__(self) -> str: def data_file_with_partition(partition_type: StructType, format_version: TableVersion) -> StructType: - data_file_partition_type = StructType(*[ - NestedField( - field_id=field.field_id, - name=field.name, - field_type=field.field_type, - required=field.required, - ) - for field in partition_type.fields - ]) + data_file_partition_type = StructType( + *[ + NestedField( + field_id=field.field_id, + name=field.name, + field_type=field.field_type, + required=field.required, + ) + for field in partition_type.fields + ] + ) - return StructType(*[ - NestedField( - field_id=102, - name="partition", - field_type=data_file_partition_type, - required=True, - doc="Partition data tuple, schema based on the partition spec", - ) - if field.field_id == 102 - else field - for field in DATA_FILE_TYPE[format_version].fields - ]) + return StructType( + *[ + NestedField( + field_id=102, + name="partition", + field_type=data_file_partition_type, + required=True, + doc="Partition data tuple, schema based on the partition spec", + ) + if field.field_id == 102 + else field + for field in DATA_FILE_TYPE[format_version].fields + ] + ) class DataFile(Record): @@ -398,10 +402,12 @@ def __eq__(self, other: Any) -> bool: def manifest_entry_schema_with_data_file(format_version: TableVersion, data_file: StructType) -> Schema: - return Schema(*[ - NestedField(2, "data_file", data_file, required=True) if field.field_id == 2 else field - for field in MANIFEST_ENTRY_SCHEMAS[format_version].fields - ]) + return Schema( + *[ + NestedField(2, "data_file", data_file, required=True) if field.field_id == 2 else field + for field in MANIFEST_ENTRY_SCHEMAS[format_version].fields + ] + ) class ManifestEntry(Record): diff --git a/pyiceberg/schema.py b/pyiceberg/schema.py index cfe3fe3a7b..5a373cb15f 100644 --- a/pyiceberg/schema.py +++ b/pyiceberg/schema.py @@ -1707,12 +1707,14 @@ def list(self, list_type: ListType, element_result: Callable[[], bool]) -> bool: return self._is_field_compatible(list_type.element_field) and element_result() def map(self, map_type: MapType, key_result: Callable[[], bool], value_result: Callable[[], bool]) -> bool: - return all([ - self._is_field_compatible(map_type.key_field), - self._is_field_compatible(map_type.value_field), - key_result(), - value_result(), - ]) + return all( + [ + self._is_field_compatible(map_type.key_field), + self._is_field_compatible(map_type.value_field), + key_result(), + value_result(), + ] + ) def primitive(self, primitive: PrimitiveType) -> bool: return True diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 2469a9ed7b..7bc3fe838b 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -629,18 +629,20 @@ def delete( if len(filtered_df) == 0: replaced_files.append((original_file.file, [])) elif len(df) != len(filtered_df): - replaced_files.append(( - original_file.file, - list( - _dataframe_to_data_files( - io=self._table.io, - df=filtered_df, - table_metadata=self.table_metadata, - write_uuid=commit_uuid, - counter=counter, - ) - ), - )) + replaced_files.append( + ( + original_file.file, + list( + _dataframe_to_data_files( + io=self._table.io, + df=filtered_df, + table_metadata=self.table_metadata, + write_uuid=commit_uuid, + counter=counter, + ) + ), + ) + ) if len(replaced_files) > 0: with self.update_snapshot(snapshot_properties=snapshot_properties).overwrite() as overwrite_snapshot: @@ -680,9 +682,9 @@ def add_files( raise ValueError(f"Cannot add files that are already referenced by table, files: {', '.join(referenced_files)}") if self.table_metadata.name_mapping() is None: - self.set_properties(**{ - TableProperties.DEFAULT_NAME_MAPPING: self.table_metadata.schema().name_mapping.model_dump_json() - }) + self.set_properties( + **{TableProperties.DEFAULT_NAME_MAPPING: self.table_metadata.schema().name_mapping.model_dump_json()} + ) with self.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as update_snapshot: data_files = _parquet_files_to_data_files( table_metadata=self.table_metadata, file_paths=file_paths, io=self._table.io diff --git a/pyiceberg/table/inspect.py b/pyiceberg/table/inspect.py index beee426533..71d38a2279 100644 --- a/pyiceberg/table/inspect.py +++ b/pyiceberg/table/inspect.py @@ -58,14 +58,16 @@ def _get_snapshot(self, snapshot_id: Optional[int] = None) -> Snapshot: def snapshots(self) -> "pa.Table": import pyarrow as pa - snapshots_schema = pa.schema([ - pa.field("committed_at", pa.timestamp(unit="ms"), nullable=False), - pa.field("snapshot_id", pa.int64(), nullable=False), - pa.field("parent_id", pa.int64(), nullable=True), - pa.field("operation", pa.string(), nullable=True), - pa.field("manifest_list", pa.string(), nullable=False), - pa.field("summary", pa.map_(pa.string(), pa.string()), nullable=True), - ]) + snapshots_schema = pa.schema( + [ + pa.field("committed_at", pa.timestamp(unit="ms"), nullable=False), + pa.field("snapshot_id", pa.int64(), nullable=False), + pa.field("parent_id", pa.int64(), nullable=True), + pa.field("operation", pa.string(), nullable=True), + pa.field("manifest_list", pa.string(), nullable=False), + pa.field("summary", pa.map_(pa.string(), pa.string()), nullable=True), + ] + ) snapshots = [] for snapshot in self.tbl.metadata.snapshots: if summary := snapshot.summary: @@ -75,14 +77,16 @@ def snapshots(self) -> "pa.Table": operation = None additional_properties = None - snapshots.append({ - "committed_at": datetime.fromtimestamp(snapshot.timestamp_ms / 1000.0, tz=timezone.utc), - "snapshot_id": snapshot.snapshot_id, - "parent_id": snapshot.parent_snapshot_id, - "operation": str(operation), - "manifest_list": snapshot.manifest_list, - "summary": additional_properties, - }) + snapshots.append( + { + "committed_at": datetime.fromtimestamp(snapshot.timestamp_ms / 1000.0, tz=timezone.utc), + "snapshot_id": snapshot.snapshot_id, + "parent_id": snapshot.parent_snapshot_id, + "operation": str(operation), + "manifest_list": snapshot.manifest_list, + "summary": additional_properties, + } + ) return pa.Table.from_pylist( snapshots, @@ -100,14 +104,16 @@ def entries(self, snapshot_id: Optional[int] = None) -> "pa.Table": def _readable_metrics_struct(bound_type: PrimitiveType) -> pa.StructType: pa_bound_type = schema_to_pyarrow(bound_type) - return pa.struct([ - pa.field("column_size", pa.int64(), nullable=True), - pa.field("value_count", pa.int64(), nullable=True), - pa.field("null_value_count", pa.int64(), nullable=True), - pa.field("nan_value_count", pa.int64(), nullable=True), - pa.field("lower_bound", pa_bound_type, nullable=True), - pa.field("upper_bound", pa_bound_type, nullable=True), - ]) + return pa.struct( + [ + pa.field("column_size", pa.int64(), nullable=True), + pa.field("value_count", pa.int64(), nullable=True), + pa.field("null_value_count", pa.int64(), nullable=True), + pa.field("nan_value_count", pa.int64(), nullable=True), + pa.field("lower_bound", pa_bound_type, nullable=True), + pa.field("upper_bound", pa_bound_type, nullable=True), + ] + ) for field in self.tbl.metadata.schema().fields: readable_metrics_struct.append( @@ -117,35 +123,39 @@ def _readable_metrics_struct(bound_type: PrimitiveType) -> pa.StructType: partition_record = self.tbl.metadata.specs_struct() pa_record_struct = schema_to_pyarrow(partition_record) - entries_schema = pa.schema([ - pa.field("status", pa.int8(), nullable=False), - pa.field("snapshot_id", pa.int64(), nullable=False), - pa.field("sequence_number", pa.int64(), nullable=False), - pa.field("file_sequence_number", pa.int64(), nullable=False), - pa.field( - "data_file", - pa.struct([ - pa.field("content", pa.int8(), nullable=False), - pa.field("file_path", pa.string(), nullable=False), - pa.field("file_format", pa.string(), nullable=False), - pa.field("partition", pa_record_struct, nullable=False), - pa.field("record_count", pa.int64(), nullable=False), - pa.field("file_size_in_bytes", pa.int64(), nullable=False), - pa.field("column_sizes", pa.map_(pa.int32(), pa.int64()), nullable=True), - pa.field("value_counts", pa.map_(pa.int32(), pa.int64()), nullable=True), - pa.field("null_value_counts", pa.map_(pa.int32(), pa.int64()), nullable=True), - pa.field("nan_value_counts", pa.map_(pa.int32(), pa.int64()), nullable=True), - pa.field("lower_bounds", pa.map_(pa.int32(), pa.binary()), nullable=True), - pa.field("upper_bounds", pa.map_(pa.int32(), pa.binary()), nullable=True), - pa.field("key_metadata", pa.binary(), nullable=True), - pa.field("split_offsets", pa.list_(pa.int64()), nullable=True), - pa.field("equality_ids", pa.list_(pa.int32()), nullable=True), - pa.field("sort_order_id", pa.int32(), nullable=True), - ]), - nullable=False, - ), - pa.field("readable_metrics", pa.struct(readable_metrics_struct), nullable=True), - ]) + entries_schema = pa.schema( + [ + pa.field("status", pa.int8(), nullable=False), + pa.field("snapshot_id", pa.int64(), nullable=False), + pa.field("sequence_number", pa.int64(), nullable=False), + pa.field("file_sequence_number", pa.int64(), nullable=False), + pa.field( + "data_file", + pa.struct( + [ + pa.field("content", pa.int8(), nullable=False), + pa.field("file_path", pa.string(), nullable=False), + pa.field("file_format", pa.string(), nullable=False), + pa.field("partition", pa_record_struct, nullable=False), + pa.field("record_count", pa.int64(), nullable=False), + pa.field("file_size_in_bytes", pa.int64(), nullable=False), + pa.field("column_sizes", pa.map_(pa.int32(), pa.int64()), nullable=True), + pa.field("value_counts", pa.map_(pa.int32(), pa.int64()), nullable=True), + pa.field("null_value_counts", pa.map_(pa.int32(), pa.int64()), nullable=True), + pa.field("nan_value_counts", pa.map_(pa.int32(), pa.int64()), nullable=True), + pa.field("lower_bounds", pa.map_(pa.int32(), pa.binary()), nullable=True), + pa.field("upper_bounds", pa.map_(pa.int32(), pa.binary()), nullable=True), + pa.field("key_metadata", pa.binary(), nullable=True), + pa.field("split_offsets", pa.list_(pa.int64()), nullable=True), + pa.field("equality_ids", pa.list_(pa.int32()), nullable=True), + pa.field("sort_order_id", pa.int32(), nullable=True), + ] + ), + nullable=False, + ), + pa.field("readable_metrics", pa.struct(readable_metrics_struct), nullable=True), + ] + ) entries = [] snapshot = self._get_snapshot(snapshot_id) @@ -180,32 +190,34 @@ def _readable_metrics_struct(bound_type: PrimitiveType) -> pa.StructType: for pos, field in enumerate(self.tbl.metadata.specs()[manifest.partition_spec_id].fields) } - entries.append({ - "status": entry.status.value, - "snapshot_id": entry.snapshot_id, - "sequence_number": entry.sequence_number, - "file_sequence_number": entry.file_sequence_number, - "data_file": { - "content": entry.data_file.content, - "file_path": entry.data_file.file_path, - "file_format": entry.data_file.file_format, - "partition": partition_record_dict, - "record_count": entry.data_file.record_count, - "file_size_in_bytes": entry.data_file.file_size_in_bytes, - "column_sizes": dict(entry.data_file.column_sizes), - "value_counts": dict(entry.data_file.value_counts), - "null_value_counts": dict(entry.data_file.null_value_counts), - "nan_value_counts": entry.data_file.nan_value_counts, - "lower_bounds": entry.data_file.lower_bounds, - "upper_bounds": entry.data_file.upper_bounds, - "key_metadata": entry.data_file.key_metadata, - "split_offsets": entry.data_file.split_offsets, - "equality_ids": entry.data_file.equality_ids, - "sort_order_id": entry.data_file.sort_order_id, - "spec_id": entry.data_file.spec_id, - }, - "readable_metrics": readable_metrics, - }) + entries.append( + { + "status": entry.status.value, + "snapshot_id": entry.snapshot_id, + "sequence_number": entry.sequence_number, + "file_sequence_number": entry.file_sequence_number, + "data_file": { + "content": entry.data_file.content, + "file_path": entry.data_file.file_path, + "file_format": entry.data_file.file_format, + "partition": partition_record_dict, + "record_count": entry.data_file.record_count, + "file_size_in_bytes": entry.data_file.file_size_in_bytes, + "column_sizes": dict(entry.data_file.column_sizes), + "value_counts": dict(entry.data_file.value_counts), + "null_value_counts": dict(entry.data_file.null_value_counts), + "nan_value_counts": entry.data_file.nan_value_counts, + "lower_bounds": entry.data_file.lower_bounds, + "upper_bounds": entry.data_file.upper_bounds, + "key_metadata": entry.data_file.key_metadata, + "split_offsets": entry.data_file.split_offsets, + "equality_ids": entry.data_file.equality_ids, + "sort_order_id": entry.data_file.sort_order_id, + "spec_id": entry.data_file.spec_id, + }, + "readable_metrics": readable_metrics, + } + ) return pa.Table.from_pylist( entries, @@ -215,26 +227,30 @@ def _readable_metrics_struct(bound_type: PrimitiveType) -> pa.StructType: def refs(self) -> "pa.Table": import pyarrow as pa - ref_schema = pa.schema([ - pa.field("name", pa.string(), nullable=False), - pa.field("type", pa.dictionary(pa.int32(), pa.string()), nullable=False), - pa.field("snapshot_id", pa.int64(), nullable=False), - pa.field("max_reference_age_in_ms", pa.int64(), nullable=True), - pa.field("min_snapshots_to_keep", pa.int32(), nullable=True), - pa.field("max_snapshot_age_in_ms", pa.int64(), nullable=True), - ]) + ref_schema = pa.schema( + [ + pa.field("name", pa.string(), nullable=False), + pa.field("type", pa.dictionary(pa.int32(), pa.string()), nullable=False), + pa.field("snapshot_id", pa.int64(), nullable=False), + pa.field("max_reference_age_in_ms", pa.int64(), nullable=True), + pa.field("min_snapshots_to_keep", pa.int32(), nullable=True), + pa.field("max_snapshot_age_in_ms", pa.int64(), nullable=True), + ] + ) ref_results = [] for ref in self.tbl.metadata.refs: if snapshot_ref := self.tbl.metadata.refs.get(ref): - ref_results.append({ - "name": ref, - "type": snapshot_ref.snapshot_ref_type.upper(), - "snapshot_id": snapshot_ref.snapshot_id, - "max_reference_age_in_ms": snapshot_ref.max_ref_age_ms, - "min_snapshots_to_keep": snapshot_ref.min_snapshots_to_keep, - "max_snapshot_age_in_ms": snapshot_ref.max_snapshot_age_ms, - }) + ref_results.append( + { + "name": ref, + "type": snapshot_ref.snapshot_ref_type.upper(), + "snapshot_id": snapshot_ref.snapshot_id, + "max_reference_age_in_ms": snapshot_ref.max_ref_age_ms, + "min_snapshots_to_keep": snapshot_ref.min_snapshots_to_keep, + "max_snapshot_age_in_ms": snapshot_ref.max_snapshot_age_ms, + } + ) return pa.Table.from_pylist(ref_results, schema=ref_schema) @@ -243,27 +259,31 @@ def partitions(self, snapshot_id: Optional[int] = None) -> "pa.Table": from pyiceberg.io.pyarrow import schema_to_pyarrow - table_schema = pa.schema([ - pa.field("record_count", pa.int64(), nullable=False), - pa.field("file_count", pa.int32(), nullable=False), - pa.field("total_data_file_size_in_bytes", pa.int64(), nullable=False), - pa.field("position_delete_record_count", pa.int64(), nullable=False), - pa.field("position_delete_file_count", pa.int32(), nullable=False), - pa.field("equality_delete_record_count", pa.int64(), nullable=False), - pa.field("equality_delete_file_count", pa.int32(), nullable=False), - pa.field("last_updated_at", pa.timestamp(unit="ms"), nullable=True), - pa.field("last_updated_snapshot_id", pa.int64(), nullable=True), - ]) + table_schema = pa.schema( + [ + pa.field("record_count", pa.int64(), nullable=False), + pa.field("file_count", pa.int32(), nullable=False), + pa.field("total_data_file_size_in_bytes", pa.int64(), nullable=False), + pa.field("position_delete_record_count", pa.int64(), nullable=False), + pa.field("position_delete_file_count", pa.int32(), nullable=False), + pa.field("equality_delete_record_count", pa.int64(), nullable=False), + pa.field("equality_delete_file_count", pa.int32(), nullable=False), + pa.field("last_updated_at", pa.timestamp(unit="ms"), nullable=True), + pa.field("last_updated_snapshot_id", pa.int64(), nullable=True), + ] + ) partition_record = self.tbl.metadata.specs_struct() has_partitions = len(partition_record.fields) > 0 if has_partitions: pa_record_struct = schema_to_pyarrow(partition_record) - partitions_schema = pa.schema([ - pa.field("partition", pa_record_struct, nullable=False), - pa.field("spec_id", pa.int32(), nullable=False), - ]) + partitions_schema = pa.schema( + [ + pa.field("partition", pa_record_struct, nullable=False), + pa.field("spec_id", pa.int32(), nullable=False), + ] + ) table_schema = pa.unify_schemas([partitions_schema, table_schema]) @@ -329,27 +349,31 @@ def update_partitions_map( def manifests(self) -> "pa.Table": import pyarrow as pa - partition_summary_schema = pa.struct([ - pa.field("contains_null", pa.bool_(), nullable=False), - pa.field("contains_nan", pa.bool_(), nullable=True), - pa.field("lower_bound", pa.string(), nullable=True), - pa.field("upper_bound", pa.string(), nullable=True), - ]) - - manifest_schema = pa.schema([ - pa.field("content", pa.int8(), nullable=False), - pa.field("path", pa.string(), nullable=False), - pa.field("length", pa.int64(), nullable=False), - pa.field("partition_spec_id", pa.int32(), nullable=False), - pa.field("added_snapshot_id", pa.int64(), nullable=False), - pa.field("added_data_files_count", pa.int32(), nullable=False), - pa.field("existing_data_files_count", pa.int32(), nullable=False), - pa.field("deleted_data_files_count", pa.int32(), nullable=False), - pa.field("added_delete_files_count", pa.int32(), nullable=False), - pa.field("existing_delete_files_count", pa.int32(), nullable=False), - pa.field("deleted_delete_files_count", pa.int32(), nullable=False), - pa.field("partition_summaries", pa.list_(partition_summary_schema), nullable=False), - ]) + partition_summary_schema = pa.struct( + [ + pa.field("contains_null", pa.bool_(), nullable=False), + pa.field("contains_nan", pa.bool_(), nullable=True), + pa.field("lower_bound", pa.string(), nullable=True), + pa.field("upper_bound", pa.string(), nullable=True), + ] + ) + + manifest_schema = pa.schema( + [ + pa.field("content", pa.int8(), nullable=False), + pa.field("path", pa.string(), nullable=False), + pa.field("length", pa.int64(), nullable=False), + pa.field("partition_spec_id", pa.int32(), nullable=False), + pa.field("added_snapshot_id", pa.int64(), nullable=False), + pa.field("added_data_files_count", pa.int32(), nullable=False), + pa.field("existing_data_files_count", pa.int32(), nullable=False), + pa.field("deleted_data_files_count", pa.int32(), nullable=False), + pa.field("added_delete_files_count", pa.int32(), nullable=False), + pa.field("existing_delete_files_count", pa.int32(), nullable=False), + pa.field("deleted_delete_files_count", pa.int32(), nullable=False), + pa.field("partition_summaries", pa.list_(partition_summary_schema), nullable=False), + ] + ) def _partition_summaries_to_rows( spec: PartitionSpec, partition_summaries: List[PartitionFieldSummary] @@ -376,12 +400,14 @@ def _partition_summaries_to_rows( if field_summary.upper_bound else None ) - rows.append({ - "contains_null": field_summary.contains_null, - "contains_nan": field_summary.contains_nan, - "lower_bound": lower_bound, - "upper_bound": upper_bound, - }) + rows.append( + { + "contains_null": field_summary.contains_null, + "contains_nan": field_summary.contains_nan, + "lower_bound": lower_bound, + "upper_bound": upper_bound, + } + ) return rows specs = self.tbl.metadata.specs() @@ -390,22 +416,26 @@ def _partition_summaries_to_rows( for manifest in snapshot.manifests(self.tbl.io): is_data_file = manifest.content == ManifestContent.DATA is_delete_file = manifest.content == ManifestContent.DELETES - manifests.append({ - "content": manifest.content, - "path": manifest.manifest_path, - "length": manifest.manifest_length, - "partition_spec_id": manifest.partition_spec_id, - "added_snapshot_id": manifest.added_snapshot_id, - "added_data_files_count": manifest.added_files_count if is_data_file else 0, - "existing_data_files_count": manifest.existing_files_count if is_data_file else 0, - "deleted_data_files_count": manifest.deleted_files_count if is_data_file else 0, - "added_delete_files_count": manifest.added_files_count if is_delete_file else 0, - "existing_delete_files_count": manifest.existing_files_count if is_delete_file else 0, - "deleted_delete_files_count": manifest.deleted_files_count if is_delete_file else 0, - "partition_summaries": _partition_summaries_to_rows(specs[manifest.partition_spec_id], manifest.partitions) - if manifest.partitions - else [], - }) + manifests.append( + { + "content": manifest.content, + "path": manifest.manifest_path, + "length": manifest.manifest_length, + "partition_spec_id": manifest.partition_spec_id, + "added_snapshot_id": manifest.added_snapshot_id, + "added_data_files_count": manifest.added_files_count if is_data_file else 0, + "existing_data_files_count": manifest.existing_files_count if is_data_file else 0, + "deleted_data_files_count": manifest.deleted_files_count if is_data_file else 0, + "added_delete_files_count": manifest.added_files_count if is_delete_file else 0, + "existing_delete_files_count": manifest.existing_files_count if is_delete_file else 0, + "deleted_delete_files_count": manifest.deleted_files_count if is_delete_file else 0, + "partition_summaries": _partition_summaries_to_rows( + specs[manifest.partition_spec_id], manifest.partitions + ) + if manifest.partitions + else [], + } + ) return pa.Table.from_pylist( manifests, @@ -417,13 +447,15 @@ def metadata_log_entries(self) -> "pa.Table": from pyiceberg.table.snapshots import MetadataLogEntry - table_schema = pa.schema([ - pa.field("timestamp", pa.timestamp(unit="ms"), nullable=False), - pa.field("file", pa.string(), nullable=False), - pa.field("latest_snapshot_id", pa.int64(), nullable=True), - pa.field("latest_schema_id", pa.int32(), nullable=True), - pa.field("latest_sequence_number", pa.int64(), nullable=True), - ]) + table_schema = pa.schema( + [ + pa.field("timestamp", pa.timestamp(unit="ms"), nullable=False), + pa.field("file", pa.string(), nullable=False), + pa.field("latest_snapshot_id", pa.int64(), nullable=True), + pa.field("latest_schema_id", pa.int32(), nullable=True), + pa.field("latest_sequence_number", pa.int64(), nullable=True), + ] + ) def metadata_log_entry_to_row(metadata_entry: MetadataLogEntry) -> Dict[str, Any]: latest_snapshot = self.tbl.snapshot_as_of_timestamp(metadata_entry.timestamp_ms) @@ -449,12 +481,14 @@ def metadata_log_entry_to_row(metadata_entry: MetadataLogEntry) -> Dict[str, Any def history(self) -> "pa.Table": import pyarrow as pa - history_schema = pa.schema([ - pa.field("made_current_at", pa.timestamp(unit="ms"), nullable=False), - pa.field("snapshot_id", pa.int64(), nullable=False), - pa.field("parent_id", pa.int64(), nullable=True), - pa.field("is_current_ancestor", pa.bool_(), nullable=False), - ]) + history_schema = pa.schema( + [ + pa.field("made_current_at", pa.timestamp(unit="ms"), nullable=False), + pa.field("snapshot_id", pa.int64(), nullable=False), + pa.field("parent_id", pa.int64(), nullable=True), + pa.field("is_current_ancestor", pa.bool_(), nullable=False), + ] + ) ancestors_ids = {snapshot.snapshot_id for snapshot in ancestors_of(self.tbl.current_snapshot(), self.tbl.metadata)} @@ -464,12 +498,14 @@ def history(self) -> "pa.Table": for snapshot_entry in metadata.snapshot_log: snapshot = metadata.snapshot_by_id(snapshot_entry.snapshot_id) - history.append({ - "made_current_at": datetime.fromtimestamp(snapshot_entry.timestamp_ms / 1000.0, tz=timezone.utc), - "snapshot_id": snapshot_entry.snapshot_id, - "parent_id": snapshot.parent_snapshot_id if snapshot else None, - "is_current_ancestor": snapshot_entry.snapshot_id in ancestors_ids, - }) + history.append( + { + "made_current_at": datetime.fromtimestamp(snapshot_entry.timestamp_ms / 1000.0, tz=timezone.utc), + "snapshot_id": snapshot_entry.snapshot_id, + "parent_id": snapshot.parent_snapshot_id if snapshot else None, + "is_current_ancestor": snapshot_entry.snapshot_id in ancestors_ids, + } + ) return pa.Table.from_pylist(history, schema=history_schema) @@ -483,39 +519,43 @@ def _files(self, snapshot_id: Optional[int] = None, data_file_filter: Optional[S def _readable_metrics_struct(bound_type: PrimitiveType) -> pa.StructType: pa_bound_type = schema_to_pyarrow(bound_type) - return pa.struct([ - pa.field("column_size", pa.int64(), nullable=True), - pa.field("value_count", pa.int64(), nullable=True), - pa.field("null_value_count", pa.int64(), nullable=True), - pa.field("nan_value_count", pa.int64(), nullable=True), - pa.field("lower_bound", pa_bound_type, nullable=True), - pa.field("upper_bound", pa_bound_type, nullable=True), - ]) + return pa.struct( + [ + pa.field("column_size", pa.int64(), nullable=True), + pa.field("value_count", pa.int64(), nullable=True), + pa.field("null_value_count", pa.int64(), nullable=True), + pa.field("nan_value_count", pa.int64(), nullable=True), + pa.field("lower_bound", pa_bound_type, nullable=True), + pa.field("upper_bound", pa_bound_type, nullable=True), + ] + ) for field in self.tbl.metadata.schema().fields: readable_metrics_struct.append( pa.field(schema.find_column_name(field.field_id), _readable_metrics_struct(field.field_type), nullable=False) ) - files_schema = pa.schema([ - pa.field("content", pa.int8(), nullable=False), - pa.field("file_path", pa.string(), nullable=False), - pa.field("file_format", pa.dictionary(pa.int32(), pa.string()), nullable=False), - pa.field("spec_id", pa.int32(), nullable=False), - pa.field("record_count", pa.int64(), nullable=False), - pa.field("file_size_in_bytes", pa.int64(), nullable=False), - pa.field("column_sizes", pa.map_(pa.int32(), pa.int64()), nullable=True), - pa.field("value_counts", pa.map_(pa.int32(), pa.int64()), nullable=True), - pa.field("null_value_counts", pa.map_(pa.int32(), pa.int64()), nullable=True), - pa.field("nan_value_counts", pa.map_(pa.int32(), pa.int64()), nullable=True), - pa.field("lower_bounds", pa.map_(pa.int32(), pa.binary()), nullable=True), - pa.field("upper_bounds", pa.map_(pa.int32(), pa.binary()), nullable=True), - pa.field("key_metadata", pa.binary(), nullable=True), - pa.field("split_offsets", pa.list_(pa.int64()), nullable=True), - pa.field("equality_ids", pa.list_(pa.int32()), nullable=True), - pa.field("sort_order_id", pa.int32(), nullable=True), - pa.field("readable_metrics", pa.struct(readable_metrics_struct), nullable=True), - ]) + files_schema = pa.schema( + [ + pa.field("content", pa.int8(), nullable=False), + pa.field("file_path", pa.string(), nullable=False), + pa.field("file_format", pa.dictionary(pa.int32(), pa.string()), nullable=False), + pa.field("spec_id", pa.int32(), nullable=False), + pa.field("record_count", pa.int64(), nullable=False), + pa.field("file_size_in_bytes", pa.int64(), nullable=False), + pa.field("column_sizes", pa.map_(pa.int32(), pa.int64()), nullable=True), + pa.field("value_counts", pa.map_(pa.int32(), pa.int64()), nullable=True), + pa.field("null_value_counts", pa.map_(pa.int32(), pa.int64()), nullable=True), + pa.field("nan_value_counts", pa.map_(pa.int32(), pa.int64()), nullable=True), + pa.field("lower_bounds", pa.map_(pa.int32(), pa.binary()), nullable=True), + pa.field("upper_bounds", pa.map_(pa.int32(), pa.binary()), nullable=True), + pa.field("key_metadata", pa.binary(), nullable=True), + pa.field("split_offsets", pa.list_(pa.int64()), nullable=True), + pa.field("equality_ids", pa.list_(pa.int32()), nullable=True), + pa.field("sort_order_id", pa.int32(), nullable=True), + pa.field("readable_metrics", pa.struct(readable_metrics_struct), nullable=True), + ] + ) files: list[dict[str, Any]] = [] @@ -553,25 +593,29 @@ def _readable_metrics_struct(bound_type: PrimitiveType) -> pa.StructType: } for field in self.tbl.metadata.schema().fields } - files.append({ - "content": data_file.content, - "file_path": data_file.file_path, - "file_format": data_file.file_format, - "spec_id": data_file.spec_id, - "record_count": data_file.record_count, - "file_size_in_bytes": data_file.file_size_in_bytes, - "column_sizes": dict(data_file.column_sizes) if data_file.column_sizes is not None else None, - "value_counts": dict(data_file.value_counts) if data_file.value_counts is not None else None, - "null_value_counts": dict(data_file.null_value_counts) if data_file.null_value_counts is not None else None, - "nan_value_counts": dict(data_file.nan_value_counts) if data_file.nan_value_counts is not None else None, - "lower_bounds": dict(data_file.lower_bounds) if data_file.lower_bounds is not None else None, - "upper_bounds": dict(data_file.upper_bounds) if data_file.upper_bounds is not None else None, - "key_metadata": data_file.key_metadata, - "split_offsets": data_file.split_offsets, - "equality_ids": data_file.equality_ids, - "sort_order_id": data_file.sort_order_id, - "readable_metrics": readable_metrics, - }) + files.append( + { + "content": data_file.content, + "file_path": data_file.file_path, + "file_format": data_file.file_format, + "spec_id": data_file.spec_id, + "record_count": data_file.record_count, + "file_size_in_bytes": data_file.file_size_in_bytes, + "column_sizes": dict(data_file.column_sizes) if data_file.column_sizes is not None else None, + "value_counts": dict(data_file.value_counts) if data_file.value_counts is not None else None, + "null_value_counts": dict(data_file.null_value_counts) + if data_file.null_value_counts is not None + else None, + "nan_value_counts": dict(data_file.nan_value_counts) if data_file.nan_value_counts is not None else None, + "lower_bounds": dict(data_file.lower_bounds) if data_file.lower_bounds is not None else None, + "upper_bounds": dict(data_file.upper_bounds) if data_file.upper_bounds is not None else None, + "key_metadata": data_file.key_metadata, + "split_offsets": data_file.split_offsets, + "equality_ids": data_file.equality_ids, + "sort_order_id": data_file.sort_order_id, + "readable_metrics": readable_metrics, + } + ) return pa.Table.from_pylist( files, diff --git a/ruff.toml b/ruff.toml index caaa108c84..11fd2a957b 100644 --- a/ruff.toml +++ b/ruff.toml @@ -58,7 +58,7 @@ select = [ "I", # isort "UP", # pyupgrade ] -ignore = ["E501","E203","B024","B028","UP037"] +ignore = ["E501","E203","B024","B028","UP037", "UP035", "UP006"] # Allow autofix for all enabled rules (when `--fix`) is provided. fixable = ["ALL"] diff --git a/tests/avro/test_resolver.py b/tests/avro/test_resolver.py index decd9060a4..b5388b5ebb 100644 --- a/tests/avro/test_resolver.py +++ b/tests/avro/test_resolver.py @@ -322,30 +322,34 @@ def test_resolver_initial_value() -> None: def test_resolve_writer() -> None: actual = resolve_writer(record_schema=MANIFEST_ENTRY_SCHEMAS[2], file_schema=MANIFEST_ENTRY_SCHEMAS[1]) - expected = StructWriter(( - (0, IntegerWriter()), - (1, IntegerWriter()), + expected = StructWriter( ( - 4, - StructWriter(( - (1, StringWriter()), - (2, StringWriter()), - (3, StructWriter(())), - (4, IntegerWriter()), - (5, IntegerWriter()), - (None, DefaultWriter(writer=IntegerWriter(), value=67108864)), - (6, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=IntegerWriter()))), - (7, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=IntegerWriter()))), - (8, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=IntegerWriter()))), - (9, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=IntegerWriter()))), - (10, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=BinaryWriter()))), - (11, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=BinaryWriter()))), - (12, OptionWriter(option=BinaryWriter())), - (13, OptionWriter(option=ListWriter(element_writer=IntegerWriter()))), - (15, OptionWriter(option=IntegerWriter())), - )), - ), - )) + (0, IntegerWriter()), + (1, IntegerWriter()), + ( + 4, + StructWriter( + ( + (1, StringWriter()), + (2, StringWriter()), + (3, StructWriter(())), + (4, IntegerWriter()), + (5, IntegerWriter()), + (None, DefaultWriter(writer=IntegerWriter(), value=67108864)), + (6, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=IntegerWriter()))), + (7, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=IntegerWriter()))), + (8, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=IntegerWriter()))), + (9, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=IntegerWriter()))), + (10, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=BinaryWriter()))), + (11, OptionWriter(option=MapWriter(key_writer=IntegerWriter(), value_writer=BinaryWriter()))), + (12, OptionWriter(option=BinaryWriter())), + (13, OptionWriter(option=ListWriter(element_writer=IntegerWriter()))), + (15, OptionWriter(option=IntegerWriter())), + ) + ), + ), + ) + ) assert actual == expected diff --git a/tests/avro/test_writer.py b/tests/avro/test_writer.py index 5a531c7748..39b8ecc393 100644 --- a/tests/avro/test_writer.py +++ b/tests/avro/test_writer.py @@ -178,15 +178,17 @@ class MyStruct(Record): construct_writer(schema).write(encoder, my_struct) - assert output.getbuffer() == b"".join([ - b"\x18", - zigzag_encode(len(my_struct.properties)), - zigzag_encode(1), - zigzag_encode(2), - zigzag_encode(3), - zigzag_encode(4), - b"\x00", - ]) + assert output.getbuffer() == b"".join( + [ + b"\x18", + zigzag_encode(len(my_struct.properties)), + zigzag_encode(1), + zigzag_encode(2), + zigzag_encode(3), + zigzag_encode(4), + b"\x00", + ] + ) def test_write_struct_with_list() -> None: @@ -206,15 +208,17 @@ class MyStruct(Record): construct_writer(schema).write(encoder, my_struct) - assert output.getbuffer() == b"".join([ - b"\x18", - zigzag_encode(len(my_struct.properties)), - zigzag_encode(1), - zigzag_encode(2), - zigzag_encode(3), - zigzag_encode(4), - b"\x00", - ]) + assert output.getbuffer() == b"".join( + [ + b"\x18", + zigzag_encode(len(my_struct.properties)), + zigzag_encode(1), + zigzag_encode(2), + zigzag_encode(3), + zigzag_encode(4), + b"\x00", + ] + ) def test_write_decimal() -> None: diff --git a/tests/catalog/test_rest.py b/tests/catalog/test_rest.py index 2a4b3a7a1f..21aa9677bd 100644 --- a/tests/catalog/test_rest.py +++ b/tests/catalog/test_rest.py @@ -323,19 +323,19 @@ def test_properties_sets_headers(requests_mock: Mocker) -> None: **{"header.Content-Type": "application/vnd.api+json", "header.Customized-Header": "some/value"}, ) - assert catalog._session.headers.get("Content-type") == "application/json", ( - "Expected 'Content-Type' default header not to be overwritten" - ) - assert requests_mock.last_request.headers["Content-type"] == "application/json", ( - "Config request did not include expected 'Content-Type' header" - ) + assert ( + catalog._session.headers.get("Content-type") == "application/json" + ), "Expected 'Content-Type' default header not to be overwritten" + assert ( + requests_mock.last_request.headers["Content-type"] == "application/json" + ), "Config request did not include expected 'Content-Type' header" - assert catalog._session.headers.get("Customized-Header") == "some/value", ( - "Expected 'Customized-Header' header to be 'some/value'" - ) - assert requests_mock.last_request.headers["Customized-Header"] == "some/value", ( - "Config request did not include expected 'Customized-Header' header" - ) + assert ( + catalog._session.headers.get("Customized-Header") == "some/value" + ), "Expected 'Customized-Header' header to be 'some/value'" + assert ( + requests_mock.last_request.headers["Customized-Header"] == "some/value" + ), "Config request did not include expected 'Customized-Header' header" def test_config_sets_headers(requests_mock: Mocker) -> None: @@ -352,19 +352,19 @@ def test_config_sets_headers(requests_mock: Mocker) -> None: catalog = RestCatalog("rest", uri=TEST_URI, warehouse="s3://some-bucket") catalog.create_namespace(namespace) - assert catalog._session.headers.get("Content-type") == "application/json", ( - "Expected 'Content-Type' default header not to be overwritten" - ) - assert requests_mock.last_request.headers["Content-type"] == "application/json", ( - "Create namespace request did not include expected 'Content-Type' header" - ) + assert ( + catalog._session.headers.get("Content-type") == "application/json" + ), "Expected 'Content-Type' default header not to be overwritten" + assert ( + requests_mock.last_request.headers["Content-type"] == "application/json" + ), "Create namespace request did not include expected 'Content-Type' header" - assert catalog._session.headers.get("Customized-Header") == "some/value", ( - "Expected 'Customized-Header' header to be 'some/value'" - ) - assert requests_mock.last_request.headers["Customized-Header"] == "some/value", ( - "Create namespace request did not include expected 'Customized-Header' header" - ) + assert ( + catalog._session.headers.get("Customized-Header") == "some/value" + ), "Expected 'Customized-Header' header to be 'some/value'" + assert ( + requests_mock.last_request.headers["Customized-Header"] == "some/value" + ), "Create namespace request did not include expected 'Customized-Header' header" @pytest.mark.filterwarnings( diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py index 7f72568b41..cffc14d9d7 100644 --- a/tests/catalog/test_sql.py +++ b/tests/catalog/test_sql.py @@ -401,12 +401,14 @@ def test_write_pyarrow_schema(catalog: SqlCatalog, table_identifier: Identifier) pa.array([True, None, False, True]), # 'baz' column pa.array([None, "A", "B", "C"]), # 'large' column ], - schema=pa.schema([ - pa.field("foo", pa.large_string(), nullable=True), - pa.field("bar", pa.int32(), nullable=False), - pa.field("baz", pa.bool_(), nullable=True), - pa.field("large", pa.large_string(), nullable=True), - ]), + schema=pa.schema( + [ + pa.field("foo", pa.large_string(), nullable=True), + pa.field("bar", pa.int32(), nullable=False), + pa.field("baz", pa.bool_(), nullable=True), + pa.field("large", pa.large_string(), nullable=True), + ] + ), ) namespace = Catalog.namespace_from(table_identifier) catalog.create_namespace(namespace) @@ -1426,10 +1428,12 @@ def test_write_and_evolve(catalog: SqlCatalog, format_version: int) -> None: "foo": ["a", None, "z"], "bar": [19, None, 25], }, - schema=pa.schema([ - pa.field("foo", pa.large_string(), nullable=True), - pa.field("bar", pa.int32(), nullable=True), - ]), + schema=pa.schema( + [ + pa.field("foo", pa.large_string(), nullable=True), + pa.field("bar", pa.int32(), nullable=True), + ] + ), ) with tbl.transaction() as txn: @@ -1474,10 +1478,12 @@ def test_create_table_transaction(catalog: SqlCatalog, format_version: int) -> N "foo": ["a", None, "z"], "bar": [19, None, 25], }, - schema=pa.schema([ - pa.field("foo", pa.large_string(), nullable=True), - pa.field("bar", pa.int32(), nullable=True), - ]), + schema=pa.schema( + [ + pa.field("foo", pa.large_string(), nullable=True), + pa.field("bar", pa.int32(), nullable=True), + ] + ), ) with catalog.create_table_transaction( diff --git a/tests/conftest.py b/tests/conftest.py index 22329b3882..ef980f3818 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -353,49 +353,57 @@ def table_schema_with_all_types() -> Schema: def pyarrow_schema_simple_without_ids() -> "pa.Schema": import pyarrow as pa - return pa.schema([ - pa.field("foo", pa.string(), nullable=True), - pa.field("bar", pa.int32(), nullable=False), - pa.field("baz", pa.bool_(), nullable=True), - ]) + return pa.schema( + [ + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.int32(), nullable=False), + pa.field("baz", pa.bool_(), nullable=True), + ] + ) @pytest.fixture(scope="session") def pyarrow_schema_nested_without_ids() -> "pa.Schema": import pyarrow as pa - return pa.schema([ - pa.field("foo", pa.string(), nullable=False), - pa.field("bar", pa.int32(), nullable=False), - pa.field("baz", pa.bool_(), nullable=True), - pa.field("qux", pa.list_(pa.string()), nullable=False), - pa.field( - "quux", - pa.map_( - pa.string(), - pa.map_(pa.string(), pa.int32()), + return pa.schema( + [ + pa.field("foo", pa.string(), nullable=False), + pa.field("bar", pa.int32(), nullable=False), + pa.field("baz", pa.bool_(), nullable=True), + pa.field("qux", pa.list_(pa.string()), nullable=False), + pa.field( + "quux", + pa.map_( + pa.string(), + pa.map_(pa.string(), pa.int32()), + ), + nullable=False, ), - nullable=False, - ), - pa.field( - "location", - pa.list_( - pa.struct([ - pa.field("latitude", pa.float32(), nullable=False), - pa.field("longitude", pa.float32(), nullable=False), - ]), + pa.field( + "location", + pa.list_( + pa.struct( + [ + pa.field("latitude", pa.float32(), nullable=False), + pa.field("longitude", pa.float32(), nullable=False), + ] + ), + ), + nullable=False, ), - nullable=False, - ), - pa.field( - "person", - pa.struct([ - pa.field("name", pa.string(), nullable=True), - pa.field("age", pa.int32(), nullable=False), - ]), - nullable=True, - ), - ]) + pa.field( + "person", + pa.struct( + [ + pa.field("name", pa.string(), nullable=True), + pa.field("age", pa.int32(), nullable=False), + ] + ), + nullable=True, + ), + ] + ) @pytest.fixture(scope="session") @@ -2314,26 +2322,28 @@ def spark() -> "SparkSession": def pa_schema() -> "pa.Schema": import pyarrow as pa - return pa.schema([ - ("bool", pa.bool_()), - ("string", pa.large_string()), - ("string_long", pa.large_string()), - ("int", pa.int32()), - ("long", pa.int64()), - ("float", pa.float32()), - ("double", pa.float64()), - # Not supported by Spark - # ("time", pa.time64('us')), - ("timestamp", pa.timestamp(unit="us")), - ("timestamptz", pa.timestamp(unit="us", tz="UTC")), - ("date", pa.date32()), - # Not supported by Spark - # ("time", pa.time64("us")), - # Not natively supported by Arrow - # ("uuid", pa.fixed(16)), - ("binary", pa.large_binary()), - ("fixed", pa.binary(16)), - ]) + return pa.schema( + [ + ("bool", pa.bool_()), + ("string", pa.large_string()), + ("string_long", pa.large_string()), + ("int", pa.int32()), + ("long", pa.int64()), + ("float", pa.float32()), + ("double", pa.float64()), + # Not supported by Spark + # ("time", pa.time64('us')), + ("timestamp", pa.timestamp(unit="us")), + ("timestamptz", pa.timestamp(unit="us", tz="UTC")), + ("date", pa.date32()), + # Not supported by Spark + # ("time", pa.time64("us")), + # Not natively supported by Arrow + # ("uuid", pa.fixed(16)), + ("binary", pa.large_binary()), + ("fixed", pa.binary(16)), + ] + ) @pytest.fixture(scope="session") @@ -2415,11 +2425,13 @@ def arrow_table_date_timestamps() -> "pa.Table": None, ], }, - schema=pa.schema([ - ("date", pa.date32()), - ("timestamp", pa.timestamp(unit="us")), - ("timestamptz", pa.timestamp(unit="us", tz="UTC")), - ]), + schema=pa.schema( + [ + ("date", pa.date32()), + ("timestamp", pa.timestamp(unit="us")), + ("timestamptz", pa.timestamp(unit="us", tz="UTC")), + ] + ), ) @@ -2438,19 +2450,21 @@ def arrow_table_schema_with_all_timestamp_precisions() -> "pa.Schema": """Pyarrow Schema with all supported timestamp types.""" import pyarrow as pa - return pa.schema([ - ("timestamp_s", pa.timestamp(unit="s")), - ("timestamptz_s", pa.timestamp(unit="s", tz="UTC")), - ("timestamp_ms", pa.timestamp(unit="ms")), - ("timestamptz_ms", pa.timestamp(unit="ms", tz="UTC")), - ("timestamp_us", pa.timestamp(unit="us")), - ("timestamptz_us", pa.timestamp(unit="us", tz="UTC")), - ("timestamp_ns", pa.timestamp(unit="ns")), - ("timestamptz_ns", pa.timestamp(unit="ns", tz="UTC")), - ("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="Etc/UTC")), - ("timestamptz_ns_z", pa.timestamp(unit="ns", tz="Z")), - ("timestamptz_s_0000", pa.timestamp(unit="s", tz="+00:00")), - ]) + return pa.schema( + [ + ("timestamp_s", pa.timestamp(unit="s")), + ("timestamptz_s", pa.timestamp(unit="s", tz="UTC")), + ("timestamp_ms", pa.timestamp(unit="ms")), + ("timestamptz_ms", pa.timestamp(unit="ms", tz="UTC")), + ("timestamp_us", pa.timestamp(unit="us")), + ("timestamptz_us", pa.timestamp(unit="us", tz="UTC")), + ("timestamp_ns", pa.timestamp(unit="ns")), + ("timestamptz_ns", pa.timestamp(unit="ns", tz="UTC")), + ("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="Etc/UTC")), + ("timestamptz_ns_z", pa.timestamp(unit="ns", tz="Z")), + ("timestamptz_s_0000", pa.timestamp(unit="s", tz="+00:00")), + ] + ) @pytest.fixture(scope="session") @@ -2459,51 +2473,53 @@ def arrow_table_with_all_timestamp_precisions(arrow_table_schema_with_all_timest import pandas as pd import pyarrow as pa - test_data = pd.DataFrame({ - "timestamp_s": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], - "timestamptz_s": [ - datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), - None, - datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), - ], - "timestamp_ms": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], - "timestamptz_ms": [ - datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), - None, - datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), - ], - "timestamp_us": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], - "timestamptz_us": [ - datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), - None, - datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), - ], - "timestamp_ns": [ - pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=6), - None, - pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=7), - ], - "timestamptz_ns": [ - datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), - None, - datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), - ], - "timestamptz_us_etc_utc": [ - datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), - None, - datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), - ], - "timestamptz_ns_z": [ - pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=6, tz="UTC"), - None, - pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=7, tz="UTC"), - ], - "timestamptz_s_0000": [ - datetime(2023, 1, 1, 19, 25, 1, tzinfo=timezone.utc), - None, - datetime(2023, 3, 1, 19, 25, 1, tzinfo=timezone.utc), - ], - }) + test_data = pd.DataFrame( + { + "timestamp_s": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], + "timestamptz_s": [ + datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), + ], + "timestamp_ms": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], + "timestamptz_ms": [ + datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), + ], + "timestamp_us": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], + "timestamptz_us": [ + datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), + ], + "timestamp_ns": [ + pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=6), + None, + pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=7), + ], + "timestamptz_ns": [ + datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), + ], + "timestamptz_us_etc_utc": [ + datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), + ], + "timestamptz_ns_z": [ + pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=6, tz="UTC"), + None, + pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=7, tz="UTC"), + ], + "timestamptz_s_0000": [ + datetime(2023, 1, 1, 19, 25, 1, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 1, tzinfo=timezone.utc), + ], + } + ) return pa.Table.from_pandas(test_data, schema=arrow_table_schema_with_all_timestamp_precisions) @@ -2512,19 +2528,21 @@ def arrow_table_schema_with_all_microseconds_timestamp_precisions() -> "pa.Schem """Pyarrow Schema with all microseconds timestamp.""" import pyarrow as pa - return pa.schema([ - ("timestamp_s", pa.timestamp(unit="us")), - ("timestamptz_s", pa.timestamp(unit="us", tz="UTC")), - ("timestamp_ms", pa.timestamp(unit="us")), - ("timestamptz_ms", pa.timestamp(unit="us", tz="UTC")), - ("timestamp_us", pa.timestamp(unit="us")), - ("timestamptz_us", pa.timestamp(unit="us", tz="UTC")), - ("timestamp_ns", pa.timestamp(unit="us")), - ("timestamptz_ns", pa.timestamp(unit="us", tz="UTC")), - ("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="UTC")), - ("timestamptz_ns_z", pa.timestamp(unit="us", tz="UTC")), - ("timestamptz_s_0000", pa.timestamp(unit="us", tz="UTC")), - ]) + return pa.schema( + [ + ("timestamp_s", pa.timestamp(unit="us")), + ("timestamptz_s", pa.timestamp(unit="us", tz="UTC")), + ("timestamp_ms", pa.timestamp(unit="us")), + ("timestamptz_ms", pa.timestamp(unit="us", tz="UTC")), + ("timestamp_us", pa.timestamp(unit="us")), + ("timestamptz_us", pa.timestamp(unit="us", tz="UTC")), + ("timestamp_ns", pa.timestamp(unit="us")), + ("timestamptz_ns", pa.timestamp(unit="us", tz="UTC")), + ("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="UTC")), + ("timestamptz_ns_z", pa.timestamp(unit="us", tz="UTC")), + ("timestamptz_s_0000", pa.timestamp(unit="us", tz="UTC")), + ] + ) @pytest.fixture(scope="session") @@ -2578,13 +2596,15 @@ def pyarrow_schema_with_promoted_types() -> "pa.Schema": """Pyarrow Schema with longs, doubles and uuid in simple and nested types.""" import pyarrow as pa - return pa.schema(( - pa.field("long", pa.int32(), nullable=True), # can support upcasting integer to long - pa.field("list", pa.list_(pa.int32()), nullable=False), # can support upcasting integer to long - pa.field("map", pa.map_(pa.string(), pa.int32()), nullable=False), # can support upcasting integer to long - pa.field("double", pa.float32(), nullable=True), # can support upcasting float to double - pa.field("uuid", pa.binary(length=16), nullable=True), # can support upcasting float to double - )) + return pa.schema( + ( + pa.field("long", pa.int32(), nullable=True), # can support upcasting integer to long + pa.field("list", pa.list_(pa.int32()), nullable=False), # can support upcasting integer to long + pa.field("map", pa.map_(pa.string(), pa.int32()), nullable=False), # can support upcasting integer to long + pa.field("double", pa.float32(), nullable=True), # can support upcasting float to double + pa.field("uuid", pa.binary(length=16), nullable=True), # can support upcasting float to double + ) + ) @pytest.fixture(scope="session") diff --git a/tests/expressions/test_evaluator.py b/tests/expressions/test_evaluator.py index f8a9a8806d..e2b1f27377 100644 --- a/tests/expressions/test_evaluator.py +++ b/tests/expressions/test_evaluator.py @@ -681,25 +681,25 @@ def data_file_nan() -> DataFile: def test_inclusive_metrics_evaluator_less_than_and_less_than_equal(schema_data_file_nan: Schema, data_file_nan: DataFile) -> None: for operator in [LessThan, LessThanOrEqual]: - should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan) + should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan) # type: ignore[arg-type] assert not should_read, "Should not match: all nan column doesn't contain number" - should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 1)).eval(data_file_nan) + should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 1)).eval(data_file_nan) # type: ignore[arg-type] assert not should_read, "Should not match: 1 is smaller than lower bound" - should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 10)).eval(data_file_nan) + should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 10)).eval(data_file_nan) # type: ignore[arg-type] assert should_read, "Should match: 10 is larger than lower bound" - should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("min_max_nan", 1)).eval(data_file_nan) + should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("min_max_nan", 1)).eval(data_file_nan) # type: ignore[arg-type] assert should_read, "Should match: no visibility" - should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan_null_bounds", 1)).eval(data_file_nan) + should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan_null_bounds", 1)).eval(data_file_nan) # type: ignore[arg-type] assert not should_read, "Should not match: all nan column doesn't contain number" - should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 1)).eval(data_file_nan) + should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 1)).eval(data_file_nan) # type: ignore[arg-type] assert not should_read, "Should not match: 1 is smaller than lower bound" - should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 10)).eval( + should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 10)).eval( # type: ignore[arg-type] data_file_nan ) assert should_read, "Should match: 10 larger than lower bound" @@ -709,30 +709,30 @@ def test_inclusive_metrics_evaluator_greater_than_and_greater_than_equal( schema_data_file_nan: Schema, data_file_nan: DataFile ) -> None: for operator in [GreaterThan, GreaterThanOrEqual]: - should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan) + should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 1)).eval(data_file_nan) # type: ignore[arg-type] assert not should_read, "Should not match: all nan column doesn't contain number" - should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 1)).eval(data_file_nan) + should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 1)).eval(data_file_nan) # type: ignore[arg-type] assert should_read, "Should match: upper bound is larger than 1" - should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 10)).eval(data_file_nan) + should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("max_nan", 10)).eval(data_file_nan) # type: ignore[arg-type] assert should_read, "Should match: upper bound is larger than 10" - should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("min_max_nan", 1)).eval(data_file_nan) + should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("min_max_nan", 1)).eval(data_file_nan) # type: ignore[arg-type] assert should_read, "Should match: no visibility" - should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan_null_bounds", 1)).eval(data_file_nan) + should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan_null_bounds", 1)).eval(data_file_nan) # type: ignore[arg-type] assert not should_read, "Should not match: all nan column doesn't contain number" - should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 1)).eval(data_file_nan) + should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 1)).eval(data_file_nan) # type: ignore[arg-type] assert should_read, "Should match: 1 is smaller than upper bound" - should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 10)).eval( + should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("some_nan_correct_bounds", 10)).eval( # type: ignore[arg-type] data_file_nan ) assert should_read, "Should match: 10 is smaller than upper bound" - should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 30)).eval(data_file_nan) + should_read = _InclusiveMetricsEvaluator(schema_data_file_nan, operator("all_nan", 30)).eval(data_file_nan) # type: ignore[arg-type] assert not should_read, "Should not match: 30 is greater than upper bound" diff --git a/tests/expressions/test_visitors.py b/tests/expressions/test_visitors.py index d61c193719..94bfcf076c 100644 --- a/tests/expressions/test_visitors.py +++ b/tests/expressions/test_visitors.py @@ -947,95 +947,95 @@ def manifest() -> ManifestFile: def test_all_nulls(schema: Schema, manifest: ManifestFile) -> None: - assert not _ManifestEvalVisitor(schema, NotNull(Reference("all_nulls_missing_nan")), case_sensitive=True).eval(manifest), ( - "Should skip: all nulls column with non-floating type contains all null" - ) + assert not _ManifestEvalVisitor(schema, NotNull(Reference("all_nulls_missing_nan")), case_sensitive=True).eval( + manifest + ), "Should skip: all nulls column with non-floating type contains all null" - assert _ManifestEvalVisitor(schema, NotNull(Reference("all_nulls_missing_nan_float")), case_sensitive=True).eval(manifest), ( - "Should read: no NaN information may indicate presence of NaN value" - ) + assert _ManifestEvalVisitor(schema, NotNull(Reference("all_nulls_missing_nan_float")), case_sensitive=True).eval( + manifest + ), "Should read: no NaN information may indicate presence of NaN value" - assert _ManifestEvalVisitor(schema, NotNull(Reference("some_nulls")), case_sensitive=True).eval(manifest), ( - "Should read: column with some nulls contains a non-null value" - ) + assert _ManifestEvalVisitor(schema, NotNull(Reference("some_nulls")), case_sensitive=True).eval( + manifest + ), "Should read: column with some nulls contains a non-null value" - assert _ManifestEvalVisitor(schema, NotNull(Reference("no_nulls")), case_sensitive=True).eval(manifest), ( - "Should read: non-null column contains a non-null value" - ) + assert _ManifestEvalVisitor(schema, NotNull(Reference("no_nulls")), case_sensitive=True).eval( + manifest + ), "Should read: non-null column contains a non-null value" def test_no_nulls(schema: Schema, manifest: ManifestFile) -> None: - assert _ManifestEvalVisitor(schema, IsNull(Reference("all_nulls_missing_nan")), case_sensitive=True).eval(manifest), ( - "Should read: at least one null value in all null column" - ) + assert _ManifestEvalVisitor(schema, IsNull(Reference("all_nulls_missing_nan")), case_sensitive=True).eval( + manifest + ), "Should read: at least one null value in all null column" - assert _ManifestEvalVisitor(schema, IsNull(Reference("some_nulls")), case_sensitive=True).eval(manifest), ( - "Should read: column with some nulls contains a null value" - ) + assert _ManifestEvalVisitor(schema, IsNull(Reference("some_nulls")), case_sensitive=True).eval( + manifest + ), "Should read: column with some nulls contains a null value" - assert not _ManifestEvalVisitor(schema, IsNull(Reference("no_nulls")), case_sensitive=True).eval(manifest), ( - "Should skip: non-null column contains no null values" - ) + assert not _ManifestEvalVisitor(schema, IsNull(Reference("no_nulls")), case_sensitive=True).eval( + manifest + ), "Should skip: non-null column contains no null values" - assert _ManifestEvalVisitor(schema, IsNull(Reference("both_nan_and_null")), case_sensitive=True).eval(manifest), ( - "Should read: both_nan_and_null column contains no null values" - ) + assert _ManifestEvalVisitor(schema, IsNull(Reference("both_nan_and_null")), case_sensitive=True).eval( + manifest + ), "Should read: both_nan_and_null column contains no null values" def test_is_nan(schema: Schema, manifest: ManifestFile) -> None: - assert _ManifestEvalVisitor(schema, IsNaN(Reference("float")), case_sensitive=True).eval(manifest), ( - "Should read: no information on if there are nan value in float column" - ) + assert _ManifestEvalVisitor(schema, IsNaN(Reference("float")), case_sensitive=True).eval( + manifest + ), "Should read: no information on if there are nan value in float column" - assert _ManifestEvalVisitor(schema, IsNaN(Reference("all_nulls_double")), case_sensitive=True).eval(manifest), ( - "Should read: no NaN information may indicate presence of NaN value" - ) + assert _ManifestEvalVisitor(schema, IsNaN(Reference("all_nulls_double")), case_sensitive=True).eval( + manifest + ), "Should read: no NaN information may indicate presence of NaN value" - assert _ManifestEvalVisitor(schema, IsNaN(Reference("all_nulls_missing_nan_float")), case_sensitive=True).eval(manifest), ( - "Should read: no NaN information may indicate presence of NaN value" - ) + assert _ManifestEvalVisitor(schema, IsNaN(Reference("all_nulls_missing_nan_float")), case_sensitive=True).eval( + manifest + ), "Should read: no NaN information may indicate presence of NaN value" - assert not _ManifestEvalVisitor(schema, IsNaN(Reference("all_nulls_no_nans")), case_sensitive=True).eval(manifest), ( - "Should skip: no nan column doesn't contain nan value" - ) + assert not _ManifestEvalVisitor(schema, IsNaN(Reference("all_nulls_no_nans")), case_sensitive=True).eval( + manifest + ), "Should skip: no nan column doesn't contain nan value" - assert _ManifestEvalVisitor(schema, IsNaN(Reference("all_nans")), case_sensitive=True).eval(manifest), ( - "Should read: all_nans column contains nan value" - ) + assert _ManifestEvalVisitor(schema, IsNaN(Reference("all_nans")), case_sensitive=True).eval( + manifest + ), "Should read: all_nans column contains nan value" - assert _ManifestEvalVisitor(schema, IsNaN(Reference("both_nan_and_null")), case_sensitive=True).eval(manifest), ( - "Should read: both_nan_and_null column contains nan value" - ) + assert _ManifestEvalVisitor(schema, IsNaN(Reference("both_nan_and_null")), case_sensitive=True).eval( + manifest + ), "Should read: both_nan_and_null column contains nan value" - assert not _ManifestEvalVisitor(schema, IsNaN(Reference("no_nan_or_null")), case_sensitive=True).eval(manifest), ( - "Should skip: no_nan_or_null column doesn't contain nan value" - ) + assert not _ManifestEvalVisitor(schema, IsNaN(Reference("no_nan_or_null")), case_sensitive=True).eval( + manifest + ), "Should skip: no_nan_or_null column doesn't contain nan value" def test_not_nan(schema: Schema, manifest: ManifestFile) -> None: - assert _ManifestEvalVisitor(schema, NotNaN(Reference("float")), case_sensitive=True).eval(manifest), ( - "Should read: no information on if there are nan value in float column" - ) + assert _ManifestEvalVisitor(schema, NotNaN(Reference("float")), case_sensitive=True).eval( + manifest + ), "Should read: no information on if there are nan value in float column" - assert _ManifestEvalVisitor(schema, NotNaN(Reference("all_nulls_double")), case_sensitive=True).eval(manifest), ( - "Should read: all null column contains non nan value" - ) + assert _ManifestEvalVisitor(schema, NotNaN(Reference("all_nulls_double")), case_sensitive=True).eval( + manifest + ), "Should read: all null column contains non nan value" - assert _ManifestEvalVisitor(schema, NotNaN(Reference("all_nulls_no_nans")), case_sensitive=True).eval(manifest), ( - "Should read: no_nans column contains non nan value" - ) + assert _ManifestEvalVisitor(schema, NotNaN(Reference("all_nulls_no_nans")), case_sensitive=True).eval( + manifest + ), "Should read: no_nans column contains non nan value" - assert not _ManifestEvalVisitor(schema, NotNaN(Reference("all_nans")), case_sensitive=True).eval(manifest), ( - "Should skip: all nans column doesn't contain non nan value" - ) + assert not _ManifestEvalVisitor(schema, NotNaN(Reference("all_nans")), case_sensitive=True).eval( + manifest + ), "Should skip: all nans column doesn't contain non nan value" - assert _ManifestEvalVisitor(schema, NotNaN(Reference("both_nan_and_null")), case_sensitive=True).eval(manifest), ( - "Should read: both_nan_and_null nans column contains non nan value" - ) + assert _ManifestEvalVisitor(schema, NotNaN(Reference("both_nan_and_null")), case_sensitive=True).eval( + manifest + ), "Should read: both_nan_and_null nans column contains non nan value" - assert _ManifestEvalVisitor(schema, NotNaN(Reference("no_nan_or_null")), case_sensitive=True).eval(manifest), ( - "Should read: no_nan_or_null column contains non nan value" - ) + assert _ManifestEvalVisitor(schema, NotNaN(Reference("no_nan_or_null")), case_sensitive=True).eval( + manifest + ), "Should read: no_nan_or_null column contains non nan value" def test_missing_stats(schema: Schema, manifest_no_stats: ManifestFile) -> None: @@ -1053,15 +1053,15 @@ def test_missing_stats(schema: Schema, manifest_no_stats: ManifestFile) -> None: ] for expr in expressions: - assert _ManifestEvalVisitor(schema, expr, case_sensitive=True).eval(manifest_no_stats), ( - f"Should read when missing stats for expr: {expr}" - ) + assert _ManifestEvalVisitor(schema, expr, case_sensitive=True).eval( + manifest_no_stats + ), f"Should read when missing stats for expr: {expr}" def test_not(schema: Schema, manifest: ManifestFile) -> None: - assert _ManifestEvalVisitor(schema, Not(LessThan(Reference("id"), INT_MIN_VALUE - 25)), case_sensitive=True).eval(manifest), ( - "Should read: not(false)" - ) + assert _ManifestEvalVisitor(schema, Not(LessThan(Reference("id"), INT_MIN_VALUE - 25)), case_sensitive=True).eval( + manifest + ), "Should read: not(false)" assert not _ManifestEvalVisitor(schema, Not(GreaterThan(Reference("id"), INT_MIN_VALUE - 25)), case_sensitive=True).eval( manifest @@ -1118,21 +1118,21 @@ def test_or(schema: Schema, manifest: ManifestFile) -> None: def test_integer_lt(schema: Schema, manifest: ManifestFile) -> None: - assert not _ManifestEvalVisitor(schema, LessThan(Reference("id"), INT_MIN_VALUE - 25), case_sensitive=True).eval(manifest), ( - "Should not read: id range below lower bound (5 < 30)" - ) + assert not _ManifestEvalVisitor(schema, LessThan(Reference("id"), INT_MIN_VALUE - 25), case_sensitive=True).eval( + manifest + ), "Should not read: id range below lower bound (5 < 30)" - assert not _ManifestEvalVisitor(schema, LessThan(Reference("id"), INT_MIN_VALUE), case_sensitive=True).eval(manifest), ( - "Should not read: id range below lower bound (30 is not < 30)" - ) + assert not _ManifestEvalVisitor(schema, LessThan(Reference("id"), INT_MIN_VALUE), case_sensitive=True).eval( + manifest + ), "Should not read: id range below lower bound (30 is not < 30)" - assert _ManifestEvalVisitor(schema, LessThan(Reference("id"), INT_MIN_VALUE + 1), case_sensitive=True).eval(manifest), ( - "Should read: one possible id" - ) + assert _ManifestEvalVisitor(schema, LessThan(Reference("id"), INT_MIN_VALUE + 1), case_sensitive=True).eval( + manifest + ), "Should read: one possible id" - assert _ManifestEvalVisitor(schema, LessThan(Reference("id"), INT_MAX_VALUE), case_sensitive=True).eval(manifest), ( - "Should read: may possible ids" - ) + assert _ManifestEvalVisitor(schema, LessThan(Reference("id"), INT_MAX_VALUE), case_sensitive=True).eval( + manifest + ), "Should read: may possible ids" def test_integer_lt_eq(schema: Schema, manifest: ManifestFile) -> None: @@ -1144,13 +1144,13 @@ def test_integer_lt_eq(schema: Schema, manifest: ManifestFile) -> None: manifest ), "Should not read: id range below lower bound (29 < 30)" - assert _ManifestEvalVisitor(schema, LessThanOrEqual(Reference("id"), INT_MIN_VALUE), case_sensitive=True).eval(manifest), ( - "Should read: one possible id" - ) + assert _ManifestEvalVisitor(schema, LessThanOrEqual(Reference("id"), INT_MIN_VALUE), case_sensitive=True).eval( + manifest + ), "Should read: one possible id" - assert _ManifestEvalVisitor(schema, LessThanOrEqual(Reference("id"), INT_MAX_VALUE), case_sensitive=True).eval(manifest), ( - "Should read: many possible ids" - ) + assert _ManifestEvalVisitor(schema, LessThanOrEqual(Reference("id"), INT_MAX_VALUE), case_sensitive=True).eval( + manifest + ), "Should read: many possible ids" def test_integer_gt(schema: Schema, manifest: ManifestFile) -> None: @@ -1158,17 +1158,17 @@ def test_integer_gt(schema: Schema, manifest: ManifestFile) -> None: manifest ), "Should not read: id range above upper bound (85 < 79)" - assert not _ManifestEvalVisitor(schema, GreaterThan(Reference("id"), INT_MAX_VALUE), case_sensitive=True).eval(manifest), ( - "Should not read: id range above upper bound (79 is not > 79)" - ) + assert not _ManifestEvalVisitor(schema, GreaterThan(Reference("id"), INT_MAX_VALUE), case_sensitive=True).eval( + manifest + ), "Should not read: id range above upper bound (79 is not > 79)" - assert _ManifestEvalVisitor(schema, GreaterThan(Reference("id"), INT_MAX_VALUE - 1), case_sensitive=True).eval(manifest), ( - "Should read: one possible id" - ) + assert _ManifestEvalVisitor(schema, GreaterThan(Reference("id"), INT_MAX_VALUE - 1), case_sensitive=True).eval( + manifest + ), "Should read: one possible id" - assert _ManifestEvalVisitor(schema, GreaterThan(Reference("id"), INT_MAX_VALUE - 4), case_sensitive=True).eval(manifest), ( - "Should read: may possible ids" - ) + assert _ManifestEvalVisitor(schema, GreaterThan(Reference("id"), INT_MAX_VALUE - 4), case_sensitive=True).eval( + manifest + ), "Should read: may possible ids" def test_integer_gt_eq(schema: Schema, manifest: ManifestFile) -> None: @@ -1180,133 +1180,133 @@ def test_integer_gt_eq(schema: Schema, manifest: ManifestFile) -> None: manifest ), "Should not read: id range above upper bound (80 > 79)" - assert _ManifestEvalVisitor(schema, GreaterThanOrEqual(Reference("id"), INT_MAX_VALUE), case_sensitive=True).eval(manifest), ( - "Should read: one possible id" - ) + assert _ManifestEvalVisitor(schema, GreaterThanOrEqual(Reference("id"), INT_MAX_VALUE), case_sensitive=True).eval( + manifest + ), "Should read: one possible id" - assert _ManifestEvalVisitor(schema, GreaterThanOrEqual(Reference("id"), INT_MAX_VALUE), case_sensitive=True).eval(manifest), ( - "Should read: may possible ids" - ) + assert _ManifestEvalVisitor(schema, GreaterThanOrEqual(Reference("id"), INT_MAX_VALUE), case_sensitive=True).eval( + manifest + ), "Should read: may possible ids" def test_integer_eq(schema: Schema, manifest: ManifestFile) -> None: - assert not _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MIN_VALUE - 25), case_sensitive=True).eval(manifest), ( - "Should not read: id below lower bound" - ) + assert not _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MIN_VALUE - 25), case_sensitive=True).eval( + manifest + ), "Should not read: id below lower bound" - assert not _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MIN_VALUE - 1), case_sensitive=True).eval(manifest), ( - "Should not read: id below lower bound" - ) + assert not _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MIN_VALUE - 1), case_sensitive=True).eval( + manifest + ), "Should not read: id below lower bound" - assert _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MIN_VALUE), case_sensitive=True).eval(manifest), ( - "Should read: id equal to lower bound" - ) + assert _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MIN_VALUE), case_sensitive=True).eval( + manifest + ), "Should read: id equal to lower bound" - assert _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MAX_VALUE - 4), case_sensitive=True).eval(manifest), ( - "Should read: id between lower and upper bounds" - ) + assert _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MAX_VALUE - 4), case_sensitive=True).eval( + manifest + ), "Should read: id between lower and upper bounds" - assert _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MAX_VALUE), case_sensitive=True).eval(manifest), ( - "Should read: id equal to upper bound" - ) + assert _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MAX_VALUE), case_sensitive=True).eval( + manifest + ), "Should read: id equal to upper bound" - assert not _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MAX_VALUE + 1), case_sensitive=True).eval(manifest), ( - "Should not read: id above upper bound" - ) + assert not _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MAX_VALUE + 1), case_sensitive=True).eval( + manifest + ), "Should not read: id above upper bound" - assert not _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MAX_VALUE + 6), case_sensitive=True).eval(manifest), ( - "Should not read: id above upper bound" - ) + assert not _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MAX_VALUE + 6), case_sensitive=True).eval( + manifest + ), "Should not read: id above upper bound" def test_integer_not_eq(schema: Schema, manifest: ManifestFile) -> None: - assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MIN_VALUE - 25), case_sensitive=True).eval(manifest), ( - "Should read: id below lower bound" - ) + assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MIN_VALUE - 25), case_sensitive=True).eval( + manifest + ), "Should read: id below lower bound" - assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MIN_VALUE - 1), case_sensitive=True).eval(manifest), ( - "Should read: id below lower bound" - ) + assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MIN_VALUE - 1), case_sensitive=True).eval( + manifest + ), "Should read: id below lower bound" - assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MIN_VALUE), case_sensitive=True).eval(manifest), ( - "Should read: id equal to lower bound" - ) + assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MIN_VALUE), case_sensitive=True).eval( + manifest + ), "Should read: id equal to lower bound" - assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MAX_VALUE - 4), case_sensitive=True).eval(manifest), ( - "Should read: id between lower and upper bounds" - ) + assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MAX_VALUE - 4), case_sensitive=True).eval( + manifest + ), "Should read: id between lower and upper bounds" - assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MAX_VALUE), case_sensitive=True).eval(manifest), ( - "Should read: id equal to upper bound" - ) + assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MAX_VALUE), case_sensitive=True).eval( + manifest + ), "Should read: id equal to upper bound" - assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MAX_VALUE + 1), case_sensitive=True).eval(manifest), ( - "Should read: id above upper bound" - ) + assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MAX_VALUE + 1), case_sensitive=True).eval( + manifest + ), "Should read: id above upper bound" - assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MAX_VALUE + 6), case_sensitive=True).eval(manifest), ( - "Should read: id above upper bound" - ) + assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MAX_VALUE + 6), case_sensitive=True).eval( + manifest + ), "Should read: id above upper bound" def test_integer_not_eq_rewritten(schema: Schema, manifest: ManifestFile) -> None: - assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MIN_VALUE - 25)), case_sensitive=True).eval(manifest), ( - "Should read: id below lower bound" - ) + assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MIN_VALUE - 25)), case_sensitive=True).eval( + manifest + ), "Should read: id below lower bound" - assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MIN_VALUE - 1)), case_sensitive=True).eval(manifest), ( - "Should read: id below lower bound" - ) + assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MIN_VALUE - 1)), case_sensitive=True).eval( + manifest + ), "Should read: id below lower bound" - assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MIN_VALUE)), case_sensitive=True).eval(manifest), ( - "Should read: id equal to lower bound" - ) + assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MIN_VALUE)), case_sensitive=True).eval( + manifest + ), "Should read: id equal to lower bound" - assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MAX_VALUE - 4)), case_sensitive=True).eval(manifest), ( - "Should read: id between lower and upper bounds" - ) + assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MAX_VALUE - 4)), case_sensitive=True).eval( + manifest + ), "Should read: id between lower and upper bounds" - assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MAX_VALUE)), case_sensitive=True).eval(manifest), ( - "Should read: id equal to upper bound" - ) + assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MAX_VALUE)), case_sensitive=True).eval( + manifest + ), "Should read: id equal to upper bound" - assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MAX_VALUE + 1)), case_sensitive=True).eval(manifest), ( - "Should read: id above upper bound" - ) + assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MAX_VALUE + 1)), case_sensitive=True).eval( + manifest + ), "Should read: id above upper bound" - assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MAX_VALUE + 6)), case_sensitive=True).eval(manifest), ( - "Should read: id above upper bound" - ) + assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MAX_VALUE + 6)), case_sensitive=True).eval( + manifest + ), "Should read: id above upper bound" def test_integer_not_eq_rewritten_case_insensitive(schema: Schema, manifest: ManifestFile) -> None: - assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("ID"), INT_MIN_VALUE - 25)), case_sensitive=False).eval(manifest), ( - "Should read: id below lower bound" - ) + assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("ID"), INT_MIN_VALUE - 25)), case_sensitive=False).eval( + manifest + ), "Should read: id below lower bound" - assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("ID"), INT_MIN_VALUE - 1)), case_sensitive=False).eval(manifest), ( - "Should read: id below lower bound" - ) + assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("ID"), INT_MIN_VALUE - 1)), case_sensitive=False).eval( + manifest + ), "Should read: id below lower bound" - assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("ID"), INT_MIN_VALUE)), case_sensitive=False).eval(manifest), ( - "Should read: id equal to lower bound" - ) + assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("ID"), INT_MIN_VALUE)), case_sensitive=False).eval( + manifest + ), "Should read: id equal to lower bound" - assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("ID"), INT_MAX_VALUE - 4)), case_sensitive=False).eval(manifest), ( - "Should read: id between lower and upper bounds" - ) + assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("ID"), INT_MAX_VALUE - 4)), case_sensitive=False).eval( + manifest + ), "Should read: id between lower and upper bounds" - assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("ID"), INT_MAX_VALUE)), case_sensitive=False).eval(manifest), ( - "Should read: id equal to upper bound" - ) + assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("ID"), INT_MAX_VALUE)), case_sensitive=False).eval( + manifest + ), "Should read: id equal to upper bound" - assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("ID"), INT_MAX_VALUE + 1)), case_sensitive=False).eval(manifest), ( - "Should read: id above upper bound" - ) + assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("ID"), INT_MAX_VALUE + 1)), case_sensitive=False).eval( + manifest + ), "Should read: id above upper bound" - assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("ID"), INT_MAX_VALUE + 6)), case_sensitive=False).eval(manifest), ( - "Should read: id above upper bound" - ) + assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("ID"), INT_MAX_VALUE + 6)), case_sensitive=False).eval( + manifest + ), "Should read: id above upper bound" def test_integer_in(schema: Schema, manifest: ManifestFile) -> None: @@ -1342,13 +1342,13 @@ def test_integer_in(schema: Schema, manifest: ManifestFile) -> None: manifest ), "Should skip: in on all nulls column" - assert _ManifestEvalVisitor(schema, In(Reference("some_nulls"), ("abc", "def")), case_sensitive=True).eval(manifest), ( - "Should read: in on some nulls column" - ) + assert _ManifestEvalVisitor(schema, In(Reference("some_nulls"), ("abc", "def")), case_sensitive=True).eval( + manifest + ), "Should read: in on some nulls column" - assert _ManifestEvalVisitor(schema, In(Reference("no_nulls"), ("abc", "def")), case_sensitive=True).eval(manifest), ( - "Should read: in on no nulls column" - ) + assert _ManifestEvalVisitor(schema, In(Reference("no_nulls"), ("abc", "def")), case_sensitive=True).eval( + manifest + ), "Should read: in on no nulls column" def test_integer_not_in(schema: Schema, manifest: ManifestFile) -> None: @@ -1384,73 +1384,73 @@ def test_integer_not_in(schema: Schema, manifest: ManifestFile) -> None: manifest ), "Should read: notIn on no nulls column" - assert _ManifestEvalVisitor(schema, NotIn(Reference("some_nulls"), ("abc", "def")), case_sensitive=True).eval(manifest), ( - "Should read: in on some nulls column" - ) + assert _ManifestEvalVisitor(schema, NotIn(Reference("some_nulls"), ("abc", "def")), case_sensitive=True).eval( + manifest + ), "Should read: in on some nulls column" - assert _ManifestEvalVisitor(schema, NotIn(Reference("no_nulls"), ("abc", "def")), case_sensitive=True).eval(manifest), ( - "Should read: in on no nulls column" - ) + assert _ManifestEvalVisitor(schema, NotIn(Reference("no_nulls"), ("abc", "def")), case_sensitive=True).eval( + manifest + ), "Should read: in on no nulls column" def test_string_starts_with(schema: Schema, manifest: ManifestFile) -> None: - assert _ManifestEvalVisitor(schema, StartsWith(Reference("some_nulls"), "a"), case_sensitive=False).eval(manifest), ( - "Should read: range matches" - ) + assert _ManifestEvalVisitor(schema, StartsWith(Reference("some_nulls"), "a"), case_sensitive=False).eval( + manifest + ), "Should read: range matches" - assert _ManifestEvalVisitor(schema, StartsWith(Reference("some_nulls"), "aa"), case_sensitive=False).eval(manifest), ( - "Should read: range matches" - ) + assert _ManifestEvalVisitor(schema, StartsWith(Reference("some_nulls"), "aa"), case_sensitive=False).eval( + manifest + ), "Should read: range matches" - assert _ManifestEvalVisitor(schema, StartsWith(Reference("some_nulls"), "dddd"), case_sensitive=False).eval(manifest), ( - "Should read: range matches" - ) + assert _ManifestEvalVisitor(schema, StartsWith(Reference("some_nulls"), "dddd"), case_sensitive=False).eval( + manifest + ), "Should read: range matches" - assert _ManifestEvalVisitor(schema, StartsWith(Reference("some_nulls"), "z"), case_sensitive=False).eval(manifest), ( - "Should read: range matches" - ) + assert _ManifestEvalVisitor(schema, StartsWith(Reference("some_nulls"), "z"), case_sensitive=False).eval( + manifest + ), "Should read: range matches" - assert _ManifestEvalVisitor(schema, StartsWith(Reference("no_nulls"), "a"), case_sensitive=False).eval(manifest), ( - "Should read: range matches" - ) + assert _ManifestEvalVisitor(schema, StartsWith(Reference("no_nulls"), "a"), case_sensitive=False).eval( + manifest + ), "Should read: range matches" - assert not _ManifestEvalVisitor(schema, StartsWith(Reference("some_nulls"), "zzzz"), case_sensitive=False).eval(manifest), ( - "Should skip: range doesn't match" - ) + assert not _ManifestEvalVisitor(schema, StartsWith(Reference("some_nulls"), "zzzz"), case_sensitive=False).eval( + manifest + ), "Should skip: range doesn't match" - assert not _ManifestEvalVisitor(schema, StartsWith(Reference("some_nulls"), "1"), case_sensitive=False).eval(manifest), ( - "Should skip: range doesn't match" - ) + assert not _ManifestEvalVisitor(schema, StartsWith(Reference("some_nulls"), "1"), case_sensitive=False).eval( + manifest + ), "Should skip: range doesn't match" def test_string_not_starts_with(schema: Schema, manifest: ManifestFile) -> None: - assert _ManifestEvalVisitor(schema, NotStartsWith(Reference("some_nulls"), "a"), case_sensitive=False).eval(manifest), ( - "Should read: range matches" - ) + assert _ManifestEvalVisitor(schema, NotStartsWith(Reference("some_nulls"), "a"), case_sensitive=False).eval( + manifest + ), "Should read: range matches" - assert _ManifestEvalVisitor(schema, NotStartsWith(Reference("some_nulls"), "aa"), case_sensitive=False).eval(manifest), ( - "Should read: range matches" - ) + assert _ManifestEvalVisitor(schema, NotStartsWith(Reference("some_nulls"), "aa"), case_sensitive=False).eval( + manifest + ), "Should read: range matches" - assert _ManifestEvalVisitor(schema, NotStartsWith(Reference("some_nulls"), "dddd"), case_sensitive=False).eval(manifest), ( - "Should read: range matches" - ) + assert _ManifestEvalVisitor(schema, NotStartsWith(Reference("some_nulls"), "dddd"), case_sensitive=False).eval( + manifest + ), "Should read: range matches" - assert _ManifestEvalVisitor(schema, NotStartsWith(Reference("some_nulls"), "z"), case_sensitive=False).eval(manifest), ( - "Should read: range matches" - ) + assert _ManifestEvalVisitor(schema, NotStartsWith(Reference("some_nulls"), "z"), case_sensitive=False).eval( + manifest + ), "Should read: range matches" - assert _ManifestEvalVisitor(schema, NotStartsWith(Reference("no_nulls"), "a"), case_sensitive=False).eval(manifest), ( - "Should read: range matches" - ) + assert _ManifestEvalVisitor(schema, NotStartsWith(Reference("no_nulls"), "a"), case_sensitive=False).eval( + manifest + ), "Should read: range matches" - assert _ManifestEvalVisitor(schema, NotStartsWith(Reference("some_nulls"), "zzzz"), case_sensitive=False).eval(manifest), ( - "Should read: range matches" - ) + assert _ManifestEvalVisitor(schema, NotStartsWith(Reference("some_nulls"), "zzzz"), case_sensitive=False).eval( + manifest + ), "Should read: range matches" - assert _ManifestEvalVisitor(schema, NotStartsWith(Reference("some_nulls"), "1"), case_sensitive=False).eval(manifest), ( - "Should read: range matches" - ) + assert _ManifestEvalVisitor(schema, NotStartsWith(Reference("some_nulls"), "1"), case_sensitive=False).eval( + manifest + ), "Should read: range matches" assert _ManifestEvalVisitor(schema, NotStartsWith(Reference("all_same_value_or_null"), "a"), case_sensitive=False).eval( manifest diff --git a/tests/integration/test_add_files.py b/tests/integration/test_add_files.py index 85e626edf4..c1d916e0e0 100644 --- a/tests/integration/test_add_files.py +++ b/tests/integration/test_add_files.py @@ -52,12 +52,14 @@ NestedField(field_id=10, name="qux", field_type=DateType(), required=False), ) -ARROW_SCHEMA = pa.schema([ - ("foo", pa.bool_()), - ("bar", pa.string()), - ("baz", pa.int32()), - ("qux", pa.date32()), -]) +ARROW_SCHEMA = pa.schema( + [ + ("foo", pa.bool_()), + ("bar", pa.string()), + ("baz", pa.int32()), + ("qux", pa.date32()), + ] +) ARROW_TABLE = pa.Table.from_pylist( [ @@ -71,12 +73,14 @@ schema=ARROW_SCHEMA, ) -ARROW_SCHEMA_WITH_IDS = pa.schema([ - pa.field("foo", pa.bool_(), nullable=False, metadata={"PARQUET:field_id": "1"}), - pa.field("bar", pa.string(), nullable=False, metadata={"PARQUET:field_id": "2"}), - pa.field("baz", pa.int32(), nullable=False, metadata={"PARQUET:field_id": "3"}), - pa.field("qux", pa.date32(), nullable=False, metadata={"PARQUET:field_id": "4"}), -]) +ARROW_SCHEMA_WITH_IDS = pa.schema( + [ + pa.field("foo", pa.bool_(), nullable=False, metadata={"PARQUET:field_id": "1"}), + pa.field("bar", pa.string(), nullable=False, metadata={"PARQUET:field_id": "2"}), + pa.field("baz", pa.int32(), nullable=False, metadata={"PARQUET:field_id": "3"}), + pa.field("qux", pa.date32(), nullable=False, metadata={"PARQUET:field_id": "4"}), + ] +) ARROW_TABLE_WITH_IDS = pa.Table.from_pylist( @@ -91,12 +95,14 @@ schema=ARROW_SCHEMA_WITH_IDS, ) -ARROW_SCHEMA_UPDATED = pa.schema([ - ("foo", pa.bool_()), - ("baz", pa.int32()), - ("qux", pa.date32()), - ("quux", pa.int32()), -]) +ARROW_SCHEMA_UPDATED = pa.schema( + [ + ("foo", pa.bool_()), + ("baz", pa.int32()), + ("qux", pa.date32()), + ("quux", pa.int32()), + ] +) ARROW_TABLE_UPDATED = pa.Table.from_pylist( [ @@ -471,12 +477,14 @@ def test_add_files_fails_on_schema_mismatch(spark: SparkSession, session_catalog identifier = f"default.table_schema_mismatch_fails_v{format_version}" tbl = _create_table(session_catalog, identifier, format_version) - WRONG_SCHEMA = pa.schema([ - ("foo", pa.bool_()), - ("bar", pa.string()), - ("baz", pa.string()), # should be integer - ("qux", pa.date32()), - ]) + WRONG_SCHEMA = pa.schema( + [ + ("foo", pa.bool_()), + ("bar", pa.string()), + ("baz", pa.string()), # should be integer + ("qux", pa.date32()), + ] + ) file_path = f"s3://warehouse/default/table_schema_mismatch_fails/v{format_version}/test.parquet" # write parquet files fo = tbl.io.new_output(file_path) @@ -522,12 +530,16 @@ def test_add_files_with_large_and_regular_schema(spark: SparkSession, session_ca identifier = f"default.unpartitioned_with_large_types{format_version}" iceberg_schema = Schema(NestedField(1, "foo", StringType(), required=True)) - arrow_schema = pa.schema([ - pa.field("foo", pa.string(), nullable=False), - ]) - arrow_schema_large = pa.schema([ - pa.field("foo", pa.large_string(), nullable=False), - ]) + arrow_schema = pa.schema( + [ + pa.field("foo", pa.string(), nullable=False), + ] + ) + arrow_schema_large = pa.schema( + [ + pa.field("foo", pa.large_string(), nullable=False), + ] + ) tbl = _create_table(session_catalog, identifier, format_version, schema=iceberg_schema) @@ -576,9 +588,11 @@ def test_add_files_with_large_and_regular_schema(spark: SparkSession, session_ca def test_add_files_with_timestamp_tz_ns_fails(session_catalog: Catalog, format_version: int, mocker: MockerFixture) -> None: nanoseconds_schema_iceberg = Schema(NestedField(1, "quux", TimestamptzType())) - nanoseconds_schema = pa.schema([ - ("quux", pa.timestamp("ns", tz="UTC")), - ]) + nanoseconds_schema = pa.schema( + [ + ("quux", pa.timestamp("ns", tz="UTC")), + ] + ) arrow_table = pa.Table.from_pylist( [ @@ -617,9 +631,11 @@ def test_add_file_with_valid_nullability_diff(spark: SparkSession, session_catal table_schema = Schema( NestedField(field_id=1, name="long", field_type=LongType(), required=False), ) - other_schema = pa.schema(( - pa.field("long", pa.int64(), nullable=False), # can support writing required pyarrow field to optional Iceberg field - )) + other_schema = pa.schema( + ( + pa.field("long", pa.int64(), nullable=False), # can support writing required pyarrow field to optional Iceberg field + ) + ) arrow_table = pa.Table.from_pydict( { "long": [1, 9], @@ -671,13 +687,15 @@ def test_add_files_with_valid_upcast( # table's long field should cast to long on read written_arrow_table = tbl.scan().to_arrow() assert written_arrow_table == pyarrow_table_with_promoted_types.cast( - pa.schema(( - pa.field("long", pa.int64(), nullable=True), - pa.field("list", pa.large_list(pa.int64()), nullable=False), - pa.field("map", pa.map_(pa.large_string(), pa.int64()), nullable=False), - pa.field("double", pa.float64(), nullable=True), - pa.field("uuid", pa.binary(length=16), nullable=True), # can UUID is read as fixed length binary of length 16 - )) + pa.schema( + ( + pa.field("long", pa.int64(), nullable=True), + pa.field("list", pa.large_list(pa.int64()), nullable=False), + pa.field("map", pa.map_(pa.large_string(), pa.int64()), nullable=False), + pa.field("double", pa.float64(), nullable=True), + pa.field("uuid", pa.binary(length=16), nullable=True), # can UUID is read as fixed length binary of length 16 + ) + ) ) lhs = spark.table(f"{identifier}").toPandas() rhs = written_arrow_table.to_pandas() diff --git a/tests/integration/test_deletes.py b/tests/integration/test_deletes.py index f2417bde2d..ae03beea53 100644 --- a/tests/integration/test_deletes.py +++ b/tests/integration/test_deletes.py @@ -746,13 +746,15 @@ def test_delete_after_partition_evolution_from_partitioned(session_catalog: Rest arrow_table = pa.Table.from_arrays( [ pa.array([2, 3, 4, 5, 6]), - pa.array([ - datetime(2021, 5, 19), - datetime(2022, 7, 25), - datetime(2023, 3, 22), - datetime(2024, 7, 17), - datetime(2025, 2, 22), - ]), + pa.array( + [ + datetime(2021, 5, 19), + datetime(2022, 7, 25), + datetime(2023, 3, 22), + datetime(2024, 7, 17), + datetime(2025, 2, 22), + ] + ), ], names=["idx", "ts"], ) diff --git a/tests/integration/test_reads.py b/tests/integration/test_reads.py index 0279c2199a..8d13724087 100644 --- a/tests/integration/test_reads.py +++ b/tests/integration/test_reads.py @@ -833,12 +833,14 @@ def test_table_scan_default_to_large_types(catalog: Catalog) -> None: result_table = tbl.scan().to_arrow() - expected_schema = pa.schema([ - pa.field("string", pa.large_string()), - pa.field("string-to-binary", pa.large_binary()), - pa.field("binary", pa.large_binary()), - pa.field("list", pa.large_list(pa.large_string())), - ]) + expected_schema = pa.schema( + [ + pa.field("string", pa.large_string()), + pa.field("string-to-binary", pa.large_binary()), + pa.field("binary", pa.large_binary()), + pa.field("list", pa.large_list(pa.large_string())), + ] + ) assert result_table.schema.equals(expected_schema) @@ -874,12 +876,14 @@ def test_table_scan_override_with_small_types(catalog: Catalog) -> None: tbl.io.properties[PYARROW_USE_LARGE_TYPES_ON_READ] = "False" result_table = tbl.scan().to_arrow() - expected_schema = pa.schema([ - pa.field("string", pa.string()), - pa.field("string-to-binary", pa.binary()), - pa.field("binary", pa.binary()), - pa.field("list", pa.list_(pa.string())), - ]) + expected_schema = pa.schema( + [ + pa.field("string", pa.string()), + pa.field("string-to-binary", pa.binary()), + pa.field("binary", pa.binary()), + pa.field("list", pa.list_(pa.string())), + ] + ) assert result_table.schema.equals(expected_schema) diff --git a/tests/integration/test_rest_schema.py b/tests/integration/test_rest_schema.py index 8e64142b3f..6a704839e2 100644 --- a/tests/integration/test_rest_schema.py +++ b/tests/integration/test_rest_schema.py @@ -685,11 +685,13 @@ def test_rename_simple(simple_table: Table) -> None: ) # Check that the name mapping gets updated - assert simple_table.name_mapping() == NameMapping([ - MappedField(field_id=1, names=["foo", "vo"]), - MappedField(field_id=2, names=["bar", "var"]), - MappedField(field_id=3, names=["baz"]), - ]) + assert simple_table.name_mapping() == NameMapping( + [ + MappedField(field_id=1, names=["foo", "vo"]), + MappedField(field_id=2, names=["bar", "var"]), + MappedField(field_id=3, names=["baz"]), + ] + ) @pytest.mark.integration @@ -719,9 +721,11 @@ def test_rename_simple_nested(catalog: Catalog) -> None: ) # Check that the name mapping gets updated - assert tbl.name_mapping() == NameMapping([ - MappedField(field_id=1, names=["foo"], fields=[MappedField(field_id=2, names=["bar", "vo"])]), - ]) + assert tbl.name_mapping() == NameMapping( + [ + MappedField(field_id=1, names=["foo"], fields=[MappedField(field_id=2, names=["bar", "vo"])]), + ] + ) @pytest.mark.integration diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index f9c0afd3bc..c23e836554 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -324,20 +324,24 @@ def test_python_writes_special_character_column_with_spark_reads( {"street": "789", "city": "Random", "zip": 10112, column_name_with_special_character: "c"}, ], } - pa_schema = pa.schema([ - pa.field(column_name_with_special_character, pa.string()), - pa.field("id", pa.int32()), - pa.field("name", pa.string()), - pa.field( - "address", - pa.struct([ - pa.field("street", pa.string()), - pa.field("city", pa.string()), - pa.field("zip", pa.int32()), - pa.field(column_name_with_special_character, pa.string()), - ]), - ), - ]) + pa_schema = pa.schema( + [ + pa.field(column_name_with_special_character, pa.string()), + pa.field("id", pa.int32()), + pa.field("name", pa.string()), + pa.field( + "address", + pa.struct( + [ + pa.field("street", pa.string()), + pa.field("city", pa.string()), + pa.field("zip", pa.int32()), + pa.field(column_name_with_special_character, pa.string()), + ] + ), + ), + ] + ) arrow_table_with_special_character_column = pa.Table.from_pydict(TEST_DATA_WITH_SPECIAL_CHARACTER_COLUMN, schema=pa_schema) tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, schema=pa_schema) @@ -357,10 +361,12 @@ def test_python_writes_dictionary_encoded_column_with_spark_reads( "id": [1, 2, 3, 1, 1], "name": ["AB", "CD", "EF", "CD", "EF"], } - pa_schema = pa.schema([ - pa.field("id", pa.dictionary(pa.int32(), pa.int32(), False)), - pa.field("name", pa.dictionary(pa.int32(), pa.string(), False)), - ]) + pa_schema = pa.schema( + [ + pa.field("id", pa.dictionary(pa.int32(), pa.int32(), False)), + pa.field("name", pa.dictionary(pa.int32(), pa.string(), False)), + ] + ) arrow_table = pa.Table.from_pydict(TEST_DATA, schema=pa_schema) tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, schema=pa_schema) @@ -387,20 +393,24 @@ def test_python_writes_with_small_and_large_types_spark_reads( {"street": "789", "city": "Random", "zip": 10112, "bar": "c"}, ], } - pa_schema = pa.schema([ - pa.field("foo", pa.large_string()), - pa.field("id", pa.int32()), - pa.field("name", pa.string()), - pa.field( - "address", - pa.struct([ - pa.field("street", pa.string()), - pa.field("city", pa.string()), - pa.field("zip", pa.int32()), - pa.field("bar", pa.large_string()), - ]), - ), - ]) + pa_schema = pa.schema( + [ + pa.field("foo", pa.large_string()), + pa.field("id", pa.int32()), + pa.field("name", pa.string()), + pa.field( + "address", + pa.struct( + [ + pa.field("street", pa.string()), + pa.field("city", pa.string()), + pa.field("zip", pa.int32()), + pa.field("bar", pa.large_string()), + ] + ), + ), + ] + ) arrow_table = pa.Table.from_pydict(TEST_DATA, schema=pa_schema) tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, schema=pa_schema) @@ -409,20 +419,24 @@ def test_python_writes_with_small_and_large_types_spark_reads( pyiceberg_df = tbl.scan().to_pandas() assert spark_df.equals(pyiceberg_df) arrow_table_on_read = tbl.scan().to_arrow() - assert arrow_table_on_read.schema == pa.schema([ - pa.field("foo", pa.large_string()), - pa.field("id", pa.int32()), - pa.field("name", pa.large_string()), - pa.field( - "address", - pa.struct([ - pa.field("street", pa.large_string()), - pa.field("city", pa.large_string()), - pa.field("zip", pa.int32()), - pa.field("bar", pa.large_string()), - ]), - ), - ]) + assert arrow_table_on_read.schema == pa.schema( + [ + pa.field("foo", pa.large_string()), + pa.field("id", pa.int32()), + pa.field("name", pa.large_string()), + pa.field( + "address", + pa.struct( + [ + pa.field("street", pa.large_string()), + pa.field("city", pa.large_string()), + pa.field("zip", pa.int32()), + pa.field("bar", pa.large_string()), + ] + ), + ), + ] + ) @pytest.mark.integration @@ -718,10 +732,12 @@ def test_write_and_evolve(session_catalog: Catalog, format_version: int) -> None "foo": ["a", None, "z"], "bar": [19, None, 25], }, - schema=pa.schema([ - pa.field("foo", pa.string(), nullable=True), - pa.field("bar", pa.int32(), nullable=True), - ]), + schema=pa.schema( + [ + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.int32(), nullable=True), + ] + ), ) with tbl.transaction() as txn: @@ -761,10 +777,12 @@ def test_create_table_transaction(catalog: Catalog, format_version: int) -> None "foo": ["a", None, "z"], "bar": [19, None, 25], }, - schema=pa.schema([ - pa.field("foo", pa.string(), nullable=True), - pa.field("bar", pa.int32(), nullable=True), - ]), + schema=pa.schema( + [ + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.int32(), nullable=True), + ] + ), ) with catalog.create_table_transaction( @@ -810,9 +828,9 @@ def test_create_table_with_non_default_values(catalog: Catalog, table_schema_wit except NoSuchTableError: pass - iceberg_spec = PartitionSpec(*[ - PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="integer_partition") - ]) + iceberg_spec = PartitionSpec( + *[PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="integer_partition")] + ) sort_order = SortOrder(*[SortField(source_id=2, transform=IdentityTransform(), direction=SortDirection.ASC)]) @@ -1071,9 +1089,11 @@ def test_table_write_schema_with_valid_nullability_diff( table_schema = Schema( NestedField(field_id=1, name="long", field_type=LongType(), required=False), ) - other_schema = pa.schema(( - pa.field("long", pa.int64(), nullable=False), # can support writing required pyarrow field to optional Iceberg field - )) + other_schema = pa.schema( + ( + pa.field("long", pa.int64(), nullable=False), # can support writing required pyarrow field to optional Iceberg field + ) + ) arrow_table = pa.Table.from_pydict( { "long": [1, 9], @@ -1114,13 +1134,15 @@ def test_table_write_schema_with_valid_upcast( # table's long field should cast to long on read written_arrow_table = tbl.scan().to_arrow() assert written_arrow_table == pyarrow_table_with_promoted_types.cast( - pa.schema(( - pa.field("long", pa.int64(), nullable=True), - pa.field("list", pa.large_list(pa.int64()), nullable=False), - pa.field("map", pa.map_(pa.large_string(), pa.int64()), nullable=False), - pa.field("double", pa.float64(), nullable=True), # can support upcasting float to double - pa.field("uuid", pa.binary(length=16), nullable=True), # can UUID is read as fixed length binary of length 16 - )) + pa.schema( + ( + pa.field("long", pa.int64(), nullable=True), + pa.field("list", pa.large_list(pa.int64()), nullable=False), + pa.field("map", pa.map_(pa.large_string(), pa.int64()), nullable=False), + pa.field("double", pa.float64(), nullable=True), # can support upcasting float to double + pa.field("uuid", pa.binary(length=16), nullable=True), # can UUID is read as fixed length binary of length 16 + ) + ) ) lhs = spark.table(f"{identifier}").toPandas() rhs = written_arrow_table.to_pandas() @@ -1510,16 +1532,20 @@ def test_rewrite_manifest_after_partition_evolution(session_catalog: Catalog) -> def test_writing_null_structs(session_catalog: Catalog) -> None: import pyarrow as pa - schema = pa.schema([ - pa.field( - "struct_field_1", - pa.struct([ - pa.field("string_nested_1", pa.string()), - pa.field("int_item_2", pa.int32()), - pa.field("float_item_2", pa.float32()), - ]), - ), - ]) + schema = pa.schema( + [ + pa.field( + "struct_field_1", + pa.struct( + [ + pa.field("string_nested_1", pa.string()), + pa.field("int_item_2", pa.int32()), + pa.field("float_item_2", pa.float32()), + ] + ), + ), + ] + ) records = [ { diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index e4017e1df5..8bb97e150a 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -547,11 +547,13 @@ def test_binary_type_to_pyarrow() -> None: def test_struct_type_to_pyarrow(table_schema_simple: Schema) -> None: - expected = pa.struct([ - pa.field("foo", pa.large_string(), nullable=True, metadata={"field_id": "1"}), - pa.field("bar", pa.int32(), nullable=False, metadata={"field_id": "2"}), - pa.field("baz", pa.bool_(), nullable=True, metadata={"field_id": "3"}), - ]) + expected = pa.struct( + [ + pa.field("foo", pa.large_string(), nullable=True, metadata={"field_id": "1"}), + pa.field("bar", pa.int32(), nullable=False, metadata={"field_id": "2"}), + pa.field("baz", pa.bool_(), nullable=True, metadata={"field_id": "3"}), + ] + ) assert visit(table_schema_simple.as_struct(), _ConvertToArrowSchema()) == expected @@ -1771,11 +1773,13 @@ def test_bin_pack_arrow_table(arrow_table_with_null: pa.Table) -> None: def test_schema_mismatch_type(table_schema_simple: Schema) -> None: - other_schema = pa.schema(( - pa.field("foo", pa.string(), nullable=True), - pa.field("bar", pa.decimal128(18, 6), nullable=False), - pa.field("baz", pa.bool_(), nullable=True), - )) + other_schema = pa.schema( + ( + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.decimal128(18, 6), nullable=False), + pa.field("baz", pa.bool_(), nullable=True), + ) + ) expected = r"""Mismatch in fields: ┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ @@ -1792,11 +1796,13 @@ def test_schema_mismatch_type(table_schema_simple: Schema) -> None: def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None: - other_schema = pa.schema(( - pa.field("foo", pa.string(), nullable=True), - pa.field("bar", pa.int32(), nullable=True), - pa.field("baz", pa.bool_(), nullable=True), - )) + other_schema = pa.schema( + ( + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.int32(), nullable=True), + pa.field("baz", pa.bool_(), nullable=True), + ) + ) expected = """Mismatch in fields: ┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓ @@ -1813,11 +1819,13 @@ def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None: def test_schema_compatible_nullability_diff(table_schema_simple: Schema) -> None: - other_schema = pa.schema(( - pa.field("foo", pa.string(), nullable=True), - pa.field("bar", pa.int32(), nullable=False), - pa.field("baz", pa.bool_(), nullable=False), - )) + other_schema = pa.schema( + ( + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.int32(), nullable=False), + pa.field("baz", pa.bool_(), nullable=False), + ) + ) try: _check_pyarrow_schema_compatible(table_schema_simple, other_schema) @@ -1826,10 +1834,12 @@ def test_schema_compatible_nullability_diff(table_schema_simple: Schema) -> None def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None: - other_schema = pa.schema(( - pa.field("foo", pa.string(), nullable=True), - pa.field("baz", pa.bool_(), nullable=True), - )) + other_schema = pa.schema( + ( + pa.field("foo", pa.string(), nullable=True), + pa.field("baz", pa.bool_(), nullable=True), + ) + ) expected = """Mismatch in fields: ┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓ @@ -1851,9 +1861,11 @@ def test_schema_compatible_missing_nullable_field_nested(table_schema_nested: Sc 6, pa.field( "person", - pa.struct([ - pa.field("age", pa.int32(), nullable=False), - ]), + pa.struct( + [ + pa.field("age", pa.int32(), nullable=False), + ] + ), nullable=True, ), ) @@ -1869,9 +1881,11 @@ def test_schema_mismatch_missing_required_field_nested(table_schema_nested: Sche 6, pa.field( "person", - pa.struct([ - pa.field("name", pa.string(), nullable=True), - ]), + pa.struct( + [ + pa.field("name", pa.string(), nullable=True), + ] + ), nullable=True, ), ) @@ -1920,12 +1934,14 @@ def test_schema_compatible_nested(table_schema_nested: Schema) -> None: def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None: - other_schema = pa.schema(( - pa.field("foo", pa.string(), nullable=True), - pa.field("bar", pa.int32(), nullable=False), - pa.field("baz", pa.bool_(), nullable=True), - pa.field("new_field", pa.date32(), nullable=True), - )) + other_schema = pa.schema( + ( + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.int32(), nullable=False), + pa.field("baz", pa.bool_(), nullable=True), + pa.field("new_field", pa.date32(), nullable=True), + ) + ) with pytest.raises( ValueError, match=r"PyArrow table contains more columns: new_field. Update the schema first \(hint, use union_by_name\)." @@ -1942,10 +1958,12 @@ def test_schema_compatible(table_schema_simple: Schema) -> None: def test_schema_projection(table_schema_simple: Schema) -> None: # remove optional `baz` field from `table_schema_simple` - other_schema = pa.schema(( - pa.field("foo", pa.string(), nullable=True), - pa.field("bar", pa.int32(), nullable=False), - )) + other_schema = pa.schema( + ( + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.int32(), nullable=False), + ) + ) try: _check_pyarrow_schema_compatible(table_schema_simple, other_schema) except Exception: @@ -1954,11 +1972,13 @@ def test_schema_projection(table_schema_simple: Schema) -> None: def test_schema_downcast(table_schema_simple: Schema) -> None: # large_string type is compatible with string type - other_schema = pa.schema(( - pa.field("foo", pa.large_string(), nullable=True), - pa.field("bar", pa.int32(), nullable=False), - pa.field("baz", pa.bool_(), nullable=True), - )) + other_schema = pa.schema( + ( + pa.field("foo", pa.large_string(), nullable=True), + pa.field("bar", pa.int32(), nullable=False), + pa.field("baz", pa.bool_(), nullable=True), + ) + ) try: _check_pyarrow_schema_compatible(table_schema_simple, other_schema) @@ -2037,11 +2057,13 @@ def test_identity_partition_on_multi_columns() -> None: assert {table_partition.partition_key.partition for table_partition in result} == expected concatenated_arrow_table = pa.concat_tables([table_partition.arrow_table_partition for table_partition in result]) assert concatenated_arrow_table.num_rows == arrow_table.num_rows - assert concatenated_arrow_table.sort_by([ - ("born_year", "ascending"), - ("n_legs", "ascending"), - ("animal", "ascending"), - ]) == arrow_table.sort_by([("born_year", "ascending"), ("n_legs", "ascending"), ("animal", "ascending")]) + assert concatenated_arrow_table.sort_by( + [ + ("born_year", "ascending"), + ("n_legs", "ascending"), + ("animal", "ascending"), + ] + ) == arrow_table.sort_by([("born_year", "ascending"), ("n_legs", "ascending"), ("animal", "ascending")]) def test__to_requested_schema_timestamps( diff --git a/tests/io/test_pyarrow_visitor.py b/tests/io/test_pyarrow_visitor.py index 9e6df720c6..027fccae7c 100644 --- a/tests/io/test_pyarrow_visitor.py +++ b/tests/io/test_pyarrow_visitor.py @@ -239,11 +239,13 @@ def test_pyarrow_variable_binary_to_iceberg() -> None: def test_pyarrow_struct_to_iceberg() -> None: - pyarrow_struct = pa.struct([ - pa.field("foo", pa.string(), nullable=True, metadata={"PARQUET:field_id": "1", "doc": "foo doc"}), - pa.field("bar", pa.int32(), nullable=False, metadata={"PARQUET:field_id": "2"}), - pa.field("baz", pa.bool_(), nullable=True, metadata={"PARQUET:field_id": "3"}), - ]) + pyarrow_struct = pa.struct( + [ + pa.field("foo", pa.string(), nullable=True, metadata={"PARQUET:field_id": "1", "doc": "foo doc"}), + pa.field("bar", pa.int32(), nullable=False, metadata={"PARQUET:field_id": "2"}), + pa.field("baz", pa.bool_(), nullable=True, metadata={"PARQUET:field_id": "3"}), + ] + ) expected = StructType( NestedField(field_id=1, name="foo", field_type=StringType(), required=False, doc="foo doc"), NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), @@ -344,84 +346,94 @@ def test_round_schema_large_string() -> None: def test_simple_schema_has_missing_ids() -> None: - schema = pa.schema([ - pa.field("foo", pa.string(), nullable=False), - ]) + schema = pa.schema( + [ + pa.field("foo", pa.string(), nullable=False), + ] + ) visitor = _HasIds() has_ids = visit_pyarrow(schema, visitor) assert not has_ids def test_simple_schema_has_missing_ids_partial() -> None: - schema = pa.schema([ - pa.field("foo", pa.string(), nullable=False, metadata={"PARQUET:field_id": "1", "doc": "foo doc"}), - pa.field("bar", pa.int32(), nullable=False), - ]) + schema = pa.schema( + [ + pa.field("foo", pa.string(), nullable=False, metadata={"PARQUET:field_id": "1", "doc": "foo doc"}), + pa.field("bar", pa.int32(), nullable=False), + ] + ) visitor = _HasIds() has_ids = visit_pyarrow(schema, visitor) assert not has_ids def test_nested_schema_has_missing_ids() -> None: - schema = pa.schema([ - pa.field("foo", pa.string(), nullable=False), - pa.field( - "quux", - pa.map_( - pa.string(), - pa.map_(pa.string(), pa.int32()), + schema = pa.schema( + [ + pa.field("foo", pa.string(), nullable=False), + pa.field( + "quux", + pa.map_( + pa.string(), + pa.map_(pa.string(), pa.int32()), + ), + nullable=False, ), - nullable=False, - ), - ]) + ] + ) visitor = _HasIds() has_ids = visit_pyarrow(schema, visitor) assert not has_ids def test_nested_schema_has_ids() -> None: - schema = pa.schema([ - pa.field("foo", pa.string(), nullable=False, metadata={"PARQUET:field_id": "1", "doc": "foo doc"}), - pa.field( - "quux", - pa.map_( - pa.field("key", pa.string(), nullable=False, metadata={"PARQUET:field_id": "7"}), - pa.field( - "value", - pa.map_( - pa.field("key", pa.string(), nullable=False, metadata={"PARQUET:field_id": "9"}), - pa.field("value", pa.int32(), metadata={"PARQUET:field_id": "10"}), + schema = pa.schema( + [ + pa.field("foo", pa.string(), nullable=False, metadata={"PARQUET:field_id": "1", "doc": "foo doc"}), + pa.field( + "quux", + pa.map_( + pa.field("key", pa.string(), nullable=False, metadata={"PARQUET:field_id": "7"}), + pa.field( + "value", + pa.map_( + pa.field("key", pa.string(), nullable=False, metadata={"PARQUET:field_id": "9"}), + pa.field("value", pa.int32(), metadata={"PARQUET:field_id": "10"}), + ), + nullable=False, + metadata={"PARQUET:field_id": "8"}, ), - nullable=False, - metadata={"PARQUET:field_id": "8"}, ), + nullable=False, + metadata={"PARQUET:field_id": "6", "doc": "quux doc"}, ), - nullable=False, - metadata={"PARQUET:field_id": "6", "doc": "quux doc"}, - ), - ]) + ] + ) visitor = _HasIds() has_ids = visit_pyarrow(schema, visitor) assert has_ids def test_nested_schema_has_partial_missing_ids() -> None: - schema = pa.schema([ - pa.field("foo", pa.string(), nullable=False, metadata={"PARQUET:field_id": "1", "doc": "foo doc"}), - pa.field( - "quux", - pa.map_( - pa.field("key", pa.string(), nullable=False, metadata={"PARQUET:field_id": "7"}), - pa.field( - "value", - pa.map_(pa.field("key", pa.string(), nullable=False), pa.field("value", pa.int32())), - nullable=False, + schema = pa.schema( + [ + pa.field("foo", pa.string(), nullable=False, metadata={"PARQUET:field_id": "1", "doc": "foo doc"}), + pa.field( + "quux", + pa.map_( + pa.field("key", pa.string(), nullable=False, metadata={"PARQUET:field_id": "7"}), + pa.field( + "value", + pa.map_(pa.field("key", pa.string(), nullable=False), pa.field("value", pa.int32())), + nullable=False, + ), ), + nullable=False, + metadata={"PARQUET:field_id": "6", "doc": "quux doc"}, ), - nullable=False, - metadata={"PARQUET:field_id": "6", "doc": "quux doc"}, - ), - ]) + ] + ) visitor = _HasIds() has_ids = visit_pyarrow(schema, visitor) assert not has_ids @@ -441,11 +453,13 @@ def test_simple_pyarrow_schema_to_schema_missing_ids_using_name_mapping( pyarrow_schema_simple_without_ids: pa.Schema, iceberg_schema_simple: Schema ) -> None: schema = pyarrow_schema_simple_without_ids - name_mapping = NameMapping([ - MappedField(field_id=1, names=["foo"]), - MappedField(field_id=2, names=["bar"]), - MappedField(field_id=3, names=["baz"]), - ]) + name_mapping = NameMapping( + [ + MappedField(field_id=1, names=["foo"]), + MappedField(field_id=2, names=["bar"]), + MappedField(field_id=3, names=["baz"]), + ] + ) assert pyarrow_to_schema(schema, name_mapping) == iceberg_schema_simple @@ -454,9 +468,11 @@ def test_simple_pyarrow_schema_to_schema_missing_ids_using_name_mapping_partial_ pyarrow_schema_simple_without_ids: pa.Schema, ) -> None: schema = pyarrow_schema_simple_without_ids - name_mapping = NameMapping([ - MappedField(field_id=1, names=["foo"]), - ]) + name_mapping = NameMapping( + [ + MappedField(field_id=1, names=["foo"]), + ] + ) with pytest.raises(ValueError) as exc_info: _ = pyarrow_to_schema(schema, name_mapping) assert "Could not find field with name: bar" in str(exc_info.value) @@ -467,83 +483,89 @@ def test_nested_pyarrow_schema_to_schema_missing_ids_using_name_mapping( ) -> None: schema = pyarrow_schema_nested_without_ids - name_mapping = NameMapping([ - MappedField(field_id=1, names=["foo"]), - MappedField(field_id=2, names=["bar"]), - MappedField(field_id=3, names=["baz"]), - MappedField(field_id=4, names=["qux"], fields=[MappedField(field_id=5, names=["element"])]), - MappedField( - field_id=6, - names=["quux"], - fields=[ - MappedField(field_id=7, names=["key"]), - MappedField( - field_id=8, - names=["value"], - fields=[ - MappedField(field_id=9, names=["key"]), - MappedField(field_id=10, names=["value"]), - ], - ), - ], - ), - MappedField( - field_id=11, - names=["location"], - fields=[ - MappedField( - field_id=12, - names=["element"], - fields=[ - MappedField(field_id=13, names=["latitude"]), - MappedField(field_id=14, names=["longitude"]), - ], - ) - ], - ), - MappedField( - field_id=15, - names=["person"], - fields=[ - MappedField(field_id=16, names=["name"]), - MappedField(field_id=17, names=["age"]), - ], - ), - ]) + name_mapping = NameMapping( + [ + MappedField(field_id=1, names=["foo"]), + MappedField(field_id=2, names=["bar"]), + MappedField(field_id=3, names=["baz"]), + MappedField(field_id=4, names=["qux"], fields=[MappedField(field_id=5, names=["element"])]), + MappedField( + field_id=6, + names=["quux"], + fields=[ + MappedField(field_id=7, names=["key"]), + MappedField( + field_id=8, + names=["value"], + fields=[ + MappedField(field_id=9, names=["key"]), + MappedField(field_id=10, names=["value"]), + ], + ), + ], + ), + MappedField( + field_id=11, + names=["location"], + fields=[ + MappedField( + field_id=12, + names=["element"], + fields=[ + MappedField(field_id=13, names=["latitude"]), + MappedField(field_id=14, names=["longitude"]), + ], + ) + ], + ), + MappedField( + field_id=15, + names=["person"], + fields=[ + MappedField(field_id=16, names=["name"]), + MappedField(field_id=17, names=["age"]), + ], + ), + ] + ) assert pyarrow_to_schema(schema, name_mapping) == iceberg_schema_nested def test_pyarrow_schema_to_schema_missing_ids_using_name_mapping_nested_missing_id() -> None: - schema = pa.schema([ - pa.field("foo", pa.string(), nullable=False), - pa.field( - "quux", - pa.map_( - pa.string(), - pa.map_(pa.string(), pa.int32()), - ), - nullable=False, - ), - ]) - - name_mapping = NameMapping([ - MappedField(field_id=1, names=["foo"]), - MappedField( - field_id=6, - names=["quux"], - fields=[ - MappedField(field_id=7, names=["key"]), - MappedField( - field_id=8, - names=["value"], - fields=[ - MappedField(field_id=10, names=["value"]), - ], + schema = pa.schema( + [ + pa.field("foo", pa.string(), nullable=False), + pa.field( + "quux", + pa.map_( + pa.string(), + pa.map_(pa.string(), pa.int32()), ), - ], - ), - ]) + nullable=False, + ), + ] + ) + + name_mapping = NameMapping( + [ + MappedField(field_id=1, names=["foo"]), + MappedField( + field_id=6, + names=["quux"], + fields=[ + MappedField(field_id=7, names=["key"]), + MappedField( + field_id=8, + names=["value"], + fields=[ + MappedField(field_id=10, names=["value"]), + ], + ), + ], + ), + ] + ) with pytest.raises(ValueError) as exc_info: _ = pyarrow_to_schema(schema, name_mapping) assert "Could not find field with name: quux.value.key" in str(exc_info.value) @@ -562,38 +584,44 @@ def test_pyarrow_schema_to_schema_fresh_ids_nested_schema( def test_pyarrow_schema_ensure_large_types(pyarrow_schema_nested_without_ids: pa.Schema) -> None: - expected_schema = pa.schema([ - pa.field("foo", pa.large_string(), nullable=False), - pa.field("bar", pa.int32(), nullable=False), - pa.field("baz", pa.bool_(), nullable=True), - pa.field("qux", pa.large_list(pa.large_string()), nullable=False), - pa.field( - "quux", - pa.map_( - pa.large_string(), - pa.map_(pa.large_string(), pa.int32()), + expected_schema = pa.schema( + [ + pa.field("foo", pa.large_string(), nullable=False), + pa.field("bar", pa.int32(), nullable=False), + pa.field("baz", pa.bool_(), nullable=True), + pa.field("qux", pa.large_list(pa.large_string()), nullable=False), + pa.field( + "quux", + pa.map_( + pa.large_string(), + pa.map_(pa.large_string(), pa.int32()), + ), + nullable=False, ), - nullable=False, - ), - pa.field( - "location", - pa.large_list( - pa.struct([ - pa.field("latitude", pa.float32(), nullable=False), - pa.field("longitude", pa.float32(), nullable=False), - ]), + pa.field( + "location", + pa.large_list( + pa.struct( + [ + pa.field("latitude", pa.float32(), nullable=False), + pa.field("longitude", pa.float32(), nullable=False), + ] + ), + ), + nullable=False, ), - nullable=False, - ), - pa.field( - "person", - pa.struct([ - pa.field("name", pa.large_string(), nullable=True), - pa.field("age", pa.int32(), nullable=False), - ]), - nullable=True, - ), - ]) + pa.field( + "person", + pa.struct( + [ + pa.field("name", pa.large_string(), nullable=True), + pa.field("age", pa.int32(), nullable=False), + ] + ), + nullable=True, + ), + ] + ) assert _pyarrow_schema_ensure_large_types(pyarrow_schema_nested_without_ids) == expected_schema diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 397fa9f537..bcb2d643dc 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -538,15 +538,15 @@ def test_update_column(table_v1: Table, table_v2: Table) -> None: assert new_schema3.find_field("z").required is False, "failed to update existing field required" # assert the above two updates also works with union_by_name - assert table.update_schema().union_by_name(new_schema)._apply() == new_schema, ( - "failed to update existing field doc with union_by_name" - ) - assert table.update_schema().union_by_name(new_schema2)._apply() == new_schema2, ( - "failed to remove existing field doc with union_by_name" - ) - assert table.update_schema().union_by_name(new_schema3)._apply() == new_schema3, ( - "failed to update existing field required with union_by_name" - ) + assert ( + table.update_schema().union_by_name(new_schema)._apply() == new_schema + ), "failed to update existing field doc with union_by_name" + assert ( + table.update_schema().union_by_name(new_schema2)._apply() == new_schema2 + ), "failed to remove existing field doc with union_by_name" + assert ( + table.update_schema().union_by_name(new_schema3)._apply() == new_schema3 + ), "failed to update existing field required with union_by_name" def test_add_primitive_type_column(table_v2: Table) -> None: @@ -1077,52 +1077,56 @@ def test_assert_default_sort_order_id(table_v2: Table) -> None: def test_correct_schema() -> None: - table_metadata = TableMetadataV2(**{ - "format-version": 2, - "table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1", - "location": "s3://bucket/test/location", - "last-sequence-number": 34, - "last-updated-ms": 1602638573590, - "last-column-id": 3, - "current-schema-id": 1, - "schemas": [ - {"type": "struct", "schema-id": 0, "fields": [{"id": 1, "name": "x", "required": True, "type": "long"}]}, - { - "type": "struct", - "schema-id": 1, - "identifier-field-ids": [1, 2], - "fields": [ - {"id": 1, "name": "x", "required": True, "type": "long"}, - {"id": 2, "name": "y", "required": True, "type": "long"}, - {"id": 3, "name": "z", "required": True, "type": "long"}, - ], - }, - ], - "default-spec-id": 0, - "partition-specs": [{"spec-id": 0, "fields": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]}], - "last-partition-id": 1000, - "default-sort-order-id": 0, - "sort-orders": [], - "current-snapshot-id": 123, - "snapshots": [ - { - "snapshot-id": 234, - "timestamp-ms": 1515100955770, - "sequence-number": 0, - "summary": {"operation": "append"}, - "manifest-list": "s3://a/b/1.avro", - "schema-id": 10, - }, - { - "snapshot-id": 123, - "timestamp-ms": 1515100955770, - "sequence-number": 0, - "summary": {"operation": "append"}, - "manifest-list": "s3://a/b/1.avro", - "schema-id": 0, - }, - ], - }) + table_metadata = TableMetadataV2( + **{ + "format-version": 2, + "table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1", + "location": "s3://bucket/test/location", + "last-sequence-number": 34, + "last-updated-ms": 1602638573590, + "last-column-id": 3, + "current-schema-id": 1, + "schemas": [ + {"type": "struct", "schema-id": 0, "fields": [{"id": 1, "name": "x", "required": True, "type": "long"}]}, + { + "type": "struct", + "schema-id": 1, + "identifier-field-ids": [1, 2], + "fields": [ + {"id": 1, "name": "x", "required": True, "type": "long"}, + {"id": 2, "name": "y", "required": True, "type": "long"}, + {"id": 3, "name": "z", "required": True, "type": "long"}, + ], + }, + ], + "default-spec-id": 0, + "partition-specs": [ + {"spec-id": 0, "fields": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]} + ], + "last-partition-id": 1000, + "default-sort-order-id": 0, + "sort-orders": [], + "current-snapshot-id": 123, + "snapshots": [ + { + "snapshot-id": 234, + "timestamp-ms": 1515100955770, + "sequence-number": 0, + "summary": {"operation": "append"}, + "manifest-list": "s3://a/b/1.avro", + "schema-id": 10, + }, + { + "snapshot-id": 123, + "timestamp-ms": 1515100955770, + "sequence-number": 0, + "summary": {"operation": "append"}, + "manifest-list": "s3://a/b/1.avro", + "schema-id": 0, + }, + ], + } + ) t = Table( identifier=("default", "t1"), diff --git a/tests/table/test_name_mapping.py b/tests/table/test_name_mapping.py index bd271f59f8..c567f3ffb4 100644 --- a/tests/table/test_name_mapping.py +++ b/tests/table/test_name_mapping.py @@ -30,49 +30,51 @@ @pytest.fixture(scope="session") def table_name_mapping_nested() -> NameMapping: - return NameMapping([ - MappedField(field_id=1, names=["foo"]), - MappedField(field_id=2, names=["bar"]), - MappedField(field_id=3, names=["baz"]), - MappedField(field_id=4, names=["qux"], fields=[MappedField(field_id=5, names=["element"])]), - MappedField( - field_id=6, - names=["quux"], - fields=[ - MappedField(field_id=7, names=["key"]), - MappedField( - field_id=8, - names=["value"], - fields=[ - MappedField(field_id=9, names=["key"]), - MappedField(field_id=10, names=["value"]), - ], - ), - ], - ), - MappedField( - field_id=11, - names=["location"], - fields=[ - MappedField( - field_id=12, - names=["element"], - fields=[ - MappedField(field_id=13, names=["latitude"]), - MappedField(field_id=14, names=["longitude"]), - ], - ) - ], - ), - MappedField( - field_id=15, - names=["person"], - fields=[ - MappedField(field_id=16, names=["name"]), - MappedField(field_id=17, names=["age"]), - ], - ), - ]) + return NameMapping( + [ + MappedField(field_id=1, names=["foo"]), + MappedField(field_id=2, names=["bar"]), + MappedField(field_id=3, names=["baz"]), + MappedField(field_id=4, names=["qux"], fields=[MappedField(field_id=5, names=["element"])]), + MappedField( + field_id=6, + names=["quux"], + fields=[ + MappedField(field_id=7, names=["key"]), + MappedField( + field_id=8, + names=["value"], + fields=[ + MappedField(field_id=9, names=["key"]), + MappedField(field_id=10, names=["value"]), + ], + ), + ], + ), + MappedField( + field_id=11, + names=["location"], + fields=[ + MappedField( + field_id=12, + names=["element"], + fields=[ + MappedField(field_id=13, names=["latitude"]), + MappedField(field_id=14, names=["longitude"]), + ], + ) + ], + ), + MappedField( + field_id=15, + names=["person"], + fields=[ + MappedField(field_id=16, names=["name"]), + MappedField(field_id=17, names=["age"]), + ], + ), + ] + ) def test_json_mapped_field_deserialization() -> None: @@ -165,26 +167,30 @@ def test_json_name_mapping_deserialization() -> None: ] """ - assert parse_mapping_from_json(name_mapping) == NameMapping([ - MappedField(field_id=1, names=["id", "record_id"]), - MappedField(field_id=2, names=["data"]), - MappedField( - names=["location"], - field_id=3, - fields=[ - MappedField(field_id=4, names=["latitude", "lat"]), - MappedField(field_id=5, names=["longitude", "long"]), - ], - ), - ]) + assert parse_mapping_from_json(name_mapping) == NameMapping( + [ + MappedField(field_id=1, names=["id", "record_id"]), + MappedField(field_id=2, names=["data"]), + MappedField( + names=["location"], + field_id=3, + fields=[ + MappedField(field_id=4, names=["latitude", "lat"]), + MappedField(field_id=5, names=["longitude", "long"]), + ], + ), + ] + ) def test_json_mapped_field_no_field_id_serialization() -> None: - table_name_mapping_nested_no_field_id = NameMapping([ - MappedField(field_id=1, names=["foo"]), - MappedField(field_id=None, names=["bar"]), - MappedField(field_id=2, names=["qux"], fields=[MappedField(field_id=None, names=["element"])]), - ]) + table_name_mapping_nested_no_field_id = NameMapping( + [ + MappedField(field_id=1, names=["foo"]), + MappedField(field_id=None, names=["bar"]), + MappedField(field_id=2, names=["qux"], fields=[MappedField(field_id=None, names=["element"])]), + ] + ) assert ( table_name_mapping_nested_no_field_id.model_dump_json() @@ -200,18 +206,20 @@ def test_json_serialization(table_name_mapping_nested: NameMapping) -> None: def test_name_mapping_to_string() -> None: - nm = NameMapping([ - MappedField(field_id=1, names=["id", "record_id"]), - MappedField(field_id=2, names=["data"]), - MappedField( - names=["location"], - field_id=3, - fields=[ - MappedField(field_id=4, names=["lat", "latitude"]), - MappedField(field_id=5, names=["long", "longitude"]), - ], - ), - ]) + nm = NameMapping( + [ + MappedField(field_id=1, names=["id", "record_id"]), + MappedField(field_id=2, names=["data"]), + MappedField( + names=["location"], + field_id=3, + fields=[ + MappedField(field_id=4, names=["lat", "latitude"]), + MappedField(field_id=5, names=["long", "longitude"]), + ], + ), + ] + ) assert ( str(nm) @@ -294,51 +302,53 @@ def test_update_mapping(table_name_mapping_nested: NameMapping) -> None: 15: [NestedField(19, "name", StringType(), True), NestedField(20, "add_20", StringType(), True)], } - expected = NameMapping([ - MappedField(field_id=1, names=["foo", "foo_update"]), - MappedField(field_id=2, names=["bar"]), - MappedField(field_id=3, names=["baz"]), - MappedField(field_id=4, names=["qux"], fields=[MappedField(field_id=5, names=["element"])]), - MappedField( - field_id=6, - names=["quux"], - fields=[ - MappedField(field_id=7, names=["key"]), - MappedField( - field_id=8, - names=["value"], - fields=[ - MappedField(field_id=9, names=["key"]), - MappedField(field_id=10, names=["value"]), - ], - ), - ], - ), - MappedField( - field_id=11, - names=["location"], - fields=[ - MappedField( - field_id=12, - names=["element"], - fields=[ - MappedField(field_id=13, names=["latitude"]), - MappedField(field_id=14, names=["longitude"]), - ], - ) - ], - ), - MappedField( - field_id=15, - names=["person"], - fields=[ - MappedField(field_id=17, names=["age"]), - MappedField(field_id=19, names=["name"]), - MappedField(field_id=20, names=["add_20"]), - ], - ), - MappedField(field_id=18, names=["add_18"]), - ]) + expected = NameMapping( + [ + MappedField(field_id=1, names=["foo", "foo_update"]), + MappedField(field_id=2, names=["bar"]), + MappedField(field_id=3, names=["baz"]), + MappedField(field_id=4, names=["qux"], fields=[MappedField(field_id=5, names=["element"])]), + MappedField( + field_id=6, + names=["quux"], + fields=[ + MappedField(field_id=7, names=["key"]), + MappedField( + field_id=8, + names=["value"], + fields=[ + MappedField(field_id=9, names=["key"]), + MappedField(field_id=10, names=["value"]), + ], + ), + ], + ), + MappedField( + field_id=11, + names=["location"], + fields=[ + MappedField( + field_id=12, + names=["element"], + fields=[ + MappedField(field_id=13, names=["latitude"]), + MappedField(field_id=14, names=["longitude"]), + ], + ) + ], + ), + MappedField( + field_id=15, + names=["person"], + fields=[ + MappedField(field_id=17, names=["age"]), + MappedField(field_id=19, names=["name"]), + MappedField(field_id=20, names=["add_20"]), + ], + ), + MappedField(field_id=18, names=["add_18"]), + ] + ) assert update_mapping(table_name_mapping_nested, updates, adds) == expected diff --git a/tests/test_schema.py b/tests/test_schema.py index d1fc19df77..daa46dee1f 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1618,11 +1618,13 @@ def test_append_nested_lists() -> None: def test_union_with_pa_schema(primitive_fields: NestedField) -> None: base_schema = Schema(NestedField(field_id=1, name="foo", field_type=StringType(), required=True)) - pa_schema = pa.schema([ - pa.field("foo", pa.string(), nullable=False), - pa.field("bar", pa.int32(), nullable=True), - pa.field("baz", pa.bool_(), nullable=True), - ]) + pa_schema = pa.schema( + [ + pa.field("foo", pa.string(), nullable=False), + pa.field("bar", pa.int32(), nullable=True), + pa.field("baz", pa.bool_(), nullable=True), + ] + ) new_schema = UpdateSchema(transaction=None, schema=base_schema).union_by_name(pa_schema)._apply() # type: ignore @@ -1642,10 +1644,12 @@ def test_arrow_schema() -> None: NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), ) - expected_schema = pa.schema([ - pa.field("foo", pa.large_string(), nullable=False), - pa.field("bar", pa.int32(), nullable=True), - pa.field("baz", pa.bool_(), nullable=True), - ]) + expected_schema = pa.schema( + [ + pa.field("foo", pa.large_string(), nullable=False), + pa.field("bar", pa.int32(), nullable=True), + pa.field("baz", pa.bool_(), nullable=True), + ] + ) assert base_schema.as_arrow() == expected_schema diff --git a/tests/utils/test_manifest.py b/tests/utils/test_manifest.py index 154671c92e..3b1fc6f013 100644 --- a/tests/utils/test_manifest.py +++ b/tests/utils/test_manifest.py @@ -621,9 +621,9 @@ def test_write_manifest_list( def test_file_format_case_insensitive(raw_file_format: str, expected_file_format: FileFormat) -> None: if expected_file_format: parsed_file_format = FileFormat(raw_file_format) - assert parsed_file_format == expected_file_format, ( - f"File format {raw_file_format}: {parsed_file_format} != {expected_file_format}" - ) + assert ( + parsed_file_format == expected_file_format + ), f"File format {raw_file_format}: {parsed_file_format} != {expected_file_format}" else: with pytest.raises(ValueError): _ = FileFormat(raw_file_format)