diff --git a/src/fastcs/datatypes/datatype.py b/src/fastcs/datatypes/datatype.py index 69d743904..22d1d2fdb 100644 --- a/src/fastcs/datatypes/datatype.py +++ b/src/fastcs/datatypes/datatype.py @@ -1,5 +1,6 @@ import enum from abc import abstractmethod +from collections.abc import Sequence from dataclasses import dataclass from typing import Any, Generic, TypeVar @@ -28,6 +29,11 @@ class DataType(Generic[DType_T]): def dtype(self) -> type[DType_T]: # Using property due to lack of Generic ClassVars raise NotImplementedError() + @property + @abstractmethod + def initial_value(self) -> DType_T: + raise NotImplementedError() + def validate(self, value: Any) -> DType_T: """Validate a value against the datatype. @@ -55,7 +61,32 @@ def validate(self, value: Any) -> DType_T: except (ValueError, TypeError) as e: raise ValueError(f"Failed to cast {value} to type {self.dtype}") from e - @property - @abstractmethod - def initial_value(self) -> DType_T: - raise NotImplementedError() + @staticmethod + def equal(value1: DType_T, value2: DType_T) -> bool: + """Compare two values for equality + + Child classes can override this if the underlying type does not implement + ``__eq__`` or to define custom logic. + + Args: + value1: The first value to compare + value2: The second value to compare + + Returns: + `True` if the values are equal + + """ + return value1 == value2 + + @classmethod + def all_equal(cls, values: Sequence[DType_T]) -> bool: + """Compare a sequence of values for equality + + Args: + values: Values to compare + + Returns: + `True` if all values are equal, else `False` + + """ + return all(cls.equal(values[0], value) for value in values[1:]) diff --git a/src/fastcs/datatypes/table.py b/src/fastcs/datatypes/table.py index 4c9757793..7bbb0cabf 100644 --- a/src/fastcs/datatypes/table.py +++ b/src/fastcs/datatypes/table.py @@ -30,3 +30,7 @@ def validate(self, value: Any) -> np.ndarray: ) return _value + + @staticmethod + def equal(value1: np.ndarray, value2: np.ndarray) -> bool: + return np.array_equal(value1, value2) diff --git a/src/fastcs/datatypes/waveform.py b/src/fastcs/datatypes/waveform.py index 24d2ba1e6..e9ca0ada5 100644 --- a/src/fastcs/datatypes/waveform.py +++ b/src/fastcs/datatypes/waveform.py @@ -38,3 +38,7 @@ def validate(self, value: np.ndarray) -> np.ndarray: ) return _value + + @staticmethod + def equal(value1: np.ndarray, value2: np.ndarray) -> bool: + return np.array_equal(value1, value2) diff --git a/tests/test_datatypes.py b/tests/test_datatypes.py index 45157f2f2..b3a41f31d 100644 --- a/tests/test_datatypes.py +++ b/tests/test_datatypes.py @@ -3,10 +3,8 @@ import numpy as np import pytest -from fastcs.datatypes import DataType, Enum, Float, Int, Waveform +from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, Table, Waveform from fastcs.datatypes._util import numpy_to_fastcs_datatype -from fastcs.datatypes.bool import Bool -from fastcs.datatypes.string import String def test_base_validate(): @@ -61,3 +59,49 @@ def test_validate(datatype, init_args, value): ) def test_numpy_to_fastcs_datatype(numpy_type, fastcs_datatype): assert fastcs_datatype == numpy_to_fastcs_datatype(numpy_type) + + +@pytest.mark.parametrize( + "fastcs_datatype, value1, value2, expected", + [ + (Int(), 1, 1, True), + (Int(), 1, 2, False), + (Float(), 1.0, 1.0, True), + (Float(), 1.0, 2.0, False), + (Bool(), True, True, True), + (Bool(), True, False, False), + (String(), "foo", "foo", True), + (String(), "foo", "bar", False), + (Waveform(np.int16), np.array([1]), np.array([1]), True), + (Waveform(np.int16), np.array([1]), np.array([2]), False), + ( + Table([("int", np.int16), ("bool", np.bool), ("str", np.dtype("S10"))]), + np.array([1, True, "foo"]), + np.array([1, True, "foo"]), + True, + ), + ( + Table([("int", np.int16), ("bool", np.bool), ("str", np.dtype("S10"))]), + np.array([1, True, "foo"]), + np.array([2, False, "bar"]), + False, + ), + ], +) +def test_dataset_equal(fastcs_datatype: DataType, value1, value2, expected): + assert fastcs_datatype.equal(value1, value2) is expected + + +@pytest.mark.parametrize( + "fastcs_datatype, values, expected", + [ + (Int(), [1, 1], True), + (Int(), [1, 2], False), + (Float(), [1.0, 1.0], True), + (Float(), [1.0, 2.0], False), + (Bool(), [True, True], True), + (Bool(), [True, False], False), + ], +) +def test_dataset_all_equal(fastcs_datatype: DataType, values, expected): + assert fastcs_datatype.all_equal(values) is expected