Skip to content

Commit

Permalink
[SPARK-47681][SQL][FOLLOWUP] Fix variant decimal handling
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

There are two issues with the current variant decimal handling:
1. The precision and scale of the `BigDecimal` returned by `getDecimal` is not checked. Based on the variant spec, they must be within the corresponding limit for DECIMAL4/8/16. An out-of-range decimal can lead to failure in the downstream Spark operations.
2. The current `schema_of_variant` implementation doesn't correctly handle the case where precision is smaller than scale. Spark's `DecimalType` requires `precision >= scale`.

The Python side requires a similar fix for 1. During the fix, I found that Python error reporting was not correctly implemented (it was never tested either) and I also fixed it.

### Why are the changes needed?

They are bug fixes and are required to process decimals correctly.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Unit tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #46338 from chenhao-db/fix_variant_decimal.

Authored-by: Chenhao Li <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
chenhao-db authored and cloud-fan committed May 7, 2024
1 parent 4c68842 commit d67752a
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -392,21 +392,32 @@ public static double getDouble(byte[] value, int pos) {
return Double.longBitsToDouble(readLong(value, pos + 1, 8));
}

// Check whether the precision and scale of the decimal are within the limit.
private static void checkDecimal(BigDecimal d, int maxPrecision) {
if (d.precision() > maxPrecision || d.scale() > maxPrecision) {
throw malformedVariant();
}
}

// Get a decimal value from variant value `value[pos...]`.
// Throw `MALFORMED_VARIANT` if the variant is malformed.
public static BigDecimal getDecimal(byte[] value, int pos) {
checkIndex(pos, value.length);
int basicType = value[pos] & BASIC_TYPE_MASK;
int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK;
if (basicType != PRIMITIVE) throw unexpectedType(Type.DECIMAL);
int scale = value[pos + 1];
// Interpret the scale byte as unsigned. If it is a negative byte, the unsigned value must be
// greater than `MAX_DECIMAL16_PRECISION` and will trigger an error in `checkDecimal`.
int scale = value[pos + 1] & 0xFF;
BigDecimal result;
switch (typeInfo) {
case DECIMAL4:
result = BigDecimal.valueOf(readLong(value, pos + 2, 4), scale);
checkDecimal(result, MAX_DECIMAL4_PRECISION);
break;
case DECIMAL8:
result = BigDecimal.valueOf(readLong(value, pos + 2, 8), scale);
checkDecimal(result, MAX_DECIMAL8_PRECISION);
break;
case DECIMAL16:
checkIndex(pos + 17, value.length);
Expand All @@ -417,6 +428,7 @@ public static BigDecimal getDecimal(byte[] value, int pos) {
bytes[i] = value[pos + 17 - i];
}
result = new BigDecimal(new BigInteger(bytes), scale);
checkDecimal(result, MAX_DECIMAL16_PRECISION);
break;
default:
throw unexpectedType(Type.DECIMAL);
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/errors/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,11 @@
"<arg1> and <arg2> should be of the same length, got <arg1_length> and <arg2_length>."
]
},
"MALFORMED_VARIANT" : {
"message" : [
"Variant binary is malformed. Please check the data source is valid."
]
},
"MASTER_URL_NOT_SET": {
"message": [
"A master URL must be set in your configuration."
Expand Down
13 changes: 13 additions & 0 deletions python/pyspark/sql/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1577,6 +1577,19 @@ def test_variant_type(self):
# check repr
self.assertEqual(str(variants[0]), str(eval(repr(variants[0]))))

metadata = bytes([1, 0, 0])
self.assertEqual(str(VariantVal(bytes([32, 0, 1, 0, 0, 0]), metadata)), "1")
self.assertEqual(str(VariantVal(bytes([32, 1, 2, 0, 0, 0]), metadata)), "0.2")
self.assertEqual(str(VariantVal(bytes([32, 2, 3, 0, 0, 0]), metadata)), "0.03")
self.assertEqual(str(VariantVal(bytes([32, 0, 1, 0, 0, 0]), metadata)), "1")
self.assertEqual(str(VariantVal(bytes([32, 0, 255, 201, 154, 59]), metadata)), "999999999")
self.assertRaises(
PySparkValueError, lambda: str(VariantVal(bytes([32, 0, 0, 202, 154, 59]), metadata))
)
self.assertRaises(
PySparkValueError, lambda: str(VariantVal(bytes([32, 10, 1, 0, 0, 0]), metadata))
)

def test_from_ddl(self):
self.assertEqual(DataType.fromDDL("long"), LongType())
self.assertEqual(
Expand Down
59 changes: 39 additions & 20 deletions python/pyspark/sql/variant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@ class VariantUtils:
)
EPOCH_NTZ = datetime.datetime(year=1970, month=1, day=1, hour=0, minute=0, second=0)

MAX_DECIMAL4_PRECISION = 9
MAX_DECIMAL4_VALUE = 10**MAX_DECIMAL4_PRECISION
MAX_DECIMAL8_PRECISION = 18
MAX_DECIMAL8_VALUE = 10**MAX_DECIMAL8_PRECISION
MAX_DECIMAL16_PRECISION = 38
MAX_DECIMAL16_VALUE = 10**MAX_DECIMAL16_PRECISION

@classmethod
def to_json(cls, value: bytes, metadata: bytes, zone_id: str = "UTC") -> str:
"""
Expand Down Expand Up @@ -142,7 +149,7 @@ def _read_long(cls, data: bytes, pos: int, num_bytes: int, signed: bool) -> int:
@classmethod
def _check_index(cls, pos: int, length: int) -> None:
if pos < 0 or pos >= length:
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})

@classmethod
def _get_type_info(cls, value: bytes, pos: int) -> Tuple[int, int]:
Expand All @@ -162,14 +169,14 @@ def _get_metadata_key(cls, metadata: bytes, id: int) -> str:
offset_size = ((metadata[0] >> 6) & 0x3) + 1
dict_size = cls._read_long(metadata, 1, offset_size, signed=False)
if id >= dict_size:
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
string_start = 1 + (dict_size + 2) * offset_size
offset = cls._read_long(metadata, 1 + (id + 1) * offset_size, offset_size, signed=False)
next_offset = cls._read_long(
metadata, 1 + (id + 2) * offset_size, offset_size, signed=False
)
if offset > next_offset:
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
cls._check_index(string_start + next_offset - 1, len(metadata))
return metadata[string_start + offset : (string_start + next_offset)].decode("utf-8")

Expand All @@ -180,15 +187,15 @@ def _get_boolean(cls, value: bytes, pos: int) -> bool:
if basic_type != VariantUtils.PRIMITIVE or (
type_info != VariantUtils.TRUE and type_info != VariantUtils.FALSE
):
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
return type_info == VariantUtils.TRUE

@classmethod
def _get_long(cls, value: bytes, pos: int) -> int:
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type != VariantUtils.PRIMITIVE:
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
if type_info == VariantUtils.INT1:
return cls._read_long(value, pos + 1, 1, signed=True)
elif type_info == VariantUtils.INT2:
Expand All @@ -197,25 +204,25 @@ def _get_long(cls, value: bytes, pos: int) -> int:
return cls._read_long(value, pos + 1, 4, signed=True)
elif type_info == VariantUtils.INT8:
return cls._read_long(value, pos + 1, 8, signed=True)
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})

@classmethod
def _get_date(cls, value: bytes, pos: int) -> datetime.date:
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type != VariantUtils.PRIMITIVE:
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
if type_info == VariantUtils.DATE:
days_since_epoch = cls._read_long(value, pos + 1, 4, signed=True)
return datetime.date.fromordinal(VariantUtils.EPOCH.toordinal() + days_since_epoch)
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})

@classmethod
def _get_timestamp(cls, value: bytes, pos: int, zone_id: str) -> datetime.datetime:
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type != VariantUtils.PRIMITIVE:
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
if type_info == VariantUtils.TIMESTAMP_NTZ:
microseconds_since_epoch = cls._read_long(value, pos + 1, 8, signed=True)
return VariantUtils.EPOCH_NTZ + datetime.timedelta(
Expand All @@ -226,7 +233,7 @@ def _get_timestamp(cls, value: bytes, pos: int, zone_id: str) -> datetime.dateti
return (
VariantUtils.EPOCH + datetime.timedelta(microseconds=microseconds_since_epoch)
).astimezone(ZoneInfo(zone_id))
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})

@classmethod
def _get_string(cls, value: bytes, pos: int) -> str:
Expand All @@ -245,47 +252,59 @@ def _get_string(cls, value: bytes, pos: int) -> str:
length = cls._read_long(value, pos + 1, VariantUtils.U32_SIZE, signed=False)
cls._check_index(start + length - 1, len(value))
return value[start : start + length].decode("utf-8")
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})

@classmethod
def _get_double(cls, value: bytes, pos: int) -> float:
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type != VariantUtils.PRIMITIVE:
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
if type_info == VariantUtils.FLOAT:
cls._check_index(pos + 4, len(value))
return struct.unpack("<f", value[pos + 1 : pos + 5])[0]
elif type_info == VariantUtils.DOUBLE:
cls._check_index(pos + 8, len(value))
return struct.unpack("<d", value[pos + 1 : pos + 9])[0]
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})

@classmethod
def _check_decimal(cls, unscaled: int, scale: int, max_unscaled: int, max_scale: int) -> None:
# max_unscaled == 10**max_scale, but we pass a literal parameter to avoid redundant
# computation.
if unscaled >= max_unscaled or unscaled <= -max_unscaled or scale > max_scale:
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})

@classmethod
def _get_decimal(cls, value: bytes, pos: int) -> decimal.Decimal:
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type != VariantUtils.PRIMITIVE:
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
scale = value[pos + 1]
unscaled = 0
if type_info == VariantUtils.DECIMAL4:
unscaled = cls._read_long(value, pos + 2, 4, signed=True)
cls._check_decimal(unscaled, scale, cls.MAX_DECIMAL4_VALUE, cls.MAX_DECIMAL4_PRECISION)
elif type_info == VariantUtils.DECIMAL8:
unscaled = cls._read_long(value, pos + 2, 8, signed=True)
cls._check_decimal(unscaled, scale, cls.MAX_DECIMAL8_VALUE, cls.MAX_DECIMAL8_PRECISION)
elif type_info == VariantUtils.DECIMAL16:
cls._check_index(pos + 17, len(value))
unscaled = int.from_bytes(value[pos + 2 : pos + 18], byteorder="little", signed=True)
cls._check_decimal(
unscaled, scale, cls.MAX_DECIMAL16_VALUE, cls.MAX_DECIMAL16_PRECISION
)
else:
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
return decimal.Decimal(unscaled) * (decimal.Decimal(10) ** (-scale))

@classmethod
def _get_binary(cls, value: bytes, pos: int) -> bytes:
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type != VariantUtils.PRIMITIVE or type_info != VariantUtils.BINARY:
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
start = pos + 1 + VariantUtils.U32_SIZE
length = cls._read_long(value, pos + 1, VariantUtils.U32_SIZE, signed=False)
cls._check_index(start + length - 1, len(value))
Expand Down Expand Up @@ -331,7 +350,7 @@ def _get_type(cls, value: bytes, pos: int) -> Any:
return datetime.datetime
elif type_info == VariantUtils.LONG_STR:
return str
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})

@classmethod
def _to_json(cls, value: bytes, metadata: bytes, pos: int, zone_id: str) -> str:
Expand Down Expand Up @@ -419,7 +438,7 @@ def _get_scalar(
elif variant_type == datetime.datetime:
return cls._get_timestamp(value, pos, zone_id)
else:
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})

@classmethod
def _handle_object(
Expand All @@ -432,7 +451,7 @@ def _handle_object(
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type != VariantUtils.OBJECT:
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
large_size = ((type_info >> 4) & 0x1) != 0
size_bytes = VariantUtils.U32_SIZE if large_size else 1
num_fields = cls._read_long(value, pos + 1, size_bytes, signed=False)
Expand Down Expand Up @@ -461,7 +480,7 @@ def _handle_array(cls, value: bytes, pos: int, func: Callable[[List[int]], Any])
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type != VariantUtils.ARRAY:
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
large_size = ((type_info >> 2) & 0x1) != 0
size_bytes = VariantUtils.U32_SIZE if large_size else 1
num_fields = cls._read_long(value, pos + 1, size_bytes, signed=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,8 @@ object SchemaOfVariant {
case Type.DOUBLE => DoubleType
case Type.DECIMAL =>
val d = v.getDecimal
DecimalType(d.precision(), d.scale())
// Spark doesn't allow `DecimalType` to have `precision < scale`.
DecimalType(d.precision().max(d.scale()), d.scale())
case Type.DATE => DateType
case Type.TIMESTAMP => TimestampType
case Type.TIMESTAMP_NTZ => TimestampNTZType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
check(Array(primitiveHeader(INT8), 0, 0, 0, 0, 0, 0, 0), emptyMetadata)
// DECIMAL16 only has 15 byte content.
check(Array(primitiveHeader(DECIMAL16)) ++ Array.fill(16)(0.toByte), emptyMetadata)
// 1e38 has a precision of 39. Even if it still fits into 16 bytes, it is not a valid decimal.
check(Array[Byte](primitiveHeader(DECIMAL16), 0) ++
BigDecimal(1e38).toBigInt.toByteArray.reverse, emptyMetadata)
// Short string content too short.
check(Array(shortStrHeader(2), 'x'), emptyMetadata)
// Long string length too short (requires 4 bytes).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession {
check("null", "VOID")
check("1", "BIGINT")
check("1.0", "DECIMAL(1,0)")
check("0.01", "DECIMAL(2,2)")
check("1E0", "DOUBLE")
check("true", "BOOLEAN")
check("\"2000-01-01\"", "STRING")
Expand Down

0 comments on commit d67752a

Please sign in to comment.