Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion effectful/handlers/llm/sampling.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import functools
import threading
from collections import Counter
from collections.abc import Callable, Sequence
from concurrent import futures
from concurrent.futures.thread import ThreadPoolExecutor

from effectful.handlers.llm import Template
from effectful.handlers.llm.providers import completion, tool_call
from effectful.internals.runtime import get_interpretation, interpreter
from effectful.ops.semantics import fwd
from effectful.ops.semantics import fwd, handler
from effectful.ops.syntax import ObjectInterpretation, implements


Expand Down Expand Up @@ -45,3 +49,51 @@ def n_votes_ahead():
tasks.append(executor.submit(interpreter(intp)(fwd), *args, **kwargs))
executor.shutdown()
return self.votes.most_common(1)[0][0]


def sample[**P, T](template: Template[P, T], n: int) -> Callable[P, Sequence[T]]:
"""sample returns a function with the same signature as `template` except
that `n` samples are returned.

When computing a batch of samples, calls to `completion` (and handlers of
`completion`) proceed in parallel, but other calls (e.g. tool calls) proceed
synchronously.

"""
lock = threading.Lock()

def _completion(*args, **kwargs):
lock.release()
result = fwd()
lock.acquire()
return result

def _tool_call(*args, **kwargs):
lock.acquire()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this call to acquire isn't necessary.

try:
result = fwd()
except Exception as e:
lock.release()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if fwd() doesn't raise an exception this won't release the lock no?

with lock:
    return fwd()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also it might be good to have some test for this functionality, mocking the llm, just checking that we're doing concurrency correctly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if fwd() doesn't raise an exception this won't release the lock no?

with lock:
    return fwd()

We actually don't want to. If the tool call completes successfully, we keep the lock up to the point that we start another completion. Releasing the lock here allows the remaining threads to proceed as this thread dies. The fwd in the completion handler can fail safely because we don't hold the lock.

This raises a good point though: what should we do when one of the threads fails? Right now we reraise the exception from the failed thread and you get no samples. Maybe instead you should get the successful samples?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed on tests.

raise e
return result

@functools.wraps(template)
@handler({completion: _completion, tool_call: _tool_call})
def wrapper(*args, **kwargs):
with ThreadPoolExecutor() as executor:

@interpreter(get_interpretation())
def do_work():
lock.acquire()
try:
result = template(*args, **kwargs)
finally:
assert lock.locked()
lock.release()
return result

tasks = [executor.submit(do_work) for _ in range(n)]
completed = futures.wait(tasks, return_when=futures.ALL_COMPLETED)
return [t.result() for t in completed.done]

return wrapper
Loading