-
Notifications
You must be signed in to change notification settings - Fork 74
Handle granite 4 as MoE models in training #669
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughExtended MoE handling: added detection for Changes
Sequence Diagram(s)sequenceDiagram
participant Main as main_ds.py
participant Model as model.py
participant Utils as utils.py
participant Loss as batch_loss_manager.py
Main->>Model: Initialize(model_path)
Model->>Model: detect is_gpt_oss / is_granitemoehybrid (is_known_model)
Main->>Utils: freeze_router_params(model)
alt Router params frozen
Utils-->>Main: return True
Main->>Main: set fsdp_use_orig_params / log frozen router
else No router params frozen
Utils-->>Main: return False
end
Main->>Model: forward batch
Model-->>Loss: pass outputs (may include aux_loss)
Loss->>Loss: _compute_average_loss()
alt accumulated_aux_loss not None
Loss->>Loss: add accumulated_aux_loss to total_batch_loss (no model flag gate)
end
Loss-->>Main: return reduced loss
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
🔇 Additional comments (1)
Tip 📝 Customizable high-level summaries are now available in beta!You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.
Example instruction:
Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
src/instructlab/training/main_ds.py (1)
349-349: Consider clarifying or removing the "NOTE is this guard needed?" comment.The guard
if m.is_gpt_oss or m.is_granitemoehybrid:appears necessary to ensure that router parameter freezing andfsdp_should_use_orig_paramsconfiguration only apply to MoE models. Without this guard, non-MoE models would unnecessarily go through the router freezing logic, andfreeze_router_paramswould always returnFalsefor them (since they have no router parameters).If the guard is intentional and necessary, consider removing the "NOTE is this guard needed?" comment to avoid confusion. If there's genuine uncertainty about whether this guard is required, please clarify the intended behavior.
src/instructlab/training/model.py (1)
421-428: Consider clarifying the guard necessity.The guard
if (self.is_gpt_oss or self.is_granitemoehybrid)before checking foroutput.aux_lossappears to be an optimization to avoid unnecessaryhasattrchecks on non-MoE models. However, the guard may be redundant since the subsequent checks (hasattr(output, "aux_loss") and output.aux_loss is not None) would safely handle non-MoE models anyway.The "NOTE is this guard needed?" comment suggests uncertainty. Consider either:
- Removing the guard if the
hasattrandis not Nonechecks are sufficient- Removing the comment if the guard provides meaningful performance benefits or code clarity
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
src/instructlab/training/batch_loss_manager.py(1 hunks)src/instructlab/training/gpt_oss_utils_correct.py(1 hunks)src/instructlab/training/main_ds.py(1 hunks)src/instructlab/training/model.py(4 hunks)src/instructlab/training/utils.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/instructlab/training/model.py (1)
src/instructlab/training/gpt_oss_utils_correct.py (2)
is_gpt_oss(397-411)is_known_model(414-429)
src/instructlab/training/main_ds.py (2)
src/instructlab/training/gpt_oss_utils_correct.py (1)
is_gpt_oss(397-411)src/instructlab/training/utils.py (1)
freeze_router_params(903-926)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: unit: 3.13 on ubuntu-latest
- GitHub Check: unit: 3.11 on ubuntu-latest
- GitHub Check: unit: 3.12 on ubuntu-latest
- GitHub Check: pylint
- GitHub Check: Summary
🔇 Additional comments (6)
src/instructlab/training/gpt_oss_utils_correct.py (1)
414-430: LGTM! Well-structured utility function.The
is_known_modelfunction provides a clean generalization of theis_gpt_osspattern, enabling support for multiple model types including granitemoehybrid. The implementation properly handles both string and list inputs forknown_model_type, and follows the same validation pattern asis_gpt_oss.src/instructlab/training/utils.py (1)
903-927: LGTM! Improved return semantics for MoE models.The updated
freeze_router_paramsfunction now correctly returnsTrueonly when router parameters were actually frozen, rather than always returningTrue. The docstring and log messages have been appropriately generalized from "GPT-OSS" to "MoE models," aligning with the PR's objective to support granitemoehybrid MoE models.The caller in
main_ds.py(lines 350-355) properly handles the new return value.src/instructlab/training/batch_loss_manager.py (1)
177-178: LGTM! Generalized auxiliary loss handling.Removing the
is_gpt_ossgate and applying auxiliary loss wheneveraccumulated_aux_loss is not Nonecorrectly generalizes the logic to support both GPT-OSS and granitemoehybrid MoE models. This change aligns with the broader PR objective to extend MoE support beyond GPT-OSS.src/instructlab/training/main_ds.py (1)
349-355: LGTM! Correctly extends MoE support to granitemoehybrid.The gate expansion to include
is_granitemoehybridproperly extends router parameter freezing to support granite 4 MoE models. The code correctly captures and uses the return value fromfreeze_router_paramsto conditionally setfsdp_should_use_orig_params, which aligns with the updated return semantics inutils.py.src/instructlab/training/model.py (2)
46-46: LGTM! Proper initialization of granitemoehybrid detection.The import of
is_known_modeland initialization ofself.is_granitemoehybridusingis_known_model(model_path, "granitemoehybrid")correctly follows the established pattern used foris_gpt_oss. This enables proper detection of granite 4 MoE models throughout the training pipeline.Also applies to: 68-68
433-434: LGTM! Simplified auxiliary loss application.Removing the
is_gpt_osscheck and applying auxiliary loss wheneveraux_loss is not Nonecorrectly generalizes the logic for all MoE models. This change is consistent with the similar update inbatch_loss_manager.py(line 177) and aligns with the PR's objective to support granitemoehybrid MoE models.
|
The first failure is
The second through fifth failures are due to missing EC2 credential. |
|
@mergify rebase |
✅ Branch has been successfully rebased |
e8f8922 to
c97c9ba
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
src/instructlab/training/batch_loss_manager.py(1 hunks)src/instructlab/training/gpt_oss_utils_correct.py(1 hunks)src/instructlab/training/main_ds.py(1 hunks)src/instructlab/training/model.py(4 hunks)src/instructlab/training/utils.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/instructlab/training/model.py (1)
src/instructlab/training/gpt_oss_utils_correct.py (2)
is_gpt_oss(397-411)is_known_model(414-433)
src/instructlab/training/main_ds.py (2)
src/instructlab/training/gpt_oss_utils_correct.py (1)
is_gpt_oss(397-411)src/instructlab/training/utils.py (1)
freeze_router_params(903-926)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: unit: 3.12 on ubuntu-latest
- GitHub Check: unit: 3.13 on ubuntu-latest
- GitHub Check: unit: 3.11 on ubuntu-latest
- GitHub Check: pylint
- GitHub Check: Summary
🔇 Additional comments (7)
src/instructlab/training/utils.py (1)
903-927: LGTM! Improved return semantics and broader MoE support.The updated docstring and return semantics accurately reflect the broader MoE context. Returning
Trueonly when router parameters are actually frozen (rather than unconditionally) is more precise and aligns well with the updated usage insrc/instructlab/training/main_ds.py(line 351).src/instructlab/training/model.py (4)
46-46: LGTM! Proper import of the new utility.The import of
is_known_modelis appropriate for detecting granitemoehybrid models.
68-68: LGTM! New attribute for granitemoehybrid detection.The new
is_granitemoehybridattribute follows the same pattern asis_gpt_ossand enables proper MoE model detection.
421-428: LGTM! Expanded MoE auxiliary loss gating.The condition correctly includes both GPT-OSS and granitemoehybrid models when checking for auxiliary loss presence.
433-434: LGTM! Unconditional auxiliary loss application.The removal of the GPT-OSS-specific gate aligns with the broader MoE support. The auxiliary loss is now applied whenever present, which is appropriate for multiple MoE model types.
src/instructlab/training/batch_loss_manager.py (1)
177-178: Auxiliary loss compatibility verified—change is correct.Both GPT-OSS and granitemoehybrid models extract
aux_lossthrough the identical path inmodel.py(lines 421–427), converting it viaoutput.aux_loss.float(). The unconditional check at lines 177–178 is valid because both model types produceaux_lossin the same format, andaccumulated_aux_lossremains0.0when no auxiliary loss exists. No compatibility issues found.Note: The comment at
model.py:420referencing only GPT-OSS is outdated and should be updated to reflect that both MoE model types now produce auxiliary loss.src/instructlab/training/main_ds.py (1)
349-355: Code changes verified and approved.The model_type string "granitemoehybrid" is confirmed as correct for IBM Granite 4 MoE models in HuggingFace transformers. The logic correctly extends MoE router parameter freezing to granitemoehybrid models and appropriately uses the return value from
freeze_router_params.
RobotSail
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Granite 4 models are MoE models. They have router parameters which should be frozen in the training phase, and the auxiliary losses should be accumulated. This PR is to apply the existing code for freezing the parameters and accumulating the losses for GPT-OSS model to granite 4 models as well.
Summary by CodeRabbit
New Features
Refactor
Documentation