Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
("py:class", "asyncio.events.AbstractEventLoop"),
("py:class", "asyncio.streams.StreamReader"),
("py:class", "asyncio.streams.StreamWriter"),
("py:class", "asyncio.locks.Event"),
# Annoying error:
# docstring of collections.abc.Callable:1: WARNING:
# 'any' reference target not found: self [ref.any]
Expand Down
76 changes: 76 additions & 0 deletions src/fastcs/attributes/attr_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from fastcs.attributes.attribute import Attribute
from fastcs.attributes.attribute_io_ref import AttributeIORefT
from fastcs.attributes.util import AttrValuePredicate, PredicateEvent
from fastcs.datatypes import DataType, DType_T
from fastcs.logging import bind_logger

Expand Down Expand Up @@ -39,6 +40,8 @@ def __init__(
"""Callback to update the value of the attribute with an IO to the source"""
self._on_update_callbacks: list[AttrOnUpdateCallback[DType_T]] | None = None
"""Callbacks to publish changes to the value of the attribute"""
self._on_update_events: set[PredicateEvent[DType_T]] = set()
"""Events to set when the value satisifies some predicate"""

def get(self) -> DType_T:
"""Get the cached value of the attribute."""
Expand All @@ -51,6 +54,9 @@ async def update(self, value: Any) -> None:
generally only be called from an IO or a controller that is updating the value
from some underlying source.

Any update callbacks will be called with the new value and any update events
with predicates satisfied by the new value will be set.

To request a change to the setpoint of the attribute, use the ``put`` method,
which will attempt to apply the change to the underlying source.

Expand All @@ -67,6 +73,10 @@ async def update(self, value: Any) -> None:

self._value = self._datatype.validate(value)

self._on_update_events -= {
e for e in self._on_update_events if e.set(self._value)
}

if self._on_update_callbacks is not None:
try:
await asyncio.gather(
Expand Down Expand Up @@ -115,3 +125,69 @@ async def update_attribute():
raise

return update_attribute

async def wait_for_predicate(
self, predicate: AttrValuePredicate[DType_T], *, timeout: float
):
"""Wait for the predicate to be satisfied when called with the current value

Args:
predicate: The predicate to test - a callable that takes the attribute
value and returns True if the event should be set
timeout: The timeout in seconds

"""
if predicate(self._value):
self.log_event(
"Predicate already satisfied", predicate=predicate, attribute=self
)
return

self._on_update_events.add(update_event := PredicateEvent(predicate))

self.log_event("Waiting for predicate", predicate=predicate, attribute=self)
try:
await asyncio.wait_for(update_event.wait(), timeout)
except TimeoutError:
self._on_update_events.remove(update_event)
raise TimeoutError(
f"Timeout waiting {timeout}s for {self.full_name} predicate {predicate}"
f" - current value: {self._value}"
) from None

self.log_event("Predicate satisfied", predicate=predicate, attribute=self)

async def wait_for_value(self, target_value: DType_T, *, timeout: float):
"""Wait for self._value to equal the target value

Args:
target_value: The target value to wait for
timeout: The timeout in seconds

Raises:
TimeoutError: If the attribute does not reach the target value within the
timeout

"""
if self._value == target_value:
self.log_event(
"Current value already equals target value",
target_value=target_value,
attribute=self,
)
return

def predicate(v: DType_T) -> bool:
return v == target_value

try:
await self.wait_for_predicate(predicate, timeout=timeout)
except TimeoutError:
raise TimeoutError(
f"Timeout waiting {timeout}s for {self.full_name} value {target_value}"
f" - current value: {self._value}"
) from None

self.log_event(
"Value equals target value", target_valuevalue=target_value, attribute=self
)
8 changes: 6 additions & 2 deletions src/fastcs/attributes/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ def name(self) -> str:
def path(self) -> list[str]:
return self._path

@property
def full_name(self) -> str:
return ".".join(self._path + [self._name])

def add_update_datatype_callback(
self, callback: Callable[[DataType[DType_T]], None]
) -> None:
Expand Down Expand Up @@ -102,7 +106,7 @@ def set_path(self, path: list[str]):

def __repr__(self):
name = self.__class__.__name__
path = ".".join(self._path + [self._name]) or None
full_name = self.full_name or None
datatype = self._datatype.__class__.__name__

return f"{name}(path={path}, datatype={datatype}, io_ref={self._io_ref})"
return f"{name}(name={full_name}, datatype={datatype}, io_ref={self._io_ref})"
39 changes: 39 additions & 0 deletions src/fastcs/attributes/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import asyncio
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Generic

from fastcs.datatypes import DType_T

AttrValuePredicate = Callable[[DType_T], bool]


@dataclass(eq=False)
class PredicateEvent(Generic[DType_T]):
"""A wrapper of `asyncio.Event` that only triggers when a predicate is satisfied"""

_predicate: AttrValuePredicate[DType_T]
"""Predicate to filter set calls by"""
_event: asyncio.Event = field(default_factory=asyncio.Event)
"""Event to set"""

def set(self, value: DType_T) -> bool:
"""Set the event if the predicate is satisfied by the value

Returns:
`True` if the predicate was satisfied and the event was set, else `False`

"""
if self._predicate(value):
self._event.set()
return True

return False

async def wait(self):
"""Wait for the event to be set"""
await self._event.wait()

def __hash__(self) -> int:
"""Make instances unique when stored in sets"""
return id(self)
6 changes: 5 additions & 1 deletion src/fastcs/transports/epics/ca/ioc.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,11 @@ def _make_record(
)
attribute_record_metadata = record_metadata_from_attribute(attribute)

update = {"always_update": True, "on_update": on_update} if on_update else {}
update = (
{"on_update": on_update, "always_update": True, "blocking": True}
if on_update
else {}
)

record = builder_callable(
pv, **update, **datatype_record_metadata, **attribute_record_metadata
Expand Down
54 changes: 53 additions & 1 deletion tests/test_attributes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from dataclasses import dataclass
from functools import partial
from typing import Generic, TypeVar
Expand All @@ -12,7 +13,7 @@
NumberT = TypeVar("NumberT", int, float)


def test_attribute():
def test_attr_r():
attr = AttrR(String(), group="test group")

with pytest.raises(RuntimeError):
Expand All @@ -39,6 +40,57 @@ def test_attribute():
assert attr.get() == ""


@pytest.mark.asyncio
async def test_wait_for_predicate(mocker: MockerFixture):
attr = AttrR(Int(), initial_value=0)

async def update(attr: AttrR):
while True:
await asyncio.sleep(0.1)
await attr.update(attr.get() + 3) # 3, 6, 9, 12 != 10

asyncio.create_task(update(attr))

# We won't see exactly 10 so check for greater than
def predicate(v: int) -> bool:
return v > 10

wait_mock = mocker.spy(asyncio, "wait_for")
with pytest.raises(TimeoutError):
await attr.wait_for_predicate(predicate, timeout=0.2)

await attr.wait_for_predicate(predicate, timeout=1)

assert wait_mock.call_count == 2

# Returns immediately without creating event if value already as expected
await attr.wait_for_predicate(predicate, timeout=1)
assert wait_mock.call_count == 2


@pytest.mark.asyncio
async def test_wait_for_value(mocker: MockerFixture):
attr = AttrR(Int(), initial_value=0)

async def update(attr: AttrR):
await asyncio.sleep(0.5)
await attr.update(1)

asyncio.create_task(update(attr))

wait_mock = mocker.spy(asyncio, "wait_for")
with pytest.raises(TimeoutError):
await attr.wait_for_value(10, timeout=0.2)

await attr.wait_for_value(1, timeout=1)

assert wait_mock.call_count == 2

# Returns immediately without creating event if value already as expected
await attr.wait_for_value(1, timeout=1)
assert wait_mock.call_count == 2


@pytest.mark.asyncio
async def test_attributes():
device = {"state": "Idle", "number": 1, "count": False}
Expand Down
8 changes: 7 additions & 1 deletion tests/transports/epics/ca/test_softioc.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def test_make_output_record(

kwargs.update(record_metadata_from_datatype(attribute.datatype, out_record=True))
kwargs.update(record_metadata_from_attribute(attribute))
kwargs.update({"always_update": True, "on_update": update})
kwargs.update({"always_update": True, "on_update": update, "blocking": True})

getattr(builder, record_type).assert_called_once_with(
pv,
Expand Down Expand Up @@ -265,6 +265,7 @@ def test_ioc(mocker: MockerFixture, epics_controller_api: ControllerAPI):
builder.aOut.assert_any_call(
f"{DEVICE}:ReadWriteFloat",
always_update=True,
blocking=True,
on_update=mocker.ANY,
**record_metadata_from_attribute(
epics_controller_api.attributes["read_write_float"]
Expand All @@ -286,6 +287,7 @@ def test_ioc(mocker: MockerFixture, epics_controller_api: ControllerAPI):
builder.longOut.assert_called_with(
f"{DEVICE}:ReadWriteInt",
always_update=True,
blocking=True,
on_update=mocker.ANY,
**record_metadata_from_attribute(
epics_controller_api.attributes["read_write_int"]
Expand All @@ -304,6 +306,7 @@ def test_ioc(mocker: MockerFixture, epics_controller_api: ControllerAPI):
builder.mbbOut.assert_called_once_with(
f"{DEVICE}:Enum",
always_update=True,
blocking=True,
on_update=mocker.ANY,
**record_metadata_from_attribute(epics_controller_api.attributes["enum"]),
**record_metadata_from_datatype(
Expand All @@ -313,6 +316,7 @@ def test_ioc(mocker: MockerFixture, epics_controller_api: ControllerAPI):
builder.boolOut.assert_called_once_with(
f"{DEVICE}:WriteBool",
always_update=True,
blocking=True,
on_update=mocker.ANY,
**record_metadata_from_attribute(epics_controller_api.attributes["write_bool"]),
**record_metadata_from_datatype(
Expand Down Expand Up @@ -460,6 +464,7 @@ def test_long_pv_names_discarded(mocker: MockerFixture):
builder.longOut.assert_called_once_with(
f"{DEVICE}:{short_pv_name}",
always_update=True,
blocking=True,
on_update=mocker.ANY,
**record_metadata_from_datatype(
long_name_controller_api.attributes["attr_rw_short_name"].datatype,
Expand Down Expand Up @@ -500,6 +505,7 @@ def test_long_pv_names_discarded(mocker: MockerFixture):
builder.longOut.assert_called_once_with(
f"{DEVICE}:{long_rw_pv_name}",
always_update=True,
blocking=True,
on_update=mocker.ANY,
)
with pytest.raises(AssertionError):
Expand Down
Loading