diff --git a/docs/conf.py b/docs/conf.py index 5471426f7..db4fcc0fa 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -77,6 +77,7 @@ ("py:class", "asyncio.events.AbstractEventLoop"), ("py:class", "asyncio.streams.StreamReader"), ("py:class", "asyncio.streams.StreamWriter"), + ("py:class", "asyncio.locks.Event"), # Annoying error: # docstring of collections.abc.Callable:1: WARNING: # 'any' reference target not found: self [ref.any] diff --git a/src/fastcs/attributes/attr_r.py b/src/fastcs/attributes/attr_r.py index c6583252a..aec7f5222 100644 --- a/src/fastcs/attributes/attr_r.py +++ b/src/fastcs/attributes/attr_r.py @@ -6,6 +6,7 @@ from fastcs.attributes.attribute import Attribute from fastcs.attributes.attribute_io_ref import AttributeIORefT +from fastcs.attributes.util import AttrValuePredicate, PredicateEvent from fastcs.datatypes import DataType, DType_T from fastcs.logging import bind_logger @@ -39,6 +40,8 @@ def __init__( """Callback to update the value of the attribute with an IO to the source""" self._on_update_callbacks: list[AttrOnUpdateCallback[DType_T]] | None = None """Callbacks to publish changes to the value of the attribute""" + self._on_update_events: set[PredicateEvent[DType_T]] = set() + """Events to set when the value satisifies some predicate""" def get(self) -> DType_T: """Get the cached value of the attribute.""" @@ -51,6 +54,9 @@ async def update(self, value: Any) -> None: generally only be called from an IO or a controller that is updating the value from some underlying source. + Any update callbacks will be called with the new value and any update events + with predicates satisfied by the new value will be set. + To request a change to the setpoint of the attribute, use the ``put`` method, which will attempt to apply the change to the underlying source. @@ -67,6 +73,10 @@ async def update(self, value: Any) -> None: self._value = self._datatype.validate(value) + self._on_update_events -= { + e for e in self._on_update_events if e.set(self._value) + } + if self._on_update_callbacks is not None: try: await asyncio.gather( @@ -115,3 +125,69 @@ async def update_attribute(): raise return update_attribute + + async def wait_for_predicate( + self, predicate: AttrValuePredicate[DType_T], *, timeout: float + ): + """Wait for the predicate to be satisfied when called with the current value + + Args: + predicate: The predicate to test - a callable that takes the attribute + value and returns True if the event should be set + timeout: The timeout in seconds + + """ + if predicate(self._value): + self.log_event( + "Predicate already satisfied", predicate=predicate, attribute=self + ) + return + + self._on_update_events.add(update_event := PredicateEvent(predicate)) + + self.log_event("Waiting for predicate", predicate=predicate, attribute=self) + try: + await asyncio.wait_for(update_event.wait(), timeout) + except TimeoutError: + self._on_update_events.remove(update_event) + raise TimeoutError( + f"Timeout waiting {timeout}s for {self.full_name} predicate {predicate}" + f" - current value: {self._value}" + ) from None + + self.log_event("Predicate satisfied", predicate=predicate, attribute=self) + + async def wait_for_value(self, target_value: DType_T, *, timeout: float): + """Wait for self._value to equal the target value + + Args: + target_value: The target value to wait for + timeout: The timeout in seconds + + Raises: + TimeoutError: If the attribute does not reach the target value within the + timeout + + """ + if self._value == target_value: + self.log_event( + "Current value already equals target value", + target_value=target_value, + attribute=self, + ) + return + + def predicate(v: DType_T) -> bool: + return v == target_value + + try: + await self.wait_for_predicate(predicate, timeout=timeout) + except TimeoutError: + raise TimeoutError( + f"Timeout waiting {timeout}s for {self.full_name} value {target_value}" + f" - current value: {self._value}" + ) from None + + self.log_event( + "Value equals target value", target_valuevalue=target_value, attribute=self + ) diff --git a/src/fastcs/attributes/attribute.py b/src/fastcs/attributes/attribute.py index 546dab165..c737bd28d 100644 --- a/src/fastcs/attributes/attribute.py +++ b/src/fastcs/attributes/attribute.py @@ -70,6 +70,10 @@ def name(self) -> str: def path(self) -> list[str]: return self._path + @property + def full_name(self) -> str: + return ".".join(self._path + [self._name]) + def add_update_datatype_callback( self, callback: Callable[[DataType[DType_T]], None] ) -> None: @@ -102,7 +106,7 @@ def set_path(self, path: list[str]): def __repr__(self): name = self.__class__.__name__ - path = ".".join(self._path + [self._name]) or None + full_name = self.full_name or None datatype = self._datatype.__class__.__name__ - return f"{name}(path={path}, datatype={datatype}, io_ref={self._io_ref})" + return f"{name}(name={full_name}, datatype={datatype}, io_ref={self._io_ref})" diff --git a/src/fastcs/attributes/util.py b/src/fastcs/attributes/util.py new file mode 100644 index 000000000..4739c30aa --- /dev/null +++ b/src/fastcs/attributes/util.py @@ -0,0 +1,39 @@ +import asyncio +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Generic + +from fastcs.datatypes import DType_T + +AttrValuePredicate = Callable[[DType_T], bool] + + +@dataclass(eq=False) +class PredicateEvent(Generic[DType_T]): + """A wrapper of `asyncio.Event` that only triggers when a predicate is satisfied""" + + _predicate: AttrValuePredicate[DType_T] + """Predicate to filter set calls by""" + _event: asyncio.Event = field(default_factory=asyncio.Event) + """Event to set""" + + def set(self, value: DType_T) -> bool: + """Set the event if the predicate is satisfied by the value + + Returns: + `True` if the predicate was satisfied and the event was set, else `False` + + """ + if self._predicate(value): + self._event.set() + return True + + return False + + async def wait(self): + """Wait for the event to be set""" + await self._event.wait() + + def __hash__(self) -> int: + """Make instances unique when stored in sets""" + return id(self) diff --git a/src/fastcs/transports/epics/ca/ioc.py b/src/fastcs/transports/epics/ca/ioc.py index 733fdf62b..f6ca6c5af 100644 --- a/src/fastcs/transports/epics/ca/ioc.py +++ b/src/fastcs/transports/epics/ca/ioc.py @@ -193,7 +193,11 @@ def _make_record( ) attribute_record_metadata = record_metadata_from_attribute(attribute) - update = {"always_update": True, "on_update": on_update} if on_update else {} + update = ( + {"on_update": on_update, "always_update": True, "blocking": True} + if on_update + else {} + ) record = builder_callable( pv, **update, **datatype_record_metadata, **attribute_record_metadata diff --git a/tests/test_attributes.py b/tests/test_attributes.py index 04ced8bd3..c5a246ec0 100644 --- a/tests/test_attributes.py +++ b/tests/test_attributes.py @@ -1,3 +1,4 @@ +import asyncio from dataclasses import dataclass from functools import partial from typing import Generic, TypeVar @@ -12,7 +13,7 @@ NumberT = TypeVar("NumberT", int, float) -def test_attribute(): +def test_attr_r(): attr = AttrR(String(), group="test group") with pytest.raises(RuntimeError): @@ -39,6 +40,57 @@ def test_attribute(): assert attr.get() == "" +@pytest.mark.asyncio +async def test_wait_for_predicate(mocker: MockerFixture): + attr = AttrR(Int(), initial_value=0) + + async def update(attr: AttrR): + while True: + await asyncio.sleep(0.1) + await attr.update(attr.get() + 3) # 3, 6, 9, 12 != 10 + + asyncio.create_task(update(attr)) + + # We won't see exactly 10 so check for greater than + def predicate(v: int) -> bool: + return v > 10 + + wait_mock = mocker.spy(asyncio, "wait_for") + with pytest.raises(TimeoutError): + await attr.wait_for_predicate(predicate, timeout=0.2) + + await attr.wait_for_predicate(predicate, timeout=1) + + assert wait_mock.call_count == 2 + + # Returns immediately without creating event if value already as expected + await attr.wait_for_predicate(predicate, timeout=1) + assert wait_mock.call_count == 2 + + +@pytest.mark.asyncio +async def test_wait_for_value(mocker: MockerFixture): + attr = AttrR(Int(), initial_value=0) + + async def update(attr: AttrR): + await asyncio.sleep(0.5) + await attr.update(1) + + asyncio.create_task(update(attr)) + + wait_mock = mocker.spy(asyncio, "wait_for") + with pytest.raises(TimeoutError): + await attr.wait_for_value(10, timeout=0.2) + + await attr.wait_for_value(1, timeout=1) + + assert wait_mock.call_count == 2 + + # Returns immediately without creating event if value already as expected + await attr.wait_for_value(1, timeout=1) + assert wait_mock.call_count == 2 + + @pytest.mark.asyncio async def test_attributes(): device = {"state": "Idle", "number": 1, "count": False} diff --git a/tests/transports/epics/ca/test_softioc.py b/tests/transports/epics/ca/test_softioc.py index c76ba2aa6..d9eb3b1e5 100644 --- a/tests/transports/epics/ca/test_softioc.py +++ b/tests/transports/epics/ca/test_softioc.py @@ -187,7 +187,7 @@ def test_make_output_record( kwargs.update(record_metadata_from_datatype(attribute.datatype, out_record=True)) kwargs.update(record_metadata_from_attribute(attribute)) - kwargs.update({"always_update": True, "on_update": update}) + kwargs.update({"always_update": True, "on_update": update, "blocking": True}) getattr(builder, record_type).assert_called_once_with( pv, @@ -265,6 +265,7 @@ def test_ioc(mocker: MockerFixture, epics_controller_api: ControllerAPI): builder.aOut.assert_any_call( f"{DEVICE}:ReadWriteFloat", always_update=True, + blocking=True, on_update=mocker.ANY, **record_metadata_from_attribute( epics_controller_api.attributes["read_write_float"] @@ -286,6 +287,7 @@ def test_ioc(mocker: MockerFixture, epics_controller_api: ControllerAPI): builder.longOut.assert_called_with( f"{DEVICE}:ReadWriteInt", always_update=True, + blocking=True, on_update=mocker.ANY, **record_metadata_from_attribute( epics_controller_api.attributes["read_write_int"] @@ -304,6 +306,7 @@ def test_ioc(mocker: MockerFixture, epics_controller_api: ControllerAPI): builder.mbbOut.assert_called_once_with( f"{DEVICE}:Enum", always_update=True, + blocking=True, on_update=mocker.ANY, **record_metadata_from_attribute(epics_controller_api.attributes["enum"]), **record_metadata_from_datatype( @@ -313,6 +316,7 @@ def test_ioc(mocker: MockerFixture, epics_controller_api: ControllerAPI): builder.boolOut.assert_called_once_with( f"{DEVICE}:WriteBool", always_update=True, + blocking=True, on_update=mocker.ANY, **record_metadata_from_attribute(epics_controller_api.attributes["write_bool"]), **record_metadata_from_datatype( @@ -460,6 +464,7 @@ def test_long_pv_names_discarded(mocker: MockerFixture): builder.longOut.assert_called_once_with( f"{DEVICE}:{short_pv_name}", always_update=True, + blocking=True, on_update=mocker.ANY, **record_metadata_from_datatype( long_name_controller_api.attributes["attr_rw_short_name"].datatype, @@ -500,6 +505,7 @@ def test_long_pv_names_discarded(mocker: MockerFixture): builder.longOut.assert_called_once_with( f"{DEVICE}:{long_rw_pv_name}", always_update=True, + blocking=True, on_update=mocker.ANY, ) with pytest.raises(AssertionError): diff --git a/tests/transports/epics/pva/test_p4p.py b/tests/transports/epics/pva/test_p4p.py index 1b5ca9f18..e33c58f98 100644 --- a/tests/transports/epics/pva/test_p4p.py +++ b/tests/transports/epics/pva/test_p4p.py @@ -394,7 +394,8 @@ class SomeController(Controller): } -def test_more_exotic_datatypes(): +@pytest.mark.asyncio +async def test_more_exotic_datatypes(): table_columns: list[tuple[str, DTypeLike]] = [ ("A", "i"), ("B", "i"), @@ -439,7 +440,6 @@ class SomeController(Controller): client_put_enum_value = "C" async def _wait_and_set_attrs(): - await asyncio.sleep(0.1) # This demonstrates an update from hardware, # resulting in only a change in the read back. await asyncio.gather( @@ -449,7 +449,6 @@ async def _wait_and_set_attrs(): ) async def _wait_and_put_pvs(): - await asyncio.sleep(0.3) ctxt = Context("pva") # This demonstrates a client put, # resulting in a change in the demand and read back. @@ -471,72 +470,62 @@ async def _wait_and_put_pvs(): enum_values.append, ) - serve = asyncio.ensure_future(fastcs.serve(interactive=False)) - wait_and_set_attrs = asyncio.ensure_future(_wait_and_set_attrs()) - wait_and_put_pvs = asyncio.ensure_future(_wait_and_put_pvs()) - try: - asyncio.get_event_loop().run_until_complete( - asyncio.wait_for( - asyncio.gather(serve, wait_and_set_attrs, wait_and_put_pvs), - timeout=0.6, - ) - ) - except TimeoutError: - ... - finally: - waveform_monitor.close() - table_monitor.close() - enum_monitor.close() - serve.cancel() - wait_and_set_attrs.cancel() - wait_and_put_pvs.cancel() + serve = asyncio.create_task(fastcs.serve(interactive=False)) + await asyncio.sleep(0.1) # Wait for task to start - expected_waveform_gets = [ - initial_waveform_value, - server_set_waveform_value, - client_put_waveform_value, - ] + await _wait_and_set_attrs() + await _wait_and_put_pvs() + await asyncio.sleep(0.1) # Wait for monitors to return - for expected_waveform, actual_waveform in zip( - expected_waveform_gets, waveform_values, strict=True - ): - np.testing.assert_array_equal( - expected_waveform, actual_waveform.todict()["value"].reshape(10, 10) - ) + waveform_monitor.close() + table_monitor.close() + enum_monitor.close() + serve.cancel() - expected_table_gets = [ - NTTable(columns=table_columns).wrap(initial_table_value), - NTTable(columns=table_columns).wrap(server_set_table_value), - client_put_table_value, - ] - for expected_table, actual_table in zip( - expected_table_gets, table_values, strict=True - ): - expected_table = expected_table.todict()["value"] - actual_table = actual_table.todict()["value"] - for expected_column, actual_column in zip( - expected_table.values(), actual_table.values(), strict=True - ): - if isinstance(expected_column, np.ndarray): - np.testing.assert_array_equal(expected_column, actual_column) - else: - assert expected_column == actual_column and actual_column is None - - expected_enum_gets = [ - initial_enum_value, - server_set_enum_value, - AnEnum.C, - ] + expected_waveform_gets = [ + initial_waveform_value, + server_set_waveform_value, + client_put_waveform_value, + ] + + for expected_waveform, actual_waveform in zip( + expected_waveform_gets, waveform_values, strict=True + ): + np.testing.assert_array_equal( + expected_waveform, actual_waveform.todict()["value"].reshape(10, 10) + ) - for expected_enum, actual_enum in zip( - expected_enum_gets, enum_values, strict=True + expected_table_gets = [ + NTTable(columns=table_columns).wrap(initial_table_value), + NTTable(columns=table_columns).wrap(server_set_table_value), + client_put_table_value, + ] + for expected_table, actual_table in zip( + expected_table_gets, table_values, strict=True + ): + expected_table = expected_table.todict()["value"] + actual_table = actual_table.todict()["value"] + for expected_column, actual_column in zip( + expected_table.values(), actual_table.values(), strict=True ): - assert ( - expected_enum - == controller.some_enum.datatype.members[ # type: ignore - actual_enum.todict()["value"]["index"] - ] - ) + if isinstance(expected_column, np.ndarray): + np.testing.assert_array_equal(expected_column, actual_column) + else: + assert expected_column == actual_column and actual_column is None + + expected_enum_gets = [ + initial_enum_value, + server_set_enum_value, + AnEnum.C, + ] + + for expected_enum, actual_enum in zip(expected_enum_gets, enum_values, strict=True): + assert ( + expected_enum + == controller.some_enum.datatype.members[ # type: ignore + actual_enum.todict()["value"]["index"] + ] + ) @pytest.mark.timeout(4)