Skip to content

Commit

Permalink
[SPARK-47681][FOLLOWUP] Fix schema_of_variant(decimal)
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

The PR #46338 found `schema_of_variant` sometimes could not correctly handle variant decimals and had a fix. However, I found that the fix is incomplete and `schema_of_variant` can still fail on some inputs. The reason is that `VariantUtil.getDecimal` calls `stripTrailingZeros`. For an input decimal `10.00`,  the resulting scale is -1 and the unscaled value is 1. However, negative decimal scale is not allowed by Spark. The correct approach is to use the `BigDecimal` to construct a `Decimal` and read its precision and scale, as what we did in `VariantGet`.

This PR also includes a minor change for `VariantGet`, where a duplicated expression is computed twice.

### Why are the changes needed?

They are bug fixes and are required to process decimals correctly.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

More unit tests. Some of them would fail without the change in this PR (e.g., `check("10.00", "DECIMAL(2,0)")`). Others wouldn't fail, but can still enhance test coverage.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #46549 from chenhao-db/fix_decimal_schema.

Authored-by: Chenhao Li <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
chenhao-db authored and cloud-fan committed May 13, 2024
1 parent 42f2132 commit 3456d4f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ case object VariantGet {
case Type.DOUBLE => Literal(v.getDouble, DoubleType)
case Type.DECIMAL =>
val d = Decimal(v.getDecimal)
Literal(Decimal(v.getDecimal), DecimalType(d.precision, d.scale))
Literal(d, DecimalType(d.precision, d.scale))
case Type.DATE => Literal(v.getLong.toInt, DateType)
case Type.TIMESTAMP => Literal(v.getLong, TimestampType)
case Type.TIMESTAMP_NTZ => Literal(v.getLong, TimestampNTZType)
Expand Down Expand Up @@ -682,9 +682,8 @@ object SchemaOfVariant {
case Type.STRING => SQLConf.get.defaultStringType
case Type.DOUBLE => DoubleType
case Type.DECIMAL =>
val d = v.getDecimal
// Spark doesn't allow `DecimalType` to have `precision < scale`.
DecimalType(d.precision().max(d.scale()), d.scale())
val d = Decimal(v.getDecimal)
DecimalType(d.precision, d.scale)
case Type.DATE => DateType
case Type.TIMESTAMP => TimestampType
case Type.TIMESTAMP_NTZ => TimestampNTZType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,16 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession {
check("1", "BIGINT")
check("1.0", "DECIMAL(1,0)")
check("0.01", "DECIMAL(2,2)")
check("1.00", "DECIMAL(1,0)")
check("10.00", "DECIMAL(2,0)")
check("10.10", "DECIMAL(3,1)")
check("0.0", "DECIMAL(1,0)")
check("-0.0", "DECIMAL(1,0)")
check("2147483647.999", "DECIMAL(13,3)")
check("9223372036854775808", "DECIMAL(19,0)")
check("-9223372036854775808.0", "DECIMAL(19,0)")
check("9999999999999999999.9999999999999999999", "DECIMAL(38,19)")
check("9999999999999999999.99999999999999999999", "DOUBLE")
check("1E0", "DOUBLE")
check("true", "BOOLEAN")
check("\"2000-01-01\"", "STRING")
Expand Down

0 comments on commit 3456d4f

Please sign in to comment.