Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
3808f13
Add more loss calibration comments for ce loss
jayhenry Oct 23, 2025
50a0179
[ENV] add XTUNER_USE_CUTLASS_GROUP_GEMM
jayhenry Oct 23, 2025
3706795
fix lint
jayhenry Oct 24, 2025
4eee27e
refine debug skip save
jayhenry Oct 24, 2025
f733fb8
1) register debug grad_hook to loss and final logits. 2) eager ce los…
jayhenry Oct 27, 2025
7835a3b
fake transformers module to run lower version 4.51.0
jayhenry Oct 27, 2025
c5a369e
modify moe_grouped_gemm's trans_b default arg value
jayhenry Oct 27, 2025
e2674b2
add more debug_acc breakpoint
jayhenry Oct 27, 2025
fcb5a78
add rope comment
jayhenry Oct 27, 2025
3db85e6
1) add AccProber.before_grad_clip. 2) remove resolve pad token config…
jayhenry Oct 27, 2025
ab4f8de
fix qwen3 pad token id to None
jayhenry Oct 27, 2025
efbfdc5
AccProber dump forward activation records
jayhenry Oct 27, 2025
dcf5d6d
acc prober add lm_head module and origin ce_loss
jayhenry Oct 27, 2025
7324e14
final logits use fp32
jayhenry Oct 28, 2025
2fa3529
turn off clip_grad_norm
jayhenry Oct 28, 2025
432f95e
add base prober class
jayhenry Oct 28, 2025
64d1755
improve base prober
jayhenry Oct 28, 2025
7611c55
add acc prober to prober list
jayhenry Oct 28, 2025
54ba5a8
rename prober module
jayhenry Oct 28, 2025
70198c0
prober add layer hook
jayhenry Oct 28, 2025
42877d9
add prober to trainer args
jayhenry Oct 28, 2025
11eba2e
acc prober add hooks for attn block
jayhenry Oct 29, 2025
1b2244a
acc prober add moe block
jayhenry Oct 29, 2025
24cdfbd
refine liger loss calibration
jayhenry Oct 30, 2025
b8e736d
acc prober add balancing loss
jayhenry Oct 30, 2025
ff3464f
acc prober add z_loss
jayhenry Oct 30, 2025
7a4e0b2
add grad_norm_dtype
jayhenry Oct 30, 2025
cba1365
move cal_grad_norm to utils
jayhenry Oct 30, 2025
1f5d7fa
rollback unnecessary changes
jayhenry Oct 31, 2025
11869f9
remove DEBUG_ACC, use PdbProber instead
jayhenry Oct 31, 2025
54052d3
set qwen model's pad token id to None
jayhenry Oct 31, 2025
8076da1
fix unit test and lint
jayhenry Oct 31, 2025
0337398
fix ut
jayhenry Oct 31, 2025
424958d
old prober and new prober mix to check old accuracy profile
jayhenry Oct 31, 2025
d24f2aa
All prober use new wrapper format
jayhenry Oct 31, 2025
92df4cc
fix lint
jayhenry Oct 31, 2025
0be74a6
fix trainer config grad_norm_dtype
jayhenry Oct 31, 2025
f4e0319
[Fix] set deepseek v3 pad token id = None
jayhenry Nov 4, 2025
4eb6bb9
fix rebase typo
jayhenry Nov 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/train/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self):
self.grad_norm_calls = 0
self.optimizer_step_calls = 0

model = nn.Linear(10, 10)
self.model = model = nn.Linear(10, 10)
self.optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def grad_accumulation_steps(self, *args, **kwargs):
Expand All @@ -63,7 +63,7 @@ def step_optimizer(self, *args, **kwargs):
self.optimizer_step_calls += 1
return 1.0

def clip_grad_norm(self):
def clip_grad_norm(self, do_clip: bool=True, dtype=torch.float32):
self.grad_norm_calls += 1
return torch.tensor(1.0)

Expand Down
95 changes: 25 additions & 70 deletions xtuner/v1/engine/train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
import threading
from concurrent.futures import wait
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, cast
from typing import Any, Dict, Optional, cast

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from safetensors import safe_open
from torch.distributed._tensor import DTensor, Replicate, Shard
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_model_state_dict,
Expand All @@ -18,10 +17,9 @@
set_optimizer_state_dict,
)
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.placement_types import Placement
from torch.nn.utils.clip_grad import _no_grad
from torch.utils._foreach_utils import (
_device_has_foreach_support,
_has_foreach_support,
)

from xtuner.v1.config import FSDPConfig, OptimConfig
Expand All @@ -30,7 +28,9 @@
from xtuner.v1.model.base import BaseModel, ModelItem, TransformerConfig
from xtuner.v1.model.utils import ModelForwardExtraLogInfo
from xtuner.v1.module.router import NoAuxRouterConfig
from xtuner.v1.profiler.prober import ProberList
from xtuner.v1.utils import get_device, get_logger, get_torch_device_module, profile_time_and_memory
from xtuner.v1.utils.grad_norm import cal_grad_norm


logger = get_logger()
Expand Down Expand Up @@ -230,7 +230,10 @@ def train_step(self, data_batches: list[ModelItem]):
self._count += 1

train_engine_extra_info = ModelForwardExtraLogInfo()
micro_batch_iter = 0
for i in range(0, len(data_batches), intra_layer_micro_batch):
ProberList.set_micro_batch_iter(micro_batch_iter)
micro_batch_iter += 1
data_batch = data_batches[i : i + intra_layer_micro_batch]
seq_ctx_list = []
loss_ctx_list = []
Expand Down Expand Up @@ -282,6 +285,8 @@ def train_step(self, data_batches: list[ModelItem]):

del output
loss.backward()
# call dump_forward_records after backward to record the recomputed activations
ProberList.after_micro_iter_forward()
step_loss += loss.detach().clone()

if moe_need_update_bias:
Expand Down Expand Up @@ -315,75 +320,25 @@ def from_hf(self, hf_path: str | Path, strict: bool = False):
def init_model_weights(self):
self.model.init_weights()

@staticmethod
def group_tensors_by_device_mesh_and_placements(tensors: List[torch.Tensor]):
grouped_tensors: Dict[Tuple[DeviceMesh, Tuple[Placement, ...]], List[torch.Tensor]] = {}
for tensor in tensors:
assert isinstance(tensor, DTensor)
key = (tensor.device_mesh, tensor.placements)
if key in grouped_tensors:
grouped_tensors[key].append(tensor)
else:
grouped_tensors[key] = [tensor]
return grouped_tensors

def cal_total_norm(self, tensors: List[DTensor], norm_type: float = 2.0, foreach: Optional[bool] = None):
norm_type = float(norm_type)
if len(tensors) == 0:
return torch.tensor(0.0)

device_mesh: DeviceMesh = tensors[0].device_mesh
placements = tensors[0].placements
device = tensors[0].device
norms: Tuple[DTensor, ...]
if (foreach is None and _has_foreach_support(tensors, device)) or ( # type: ignore
foreach and _device_has_foreach_support(device)
):
norms = torch._foreach_norm(tensors, norm_type) # type: ignore
elif foreach:
raise RuntimeError(f"foreach=True was passed, but can't use the foreach API on {device.type} tensors")
else:
norms = tuple(torch.linalg.vector_norm(g, norm_type) for g in tensors)

local_norm = torch.linalg.vector_norm(
torch.stack([norm.to_local() for norm in norms]), norm_type, dtype=torch.float32
)
if norm_type == 2:
local_norm_squared = local_norm**2
for i, placement in enumerate(placements):
if isinstance(placement, Shard):
# When using ep + fsdp, the placement corresponding to fsdp mesh is _StridedShard
# isinstance(_StridedShard, Shard) is True
dist.all_reduce(local_norm_squared, group=device_mesh.get_group(i))
elif isinstance(placement, Replicate):
pass
else:
raise ValueError(f"Unsupported placement type {placement} in clip_grad_norm")
global_norm = local_norm_squared**0.5
else:
raise NotImplementedError
return global_norm

def clip_grad_norm(self):
@_no_grad
def clip_grad_norm(self, do_clip: bool = True, dtype=torch.float32):
ProberList.before_clip_grad_norm()
self.model.scale_and_reduce_grad()
params = self.model.trainable_parameters()
grads = [p.grad for _, p in params if p.grad is not None]
grouped_grads = self.group_tensors_by_device_mesh_and_placements(grads)
total_norms = []
for grads in grouped_grads.values():
total_norm = self.cal_total_norm(grads, norm_type=2.0, foreach=True)
total_norms.append(total_norm)
grad_norm = torch.linalg.vector_norm(torch.stack(total_norms), ord=2.0, dtype=torch.float32)
clip_coef = self.optim_cfg.max_grad_norm / (grad_norm + 1e-6)
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for grads in grouped_grads.values():
device = grads[0].device
if _device_has_foreach_support(device):
torch._foreach_mul_(grads, clip_coef_clamped.to(device))
else:
clip_coef_clamped_device = clip_coef_clamped.to(device)
for g in grads:
g.mul_(clip_coef_clamped_device)
grad_norm, grouped_grads = cal_grad_norm(grads, dtype=dtype)
if do_clip:
clip_coef = self.optim_cfg.max_grad_norm / (grad_norm + 1e-6)
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for grads in grouped_grads.values():
device = grads[0].device
if _device_has_foreach_support(device):
torch._foreach_mul_(grads, clip_coef_clamped.to(device))
else:
clip_coef_clamped_device = clip_coef_clamped.to(device)
for g in grads:
g.mul_(clip_coef_clamped_device)
ProberList.after_clip_grad_norm(grad_norm)
return grad_norm

def step_optimizer(self, grad_norm):
Expand Down
5 changes: 5 additions & 0 deletions xtuner/v1/loss/base_loss_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,12 @@ class BaseLossKwargs(BaseModel):

def chunk(self, chunk_size) -> list["BaseLossKwargs"]:
tensor_fields: dict[str, tuple[torch.Tensor, ...]] = {}
nontensor_fields: dict[str, Any] = {}
for field_name, field_value in self.__dict__.items():
if isinstance(field_value, torch.Tensor):
tensor_fields[field_name] = torch.split(field_value, chunk_size, dim=1)
else:
nontensor_fields[field_name] = field_value

assert len(tensor_fields) > 0, "At least one field should be a tensor to chunk."

Expand All @@ -62,6 +65,8 @@ def chunk(self, chunk_size) -> list["BaseLossKwargs"]:
chunk_dict = {}
for field_name, splits in tensor_fields.items():
chunk_dict[field_name] = splits[i]
for field_name, field_value in nontensor_fields.items():
chunk_dict[field_name] = field_value
chunks.append(type(self)(**chunk_dict))
return chunks

Expand Down
20 changes: 16 additions & 4 deletions xtuner/v1/loss/ce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from xtuner.v1.loss import BaseLossConfig, BaseLossContext, BaseLossKwargs

# from xtuner.v1.profiler.prober import ProberList
from .utils import sp_gather, sp_split


Expand Down Expand Up @@ -49,6 +50,7 @@ class CELossKwargs(BaseLossKwargs):

shifted_labels: torch.Tensor
loss_weight: torch.Tensor
global_denominator: int


class CELossContextInputItem(BaseModel):
Expand Down Expand Up @@ -158,6 +160,7 @@ def build_batches_loss_kwargs(
loss_kwargs = CELossKwargs(
shifted_labels=shifted_labels,
loss_weight=loss_weight,
global_denominator=int(global_denominator),
)
batches_loss_kwargs.append(loss_kwargs)
return batches_loss_kwargs
Expand Down Expand Up @@ -188,6 +191,13 @@ def loss_fn(
# Step 2.b in the loss calculation: sum the loss over all tokens
loss = (loss * loss_weight).sum()

# Below is the old implementation of loss calibration for accuracy regression
# loss = F.cross_entropy(logits, shifted_labels, reduction="sum", ignore_index=self.loss_cfg.ignore_idx)
# ProberList.record_tensor(loss, "[lm_head.ce_loss][before calibration]loss")
# register_grad_hook(loss, "loss")
# print(f"loss_kwargs.global_denominator: {loss_kwargs.global_denominator}")
# loss = loss / loss_kwargs.global_denominator

return loss, (logits, {})

def chunk_mode(
Expand All @@ -202,15 +212,17 @@ def chunk_mode(
else:
assert self.liger_loss_fct is not None, "liger_loss_fct must be initialized in liger mode"
shifted_labels = loss_kwargs.shifted_labels # (bs, seq_len)
loss_weight = loss_kwargs.loss_weight # (bs, seq_len)
# loss_weight = loss_kwargs.loss_weight # (bs, seq_len)

bs, seq, dim = hidden_states.shape
hidden_states = hidden_states.reshape(bs * seq, dim)
shifted_labels = shifted_labels.flatten()
# liger kernel dont support reduction=="none"
# step 2.b in the loss calculation: sum the loss over all tokens, then multiply the loss weight (i.e. divide by the global_denominator)
loss = self.liger_loss_fct(head_weight, hidden_states, shifted_labels)
mask = loss_weight != 0
w = loss_weight.sum() / mask.sum() # equal to the global_denominator
loss = loss * w
# ProberList.record_tensor(loss, "[lm_head.ce_loss][before calibration]loss")
# mask = loss_weight != 0
# w = loss_weight.sum() / mask.sum() # w equals to 1/global_denominator
# loss = loss * w
loss = loss / loss_kwargs.global_denominator
return loss, (None, {})
6 changes: 5 additions & 1 deletion xtuner/v1/loss/moe_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ def forward(self, router_logits, n_routed_experts, num_experts_per_tok):
routing_weights_mean_global = routing_weights.mean(dim=1)
loss = scale_global * (tokens_per_expert_global * routing_weights_mean_global).sum(-1)
loss = loss.sum()

# from xtuner.v1.profiler.prober import ProberList
# ProberList.record_tensor(routing_weights_mean_global, "[balancing_loss][after]routing_weights_mean_global")
# ProberList.record_tensor(tokens_per_expert_global, "[balancing_loss][after]tokens_per_expert_global")
# ProberList.record_tensor(scale_global, "[balancing_loss][after]scale_global")
return loss * self.loss_weight


Expand All @@ -92,6 +95,7 @@ def z_loss(router_logits: torch.Tensor, global_average: bool = False):
unmasked_num_global = all_reduce(unmasked_num_rank, "sum", dist.group.WORLD) # type: ignore
world_size = dist.get_world_size()
z_loss = z_loss * unmasked_num * world_size / unmasked_num_global

return z_loss


Expand Down
2 changes: 1 addition & 1 deletion xtuner/v1/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class TransformerConfig(PydanticBaseModel):
vocab_size: Annotated[int, Parameter(group="model")]
max_position_embeddings: Annotated[int, Parameter(group="model")]
eos_token_id: Annotated[int, Parameter(group="model")]
pad_token_id: Annotated[int, Parameter(group="model")]
pad_token_id: Annotated[int | None, Parameter(group="model")] = None
num_hidden_layers: Annotated[int, Parameter(group="model")]
hidden_size: Annotated[int, Parameter(group="model")]
intermediate_size: Annotated[int, Parameter(group="model")]
Expand Down
4 changes: 2 additions & 2 deletions xtuner/v1/model/dense/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def from_hf(cls, hf_path: str | Path) -> Self:
hf_config=hf_config,
vocab_size=hf_config.vocab_size,
max_position_embeddings=hf_config.max_position_embeddings,
pad_token_id=hf_config.eos_token_id,
pad_token_id=hf_config.get("pad_token_id"),
bos_token_id=hf_config.bos_token_id,
eos_token_id=hf_config.eos_token_id,
num_hidden_layers=hf_config.num_hidden_layers,
Expand Down Expand Up @@ -103,7 +103,7 @@ class Qwen2Dense7BConfig(Qwen2DenseConfig):
vocab_size: int = 152064
max_position_embeddings: int = 32768
bos_token_id: int = 151643
pad_token_id: int = 151643 # eos_id
pad_token_id: int | None = None
eos_token_id: int = 151643 # eos_id
num_hidden_layers: int = 28
hidden_size: int = 3584
Expand Down
8 changes: 4 additions & 4 deletions xtuner/v1/model/dense/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def from_hf(cls, hf_path: str | Path) -> Self:
config = cls(
vocab_size=hf_config.vocab_size,
max_position_embeddings=hf_config.max_position_embeddings,
pad_token_id=hf_config.eos_token_id,
pad_token_id=hf_config.get("pad_token_id"),
bos_token_id=hf_config.bos_token_id,
eos_token_id=hf_config.eos_token_id,
num_hidden_layers=hf_config.num_hidden_layers,
Expand Down Expand Up @@ -101,7 +101,7 @@ def hf_config(self) -> HFQwen3DenseConfig:
class Qwen3Dense8BConfig(Qwen3DenseConfig):
vocab_size: int = 151936
max_position_embeddings: int = 40960
pad_token_id: int = 151643
pad_token_id: int | None = None
eos_token_id: int = 151645
bos_token_id: int = 151643
num_hidden_layers: int = 36
Expand All @@ -121,7 +121,7 @@ class Qwen3Dense8BConfig(Qwen3DenseConfig):
class Qwen3Dense4BConfig(Qwen3DenseConfig):
vocab_size: int = 151936
max_position_embeddings: int = 262144
pad_token_id: int = 151643
pad_token_id: int | None = None
eos_token_id: int = 151645
bos_token_id: int = 151643
num_hidden_layers: int = 36
Expand All @@ -141,7 +141,7 @@ class Qwen3Dense4BConfig(Qwen3DenseConfig):
class Qwen3Dense0P6BConfig(Qwen3DenseConfig):
vocab_size: int = 151936
max_position_embeddings: int = 40960
pad_token_id: int = 151643
pad_token_id: int | None = None
eos_token_id: int = 151645
bos_token_id: int = 151643
num_hidden_layers: int = 28
Expand Down
4 changes: 2 additions & 2 deletions xtuner/v1/model/moe/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def to_hf_key_list(self, key: str) -> list[str]:
class DeepSeekV3Config(MoEConfig):
vocab_size: int = 129280
max_position_embeddings: int = 163840
pad_token_id: int = 1 # eos_id
pad_token_id: int | None = None
eos_token_id: int = 1
num_hidden_layers: int = 61
first_k_dense_replace: int = 3
Expand Down Expand Up @@ -110,7 +110,7 @@ def from_hf(cls, hf_path: str | Path) -> Self:
config = cls(
vocab_size=cfg.vocab_size,
max_position_embeddings=cfg.max_position_embeddings,
pad_token_id=cfg.eos_token_id,
pad_token_id=cfg.get("pad_token_id"),
eos_token_id=cfg.eos_token_id,
num_hidden_layers=cfg.num_hidden_layers,
first_k_dense_replace=cfg.first_k_dense_replace,
Expand Down
1 change: 1 addition & 0 deletions xtuner/v1/model/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,7 @@ def patched_emb_forward(self, input):
w = self.weight.to_local()
else:
w = self.weight
# print(f"Embedding forward args: type(self)={type(self)}, padding_idx={self.padding_idx}, max_norm={self.max_norm}, norm_type={self.norm_type}, scale_grad_by_freq={self.scale_grad_by_freq}, sparse={self.sparse}, self={str(self)}")
return F.embedding(
input,
w,
Expand Down
9 changes: 6 additions & 3 deletions xtuner/v1/model/moe/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def from_hf(cls, hf_path: str | Path) -> Self:
config = cls(
vocab_size=hf_config.vocab_size,
max_position_embeddings=hf_config.max_position_embeddings,
pad_token_id=hf_config.eos_token_id,
pad_token_id=hf_config.get("pad_token_id"),
bos_token_id=hf_config.bos_token_id,
eos_token_id=hf_config.eos_token_id,
num_hidden_layers=hf_config.num_hidden_layers,
Expand Down Expand Up @@ -123,7 +123,10 @@ def hf_config(self) -> HFQwen3MoeConfig:
class Qwen3MoE30BA3Config(Qwen3MoEConfig):
vocab_size: int = 151936
max_position_embeddings: int = 40960
pad_token_id: int = 151643
# Qwen3 Model(dense and moe)'s pad_token_id is not set, so we need to set it to None.
# If this pad_token_id is not set, the embedding module will not act specially for pad token.
# Note: Qwen3 Model's pad_token_id may be different from Qwen tokenizer's pad_token_id.
pad_token_id: int | None = None
eos_token_id: int = 151645
bos_token_id: int = 151643
num_hidden_layers: int = 48
Expand Down Expand Up @@ -155,7 +158,7 @@ class Qwen3MoE30BA3Config(Qwen3MoEConfig):
class Qwen3MoE235BA22Config(Qwen3MoEConfig):
vocab_size: int = 151936
max_position_embeddings: int = 40960
pad_token_id: int = 151643
pad_token_id: int | None = None
eos_token_id: int = 151645
bos_token_id: int = 151643
num_hidden_layers: int = 94
Expand Down
1 change: 0 additions & 1 deletion xtuner/v1/module/attention/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,6 @@ def decoding(
_key_states = key_states.transpose(1, 2).squeeze(0)
_value_states = value_states.transpose(1, 2).squeeze(0)

# torch.distributed.breakpoint()
block_index = block_table[:, 0] + (seq_lens_k[:bs] - 1) // block_size
past_key_values[self.layer_idx][0][block_index, (seq_lens_k[:bs] - 1) % block_size] = _key_states
past_key_values[self.layer_idx][1][block_index, (seq_lens_k[:bs] - 1) % block_size] = _value_states
Expand Down
Loading
Loading