@@ -481,7 +481,7 @@ def _get_fused_hf_param(
481481 dtype : torch .dtype ,
482482 device = "cpu" ,
483483 bucket_size = None ,
484- return_full_key_per_rank : bool = False ,
484+ update_weights_for_rl : bool = False ,
485485 ) -> Generator [tuple [list [str ], list [torch .Tensor ]], None , None ]:
486486 if not params :
487487 return
@@ -506,63 +506,58 @@ def _get_hf_params(
506506 for load_spec , fsdp_unshared_tensor in zip (spec_list , fsdp_unshard_tensor_list ):
507507 hf_keys = load_spec .hf_keys
508508
509- if load_spec .group is not None :
510- all_hf_keys_list : list [None ] | list [list [str ]] = [None for _ in range (load_spec .group .size ())]
511- dist .all_gather_object (all_hf_keys_list , hf_keys , group = load_spec .group )
512- all_hf_keys_list = cast (list [list [str ]], all_hf_keys_list )
513- all_hf_keys = list (chain (* all_hf_keys_list ))
509+ if update_weights_for_rl :
510+ hf_keys_list .append (hf_keys )
511+ saved_fused_tensor_list .append (fsdp_unshared_tensor )
514512 else :
515- all_hf_keys = hf_keys
516-
517- current_rank = dist .get_rank ()
518- fused_save_ranks = self ._get_ranks_to_save_fused_tensor (len (all_hf_keys ))
519- key_per_rank = len (all_hf_keys ) / len (fused_save_ranks )
520- assert key_per_rank .is_integer (), (
521- f"XTuner Internal Error, size of all_hf_keys: { len (all_hf_keys )} , "
522- f"size of `fused_save_ranks` { len (fused_save_ranks )} "
523- )
513+ if load_spec .group is not None :
514+ all_hf_keys_list : list [None ] | list [list [str ]] = [None for _ in range (load_spec .group .size ())]
515+ dist .all_gather_object (all_hf_keys_list , hf_keys , group = load_spec .group )
516+ all_hf_keys_list = cast (list [list [str ]], all_hf_keys_list )
517+ all_hf_keys = list (chain (* all_hf_keys_list ))
518+ else :
519+ all_hf_keys = hf_keys
520+
521+ current_rank = dist .get_rank ()
522+ fused_save_ranks = self ._get_ranks_to_save_fused_tensor (len (all_hf_keys ))
523+ key_per_rank = len (all_hf_keys ) / len (fused_save_ranks )
524+ assert key_per_rank .is_integer (), (
525+ f"XTuner Internal Error, size of all_hf_keys: { len (all_hf_keys )} , "
526+ f"size of `fused_save_ranks` { len (fused_save_ranks )} "
527+ )
524528
525- # 1. When return_full_key_per_rank is False, we intends to save hf models across ranks,
526- # each rank only saves part of hf keys and tensors
527- # 2. When return_full_key_per_rank is True, we intends to generate full tensors on each
528- # rank for ipc updating weights in RL training.
529- if not return_full_key_per_rank :
530529 start = int (current_rank * key_per_rank )
531530 end = int (start + key_per_rank )
532- else :
533- start = 0
534- end = len (all_hf_keys )
535531
536- _hf_key_list = all_hf_keys [start :end ]
532+ _hf_key_list = all_hf_keys [start :end ]
537533
538- if not _hf_key_list :
539- continue
534+ if not _hf_key_list :
535+ continue
540536
541- hf_keys_list .append (_hf_key_list )
537+ hf_keys_list .append (_hf_key_list )
542538
543- assert load_spec .dim is not None
544- if load_spec .group is not None :
545539 assert load_spec .dim is not None
546- _gathered_tensor_list = [
547- torch .zeros_like (fsdp_unshared_tensor ) for _ in range (load_spec .group .size ())
548- ]
549- dist .all_gather (_gathered_tensor_list , fsdp_unshared_tensor , group = load_spec .group )
550- _gathered_tensor = torch .cat (_gathered_tensor_list , dim = load_spec .dim )
551- else :
552- _gathered_tensor = fsdp_unshared_tensor
553-
554- hf_tensor_size = _gathered_tensor .shape [load_spec .dim ] / len (all_hf_keys )
555- _saved_fused_tensor = torch .index_select (
556- _gathered_tensor ,
557- dim = load_spec .dim ,
558- index = torch .arange (
559- int (start * hf_tensor_size ),
560- int (end * hf_tensor_size ),
561- dtype = torch .int64 ,
562- device = _gathered_tensor .device ,
563- ),
564- )
565- saved_fused_tensor_list .append (_saved_fused_tensor )
540+ if load_spec .group is not None :
541+ assert load_spec .dim is not None
542+ _gathered_tensor_list = [
543+ torch .zeros_like (fsdp_unshared_tensor ) for _ in range (load_spec .group .size ())
544+ ]
545+ dist .all_gather (_gathered_tensor_list , fsdp_unshared_tensor , group = load_spec .group )
546+ _gathered_tensor = torch .cat (_gathered_tensor_list , dim = load_spec .dim )
547+ else :
548+ _gathered_tensor = fsdp_unshared_tensor
549+ hf_tensor_size = _gathered_tensor .shape [load_spec .dim ] / len (all_hf_keys )
550+ _saved_fused_tensor = torch .index_select (
551+ _gathered_tensor ,
552+ dim = load_spec .dim ,
553+ index = torch .arange (
554+ int (start * hf_tensor_size ),
555+ int (end * hf_tensor_size ),
556+ dtype = torch .int64 ,
557+ device = _gathered_tensor .device ,
558+ ),
559+ )
560+ saved_fused_tensor_list .append (_saved_fused_tensor )
566561
567562 # Split the fused tensor into hf tensors
568563 hf_tensor_list : list [torch .Tensor ] = []
@@ -1141,6 +1136,14 @@ def _fsdp_foreach_allgather(
11411136
11421137 # Concatenate the tensors along the FSDP shard dim
11431138 for tensors , size in zip (_fsdp_unsharded_tensor_list , origin_fsdp_size ):
1139+ # special case for partition of tensors are contiguous
1140+ fused_tensor = self .fuse_contiguous_chunks_without_alloc (tensors )
1141+ if fused_tensor is not None and fused_tensor .shape [self .FSDP_SHARD_DIM ] == size :
1142+ fsdp_unsharded_tensor_list .append (fused_tensor )
1143+ continue
1144+ elif fused_tensor is not None :
1145+ # free memory ASAP
1146+ del fused_tensor
11441147 tensor = torch .cat (tensors , dim = self .FSDP_SHARD_DIM )
11451148 cat_tensor = torch .index_select (
11461149 tensor ,
@@ -1157,6 +1160,48 @@ def _fsdp_foreach_allgather(
11571160
11581161 return fsdp_unsharded_tensor_list
11591162
1163+ @staticmethod
1164+ def fuse_contiguous_chunks_without_alloc (tensors : list [torch .Tensor ]) -> torch .Tensor | None :
1165+ """Fuse contiguous chunks without extra memory allocation.
1166+
1167+ Return None if not possible.
1168+ """
1169+ if not tensors :
1170+ return None
1171+ base = tensors [0 ]
1172+ storage = base .untyped_storage ()
1173+ dtype = base .dtype
1174+ device = base .device
1175+ stride = base .stride ()
1176+
1177+ inner_stride = stride [1 :]
1178+ inner_elems = math .prod (base .shape [1 :]) if base .dim () > 1 else 1
1179+
1180+ chunks = []
1181+ for t in tensors :
1182+ if (
1183+ t .untyped_storage ().data_ptr () != storage .data_ptr ()
1184+ or t .dtype != dtype
1185+ or t .device != device
1186+ or t .stride ()[1 :] != inner_stride
1187+ ):
1188+ return None
1189+ chunks .append ((t .storage_offset (), t .shape [0 ], t ))
1190+ chunks .sort (key = lambda x : x [0 ])
1191+
1192+ expected_offset = chunks [0 ][0 ]
1193+ total_rows = 0
1194+ for offset , rows , _ in chunks :
1195+ if offset != expected_offset :
1196+ return None
1197+ expected_offset += rows * inner_elems
1198+ total_rows += rows
1199+
1200+ size = (total_rows , * base .shape [1 :])
1201+ flat = torch .empty (0 , dtype = dtype , device = device )
1202+ flat .set_ (storage , chunks [0 ][0 ], size , stride )
1203+ return flat
1204+
11601205 def _maybe_compile_layers (self ):
11611206 if self .fsdp_config is not None :
11621207 if self .fsdp_config .torch_compile :
0 commit comments