-
Notifications
You must be signed in to change notification settings - Fork 466
Implements an OLMo-core compatible data loader to load HF datasets from get_cached_dataset_tulu.
#1208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| percent_solved = np.mean(scores).item() / args.max_possible_score | ||
| # Don't resample prompt that was solved at more than no_resample_positive_rate | ||
| if no_resampling_pass_rate is not None and percent_solved >= no_resampling_pass_rate: | ||
| iter_dataloader.exclude_index(result.dataset_index) | ||
| total_no_resampled += 1 | ||
| logging.debug( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Enforce no_resampling_pass_rate when prompt is solved
The no_resampling_pass_rate flag is now effectively ignored. In accumulate_inference_batches, when a prompt’s average score exceeds the threshold (lines 1678‑1682) the code only increments total_no_resampled and logs, but no longer excludes that prompt from future sampling (the previous iter_dataloader.exclude_index(...) call was removed). With the flag set, solved prompts will keep being fed back through add_prompt_to_generator(next(data_loader), …) and resampled indefinitely instead of being retired, so training continues to spend compute on already-solved items.
Useful? React with 👍 / 👎.
hamishivi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice stuff! I like that this PR is actually reducing the code complexity a little. But a few questions before I'm happy with merging...
|
|
||
| @property | ||
| def total_batches(self) -> int: | ||
| """Return the total number of batches in an epoch.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this doc string right? Isn't effective size number of samples in the dataset, from line 38?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, you're right. Fixed!
| percent_solved = np.mean(scores).item() / args.max_possible_score | ||
| # Don't resample prompt that was solved at more than no_resample_positive_rate | ||
| if no_resampling_pass_rate is not None and percent_solved >= no_resampling_pass_rate: | ||
| iter_dataloader.exclude_index(result.dataset_index) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was this change intentional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mistake! Fixed.
| """Return an iterable over all batches in the epoch.""" | ||
| for i in range(self.batches_processed, self.effective_size): | ||
| example = self.dataset[i] | ||
| yield example | {"prompt_id": f"{self._epoch}_{example['dataset_index']}"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: batches_processed never incremented during iteration
The _iter_batches method yields items but never increments self.batches_processed. This breaks checkpointing because state_dict will always save batches_processed at its initial value (0 or whatever was restored). When resuming from a checkpoint, the data loader will restart from the beginning of the epoch instead of continuing from where it left off. The loop variable i should be used to update batches_processed after each yield, or batches_processed should be incremented in the __next__ method.
| self.reshuffle() | ||
| self._current_iter = self._iter_batches() | ||
| return next(self._current_iter) | ||
| raise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Evaluation data loader exhausted after first use
When automatic_reshuffle=False (the default), after the first complete iteration through the dataset, _current_iter remains as an exhausted generator and is never reset. Subsequent calls to __next__ will immediately raise StopIteration without yielding any items. This breaks evaluation in grpo_fast.py where eval_data_loader is created without automatic_reshuffle=True and is iterated over multiple times during training. After the first evaluation completes, all subsequent evaluations will process zero examples because the iterator is exhausted.
hamishivi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dont think the fix got pushed for two comments 😅 I think the new cursor bugbot comments are a bit valid, and left a comment on the automatic reshuffle!
| """Whether to filter out prompts with zero reward std (all samples have the same score).""" | ||
| no_resampling_pass_rate: float | None = None | ||
| """If the response to a prompt is solved at a rate higher than this, do not resample this prompt again""" | ||
| automatic_reshuffle: bool = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this flag is false, then training just errors right? is there a reason to have it configurable?
There's not a ton of point to this, right now. But by making this change, we can subsequently:
dpo_tune_cache.pyover to use this, which will (eventually) let that use the Olmo-core trainermain. This will let us removeadd_prompt_to_generatorand have them operate fully asynchronously. This will make switching the entiregrpo_fast.pyscript over to the data_loader/trainer pattern easier.Runs:
Note
Adds an olmo_core-compatible HF dataset loader and rewires GRPO, queues, and vLLM utils to use prompt_id-based requests with checkpointable reshuffling.
DataLoaderBasewrapper for HuggingFaceDatasetwith sharding, shuffling/automatic reshuffle, exclusion, checkpointable state, andprompt_idemission.ShufflingIteratorandPendingQueriesMapwithHFDataLoader-driven sampling and replenishment.accumulate_inference_batches) to fetch examples fromprompt_datasetbydataset_indexand handle no-resample/exclusion via loader.add_prompt_to_generatorto enqueue single-examplePromptRequestusingdataset_indexandprompt_id.PromptRequest/GenerationResult: requiredataset_index, addprompt_id, removeepoch_number/training_stepfields.train|eval_{prompt_id}.prompt_id.test_data_loader.py; refactor GRPO and vLLM utils tests to new loader/queue semantics and prompt-based flow.ai2-olmo-core==2.3.0to integrate withDataLoaderBase.Written by Cursor Bugbot for commit 3107302. This will update automatically on new commits. Configure here.