Skip to content

Commit

Permalink
take into account the eval mode before reordering commutative operands
Browse files Browse the repository at this point in the history
  • Loading branch information
db-scnakandala committed Apr 30, 2024
1 parent 0329479 commit 87e0dfa
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1378,6 +1378,20 @@ trait CommutativeExpression extends Expression {
}
reorderResult
}

/**
* Helper method to collect the evaluation mode of the commutative expressions. This is
* used by the canonicalized methods of [[Add]] and [[Multiply]] operators to ensure that
* all operands have the same evaluation mode before reordering the operands.
*/
protected def collectEvalModes(
e: Expression,
f: PartialFunction[CommutativeExpression, Seq[EvalMode.Value]]
): Seq[EvalMode.Value] = e match {
case c: CommutativeExpression if f.isDefinedAt(c) =>
f(c) ++ c.children.flatMap(collectEvalModes(_, f))
case _ => Nil
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,13 +452,14 @@ case class Add(
copy(left = newLeft, right = newRight)

override lazy val canonicalized: Expression = {
// TODO: do not reorder consecutive `Add`s with different `evalMode`
val reorderResult = buildCanonicalizedPlan(
val evalModes = collectEvalModes(this, {case Add(_, _, evalMode) => Seq(evalMode)})
lazy val reorderResult = buildCanonicalizedPlan(
{ case Add(l, r, _) => Seq(l, r) },
{ case (l: Expression, r: Expression) => Add(l, r, evalMode)},
Some(evalMode)
)
if (resolved && reorderResult.resolved && reorderResult.dataType == dataType) {
if (resolved && evalModes.forall(_ == evalMode) && reorderResult.resolved &&
reorderResult.dataType == dataType) {
reorderResult
} else {
// SPARK-40903: Avoid reordering decimal Add for canonicalization if the result data type is
Expand Down Expand Up @@ -608,12 +609,16 @@ case class Multiply(
newLeft: Expression, newRight: Expression): Multiply = copy(left = newLeft, right = newRight)

override lazy val canonicalized: Expression = {
// TODO: do not reorder consecutive `Multiply`s with different `evalMode`
buildCanonicalizedPlan(
{ case Multiply(l, r, _) => Seq(l, r) },
{ case (l: Expression, r: Expression) => Multiply(l, r, evalMode)},
Some(evalMode)
)
val evalModes = collectEvalModes(this, {case Add(_, _, evalMode) => Seq(evalMode)})
if (evalModes.forall(_ == evalMode)) {
buildCanonicalizedPlan(
{ case Multiply(l, r, _) => Seq(l, r) },
{ case (l: Expression, r: Expression) => Multiply(l, r, evalMode)},
Some(evalMode)
)
} else {
withCanonicalizedChildren
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,4 +454,30 @@ class CanonicalizeSuite extends SparkFunSuite {
// different.
assert(common3.canonicalized != common4.canonicalized)
}

test("[SPARK-48035] Add/Multiply operator canonicalization should take into account the" +
"evaluation mode of the operands before operand reordering") {
Seq(1, 10).match {
case multiCommutativeOpOptThreshold =>
val default = SQLConf.get.getConf(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD)
SQLConf.get.setConfString(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD.key,
multiCommutativeOpOptThreshold.toString)
try {
val l1 = Literal(1)
val l2 = Literal(2)
val l3 = Literal(3)

val expr1 = Add(Add(l1, l2), l3)
val expr2 = Add(Add(l2, l1, EvalMode.TRY), l3)
assert(!expr1.semanticEquals(expr2))

val expr3 = Multiply(Multiply(l1, l2), l3)
val expr4 = Multiply(Multiply(l2, l1, EvalMode.TRY), l3)
assert(!expr3.semanticEquals(expr4))
} finally {
SQLConf.get.setConfString(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD.key,
default.toString)
}
}
}
}

0 comments on commit 87e0dfa

Please sign in to comment.