Skip to content

Commit e6236af

Browse files
uros-dbcloud-fan
authored andcommitted
[SPARK-48000][SQL] Enable hash join support for all collations (StringType)
### What changes were proposed in this pull request? Enable collation support for hash join on StringType. Note: support for complex types will be added separately. - Logical plan is rewritten in analysis to replace non-binary strings with `CollationKey` - `CollationKey` is a unary expression that transforms `StringType` to `BinaryType` - Collation keys allow correct & efficient string comparison under specific collation rules ### Why are the changes needed? Improve JOIN performance for collated strings. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - Unit tests for `CollationKey` in `CollationExpressionSuite` - E2e SQL tests for `RewriteCollationJoin` in `CollationSuite` - Various queries with JOIN in existing TPCDS collation test suite ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#46599 from uros-db/hash-join-str. Authored-by: Uros Bojanic <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent a86bca1 commit e6236af

File tree

6 files changed

+264
-35
lines changed

6 files changed

+264
-35
lines changed

common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,4 +817,15 @@ public static UTF8String getCollationKey(UTF8String input, int collationId) {
817817
}
818818
}
819819

820+
public static byte[] getCollationKeyBytes(UTF8String input, int collationId) {
821+
Collation collation = fetchCollation(collationId);
822+
if (collation.supportsBinaryEquality) {
823+
return input.getBytes();
824+
} else if (collation.supportsLowercaseEquality) {
825+
return input.toLowerCase().getBytes();
826+
} else {
827+
return collation.collator.getCollationKey(input.toString()).toByteArray();
828+
}
829+
}
830+
820831
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.analysis
19+
20+
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, CollationKey, Equality}
21+
import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan}
22+
import org.apache.spark.sql.catalyst.rules.Rule
23+
import org.apache.spark.sql.catalyst.util.CollationFactory
24+
import org.apache.spark.sql.types.StringType
25+
26+
object RewriteCollationJoin extends Rule[LogicalPlan] {
27+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
28+
case j @ Join(_, _, _, Some(condition), _) =>
29+
val newCondition = condition transform {
30+
case e @ Equality(l: AttributeReference, r: AttributeReference) =>
31+
(l.dataType, r.dataType) match {
32+
case (st: StringType, _: StringType)
33+
if !CollationFactory.fetchCollation(st.collationId).supportsBinaryEquality =>
34+
e.withNewChildren(Seq(CollationKey(l), CollationKey(r)))
35+
case _ =>
36+
e
37+
}
38+
}
39+
if (!newCondition.fastEquals(condition)) {
40+
j.copy(condition = Some(newCondition))
41+
} else {
42+
j
43+
}
44+
}
45+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
21+
import org.apache.spark.sql.catalyst.util.CollationFactory
22+
import org.apache.spark.sql.internal.types.StringTypeAnyCollation
23+
import org.apache.spark.sql.types._
24+
import org.apache.spark.unsafe.types.UTF8String
25+
26+
case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsInputTypes {
27+
override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation)
28+
override def dataType: DataType = BinaryType
29+
30+
final lazy val collationId: Int = expr.dataType match {
31+
case st: StringType =>
32+
st.collationId
33+
}
34+
35+
override def nullSafeEval(input: Any): Any =
36+
CollationFactory.getCollationKeyBytes(input.asInstanceOf[UTF8String], collationId)
37+
38+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
39+
defineCodeGen(ctx, ev, c => s"CollationFactory.getCollationKeyBytes($c, $collationId)")
40+
}
41+
42+
override protected def withNewChildInternal(newChild: Expression): Expression = {
43+
copy(expr = newChild)
44+
}
45+
46+
override def child: Expression = expr
47+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
2020
import org.apache.spark.{SparkException, SparkFunSuite}
2121
import org.apache.spark.sql.catalyst.util.CollationFactory
2222
import org.apache.spark.sql.types._
23+
import org.apache.spark.unsafe.types.UTF8String
2324

2425
class CollationExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
2526
test("validate default collation") {
@@ -163,6 +164,31 @@ class CollationExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
163164
}
164165
}
165166

167+
test("CollationKey generates correct collation key for collated string") {
168+
val testCases = Seq(
169+
("", "UTF8_BINARY", UTF8String.fromString("").getBytes),
170+
("aa", "UTF8_BINARY", UTF8String.fromString("aa").getBytes),
171+
("AA", "UTF8_BINARY", UTF8String.fromString("AA").getBytes),
172+
("aA", "UTF8_BINARY", UTF8String.fromString("aA").getBytes),
173+
("", "UTF8_BINARY_LCASE", UTF8String.fromString("").getBytes),
174+
("aa", "UTF8_BINARY_LCASE", UTF8String.fromString("aa").getBytes),
175+
("AA", "UTF8_BINARY_LCASE", UTF8String.fromString("aa").getBytes),
176+
("aA", "UTF8_BINARY_LCASE", UTF8String.fromString("aa").getBytes),
177+
("", "UNICODE", UTF8String.fromString("").getBytes),
178+
("aa", "UNICODE", UTF8String.fromString("aa").getBytes),
179+
("AA", "UNICODE", UTF8String.fromString("AA").getBytes),
180+
("aA", "UNICODE", UTF8String.fromString("aA").getBytes),
181+
("", "UNICODE_CI", Array[Byte](1, 0)),
182+
("aa", "UNICODE_CI", Array[Byte](42, 42, 1, 6, 0)),
183+
("AA", "UNICODE_CI", Array[Byte](42, 42, 1, 6, 0)),
184+
("aA", "UNICODE_CI", Array[Byte](42, 42, 1, 6, 0))
185+
)
186+
for ((input, collation, expected) <- testCases) {
187+
val str = Literal.create(input, StringType(collation))
188+
checkEvaluation(CollationKey(str), expected)
189+
}
190+
}
191+
166192
test("collation name normalization in collation expression") {
167193
Seq(
168194
("en_USA", "en_USA"),

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.sql.ExperimentalMethods
21+
import org.apache.spark.sql.catalyst.analysis.RewriteCollationJoin
2122
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
2223
import org.apache.spark.sql.catalyst.optimizer._
2324
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -92,7 +93,8 @@ class SparkOptimizer(
9293
EliminateLimits,
9394
ConstantFolding) :+
9495
Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) :+
95-
Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition)
96+
Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition) :+
97+
Batch("RewriteCollationJoin", Once, RewriteCollationJoin)
9698

9799
override def nonExcludableRules: Seq[String] = super.nonExcludableRules :+
98100
ExtractPythonUDFFromJoinCondition.ruleName :+

sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala

Lines changed: 132 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.jdk.CollectionConverters.MapHasAsJava
2121

2222
import org.apache.spark.SparkException
2323
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
24-
import org.apache.spark.sql.catalyst.expressions.Literal
24+
import org.apache.spark.sql.catalyst.expressions._
2525
import org.apache.spark.sql.catalyst.util.CollationFactory
2626
import org.apache.spark.sql.connector.{DatasourceV2SQLBase, FakeV2ProviderWithCustomSchema}
2727
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable}
@@ -30,8 +30,8 @@ import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership
3030
import org.apache.spark.sql.errors.DataTypeErrors.toSQLType
3131
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
3232
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec}
33-
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
34-
import org.apache.spark.sql.internal.SqlApiConf
33+
import org.apache.spark.sql.execution.joins._
34+
import org.apache.spark.sql.internal.{SqlApiConf, SQLConf}
3535
import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeAnyCollation}
3636
import org.apache.spark.sql.types.{MapType, StringType, StructField, StructType}
3737

@@ -769,37 +769,6 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
769769
})
770770
}
771771

772-
test("hash based joins not allowed for non-binary collated strings") {
773-
val in = (('a' to 'z') ++ ('A' to 'Z')).map(_.toString * 3).map(e => Row.apply(e, e))
774-
775-
val schema = StructType(StructField(
776-
"col_non_binary",
777-
StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE"))) ::
778-
StructField("col_binary", StringType) :: Nil)
779-
val df1 = spark.createDataFrame(sparkContext.parallelize(in), schema)
780-
781-
// Binary collations are allowed to use hash join.
782-
assert(collectFirst(
783-
df1.hint("broadcast").join(df1, df1("col_binary") === df1("col_binary"))
784-
.queryExecution.executedPlan) {
785-
case _: BroadcastHashJoinExec => ()
786-
}.nonEmpty)
787-
788-
// Even with hint broadcast, hash join is not used for non-binary collated strings.
789-
assert(collectFirst(
790-
df1.hint("broadcast").join(df1, df1("col_non_binary") === df1("col_non_binary"))
791-
.queryExecution.executedPlan) {
792-
case _: BroadcastHashJoinExec => ()
793-
}.isEmpty)
794-
795-
// Instead they will default to sort merge join.
796-
assert(collectFirst(
797-
df1.hint("broadcast").join(df1, df1("col_non_binary") === df1("col_non_binary"))
798-
.queryExecution.executedPlan) {
799-
case _: SortMergeJoinExec => ()
800-
}.nonEmpty)
801-
}
802-
803772
test("Generated column expressions using collations - errors out") {
804773
checkError(
805774
exception = intercept[AnalysisException] {
@@ -1030,6 +999,135 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
1030999
}
10311000
}
10321001

1002+
test("hash join should be used for collated strings") {
1003+
val t1 = "T_1"
1004+
val t2 = "T_2"
1005+
1006+
case class HashJoinTestCase[R](collation: String, result: R)
1007+
val testCases = Seq(
1008+
HashJoinTestCase("UTF8_BINARY", Seq(Row("aa", 1, "aa", 2))),
1009+
HashJoinTestCase("UTF8_BINARY_LCASE", Seq(Row("aa", 1, "AA", 2), Row("aa", 1, "aa", 2))),
1010+
HashJoinTestCase("UNICODE", Seq(Row("aa", 1, "aa", 2))),
1011+
HashJoinTestCase("UNICODE_CI", Seq(Row("aa", 1, "AA", 2), Row("aa", 1, "aa", 2)))
1012+
)
1013+
1014+
testCases.foreach(t => {
1015+
withTable(t1, t2) {
1016+
sql(s"CREATE TABLE $t1 (x STRING COLLATE ${t.collation}, i int) USING PARQUET")
1017+
sql(s"INSERT INTO $t1 VALUES ('aa', 1)")
1018+
1019+
sql(s"CREATE TABLE $t2 (y STRING COLLATE ${t.collation}, j int) USING PARQUET")
1020+
sql(s"INSERT INTO $t2 VALUES ('AA', 2), ('aa', 2)")
1021+
1022+
val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y")
1023+
checkAnswer(df, t.result)
1024+
1025+
val queryPlan = df.queryExecution.executedPlan
1026+
1027+
// confirm that hash join is used instead of sort merge join
1028+
assert(
1029+
collectFirst(queryPlan) {
1030+
case _: HashJoin => ()
1031+
}.nonEmpty
1032+
)
1033+
assert(
1034+
collectFirst(queryPlan) {
1035+
case _: SortMergeJoinExec => ()
1036+
}.isEmpty
1037+
)
1038+
1039+
// if collation doesn't support binary equality, collation key should be injected
1040+
if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) {
1041+
assert(collectFirst(queryPlan) {
1042+
case b: HashJoin => b.leftKeys.head
1043+
}.head.isInstanceOf[CollationKey])
1044+
}
1045+
}
1046+
})
1047+
}
1048+
1049+
test("rewrite with collationkey should be an excludable rule") {
1050+
val t1 = "T_1"
1051+
val t2 = "T_2"
1052+
val collation = "UTF8_BINARY_LCASE"
1053+
val collationRewriteJoinRule = "org.apache.spark.sql.catalyst.analysis.RewriteCollationJoin"
1054+
withTable(t1, t2) {
1055+
withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> collationRewriteJoinRule) {
1056+
sql(s"CREATE TABLE $t1 (x STRING COLLATE $collation, i int) USING PARQUET")
1057+
sql(s"INSERT INTO $t1 VALUES ('aa', 1)")
1058+
1059+
sql(s"CREATE TABLE $t2 (y STRING COLLATE $collation, j int) USING PARQUET")
1060+
sql(s"INSERT INTO $t2 VALUES ('AA', 2), ('aa', 2)")
1061+
1062+
val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y")
1063+
checkAnswer(df, Seq(Row("aa", 1, "AA", 2), Row("aa", 1, "aa", 2)))
1064+
1065+
val queryPlan = df.queryExecution.executedPlan
1066+
1067+
// confirm that shuffle join is used instead of hash join
1068+
assert(
1069+
collectFirst(queryPlan) {
1070+
case _: HashJoin => ()
1071+
}.isEmpty
1072+
)
1073+
assert(
1074+
collectFirst(queryPlan) {
1075+
case _: SortMergeJoinExec => ()
1076+
}.nonEmpty
1077+
)
1078+
}
1079+
}
1080+
}
1081+
1082+
test("rewrite with collationkey shouldn't disrupt multiple join conditions") {
1083+
val t1 = "T_1"
1084+
val t2 = "T_2"
1085+
1086+
case class HashMultiJoinTestCase[R](
1087+
type1: String,
1088+
type2: String,
1089+
data1: String,
1090+
data2: String,
1091+
result: R
1092+
)
1093+
val testCases = Seq(
1094+
HashMultiJoinTestCase("STRING COLLATE UTF8_BINARY", "INT",
1095+
"'a', 0, 1", "'a', 0, 1", Row("a", 0, 1, "a", 0, 1)),
1096+
HashMultiJoinTestCase("STRING COLLATE UTF8_BINARY", "STRING COLLATE UTF8_BINARY",
1097+
"'a', 'a', 1", "'a', 'a', 1", Row("a", "a", 1, "a", "a", 1)),
1098+
HashMultiJoinTestCase("STRING COLLATE UTF8_BINARY", "STRING COLLATE UTF8_BINARY_LCASE",
1099+
"'a', 'a', 1", "'a', 'A', 1", Row("a", "a", 1, "a", "A", 1)),
1100+
HashMultiJoinTestCase("STRING COLLATE UTF8_BINARY_LCASE", "STRING COLLATE UNICODE_CI",
1101+
"'a', 'a', 1", "'A', 'A', 1", Row("a", "a", 1, "A", "A", 1))
1102+
)
1103+
1104+
testCases.foreach(t => {
1105+
withTable(t1, t2) {
1106+
sql(s"CREATE TABLE $t1 (x ${t.type1}, y ${t.type2}, i int) USING PARQUET")
1107+
sql(s"INSERT INTO $t1 VALUES (${t.data1})")
1108+
sql(s"CREATE TABLE $t2 (x ${t.type1}, y ${t.type2}, i int) USING PARQUET")
1109+
sql(s"INSERT INTO $t2 VALUES (${t.data2})")
1110+
1111+
val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.x AND $t1.y = $t2.y")
1112+
checkAnswer(df, t.result)
1113+
1114+
val queryPlan = df.queryExecution.executedPlan
1115+
1116+
// confirm that hash join is used instead of sort merge join
1117+
assert(
1118+
collectFirst(queryPlan) {
1119+
case _: HashJoin => ()
1120+
}.nonEmpty
1121+
)
1122+
assert(
1123+
collectFirst(queryPlan) {
1124+
case _: SortMergeJoinExec => ()
1125+
}.isEmpty
1126+
)
1127+
}
1128+
})
1129+
}
1130+
10331131
test("hll sketch aggregate should respect collation") {
10341132
case class HllSketchAggTestCase[R](c: String, result: R)
10351133
val testCases = Seq(

0 commit comments

Comments
 (0)