From 25e73215d672f178607b61a698881c2b993b3da7 Mon Sep 17 00:00:00 2001 From: Rohan Varma Date: Mon, 26 Sep 2022 23:29:26 +0000 Subject: [PATCH] HSGD --- config/base_config.py | 4 +- config/t5_config.py | 6 +++ main_training.py | 108 +++++++++++++++++------------------------- run_training.sh | 2 +- 4 files changed, 52 insertions(+), 68 deletions(-) mode change 100644 => 100755 run_training.sh diff --git a/config/base_config.py b/config/base_config.py index 42402b8..af4dd5d 100644 --- a/config/base_config.py +++ b/config/base_config.py @@ -68,14 +68,14 @@ class base_config: num_workers_dataloader: int = 2 # training - batch_size_training: int = 16 + batch_size_training: int = 2 # activation checkpointing fsdp_activation_checkpointing: bool = True # validation run_validation: bool = False - val_batch_size = 18 + val_batch_size = 1 # logging track_memory = True diff --git a/config/t5_config.py b/config/t5_config.py index ccd226b..e8ea88b 100644 --- a/config/t5_config.py +++ b/config/t5_config.py @@ -249,8 +249,14 @@ def train( ) loss = output["loss"] loss.backward() + assert optimizer if optimizer: optimizer.step() + if hasattr(model, '_averager'): + print(" -- averaging --") + model._averager.average_parameters(model.parameters()) + else: + print(" --- NOT averaging --") if local_rank == 0: inner_pbar.update(1) diff --git a/main_training.py b/main_training.py index a9476f5..e5d42c2 100644 --- a/main_training.py +++ b/main_training.py @@ -14,6 +14,9 @@ MixedPrecision, StateDictType, ) +from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import PostLocalSGDState, post_localSGD_hook +import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD +from torch.nn.parallel import DistributedDataParallel as DDP import model_checkpointing @@ -54,10 +57,10 @@ def setup_environ_flags(cfg, rank): os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1) if cfg.nccl_debug_handler: os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1) - if cfg.distributed_debug: - os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" - if rank == 0: - print(f"--> running with torch dist debug set to detail") + # if cfg.distributed_debug: + # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" + # if rank == 0: + # print(f"--> running with torch dist debug set to detail") def cleanup(): @@ -176,22 +179,34 @@ def fsdp_main(): if rank == 0: print(f"--> Model converted to BF16.\nRunning in ** PURE ** BFloat mode") - # ----- main FSDP init ----------- - model = FSDP( - model, - auto_wrap_policy=my_auto_wrap_policy, - mixed_precision=mp_policy, - backward_prefetch=prefetch_policy, - sharding_strategy=cfg.sharding_strategy, - device_id=torch.cuda.current_device(), - forward_prefetch=cfg.forward_prefetch, - limit_all_gathers=True, - ) - - if cfg.fsdp_activation_checkpointing: - config.fsdp_checkpointing(model) - if rank == 0: - print(f"--> FSDP activation checkpointing in use") + model = model.cuda() + model = DDP(model, device_ids=[torch.cuda.current_device()]) + # Register a post-localSGD communication hook. + subgroup, _ = dist.new_subgroups() + print(f"WS of subgroups: {dist.get_world_size(subgroup)}") + state = PostLocalSGDState(process_group=None, subgroup=subgroup, start_localSGD_iter=100) + model.register_comm_hook(state, post_localSGD_hook) + from collections import OrderedDict + d = OrderedDict([(1, 8)]) + averager = hierarchicalSGD.HierarchicalModelAverager(period_group_size_dict=d, warmup_steps=10) + model._averager = averager + # >>> # Register a post-localSGD communication hook. + # >>> # Assume that each machine has 4 GPUs, then each intra-machine subgroup has a size of 4. + # >>> subgroup, _ = dist.new_subgroups() + # >>> state = PostLocalSGDState(subgroup=subgroup, start_localSGD_iter=100) + # >>> model.register_comm_hook(state, post_localSGD_hook) + # >>> + # >>> # Average parameters among each group of 8 processes every 4 iterations, and among all + # >>> # the 16 processes every 16 iterations. + # >>> averager = hierarchicalSGD.HierarchicalModelAverager( + # >>> period_group_size_dict=OrderedDict([(4, 8), (16, 16)]), warmup_steps=100) + # Use Hierarchical SGD + + +# if cfg.fsdp_activation_checkpointing: +# config.fsdp_checkpointing(model) +# if rank == 0: +# print(f"--> FSDP activation checkpointing in use") # print sharding plan? if rank == 0 and cfg.print_sharding_plan: @@ -212,7 +227,7 @@ def fsdp_main(): data_loader = torch.utils.data.DataLoader( dataset, batch_size=cfg.batch_size_training, - num_workers=cfg.num_workers_dataloader, + num_workers=1, pin_memory=False, sampler=train_sampler, ) @@ -221,7 +236,7 @@ def fsdp_main(): val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=cfg.val_batch_size, - num_workers=cfg.num_workers_dataloader, + num_workers=1, pin_memory=False, sampler=val_sampler, ) @@ -234,50 +249,13 @@ def fsdp_main(): else: tracking_duration = None - # warmup, this is only used in the non-recursive ParamExecOrderPolicy - config.train( - model, data_loader, None, None, memmax, local_rank, tracking_duration, 1 - ) - if rank == 0: - print("Finish warm up") - model.zero_grad() - - # optimizer ---------- - optimizer = None lr = 8e-4 weight_decay = 0.0 - - if cfg.optimizer == "int8": - import bitsandbytes as bnb - - optimizer = bnb.optim.Adam8bit( - model.parameters(), lr=lr, weight_decay=weight_decay, amsgrad=False - ) - if rank == 0: - print(f"Running with 8 bit optimizer") - - elif cfg.optimizer == "AnyPrecision": - import optimizers - - optimizer = optimizers.AnyPrecisionAdamW( - model.parameters(), - lr=lr, - weight_decay=weight_decay, - momentum_dtype=cfg.ap_momentum_dtype, - variance_dtype=cfg.ap_variance_dtype, - use_kahan_summation=cfg.ap_use_kahan_summation, - ) - if rank == 0: - print( - f"Running with AnyPrecision Optimizer, momo={cfg.ap_momentum_dtype}, var = {cfg.ap_variance_dtype}, kahan summation = {cfg.ap_use_kahan_summation}" - ) - - else: - optimizer = torch.optim.AdamW( - model.parameters(), lr=lr, weight_decay=weight_decay, amsgrad=False - ) - if rank == 0: - print(f"Running with AdamW optimizer") + optimizer = torch.optim.AdamW( + model.parameters(), lr=lr, weight_decay=weight_decay, amsgrad=False + ) + if rank == 0: + print(f"Running with AdamW optimizer") # optimizer = torch.optim.SGD(model.parameters(), lr=0.01) @@ -286,7 +264,7 @@ def fsdp_main(): model_checkpointing.load_optimizer_checkpoint(model, optimizer, rank, cfg) torch_profiler = None - if cfg.run_profiler and rank == 0: + if cfg.run_profiler and rank == 0 and False: print(f"Profiling active. Traces will be saved at {cfg.profile_folder}") with torch.profiler.profile( diff --git a/run_training.sh b/run_training.sh old mode 100644 new mode 100755 index 62e445a..8280432 --- a/run_training.sh +++ b/run_training.sh @@ -1 +1 @@ -torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=101 --rdzv_endpoint="localhost:5970" main_training.py --model deepvit +torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=101 --rdzv_endpoint="localhost:5970" main_training.py --model t5