@@ -21,7 +21,7 @@ import scala.jdk.CollectionConverters.MapHasAsJava
21
21
22
22
import org .apache .spark .SparkException
23
23
import org .apache .spark .sql .catalyst .ExtendedAnalysisException
24
- import org .apache .spark .sql .catalyst .expressions .Literal
24
+ import org .apache .spark .sql .catalyst .expressions ._
25
25
import org .apache .spark .sql .catalyst .util .CollationFactory
26
26
import org .apache .spark .sql .connector .{DatasourceV2SQLBase , FakeV2ProviderWithCustomSchema }
27
27
import org .apache .spark .sql .connector .catalog .{Identifier , InMemoryTable }
@@ -30,8 +30,8 @@ import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership
30
30
import org .apache .spark .sql .errors .DataTypeErrors .toSQLType
31
31
import org .apache .spark .sql .execution .adaptive .AdaptiveSparkPlanHelper
32
32
import org .apache .spark .sql .execution .aggregate .{HashAggregateExec , ObjectHashAggregateExec }
33
- import org .apache .spark .sql .execution .joins .{ BroadcastHashJoinExec , SortMergeJoinExec }
34
- import org .apache .spark .sql .internal .SqlApiConf
33
+ import org .apache .spark .sql .execution .joins ._
34
+ import org .apache .spark .sql .internal .{ SqlApiConf , SQLConf }
35
35
import org .apache .spark .sql .internal .types .{AbstractMapType , StringTypeAnyCollation }
36
36
import org .apache .spark .sql .types .{MapType , StringType , StructField , StructType }
37
37
@@ -769,37 +769,6 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
769
769
})
770
770
}
771
771
772
- test(" hash based joins not allowed for non-binary collated strings" ) {
773
- val in = (('a' to 'z' ) ++ ('A' to 'Z' )).map(_.toString * 3 ).map(e => Row .apply(e, e))
774
-
775
- val schema = StructType (StructField (
776
- " col_non_binary" ,
777
- StringType (CollationFactory .collationNameToId(" UTF8_BINARY_LCASE" ))) ::
778
- StructField (" col_binary" , StringType ) :: Nil )
779
- val df1 = spark.createDataFrame(sparkContext.parallelize(in), schema)
780
-
781
- // Binary collations are allowed to use hash join.
782
- assert(collectFirst(
783
- df1.hint(" broadcast" ).join(df1, df1(" col_binary" ) === df1(" col_binary" ))
784
- .queryExecution.executedPlan) {
785
- case _ : BroadcastHashJoinExec => ()
786
- }.nonEmpty)
787
-
788
- // Even with hint broadcast, hash join is not used for non-binary collated strings.
789
- assert(collectFirst(
790
- df1.hint(" broadcast" ).join(df1, df1(" col_non_binary" ) === df1(" col_non_binary" ))
791
- .queryExecution.executedPlan) {
792
- case _ : BroadcastHashJoinExec => ()
793
- }.isEmpty)
794
-
795
- // Instead they will default to sort merge join.
796
- assert(collectFirst(
797
- df1.hint(" broadcast" ).join(df1, df1(" col_non_binary" ) === df1(" col_non_binary" ))
798
- .queryExecution.executedPlan) {
799
- case _ : SortMergeJoinExec => ()
800
- }.nonEmpty)
801
- }
802
-
803
772
test(" Generated column expressions using collations - errors out" ) {
804
773
checkError(
805
774
exception = intercept[AnalysisException ] {
@@ -1030,6 +999,135 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
1030
999
}
1031
1000
}
1032
1001
1002
+ test(" hash join should be used for collated strings" ) {
1003
+ val t1 = " T_1"
1004
+ val t2 = " T_2"
1005
+
1006
+ case class HashJoinTestCase [R ](collation : String , result : R )
1007
+ val testCases = Seq (
1008
+ HashJoinTestCase (" UTF8_BINARY" , Seq (Row (" aa" , 1 , " aa" , 2 ))),
1009
+ HashJoinTestCase (" UTF8_BINARY_LCASE" , Seq (Row (" aa" , 1 , " AA" , 2 ), Row (" aa" , 1 , " aa" , 2 ))),
1010
+ HashJoinTestCase (" UNICODE" , Seq (Row (" aa" , 1 , " aa" , 2 ))),
1011
+ HashJoinTestCase (" UNICODE_CI" , Seq (Row (" aa" , 1 , " AA" , 2 ), Row (" aa" , 1 , " aa" , 2 )))
1012
+ )
1013
+
1014
+ testCases.foreach(t => {
1015
+ withTable(t1, t2) {
1016
+ sql(s " CREATE TABLE $t1 (x STRING COLLATE ${t.collation}, i int) USING PARQUET " )
1017
+ sql(s " INSERT INTO $t1 VALUES ('aa', 1) " )
1018
+
1019
+ sql(s " CREATE TABLE $t2 (y STRING COLLATE ${t.collation}, j int) USING PARQUET " )
1020
+ sql(s " INSERT INTO $t2 VALUES ('AA', 2), ('aa', 2) " )
1021
+
1022
+ val df = sql(s " SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y " )
1023
+ checkAnswer(df, t.result)
1024
+
1025
+ val queryPlan = df.queryExecution.executedPlan
1026
+
1027
+ // confirm that hash join is used instead of sort merge join
1028
+ assert(
1029
+ collectFirst(queryPlan) {
1030
+ case _ : HashJoin => ()
1031
+ }.nonEmpty
1032
+ )
1033
+ assert(
1034
+ collectFirst(queryPlan) {
1035
+ case _ : SortMergeJoinExec => ()
1036
+ }.isEmpty
1037
+ )
1038
+
1039
+ // if collation doesn't support binary equality, collation key should be injected
1040
+ if (! CollationFactory .fetchCollation(t.collation).supportsBinaryEquality) {
1041
+ assert(collectFirst(queryPlan) {
1042
+ case b : HashJoin => b.leftKeys.head
1043
+ }.head.isInstanceOf [CollationKey ])
1044
+ }
1045
+ }
1046
+ })
1047
+ }
1048
+
1049
+ test(" rewrite with collationkey should be an excludable rule" ) {
1050
+ val t1 = " T_1"
1051
+ val t2 = " T_2"
1052
+ val collation = " UTF8_BINARY_LCASE"
1053
+ val collationRewriteJoinRule = " org.apache.spark.sql.catalyst.analysis.RewriteCollationJoin"
1054
+ withTable(t1, t2) {
1055
+ withSQLConf(SQLConf .OPTIMIZER_EXCLUDED_RULES .key -> collationRewriteJoinRule) {
1056
+ sql(s " CREATE TABLE $t1 (x STRING COLLATE $collation, i int) USING PARQUET " )
1057
+ sql(s " INSERT INTO $t1 VALUES ('aa', 1) " )
1058
+
1059
+ sql(s " CREATE TABLE $t2 (y STRING COLLATE $collation, j int) USING PARQUET " )
1060
+ sql(s " INSERT INTO $t2 VALUES ('AA', 2), ('aa', 2) " )
1061
+
1062
+ val df = sql(s " SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y " )
1063
+ checkAnswer(df, Seq (Row (" aa" , 1 , " AA" , 2 ), Row (" aa" , 1 , " aa" , 2 )))
1064
+
1065
+ val queryPlan = df.queryExecution.executedPlan
1066
+
1067
+ // confirm that shuffle join is used instead of hash join
1068
+ assert(
1069
+ collectFirst(queryPlan) {
1070
+ case _ : HashJoin => ()
1071
+ }.isEmpty
1072
+ )
1073
+ assert(
1074
+ collectFirst(queryPlan) {
1075
+ case _ : SortMergeJoinExec => ()
1076
+ }.nonEmpty
1077
+ )
1078
+ }
1079
+ }
1080
+ }
1081
+
1082
+ test(" rewrite with collationkey shouldn't disrupt multiple join conditions" ) {
1083
+ val t1 = " T_1"
1084
+ val t2 = " T_2"
1085
+
1086
+ case class HashMultiJoinTestCase [R ](
1087
+ type1 : String ,
1088
+ type2 : String ,
1089
+ data1 : String ,
1090
+ data2 : String ,
1091
+ result : R
1092
+ )
1093
+ val testCases = Seq (
1094
+ HashMultiJoinTestCase (" STRING COLLATE UTF8_BINARY" , " INT" ,
1095
+ " 'a', 0, 1" , " 'a', 0, 1" , Row (" a" , 0 , 1 , " a" , 0 , 1 )),
1096
+ HashMultiJoinTestCase (" STRING COLLATE UTF8_BINARY" , " STRING COLLATE UTF8_BINARY" ,
1097
+ " 'a', 'a', 1" , " 'a', 'a', 1" , Row (" a" , " a" , 1 , " a" , " a" , 1 )),
1098
+ HashMultiJoinTestCase (" STRING COLLATE UTF8_BINARY" , " STRING COLLATE UTF8_BINARY_LCASE" ,
1099
+ " 'a', 'a', 1" , " 'a', 'A', 1" , Row (" a" , " a" , 1 , " a" , " A" , 1 )),
1100
+ HashMultiJoinTestCase (" STRING COLLATE UTF8_BINARY_LCASE" , " STRING COLLATE UNICODE_CI" ,
1101
+ " 'a', 'a', 1" , " 'A', 'A', 1" , Row (" a" , " a" , 1 , " A" , " A" , 1 ))
1102
+ )
1103
+
1104
+ testCases.foreach(t => {
1105
+ withTable(t1, t2) {
1106
+ sql(s " CREATE TABLE $t1 (x ${t.type1}, y ${t.type2}, i int) USING PARQUET " )
1107
+ sql(s " INSERT INTO $t1 VALUES ( ${t.data1}) " )
1108
+ sql(s " CREATE TABLE $t2 (x ${t.type1}, y ${t.type2}, i int) USING PARQUET " )
1109
+ sql(s " INSERT INTO $t2 VALUES ( ${t.data2}) " )
1110
+
1111
+ val df = sql(s " SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.x AND $t1.y = $t2.y " )
1112
+ checkAnswer(df, t.result)
1113
+
1114
+ val queryPlan = df.queryExecution.executedPlan
1115
+
1116
+ // confirm that hash join is used instead of sort merge join
1117
+ assert(
1118
+ collectFirst(queryPlan) {
1119
+ case _ : HashJoin => ()
1120
+ }.nonEmpty
1121
+ )
1122
+ assert(
1123
+ collectFirst(queryPlan) {
1124
+ case _ : SortMergeJoinExec => ()
1125
+ }.isEmpty
1126
+ )
1127
+ }
1128
+ })
1129
+ }
1130
+
1033
1131
test(" hll sketch aggregate should respect collation" ) {
1034
1132
case class HllSketchAggTestCase [R ](c : String , result : R )
1035
1133
val testCases = Seq (
0 commit comments