1717 DeviceMesh ,
1818 distribute_module ,
1919 distribute_tensor ,
20- DTensor ,
2120 Shard ,
2221)
2322from torch .distributed .tensor .parallel import ParallelStyle
@@ -85,12 +84,15 @@ def __init__(self):
8584 super ().__init__ ()
8685 self .input_splits = None
8786 self .output_splits = None
87+ self .input_shape = None
88+ self .permuted_indices = None
8889
8990 # performing all-to-all dispatch on the input
9091 def _token_dispatch (self , mod , inputs , device_mesh ):
9192 # annotate module input placements/sharding with input_layouts
9293 routed_input , num_tokens_per_expert = inputs
93- ep_size = device_mesh .shape [0 ]
94+ ep_degree = device_mesh .shape [0 ]
95+ num_local_experts = num_tokens_per_expert .shape [0 ] // ep_degree
9496
9597 # generate the input splits and output splits for all-to-all
9698 with torch .no_grad ():
@@ -106,13 +108,13 @@ def _token_dispatch(self, mod, inputs, device_mesh):
106108 num_tokens_per_expert_group
107109 )
108110 input_splits = (
109- num_tokens_per_expert .view (ep_size , - 1 )
111+ num_tokens_per_expert .view (ep_degree , - 1 )
110112 .sum (dim = 1 )
111113 .to (torch .device ("cpu" ), non_blocking = True )
112114 )
113115 # NOTE: this would incur a device-to-host sync
114116 output_splits = (
115- num_tokens_per_expert_group .view (ep_size , - 1 )
117+ num_tokens_per_expert_group .view (ep_degree , - 1 )
116118 .sum (dim = 1 )
117119 .to (torch .device ("cpu" ), non_blocking = False )
118120 )
@@ -133,9 +135,20 @@ def _token_dispatch(self, mod, inputs, device_mesh):
133135 # Rather, it is of the format
134136 # [#tokens for local expert 0 from EP rank 0, #tokens for local expert 1 from EP rank 0, ...,
135137 # #tokens for local expert 0 from EP rank 1, #tokens for local expert 1 from EP rank 1, ...]
136- # We need to perform another shuffle to get the correct format -- this is done via the function
137- # generate_permute_indices in moe.py, which also does padding to make sure the number of tokens
138- # each expert gets locally is a multiple of ALIGN_SIZE_M.
138+ # We need to perform another shuffle to get the correct layout, via the _permute function
139+ # below, which also does padding to make sure the number of tokens each expert gets locally
140+ # is a multiple of TOKEN_GROUP_ALIGN_SIZE_M.
141+ # Note that this will create side effects when wrapping the for-loop implementation
142+ # of GroupedExperts, as it does not need padding.
143+
144+ (
145+ self .input_shape ,
146+ routed_input ,
147+ self .permuted_indices ,
148+ num_tokens_per_expert_group ,
149+ ) = _permute (
150+ routed_input , num_tokens_per_expert_group , ep_degree , num_local_experts
151+ )
139152
140153 return routed_input , num_tokens_per_expert_group
141154
@@ -148,6 +161,10 @@ def _partition_fn(name, mod, device_mesh):
148161
149162 # performing all-to-all combine on the output
150163 def _token_combine (self , mod , routed_output , device_mesh ):
164+ routed_output = _unpermute (
165+ routed_output , self .input_shape , self .permuted_indices
166+ )
167+
151168 routed_output = all_to_all_single_autograd (
152169 routed_output ,
153170 self .input_splits ,
@@ -168,20 +185,9 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
168185
169186# This class is for dp2ep with TP (without TP we can just use ExpertParallel)
170187class ExpertTensorParallel (ExpertParallel ):
171- def __init__ (
172- self ,
173- tp_mesh : DeviceMesh ,
174- ep_mesh : DeviceMesh ,
175- ):
176- super ().__init__ ()
177- # TODO: has to pass in the meshes in addition to the [ep, tp] device_mesh,
178- # as DeviceMesh doesn't support slicing from a submesh.
179- self .tp_mesh = tp_mesh
180- self .ep_mesh = ep_mesh
181-
182188 def _token_dispatch (self , mod , inputs , device_mesh ):
183189 # token dispatch happens on the EP mesh, whereas device_mesh is [ep, tp] mesh
184- return super ()._token_dispatch (mod , inputs , self . ep_mesh )
190+ return super ()._token_dispatch (mod , inputs , device_mesh [ "ep" ] )
185191
186192 def _partition_fn_2d (self , name , mod , ep_tp_mesh ):
187193 # w1 shape = (experts, out_dim, in_dim)
@@ -204,7 +210,7 @@ def _partition_fn_2d(self, name, mod, ep_tp_mesh):
204210
205211 def _token_combine (self , mod , routed_output , device_mesh ):
206212 # token combine happens on the EP mesh, whereas device_mesh is [ep, tp] mesh
207- return super ()._token_combine (mod , routed_output , self . ep_mesh )
213+ return super ()._token_combine (mod , routed_output , device_mesh [ "ep" ] )
208214
209215 def _apply (self , module : nn .Module , device_mesh : DeviceMesh ) -> nn .Module :
210216 return distribute_module (
@@ -216,25 +222,42 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
216222 )
217223
218224
219- def expert_parallel (func : Callable ) -> Callable :
225+ def _permute (x , num_tokens_per_expert , ep_degree , num_local_experts ):
226+ # TODO: move to core
227+ from torchtitan .experiments .kernels .moe .indices import generate_permute_indices
228+
229+ global TOKEN_GROUP_ALIGN_SIZE_M
230+ x_padded_per_expert = x .shape [0 ] + num_local_experts * TOKEN_GROUP_ALIGN_SIZE_M
231+ padded_max_len = _round_up (x_padded_per_expert , TOKEN_GROUP_ALIGN_SIZE_M )
232+ with torch .no_grad ():
233+ (permuted_indices , num_tokens_per_expert , _offsets ,) = generate_permute_indices (
234+ num_tokens_per_expert ,
235+ num_local_experts ,
236+ ep_degree ,
237+ padded_max_len ,
238+ TOKEN_GROUP_ALIGN_SIZE_M ,
239+ )
240+
241+ x = torch .vstack ((x , x .new_zeros ((x .shape [- 1 ]))))
242+ input_shape = x .shape
243+ x = x [permuted_indices , :]
244+
245+ return input_shape , x , permuted_indices , num_tokens_per_expert
246+
247+
248+ def _unpermute (out , input_shape , permuted_indices ):
249+ out_unpermuted = out .new_empty (input_shape )
250+ out_unpermuted [permuted_indices , :] = out
251+ out = out_unpermuted [:- 1 ]
252+ return out
253+
254+
255+ def indices_padding_wrapper (func : Callable ) -> Callable :
220256 """
221- This is a wrapper applied to the GroupedExperts computation, serving
222- the following three purposes:
223- 1. Convert parameters from DTensors to plain Tensors, to work with
224- dynamic-shape inputs which cannot be easily expressed as DTensors.
225- 2. In Expert Parallel, apply the generate_permute_indices kernel to
226- permute the inputs to be ordered by local experts (see the _token_dispatch
227- function in ExpertParallel) and permute the outputs back.
228- 3. In order to use torch._grouped_mm, we need to make sure the number of
229- tokens each expert gets is a multiple of ALIGN_SIZE_M. The generate_permute_indices
230- kernel also helps achieve this via padding, without incurring synchronization
231- between device and host. Note that this will create side effects when wrapping
232- the for-loop implementation of GroupedExperts, as it does not need padding.
233-
234- Among the above:
235- 1 and 2 are needed only when expert_parallel_degree > 1.
236- 3 is needed even for single-device computation.
237- 2 can be moved to ExpertParallel _token_dispatch if not coupled with 3.
257+ In order to use torch._grouped_mm, we need to make sure the number of
258+ tokens each expert gets is a multiple of TOKEN_GROUP_ALIGN_SIZE_M. The
259+ generate_permute_indices kernel also helps achieve this via padding,
260+ without incurring synchronization between device and host.
238261 """
239262
240263 def wrapper (
@@ -244,45 +267,16 @@ def wrapper(
244267 x : torch .Tensor ,
245268 num_tokens_per_expert : torch .Tensor ,
246269 ) -> torch .Tensor :
247- global TOKEN_GROUP_ALIGN_SIZE_M
248- if isinstance (w1 , DTensor ):
249- w1 = w1 .to_local ()
250- w2 = w2 .to_local ()
251- w3 = w3 .to_local ()
270+ num_local_experts = w1 .shape [0 ]
271+ ep_degree = num_tokens_per_expert .shape [0 ] // num_local_experts
252272
253- from torchtitan .experiments .kernels .moe .indices import generate_permute_indices
254-
255- experts_per_ep_rank = w1 .shape [0 ]
256- num_ep_ranks = num_tokens_per_expert .shape [0 ] // experts_per_ep_rank
257-
258- # Make sure max_len of permuted token indicies is divisible by TOKEN_GROUP_ALIGN_SIZE_M,
259- # by padding it to the nearest multiple of TOKEN_GROUP_ALIGN_SIZE_M.
260- x_padded_per_expert = (
261- x .shape [0 ] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M
273+ input_shape , x , permuted_indices , num_tokens_per_expert = _permute (
274+ x , num_tokens_per_expert , ep_degree , num_local_experts
262275 )
263- padded_max_len = _round_up (x_padded_per_expert , TOKEN_GROUP_ALIGN_SIZE_M )
264- with torch .no_grad ():
265- (
266- permuted_indices ,
267- num_tokens_per_expert ,
268- _ , # offsets,
269- ) = generate_permute_indices (
270- num_tokens_per_expert ,
271- experts_per_ep_rank ,
272- num_ep_ranks ,
273- padded_max_len ,
274- TOKEN_GROUP_ALIGN_SIZE_M ,
275- )
276-
277- x = torch .vstack ((x , x .new_zeros ((x .shape [- 1 ]))))
278- input_shape = x .shape
279- x = x [permuted_indices , :]
280276
281277 out = func (w1 , w2 , w3 , x , num_tokens_per_expert )
282278
283- out_unpermuted = out .new_empty (input_shape )
284- out_unpermuted [permuted_indices , :] = out
285- out = out_unpermuted [:- 1 ]
279+ out = _unpermute (out , input_shape , permuted_indices )
286280
287281 return out
288282
@@ -294,11 +288,12 @@ def wrapper(
294288class ReordererSequenceParallel (ParallelStyle ):
295289 def __init__ (self ):
296290 super ().__init__ ()
297- self .num_tokens = None
291+ self .top_k = None
298292
299293 def _prepare_inputput_fn (self , mod , inputs , device_mesh ):
294+ # shape (batch_size*seq_len, top_k)
300295 top_scores , selected_experts_indices = inputs
301- self .num_tokens = top_scores .shape [ 0 ]
296+ num_tokens , self .top_k = top_scores .shape
302297
303298 # NOTE: If needed, we can pad tokens in case bs*slen is not divisible by TP degree
304299 # if top_scores.shape[0] % device_mesh.size() != 0:
@@ -310,8 +305,12 @@ def _prepare_inputput_fn(self, mod, inputs, device_mesh):
310305
311306 def _split_along_first_dim (x : torch .Tensor ) -> torch .Tensor :
312307 assert x .is_contiguous ()
313- assert self .num_tokens % device_mesh .size () == 0
314- local_num_tokens = self .num_tokens // device_mesh .size ()
308+ if num_tokens % device_mesh .size () != 0 :
309+ raise ValueError (
310+ "Uneven split of tokens of is not supported yet. "
311+ "Requires EP degree dividing batch size * seq len."
312+ )
313+ local_num_tokens = num_tokens // device_mesh .size ()
315314 local_rank = device_mesh .get_local_rank ()
316315 offset = local_rank * local_num_tokens
317316 output = x [offset : offset + local_num_tokens ]
@@ -321,17 +320,18 @@ def _split_along_first_dim(x: torch.Tensor) -> torch.Tensor:
321320 top_scores = _split_along_first_dim (top_scores )
322321 selected_experts_indices = _split_along_first_dim (selected_experts_indices )
323322
323+ # shape (batch_size * seq_len // ep_degree, top_k)
324324 return top_scores , selected_experts_indices
325325
326326 def _prepare_output_fn (self , mod , outputs , device_mesh ):
327+ # shape (batch_size * seq_len * top_k // ep_degree)
327328 top_scores , token_indices_experts_sorted , num_tokens_per_expert = outputs
328329
329330 # NOTE: As we shard routed tokens along bs*slen dim across the TP ranks,
330331 # the MoE gather and scatter still require global token indices.
331332 local_rank = device_mesh .get_local_rank ()
332- token_indices_experts_sorted += (
333- self .num_tokens // device_mesh .size () * local_rank
334- )
333+ # fact: top_scores.shape[0] // self.top_k = batch_size * seq_len // ep_degree
334+ token_indices_experts_sorted += top_scores .shape [0 ] // self .top_k * local_rank
335335
336336 return top_scores , token_indices_experts_sorted , num_tokens_per_expert
337337
0 commit comments