diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..244bdde6 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: patch + changes: + changed: + - Replaced unsafe numpy-Python comparison with use of numpy dtype to convert byte-string arrays to Unicode ones within enums \ No newline at end of file diff --git a/policyengine_core/enums/enum.py b/policyengine_core/enums/enum.py index 584b1bd4..b55ec347 100644 --- a/policyengine_core/enums/enum.py +++ b/policyengine_core/enums/enum.py @@ -49,8 +49,11 @@ def encode(cls, array: Union[EnumArray, np.ndarray]) -> EnumArray: if isinstance(array, EnumArray): return array - # if array.dtype.kind == "b": - if isinstance(array == 0, bool): + # First, convert byte-string arrays to Unicode-string arrays + # Confusingly, Numpy uses "S" to refer to byte-string arrays + # and "U" to refer to Unicode-string arrays, which are also + # referred to as the "str" type + if array.dtype.kind == "S": # Convert boolean array to string array array = array.astype(str) diff --git a/tests/core/enums/test_enum.py b/tests/core/enums/test_enum.py new file mode 100644 index 00000000..0dde9bac --- /dev/null +++ b/tests/core/enums/test_enum.py @@ -0,0 +1,39 @@ +import pytest +import numpy as np +from policyengine_core.enums.enum import Enum +from policyengine_core.enums.enum_array import EnumArray + + +def test_enum_creation(): + """ + Test to make sure that various types of numpy arrays + are correctly encoded to int-typed EnumArray instances; + check enum_array.py to see why int-typed + """ + + test_simple_array = ["MAXWELL", "DWORKIN", "MAXWELL"] + + class Sample(Enum): + MAXWELL = "maxwell" + DWORKIN = "dworkin" + + sample_string_array = np.array(test_simple_array) + sample_item_array = np.array( + [Sample.MAXWELL, Sample.DWORKIN, Sample.MAXWELL] + ) + explicit_s_array = np.array(test_simple_array, "S") + + encoded_array = Sample.encode(sample_string_array) + assert len(encoded_array) == 3 + assert isinstance(encoded_array, EnumArray) + assert encoded_array.dtype.kind == "i" + + encoded_array = Sample.encode(sample_item_array) + assert len(encoded_array) == 3 + assert isinstance(encoded_array, EnumArray) + assert encoded_array.dtype.kind == "i" + + encoded_array = Sample.encode(explicit_s_array) + assert len(encoded_array) == 3 + assert isinstance(encoded_array, EnumArray) + assert encoded_array.dtype.kind == "i"