From 7dbe280db75f0ba1407cb603dad14fdd3b203cad Mon Sep 17 00:00:00 2001 From: Arthur Besen Soprana Date: Fri, 4 Oct 2019 17:12:10 -0300 Subject: [PATCH] Add correct comparison support for Scalar Modified comparison behavior of ``Scalar``. The previous behavior assumes that ``Scalar(1, "m") != Scalar(100, "cm")`` and not it has been changed to ``Scalar(1, "m") == Scalar(100, "cm")``. This may affect users that rely on the previous behavior. BARRIL-26 --- CHANGELOG.rst | 3 +++ src/barril/units/_quantity.py | 14 ++++++++++--- src/barril/units/_scalar.py | 23 ++++++++++++++------ src/barril/units/_tests/test_scalar.py | 29 ++++++++++++++++++++++++++ 4 files changed, 60 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index c56afdb..2b7f702 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -9,6 +9,9 @@ UNRELEASED * Fix division ``1.0 / a`` where ``a`` is a ``Scalar`` or ``Array`` and also add support for floor division, i.e., operations like ``a // b`` where ``a`` and ``b`` are ``Scalar`` or ``Array`` (and combinations with ``float`` or ``int``). +* Modified comparison behavior of ``Scalar``. The previous behavior assumes that ``Scalar(1, "m") != Scalar(100, "cm")`` + and not it has been changed to ``Scalar(1, "m") == Scalar(100, "cm")``. This may affect users that rely on the previous + behavior. 1.7.1 (2019-10-03) ------------------ diff --git a/src/barril/units/_quantity.py b/src/barril/units/_quantity.py index c754ea0..1d24f98 100644 --- a/src/barril/units/_quantity.py +++ b/src/barril/units/_quantity.py @@ -677,9 +677,17 @@ def Convert(self, value, to_unit): :returns: An object with values to the passed unit. """ - return self._unit_database.Convert( - self._composing_categories, self._composing_units, to_unit, value - ) + try: + return self._unit_database.Convert( + self._composing_categories, self._composing_units, to_unit, value + ) + except: + return self._unit_database.Convert( + self._category, + self._CreateUnitsWithJoinedExponentsString(), + to_unit, + value, + ) @classmethod def _GetComparison(cls, operator, use_literals=False): diff --git a/src/barril/units/_scalar.py b/src/barril/units/_scalar.py index 81fe445..e9f3429 100644 --- a/src/barril/units/_scalar.py +++ b/src/barril/units/_scalar.py @@ -10,7 +10,7 @@ from ._abstractvaluewithquantity import AbstractValueWithQuantityObject from ._quantity import ObtainQuantity, Quantity from .interfaces import IQuantity, IScalar -from .unit_database import UnitDatabase +from .unit_database import InvalidQuantityTypeError, UnitDatabase __all__ = ["Scalar"] @@ -238,23 +238,34 @@ def GetFormatted(self, unit=None, value_format=None): ) # Compare -------------------------------------------------------------------------------------- + def _GetValueInDefaultUnit(self): + try: + return self.GetValue( + self._unit_database.GetDefaultUnit(self._quantity._category) + ) + except InvalidQuantityTypeError: + return self._value def __eq__(self, other): return ( type(self) is type(other) - and self._value == other.value - and self._quantity == other._quantity + and self._quantity._category == other._quantity._category + and self._GetValueInDefaultUnit() == other._GetValueInDefaultUnit() ) def AlmostEqual(self, other, precision): return ( type(self) is type(other) - and round(self._value - other.value, precision) == 0 - and self._quantity == other._quantity + and self._quantity._category == other._quantity._category + and round( + self._GetValueInDefaultUnit() - other._GetValueInDefaultUnit(), + precision, + ) + == 0 ) def __hash__(self, *args, **kwargs): - return hash((self._value, self._quantity)) + return hash((self._quantity._category, self._GetValueInDefaultUnit())) def __lt__(self, other): if self.quantity_type != other.quantity_type: diff --git a/src/barril/units/_tests/test_scalar.py b/src/barril/units/_tests/test_scalar.py index 07f675e..5a0e587 100644 --- a/src/barril/units/_tests/test_scalar.py +++ b/src/barril/units/_tests/test_scalar.py @@ -630,3 +630,32 @@ class Fluid: assert fluid.density.GetValue("lbm/galUS") == 10 assert fluid.concentration.GetValue("%") == 1.0 + + +def testComparison(): + a = Scalar(1, "m") + b = Scalar(100, "cm") + c = Scalar(99, "cm") + + assert a == b + assert a <= b + assert b >= a + assert c <= b + assert c <= a + + # Test set creation with scalars + assert {b, a, c} == {a, b, c} == {c, a, b} + + # Check Scalars with different categories + assert Scalar(99, "psi") != Scalar(100, "cm") + + assert a.AlmostEqual(b, precision=10) + assert a.AlmostEqual(c, precision=1) + + # Testing derived quantities + q = Quantity.CreateDerived( + OrderedDict([("length", ["m", 1]), ("length", ["m", 1]), ("time", ["s", -2])]) + ) + a = Scalar(q, 1.0) + b = Scalar(q * q, 1.0) + assert {a, b, b / a} == {b, a}