Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-47681][SQL][FOLLOWUP] Fix variant decimal handling. #46338

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -392,21 +392,32 @@ public static double getDouble(byte[] value, int pos) {
return Double.longBitsToDouble(readLong(value, pos + 1, 8));
}

// Check whether the precision and scale of the decimal are within the limit.
private static void checkDecimal(BigDecimal d, int maxPrecision) {
if (d.precision() > maxPrecision || d.scale() > maxPrecision) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we also check precision > scale?

Copy link
Contributor Author

@chenhao-db chenhao-db May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where we do need to distinguish between variant decimal spec and Spark decimal limitations. precision >= scale is a requirement in Spark, not in variant spec. In the new test case I add, 0.01 has a precison of 1 and a scale of 2. It is a valid decimal in the variant spec but requires some special handling in schema_of_variant.

throw malformedVariant();
}
}

// Get a decimal value from variant value `value[pos...]`.
// Throw `MALFORMED_VARIANT` if the variant is malformed.
public static BigDecimal getDecimal(byte[] value, int pos) {
checkIndex(pos, value.length);
int basicType = value[pos] & BASIC_TYPE_MASK;
int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK;
if (basicType != PRIMITIVE) throw unexpectedType(Type.DECIMAL);
int scale = value[pos + 1];
// Interpret the scale byte as unsigned. If it is a negative byte, the unsigned value must be
// greater than `MAX_DECIMAL16_PRECISION` and will trigger an error in `checkDecimal`.
Copy link
Contributor

@cloud-fan cloud-fan May 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should distinguish between variant decimal spec and Spark decimal limitations. Some databases (and Java libraries) support negative decimal scale, which means the value range of decimal(precision, 0) * (10^-scale). To read such decimals, e.g. decimal(10, -5), we can turn it into Spark decimal decimal(15, 0).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't allow negative scale in the variant spec either:

| decimal4 | `8` | DECIMAL(precision, scale) | 1 byte scale in range [0, 38], followed by little-endian unscaled value (see decimal table) |
.

int scale = value[pos + 1] & 0xFF;
BigDecimal result;
switch (typeInfo) {
case DECIMAL4:
result = BigDecimal.valueOf(readLong(value, pos + 2, 4), scale);
checkDecimal(result, MAX_DECIMAL4_PRECISION);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the pyspark side also need to be updated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The python side is also fixed. During the fix, I found that the python error reporting was not correctly implemented and also fixed it.

break;
case DECIMAL8:
result = BigDecimal.valueOf(readLong(value, pos + 2, 8), scale);
checkDecimal(result, MAX_DECIMAL8_PRECISION);
break;
case DECIMAL16:
checkIndex(pos + 17, value.length);
Expand All @@ -417,6 +428,7 @@ public static BigDecimal getDecimal(byte[] value, int pos) {
bytes[i] = value[pos + 17 - i];
}
result = new BigDecimal(new BigInteger(bytes), scale);
checkDecimal(result, MAX_DECIMAL16_PRECISION);
break;
default:
throw unexpectedType(Type.DECIMAL);
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/errors/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,11 @@
"<arg1> and <arg2> should be of the same length, got <arg1_length> and <arg2_length>."
]
},
"MALFORMED_VARIANT" : {
"message" : [
"Variant binary is malformed. Please check the data source is valid."
]
},
"MASTER_URL_NOT_SET": {
"message": [
"A master URL must be set in your configuration."
Expand Down
13 changes: 13 additions & 0 deletions python/pyspark/sql/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1577,6 +1577,19 @@ def test_variant_type(self):
# check repr
self.assertEqual(str(variants[0]), str(eval(repr(variants[0]))))

metadata = bytes([1, 0, 0])
self.assertEqual(str(VariantVal(bytes([32, 0, 1, 0, 0, 0]), metadata)), "1")
self.assertEqual(str(VariantVal(bytes([32, 1, 2, 0, 0, 0]), metadata)), "0.2")
self.assertEqual(str(VariantVal(bytes([32, 2, 3, 0, 0, 0]), metadata)), "0.03")
self.assertEqual(str(VariantVal(bytes([32, 0, 1, 0, 0, 0]), metadata)), "1")
self.assertEqual(str(VariantVal(bytes([32, 0, 255, 201, 154, 59]), metadata)), "999999999")
self.assertRaises(
PySparkValueError, lambda: str(VariantVal(bytes([32, 0, 0, 202, 154, 59]), metadata))
)
self.assertRaises(
PySparkValueError, lambda: str(VariantVal(bytes([32, 10, 1, 0, 0, 0]), metadata))
)

def test_from_ddl(self):
self.assertEqual(DataType.fromDDL("long"), LongType())
self.assertEqual(
Expand Down
59 changes: 39 additions & 20 deletions python/pyspark/sql/variant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@ class VariantUtils:
)
EPOCH_NTZ = datetime.datetime(year=1970, month=1, day=1, hour=0, minute=0, second=0)

MAX_DECIMAL4_PRECISION = 9
MAX_DECIMAL4_VALUE = 10**MAX_DECIMAL4_PRECISION
MAX_DECIMAL8_PRECISION = 18
MAX_DECIMAL8_VALUE = 10**MAX_DECIMAL8_PRECISION
MAX_DECIMAL16_PRECISION = 38
MAX_DECIMAL16_VALUE = 10**MAX_DECIMAL16_PRECISION

@classmethod
def to_json(cls, value: bytes, metadata: bytes, zone_id: str = "UTC") -> str:
"""
Expand Down Expand Up @@ -142,7 +149,7 @@ def _read_long(cls, data: bytes, pos: int, num_bytes: int, signed: bool) -> int:
@classmethod
def _check_index(cls, pos: int, length: int) -> None:
if pos < 0 or pos >= length:
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})

@classmethod
def _get_type_info(cls, value: bytes, pos: int) -> Tuple[int, int]:
Expand All @@ -162,14 +169,14 @@ def _get_metadata_key(cls, metadata: bytes, id: int) -> str:
offset_size = ((metadata[0] >> 6) & 0x3) + 1
dict_size = cls._read_long(metadata, 1, offset_size, signed=False)
if id >= dict_size:
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
string_start = 1 + (dict_size + 2) * offset_size
offset = cls._read_long(metadata, 1 + (id + 1) * offset_size, offset_size, signed=False)
next_offset = cls._read_long(
metadata, 1 + (id + 2) * offset_size, offset_size, signed=False
)
if offset > next_offset:
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
cls._check_index(string_start + next_offset - 1, len(metadata))
return metadata[string_start + offset : (string_start + next_offset)].decode("utf-8")

Expand All @@ -180,15 +187,15 @@ def _get_boolean(cls, value: bytes, pos: int) -> bool:
if basic_type != VariantUtils.PRIMITIVE or (
type_info != VariantUtils.TRUE and type_info != VariantUtils.FALSE
):
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
return type_info == VariantUtils.TRUE

@classmethod
def _get_long(cls, value: bytes, pos: int) -> int:
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type != VariantUtils.PRIMITIVE:
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
if type_info == VariantUtils.INT1:
return cls._read_long(value, pos + 1, 1, signed=True)
elif type_info == VariantUtils.INT2:
Expand All @@ -197,25 +204,25 @@ def _get_long(cls, value: bytes, pos: int) -> int:
return cls._read_long(value, pos + 1, 4, signed=True)
elif type_info == VariantUtils.INT8:
return cls._read_long(value, pos + 1, 8, signed=True)
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})

@classmethod
def _get_date(cls, value: bytes, pos: int) -> datetime.date:
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type != VariantUtils.PRIMITIVE:
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
if type_info == VariantUtils.DATE:
days_since_epoch = cls._read_long(value, pos + 1, 4, signed=True)
return datetime.date.fromordinal(VariantUtils.EPOCH.toordinal() + days_since_epoch)
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})

@classmethod
def _get_timestamp(cls, value: bytes, pos: int, zone_id: str) -> datetime.datetime:
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type != VariantUtils.PRIMITIVE:
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
if type_info == VariantUtils.TIMESTAMP_NTZ:
microseconds_since_epoch = cls._read_long(value, pos + 1, 8, signed=True)
return VariantUtils.EPOCH_NTZ + datetime.timedelta(
Expand All @@ -226,7 +233,7 @@ def _get_timestamp(cls, value: bytes, pos: int, zone_id: str) -> datetime.dateti
return (
VariantUtils.EPOCH + datetime.timedelta(microseconds=microseconds_since_epoch)
).astimezone(ZoneInfo(zone_id))
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})

@classmethod
def _get_string(cls, value: bytes, pos: int) -> str:
Expand All @@ -245,47 +252,59 @@ def _get_string(cls, value: bytes, pos: int) -> str:
length = cls._read_long(value, pos + 1, VariantUtils.U32_SIZE, signed=False)
cls._check_index(start + length - 1, len(value))
return value[start : start + length].decode("utf-8")
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})

@classmethod
def _get_double(cls, value: bytes, pos: int) -> float:
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type != VariantUtils.PRIMITIVE:
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
if type_info == VariantUtils.FLOAT:
cls._check_index(pos + 4, len(value))
return struct.unpack("<f", value[pos + 1 : pos + 5])[0]
elif type_info == VariantUtils.DOUBLE:
cls._check_index(pos + 8, len(value))
return struct.unpack("<d", value[pos + 1 : pos + 9])[0]
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})

@classmethod
def _check_decimal(cls, unscaled: int, scale: int, max_unscaled: int, max_scale: int) -> None:
# max_unscaled == 10**max_scale, but we pass a literal parameter to avoid redundant
# computation.
if unscaled >= max_unscaled or unscaled <= -max_unscaled or scale > max_scale:
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})

@classmethod
def _get_decimal(cls, value: bytes, pos: int) -> decimal.Decimal:
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type != VariantUtils.PRIMITIVE:
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
scale = value[pos + 1]
unscaled = 0
if type_info == VariantUtils.DECIMAL4:
unscaled = cls._read_long(value, pos + 2, 4, signed=True)
cls._check_decimal(unscaled, scale, cls.MAX_DECIMAL4_VALUE, cls.MAX_DECIMAL4_PRECISION)
elif type_info == VariantUtils.DECIMAL8:
unscaled = cls._read_long(value, pos + 2, 8, signed=True)
cls._check_decimal(unscaled, scale, cls.MAX_DECIMAL8_VALUE, cls.MAX_DECIMAL8_PRECISION)
elif type_info == VariantUtils.DECIMAL16:
cls._check_index(pos + 17, len(value))
unscaled = int.from_bytes(value[pos + 2 : pos + 18], byteorder="little", signed=True)
cls._check_decimal(
unscaled, scale, cls.MAX_DECIMAL16_VALUE, cls.MAX_DECIMAL16_PRECISION
)
else:
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
return decimal.Decimal(unscaled) * (decimal.Decimal(10) ** (-scale))

@classmethod
def _get_binary(cls, value: bytes, pos: int) -> bytes:
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type != VariantUtils.PRIMITIVE or type_info != VariantUtils.BINARY:
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
start = pos + 1 + VariantUtils.U32_SIZE
length = cls._read_long(value, pos + 1, VariantUtils.U32_SIZE, signed=False)
cls._check_index(start + length - 1, len(value))
Expand Down Expand Up @@ -331,7 +350,7 @@ def _get_type(cls, value: bytes, pos: int) -> Any:
return datetime.datetime
elif type_info == VariantUtils.LONG_STR:
return str
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})

@classmethod
def _to_json(cls, value: bytes, metadata: bytes, pos: int, zone_id: str) -> str:
Expand Down Expand Up @@ -419,7 +438,7 @@ def _get_scalar(
elif variant_type == datetime.datetime:
return cls._get_timestamp(value, pos, zone_id)
else:
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})

@classmethod
def _handle_object(
Expand All @@ -432,7 +451,7 @@ def _handle_object(
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type != VariantUtils.OBJECT:
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
large_size = ((type_info >> 4) & 0x1) != 0
size_bytes = VariantUtils.U32_SIZE if large_size else 1
num_fields = cls._read_long(value, pos + 1, size_bytes, signed=False)
Expand Down Expand Up @@ -461,7 +480,7 @@ def _handle_array(cls, value: bytes, pos: int, func: Callable[[List[int]], Any])
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type != VariantUtils.ARRAY:
raise PySparkValueError(error_class="MALFORMED_VARIANT")
raise PySparkValueError(error_class="MALFORMED_VARIANT", message_parameters={})
large_size = ((type_info >> 2) & 0x1) != 0
size_bytes = VariantUtils.U32_SIZE if large_size else 1
num_fields = cls._read_long(value, pos + 1, size_bytes, signed=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,8 @@ object SchemaOfVariant {
case Type.DOUBLE => DoubleType
case Type.DECIMAL =>
val d = v.getDecimal
DecimalType(d.precision(), d.scale())
// Spark doesn't allow `DecimalType` to have `precision < scale`.
DecimalType(d.precision().max(d.scale()), d.scale())
case Type.DATE => DateType
case Type.TIMESTAMP => TimestampType
case Type.TIMESTAMP_NTZ => TimestampNTZType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
check(Array(primitiveHeader(INT8), 0, 0, 0, 0, 0, 0, 0), emptyMetadata)
// DECIMAL16 only has 15 byte content.
check(Array(primitiveHeader(DECIMAL16)) ++ Array.fill(16)(0.toByte), emptyMetadata)
// 1e38 has a precision of 39. Even if it still fits into 16 bytes, it is not a valid decimal.
check(Array[Byte](primitiveHeader(DECIMAL16), 0) ++
BigDecimal(1e38).toBigInt.toByteArray.reverse, emptyMetadata)
// Short string content too short.
check(Array(shortStrHeader(2), 'x'), emptyMetadata)
// Long string length too short (requires 4 bytes).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession {
check("null", "VOID")
check("1", "BIGINT")
check("1.0", "DECIMAL(1,0)")
check("0.01", "DECIMAL(2,2)")
check("1E0", "DOUBLE")
check("true", "BOOLEAN")
check("\"2000-01-01\"", "STRING")
Expand Down