Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2561,20 +2561,26 @@ case class ArrayPosition(left: Expression, right: Expression)
""",
since = "3.4.0",
group = "array_funcs")
case class Get(
left: Expression,
right: Expression,
replacement: Expression) extends RuntimeReplaceable with InheritAnalysisRules {
case class Get(left: Expression, right: Expression)
extends BinaryExpression with RuntimeReplaceable with ImplicitCastInputTypes {

def this(left: Expression, right: Expression) =
this(left, right, GetArrayItem(left, right, failOnError = false))
override def inputTypes: Seq[AbstractDataType] = left.dataType match {
case _: ArrayType => Seq(ArrayType, IntegerType)
// Do not apply implicit cast if the first arguement is not array type.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to keep the null type error behavior. #50590 improved the error message for it, but also made GetArrayItem fail for null type input. This usually doesn't matter as GetArrayItem should never get null type input except from Get, but this PR restores it for consistency with other ExtractValue expressions.

case _ => Nil
}

override def prettyName: String = "get"
override def checkInputDataTypes(): TypeCheckResult = {
ExpectsInputTypes.checkInputDataTypes(Seq(left, right), Seq(ArrayType, IntegerType))
}

override def parameters: Seq[Expression] = Seq(left, right)
override lazy val replacement: Expression = GetArrayItem(left, right, failOnError = false)

override protected def withNewChildInternal(newChild: Expression): Expression =
this.copy(replacement = newChild)
override def prettyName: String = "get"

override def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = {
copy(left = newLeft, right = newRight)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.QueryContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.trees.TreePattern.{EXTRACT_VALUE, TreePattern}
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils}
Expand Down Expand Up @@ -146,7 +145,9 @@ trait ExtractValue extends Expression with QueryErrorsBase {
* For example, when get field `yEAr` from `<year: int, month: int>`, we should pass in `yEAr`.
*/
case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None)
extends UnaryExpression with ExtractValue {
extends UnaryExpression with ExtractValue with ExpectsInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(StructType, IntegralType)

lazy val childSchema = child.dataType.asInstanceOf[StructType]

Expand Down Expand Up @@ -207,6 +208,13 @@ case class GetArrayStructFields(
numFields: Int,
containsNull: Boolean) extends UnaryExpression with ExtractValue {

override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
case ArrayType(_: StructType, _) => TypeCheckResult.TypeCheckSuccess
// This should never happen, unless we hit a bug.
case other => TypeCheckResult.TypeCheckFailure(
"GetArrayStructFields.child must be array of struct type, but got " + other)
}

override def dataType: DataType = ArrayType(field.dataType, containsNull)
override def toString: String = s"$child.${field.name}"
override def sql: String = s"${child.sql}.${quoteIdentifier(field.name)}"
Expand Down Expand Up @@ -285,8 +293,7 @@ case class GetArrayItem(
with ExtractValue
with SupportQueryContext {

// We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment also didn't look correct. Actually we did checking the datatype of the child in checkInputDataTypes.

override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType)
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegralType)

override def toString: String = s"$child[$ordinal]"
override def sql: String = s"${child.sql}[${ordinal.sql}]"
Expand Down Expand Up @@ -355,30 +362,6 @@ case class GetArrayItem(
})
}

override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
case (_: ArrayType, e2) if !e2.isInstanceOf[IntegralType] =>
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> ordinalNumber(1),
"requiredType" -> toSQLType(IntegralType),
"inputSql" -> toSQLExpr(right),
"inputType" -> toSQLType(right.dataType))
)
case (e1, _) if !e1.isInstanceOf[ArrayType] =>
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> ordinalNumber(0),
"requiredType" -> toSQLType(TypeCollection(ArrayType)),
"inputSql" -> toSQLExpr(left),
"inputType" -> toSQLType(left.dataType))
)
case _ => TypeCheckResult.TypeCheckSuccess
}
}

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): GetArrayItem =
copy(child = newLeft, ordinal = newRight)
Expand Down Expand Up @@ -507,16 +490,19 @@ case class GetMapValue(child: Expression, key: Expression)

private[catalyst] def keyType = child.dataType.asInstanceOf[MapType].keyType

override def checkInputDataTypes(): TypeCheckResult = {
super.checkInputDataTypes() match {
case f if f.isFailure => f
case TypeCheckResult.TypeCheckSuccess =>
TypeUtils.checkForOrderingExpr(keyType, prettyName)
}
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
case _: MapType =>
super.checkInputDataTypes() match {
case f if f.isFailure => f
case TypeCheckResult.TypeCheckSuccess =>
TypeUtils.checkForOrderingExpr(keyType, prettyName)
}
// This should never happen, unless we hit a bug.
case other => TypeCheckResult.TypeCheckFailure(
"GetMapValue.child must be map type, but got " + other)
}

// We have done type checking for child in `ExtractValue`, so only need to check the `key`.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType)
override def inputTypes: Seq[AbstractDataType] = Seq(MapType, keyType)
Copy link
Member

@viirya viirya Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, if it is possible that the child data type could be other than MapType, the keyType cannot directly call asInstanceOf to cast to MapType, because when we call inputTypes, the child data type is not done checking type.

Copy link
Contributor Author

@cloud-fan cloud-fan Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is why I override def checkInputDataTypes, so that we only access inputTypes when the first child is map type.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh okay.


override def toString: String = s"$child[$key]"
override def sql: String = s"${child.sql}[${key.sql}]"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ import org.apache.spark.sql.catalyst.analysis.{EmptyFunctionRegistry, FakeV2Sess
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LocalRelation, LogicalPlan, OneRowRelation, Project}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LocalRelation, LogicalPlan, OneRowRelation, Project}
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, StringType, StructType}


class OptimizerStructuralIntegrityCheckerSuite extends PlanTest {
Expand All @@ -39,6 +41,18 @@ class OptimizerStructuralIntegrityCheckerSuite extends PlanTest {
case agg @ Aggregate(Nil, aggregateExpressions, child, _) =>
// Project cannot host AggregateExpression
Project(aggregateExpressions, child)
case Filter(cond, child) =>
val newCond = cond.transform {
case g @ GetStructField(a: AttributeReference, _, _) =>
g.copy(child = a.withDataType(StringType))
case g @ GetArrayStructFields(a: AttributeReference, _, _, _, _) =>
g.copy(child = a.withDataType(StringType))
case g @ GetArrayItem(a: AttributeReference, _, _) =>
g.copy(child = a.withDataType(StringType))
case g @ GetMapValue(a: AttributeReference, _) =>
g.copy(child = a.withDataType(StringType))
}
Filter(newCond, child)
}
}

Expand Down Expand Up @@ -79,6 +93,38 @@ class OptimizerStructuralIntegrityCheckerSuite extends PlanTest {
SimpleTestOptimizer.execute(analyzed)
}

test("check for invalid plan after execution of rule - bad ExtractValue") {
val input = LocalRelation(
$"c1".struct(new StructType().add("f1", "boolean")),
$"c2".array(new StructType().add("f1", "boolean")),
$"c3".array(BooleanType),
new DslAttr($"c4").map(StringType, BooleanType)
)

def assertCheckFailed(expr: Expression): Unit = {
val analyzed = Filter(expr, input).analyze
assert(analyzed.resolved)
// Should fail verification with the OptimizeRuleBreakSI rule
val message = intercept[SparkException] {
Optimize.execute(analyzed)
}.getMessage
val ruleName = OptimizeRuleBreakSI.ruleName
assert(message.contains(s"Rule $ruleName in batch OptimizeRuleBreakSI"))
assert(message.contains("generated an invalid plan"))
}

// This resolution validation should be included in the lightweight
// validator so that it's validated in production.
withSQLConf(
SQLConf.PLAN_CHANGE_VALIDATION.key -> "false",
SQLConf.LIGHTWEIGHT_PLAN_CHANGE_VALIDATION.key -> "true") {
assertCheckFailed($"c1.f1")
assertCheckFailed($"c2.f1".getItem(0))
assertCheckFailed($"c3".getItem(0))
assertCheckFailed($"c4".getItem("key"))
}
}

test("check for invalid plan before execution of any rule") {
val analyzed =
Aggregate(Nil, Seq[NamedExpression](max($"id") as "m"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ Project [get(array(1, 2, 3), 3) AS get(array(1, 2, 3), 3)#x]
-- !query
select get(array(1, 2, 3), null)
-- !query analysis
Project [get(array(1, 2, 3), null) AS get(array(1, 2, 3), NULL)#x]
Project [get(array(1, 2, 3), cast(null as int)) AS get(array(1, 2, 3), NULL)#x]
+- OneRowRelation


Expand All @@ -438,8 +438,8 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
"inputSql" : "\"1\"",
"inputType" : "\"INT\"",
"paramIndex" : "first",
"requiredType" : "(\"ARRAY\")",
"sqlExpr" : "\"1[0]\""
"requiredType" : "\"ARRAY\"",
"sqlExpr" : "\"get(1, 0)\""
},
"queryContext" : [ {
"objectType" : "",
Expand All @@ -462,8 +462,8 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
"inputSql" : "\"1\"",
"inputType" : "\"INT\"",
"paramIndex" : "first",
"requiredType" : "(\"ARRAY\")",
"sqlExpr" : "\"1[-1]\""
"requiredType" : "\"ARRAY\"",
"sqlExpr" : "\"get(1, -1)\""
},
"queryContext" : [ {
"objectType" : "",
Expand All @@ -486,8 +486,8 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
"inputSql" : "\"1\"",
"inputType" : "\"STRING\"",
"paramIndex" : "first",
"requiredType" : "(\"ARRAY\")",
"sqlExpr" : "\"1[0]\""
"requiredType" : "\"ARRAY\"",
"sqlExpr" : "\"get(1, 0)\""
},
"queryContext" : [ {
"objectType" : "",
Expand All @@ -510,8 +510,8 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
"inputSql" : "\"1\"",
"inputType" : "\"STRING\"",
"paramIndex" : "first",
"requiredType" : "(\"ARRAY\")",
"sqlExpr" : "\"1[-1]\""
"requiredType" : "\"ARRAY\"",
"sqlExpr" : "\"get(1, -1)\""
},
"queryContext" : [ {
"objectType" : "",
Expand All @@ -534,8 +534,8 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
"inputSql" : "\"NULL\"",
"inputType" : "\"VOID\"",
"paramIndex" : "first",
"requiredType" : "(\"ARRAY\")",
"sqlExpr" : "\"NULL[0]\""
"requiredType" : "\"ARRAY\"",
"sqlExpr" : "\"get(NULL, 0)\""
},
"queryContext" : [ {
"objectType" : "",
Expand All @@ -558,8 +558,8 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
"inputSql" : "\"NULL\"",
"inputType" : "\"VOID\"",
"paramIndex" : "first",
"requiredType" : "(\"ARRAY\")",
"sqlExpr" : "\"NULL[-1]\""
"requiredType" : "\"ARRAY\"",
"sqlExpr" : "\"get(NULL, -1)\""
},
"queryContext" : [ {
"objectType" : "",
Expand All @@ -582,8 +582,8 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
"inputSql" : "\"NULL\"",
"inputType" : "\"VOID\"",
"paramIndex" : "first",
"requiredType" : "(\"ARRAY\")",
"sqlExpr" : "\"NULL[NULL]\""
"requiredType" : "\"ARRAY\"",
"sqlExpr" : "\"get(NULL, NULL)\""
},
"queryContext" : [ {
"objectType" : "",
Expand All @@ -606,8 +606,8 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
"inputSql" : "\"CAST(NULL AS STRING)\"",
"inputType" : "\"STRING\"",
"paramIndex" : "first",
"requiredType" : "(\"ARRAY\")",
"sqlExpr" : "\"CAST(NULL AS STRING)[0]\""
"requiredType" : "\"ARRAY\"",
"sqlExpr" : "\"get(CAST(NULL AS STRING), 0)\""
},
"queryContext" : [ {
"objectType" : "",
Expand Down
Loading