Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions src/fastcs/datatypes/datatype.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import enum
from abc import abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Generic, TypeVar

Expand Down Expand Up @@ -28,6 +29,11 @@ class DataType(Generic[DType_T]):
def dtype(self) -> type[DType_T]: # Using property due to lack of Generic ClassVars
raise NotImplementedError()

@property
@abstractmethod
def initial_value(self) -> DType_T:
raise NotImplementedError()

def validate(self, value: Any) -> DType_T:
"""Validate a value against the datatype.

Expand Down Expand Up @@ -55,7 +61,32 @@ def validate(self, value: Any) -> DType_T:
except (ValueError, TypeError) as e:
raise ValueError(f"Failed to cast {value} to type {self.dtype}") from e

@property
@abstractmethod
def initial_value(self) -> DType_T:
raise NotImplementedError()
@staticmethod
def equal(value1: DType_T, value2: DType_T) -> bool:
"""Compare two values for equality

Child classes can override this if the underlying type does not implement
``__eq__`` or to define custom logic.

Args:
value1: The first value to compare
value2: The second value to compare

Returns:
`True` if the values are equal

"""
return value1 == value2

@classmethod
def all_equal(cls, values: Sequence[DType_T]) -> bool:
"""Compare a sequence of values for equality

Args:
values: Values to compare

Returns:
`True` if all values are equal, else `False`

"""
return all(cls.equal(values[0], value) for value in values[1:])
4 changes: 4 additions & 0 deletions src/fastcs/datatypes/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,7 @@ def validate(self, value: Any) -> np.ndarray:
)

return _value

@staticmethod
def equal(value1: np.ndarray, value2: np.ndarray) -> bool:
return np.array_equal(value1, value2)
4 changes: 4 additions & 0 deletions src/fastcs/datatypes/waveform.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,7 @@ def validate(self, value: np.ndarray) -> np.ndarray:
)

return _value

@staticmethod
def equal(value1: np.ndarray, value2: np.ndarray) -> bool:
return np.array_equal(value1, value2)
50 changes: 47 additions & 3 deletions tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
import numpy as np
import pytest

from fastcs.datatypes import DataType, Enum, Float, Int, Waveform
from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, Table, Waveform
from fastcs.datatypes._util import numpy_to_fastcs_datatype
from fastcs.datatypes.bool import Bool
from fastcs.datatypes.string import String


def test_base_validate():
Expand Down Expand Up @@ -61,3 +59,49 @@ def test_validate(datatype, init_args, value):
)
def test_numpy_to_fastcs_datatype(numpy_type, fastcs_datatype):
assert fastcs_datatype == numpy_to_fastcs_datatype(numpy_type)


@pytest.mark.parametrize(
"fastcs_datatype, value1, value2, expected",
[
(Int(), 1, 1, True),
(Int(), 1, 2, False),
(Float(), 1.0, 1.0, True),
(Float(), 1.0, 2.0, False),
(Bool(), True, True, True),
(Bool(), True, False, False),
(String(), "foo", "foo", True),
(String(), "foo", "bar", False),
(Waveform(np.int16), np.array([1]), np.array([1]), True),
(Waveform(np.int16), np.array([1]), np.array([2]), False),
(
Table([("int", np.int16), ("bool", np.bool), ("str", np.dtype("S10"))]),
np.array([1, True, "foo"]),
np.array([1, True, "foo"]),
True,
),
(
Table([("int", np.int16), ("bool", np.bool), ("str", np.dtype("S10"))]),
np.array([1, True, "foo"]),
np.array([2, False, "bar"]),
False,
),
],
)
def test_dataset_equal(fastcs_datatype: DataType, value1, value2, expected):
assert fastcs_datatype.equal(value1, value2) is expected


@pytest.mark.parametrize(
"fastcs_datatype, values, expected",
[
(Int(), [1, 1], True),
(Int(), [1, 2], False),
(Float(), [1.0, 1.0], True),
(Float(), [1.0, 2.0], False),
(Bool(), [True, True], True),
(Bool(), [True, False], False),
],
)
def test_dataset_all_equal(fastcs_datatype: DataType, values, expected):
assert fastcs_datatype.all_equal(values) is expected