Skip to content

Conversation

@finbarrtimbers
Copy link
Collaborator

@finbarrtimbers finbarrtimbers commented Nov 24, 2025

This has a few advantages:

  1. Timing wise, we won't block on slow reward functions. This setup should scale well, both horizontally and vertically. Horizontally: We can have more actors, and each will manage their own rewards. Vertically: because we call the reward function asynchronously, we can have slower reward functions with minimal impact to overall throughput.
  2. This is conceptually cleaner, as now, when we have a completion done, it has all the information needed to process it.

Runs:


Note

Moves reward calculation into vLLM actors using an async RewardConfig-built function, attaching scores/metrics to GenerationResult and updating the training/eval pipeline to consume them.

  • Runtime/Actors:
    • Compute rewards inside LLMRayActor after generation; build async reward_fn from new RewardConfig and attach reward_scores/reward_metrics to GenerationResult.
    • create_vllm_engines/LLMRayActor now accept reward_config, train_dataset, eval_dataset to enable in-actor reward computation.
  • Training pipeline:
    • Remove reward_fn plumbing from grpo_fast; accumulate_inference_batches and eval now read result.reward_scores/reward_metrics and derive stats from them.
    • create_model_and_optimizer passes RewardConfig and datasets to engines.
  • Ground truth utilities:
    • Add apply_verifiable_reward and RewardConfig.build() producing the reward function; extend metrics (per-verifier averages and correct rates).
  • Data structures:
    • Extend GenerationResult with reward_scores and reward_metrics.
  • Tests:
    • Update tests to construct GenerationResult with reward_scores and remove explicit reward_fn usage where applicable.

Written by Cursor Bugbot for commit eff4c59. This will update automatically on new commits. Configure here.

@github-actions
Copy link
Contributor

Documentation Changes Detected

📄 sitemap.xml
--- site-base/sitemap.xml	2025-11-25 22:49:39.612567452 +0000
+++ site-pr/sitemap.xml	2025-11-25 22:49:36.235470329 +0000
@@ -9,6 +9,10 @@
          <lastmod>2025-11-25</lastmod>
     </url>
     <url>
+         <loc>https://github.com/allenai/open-instruct/nccl_hang_investigation/</loc>
+         <lastmod>2025-11-25</lastmod>
+    </url>
+    <url>
📄 sitemap.xml.gz
Binary files site-base/sitemap.xml.gz and site-pr/sitemap.xml.gz differ

Showing first 10 lines of diff for each changed file (up to 5 files, excluding search indices).

@github-actions
Copy link
Contributor

Documentation Changes Detected

📄 sitemap.xml
--- site-base/sitemap.xml	2025-11-26 00:44:18.134309052 +0000
+++ site-pr/sitemap.xml	2025-11-26 00:44:15.524327761 +0000
@@ -9,6 +9,10 @@
          <lastmod>2025-11-26</lastmod>
     </url>
     <url>
+         <loc>https://github.com/allenai/open-instruct/nccl_hang_investigation/</loc>
+         <lastmod>2025-11-26</lastmod>
+    </url>
+    <url>
📄 sitemap.xml.gz
Binary files site-base/sitemap.xml.gz and site-pr/sitemap.xml.gz differ

Showing first 10 lines of diff for each changed file (up to 5 files, excluding search indices).

@github-actions
Copy link
Contributor

Documentation Changes Detected

📄 sitemap.xml
--- site-base/sitemap.xml	2025-11-26 15:35:50.138416591 +0000
+++ site-pr/sitemap.xml	2025-11-26 15:35:47.157435630 +0000
@@ -9,6 +9,10 @@
          <lastmod>2025-11-26</lastmod>
     </url>
     <url>
+         <loc>https://github.com/allenai/open-instruct/nccl_hang_investigation/</loc>
+         <lastmod>2025-11-26</lastmod>
+    </url>
+    <url>
📄 sitemap.xml.gz
Binary files site-base/sitemap.xml.gz and site-pr/sitemap.xml.gz differ

Showing first 10 lines of diff for each changed file (up to 5 files, excluding search indices).

@github-actions
Copy link
Contributor

Documentation Changes Detected

📄 sitemap.xml
--- site-base/sitemap.xml	2025-11-26 15:56:58.752319600 +0000
+++ site-pr/sitemap.xml	2025-11-26 15:56:56.715339651 +0000
@@ -9,6 +9,10 @@
          <lastmod>2025-11-26</lastmod>
     </url>
     <url>
+         <loc>https://github.com/allenai/open-instruct/nccl_hang_investigation/</loc>
+         <lastmod>2025-11-26</lastmod>
+    </url>
+    <url>
📄 sitemap.xml.gz
Binary files site-base/sitemap.xml.gz and site-pr/sitemap.xml.gz differ

Showing first 10 lines of diff for each changed file (up to 5 files, excluding search indices).

@github-actions
Copy link
Contributor

Documentation Changes Detected

📄 sitemap.xml
--- site-base/sitemap.xml	2025-11-26 16:10:08.509203117 +0000
+++ site-pr/sitemap.xml	2025-11-26 16:10:06.233215713 +0000
@@ -9,6 +9,10 @@
          <lastmod>2025-11-26</lastmod>
     </url>
     <url>
+         <loc>https://github.com/allenai/open-instruct/nccl_hang_investigation/</loc>
+         <lastmod>2025-11-26</lastmod>
+    </url>
+    <url>
📄 sitemap.xml.gz
Binary files site-base/sitemap.xml.gz and site-pr/sitemap.xml.gz differ

Showing first 10 lines of diff for each changed file (up to 5 files, excluding search indices).

@github-actions
Copy link
Contributor

Documentation Changes Detected

📄 sitemap.xml
--- site-base/sitemap.xml	2025-11-26 17:09:57.006217805 +0000
+++ site-pr/sitemap.xml	2025-11-26 17:09:54.491226702 +0000
@@ -9,6 +9,10 @@
          <lastmod>2025-11-26</lastmod>
     </url>
     <url>
+         <loc>https://github.com/allenai/open-instruct/nccl_hang_investigation/</loc>
+         <lastmod>2025-11-26</lastmod>
+    </url>
+    <url>
📄 sitemap.xml.gz
Binary files site-base/sitemap.xml.gz and site-pr/sitemap.xml.gz differ

Showing first 10 lines of diff for each changed file (up to 5 files, excluding search indices).

@github-actions
Copy link
Contributor

Documentation Changes Detected

📄 sitemap.xml
--- site-base/sitemap.xml	2025-11-26 18:59:44.789581101 +0000
+++ site-pr/sitemap.xml	2025-11-26 18:59:41.664612135 +0000
@@ -9,6 +9,10 @@
          <lastmod>2025-11-26</lastmod>
     </url>
     <url>
+         <loc>https://github.com/allenai/open-instruct/nccl_hang_investigation/</loc>
+         <lastmod>2025-11-26</lastmod>
+    </url>
+    <url>
📄 sitemap.xml.gz
Binary files site-base/sitemap.xml.gz and site-pr/sitemap.xml.gz differ

Showing first 10 lines of diff for each changed file (up to 5 files, excluding search indices).

@github-actions
Copy link
Contributor

Documentation Changes Detected

📄 sitemap.xml
--- site-base/sitemap.xml	2025-11-26 20:10:32.684389026 +0000
+++ site-pr/sitemap.xml	2025-11-26 20:10:29.655391830 +0000
@@ -9,6 +9,10 @@
          <lastmod>2025-11-26</lastmod>
     </url>
     <url>
+         <loc>https://github.com/allenai/open-instruct/nccl_hang_investigation/</loc>
+         <lastmod>2025-11-26</lastmod>
+    </url>
+    <url>
📄 sitemap.xml.gz
Binary files site-base/sitemap.xml.gz and site-pr/sitemap.xml.gz differ

Showing first 10 lines of diff for each changed file (up to 5 files, excluding search indices).

@github-actions
Copy link
Contributor

Documentation Changes Detected

📄 sitemap.xml
--- site-base/sitemap.xml	2025-11-26 20:40:29.066222102 +0000
+++ site-pr/sitemap.xml	2025-11-26 20:40:26.518212992 +0000
@@ -9,6 +9,10 @@
          <lastmod>2025-11-26</lastmod>
     </url>
     <url>
+         <loc>https://github.com/allenai/open-instruct/nccl_hang_investigation/</loc>
+         <lastmod>2025-11-26</lastmod>
+    </url>
+    <url>
📄 sitemap.xml.gz
Binary files site-base/sitemap.xml.gz and site-pr/sitemap.xml.gz differ

Showing first 10 lines of diff for each changed file (up to 5 files, excluding search indices).

@finbarrtimbers finbarrtimbers marked this pull request as ready for review December 1, 2025 22:02
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines 547 to 551
async def compute_rewards(
actor: "LLMRayActor", result: GenerationResult, dataset: datasets.Dataset, is_eval: bool
) -> tuple[list[float], dict]:
example = dataset[result.dataset_index]
decoded_responses = actor.llm_engine.tokenizer.batch_decode(result.responses)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Decode rewards without stripping special tokens

compute_rewards now decodes completions via actor.llm_engine.tokenizer.batch_decode(result.responses) without skip_special_tokens=True, whereas reward computation previously decoded with special tokens removed (e.g., in accumulate_inference_batches). For tokenizers that inject BOS/EOS markers, those tokens are passed to the verifiers, so format and ground-truth checks see extra tokens and mis-classify otherwise correct responses, distorting reward signals for every request.

Useful? React with 👍 / 👎.

Copy link
Collaborator

@hamishivi hamishivi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this generally looks good, but one higher level question: lets say I want to add extra information into my reward computation, e.g. I have some ongoing buffer of samples I want the reward computation to depend on, how easy is this? Since now the reward_fn thread has been moved into the llm actors (which I think is correct high level, just thinking about hackability).

@finbarrtimbers
Copy link
Collaborator Author

Fixed the skip_special_tokens=True issue!

dataset = actor.eval_dataset if is_eval else actor.train_dataset
result.reward_scores, result.reward_metrics = await compute_rewards(actor, result, dataset, is_eval)
results_queue = actor.eval_results_queue if is_eval else actor.results_queue
results_queue.put(result)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Rewards computed before empty response EOS modification

The reward computation now happens in finalize_completed_request on original responses, but the empty response handling (appending EOS token when finish_reason == "stop" and response is empty) happens later in accumulate_inference_batches. Previously, the empty response modification occurred BEFORE reward computation. This means rewards are now computed on potentially empty [] responses, while training uses modified responses containing [eos_token_id]. The TODO comment at line 1790 acknowledges this needs to be moved to LLMRayActor, but in the current state there's a mismatch between what the reward function evaluates and what the model trains on.

Additional Locations (1)

Fix in Cursor Fix in Web

Copy link
Collaborator

@hamishivi hamishivi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm :)

actor.request_outputs[base_request_id]["outputs"].append(sub_request["request_output"])

if len(actor.request_outputs[base_request_id]["outputs"]) == expected_n:
asyncio.run_coroutine_threadsafe(finalize_completed_request(actor, base_request_id), actor.loop)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Unhandled exceptions in async finalization cause silent failures

The Future returned by asyncio.run_coroutine_threadsafe(finalize_completed_request(...), actor.loop) is discarded, meaning exceptions in compute_rewards or elsewhere in the async function are silently swallowed. Combined with the fact that actor.request_outputs.pop(base_request_id) happens before the await compute_rewards(...) call, if reward computation fails the request data is already removed and the result never reaches the queue. The data preparation thread in accumulate_inference_batches would then hang waiting for results that will never arrive. The previous synchronous approach had exceptions bubble up visibly.

Additional Locations (1)

Fix in Cursor Fix in Web

@finbarrtimbers finbarrtimbers added this pull request to the merge queue Dec 2, 2025
Merged via the queue into main with commit 0faba3c Dec 2, 2025
6 checks passed
finbarrtimbers added a commit that referenced this pull request Dec 3, 2025
After PR #1225 moved reward_fn to live inside LLMRayActor, these
references were left behind during the merge. This removes:
- reward_fn parameter from maybe_evaluate() and run_training()
- reward_fn from accumulate_inference_batches() calls
- reward_fn from create_model_and_optimizer return value unpacking
- Unused Callable import

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants