diff --git a/src/fastcs/attributes/attribute_io.py b/src/fastcs/attributes/attribute_io.py index e3648fa3d..d44fb5843 100644 --- a/src/fastcs/attributes/attribute_io.py +++ b/src/fastcs/attributes/attribute_io.py @@ -37,4 +37,4 @@ async def send(self, attr: AttrW[DType_T, AttributeIORefT], value: DType_T) -> N raise NotImplementedError() -AnyAttributeIO = AttributeIO[DType_T, AttributeIORef] +AnyAttributeIO = AttributeIO[Any] diff --git a/src/fastcs/controllers/base_controller.py b/src/fastcs/controllers/base_controller.py index ff7507df4..da053b648 100755 --- a/src/fastcs/controllers/base_controller.py +++ b/src/fastcs/controllers/base_controller.py @@ -5,15 +5,7 @@ from copy import deepcopy from typing import _GenericAlias, get_args, get_origin, get_type_hints # type: ignore -from fastcs.attributes import ( - Attribute, - AttributeIO, - AttributeIORefT, - AttrR, - AttrW, - HintedAttribute, -) -from fastcs.datatypes import DType_T +from fastcs.attributes import AnyAttributeIO, Attribute, AttrR, AttrW, HintedAttribute from fastcs.logging import bind_logger from fastcs.tracer import Tracer @@ -41,7 +33,7 @@ def __init__( self, path: list[str] | None = None, description: str | None = None, - ios: Sequence[AttributeIO[DType_T, AttributeIORefT]] | None = None, + ios: Sequence[AnyAttributeIO] | None = None, ) -> None: super().__init__() @@ -125,7 +117,7 @@ class method and a controller instance, so that it can be called from any elif isinstance(attr, UnboundScan | UnboundCommand): setattr(self, attr_name, attr.bind(self)) - def _validate_io(self, ios: Sequence[AttributeIO[DType_T, AttributeIORefT]]): + def _validate_io(self, ios: Sequence[AnyAttributeIO]): """Validate that there is exactly one AttributeIO class registered to the controller for each type of AttributeIORef belonging to the attributes of the controller""" diff --git a/src/fastcs/controllers/controller.py b/src/fastcs/controllers/controller.py index be9287cf6..f880e1e82 100755 --- a/src/fastcs/controllers/controller.py +++ b/src/fastcs/controllers/controller.py @@ -1,8 +1,7 @@ from collections.abc import Sequence -from fastcs.attributes import AttributeIO, AttributeIORefT +from fastcs.attributes import AnyAttributeIO from fastcs.controllers.base_controller import BaseController -from fastcs.datatypes import DType_T class Controller(BaseController): @@ -11,7 +10,7 @@ class Controller(BaseController): def __init__( self, description: str | None = None, - ios: Sequence[AttributeIO[DType_T, AttributeIORefT]] | None = None, + ios: Sequence[AnyAttributeIO] | None = None, ) -> None: super().__init__(description=description, ios=ios) diff --git a/src/fastcs/controllers/controller_vector.py b/src/fastcs/controllers/controller_vector.py index f3a21a73f..2beed863f 100755 --- a/src/fastcs/controllers/controller_vector.py +++ b/src/fastcs/controllers/controller_vector.py @@ -1,9 +1,8 @@ from collections.abc import Iterator, Mapping, MutableMapping, Sequence -from fastcs.attributes import AttributeIO, AttributeIORefT +from fastcs.attributes import AnyAttributeIO from fastcs.controllers.base_controller import BaseController from fastcs.controllers.controller import Controller -from fastcs.datatypes import DType_T class ControllerVector(MutableMapping[int, Controller], BaseController): @@ -18,7 +17,7 @@ def __init__( self, children: Mapping[int, Controller], description: str | None = None, - ios: Sequence[AttributeIO[DType_T, AttributeIORefT]] | None = None, + ios: Sequence[AnyAttributeIO] | None = None, ) -> None: super().__init__(description=description, ios=ios) self._children: dict[int, Controller] = {} diff --git a/src/fastcs/datatypes/datatype.py b/src/fastcs/datatypes/datatype.py index 1559a5c08..69d743904 100644 --- a/src/fastcs/datatypes/datatype.py +++ b/src/fastcs/datatypes/datatype.py @@ -13,18 +13,11 @@ | enum.Enum # Enum | np.ndarray # Waveform / Table ) - -DType_T = TypeVar( - "DType_T", - int, # Int - float, # Float - bool, # Bool - str, # String - enum.Enum, # Enum - np.ndarray, # Waveform / Table -) """A builtin (or numpy) type supported by a corresponding FastCS Attribute DataType""" +DType_T = TypeVar("DType_T", bound=DType) +"""A TypeVar of `DType` for use in generic classes and functions""" + @dataclass(frozen=True) class DataType(Generic[DType_T]): diff --git a/src/fastcs/transports/epics/ca/util.py b/src/fastcs/transports/epics/ca/util.py index 5173be569..0c483a928 100644 --- a/src/fastcs/transports/epics/ca/util.py +++ b/src/fastcs/transports/epics/ca/util.py @@ -1,10 +1,12 @@ +import enum from dataclasses import asdict from typing import Any from softioc import builder from fastcs.attributes import Attribute, AttrR, AttrRW, AttrW -from fastcs.datatypes import Bool, DataType, DType_T, Enum, Float, Int, String, Waveform +from fastcs.datatypes import Bool, DType_T, Enum, Float, Int, String, Waveform +from fastcs.datatypes.datatype import DataType from fastcs.exceptions import FastCSError _MBB_FIELD_PREFIXES = ( @@ -31,7 +33,7 @@ MBB_MAX_CHOICES = len(_MBB_FIELD_PREFIXES) -EPICS_ALLOWED_DATATYPES = (Bool, DataType, Enum, Float, Int, String, Waveform) +EPICS_ALLOWED_DATATYPES = (Bool, Enum, Float, Int, String, Waveform) DEFAULT_STRING_WAVEFORM_LENGTH = 256 DATATYPE_FIELD_TO_RECORD_FIELD = { @@ -44,9 +46,7 @@ } -def record_metadata_from_attribute( - attribute: Attribute[DType_T], -) -> dict[str, Any]: +def record_metadata_from_attribute(attribute: Attribute[DType_T]) -> dict[str, Any]: """Converts attributes on the `Attribute` to the field name/value in the record metadata.""" metadata: dict[str, Any] = {"DESC": attribute.description} @@ -62,7 +62,7 @@ def record_metadata_from_attribute( def record_metadata_from_datatype( - datatype: DataType[DType_T], out_record: bool = False + datatype: DataType[Any], out_record: bool = False ) -> dict[str, str]: """Converts attributes on the `DataType` to the field name/value in the record metadata.""" @@ -123,9 +123,14 @@ def cast_from_epics_type(datatype: DataType[DType_T], value: object) -> DType_T: raise ValueError(f"Invalid bool value from EPICS record {value}") case Enum(): if len(datatype.members) <= MBB_MAX_CHOICES: + assert isinstance(value, int), "Got non-integer value for Enum" return datatype.validate(datatype.members[value]) else: # enum backed by string record - return datatype.validate(datatype.enum_cls[value]) + assert isinstance(value, str), "Got non-string value for long Enum" + # python typing can't narrow the nested generic enum_cls + assert issubclass(datatype.enum_cls, enum.Enum), "Invalid Enum.enum_cls" + enum_member = datatype.enum_cls[value] + return datatype.validate(enum_member) case datatype if issubclass(type(datatype), EPICS_ALLOWED_DATATYPES): return datatype.validate(value) # type: ignore case _: diff --git a/src/fastcs/transports/epics/pva/types.py b/src/fastcs/transports/epics/pva/types.py index 80f79b503..75e979996 100644 --- a/src/fastcs/transports/epics/pva/types.py +++ b/src/fastcs/transports/epics/pva/types.py @@ -7,7 +7,8 @@ from p4p.nt import NTEnum, NTNDArray, NTScalar, NTTable from fastcs.attributes import Attribute, AttrR, AttrW -from fastcs.datatypes import Bool, DType_T, Enum, Float, Int, String, Table, Waveform +from fastcs.datatypes import Bool, DType, Enum, Float, Int, String, Table, Waveform +from fastcs.datatypes.datatype import DType_T P4P_ALLOWED_DATATYPES = (Int, Float, String, Bool, Enum, Waveform, Table) @@ -90,7 +91,9 @@ def cast_from_p4p_value(attribute: Attribute[DType_T], value: object) -> DType_T """Converts from a p4p value to a FastCS `Attribute` value.""" match attribute.datatype: case Enum(): - return attribute.datatype.validate(attribute.datatype.members[value.index]) + assert hasattr(value, "index"), "Got non-enum p4p.Value for Enum DataType" + index: int = value.index # pyright: ignore[reportAttributeAccessIssue] + return attribute.datatype.validate(attribute.datatype.members[index]) case Waveform(shape=shape): # p4p sends a flattened array assert value.shape == (math.prod(shape),) @@ -154,7 +157,7 @@ def p4p_display(attribute: Attribute) -> dict: return {} -def _p4p_check_numeric_for_alarm_states(datatype: Int | Float, value: DType_T) -> dict: +def _p4p_check_numeric_for_alarm_states(datatype: Int | Float, value: DType) -> dict: low = None if datatype.min_alarm is None else value < datatype.min_alarm # type: ignore high = None if datatype.max_alarm is None else value > datatype.max_alarm # type: ignore severity = ( diff --git a/src/fastcs/transports/graphql/graphql.py b/src/fastcs/transports/graphql/graphql.py index 79eeb7d17..f3c26966a 100644 --- a/src/fastcs/transports/graphql/graphql.py +++ b/src/fastcs/transports/graphql/graphql.py @@ -135,10 +135,10 @@ async def _dynamic_f(value): def _wrap_attr_get( attr_name: str, attribute: AttrR[DType_T] -) -> Callable[[], Coroutine[Any, Any, Any]]: +) -> Callable[[], Coroutine[Any, Any, DType_T]]: """Wrap an attribute in a function with annotations for strawberry""" - async def _dynamic_f() -> Any: + async def _dynamic_f() -> DType_T: return attribute.get() _dynamic_f.__name__ = attr_name diff --git a/src/fastcs/transports/rest/rest.py b/src/fastcs/transports/rest/rest.py index f6600f94f..3bab024ff 100644 --- a/src/fastcs/transports/rest/rest.py +++ b/src/fastcs/transports/rest/rest.py @@ -92,9 +92,9 @@ def _get_response_body(attribute: AttrR[DType_T]): def _wrap_attr_get( attribute: AttrR[DType_T], -) -> Callable[[], Coroutine[Any, Any, Any]]: - async def attr_get() -> Any: # Must be any as response_model is set - value = attribute.get() # type: ignore +) -> Callable[[], Coroutine[Any, Any, dict[str, object]]]: + async def attr_get() -> dict[str, object]: + value = attribute.get() return {"value": cast_to_rest_type(attribute.datatype, value)} return attr_get diff --git a/src/fastcs/transports/rest/util.py b/src/fastcs/transports/rest/util.py index faba0bc50..550d72941 100644 --- a/src/fastcs/transports/rest/util.py +++ b/src/fastcs/transports/rest/util.py @@ -5,7 +5,7 @@ REST_ALLOWED_DATATYPES = (Bool, DataType, Enum, Float, Int, String) -def convert_datatype(datatype: DataType[DType_T]) -> type: +def convert_datatype(datatype: DataType[DType_T]) -> type[DType_T]: """Converts a datatype to a rest serialisable type.""" match datatype: case Waveform(): diff --git a/src/fastcs/transports/tango/util.py b/src/fastcs/transports/tango/util.py index d6785468d..61c287e8f 100644 --- a/src/fastcs/transports/tango/util.py +++ b/src/fastcs/transports/tango/util.py @@ -7,6 +7,7 @@ from fastcs.datatypes import ( Bool, DataType, + DType, DType_T, Enum, Float, @@ -27,7 +28,7 @@ def get_server_metadata_from_attribute( - attribute: Attribute[DType_T], + attribute: Attribute[DType], ) -> dict[str, Any]: """Gets the metadata for a Tango field from an attribute.""" arguments = {} @@ -35,7 +36,7 @@ def get_server_metadata_from_attribute( return arguments -def get_server_metadata_from_datatype(datatype: DataType[DType_T]) -> dict[str, str]: +def get_server_metadata_from_datatype(datatype: DataType[DType]) -> dict[str, str]: """Gets the metadata for a Tango field from a FastCS datatype.""" arguments = { DATATYPE_FIELD_TO_SERVER_FIELD[field]: value @@ -86,6 +87,7 @@ def cast_from_tango_type(datatype: DataType[DType_T], value: object) -> DType_T: """Casts a value from tango to FastCS datatype.""" match datatype: case Enum(): + assert isinstance(value, int), "Got non-integer value for Enum" return datatype.validate(datatype.members[value]) case datatype if issubclass(type(datatype), TANGO_ALLOWED_DATATYPES): return datatype.validate(value) # type: ignore diff --git a/tests/test_attributes.py b/tests/test_attributes.py index c5a246ec0..41de437a4 100644 --- a/tests/test_attributes.py +++ b/tests/test_attributes.py @@ -8,9 +8,7 @@ from fastcs.attributes import AttributeIO, AttributeIORef, AttrR, AttrRW, AttrW from fastcs.controllers import Controller -from fastcs.datatypes import DType_T, Float, Int, String - -NumberT = TypeVar("NumberT", int, float) +from fastcs.datatypes import Float, Int, String def test_attr_r(): @@ -125,7 +123,7 @@ class MyAttributeIORef(AttributeIORef): cool: int class MyAttributeIO(AttributeIO[int, MyAttributeIORef]): - async def update(self, attr: AttrR[DType_T, MyAttributeIORef]): + async def update(self, attr: AttrR[int, MyAttributeIORef]): print("I am updating", self.ref_type, attr.io_ref.cool) class MyController(Controller): @@ -221,6 +219,9 @@ async def set(self, uri: str, value: float | int): self._float_value = value +NumberT = TypeVar("NumberT", int, float) + + @pytest.mark.asyncio() async def test_dynamic_attribute_io_specification(): @dataclass @@ -314,11 +315,9 @@ class MyController(Controller): c = MyController() c._connect_attribute_ios() - class SimpleAttributeIO(AttributeIO[DType_T]): + class SimpleAttributeIO(AttributeIO[int]): async def update(self, attr): - match attr: - case AttrR(datatype=Int()): - await attr.update(100) + await attr.update(100) with pytest.raises( RuntimeError, match="More than one AttributeIO class handles AttributeIORef"