Skip to content

Commit 181a687

Browse files
committed
feat: support multi-node pp training
1 parent 0b2a039 commit 181a687

File tree

9 files changed

+39
-19
lines changed

9 files changed

+39
-19
lines changed

example/gpt2/main.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
#include "example/common/tiny_shakespeare_dataset.h"
3131
#include "example/common/tokenizer.h"
32-
#include "example/common/utils.h"
3332
#include "example/gpt2/net.h"
3433

3534
// I/O
@@ -321,7 +320,7 @@ void Train(const nn::parallel::Rank &rank) {
321320
const double duration_us = std::chrono::duration<double, std::micro>(iter_end - iter_start).count();
322321
const double tps = FLAGS_total_batch_size / (duration_us / 1e6);
323322

324-
if (rank.thread_rank() == pp_world_size - 1) {
323+
if (rank.GlobalRank() == pp_world_size - 1) {
325324
LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, "
326325
"DP={}, TP={}, SP={}, PP={})",
327326
step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f,
@@ -340,6 +339,10 @@ void Train(const nn::parallel::Rank &rank) {
340339
Profiler::Instance().Report("gpt2.report", Profiler::SortBy::DeviceTimePercentage);
341340
Profiler::Instance().PrintRecords("gpt2.records.log");
342341
#endif
342+
343+
if (pp_world_size > 1 && rank.IsMainRank()) {
344+
pp_pg->Barrier();
345+
}
343346
}
344347

345348
int main(int argc, char *argv[]) {

example/llama3/main.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828

2929
#include "example/common/tiny_shakespeare_dataset.h"
3030
#include "example/common/tokenizer.h"
31-
#include "example/common/utils.h"
3231
#include "example/llama3/net.h"
3332

3433
// I/O
@@ -124,8 +123,8 @@ void Train(const nn::parallel::Rank &rank) {
124123
}
125124

126125
if (pp_world_size > 1) {
127-
pp_pg = ProcessGroupFactory::Instance()->GetOrCreate(
128-
GetPipelineParallelProcessGroupName(rank.thread_rank()), GetPipelineParallelGroupRanks(pp_world_size));
126+
pp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()),
127+
GetPipelineParallelGroupRanks(rank.GlobalRank()));
129128
pp_rank = pp_pg->GetGroupRank(rank.thread_rank());
130129

131130
nn::parallel::pp_rank = pp_rank;
@@ -299,7 +298,7 @@ void Train(const nn::parallel::Rank &rank) {
299298
const double duration_us = std::chrono::duration<double, std::micro>(iter_end - iter_start).count();
300299
const double tps = FLAGS_total_batch_size / (duration_us / 1e6);
301300

302-
if (rank.thread_rank() == pp_world_size - 1) {
301+
if (rank.GlobalRank() == pp_world_size - 1) {
303302
LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, "
304303
"DP={}, TP={}, SP={}, PP={})",
305304
step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f,
@@ -318,6 +317,10 @@ void Train(const nn::parallel::Rank &rank) {
318317
Profiler::Instance().Report("llama3.report", Profiler::SortBy::DeviceTimePercentage);
319318
Profiler::Instance().PrintRecords("llama3.records.log");
320319
#endif
320+
321+
if (pp_world_size > 1 && rank.IsMainRank()) {
322+
pp_pg->Barrier();
323+
}
321324
}
322325

323326
int main(int argc, char *argv[]) {

infini_train/include/nn/parallel/process_group.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ class ProcessGroup {
5555

5656
std::vector<std::shared_ptr<Tensor>> NcclRecv(std::vector<std::shared_ptr<Tensor>> tensors, int src_rank) const;
5757

58+
void Barrier() const;
59+
5860
private:
5961
void InitSingleProcess(const std::vector<int> &ranks);
6062

infini_train/include/nn/parallel/utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@ std::vector<int> GetDataParallelGroupRanks(int rank);
1414

1515
std::vector<int> GetTensorParallelGroupRanks(int rank);
1616

17-
std::vector<int> GetPipelineParallelGroupRanks(int pp_world_size);
17+
std::vector<int> GetPipelineParallelGroupRanks(int rank);
1818
} // namespace infini_train::nn::parallel

infini_train/src/nn/parallel/global.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,7 @@ void GlobalEnv::Init(int nthread_per_process, int tensor_parallel_size, bool seq
111111

112112
layout_.sizes[DP] = data_parallel_size_;
113113
layout_.sizes[TP] = tensor_parallel_size_;
114-
// FIXME(zbl): set PP size
115-
layout_.sizes[PP] = 1;
114+
layout_.sizes[PP] = pipeline_parallel_size_;
116115
layout_.InitStrides();
117116

118117
initialized_ = true;

infini_train/src/nn/parallel/pp/pipeline_schedule.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@
22
#include "infini_train/include/nn/parallel/pp/pipeline_schedule.h"
33

44
#include <cstddef>
5-
#include <cstdint>
65
#include <vector>
76

87
#include "glog/logging.h"
98

10-
#include "infini_train/include/autograd/grad_mode.h"
119
#include "infini_train/include/device.h"
1210
#include "infini_train/include/nn/init.h"
1311
#include "infini_train/include/nn/modules/module.h"
@@ -90,9 +88,11 @@ float ScheduleGPipe::StepMicroBatches(const std::vector<std::shared_ptr<Tensor>>
9088
for (int mb = 0; mb < n; ++mb) {
9189
auto out_tensor = outputs[mb][0];
9290

93-
auto gradient = std::make_shared<Tensor>(out_tensor->Dims(), out_tensor->Dtype(), out_tensor->GetDevice());
91+
auto dummy_gradient
92+
= std::make_shared<Tensor>(out_tensor->Dims(), out_tensor->Dtype(), out_tensor->GetDevice());
9493

95-
out_tensor->Backward(gradient);
94+
out_tensor->Backward(dummy_gradient);
95+
cudaStreamSynchronize(dynamic_cast<const CudaDevice *>(stage_->device())->Stream());
9696
}
9797
} else {
9898
for (int mb = 0; mb < n; ++mb) {

infini_train/src/nn/parallel/pp/pipeline_stage.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ PipelineStage::PipelineStage(const std::shared_ptr<Module> &model, int stage_ind
1515
prev_rank_(stage_index > 0 ? stage_index - 1 : -1),
1616
next_rank_(stage_index < num_stages - 1 ? stage_index + 1 : -1), recv_shape_(recv_shape),
1717
optimizer_(std::move(optimizer)),
18-
device_(DeviceManager::Instance()->GetAllAvailableDevices(DeviceType::kCUDA).at(stage_index)) {}
18+
// FIXME(dcj): use correct device
19+
device_(DeviceManager::Instance()->GetAllAvailableDevices(DeviceType::kCUDA).at(stage_index % 8)) {}
1920

2021
std::vector<std::shared_ptr<Tensor>>
2122
PipelineStage::ForwardOneChunk(const std::vector<std::shared_ptr<Tensor>> &inputs) {

infini_train/src/nn/parallel/process_group.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,21 @@ std::vector<std::shared_ptr<Tensor>> ProcessGroup::NcclRecv(std::vector<std::sha
411411
}
412412
return tensors;
413413
}
414+
415+
void ProcessGroup::Barrier() const {
416+
// NOTE(dcj): use ncclAllreduce to barrier all processes before destroying the communicators
417+
// FIXME(dcj): should only call by one rank
418+
int dummy = 1;
419+
std::vector<int> results(1, 0);
420+
421+
NCCL_CHECK(ncclGroupStart());
422+
for (const auto &device : devices_) {
423+
auto comm = device_comm_map_.at(device);
424+
auto cuda_dev = dynamic_cast<const CudaDevice *>(device);
425+
NCCL_CHECK(ncclAllReduce(&dummy, &dummy, 1, ncclInt, ncclSum, comm, cuda_dev->Stream()));
426+
}
427+
NCCL_CHECK(ncclGroupEnd());
428+
}
414429
#endif
415430

416431
ProcessGroupFactory *ProcessGroupFactory::Instance() {

infini_train/src/nn/parallel/utils.cc

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,7 @@ std::vector<int> GetDataParallelGroupRanks(int thread_rank) { return global::Get
2020

2121
std::vector<int> GetTensorParallelGroupRanks(int thread_rank) { return global::GetGroupRanks(global::TP, thread_rank); }
2222

23-
std::vector<int> GetPipelineParallelGroupRanks(int pp_world_size) {
24-
std::vector<int> ranks;
25-
ranks.reserve(pp_world_size);
26-
for (int i = 0; i < pp_world_size; ++i) { ranks.push_back(i); }
27-
return ranks;
23+
std::vector<int> GetPipelineParallelGroupRanks(int thread_rank) {
24+
return global::GetGroupRanks(global::PP, thread_rank);
2825
}
2926
} // namespace infini_train::nn::parallel

0 commit comments

Comments
 (0)