Skip to content

Commit 966e053

Browse files
szehon-hodongjoon-hyun
authored andcommitted
[SPARK-54289][SQL] Allow MERGE INTO to preserve existing struct fields for UPDATE SET * when source struct has less nested fields than target struct
### What changes were proposed in this pull request? Introduce a new flag spark.sql.merge.nested.type.assign.by.field that allows UPDATE SET * action in MERGE INTO to be shorthand to assign every nested struct to its existing source counterpart (ie, UPDATE SET a.b.c = source.a.b.c). This will have the implication that existing struct field in the target table that has no source equivalent are preserved, when the corresponding source struct has less fields than target. Additional code is added to prevent null expansion in this case (ie, a null source struct expanding to a struct of nulls). ### Why are the changes needed? Following #52347, we now allow MERGE INTO to have a source table struct with less nested fields than target table struct. In this scenario, a user making a UPDATE SET * may have two interpretations. The use may interpret UPDATE SET * as shorthand to assign every top-column level field, ie UPDATE SET struct=source.struct, then the target struct is set to source struct object as is, with missing fields as NULL. This is the current behavior. The user may also mean that UPDATE SET * is short-hand to assign every nested struct field (ie, UPDATE SET struct.a.b = source.struct.a.b), in which case the target struct fields missing in source are retained. This is similar to UPDATE SET * not overriding existing target columns missing in the source, for example. For this case, this flag is added. ### Does this PR introduce _any_ user-facing change? No, the support to allow source structs to have less fields than target structs in MERGE INTO is unreleased yet (#52347), and in any case there is a flag to toggle this functionality. ### How was this patch tested? Unit tests, especially around cases where the source struct is null. ### Was this patch authored or co-authored using generative AI tooling? No Closes #53149 from szehon-ho/merge_schema_evolution_update_nested. Authored-by: Szehon Ho <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 371b00a commit 966e053

File tree

11 files changed

+1075
-207
lines changed

11 files changed

+1075
-207
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1709,14 +1709,15 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
17091709
val resolvedDeleteCondition = deleteCondition.map(
17101710
resolveExpressionByPlanChildren(_, m))
17111711
DeleteAction(resolvedDeleteCondition)
1712-
case UpdateAction(updateCondition, assignments) =>
1712+
case UpdateAction(updateCondition, assignments, fromStar) =>
17131713
val resolvedUpdateCondition = updateCondition.map(
17141714
resolveExpressionByPlanChildren(_, m))
17151715
UpdateAction(
17161716
resolvedUpdateCondition,
17171717
// The update value can access columns from both target and source tables.
17181718
resolveAssignments(assignments, m, MergeResolvePolicy.BOTH,
1719-
throws = throws))
1719+
throws = throws),
1720+
fromStar)
17201721
case UpdateStarAction(updateCondition) =>
17211722
// Expand star to top level source columns. If source has less columns than target,
17221723
// assignments will be added by ResolveRowLevelCommandAssignments later.
@@ -1738,7 +1739,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
17381739
updateCondition.map(resolveExpressionByPlanChildren(_, m)),
17391740
// For UPDATE *, the value must be from source table.
17401741
resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE,
1741-
throws = throws))
1742+
throws = throws),
1743+
fromStar = true)
17421744
case o => o
17431745
}
17441746
val newNotMatchedActions = m.notMatchedActions.map {
@@ -1783,14 +1785,15 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
17831785
val resolvedDeleteCondition = deleteCondition.map(
17841786
resolveExpressionByPlanOutput(_, targetTable))
17851787
DeleteAction(resolvedDeleteCondition)
1786-
case UpdateAction(updateCondition, assignments) =>
1788+
case UpdateAction(updateCondition, assignments, fromStar) =>
17871789
val resolvedUpdateCondition = updateCondition.map(
17881790
resolveExpressionByPlanOutput(_, targetTable))
17891791
UpdateAction(
17901792
resolvedUpdateCondition,
17911793
// The update value can access columns from the target table only.
17921794
resolveAssignments(assignments, m, MergeResolvePolicy.TARGET,
1793-
throws = throws))
1795+
throws = throws),
1796+
fromStar)
17941797
case o => o
17951798
}
17961799

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala

Lines changed: 189 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@ import scala.collection.mutable
2121

2222
import org.apache.spark.sql.catalyst.SQLConfHelper
2323
import org.apache.spark.sql.catalyst.analysis.TableOutputResolver.DefaultValueFillMode.{NONE, RECURSE}
24-
import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, Expression, GetStructField, Literal}
24+
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, CreateNamedStruct, Expression, GetStructField, If, IsNull, Literal}
25+
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
2526
import org.apache.spark.sql.catalyst.plans.logical.Assignment
2627
import org.apache.spark.sql.catalyst.types.DataTypeUtils
2728
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
2829
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getDefaultValueExprOrNullLit
2930
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
3031
import org.apache.spark.sql.errors.QueryCompilationErrors
32+
import org.apache.spark.sql.internal.SQLConf
3133
import org.apache.spark.sql.types.{DataType, StructType}
3234
import org.apache.spark.util.ArrayImplicits._
3335

@@ -50,13 +52,18 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
5052
*
5153
* @param attrs table attributes
5254
* @param assignments assignments to align
55+
* @param fromStar whether the assignments were resolved from an UPDATE SET * clause.
56+
* These updates may assign struct fields individually
57+
* (preserving existing fields).
5358
* @param coerceNestedTypes whether to coerce nested types to match the target type
5459
* for complex types
60+
* @param missingSourcePaths paths that exist in target but not in source
5561
* @return aligned update assignments that match table attributes
5662
*/
5763
def alignUpdateAssignments(
5864
attrs: Seq[Attribute],
5965
assignments: Seq[Assignment],
66+
fromStar: Boolean,
6067
coerceNestedTypes: Boolean): Seq[Assignment] = {
6168

6269
val errors = new mutable.ArrayBuffer[String]()
@@ -68,7 +75,8 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
6875
assignments,
6976
addError = err => errors += err,
7077
colPath = Seq(attr.name),
71-
coerceNestedTypes)
78+
coerceNestedTypes,
79+
fromStar)
7280
}
7381

7482
if (errors.nonEmpty) {
@@ -152,7 +160,8 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
152160
assignments: Seq[Assignment],
153161
addError: String => Unit,
154162
colPath: Seq[String],
155-
coerceNestedTypes: Boolean = false): Expression = {
163+
coerceNestedTypes: Boolean = false,
164+
updateStar: Boolean = false): Expression = {
156165

157166
val (exactAssignments, otherAssignments) = assignments.partition { assignment =>
158167
assignment.key.semanticEquals(colExpr)
@@ -174,9 +183,31 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
174183
} else if (exactAssignments.isEmpty && fieldAssignments.isEmpty) {
175184
TableOutputResolver.checkNullability(colExpr, col, conf, colPath)
176185
} else if (exactAssignments.nonEmpty) {
177-
val value = exactAssignments.head.value
178-
val coerceMode = if (coerceNestedTypes) RECURSE else NONE
179-
TableOutputResolver.resolveUpdate("", value, col, conf, addError, colPath, coerceMode)
186+
if (SQLConf.get.mergeUpdateStructsByField && updateStar) {
187+
val value = exactAssignments.head.value
188+
col.dataType match {
189+
case structType: StructType =>
190+
// Expand assignments to leaf fields
191+
val structAssignment =
192+
applyNestedFieldAssignments(col, colExpr, value, addError, colPath,
193+
coerceNestedTypes)
194+
195+
// Wrap with null check for missing source fields
196+
fixNullExpansion(col, value, structType, structAssignment,
197+
colPath, addError)
198+
case _ =>
199+
// For non-struct types, resolve directly
200+
val coerceMode = if (coerceNestedTypes) RECURSE else NONE
201+
TableOutputResolver.resolveUpdate("", value, col, conf, addError, colPath,
202+
coerceMode)
203+
}
204+
} else {
205+
val value = exactAssignments.head.value
206+
val coerceMode = if (coerceNestedTypes) RECURSE else NONE
207+
val resolvedValue = TableOutputResolver.resolveUpdate("", value, col, conf, addError,
208+
colPath, coerceMode)
209+
resolvedValue
210+
}
180211
} else {
181212
applyFieldAssignments(col, colExpr, fieldAssignments, addError, colPath, coerceNestedTypes)
182213
}
@@ -210,13 +241,165 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
210241
}
211242
}
212243

244+
private def applyNestedFieldAssignments(
245+
col: Attribute,
246+
colExpr: Expression,
247+
value: Expression,
248+
addError: String => Unit,
249+
colPath: Seq[String],
250+
coerceNestedTyptes: Boolean): Expression = {
251+
252+
col.dataType match {
253+
case structType: StructType =>
254+
val fieldAttrs = DataTypeUtils.toAttributes(structType)
255+
256+
val updatedFieldExprs = fieldAttrs.zipWithIndex.map { case (fieldAttr, ordinal) =>
257+
val fieldPath = colPath :+ fieldAttr.name
258+
val targetFieldExpr = GetStructField(colExpr, ordinal, Some(fieldAttr.name))
259+
260+
// Try to find a corresponding field in the source value by name
261+
val sourceFieldValue: Expression = value.dataType match {
262+
case valueStructType: StructType =>
263+
valueStructType.fields.find(f => conf.resolver(f.name, fieldAttr.name)) match {
264+
case Some(matchingField) =>
265+
// Found matching field in source, extract it
266+
val fieldIndex = valueStructType.fieldIndex(matchingField.name)
267+
GetStructField(value, fieldIndex, Some(matchingField.name))
268+
case None =>
269+
// Field doesn't exist in source, use target's current value with null check
270+
TableOutputResolver.checkNullability(targetFieldExpr, fieldAttr, conf, fieldPath)
271+
}
272+
case _ =>
273+
// Value is not a struct, cannot extract field
274+
addError(s"Cannot assign non-struct value to struct field '${fieldPath.quoted}'")
275+
Literal(null, fieldAttr.dataType)
276+
}
277+
278+
// Recurse or resolve based on field type
279+
fieldAttr.dataType match {
280+
case nestedStructType: StructType =>
281+
// Field is a struct, recurse
282+
applyNestedFieldAssignments(fieldAttr, targetFieldExpr, sourceFieldValue,
283+
addError, fieldPath, coerceNestedTyptes)
284+
case _ =>
285+
// Field is not a struct, resolve with TableOutputResolver
286+
val coerceMode = if (coerceNestedTyptes) RECURSE else NONE
287+
TableOutputResolver.resolveUpdate("", sourceFieldValue, fieldAttr, conf, addError,
288+
fieldPath, coerceMode)
289+
}
290+
}
291+
toNamedStruct(structType, updatedFieldExprs)
292+
293+
case otherType =>
294+
addError(
295+
"Updating nested fields is only supported for StructType but " +
296+
s"'${colPath.quoted}' is of type $otherType")
297+
colExpr
298+
}
299+
}
300+
213301
private def toNamedStruct(structType: StructType, fieldExprs: Seq[Expression]): Expression = {
214302
val namedStructExprs = structType.fields.zip(fieldExprs).flatMap { case (field, expr) =>
215303
Seq(Literal(field.name), expr)
216304
}.toImmutableArraySeq
217305
CreateNamedStruct(namedStructExprs)
218306
}
219307

308+
private def getMissingSourcePaths(targetType: StructType,
309+
sourceType: DataType,
310+
colPath: Seq[String],
311+
addError: String => Unit): Seq[Seq[String]] = {
312+
val nestedTargetPaths = DataTypeUtils.extractLeafFieldPaths(targetType, Seq.empty)
313+
val nestedSourcePaths = sourceType match {
314+
case sourceStructType: StructType =>
315+
DataTypeUtils.extractLeafFieldPaths(sourceStructType, Seq.empty)
316+
case _ =>
317+
addError(s"Value for struct type: " +
318+
s"${colPath.quoted} must be a struct but was ${sourceType.simpleString}")
319+
Seq()
320+
}
321+
nestedSourcePaths.diff(nestedTargetPaths)
322+
}
323+
324+
/**
325+
* Creates a null check for a field at the given path within a struct expression.
326+
* Navigates through the struct hierarchy following the path and returns an IsNull check
327+
* for the final field.
328+
*
329+
* @param rootExpr the root expression to navigate from
330+
* @param path the field path to navigate (sequence of field names)
331+
* @return an IsNull expression checking if the field at the path is null
332+
*/
333+
private def createNullCheckForFieldPath(
334+
rootExpr: Expression,
335+
path: Seq[String]): Expression = {
336+
var currentExpr: Expression = rootExpr
337+
path.foreach { fieldName =>
338+
currentExpr.dataType match {
339+
case st: StructType =>
340+
st.fields.find(f => conf.resolver(f.name, fieldName)) match {
341+
case Some(field) =>
342+
val fieldIndex = st.fieldIndex(field.name)
343+
currentExpr = GetStructField(currentExpr, fieldIndex, Some(field.name))
344+
case None =>
345+
// Field not found, shouldn't happen
346+
}
347+
case _ =>
348+
// Not a struct, shouldn't happen
349+
}
350+
}
351+
IsNull(currentExpr)
352+
}
353+
354+
/**
355+
* As UPDATE SET * can assign struct fields individually (preserving existing fields),
356+
* this will lead to null expansion, ie, a struct is created where all fields are null.
357+
* Wraps a struct assignment with null checks for the source and missing source fields.
358+
* Return null if all are null.
359+
*
360+
* @param col the target column attribute
361+
* @param value the source value expression
362+
* @param structType the target struct type
363+
* @param structAssignment the struct assignment result to wrap
364+
* @param colPath the column path for error reporting
365+
* @param addError error reporting function
366+
* @return the wrapped expression with null checks
367+
*/
368+
private def fixNullExpansion(
369+
col: Attribute,
370+
value: Expression,
371+
structType: StructType,
372+
structAssignment: Expression,
373+
colPath: Seq[String],
374+
addError: String => Unit): Expression = {
375+
// As StoreAssignmentPolicy.LEGACY is not allowed in DSv2, always add null check for
376+
// non-nullable column
377+
if (!col.nullable) {
378+
AssertNotNull(value)
379+
} else {
380+
// Check if source struct is null
381+
val valueIsNull = IsNull(value)
382+
383+
// Check if missing source paths (paths in target but not in source) are not null
384+
// These will be null for the case of UPDATE SET * and
385+
val missingSourcePaths = getMissingSourcePaths(structType, value.dataType, colPath, addError)
386+
val condition = if (missingSourcePaths.nonEmpty) {
387+
// Check if all target attributes at missing source paths are null
388+
val missingFieldNullChecks = missingSourcePaths.map { path =>
389+
createNullCheckForFieldPath(col, path)
390+
}
391+
// Combine all null checks with AND
392+
val allMissingFieldsNull = missingFieldNullChecks.reduce[Expression]((a, b) => And(a, b))
393+
And(valueIsNull, allMissingFieldsNull)
394+
} else {
395+
valueIsNull
396+
}
397+
398+
// Return: If (condition) THEN NULL ELSE structAssignment
399+
If(condition, Literal(null, structAssignment.dataType), structAssignment)
400+
}
401+
}
402+
220403
/**
221404
* Checks whether assignments are aligned and compatible with table columns.
222405
*

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] {
4444
validateStoreAssignmentPolicy()
4545
val newTable = cleanAttrMetadata(u.table)
4646
val newAssignments = AssignmentUtils.alignUpdateAssignments(u.table.output, u.assignments,
47-
coerceNestedTypes = false)
47+
fromStar = false, coerceNestedTypes = false)
4848
u.copy(table = newTable, assignments = newAssignments)
4949

5050
case u: UpdateTable if !u.skipSchemaResolution && u.resolved && !u.aligned =>
@@ -53,10 +53,11 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] {
5353
case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved && m.rewritable && !m.aligned &&
5454
!m.needSchemaEvolution =>
5555
validateStoreAssignmentPolicy()
56-
val coerceNestedTypes = SQLConf.get.coerceMergeNestedTypes
56+
val coerceNestedTypes = SQLConf.get.mergeCoerceNestedTypes
5757
m.copy(
5858
targetTable = cleanAttrMetadata(m.targetTable),
59-
matchedActions = alignActions(m.targetTable.output, m.matchedActions, coerceNestedTypes),
59+
matchedActions = alignActions(m.targetTable.output, m.matchedActions,
60+
coerceNestedTypes),
6061
notMatchedActions = alignActions(m.targetTable.output, m.notMatchedActions,
6162
coerceNestedTypes),
6263
notMatchedBySourceActions = alignActions(m.targetTable.output, m.notMatchedBySourceActions,
@@ -117,9 +118,9 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] {
117118
actions: Seq[MergeAction],
118119
coerceNestedTypes: Boolean): Seq[MergeAction] = {
119120
actions.map {
120-
case u @ UpdateAction(_, assignments) =>
121+
case u @ UpdateAction(_, assignments, fromStar) =>
121122
u.copy(assignments = AssignmentUtils.alignUpdateAssignments(attrs, assignments,
122-
coerceNestedTypes))
123+
fromStar, coerceNestedTypes))
123124
case d: DeleteAction =>
124125
d
125126
case i @ InsertAction(_, assignments) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
334334
// original row ID values must be preserved and passed back to the table to encode updates
335335
// if there are any assignments to row ID attributes, add extra columns for original values
336336
val updateAssignments = (matchedActions ++ notMatchedBySourceActions).flatMap {
337-
case UpdateAction(_, assignments) => assignments
337+
case UpdateAction(_, assignments, _) => assignments
338338
case _ => Nil
339339
}
340340
buildOriginalRowIdValues(rowIdAttrs, updateAssignments)
@@ -434,7 +434,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
434434
// converts a MERGE action into an instruction on top of the joined plan for group-based plans
435435
private def toInstruction(action: MergeAction, metadataAttrs: Seq[Attribute]): Instruction = {
436436
action match {
437-
case UpdateAction(cond, assignments) =>
437+
case UpdateAction(cond, assignments, _) =>
438438
val rowValues = assignments.map(_.value)
439439
val metadataValues = nullifyMetadataOnUpdate(metadataAttrs)
440440
val output = Seq(Literal(WRITE_WITH_METADATA_OPERATION)) ++ rowValues ++ metadataValues
@@ -466,12 +466,12 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
466466
splitUpdates: Boolean): Instruction = {
467467

468468
action match {
469-
case UpdateAction(cond, assignments) if splitUpdates =>
469+
case UpdateAction(cond, assignments, _) if splitUpdates =>
470470
val output = deltaDeleteOutput(rowAttrs, rowIdAttrs, metadataAttrs, originalRowIdValues)
471471
val otherOutput = deltaReinsertOutput(assignments, metadataAttrs, originalRowIdValues)
472472
Split(cond.getOrElse(TrueLiteral), output, otherOutput)
473473

474-
case UpdateAction(cond, assignments) =>
474+
case UpdateAction(cond, assignments, _) =>
475475
val output = deltaUpdateOutput(assignments, metadataAttrs, originalRowIdValues)
476476
Keep(Update, cond.getOrElse(TrueLiteral), output)
477477

@@ -495,7 +495,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
495495
val actions = merge.matchedActions ++ merge.notMatchedActions ++ merge.notMatchedBySourceActions
496496
actions.foreach {
497497
case DeleteAction(Some(cond)) => checkMergeIntoCondition("DELETE", cond)
498-
case UpdateAction(Some(cond), _) => checkMergeIntoCondition("UPDATE", cond)
498+
case UpdateAction(Some(cond), _, _) => checkMergeIntoCondition("UPDATE", cond)
499499
case InsertAction(Some(cond), _) => checkMergeIntoCondition("INSERT", cond)
500500
case _ => // OK
501501
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
149149

150150
private def replaceNullWithFalse(mergeActions: Seq[MergeAction]): Seq[MergeAction] = {
151151
mergeActions.map {
152-
case u @ UpdateAction(Some(cond), _) => u.copy(condition = Some(replaceNullWithFalse(cond)))
152+
case u @ UpdateAction(Some(cond), _, _) =>
153+
u.copy(condition = Some(replaceNullWithFalse(cond)))
153154
case u @ UpdateStarAction(Some(cond)) => u.copy(condition = Some(replaceNullWithFalse(cond)))
154155
case d @ DeleteAction(Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond)))
155156
case i @ InsertAction(Some(cond), _) => i.copy(condition = Some(replaceNullWithFalse(cond)))

0 commit comments

Comments
 (0)