Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion effectful/ops/syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import random
import types
import typing
from collections.abc import Callable
from collections.abc import Callable, Iterable, Mapping
from typing import Annotated, Concatenate, Generic, TypeVar

import tree
Expand Down Expand Up @@ -1108,6 +1108,39 @@ def trace(value: Callable[P, T]) -> Callable[P, T]:
return deffn(body, *bound_sig.args, **bound_sig.kwargs)


@defop
def defstream(
body: Annotated[T, Scoped[A | B]],
streams: Annotated[Mapping[Operation[[], S], Iterable[S]], Scoped[B]],
) -> Annotated[Iterable[T], Scoped[A]]:
"""A higher-order operation that represents a for-expression."""
raise NotImplementedError


@defdata.register(collections.abc.Iterable)
class _IterableTerm(Generic[T], _BaseTerm[collections.abc.Iterable[T]]):
@defop
def __iter__(self: collections.abc.Iterable[T]) -> collections.abc.Iterator[T]:
if not isinstance(self, Term):
return iter(self)
else:
raise NotImplementedError


@defdata.register(collections.abc.Iterator)
class _IteratorTerm(Generic[T], _IterableTerm[T]):
@defop
def __next__(self: collections.abc.Iterator[T]) -> T:
if not isinstance(self, Term):
return next(self)
else:
raise NotImplementedError


iter_ = _IterableTerm.__iter__
next_ = _IteratorTerm.__next__


def syntactic_eq(x: Expr[T], other: Expr[T]) -> bool:
"""Syntactic equality, ignoring the interpretation of the terms.

Expand Down
60 changes: 58 additions & 2 deletions tests/test_ops_syntax.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import functools
import inspect
from collections.abc import Callable, Mapping
from collections.abc import Callable, Iterable, Iterator, Mapping
from typing import Annotated, ClassVar, TypeVar

import pytest

import effectful.handlers.numbers # noqa: F401
from effectful.ops.semantics import call, evaluate, fvsof, handler
from effectful.ops.semantics import call, evaluate, fvsof, handler, typeof
from effectful.ops.syntax import (
Scoped,
_CustomSingleDispatchCallable,
deffn,
defop,
defstream,
defterm,
iter_,
next_,
)
from effectful.ops.types import Operation, Term

Expand Down Expand Up @@ -484,3 +487,56 @@ def _(self, x: bool) -> bool:
# Test that the method can be called with a handler
with handler({MyClass.my_singledispatch: lambda self, x: x + 6}):
assert instance.my_singledispatch(5) == 11


def test_defdata_iterable():
@defop
def cons_iterable(*args: int) -> Iterable[int]:
raise NotImplementedError

tm = cons_iterable(1, 2, 3)
assert isinstance(tm, Term)
assert isinstance(tm, Iterable)
assert issubclass(typeof(tm), Iterable)
assert tm.op is cons_iterable
assert tm.args == (1, 2, 3)

tm_iter = iter(tm)
assert isinstance(tm_iter, Term)
assert isinstance(tm_iter, Iterator)
assert issubclass(typeof(tm_iter), Iterator)
assert tm_iter.op is iter_

tm_iter_next = next(tm_iter)
assert isinstance(tm_iter_next, Term)
# assert isinstance(tm_iter_next, numbers.Number) # TODO
# assert issubclass(typeof(tm_iter_next), numbers.Number)
assert tm_iter_next.op is next_

assert list(tm.args) == [1, 2, 3]


def test_defstream_1():
x = defop(int, name="x")
y = defop(int, name="y")
tm = defstream(x() + y(), {x: [1, 2, 3], y: [x() + 1, x() + 2, x() + 3]})

assert isinstance(tm, Term)
assert isinstance(tm, Iterable)
assert issubclass(typeof(tm), Iterable)
assert tm.op is defstream

assert x not in fvsof(tm)
assert y not in fvsof(tm)

tm_iter = iter(tm)
assert isinstance(tm_iter, Term)
assert isinstance(tm_iter, Iterator)
assert issubclass(typeof(tm_iter), Iterator)
assert tm_iter.op is iter_

tm_iter_next = next(tm_iter)
assert isinstance(tm_iter_next, Term)
# assert isinstance(tm_iter_next, numbers.Number) # TODO
# assert issubclass(typeof(tm_iter_next), numbers.Number)
assert tm_iter_next.op is next_