Skip to content

Commit 9380e91

Browse files
committed
fix
1 parent fd683ce commit 9380e91

File tree

3 files changed

+42
-0
lines changed

3 files changed

+42
-0
lines changed

python/pyspark/ml/tests/test_linalg.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,9 @@ def test_unwrap_udt(self):
364364
]
365365
self.assertEqual(results, expected)
366366

367+
def test_hashable(self):
368+
_ = hash(VectorUDT())
369+
367370

368371
class MatrixUDTTests(MLlibTestCase):
369372
dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10])
@@ -394,6 +397,9 @@ def test_infer_schema(self):
394397
else:
395398
raise ValueError("Expected a matrix but got type %r" % type(m))
396399

400+
def test_hashable(self):
401+
_ = hash(MatrixUDT())
402+
397403

398404
if __name__ == "__main__":
399405
from pyspark.ml.tests.test_linalg import * # noqa: F401

python/pyspark/sql/tests/test_types.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2076,6 +2076,38 @@ def test_repr(self):
20762076
for instance in instances:
20772077
self.assertEqual(eval(repr(instance)), instance)
20782078

2079+
def test_hashable(self):
2080+
for dt in [
2081+
NullType(),
2082+
StringType(),
2083+
StringType("UTF8_BINARY"),
2084+
StringType("UTF8_LCASE"),
2085+
StringType("UNICODE"),
2086+
StringType("UNICODE_CI"),
2087+
CharType(10),
2088+
VarcharType(10),
2089+
BinaryType(),
2090+
BooleanType(),
2091+
DateType(),
2092+
TimeType(),
2093+
TimestampType(),
2094+
DecimalType(),
2095+
DoubleType(),
2096+
FloatType(),
2097+
ByteType(),
2098+
IntegerType(),
2099+
LongType(),
2100+
ShortType(),
2101+
CalendarIntervalType(),
2102+
ArrayType(StringType()),
2103+
MapType(StringType(), IntegerType()),
2104+
StructField("f1", StringType(), True),
2105+
StructType([StructField("f1", StringType(), True)]),
2106+
VariantType(),
2107+
ExamplePointUDT(),
2108+
]:
2109+
_ = hash(dt)
2110+
20792111
def test_daytime_interval_type_constructor(self):
20802112
# SPARK-37277: Test constructors in day time interval.
20812113
self.assertEqual(DayTimeIntervalType().simpleString(), "interval day to second")

python/pyspark/sql/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1995,6 +1995,10 @@ def fromJson(cls, json: Dict[str, Any]) -> "UserDefinedType":
19951995
def __eq__(self, other: Any) -> bool:
19961996
return type(self) == type(other)
19971997

1998+
# __hash__ should be defined together with __eq__, otherwise it is not hashable
1999+
def __hash__(self) -> int:
2000+
return hash(str(self))
2001+
19982002

19992003
class VariantVal:
20002004
"""

0 commit comments

Comments
 (0)