-
Notifications
You must be signed in to change notification settings - Fork 3
Add a sample function that computes a batch of template results in parallel
#417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: staging-llm
Are you sure you want to change the base?
Changes from all commits
cefaf7e
27cdbed
0d8532b
a03370f
427a349
4dab7aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,14 @@ | ||
| import functools | ||
| import threading | ||
| from collections import Counter | ||
| from collections.abc import Callable, Sequence | ||
| from concurrent import futures | ||
| from concurrent.futures.thread import ThreadPoolExecutor | ||
|
|
||
| from effectful.handlers.llm import Template | ||
| from effectful.handlers.llm.providers import completion, tool_call | ||
| from effectful.internals.runtime import get_interpretation, interpreter | ||
| from effectful.ops.semantics import fwd | ||
| from effectful.ops.semantics import fwd, handler | ||
| from effectful.ops.syntax import ObjectInterpretation, implements | ||
|
|
||
|
|
||
|
|
@@ -45,3 +49,51 @@ def n_votes_ahead(): | |
| tasks.append(executor.submit(interpreter(intp)(fwd), *args, **kwargs)) | ||
| executor.shutdown() | ||
| return self.votes.most_common(1)[0][0] | ||
|
|
||
|
|
||
| def sample[**P, T](template: Template[P, T], n: int) -> Callable[P, Sequence[T]]: | ||
| """sample returns a function with the same signature as `template` except | ||
| that `n` samples are returned. | ||
|
|
||
| When computing a batch of samples, calls to `completion` (and handlers of | ||
| `completion`) proceed in parallel, but other calls (e.g. tool calls) proceed | ||
| synchronously. | ||
|
|
||
| """ | ||
| lock = threading.Lock() | ||
|
|
||
| def _completion(*args, **kwargs): | ||
| lock.release() | ||
| result = fwd() | ||
| lock.acquire() | ||
| return result | ||
|
|
||
| def _tool_call(*args, **kwargs): | ||
| lock.acquire() | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this call to |
||
| try: | ||
| result = fwd() | ||
| except Exception as e: | ||
| lock.release() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if fwd() doesn't raise an exception this won't release the lock no? with lock:
return fwd()
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also it might be good to have some test for this functionality, mocking the llm, just checking that we're doing concurrency correctly
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We actually don't want to. If the tool call completes successfully, we keep the lock up to the point that we start another This raises a good point though: what should we do when one of the threads fails? Right now we reraise the exception from the failed thread and you get no samples. Maybe instead you should get the successful samples?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed on tests. |
||
| raise e | ||
| return result | ||
|
|
||
| @functools.wraps(template) | ||
| @handler({completion: _completion, tool_call: _tool_call}) | ||
| def wrapper(*args, **kwargs): | ||
| with ThreadPoolExecutor() as executor: | ||
|
|
||
| @interpreter(get_interpretation()) | ||
| def do_work(): | ||
| lock.acquire() | ||
| try: | ||
| result = template(*args, **kwargs) | ||
| finally: | ||
| assert lock.locked() | ||
| lock.release() | ||
| return result | ||
|
|
||
| tasks = [executor.submit(do_work) for _ in range(n)] | ||
| completed = futures.wait(tasks, return_when=futures.ALL_COMPLETED) | ||
| return [t.result() for t in completed.done] | ||
|
|
||
| return wrapper | ||
Uh oh!
There was an error while loading. Please reload this page.