Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion megatron/core/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class Fp8Recipe(str, enum.Enum):


class Fp4Recipe(str, enum.Enum):
"""FP4 recipe names: nvfp4."""
"""FP4 recipe names: nvfp4, mxfp4."""

nvfp4 = "nvfp4"
mxfp4 = "mxfp4"
38 changes: 22 additions & 16 deletions megatron/core/fp4_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@
pass


# True during MXFP4 phase; set to False by callback at healing (monkey-patches get_fp4_recipe).
_mxfp4_phase = True


def is_mxfp4_phase():
"""True when in MXFP4 phase; False after healing switches to FP8 DS or MXFP8."""
return _mxfp4_phase


# Check if Transformer Engine has class for fp4 tensors.
HAVE_TE_FP4_TENSOR_CLASS = False
if HAVE_TE:
Expand Down Expand Up @@ -59,26 +68,23 @@ def dequantize_fp4_tensor(fp4_tensor: torch.Tensor) -> torch.Tensor:
from megatron.core import parallel_state

def get_fp4_recipe(config: TransformerConfig):
"""Return fp4 recipe."""
if is_te_min_version("2.7.0.dev0"):
if config.fp4_recipe == Fp4Recipe.nvfp4:
try:
fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
except AttributeError:
raise ValueError(
"""NVFP4BlockScaling recipe is not available in this version of
Transformer Engine. Please make sure you are using TE version
>= 2.7.0.dev0."""
)
else:
"""Return fp4 recipe. Can be monkey-patched by callback at healing."""
if config.fp4_recipe == Fp4Recipe.nvfp4:
if not is_te_min_version("2.7.0.dev0"):
raise ValueError(
"NVFP4BlockScaling requires TransformerEngine >= 2.7.0.dev0."
)
fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
elif config.fp4_recipe == Fp4Recipe.mxfp4:
if not is_te_min_version("2.8.0"):
raise ValueError(
"NVFP4BlockScaling is the only supported FP4 recipe. "
"Please make sure you are using a compatible TE version >= 2.7.0.dev0."
"MXFP4BlockScaling requires TransformerEngine >= 2.8.0."
)
fp4_recipe = transformer_engine.common.recipe.MXFP4BlockScaling()
else:
raise ValueError(
"""FP4 support requires TransformerEngine version >= 2.7.0.dev0
for NVFP4BlockScaling."""
f"Unsupported FP4 recipe: {config.fp4_recipe}. "
"Supported: nvfp4, mxfp4."
)
return fp4_recipe

Expand Down
Loading