Skip to content

Commit 9603872

Browse files
authored
minor refactor over EP (#1854)
This PR: - let `ExpertParallel` handles indices permute / unpermute when EP is used - move `to_local` to model code to be more explicit - rename the `expert_parallel` wrapper which does permute / unpermute to `indices_permutation_wrapper` to be more accurate
1 parent dfd0a59 commit 9603872

File tree

4 files changed

+113
-101
lines changed

4 files changed

+113
-101
lines changed

torchtitan/distributed/expert_parallel.py

Lines changed: 79 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
DeviceMesh,
1818
distribute_module,
1919
distribute_tensor,
20-
DTensor,
2120
Shard,
2221
)
2322
from torch.distributed.tensor.parallel import ParallelStyle
@@ -85,12 +84,15 @@ def __init__(self):
8584
super().__init__()
8685
self.input_splits = None
8786
self.output_splits = None
87+
self.input_shape = None
88+
self.permuted_indices = None
8889

8990
# performing all-to-all dispatch on the input
9091
def _token_dispatch(self, mod, inputs, device_mesh):
9192
# annotate module input placements/sharding with input_layouts
9293
routed_input, num_tokens_per_expert = inputs
93-
ep_size = device_mesh.shape[0]
94+
ep_degree = device_mesh.shape[0]
95+
num_local_experts = num_tokens_per_expert.shape[0] // ep_degree
9496

9597
# generate the input splits and output splits for all-to-all
9698
with torch.no_grad():
@@ -106,13 +108,13 @@ def _token_dispatch(self, mod, inputs, device_mesh):
106108
num_tokens_per_expert_group
107109
)
108110
input_splits = (
109-
num_tokens_per_expert.view(ep_size, -1)
111+
num_tokens_per_expert.view(ep_degree, -1)
110112
.sum(dim=1)
111113
.to(torch.device("cpu"), non_blocking=True)
112114
)
113115
# NOTE: this would incur a device-to-host sync
114116
output_splits = (
115-
num_tokens_per_expert_group.view(ep_size, -1)
117+
num_tokens_per_expert_group.view(ep_degree, -1)
116118
.sum(dim=1)
117119
.to(torch.device("cpu"), non_blocking=False)
118120
)
@@ -133,9 +135,20 @@ def _token_dispatch(self, mod, inputs, device_mesh):
133135
# Rather, it is of the format
134136
# [#tokens for local expert 0 from EP rank 0, #tokens for local expert 1 from EP rank 0, ...,
135137
# #tokens for local expert 0 from EP rank 1, #tokens for local expert 1 from EP rank 1, ...]
136-
# We need to perform another shuffle to get the correct format -- this is done via the function
137-
# generate_permute_indices in moe.py, which also does padding to make sure the number of tokens
138-
# each expert gets locally is a multiple of ALIGN_SIZE_M.
138+
# We need to perform another shuffle to get the correct layout, via the _permute function
139+
# below, which also does padding to make sure the number of tokens each expert gets locally
140+
# is a multiple of TOKEN_GROUP_ALIGN_SIZE_M.
141+
# Note that this will create side effects when wrapping the for-loop implementation
142+
# of GroupedExperts, as it does not need padding.
143+
144+
(
145+
self.input_shape,
146+
routed_input,
147+
self.permuted_indices,
148+
num_tokens_per_expert_group,
149+
) = _permute(
150+
routed_input, num_tokens_per_expert_group, ep_degree, num_local_experts
151+
)
139152

140153
return routed_input, num_tokens_per_expert_group
141154

@@ -148,6 +161,10 @@ def _partition_fn(name, mod, device_mesh):
148161

149162
# performing all-to-all combine on the output
150163
def _token_combine(self, mod, routed_output, device_mesh):
164+
routed_output = _unpermute(
165+
routed_output, self.input_shape, self.permuted_indices
166+
)
167+
151168
routed_output = all_to_all_single_autograd(
152169
routed_output,
153170
self.input_splits,
@@ -168,20 +185,9 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
168185

169186
# This class is for dp2ep with TP (without TP we can just use ExpertParallel)
170187
class ExpertTensorParallel(ExpertParallel):
171-
def __init__(
172-
self,
173-
tp_mesh: DeviceMesh,
174-
ep_mesh: DeviceMesh,
175-
):
176-
super().__init__()
177-
# TODO: has to pass in the meshes in addition to the [ep, tp] device_mesh,
178-
# as DeviceMesh doesn't support slicing from a submesh.
179-
self.tp_mesh = tp_mesh
180-
self.ep_mesh = ep_mesh
181-
182188
def _token_dispatch(self, mod, inputs, device_mesh):
183189
# token dispatch happens on the EP mesh, whereas device_mesh is [ep, tp] mesh
184-
return super()._token_dispatch(mod, inputs, self.ep_mesh)
190+
return super()._token_dispatch(mod, inputs, device_mesh["ep"])
185191

186192
def _partition_fn_2d(self, name, mod, ep_tp_mesh):
187193
# w1 shape = (experts, out_dim, in_dim)
@@ -204,7 +210,7 @@ def _partition_fn_2d(self, name, mod, ep_tp_mesh):
204210

205211
def _token_combine(self, mod, routed_output, device_mesh):
206212
# token combine happens on the EP mesh, whereas device_mesh is [ep, tp] mesh
207-
return super()._token_combine(mod, routed_output, self.ep_mesh)
213+
return super()._token_combine(mod, routed_output, device_mesh["ep"])
208214

209215
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
210216
return distribute_module(
@@ -216,25 +222,42 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
216222
)
217223

218224

219-
def expert_parallel(func: Callable) -> Callable:
225+
def _permute(x, num_tokens_per_expert, ep_degree, num_local_experts):
226+
# TODO: move to core
227+
from torchtitan.experiments.kernels.moe.indices import generate_permute_indices
228+
229+
global TOKEN_GROUP_ALIGN_SIZE_M
230+
x_padded_per_expert = x.shape[0] + num_local_experts * TOKEN_GROUP_ALIGN_SIZE_M
231+
padded_max_len = _round_up(x_padded_per_expert, TOKEN_GROUP_ALIGN_SIZE_M)
232+
with torch.no_grad():
233+
(permuted_indices, num_tokens_per_expert, _offsets,) = generate_permute_indices(
234+
num_tokens_per_expert,
235+
num_local_experts,
236+
ep_degree,
237+
padded_max_len,
238+
TOKEN_GROUP_ALIGN_SIZE_M,
239+
)
240+
241+
x = torch.vstack((x, x.new_zeros((x.shape[-1]))))
242+
input_shape = x.shape
243+
x = x[permuted_indices, :]
244+
245+
return input_shape, x, permuted_indices, num_tokens_per_expert
246+
247+
248+
def _unpermute(out, input_shape, permuted_indices):
249+
out_unpermuted = out.new_empty(input_shape)
250+
out_unpermuted[permuted_indices, :] = out
251+
out = out_unpermuted[:-1]
252+
return out
253+
254+
255+
def indices_padding_wrapper(func: Callable) -> Callable:
220256
"""
221-
This is a wrapper applied to the GroupedExperts computation, serving
222-
the following three purposes:
223-
1. Convert parameters from DTensors to plain Tensors, to work with
224-
dynamic-shape inputs which cannot be easily expressed as DTensors.
225-
2. In Expert Parallel, apply the generate_permute_indices kernel to
226-
permute the inputs to be ordered by local experts (see the _token_dispatch
227-
function in ExpertParallel) and permute the outputs back.
228-
3. In order to use torch._grouped_mm, we need to make sure the number of
229-
tokens each expert gets is a multiple of ALIGN_SIZE_M. The generate_permute_indices
230-
kernel also helps achieve this via padding, without incurring synchronization
231-
between device and host. Note that this will create side effects when wrapping
232-
the for-loop implementation of GroupedExperts, as it does not need padding.
233-
234-
Among the above:
235-
1 and 2 are needed only when expert_parallel_degree > 1.
236-
3 is needed even for single-device computation.
237-
2 can be moved to ExpertParallel _token_dispatch if not coupled with 3.
257+
In order to use torch._grouped_mm, we need to make sure the number of
258+
tokens each expert gets is a multiple of TOKEN_GROUP_ALIGN_SIZE_M. The
259+
generate_permute_indices kernel also helps achieve this via padding,
260+
without incurring synchronization between device and host.
238261
"""
239262

240263
def wrapper(
@@ -244,45 +267,16 @@ def wrapper(
244267
x: torch.Tensor,
245268
num_tokens_per_expert: torch.Tensor,
246269
) -> torch.Tensor:
247-
global TOKEN_GROUP_ALIGN_SIZE_M
248-
if isinstance(w1, DTensor):
249-
w1 = w1.to_local()
250-
w2 = w2.to_local()
251-
w3 = w3.to_local()
270+
num_local_experts = w1.shape[0]
271+
ep_degree = num_tokens_per_expert.shape[0] // num_local_experts
252272

253-
from torchtitan.experiments.kernels.moe.indices import generate_permute_indices
254-
255-
experts_per_ep_rank = w1.shape[0]
256-
num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank
257-
258-
# Make sure max_len of permuted token indicies is divisible by TOKEN_GROUP_ALIGN_SIZE_M,
259-
# by padding it to the nearest multiple of TOKEN_GROUP_ALIGN_SIZE_M.
260-
x_padded_per_expert = (
261-
x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M
273+
input_shape, x, permuted_indices, num_tokens_per_expert = _permute(
274+
x, num_tokens_per_expert, ep_degree, num_local_experts
262275
)
263-
padded_max_len = _round_up(x_padded_per_expert, TOKEN_GROUP_ALIGN_SIZE_M)
264-
with torch.no_grad():
265-
(
266-
permuted_indices,
267-
num_tokens_per_expert,
268-
_, # offsets,
269-
) = generate_permute_indices(
270-
num_tokens_per_expert,
271-
experts_per_ep_rank,
272-
num_ep_ranks,
273-
padded_max_len,
274-
TOKEN_GROUP_ALIGN_SIZE_M,
275-
)
276-
277-
x = torch.vstack((x, x.new_zeros((x.shape[-1]))))
278-
input_shape = x.shape
279-
x = x[permuted_indices, :]
280276

281277
out = func(w1, w2, w3, x, num_tokens_per_expert)
282278

283-
out_unpermuted = out.new_empty(input_shape)
284-
out_unpermuted[permuted_indices, :] = out
285-
out = out_unpermuted[:-1]
279+
out = _unpermute(out, input_shape, permuted_indices)
286280

287281
return out
288282

@@ -294,11 +288,12 @@ def wrapper(
294288
class ReordererSequenceParallel(ParallelStyle):
295289
def __init__(self):
296290
super().__init__()
297-
self.num_tokens = None
291+
self.top_k = None
298292

299293
def _prepare_inputput_fn(self, mod, inputs, device_mesh):
294+
# shape (batch_size*seq_len, top_k)
300295
top_scores, selected_experts_indices = inputs
301-
self.num_tokens = top_scores.shape[0]
296+
num_tokens, self.top_k = top_scores.shape
302297

303298
# NOTE: If needed, we can pad tokens in case bs*slen is not divisible by TP degree
304299
# if top_scores.shape[0] % device_mesh.size() != 0:
@@ -310,8 +305,12 @@ def _prepare_inputput_fn(self, mod, inputs, device_mesh):
310305

311306
def _split_along_first_dim(x: torch.Tensor) -> torch.Tensor:
312307
assert x.is_contiguous()
313-
assert self.num_tokens % device_mesh.size() == 0
314-
local_num_tokens = self.num_tokens // device_mesh.size()
308+
if num_tokens % device_mesh.size() != 0:
309+
raise ValueError(
310+
"Uneven split of tokens of is not supported yet. "
311+
"Requires EP degree dividing batch size * seq len."
312+
)
313+
local_num_tokens = num_tokens // device_mesh.size()
315314
local_rank = device_mesh.get_local_rank()
316315
offset = local_rank * local_num_tokens
317316
output = x[offset : offset + local_num_tokens]
@@ -321,17 +320,18 @@ def _split_along_first_dim(x: torch.Tensor) -> torch.Tensor:
321320
top_scores = _split_along_first_dim(top_scores)
322321
selected_experts_indices = _split_along_first_dim(selected_experts_indices)
323322

323+
# shape (batch_size * seq_len // ep_degree, top_k)
324324
return top_scores, selected_experts_indices
325325

326326
def _prepare_output_fn(self, mod, outputs, device_mesh):
327+
# shape (batch_size * seq_len * top_k // ep_degree)
327328
top_scores, token_indices_experts_sorted, num_tokens_per_expert = outputs
328329

329330
# NOTE: As we shard routed tokens along bs*slen dim across the TP ranks,
330331
# the MoE gather and scatter still require global token indices.
331332
local_rank = device_mesh.get_local_rank()
332-
token_indices_experts_sorted += (
333-
self.num_tokens // device_mesh.size() * local_rank
334-
)
333+
# fact: top_scores.shape[0] // self.top_k = batch_size * seq_len // ep_degree
334+
token_indices_experts_sorted += top_scores.shape[0] // self.top_k * local_rank
335335

336336
return top_scores, token_indices_experts_sorted, num_tokens_per_expert
337337

torchtitan/experiments/kernels/moe/dispatch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ def forward( # type: ignore[no-untyped-def]
8888
out_splits_offsets, grad_out_buf, grad_in_buf, grad_in_splits_offsets
8989
)
9090
ctx.group_name = group_name
91-
return out
9291

9392
@staticmethod
9493
def backward( # type: ignore[no-untyped-def]

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
87
import torch
98
import torch.nn as nn
109
from torch.distributed.device_mesh import DeviceMesh
@@ -22,6 +21,7 @@
2221
from torchtitan.config.job_config import Compile as CompileConfig
2322
from torchtitan.distributed import NoParallel, ParallelDims
2423
from torchtitan.distributed.activation_checkpoint import apply_ac
24+
2525
from torchtitan.distributed.expert_parallel import (
2626
ExpertParallel,
2727
ExpertTensorParallel,
@@ -441,6 +441,8 @@ def apply_moe_ep_tp(
441441
ep_tp_mesh: DeviceMesh | None,
442442
etp_enabled: bool,
443443
):
444+
assert ep_mesh is not None or tp_mesh is not None
445+
444446
for transformer_block in model.layers.values():
445447
if not transformer_block.moe_enabled:
446448
continue
@@ -486,16 +488,13 @@ def apply_moe_ep_tp(
486488
experts_mesh = tp_mesh
487489
# input Replicate, output Partial
488490
experts_plan = TensorParallel()
489-
elif tp_mesh is None:
491+
elif tp_mesh is None or not etp_enabled:
490492
experts_mesh = ep_mesh
491493
# input / output sharding on the batch / tokens dim
492494
experts_plan = ExpertParallel()
493-
elif etp_enabled:
494-
experts_mesh = ep_tp_mesh
495-
experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh)
496495
else:
497-
experts_mesh = ep_mesh
498-
experts_plan = ExpertParallel()
496+
experts_mesh = ep_tp_mesh
497+
experts_plan = ExpertTensorParallel()
499498

500499
parallelize_module(
501500
module=transformer_block.moe.experts,

0 commit comments

Comments
 (0)