-
Notifications
You must be signed in to change notification settings - Fork 28k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-47681][SQL][FOLLOWUP] Fix variant decimal handling. #46338
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -392,21 +392,32 @@ public static double getDouble(byte[] value, int pos) { | |||
return Double.longBitsToDouble(readLong(value, pos + 1, 8)); | ||||
} | ||||
|
||||
// Check whether the precision and scale of the decimal are within the limit. | ||||
private static void checkDecimal(BigDecimal d, int maxPrecision) { | ||||
if (d.precision() > maxPrecision || d.scale() > maxPrecision) { | ||||
throw malformedVariant(); | ||||
} | ||||
} | ||||
|
||||
// Get a decimal value from variant value `value[pos...]`. | ||||
// Throw `MALFORMED_VARIANT` if the variant is malformed. | ||||
public static BigDecimal getDecimal(byte[] value, int pos) { | ||||
checkIndex(pos, value.length); | ||||
int basicType = value[pos] & BASIC_TYPE_MASK; | ||||
int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK; | ||||
if (basicType != PRIMITIVE) throw unexpectedType(Type.DECIMAL); | ||||
int scale = value[pos + 1]; | ||||
// Interpret the scale byte as unsigned. If it is a negative byte, the unsigned value must be | ||||
// greater than `MAX_DECIMAL16_PRECISION` and will trigger an error in `checkDecimal`. | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should distinguish between variant decimal spec and Spark decimal limitations. Some databases (and Java libraries) support negative decimal scale, which means the value range of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't allow negative scale in the variant spec either: spark/common/variant/README.md Line 348 in bf2e254
|
||||
int scale = value[pos + 1] & 0xFF; | ||||
BigDecimal result; | ||||
switch (typeInfo) { | ||||
case DECIMAL4: | ||||
result = BigDecimal.valueOf(readLong(value, pos + 2, 4), scale); | ||||
checkDecimal(result, MAX_DECIMAL4_PRECISION); | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does the pyspark side also need to be updated? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think so. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The python side is also fixed. During the fix, I found that the python error reporting was not correctly implemented and also fixed it. |
||||
break; | ||||
case DECIMAL8: | ||||
result = BigDecimal.valueOf(readLong(value, pos + 2, 8), scale); | ||||
checkDecimal(result, MAX_DECIMAL8_PRECISION); | ||||
break; | ||||
case DECIMAL16: | ||||
checkIndex(pos + 17, value.length); | ||||
|
@@ -417,6 +428,7 @@ public static BigDecimal getDecimal(byte[] value, int pos) { | |||
bytes[i] = value[pos + 17 - i]; | ||||
} | ||||
result = new BigDecimal(new BigInteger(bytes), scale); | ||||
checkDecimal(result, MAX_DECIMAL16_PRECISION); | ||||
break; | ||||
default: | ||||
throw unexpectedType(Type.DECIMAL); | ||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -142,7 +142,7 @@ def _read_long(cls, data: bytes, pos: int, num_bytes: int, signed: bool) -> int: | |
@classmethod | ||
def _check_index(cls, pos: int, length: int) -> None: | ||
if pos < 0 or pos >= length: | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT") | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) | ||
|
||
@classmethod | ||
def _get_type_info(cls, value: bytes, pos: int) -> Tuple[int, int]: | ||
|
@@ -162,14 +162,14 @@ def _get_metadata_key(cls, metadata: bytes, id: int) -> str: | |
offset_size = ((metadata[0] >> 6) & 0x3) + 1 | ||
dict_size = cls._read_long(metadata, 1, offset_size, signed=False) | ||
if id >= dict_size: | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT") | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) | ||
string_start = 1 + (dict_size + 2) * offset_size | ||
offset = cls._read_long(metadata, 1 + (id + 1) * offset_size, offset_size, signed=False) | ||
next_offset = cls._read_long( | ||
metadata, 1 + (id + 2) * offset_size, offset_size, signed=False | ||
) | ||
if offset > next_offset: | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT") | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) | ||
cls._check_index(string_start + next_offset - 1, len(metadata)) | ||
return metadata[string_start + offset : (string_start + next_offset)].decode("utf-8") | ||
|
||
|
@@ -180,15 +180,15 @@ def _get_boolean(cls, value: bytes, pos: int) -> bool: | |
if basic_type != VariantUtils.PRIMITIVE or ( | ||
type_info != VariantUtils.TRUE and type_info != VariantUtils.FALSE | ||
): | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT") | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) | ||
return type_info == VariantUtils.TRUE | ||
|
||
@classmethod | ||
def _get_long(cls, value: bytes, pos: int) -> int: | ||
cls._check_index(pos, len(value)) | ||
basic_type, type_info = cls._get_type_info(value, pos) | ||
if basic_type != VariantUtils.PRIMITIVE: | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT") | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) | ||
if type_info == VariantUtils.INT1: | ||
return cls._read_long(value, pos + 1, 1, signed=True) | ||
elif type_info == VariantUtils.INT2: | ||
|
@@ -197,25 +197,25 @@ def _get_long(cls, value: bytes, pos: int) -> int: | |
return cls._read_long(value, pos + 1, 4, signed=True) | ||
elif type_info == VariantUtils.INT8: | ||
return cls._read_long(value, pos + 1, 8, signed=True) | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT") | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) | ||
|
||
@classmethod | ||
def _get_date(cls, value: bytes, pos: int) -> datetime.date: | ||
cls._check_index(pos, len(value)) | ||
basic_type, type_info = cls._get_type_info(value, pos) | ||
if basic_type != VariantUtils.PRIMITIVE: | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT") | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) | ||
if type_info == VariantUtils.DATE: | ||
days_since_epoch = cls._read_long(value, pos + 1, 4, signed=True) | ||
return datetime.date.fromordinal(VariantUtils.EPOCH.toordinal() + days_since_epoch) | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT") | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) | ||
|
||
@classmethod | ||
def _get_timestamp(cls, value: bytes, pos: int, zone_id: str) -> datetime.datetime: | ||
cls._check_index(pos, len(value)) | ||
basic_type, type_info = cls._get_type_info(value, pos) | ||
if basic_type != VariantUtils.PRIMITIVE: | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT") | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) | ||
if type_info == VariantUtils.TIMESTAMP_NTZ: | ||
microseconds_since_epoch = cls._read_long(value, pos + 1, 8, signed=True) | ||
return VariantUtils.EPOCH_NTZ + datetime.timedelta( | ||
|
@@ -226,7 +226,7 @@ def _get_timestamp(cls, value: bytes, pos: int, zone_id: str) -> datetime.dateti | |
return ( | ||
VariantUtils.EPOCH + datetime.timedelta(microseconds=microseconds_since_epoch) | ||
).astimezone(ZoneInfo(zone_id)) | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT") | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) | ||
|
||
@classmethod | ||
def _get_string(cls, value: bytes, pos: int) -> str: | ||
|
@@ -245,47 +245,57 @@ def _get_string(cls, value: bytes, pos: int) -> str: | |
length = cls._read_long(value, pos + 1, VariantUtils.U32_SIZE, signed=False) | ||
cls._check_index(start + length - 1, len(value)) | ||
return value[start : start + length].decode("utf-8") | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT") | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) | ||
|
||
@classmethod | ||
def _get_double(cls, value: bytes, pos: int) -> float: | ||
cls._check_index(pos, len(value)) | ||
basic_type, type_info = cls._get_type_info(value, pos) | ||
if basic_type != VariantUtils.PRIMITIVE: | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT") | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) | ||
if type_info == VariantUtils.FLOAT: | ||
cls._check_index(pos + 4, len(value)) | ||
return struct.unpack("<f", value[pos + 1 : pos + 5])[0] | ||
elif type_info == VariantUtils.DOUBLE: | ||
cls._check_index(pos + 8, len(value)) | ||
return struct.unpack("<d", value[pos + 1 : pos + 9])[0] | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT") | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) | ||
|
||
@classmethod | ||
def _check_decimal(cls, unscaled: int, scale: int, max_unscaled: int, max_scale: int): | ||
# max_unscaled == 10**max_scale, but we pass a literal parameter to avoid redundant | ||
# computation. | ||
if unscaled >= max_unscaled or unscaled <= -max_unscaled or scale > max_scale: | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) | ||
|
||
@classmethod | ||
def _get_decimal(cls, value: bytes, pos: int) -> decimal.Decimal: | ||
cls._check_index(pos, len(value)) | ||
basic_type, type_info = cls._get_type_info(value, pos) | ||
if basic_type != VariantUtils.PRIMITIVE: | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT") | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) | ||
scale = value[pos + 1] | ||
unscaled = 0 | ||
if type_info == VariantUtils.DECIMAL4: | ||
unscaled = cls._read_long(value, pos + 2, 4, signed=True) | ||
cls._check_decimal(unscaled, scale, 1000000000, 9) | ||
elif type_info == VariantUtils.DECIMAL8: | ||
unscaled = cls._read_long(value, pos + 2, 8, signed=True) | ||
cls._check_decimal(unscaled, scale, 1000000000000000000, 18) | ||
elif type_info == VariantUtils.DECIMAL16: | ||
cls._check_index(pos + 17, len(value)) | ||
unscaled = int.from_bytes(value[pos + 2 : pos + 18], byteorder="little", signed=True) | ||
cls._check_decimal(unscaled, scale, 100000000000000000000000000000000000000, 38) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's make these literals into constants. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
else: | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT") | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) | ||
return decimal.Decimal(unscaled) * (decimal.Decimal(10) ** (-scale)) | ||
|
||
@classmethod | ||
def _get_binary(cls, value: bytes, pos: int) -> bytes: | ||
cls._check_index(pos, len(value)) | ||
basic_type, type_info = cls._get_type_info(value, pos) | ||
if basic_type != VariantUtils.PRIMITIVE or type_info != VariantUtils.BINARY: | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT") | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) | ||
start = pos + 1 + VariantUtils.U32_SIZE | ||
length = cls._read_long(value, pos + 1, VariantUtils.U32_SIZE, signed=False) | ||
cls._check_index(start + length - 1, len(value)) | ||
|
@@ -331,7 +341,7 @@ def _get_type(cls, value: bytes, pos: int) -> Any: | |
return datetime.datetime | ||
elif type_info == VariantUtils.LONG_STR: | ||
return str | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT") | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) | ||
|
||
@classmethod | ||
def _to_json(cls, value: bytes, metadata: bytes, pos: int, zone_id: str) -> str: | ||
|
@@ -419,7 +429,7 @@ def _get_scalar( | |
elif variant_type == datetime.datetime: | ||
return cls._get_timestamp(value, pos, zone_id) | ||
else: | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT") | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) | ||
|
||
@classmethod | ||
def _handle_object( | ||
|
@@ -432,7 +442,7 @@ def _handle_object( | |
cls._check_index(pos, len(value)) | ||
basic_type, type_info = cls._get_type_info(value, pos) | ||
if basic_type != VariantUtils.OBJECT: | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT") | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) | ||
large_size = ((type_info >> 4) & 0x1) != 0 | ||
size_bytes = VariantUtils.U32_SIZE if large_size else 1 | ||
num_fields = cls._read_long(value, pos + 1, size_bytes, signed=False) | ||
|
@@ -461,7 +471,7 @@ def _handle_array(cls, value: bytes, pos: int, func: Callable[[List[int]], Any]) | |
cls._check_index(pos, len(value)) | ||
basic_type, type_info = cls._get_type_info(value, pos) | ||
if basic_type != VariantUtils.ARRAY: | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT") | ||
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={}) | ||
large_size = ((type_info >> 2) & 0x1) != 0 | ||
size_bytes = VariantUtils.U32_SIZE if large_size else 1 | ||
num_fields = cls._read_long(value, pos + 1, size_bytes, signed=False) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we also check
precision > scale
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is where we do need to distinguish between variant decimal spec and Spark decimal limitations.
precision >= scale
is a requirement in Spark, not in variant spec. In the new test case I add,0.01
has a precison of 1 and a scale of 2. It is a valid decimal in the variant spec but requires some special handling inschema_of_variant
.