@@ -412,6 +412,12 @@ def test_dynamic_partition_overwrite_unpartitioned_evolve_to_identity_transform(
412
412
spark : SparkSession , session_catalog : Catalog , arrow_table_with_null : pa .Table , part_col : str , format_version : int
413
413
) -> None :
414
414
identifier = f"default.unpartitioned_table_v{ format_version } _evolve_into_identity_transformed_partition_field_{ part_col } "
415
+
416
+ try :
417
+ session_catalog .drop_table (identifier = identifier )
418
+ except NoSuchTableError :
419
+ pass
420
+
415
421
tbl = session_catalog .create_table (
416
422
identifier = identifier ,
417
423
schema = TABLE_SCHEMA ,
@@ -756,6 +762,55 @@ def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog) -> Non
756
762
tbl .append ("not a df" )
757
763
758
764
765
+ @pytest .mark .integration
766
+ @pytest .mark .parametrize (
767
+ "spec" ,
768
+ [
769
+ (PartitionSpec (PartitionField (source_id = 4 , field_id = 1001 , transform = TruncateTransform (2 ), name = "int_trunc" ))),
770
+ (PartitionSpec (PartitionField (source_id = 5 , field_id = 1001 , transform = TruncateTransform (2 ), name = "long_trunc" ))),
771
+ (PartitionSpec (PartitionField (source_id = 2 , field_id = 1001 , transform = TruncateTransform (2 ), name = "string_trunc" ))),
772
+ ],
773
+ )
774
+ @pytest .mark .parametrize ("format_version" , [1 , 2 ])
775
+ def test_truncate_transform (
776
+ spec : PartitionSpec ,
777
+ spark : SparkSession ,
778
+ session_catalog : Catalog ,
779
+ arrow_table_with_null : pa .Table ,
780
+ format_version : int ,
781
+ ) -> None :
782
+ identifier = "default.truncate_transform"
783
+
784
+ try :
785
+ session_catalog .drop_table (identifier = identifier )
786
+ except NoSuchTableError :
787
+ pass
788
+
789
+ tbl = _create_table (
790
+ session_catalog = session_catalog ,
791
+ identifier = identifier ,
792
+ properties = {"format-version" : str (format_version )},
793
+ data = [arrow_table_with_null ],
794
+ partition_spec = spec ,
795
+ )
796
+
797
+ assert tbl .format_version == format_version , f"Expected v{ format_version } , got: v{ tbl .format_version } "
798
+ df = spark .table (identifier )
799
+ assert df .count () == 3 , f"Expected 3 total rows for { identifier } "
800
+ for col in arrow_table_with_null .column_names :
801
+ assert df .where (f"{ col } is not null" ).count () == 2 , f"Expected 2 non-null rows for { col } "
802
+ assert df .where (f"{ col } is null" ).count () == 1 , f"Expected 1 null row for { col } is null"
803
+
804
+ assert tbl .inspect .partitions ().num_rows == 3
805
+ files_df = spark .sql (
806
+ f"""
807
+ SELECT *
808
+ FROM { identifier } .files
809
+ """
810
+ )
811
+ assert files_df .count () == 3
812
+
813
+
759
814
@pytest .mark .integration
760
815
@pytest .mark .parametrize (
761
816
"spec" ,
@@ -767,18 +822,52 @@ def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog) -> Non
767
822
PartitionField (source_id = 1 , field_id = 1002 , transform = IdentityTransform (), name = "bool" ),
768
823
)
769
824
),
770
- # none of non-identity is supported
771
- (PartitionSpec (PartitionField (source_id = 4 , field_id = 1001 , transform = BucketTransform (2 ), name = "int_bucket" ))),
772
- (PartitionSpec (PartitionField (source_id = 5 , field_id = 1001 , transform = BucketTransform (2 ), name = "long_bucket" ))),
773
- (PartitionSpec (PartitionField (source_id = 10 , field_id = 1001 , transform = BucketTransform (2 ), name = "date_bucket" ))),
774
- (PartitionSpec (PartitionField (source_id = 8 , field_id = 1001 , transform = BucketTransform (2 ), name = "timestamp_bucket" ))),
775
- (PartitionSpec (PartitionField (source_id = 9 , field_id = 1001 , transform = BucketTransform (2 ), name = "timestamptz_bucket" ))),
776
- (PartitionSpec (PartitionField (source_id = 2 , field_id = 1001 , transform = BucketTransform (2 ), name = "string_bucket" ))),
777
- (PartitionSpec (PartitionField (source_id = 12 , field_id = 1001 , transform = BucketTransform (2 ), name = "fixed_bucket" ))),
778
- (PartitionSpec (PartitionField (source_id = 11 , field_id = 1001 , transform = BucketTransform (2 ), name = "binary_bucket" ))),
779
- (PartitionSpec (PartitionField (source_id = 4 , field_id = 1001 , transform = TruncateTransform (2 ), name = "int_trunc" ))),
780
- (PartitionSpec (PartitionField (source_id = 5 , field_id = 1001 , transform = TruncateTransform (2 ), name = "long_trunc" ))),
781
- (PartitionSpec (PartitionField (source_id = 2 , field_id = 1001 , transform = TruncateTransform (2 ), name = "string_trunc" ))),
825
+ ],
826
+ )
827
+ @pytest .mark .parametrize ("format_version" , [1 , 2 ])
828
+ def test_identity_and_bucket_transform_spec (
829
+ spec : PartitionSpec ,
830
+ spark : SparkSession ,
831
+ session_catalog : Catalog ,
832
+ arrow_table_with_null : pa .Table ,
833
+ format_version : int ,
834
+ ) -> None :
835
+ identifier = "default.identity_and_bucket_transform"
836
+
837
+ try :
838
+ session_catalog .drop_table (identifier = identifier )
839
+ except NoSuchTableError :
840
+ pass
841
+
842
+ tbl = _create_table (
843
+ session_catalog = session_catalog ,
844
+ identifier = identifier ,
845
+ properties = {"format-version" : str (format_version )},
846
+ data = [arrow_table_with_null ],
847
+ partition_spec = spec ,
848
+ )
849
+
850
+ assert tbl .format_version == format_version , f"Expected v{ format_version } , got: v{ tbl .format_version } "
851
+ df = spark .table (identifier )
852
+ assert df .count () == 3 , f"Expected 3 total rows for { identifier } "
853
+ for col in arrow_table_with_null .column_names :
854
+ assert df .where (f"{ col } is not null" ).count () == 2 , f"Expected 2 non-null rows for { col } "
855
+ assert df .where (f"{ col } is null" ).count () == 1 , f"Expected 1 null row for { col } is null"
856
+
857
+ assert tbl .inspect .partitions ().num_rows == 3
858
+ files_df = spark .sql (
859
+ f"""
860
+ SELECT *
861
+ FROM { identifier } .files
862
+ """
863
+ )
864
+ assert files_df .count () == 3
865
+
866
+
867
+ @pytest .mark .integration
868
+ @pytest .mark .parametrize (
869
+ "spec" ,
870
+ [
782
871
(PartitionSpec (PartitionField (source_id = 11 , field_id = 1001 , transform = TruncateTransform (2 ), name = "binary_trunc" ))),
783
872
],
784
873
)
@@ -801,11 +890,66 @@ def test_unsupported_transform(
801
890
802
891
with pytest .raises (
803
892
ValueError ,
804
- match = "Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: * " ,
893
+ match = "FeatureUnsupported => Unsupported data type for truncate transform: LargeBinary " ,
805
894
):
806
895
tbl .append (arrow_table_with_null )
807
896
808
897
898
+ @pytest .mark .integration
899
+ @pytest .mark .parametrize (
900
+ "spec, expected_rows" ,
901
+ [
902
+ (PartitionSpec (PartitionField (source_id = 4 , field_id = 1001 , transform = BucketTransform (2 ), name = "int_bucket" )), 3 ),
903
+ (PartitionSpec (PartitionField (source_id = 5 , field_id = 1001 , transform = BucketTransform (2 ), name = "long_bucket" )), 3 ),
904
+ (PartitionSpec (PartitionField (source_id = 10 , field_id = 1001 , transform = BucketTransform (2 ), name = "date_bucket" )), 3 ),
905
+ (PartitionSpec (PartitionField (source_id = 8 , field_id = 1001 , transform = BucketTransform (2 ), name = "timestamp_bucket" )), 3 ),
906
+ (PartitionSpec (PartitionField (source_id = 9 , field_id = 1001 , transform = BucketTransform (2 ), name = "timestamptz_bucket" )), 3 ),
907
+ (PartitionSpec (PartitionField (source_id = 2 , field_id = 1001 , transform = BucketTransform (2 ), name = "string_bucket" )), 3 ),
908
+ (PartitionSpec (PartitionField (source_id = 12 , field_id = 1001 , transform = BucketTransform (2 ), name = "fixed_bucket" )), 2 ),
909
+ (PartitionSpec (PartitionField (source_id = 11 , field_id = 1001 , transform = BucketTransform (2 ), name = "binary_bucket" )), 2 ),
910
+ ],
911
+ )
912
+ @pytest .mark .parametrize ("format_version" , [1 , 2 ])
913
+ def test_bucket_transform (
914
+ spark : SparkSession ,
915
+ session_catalog : Catalog ,
916
+ arrow_table_with_null : pa .Table ,
917
+ spec : PartitionSpec ,
918
+ expected_rows : int ,
919
+ format_version : int ,
920
+ ) -> None :
921
+ identifier = "default.bucket_transform"
922
+
923
+ try :
924
+ session_catalog .drop_table (identifier = identifier )
925
+ except NoSuchTableError :
926
+ pass
927
+
928
+ tbl = _create_table (
929
+ session_catalog = session_catalog ,
930
+ identifier = identifier ,
931
+ properties = {"format-version" : str (format_version )},
932
+ data = [arrow_table_with_null ],
933
+ partition_spec = spec ,
934
+ )
935
+
936
+ assert tbl .format_version == format_version , f"Expected v{ format_version } , got: v{ tbl .format_version } "
937
+ df = spark .table (identifier )
938
+ assert df .count () == 3 , f"Expected 3 total rows for { identifier } "
939
+ for col in arrow_table_with_null .column_names :
940
+ assert df .where (f"{ col } is not null" ).count () == 2 , f"Expected 2 non-null rows for { col } "
941
+ assert df .where (f"{ col } is null" ).count () == 1 , f"Expected 1 null row for { col } is null"
942
+
943
+ assert tbl .inspect .partitions ().num_rows == expected_rows
944
+ files_df = spark .sql (
945
+ f"""
946
+ SELECT *
947
+ FROM { identifier } .files
948
+ """
949
+ )
950
+ assert files_df .count () == expected_rows
951
+
952
+
809
953
@pytest .mark .integration
810
954
@pytest .mark .parametrize (
811
955
"transform,expected_rows" ,
0 commit comments