diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index f2cc14b8..904e0ba0 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -5,7 +5,7 @@ import random import types import typing -from collections.abc import Callable +from collections.abc import Callable, Iterable, Mapping from typing import Annotated, Concatenate, Generic, TypeVar import tree @@ -1108,6 +1108,39 @@ def trace(value: Callable[P, T]) -> Callable[P, T]: return deffn(body, *bound_sig.args, **bound_sig.kwargs) +@defop +def defstream( + body: Annotated[T, Scoped[A | B]], + streams: Annotated[Mapping[Operation[[], S], Iterable[S]], Scoped[B]], +) -> Annotated[Iterable[T], Scoped[A]]: + """A higher-order operation that represents a for-expression.""" + raise NotImplementedError + + +@defdata.register(collections.abc.Iterable) +class _IterableTerm(Generic[T], _BaseTerm[collections.abc.Iterable[T]]): + @defop + def __iter__(self: collections.abc.Iterable[T]) -> collections.abc.Iterator[T]: + if not isinstance(self, Term): + return iter(self) + else: + raise NotImplementedError + + +@defdata.register(collections.abc.Iterator) +class _IteratorTerm(Generic[T], _IterableTerm[T]): + @defop + def __next__(self: collections.abc.Iterator[T]) -> T: + if not isinstance(self, Term): + return next(self) + else: + raise NotImplementedError + + +iter_ = _IterableTerm.__iter__ +next_ = _IteratorTerm.__next__ + + def syntactic_eq(x: Expr[T], other: Expr[T]) -> bool: """Syntactic equality, ignoring the interpretation of the terms. diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index 93402f87..81d0be54 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -1,18 +1,21 @@ import functools import inspect -from collections.abc import Callable, Mapping +from collections.abc import Callable, Iterable, Iterator, Mapping from typing import Annotated, ClassVar, TypeVar import pytest import effectful.handlers.numbers # noqa: F401 -from effectful.ops.semantics import call, evaluate, fvsof, handler +from effectful.ops.semantics import call, evaluate, fvsof, handler, typeof from effectful.ops.syntax import ( Scoped, _CustomSingleDispatchCallable, deffn, defop, + defstream, defterm, + iter_, + next_, ) from effectful.ops.types import Operation, Term @@ -484,3 +487,56 @@ def _(self, x: bool) -> bool: # Test that the method can be called with a handler with handler({MyClass.my_singledispatch: lambda self, x: x + 6}): assert instance.my_singledispatch(5) == 11 + + +def test_defdata_iterable(): + @defop + def cons_iterable(*args: int) -> Iterable[int]: + raise NotImplementedError + + tm = cons_iterable(1, 2, 3) + assert isinstance(tm, Term) + assert isinstance(tm, Iterable) + assert issubclass(typeof(tm), Iterable) + assert tm.op is cons_iterable + assert tm.args == (1, 2, 3) + + tm_iter = iter(tm) + assert isinstance(tm_iter, Term) + assert isinstance(tm_iter, Iterator) + assert issubclass(typeof(tm_iter), Iterator) + assert tm_iter.op is iter_ + + tm_iter_next = next(tm_iter) + assert isinstance(tm_iter_next, Term) + # assert isinstance(tm_iter_next, numbers.Number) # TODO + # assert issubclass(typeof(tm_iter_next), numbers.Number) + assert tm_iter_next.op is next_ + + assert list(tm.args) == [1, 2, 3] + + +def test_defstream_1(): + x = defop(int, name="x") + y = defop(int, name="y") + tm = defstream(x() + y(), {x: [1, 2, 3], y: [x() + 1, x() + 2, x() + 3]}) + + assert isinstance(tm, Term) + assert isinstance(tm, Iterable) + assert issubclass(typeof(tm), Iterable) + assert tm.op is defstream + + assert x not in fvsof(tm) + assert y not in fvsof(tm) + + tm_iter = iter(tm) + assert isinstance(tm_iter, Term) + assert isinstance(tm_iter, Iterator) + assert issubclass(typeof(tm_iter), Iterator) + assert tm_iter.op is iter_ + + tm_iter_next = next(tm_iter) + assert isinstance(tm_iter_next, Term) + # assert isinstance(tm_iter_next, numbers.Number) # TODO + # assert issubclass(typeof(tm_iter_next), numbers.Number) + assert tm_iter_next.op is next_