diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index bbb6e123..d73b0c86 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -29,7 +29,6 @@ #include "example/common/tiny_shakespeare_dataset.h" #include "example/common/tokenizer.h" -#include "example/common/utils.h" #include "example/gpt2/net.h" // I/O @@ -128,14 +127,14 @@ void Train(const nn::parallel::Rank &rank) { device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, rank.thread_rank()); if (ddp_world_size > 1) { - ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.thread_rank()), - GetDataParallelGroupRanks(rank.thread_rank())); + ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()), + GetDataParallelGroupRanks(rank.GlobalRank())); ddp_rank = ddp_pg->GetGroupRank(rank.thread_rank()); } if (tp_world_size > 1) { - tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.thread_rank()), - GetTensorParallelGroupRanks(rank.thread_rank())); + tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()), + GetTensorParallelGroupRanks(rank.GlobalRank())); tp_rank = tp_pg->GetGroupRank(rank.thread_rank()); // NOTE(zbl): Reserved for VocabParallelEmbedding nn::parallel::tp_rank = tp_rank; @@ -146,7 +145,7 @@ void Train(const nn::parallel::Rank &rank) { GetPipelineParallelProcessGroupName(rank.thread_rank()), GetPipelineParallelGroupRanks(pp_world_size)); pp_rank = pp_pg->GetGroupRank(rank.thread_rank()); - nn::parallel::pp_rank = pp_rank; + nn::parallel::pp_rank_tls = pp_rank; } } else { device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance()->GetDefaultDevice() @@ -232,7 +231,8 @@ void Train(const nn::parallel::Rank &rank) { auto shapes = std::vector>{{FLAGS_batch_size, FLAGS_sequence_length, model_config.n_embd}}; model = std::make_shared(model, pp_world_size, num_micro_batches, shapes, - pp_rank, std::make_shared(optimizer)); + pp_rank, std::make_shared(optimizer), + rank.thread_rank()); } LOG(INFO) << "start training"; @@ -321,7 +321,7 @@ void Train(const nn::parallel::Rank &rank) { const double duration_us = std::chrono::duration(iter_end - iter_start).count(); const double tps = FLAGS_total_batch_size / (duration_us / 1e6); - if (rank.thread_rank() == pp_world_size - 1) { + if (rank.GlobalRank() == pp_world_size - 1) { LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, " "DP={}, TP={}, SP={}, PP={})", step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, @@ -340,6 +340,10 @@ void Train(const nn::parallel::Rank &rank) { Profiler::Instance().Report("gpt2.report", Profiler::SortBy::DeviceTimePercentage); Profiler::Instance().PrintRecords("gpt2.records.log"); #endif + + if (pp_world_size > 1 && rank.IsMainRank()) { + pp_pg->Barrier(); + } } int main(int argc, char *argv[]) { @@ -355,7 +359,7 @@ int main(int argc, char *argv[]) { if (FLAGS_nthread_per_process > 1) { std::vector threads; for (int idx = 0; idx < FLAGS_nthread_per_process; ++idx) { - nn::parallel::Rank rank(nn::parallel::global::GetLocalProcRank(), idx, + nn::parallel::Rank rank(nn::parallel::global::GetGlobalProcRank(), idx, nn::parallel::global::GetNprocPerNode(), FLAGS_nthread_per_process); threads.emplace_back(Train, rank); } diff --git a/example/gpt2/net.cc b/example/gpt2/net.cc index 259439c0..8217cda6 100644 --- a/example/gpt2/net.cc +++ b/example/gpt2/net.cc @@ -179,7 +179,7 @@ Block::Forward(const std::vector> &x) { GPT2::GPT2(const GPT2Config &config) : config_(config) { int pp_size = nn::parallel::global::GetPipelineParallelSize(); auto [is_first_stage, is_last_stage, start_layer, end_layer] - = nn::parallel::PipelineParallel::GetStageInfo(config_.n_layer, pp_size); + = nn::parallel::PipelineParallel::GetStageInfo(config_.n_layer, pp_size, nn::parallel::pp_rank_tls); auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); @@ -230,7 +230,7 @@ GPT2::GPT2(const GPT2Config &config) : config_(config) { std::vector> GPT2::Forward(const std::vector> &x) { - int pp_rank = nn::parallel::pp_rank; + int pp_rank = nn::parallel::pp_rank_tls; int pp_size = nn::parallel::global::GetPipelineParallelSize(); bool is_first_stage = (pp_rank == 0); bool is_last_stage = (pp_rank == pp_size - 1); @@ -353,7 +353,7 @@ std::shared_ptr GPT2::FromLLMC(const std::string &filepath) { int pp_size = nn::parallel::global::GetPipelineParallelSize(); auto [is_first_stage, is_last_stage, start_layer, end_layer] - = nn::parallel::PipelineParallel::GetStageInfo(n_layer, pp_size); + = nn::parallel::PipelineParallel::GetStageInfo(n_layer, pp_size, nn::parallel::pp_rank_tls); auto tp_rank = nn::parallel::tp_rank; // calculate xx_size_per_partition diff --git a/example/llama3/main.cc b/example/llama3/main.cc index e67fbf75..2aa59730 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -28,7 +28,6 @@ #include "example/common/tiny_shakespeare_dataset.h" #include "example/common/tokenizer.h" -#include "example/common/utils.h" #include "example/llama3/net.h" // I/O @@ -110,25 +109,25 @@ void Train(const nn::parallel::Rank &rank) { device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, rank.thread_rank()); if (ddp_world_size > 1) { - ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.thread_rank()), - GetDataParallelGroupRanks(rank.thread_rank())); + ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()), + GetDataParallelGroupRanks(rank.GlobalRank())); ddp_rank = ddp_pg->GetGroupRank(rank.thread_rank()); } if (tp_world_size > 1) { - tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.thread_rank()), - GetTensorParallelGroupRanks(rank.thread_rank())); + tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()), + GetTensorParallelGroupRanks(rank.GlobalRank())); tp_rank = tp_pg->GetGroupRank(rank.thread_rank()); // NOTE(zbl): Reserved for VocabParallelEmbedding nn::parallel::tp_rank = tp_rank; } if (pp_world_size > 1) { - pp_pg = ProcessGroupFactory::Instance()->GetOrCreate( - GetPipelineParallelProcessGroupName(rank.thread_rank()), GetPipelineParallelGroupRanks(pp_world_size)); + pp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()), + GetPipelineParallelGroupRanks(rank.GlobalRank())); pp_rank = pp_pg->GetGroupRank(rank.thread_rank()); - nn::parallel::pp_rank = pp_rank; + nn::parallel::pp_rank_tls = pp_rank; } } else { device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance()->GetDefaultDevice() @@ -210,8 +209,9 @@ void Train(const nn::parallel::Rank &rank) { if (pp_world_size > 1) { auto shapes = std::vector>{{FLAGS_batch_size, FLAGS_sequence_length, model_config.n_embd}}; - model = std::make_shared( - model, pp_world_size, num_micro_batches, shapes, pp_rank, std::make_shared(optimizer)); + model = std::make_shared(model, pp_world_size, num_micro_batches, shapes, + pp_rank, std::make_shared(optimizer), + rank.thread_rank()); } for (int step = 0; step < FLAGS_num_iteration + 1; ++step) { @@ -274,6 +274,8 @@ void Train(const nn::parallel::Rank &rank) { LOG(INFO) << "Rank " << rank.thread_rank() << ": finish loss forward"; if (ddp_world_size > 1) { + // FIXME(dcj): should only allreduce lossf, not the entire loss tensor + // FIXME(dcj): should use ddp_pg function::AllReduce(loss, function::ReduceOpType::kAvg); } auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice()); @@ -293,13 +295,25 @@ void Train(const nn::parallel::Rank &rank) { y = std::make_shared(y->To(device)); lossf = model->TrainStep({x}, {y}, loss_fn); + + // FIXME(dcj): refactor this logic into a separate function + if (ddp_world_size > 1) { + auto loss_tensor = std::make_shared(std::vector{}, DataType::kFLOAT32); + static_cast(loss_tensor->DataPtr())[0] = lossf; + auto loss_device_ptr = std::make_shared(loss_tensor->To(device)); + function::AllReduce(loss_device_ptr, function::ReduceOpType::kAvg, ddp_pg); + lossf = static_cast( + loss_device_ptr->To(DeviceManager::Instance()->GetDefaultDevice()).DataPtr())[0]; + } } const auto iter_end = std::chrono::high_resolution_clock::now(); const double duration_us = std::chrono::duration(iter_end - iter_start).count(); const double tps = FLAGS_total_batch_size / (duration_us / 1e6); - if (rank.thread_rank() == pp_world_size - 1) { + if ((pp_world_size == 1 && rank.IsMainRank()) + || (global::GetGroupId(global::PP, rank.GlobalRank()) == 0 && pp_world_size > 1 + && pp_rank == pp_world_size - 1)) { LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, " "DP={}, TP={}, SP={}, PP={})", step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, @@ -318,6 +332,10 @@ void Train(const nn::parallel::Rank &rank) { Profiler::Instance().Report("llama3.report", Profiler::SortBy::DeviceTimePercentage); Profiler::Instance().PrintRecords("llama3.records.log"); #endif + + if (pp_world_size > 1 && rank.IsMainRank()) { + pp_pg->Barrier(); + } } int main(int argc, char *argv[]) { @@ -333,7 +351,7 @@ int main(int argc, char *argv[]) { if (FLAGS_nthread_per_process > 1) { std::vector threads; for (int idx = 0; idx < FLAGS_nthread_per_process; ++idx) { - nn::parallel::Rank rank(nn::parallel::global::GetLocalProcRank(), idx, + nn::parallel::Rank rank(nn::parallel::global::GetGlobalProcRank(), idx, nn::parallel::global::GetNprocPerNode(), FLAGS_nthread_per_process); threads.emplace_back(Train, rank); } diff --git a/example/llama3/net.cc b/example/llama3/net.cc index 3e7becfa..107d19df 100644 --- a/example/llama3/net.cc +++ b/example/llama3/net.cc @@ -328,7 +328,7 @@ std::vector> Block::Forward(const std::vector> transformer; if (is_first_stage) { @@ -356,7 +356,7 @@ LLaMA3::LLaMA3(const LLaMA3Config &config) : config_(config) { } std::vector> LLaMA3::Forward(const std::vector> &x) { - int pp_rank = nn::parallel::pp_rank; + int pp_rank = nn::parallel::pp_rank_tls; int pp_size = nn::parallel::global::GetPipelineParallelSize(); bool is_first_stage = (pp_rank == 0); bool is_last_stage = (pp_rank == pp_size - 1); @@ -467,7 +467,7 @@ std::shared_ptr LLaMA3::FromLLMC(const std::string &filepath) { .max_gen_batch_size = max_gen_bs}); int pp_size = nn::parallel::global::GetPipelineParallelSize(); auto [is_first_stage, is_last_stage, start_layer, end_layer] - = nn::parallel::PipelineParallel::GetStageInfo(n_layer, pp_size); + = nn::parallel::PipelineParallel::GetStageInfo(n_layer, pp_size, nn::parallel::pp_rank_tls); const int tp_size = nn::parallel::global::GetTensorParallelSize(); const int tp_rank = nn::parallel::tp_rank; diff --git a/infini_train/include/nn/parallel/global.h b/infini_train/include/nn/parallel/global.h index 1a6e22fd..c07b62e8 100644 --- a/infini_train/include/nn/parallel/global.h +++ b/infini_train/include/nn/parallel/global.h @@ -29,14 +29,16 @@ class GlobalEnv { void Init(int threads_per_process, int tensor_parallel_size, bool sequence_parallel_enabled, int pipeline_parallel_size); + int nnodes() const; + + int nproc_per_node() const; + int world_size() const; int global_proc_rank() const; int local_proc_rank() const; - int nproc_per_node() const; - int nthread_per_process() const; int tensor_parallel_size() const; @@ -57,9 +59,11 @@ class GlobalEnv { GlobalEnv &operator=(const GlobalEnv &) = delete; private: - int world_size_ = 1; + int nnodes_ = 1; int nproc_per_node_ = 1; int nthread_per_process_ = 1; + int world_size_ = 1; + int global_proc_rank_ = 0; int local_proc_rank_ = 0; @@ -82,6 +86,7 @@ inline void InitAllEnv(int nthread_per_process, int tensor_parallel_size, bool s pipeline_parallel_size); } +inline int GetNnodes() { return GlobalEnv::Instance().nnodes(); } inline int GetWorldSize() { return GlobalEnv::Instance().world_size(); } inline int GetNprocPerNode() { return GlobalEnv::Instance().nproc_per_node(); } inline int GetNthreadPerProc() { return GlobalEnv::Instance().nthread_per_process(); } @@ -93,28 +98,82 @@ inline bool GetSequenceParallelEnabled() { return GlobalEnv::Instance().sequence inline int GetDataParallelSize() { return GlobalEnv::Instance().data_parallel_size(); } inline int GetPipelineParallelSize() { return GlobalEnv::Instance().pipeline_parallel_size(); } +// ========================= // Layout Helper Functions +// ========================= + +/** + * @brief Get the global rank corresponding to the given (dp, tp, pp) coordinate. + */ inline int GetRankOf(int dp, int tp, int pp) { return GlobalEnv::Instance().layout().RankOf(dp, tp, pp); } +/** + * @brief Get the (dp, tp, pp) coordinate corresponding to the given global rank. + */ inline void GetCoordOf(int rank, int &dp, int &tp, int &pp) { return GlobalEnv::Instance().layout().CoordOf(rank, dp, tp, pp); } + +/** + * @brief Get the group ID that the (dp, tp, pp) coordinate belongs to along a given parallel axis. + */ inline int GetGroupId(Axis target, int dp, int tp, int pp) { return GlobalEnv::Instance().layout().GroupId(target, dp, tp, pp); } +/** + * @brief Get the group ID that a given rank belongs to along a specific parallel axis. + */ inline int GetGroupId(Axis target, int rank) { int dp, tp, pp; GetCoordOf(rank, dp, tp, pp); return GlobalEnv::Instance().layout().GroupId(target, dp, tp, pp); } + +/** + * @brief Get all ranks that belong to the same group as the given (dp, tp, pp) coordinate + * along a specified parallel axis (e.g., all ranks in the same TP group). + */ inline std::vector GetGroupRanks(Axis target, int dp, int tp, int pp) { return GlobalEnv::Instance().layout().GroupRanks(target, dp, tp, pp); } + +/** + * @brief Get all ranks that belong to the same group as the given rank + * along a specified parallel axis (e.g., all ranks in the same DP group). + */ inline std::vector GetGroupRanks(Axis target, int rank) { int dp, tp, pp; GetCoordOf(rank, dp, tp, pp); return GlobalEnv::Instance().layout().GroupRanks(target, dp, tp, pp); } +/** + * @brief Generate a human-readable overview of all parallel communication groups. + * + * The output is intended for debugging, logging, and runtime verification of + * distributed parallelism configuration. + * + * @param L The Layout describing DP / TP / PP sizes and axis ordering. + * @param skip_trivial_axes + * If true, axes whose size <= 1(i.e. parallel strategy that is not enabled) + * will be marked as "unenabled" and their detailed group listing will be skipped. + * + * @return A formatted string containing the full overview of process groups. + * + * Example: + * === Parallel Communication Groups === + * world_size = 8, config: {DP=2, TP=4, PP=1}, order: {DP -> TP -> PP} + * [DP] size=2, num_groups=4 + * - DP 0 (dp=-, tp=0, pp=0): [0, 4] + * - DP 1 (dp=-, tp=1, pp=0): [1, 5] + * - DP 2 (dp=-, tp=2, pp=0): [2, 6] + * - DP 3 (dp=-, tp=3, pp=0): [3, 7] + * + * [TP] size=4, num_groups=2 + * - TP 0 (dp=0, tp=-, pp=0): [0, 1, 2, 3] + * - TP 1 (dp=1, tp=-, pp=0): [4, 5, 6, 7] + * + * [PP] size=1, unenabled + */ std::string ProcessGroupOverview(const Layout &L = GlobalEnv::Instance().layout(), bool skip_trivial_axes = true); } // namespace infini_train::nn::parallel::global diff --git a/infini_train/include/nn/parallel/pp/pipeline_parallel.h b/infini_train/include/nn/parallel/pp/pipeline_parallel.h index 4a09d519..fa605edf 100644 --- a/infini_train/include/nn/parallel/pp/pipeline_parallel.h +++ b/infini_train/include/nn/parallel/pp/pipeline_parallel.h @@ -16,18 +16,18 @@ namespace infini_train::nn::parallel { class PipelineStage; class PipelineSchedule; -extern thread_local int pp_rank; +extern thread_local int pp_rank_tls; class PipelineParallel : public Module { public: PipelineParallel(const std::shared_ptr module, int num_stages, int num_micro_batches, const std::vector> &recv_shape, int rank, - const std::shared_ptr &optimizer); + const std::shared_ptr &optimizer, int device_id); float TrainStep(const std::vector> &input, const std::vector> &target, const std::shared_ptr &loss_fn); - static std::tuple GetStageInfo(int total_layers, int pp_size); + static std::tuple GetStageInfo(int total_layers, int pp_size, int pp_rank); private: int num_stages_ = -1; @@ -36,7 +36,7 @@ class PipelineParallel : public Module { std::shared_ptr schedule_ = nullptr; void BuildPipelineStage(const std::shared_ptr &model, const std::shared_ptr &optimizer, - const std::vector> &recv_shape); + const std::vector> &recv_shape, int device_id); void SetupSchedule(int num_micro_batches); }; diff --git a/infini_train/include/nn/parallel/pp/pipeline_schedule.h b/infini_train/include/nn/parallel/pp/pipeline_schedule.h index de24ba3e..50c44ba1 100644 --- a/infini_train/include/nn/parallel/pp/pipeline_schedule.h +++ b/infini_train/include/nn/parallel/pp/pipeline_schedule.h @@ -16,8 +16,8 @@ class PipelineStage; class PipelineSchedule { public: - PipelineSchedule(std::shared_ptr stage, int num_stages, int num_micro_batches, int stage_index) - : stage_(std::move(stage)), num_micro_batches_(num_micro_batches), stage_index_(stage_index) {} + PipelineSchedule(std::shared_ptr stage, int num_stages, int num_micro_batches) + : stage_(std::move(stage)), num_micro_batches_(num_micro_batches) {} virtual ~PipelineSchedule() = default; @@ -34,14 +34,13 @@ class PipelineSchedule { protected: int num_micro_batches_ = -1; - int stage_index_ = -1; std::shared_ptr stage_ = nullptr; }; class ScheduleGPipe : public PipelineSchedule { public: - ScheduleGPipe(std::shared_ptr stage, int num_stages, int num_micro_batches, int stage_index) - : PipelineSchedule(std::move(stage), num_stages, num_micro_batches, stage_index){}; + ScheduleGPipe(std::shared_ptr stage, int num_stages, int num_micro_batches) + : PipelineSchedule(std::move(stage), num_stages, num_micro_batches){}; float StepMicroBatches(const std::vector> &arg_mbs, const std::vector> &target_mbs, @@ -50,8 +49,8 @@ class ScheduleGPipe : public PipelineSchedule { class Schedule1F1B : public PipelineSchedule { public: - Schedule1F1B(std::shared_ptr stage, int num_stages, int num_micro_batches, int stage_index) - : PipelineSchedule(std::move(stage), num_stages, num_micro_batches, stage_index){}; + Schedule1F1B(std::shared_ptr stage, int num_stages, int num_micro_batches) + : PipelineSchedule(std::move(stage), num_stages, num_micro_batches){}; float StepMicroBatches(const std::vector> &arg_mbs, const std::vector> &target_mbs, diff --git a/infini_train/include/nn/parallel/pp/pipeline_stage.h b/infini_train/include/nn/parallel/pp/pipeline_stage.h index ac7df6f0..b7679d8c 100644 --- a/infini_train/include/nn/parallel/pp/pipeline_stage.h +++ b/infini_train/include/nn/parallel/pp/pipeline_stage.h @@ -17,7 +17,8 @@ namespace infini_train::nn::parallel { class PipelineStage { public: PipelineStage(const std::shared_ptr &model, int stage_index, int num_stages, - const std::vector> &recv_shape, std::shared_ptr optimizer); + const std::vector> &recv_shape, std::shared_ptr optimizer, + int device_id); std::vector> ForwardOneChunk(const std::vector> &inputs); diff --git a/infini_train/include/nn/parallel/process_group.h b/infini_train/include/nn/parallel/process_group.h index 5919e054..f807b3e3 100644 --- a/infini_train/include/nn/parallel/process_group.h +++ b/infini_train/include/nn/parallel/process_group.h @@ -1,8 +1,10 @@ #pragma once +#include #include #include #include +#include #include #include @@ -26,7 +28,8 @@ namespace infini_train::nn::parallel { #ifdef USE_NCCL class ProcessGroup { public: - explicit ProcessGroup(const std::vector &device_indices); + explicit ProcessGroup(const std::string &process_group_name, const std::vector &device_indices); + ~ProcessGroup(); int GetGroupRank(int thread_rank) const; @@ -52,6 +55,13 @@ class ProcessGroup { std::vector> NcclRecv(std::vector> tensors, int src_rank) const; + void Barrier() const; + +private: + void InitSingleProcess(const std::vector &ranks); + + void InitMultiProcess(const std::vector &ranks); + private: std::vector comms_; std::vector devices_; @@ -59,7 +69,11 @@ class ProcessGroup { std::unordered_map device_comm_map_; std::unordered_map thread_group_rank_map_; // thread_rank : group_rank - int comm_size_ = 0; + int world_size_ = 0; + + const std::string name_ = ""; + + bool is_main_process_ = false; }; #endif @@ -79,8 +93,29 @@ class ProcessGroupFactory { private: ProcessGroupFactory(); + + template >> + const ProcessGroup *GetOrCreate(const std::string &name, Creator &&creator) { + std::unique_lock lock(mutex_); + auto [it, inserted] = name_to_group_.emplace(name, nullptr); + if (!inserted) { + while (it->second == nullptr) { cond_.wait(lock); } + return it->second.get(); + } + + lock.unlock(); + auto new_group = creator(); + lock.lock(); + + it->second = std::move(new_group); + cond_.notify_all(); + return it->second.get(); + } + +private: // TODO(dcj): maybe RWLock later? mutable std::mutex mutex_; + std::condition_variable cond_; std::unordered_map> name_to_group_; }; } // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/rank.h b/infini_train/include/nn/parallel/rank.h index 56b76b97..c5d9e185 100644 --- a/infini_train/include/nn/parallel/rank.h +++ b/infini_train/include/nn/parallel/rank.h @@ -10,6 +10,8 @@ class Rank { int process_size() const; int thread_size() const; + int GlobalRank() const; + bool IsParallel() const; bool IsMainRank() const; diff --git a/infini_train/include/nn/parallel/utils.h b/infini_train/include/nn/parallel/utils.h index 3eb3960d..300cf0f2 100644 --- a/infini_train/include/nn/parallel/utils.h +++ b/infini_train/include/nn/parallel/utils.h @@ -4,15 +4,15 @@ #include namespace infini_train::nn::parallel { -std::string GetDataParallelProcessGroupName(int thread_rank); +std::string GetDataParallelProcessGroupName(int global_rank); -std::string GetTensorParallelProcessGroupName(int thread_rank); +std::string GetTensorParallelProcessGroupName(int global_rank); -std::string GetPipelineParallelProcessGroupName(int thread_rank); +std::string GetPipelineParallelProcessGroupName(int global_rank); -std::vector GetDataParallelGroupRanks(int rank); +std::vector GetDataParallelGroupRanks(int global_rank); -std::vector GetTensorParallelGroupRanks(int rank); +std::vector GetTensorParallelGroupRanks(int global_rank); -std::vector GetPipelineParallelGroupRanks(int pp_world_size); +std::vector GetPipelineParallelGroupRanks(int global_rank); } // namespace infini_train::nn::parallel diff --git a/infini_train/src/device.cc b/infini_train/src/device.cc index 99bfc52e..4271ff97 100644 --- a/infini_train/src/device.cc +++ b/infini_train/src/device.cc @@ -64,7 +64,7 @@ nn::parallel::Rank CudaDevice::rank() const { return rank_; } CudaDevice::CudaDevice(int8_t index) : Device(DeviceType::kCUDA, index), - rank_({nn::parallel::global::GetLocalProcRank(), index, nn::parallel::global::GetNprocPerNode(), + rank_({nn::parallel::global::GetGlobalProcRank(), index, nn::parallel::global::GetNprocPerNode(), nn::parallel::global::GetNthreadPerProc()}) { // TODO(dcj): make CudaDevice initialization lazy to avoid allocating memory on all GPUs in single-GPU mode SetDevice(); diff --git a/infini_train/src/nn/parallel/distributed_data_parallel.cc b/infini_train/src/nn/parallel/distributed_data_parallel.cc index 26ca62c5..b7fc5d91 100644 --- a/infini_train/src/nn/parallel/distributed_data_parallel.cc +++ b/infini_train/src/nn/parallel/distributed_data_parallel.cc @@ -23,7 +23,8 @@ DistributedDataParallel::DistributedDataParallel(std::shared_ptr mod CHECK_EQ(device->Index(), device_id) << "All parameters must be on the same device as the module"; auto ddp_pg - = ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(device->rank().thread_rank())); + = ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(device->rank().GlobalRank())); + // FIXME(dcj): use multi-node ddp_pg here auto hook = std::make_unique(function::ReduceOpType::kAvg, ddp_pg); param->RegisterPostAccumulateGradHook(std::move(hook)); diff --git a/infini_train/src/nn/parallel/global.cc b/infini_train/src/nn/parallel/global.cc index 8c97b225..527505a3 100644 --- a/infini_train/src/nn/parallel/global.cc +++ b/infini_train/src/nn/parallel/global.cc @@ -2,6 +2,7 @@ #include #include +#include #include #include "glog/logging.h" @@ -13,6 +14,11 @@ int GetEnvAsInt(const std::string &name, int default_value) { return value ? std::atoi(value) : default_value; } +std::string GetEnvAsStr(const std::string &name, const std::string &default_value) { + const char *value = std::getenv(name.c_str()); + return value ? std::string(value) : default_value; +} + } // namespace namespace infini_train::nn::parallel::global { @@ -90,8 +96,9 @@ void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool seq CHECK(!initialized_) << "Repeated initialization of GlobalEnv!"; - world_size_ = GetEnvAsInt("PROC_WORLD_SIZE", 1) * nthread_per_process; + nnodes_ = GetEnvAsInt("NNODES", 1); nproc_per_node_ = GetEnvAsInt("NPROC_PER_NODE", 1); + world_size_ = GetEnvAsInt("PROC_WORLD_SIZE", 1) * nthread_per_process; global_proc_rank_ = GetEnvAsInt("GLOBAL_PROC_RANK", 0); local_proc_rank_ = GetEnvAsInt("LOCAL_PROC_RANK", 0); @@ -104,36 +111,40 @@ void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool seq layout_.sizes[DP] = data_parallel_size_; layout_.sizes[TP] = tensor_parallel_size_; - // FIXME(zbl): set PP size - layout_.sizes[PP] = 1; + layout_.sizes[PP] = pipeline_parallel_size_; layout_.InitStrides(); initialized_ = true; } -int GlobalEnv::world_size() const { +int GlobalEnv::nnodes() const { CHECK(initialized_) << "GlobalEnv is not initialized!"; - return world_size_; + return nnodes_; } -int GlobalEnv::global_proc_rank() const { +int GlobalEnv::nproc_per_node() const { CHECK(initialized_) << "GlobalEnv is not initialized!"; - return global_proc_rank_; + return nproc_per_node_; } -int GlobalEnv::local_proc_rank() const { +int GlobalEnv::nthread_per_process() const { CHECK(initialized_) << "GlobalEnv is not initialized!"; - return local_proc_rank_; + return nthread_per_process_; } -int GlobalEnv::nproc_per_node() const { +int GlobalEnv::world_size() const { CHECK(initialized_) << "GlobalEnv is not initialized!"; - return nproc_per_node_; + return world_size_; } -int GlobalEnv::nthread_per_process() const { +int GlobalEnv::global_proc_rank() const { CHECK(initialized_) << "GlobalEnv is not initialized!"; - return nthread_per_process_; + return global_proc_rank_; +} + +int GlobalEnv::local_proc_rank() const { + CHECK(initialized_) << "GlobalEnv is not initialized!"; + return local_proc_rank_; } int GlobalEnv::tensor_parallel_size() const { @@ -208,34 +219,6 @@ inline void AppendAxisGroups(std::ostringstream &oss, const Layout &L, Axis targ } } -/** - * @brief Generate a human-readable overview of all parallel communication groups. - * - * The output is intended for debugging, logging, and runtime verification of - * distributed parallelism configuration. - * - * @param L The Layout describing DP / TP / PP sizes and axis ordering. - * @param skip_trivial_axes - * If true, axes whose size <= 1(i.e. parallel strategy that is not enabled) - * will be marked as "unenabled" and their detailed group listing will be skipped. - * - * @return A formatted string containing the full overview of process groups. - * - * Example: - * === Parallel Communication Groups === - * world_size = 8, config: {DP=2, TP=4, PP=1}, order: {DP -> TP -> PP} - * [DP] size=2, num_groups=4 - * - DP 0 (dp=-, tp=0, pp=0): [0, 4] - * - DP 1 (dp=-, tp=1, pp=0): [1, 5] - * - DP 2 (dp=-, tp=2, pp=0): [2, 6] - * - DP 3 (dp=-, tp=3, pp=0): [3, 7] - * - * [TP] size=4, num_groups=2 - * - TP 0 (dp=0, tp=-, pp=0): [0, 1, 2, 3] - * - TP 1 (dp=1, tp=-, pp=0): [4, 5, 6, 7] - * - * [PP] size=1, unenabled - */ std::string ProcessGroupOverview(const Layout &L, bool skip_trivial_axes) { std::ostringstream oss; oss << std::format("\n=== Parallel Communication Groups ===\n" diff --git a/infini_train/src/nn/parallel/pp/pipeline_parallel.cc b/infini_train/src/nn/parallel/pp/pipeline_parallel.cc index ef14f666..5d447662 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_parallel.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_parallel.cc @@ -15,16 +15,16 @@ namespace { constexpr char kModuleName[] = "module"; } // namespace -thread_local int pp_rank = 0; +thread_local int pp_rank_tls = 0; void PipelineParallel::BuildPipelineStage(const std::shared_ptr &module, const std::shared_ptr &optimizer, - const std::vector> &recv_shape) { - pipeline_stage_ = std::make_shared(module, rank_, num_stages_, recv_shape, optimizer); + const std::vector> &recv_shape, int device_id) { + pipeline_stage_ = std::make_shared(module, rank_, num_stages_, recv_shape, optimizer, device_id); } void PipelineParallel::SetupSchedule(int num_micro_batches) { - schedule_ = std::make_shared(pipeline_stage_, num_stages_, num_micro_batches, rank_); + schedule_ = std::make_shared(pipeline_stage_, num_stages_, num_micro_batches); } float PipelineParallel::TrainStep(const std::vector> &input, @@ -39,8 +39,7 @@ float PipelineParallel::TrainStep(const std::vector> &in return schedule_->Step(stage_input, stage_target, loss_fn); } -std::tuple PipelineParallel::GetStageInfo(int total_layers, int pp_size) { - int rank = pp_rank; +std::tuple PipelineParallel::GetStageInfo(int total_layers, int pp_size, int pp_rank) { bool is_first_stage = (pp_rank == 0); bool is_last_stage = (pp_rank == pp_size - 1); @@ -59,12 +58,12 @@ std::tuple PipelineParallel::GetStageInfo(int total_layers } PipelineParallel::PipelineParallel(const std::shared_ptr module, int num_stages, int num_micro_batches, - const std::vector> &recv_shape, int rank, - const std::shared_ptr &optimizer) - : num_stages_(num_stages), rank_(rank) { + const std::vector> &recv_shape, int pp_rank, + const std::shared_ptr &optimizer, int device_id) + : num_stages_(num_stages), rank_(pp_rank) { modules_[kModuleName] = std::move(module); - BuildPipelineStage(module, optimizer, recv_shape); + BuildPipelineStage(module, optimizer, recv_shape, device_id); SetupSchedule(num_micro_batches); } diff --git a/infini_train/src/nn/parallel/pp/pipeline_schedule.cc b/infini_train/src/nn/parallel/pp/pipeline_schedule.cc index 03caeb61..cb94de44 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_schedule.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_schedule.cc @@ -2,12 +2,10 @@ #include "infini_train/include/nn/parallel/pp/pipeline_schedule.h" #include -#include #include #include "glog/logging.h" -#include "infini_train/include/autograd/grad_mode.h" #include "infini_train/include/device.h" #include "infini_train/include/nn/init.h" #include "infini_train/include/nn/modules/module.h" @@ -90,9 +88,11 @@ float ScheduleGPipe::StepMicroBatches(const std::vector> for (int mb = 0; mb < n; ++mb) { auto out_tensor = outputs[mb][0]; - auto gradient = std::make_shared(out_tensor->Dims(), out_tensor->Dtype(), out_tensor->GetDevice()); + auto dummy_gradient + = std::make_shared(out_tensor->Dims(), out_tensor->Dtype(), out_tensor->GetDevice()); - out_tensor->Backward(gradient); + out_tensor->Backward(dummy_gradient); + cudaStreamSynchronize(dynamic_cast(stage_->device())->Stream()); } } else { for (int mb = 0; mb < n; ++mb) { diff --git a/infini_train/src/nn/parallel/pp/pipeline_stage.cc b/infini_train/src/nn/parallel/pp/pipeline_stage.cc index dc9d160c..bdb8cba9 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_stage.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_stage.cc @@ -6,16 +6,18 @@ #include "infini_train/include/device.h" #include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/parallel/process_group.h" namespace infini_train::nn::parallel { -PipelineStage::PipelineStage(const std::shared_ptr &model, int stage_index, int num_stages, - const std::vector> &recv_shape, std::shared_ptr optimizer) +PipelineStage::PipelineStage(const std::shared_ptr &model, int stage_index /* pp_rank */, + int num_stages /* pp_size */, const std::vector> &recv_shape, + std::shared_ptr optimizer, int device_id) : model_(model), stage_index_(stage_index), num_stages_(num_stages), prev_rank_(stage_index > 0 ? stage_index - 1 : -1), next_rank_(stage_index < num_stages - 1 ? stage_index + 1 : -1), recv_shape_(recv_shape), optimizer_(std::move(optimizer)), - device_(DeviceManager::Instance()->GetAllAvailableDevices(DeviceType::kCUDA).at(stage_index)) {} + device_(DeviceManager::Instance()->GetAllAvailableDevices(DeviceType::kCUDA).at(device_id)) {} std::vector> PipelineStage::ForwardOneChunk(const std::vector> &inputs) { diff --git a/infini_train/src/nn/parallel/pp/send_recv.cc b/infini_train/src/nn/parallel/pp/send_recv.cc index 9c8c5917..6a24a0e7 100644 --- a/infini_train/src/nn/parallel/pp/send_recv.cc +++ b/infini_train/src/nn/parallel/pp/send_recv.cc @@ -1,6 +1,5 @@ #include "infini_train/include/nn/parallel/pp/send_recv.h" -#include #include #include @@ -61,8 +60,8 @@ std::vector> ISend::Forward(const std::vectorGetDevice(); - auto pp_group = ProcessGroupFactory::Instance()->Get( - GetPipelineParallelProcessGroupName(input_device_->rank().thread_rank())); + auto pp_group + = ProcessGroupFactory::Instance()->Get(GetPipelineParallelProcessGroupName(input_device_->rank().GlobalRank())); pp_group->NcclSend(input_tensors, peer_rank_); @@ -77,8 +76,8 @@ std::vector> ISend::Backward(const std::vectorGet( - GetPipelineParallelProcessGroupName(input_device_->rank().thread_rank())); + auto pp_group + = ProcessGroupFactory::Instance()->Get(GetPipelineParallelProcessGroupName(input_device_->rank().GlobalRank())); return pp_group->NcclRecv(recv_tensors, peer_rank_); } @@ -86,7 +85,7 @@ std::vector> ISend::Backward(const std::vector> IRecv::Forward(const std::vector> &recv_tensors) { CHECK_NOTNULL(src_device_); auto pp_group - = ProcessGroupFactory::Instance()->Get(GetPipelineParallelProcessGroupName(src_device_->rank().thread_rank())); + = ProcessGroupFactory::Instance()->Get(GetPipelineParallelProcessGroupName(src_device_->rank().GlobalRank())); pp_group->NcclRecv(recv_tensors, peer_rank_); return recv_tensors; @@ -102,7 +101,7 @@ void IRecv::SetupContext(const std::vector> &input_tenso std::vector> IRecv::Backward(const std::vector> &grad_outputs) { auto pp_group - = ProcessGroupFactory::Instance()->Get(GetPipelineParallelProcessGroupName(cur_device_->rank().thread_rank())); + = ProcessGroupFactory::Instance()->Get(GetPipelineParallelProcessGroupName(cur_device_->rank().GlobalRank())); return pp_group->NcclSend(grad_outputs, peer_rank_); } } // namespace functions diff --git a/infini_train/src/nn/parallel/process_group.cc b/infini_train/src/nn/parallel/process_group.cc index 3271331c..05386b87 100644 --- a/infini_train/src/nn/parallel/process_group.cc +++ b/infini_train/src/nn/parallel/process_group.cc @@ -1,6 +1,14 @@ #include "infini_train/include/nn/parallel/process_group.h" +#include +#include +#include +#include +#include +#include +#include #include +#include #include #ifdef USE_NCCL @@ -18,6 +26,9 @@ namespace infini_train { namespace { +using nn::parallel::function::ReduceOpType; + +#ifdef USE_NCCL const std::unordered_map kNcclDtypeMap = { {DataType::kUINT8, ncclUint8}, {DataType::kINT8, ncclInt8}, {DataType::kUINT32, ncclUint32}, {DataType::kINT32, ncclInt32}, {DataType::kUINT64, ncclUint64}, {DataType::kINT64, ncclInt64}, @@ -25,14 +36,49 @@ const std::unordered_map kNcclDtypeMap = { {DataType::kFLOAT64, ncclFloat64}, }; -using nn::parallel::function::ReduceOpType; - const std::unordered_map kNcclReduceOpMap = { {ReduceOpType::kSum, ncclSum}, {ReduceOpType::kProd, ncclProd}, {ReduceOpType::kMax, ncclMax}, {ReduceOpType::kAvg, ncclAvg}, }; + +inline std::string NcclFileName(const std::string &name, bool tmp = false) { + return std::format("ncclUniqueId_{}.{}", name, tmp ? "tmp" : "bin"); +} + +void WriteNcclUniqueId(const ncclUniqueId &nccl_id, const std::string &pg_name) { + std::string tmp_path = NcclFileName(pg_name, true); + + std::ofstream ofs(tmp_path, std::ios::binary); + ofs.write(reinterpret_cast(&nccl_id), sizeof(nccl_id)); + ofs.close(); + + std::rename(tmp_path.c_str(), NcclFileName(pg_name).c_str()); +} + +void ReadNcclUniqueId(ncclUniqueId &nccl_id, const std::string &pg_name) { + std::string file_path = NcclFileName(pg_name); + + while (std::filesystem::exists(file_path) == false) { + std::this_thread::sleep_for(std::chrono::microseconds(1000)); + } + + std::ifstream ifs(file_path, std::ios::binary); + ifs.read(reinterpret_cast(&nccl_id), sizeof(nccl_id)); + ifs.close(); +} + +void CleanupNcclIdFile(const std::string &pg_name) { + const std::filesystem::path cwd = std::filesystem::current_path(); + std::string file_path = NcclFileName(pg_name); + + if (std::filesystem::exists(file_path)) { + std::filesystem::remove(file_path); + } +} +#endif + } // namespace } // namespace infini_train @@ -40,18 +86,73 @@ const std::unordered_map kNcclReduceOpMap = { namespace infini_train::nn::parallel { #ifdef USE_NCCL -ProcessGroup::ProcessGroup(const std::vector &device_indices) : comm_size_(device_indices.size()) { - comms_.resize(comm_size_); - NCCL_CHECK(ncclCommInitAll(comms_.data(), comm_size_, device_indices.data())); +ProcessGroup::ProcessGroup(const std::string &process_group_name, const std::vector &ranks) + : world_size_(ranks.size()), name_(process_group_name) { + if (global::GetNnodes() == 1 && global::GetNprocPerNode() == 1) { + InitSingleProcess(ranks); + } else { + InitMultiProcess(ranks); + } +} + +ProcessGroup::~ProcessGroup() { + if (is_main_process_) { + CleanupNcclIdFile(name_); + } +} + +void ProcessGroup::InitSingleProcess(const std::vector &ranks) { + comms_.resize(world_size_); + NCCL_CHECK(ncclCommInitAll(comms_.data(), world_size_, ranks.data())); - for (int i = 0; i < comm_size_; ++i) { - auto device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, device_indices[i]); + for (int i = 0; i < ranks.size(); ++i) { + auto device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, ranks[i]); devices_.push_back(device); device_comm_map_[device] = comms_[i]; thread_group_rank_map_[device->rank().thread_rank()] = i; } } +void ProcessGroup::InitMultiProcess(const std::vector &ranks) { + int n_threads = global::GetNthreadPerProc(); + int global_proc_rank = global::GetGlobalProcRank(); + int lower_rank = global_proc_rank * n_threads; + int upper_rank = (global_proc_rank + 1) * n_threads; + + ncclUniqueId nccl_id; + + int min_rank = std::ranges::min(ranks); + if (min_rank < upper_rank && min_rank >= lower_rank) { + is_main_process_ = true; + + ncclGetUniqueId(&nccl_id); + WriteNcclUniqueId(nccl_id, name_); + } else { + ReadNcclUniqueId(nccl_id, name_); + } + + std::vector device_indices; + NCCL_CHECK(ncclGroupStart()); + for (int i = 0; i < n_threads; ++i) { + int global_thread_rank = lower_rank + i; + auto it = std::ranges::find(ranks, global_thread_rank); + if (it != ranks.end()) { + cudaSetDevice(i); + + ncclComm_t comm; + int group_rank = std::distance(ranks.begin(), it); + NCCL_CHECK(ncclCommInitRank(&comm, world_size_, nccl_id, group_rank)); + comms_.push_back(comm); + + auto device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, i); + thread_group_rank_map_[device->rank().thread_rank()] = group_rank; + devices_.push_back(device); + device_comm_map_[device] = comm; + } + } + NCCL_CHECK(ncclGroupEnd()); +} + int ProcessGroup::GetGroupRank(int thread_rank) const { return thread_group_rank_map_.at(thread_rank); } void ProcessGroup::AllReduce(const std::shared_ptr &tensor, function::ReduceOpType reduce_op) const { @@ -92,7 +193,9 @@ ProcessGroup::BroadCast(const std::vector> &input_tensor std::vector comms; std::vector devices; - for (size_t i = 0; i < comm_size_; ++i) { + CHECK_EQ(world_size_, comms_.size()); + + for (size_t i = 0; i < world_size_; ++i) { auto device = devices_[i]; for (const auto &input_tensor : input_tensors) { outputs.push_back(std::make_shared(input_tensor->Dims(), input_tensor->Dtype(), device)); @@ -269,9 +372,11 @@ std::vector> ProcessGroup::NcclSend(std::vector(tensor->GetDevice()); - cudaStream_t stream = device_ptr->Stream(); - ncclComm_t comm = device_comm_map_.at(device_ptr); + auto device = tensor->GetDevice(); + device->SetDevice(); + + cudaStream_t stream = dynamic_cast(device)->Stream(); + ncclComm_t comm = device_comm_map_.at(device); CHECK_NE(dest_rank, -1) << "Destination device not found in input tensors's devices"; @@ -292,9 +397,11 @@ std::vector> ProcessGroup::NcclRecv(std::vector(tensor->GetDevice()); - cudaStream_t stream = device_ptr->Stream(); - ncclComm_t comm = device_comm_map_.at(device_ptr); + auto device = tensor->GetDevice(); + device->SetDevice(); + + cudaStream_t stream = dynamic_cast(device)->Stream(); + ncclComm_t comm = device_comm_map_.at(device); CHECK_NE(src_rank, -1) << "Source device not found in input devices"; @@ -308,6 +415,22 @@ std::vector> ProcessGroup::NcclRecv(std::vector results(1, 0); + + NCCL_CHECK(ncclGroupStart()); + for (const auto &device : devices_) { + device->SetDevice(); + auto comm = device_comm_map_.at(device); + auto cuda_dev = dynamic_cast(device); + NCCL_CHECK(ncclAllReduce(&dummy, &dummy, 1, ncclInt, ncclSum, comm, cuda_dev->Stream())); + } + NCCL_CHECK(ncclGroupEnd()); +} #endif ProcessGroupFactory *ProcessGroupFactory::Instance() { @@ -323,30 +446,13 @@ ProcessGroupFactory *ProcessGroupFactory::Instance() { } const ProcessGroup *ProcessGroupFactory::GetOrCreate(const std::string &name, int comm_size) { - std::vector devices(comm_size); - std::iota(devices.begin(), devices.end(), 0); - const std::vector &device_indices = devices; - - return GetOrCreate(name, device_indices); + std::vector device_indices(comm_size); + std::iota(device_indices.begin(), device_indices.end(), 0); + return GetOrCreate(name, [&]() { return std::make_unique(name, device_indices); }); } const ProcessGroup *ProcessGroupFactory::GetOrCreate(const std::string &name, const std::vector &device_indices) { - { - std::lock_guard lock(mutex_); - auto it = name_to_group_.find(name); - if (it != name_to_group_.end()) { - return it->second.get(); - } - } - - auto new_group = std::make_unique(device_indices); - - { - std::lock_guard lock(mutex_); - - auto [it, inserted] = name_to_group_.emplace(name, std::move(new_group)); - return it->second.get(); - } + return GetOrCreate(name, [&]() { return std::make_unique(name, device_indices); }); } const ProcessGroup *ProcessGroupFactory::Get(const std::string &name) const { diff --git a/infini_train/src/nn/parallel/rank.cc b/infini_train/src/nn/parallel/rank.cc index 73ecd5de..0ec36b8f 100644 --- a/infini_train/src/nn/parallel/rank.cc +++ b/infini_train/src/nn/parallel/rank.cc @@ -10,6 +10,8 @@ int Rank::thread_rank() const { return thread_rank_; } int Rank::process_size() const { return process_size_; } int Rank::thread_size() const { return thread_size_; } +int Rank::GlobalRank() const { return process_rank_ * thread_size_ + thread_rank_; } + bool Rank::IsParallel() const { return thread_size_ * process_size_ > 1; } bool Rank::IsMainRank() const { return thread_rank_ == 0; } } // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/tensor_parallel.cc b/infini_train/src/nn/parallel/tensor_parallel.cc index 09f04c88..bc94ac8b 100644 --- a/infini_train/src/nn/parallel/tensor_parallel.cc +++ b/infini_train/src/nn/parallel/tensor_parallel.cc @@ -35,7 +35,7 @@ std::shared_ptr GatherAlongFirstDim(const std::shared_ptr &tenso auto device = tensor->GetDevice(); auto tp_group - = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().thread_rank())); + = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().GlobalRank())); std::vector output_shape = tensor->Dims(); output_shape[0] *= world_size; @@ -55,7 +55,7 @@ std::shared_ptr GatherAlongLastDim(const std::shared_ptr &tensor auto device = tensor->GetDevice(); auto tp_group - = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().thread_rank())); + = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().GlobalRank())); std::vector output_shape = tensor->Dims(); output_shape[0] *= world_size; @@ -80,7 +80,7 @@ std::shared_ptr SplitAlongLastDim(const std::shared_ptr &tensor) auto device = tensor->GetDevice(); auto tp_group - = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().thread_rank())); + = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().GlobalRank())); auto rank = tp_group->GetGroupRank(device->rank().thread_rank()); auto last_dim_size = tensor->Dims().back() / world_size; @@ -98,7 +98,7 @@ std::shared_ptr Reduce(const std::shared_ptr &tensor) { auto device = tensor->GetDevice(); auto tp_group - = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().thread_rank())); + = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().GlobalRank())); auto output = std::make_shared(*tensor); @@ -116,7 +116,7 @@ std::shared_ptr ReduceScatterAlongFirstDim(const std::shared_ptr auto device = tensor->GetDevice(); auto tp_group - = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().thread_rank())); + = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().GlobalRank())); auto output_shape = tensor->Dims(); CHECK_EQ(output_shape[0] % world_size, 0) << "First dimension of the tensor should be divisible by TP world size"; @@ -435,8 +435,7 @@ VocabParallelCrossEntropy::Forward(const std::vector> &i const ProcessGroup *tp_group = nullptr; int rank = 0; if (tp_size > 1) { - tp_group - = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().thread_rank())); + tp_group = ProcessGroupFactory::Instance()->Get(GetTensorParallelProcessGroupName(device->rank().GlobalRank())); rank = tp_group->GetGroupRank(device->rank().thread_rank()); } diff --git a/infini_train/src/nn/parallel/utils.cc b/infini_train/src/nn/parallel/utils.cc index 4f661880..45b3118c 100644 --- a/infini_train/src/nn/parallel/utils.cc +++ b/infini_train/src/nn/parallel/utils.cc @@ -4,26 +4,23 @@ namespace infini_train::nn::parallel { -std::string GetDataParallelProcessGroupName(int thread_rank) { - return "DP" + std::to_string(global::GetGroupId(global::DP, thread_rank)); +std::string GetDataParallelProcessGroupName(int global_rank) { + return "DP" + std::to_string(global::GetGroupId(global::DP, global_rank)); } -std::string GetTensorParallelProcessGroupName(int thread_rank) { - return "TP" + std::to_string(global::GetGroupId(global::TP, thread_rank)); +std::string GetTensorParallelProcessGroupName(int global_rank) { + return "TP" + std::to_string(global::GetGroupId(global::TP, global_rank)); } -std::string GetPipelineParallelProcessGroupName(int thread_rank) { - return "PP" + std::to_string(global::GetGroupId(global::PP, thread_rank)); +std::string GetPipelineParallelProcessGroupName(int global_rank) { + return "PP" + std::to_string(global::GetGroupId(global::PP, global_rank)); } -std::vector GetDataParallelGroupRanks(int thread_rank) { return global::GetGroupRanks(global::DP, thread_rank); } +std::vector GetDataParallelGroupRanks(int global_rank) { return global::GetGroupRanks(global::DP, global_rank); } -std::vector GetTensorParallelGroupRanks(int thread_rank) { return global::GetGroupRanks(global::TP, thread_rank); } +std::vector GetTensorParallelGroupRanks(int global_rank) { return global::GetGroupRanks(global::TP, global_rank); } -std::vector GetPipelineParallelGroupRanks(int pp_world_size) { - std::vector ranks; - ranks.reserve(pp_world_size); - for (int i = 0; i < pp_world_size; ++i) { ranks.push_back(i); } - return ranks; +std::vector GetPipelineParallelGroupRanks(int global_rank) { + return global::GetGroupRanks(global::PP, global_rank); } } // namespace infini_train::nn::parallel diff --git a/tools/infini_run/infini_run.cc b/tools/infini_run/infini_run.cc index bcc3f25a..86604f54 100644 --- a/tools/infini_run/infini_run.cc +++ b/tools/infini_run/infini_run.cc @@ -1,8 +1,11 @@ +#include +#include +#include +#include #include #include -#include -#include #include +#include #include "gflags/gflags.h" #include "glog/logging.h" @@ -12,17 +15,15 @@ DEFINE_int32(nproc_per_node, 1, "Number of processes per node"); DEFINE_int32(node_rank, 0, "Rank of this node"); DEFINE_string(rdzv_endpoint, "127.0.0.1:29500", "Rendezvous endpoint (host:port)"); -int main(int argc, char** argv) { +int main(int argc, char **argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); google::InitGoogleLogging(argv[0]); CHECK_GE(argc, 2) << "No training prgram specified!"; std::string train_program = argv[1]; - std::vector train_argv; - for (int i = 1; i < argc; ++i) { - train_argv.push_back(argv[i]); - } + std::vector train_argv; + for (int i = 1; i < argc; ++i) { train_argv.push_back(argv[i]); } train_argv.push_back(nullptr); int world_size = FLAGS_nnodes * FLAGS_nproc_per_node; @@ -33,13 +34,17 @@ int main(int argc, char** argv) { pid_t pid = fork(); if (pid == 0) { int global_proc_rank = FLAGS_node_rank * FLAGS_nproc_per_node + local_proc_rank; - setenv("GLOBAL_PROC_RANK", std::to_string(global_proc_rank).c_str(), 1); - setenv("LOCAL_PROC_RANK", std::to_string(local_proc_rank).c_str(), 1); - setenv("PROC_WORLD_SIZE", std::to_string(world_size).c_str(), 1); + setenv("NNODES", std::to_string(FLAGS_nnodes).c_str(), 1); setenv("NPROC_PER_NODE", std::to_string(FLAGS_nproc_per_node).c_str(), 1); + setenv("MASTER_ADDR", master_addr.c_str(), 1); setenv("MASTER_PORT", master_port.c_str(), 1); + setenv("GLOBAL_PROC_RANK", std::to_string(global_proc_rank).c_str(), 1); + setenv("LOCAL_PROC_RANK", std::to_string(local_proc_rank).c_str(), 1); + + setenv("PROC_WORLD_SIZE", std::to_string(world_size).c_str(), 1); + execvp(train_program.c_str(), train_argv.data()); perror("exec failed"); exit(1);