Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

#include "example/common/tiny_shakespeare_dataset.h"
#include "example/common/tokenizer.h"
#include "example/common/utils.h"
#include "example/gpt2/net.h"

// I/O
Expand Down Expand Up @@ -128,14 +127,14 @@ void Train(const nn::parallel::Rank &rank) {
device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, rank.thread_rank());

if (ddp_world_size > 1) {
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.thread_rank()),
GetDataParallelGroupRanks(rank.thread_rank()));
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()),
GetDataParallelGroupRanks(rank.GlobalRank()));
ddp_rank = ddp_pg->GetGroupRank(rank.thread_rank());
}

if (tp_world_size > 1) {
tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.thread_rank()),
GetTensorParallelGroupRanks(rank.thread_rank()));
tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()),
GetTensorParallelGroupRanks(rank.GlobalRank()));
tp_rank = tp_pg->GetGroupRank(rank.thread_rank());
// NOTE(zbl): Reserved for VocabParallelEmbedding
nn::parallel::tp_rank = tp_rank;
Expand All @@ -146,7 +145,7 @@ void Train(const nn::parallel::Rank &rank) {
GetPipelineParallelProcessGroupName(rank.thread_rank()), GetPipelineParallelGroupRanks(pp_world_size));
pp_rank = pp_pg->GetGroupRank(rank.thread_rank());

nn::parallel::pp_rank = pp_rank;
nn::parallel::pp_rank_tls = pp_rank;
}
} else {
device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance()->GetDefaultDevice()
Expand Down Expand Up @@ -232,7 +231,8 @@ void Train(const nn::parallel::Rank &rank) {
auto shapes = std::vector<std::vector<int64_t>>{{FLAGS_batch_size, FLAGS_sequence_length, model_config.n_embd}};

model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, num_micro_batches, shapes,
pp_rank, std::make_shared<optimizers::SGD>(optimizer));
pp_rank, std::make_shared<optimizers::SGD>(optimizer),
rank.thread_rank());
}

LOG(INFO) << "start training";
Expand Down Expand Up @@ -321,7 +321,7 @@ void Train(const nn::parallel::Rank &rank) {
const double duration_us = std::chrono::duration<double, std::micro>(iter_end - iter_start).count();
const double tps = FLAGS_total_batch_size / (duration_us / 1e6);

if (rank.thread_rank() == pp_world_size - 1) {
if (rank.GlobalRank() == pp_world_size - 1) {
LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, "
"DP={}, TP={}, SP={}, PP={})",
step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f,
Expand All @@ -340,6 +340,10 @@ void Train(const nn::parallel::Rank &rank) {
Profiler::Instance().Report("gpt2.report", Profiler::SortBy::DeviceTimePercentage);
Profiler::Instance().PrintRecords("gpt2.records.log");
#endif

if (pp_world_size > 1 && rank.IsMainRank()) {
pp_pg->Barrier();
}
}

int main(int argc, char *argv[]) {
Expand All @@ -355,7 +359,7 @@ int main(int argc, char *argv[]) {
if (FLAGS_nthread_per_process > 1) {
std::vector<std::thread> threads;
for (int idx = 0; idx < FLAGS_nthread_per_process; ++idx) {
nn::parallel::Rank rank(nn::parallel::global::GetLocalProcRank(), idx,
nn::parallel::Rank rank(nn::parallel::global::GetGlobalProcRank(), idx,
nn::parallel::global::GetNprocPerNode(), FLAGS_nthread_per_process);
threads.emplace_back(Train, rank);
}
Expand Down
6 changes: 3 additions & 3 deletions example/gpt2/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ Block::Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) {
GPT2::GPT2(const GPT2Config &config) : config_(config) {
int pp_size = nn::parallel::global::GetPipelineParallelSize();
auto [is_first_stage, is_last_stage, start_layer, end_layer]
= nn::parallel::PipelineParallel::GetStageInfo(config_.n_layer, pp_size);
= nn::parallel::PipelineParallel::GetStageInfo(config_.n_layer, pp_size, nn::parallel::pp_rank_tls);

auto tp_world_size = nn::parallel::global::GetTensorParallelSize();

Expand Down Expand Up @@ -230,7 +230,7 @@ GPT2::GPT2(const GPT2Config &config) : config_(config) {

std::vector<std::shared_ptr<infini_train::Tensor>>
GPT2::Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) {
int pp_rank = nn::parallel::pp_rank;
int pp_rank = nn::parallel::pp_rank_tls;
int pp_size = nn::parallel::global::GetPipelineParallelSize();
bool is_first_stage = (pp_rank == 0);
bool is_last_stage = (pp_rank == pp_size - 1);
Expand Down Expand Up @@ -353,7 +353,7 @@ std::shared_ptr<GPT2> GPT2::FromLLMC(const std::string &filepath) {

int pp_size = nn::parallel::global::GetPipelineParallelSize();
auto [is_first_stage, is_last_stage, start_layer, end_layer]
= nn::parallel::PipelineParallel::GetStageInfo(n_layer, pp_size);
= nn::parallel::PipelineParallel::GetStageInfo(n_layer, pp_size, nn::parallel::pp_rank_tls);

auto tp_rank = nn::parallel::tp_rank;
// calculate xx_size_per_partition
Expand Down
28 changes: 16 additions & 12 deletions example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

#include "example/common/tiny_shakespeare_dataset.h"
#include "example/common/tokenizer.h"
#include "example/common/utils.h"
#include "example/llama3/net.h"

// I/O
Expand Down Expand Up @@ -110,25 +109,25 @@ void Train(const nn::parallel::Rank &rank) {
device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, rank.thread_rank());

if (ddp_world_size > 1) {
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.thread_rank()),
GetDataParallelGroupRanks(rank.thread_rank()));
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.GlobalRank()),
GetDataParallelGroupRanks(rank.GlobalRank()));
ddp_rank = ddp_pg->GetGroupRank(rank.thread_rank());
}

if (tp_world_size > 1) {
tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.thread_rank()),
GetTensorParallelGroupRanks(rank.thread_rank()));
tp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.GlobalRank()),
GetTensorParallelGroupRanks(rank.GlobalRank()));
tp_rank = tp_pg->GetGroupRank(rank.thread_rank());
// NOTE(zbl): Reserved for VocabParallelEmbedding
nn::parallel::tp_rank = tp_rank;
}

if (pp_world_size > 1) {
pp_pg = ProcessGroupFactory::Instance()->GetOrCreate(
GetPipelineParallelProcessGroupName(rank.thread_rank()), GetPipelineParallelGroupRanks(pp_world_size));
pp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()),
GetPipelineParallelGroupRanks(rank.GlobalRank()));
pp_rank = pp_pg->GetGroupRank(rank.thread_rank());

nn::parallel::pp_rank = pp_rank;
nn::parallel::pp_rank_tls = pp_rank;
}
} else {
device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance()->GetDefaultDevice()
Expand Down Expand Up @@ -210,8 +209,9 @@ void Train(const nn::parallel::Rank &rank) {
if (pp_world_size > 1) {
auto shapes = std::vector<std::vector<int64_t>>{{FLAGS_batch_size, FLAGS_sequence_length, model_config.n_embd}};

model = std::make_shared<nn::parallel::PipelineParallel>(
model, pp_world_size, num_micro_batches, shapes, pp_rank, std::make_shared<optimizers::Adam>(optimizer));
model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, num_micro_batches, shapes,
pp_rank, std::make_shared<optimizers::Adam>(optimizer),
rank.thread_rank());
}

for (int step = 0; step < FLAGS_num_iteration + 1; ++step) {
Expand Down Expand Up @@ -299,7 +299,7 @@ void Train(const nn::parallel::Rank &rank) {
const double duration_us = std::chrono::duration<double, std::micro>(iter_end - iter_start).count();
const double tps = FLAGS_total_batch_size / (duration_us / 1e6);

if (rank.thread_rank() == pp_world_size - 1) {
if (rank.GlobalRank() == pp_world_size - 1) {
LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, "
"DP={}, TP={}, SP={}, PP={})",
step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f,
Expand All @@ -318,6 +318,10 @@ void Train(const nn::parallel::Rank &rank) {
Profiler::Instance().Report("llama3.report", Profiler::SortBy::DeviceTimePercentage);
Profiler::Instance().PrintRecords("llama3.records.log");
#endif

if (pp_world_size > 1 && rank.IsMainRank()) {
pp_pg->Barrier();
}
}

int main(int argc, char *argv[]) {
Expand All @@ -333,7 +337,7 @@ int main(int argc, char *argv[]) {
if (FLAGS_nthread_per_process > 1) {
std::vector<std::thread> threads;
for (int idx = 0; idx < FLAGS_nthread_per_process; ++idx) {
nn::parallel::Rank rank(nn::parallel::global::GetLocalProcRank(), idx,
nn::parallel::Rank rank(nn::parallel::global::GetGlobalProcRank(), idx,
nn::parallel::global::GetNprocPerNode(), FLAGS_nthread_per_process);
threads.emplace_back(Train, rank);
}
Expand Down
6 changes: 3 additions & 3 deletions example/llama3/net.cc
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ std::vector<std::shared_ptr<Tensor>> Block::Forward(const std::vector<std::share
LLaMA3::LLaMA3(const LLaMA3Config &config) : config_(config) {
int pp_size = nn::parallel::global::GetPipelineParallelSize();
auto [is_first_stage, is_last_stage, start_layer, end_layer]
= nn::parallel::PipelineParallel::GetStageInfo(config_.n_layer, pp_size);
= nn::parallel::PipelineParallel::GetStageInfo(config_.n_layer, pp_size, nn::parallel::pp_rank_tls);

std::unordered_map<std::string, std::shared_ptr<nn::Module>> transformer;
if (is_first_stage) {
Expand Down Expand Up @@ -356,7 +356,7 @@ LLaMA3::LLaMA3(const LLaMA3Config &config) : config_(config) {
}

std::vector<std::shared_ptr<Tensor>> LLaMA3::Forward(const std::vector<std::shared_ptr<Tensor>> &x) {
int pp_rank = nn::parallel::pp_rank;
int pp_rank = nn::parallel::pp_rank_tls;
int pp_size = nn::parallel::global::GetPipelineParallelSize();
bool is_first_stage = (pp_rank == 0);
bool is_last_stage = (pp_rank == pp_size - 1);
Expand Down Expand Up @@ -467,7 +467,7 @@ std::shared_ptr<LLaMA3> LLaMA3::FromLLMC(const std::string &filepath) {
.max_gen_batch_size = max_gen_bs});
int pp_size = nn::parallel::global::GetPipelineParallelSize();
auto [is_first_stage, is_last_stage, start_layer, end_layer]
= nn::parallel::PipelineParallel::GetStageInfo(n_layer, pp_size);
= nn::parallel::PipelineParallel::GetStageInfo(n_layer, pp_size, nn::parallel::pp_rank_tls);

const int tp_size = nn::parallel::global::GetTensorParallelSize();
const int tp_rank = nn::parallel::tp_rank;
Expand Down
65 changes: 62 additions & 3 deletions infini_train/include/nn/parallel/global.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,16 @@ class GlobalEnv {
void Init(int threads_per_process, int tensor_parallel_size, bool sequence_parallel_enabled,
int pipeline_parallel_size);

int nnodes() const;

int nproc_per_node() const;

int world_size() const;

int global_proc_rank() const;

int local_proc_rank() const;

int nproc_per_node() const;

int nthread_per_process() const;

int tensor_parallel_size() const;
Expand All @@ -57,9 +59,11 @@ class GlobalEnv {
GlobalEnv &operator=(const GlobalEnv &) = delete;

private:
int world_size_ = 1;
int nnodes_ = 1;
int nproc_per_node_ = 1;
int nthread_per_process_ = 1;
int world_size_ = 1;

int global_proc_rank_ = 0;
int local_proc_rank_ = 0;

Expand All @@ -82,6 +86,7 @@ inline void InitAllEnv(int nthread_per_process, int tensor_parallel_size, bool s
pipeline_parallel_size);
}

inline int GetNnodes() { return GlobalEnv::Instance().nnodes(); }
inline int GetWorldSize() { return GlobalEnv::Instance().world_size(); }
inline int GetNprocPerNode() { return GlobalEnv::Instance().nproc_per_node(); }
inline int GetNthreadPerProc() { return GlobalEnv::Instance().nthread_per_process(); }
Expand All @@ -93,28 +98,82 @@ inline bool GetSequenceParallelEnabled() { return GlobalEnv::Instance().sequence
inline int GetDataParallelSize() { return GlobalEnv::Instance().data_parallel_size(); }
inline int GetPipelineParallelSize() { return GlobalEnv::Instance().pipeline_parallel_size(); }

// =========================
// Layout Helper Functions
// =========================

/**
* @brief Get the global rank corresponding to the given (dp, tp, pp) coordinate.
*/
inline int GetRankOf(int dp, int tp, int pp) { return GlobalEnv::Instance().layout().RankOf(dp, tp, pp); }
/**
* @brief Get the (dp, tp, pp) coordinate corresponding to the given global rank.
*/
inline void GetCoordOf(int rank, int &dp, int &tp, int &pp) {
return GlobalEnv::Instance().layout().CoordOf(rank, dp, tp, pp);
}

/**
* @brief Get the group ID that the (dp, tp, pp) coordinate belongs to along a given parallel axis.
*/
inline int GetGroupId(Axis target, int dp, int tp, int pp) {
return GlobalEnv::Instance().layout().GroupId(target, dp, tp, pp);
}
/**
* @brief Get the group ID that a given rank belongs to along a specific parallel axis.
*/
inline int GetGroupId(Axis target, int rank) {
int dp, tp, pp;
GetCoordOf(rank, dp, tp, pp);
return GlobalEnv::Instance().layout().GroupId(target, dp, tp, pp);
}

/**
* @brief Get all ranks that belong to the same group as the given (dp, tp, pp) coordinate
* along a specified parallel axis (e.g., all ranks in the same TP group).
*/
inline std::vector<int> GetGroupRanks(Axis target, int dp, int tp, int pp) {
return GlobalEnv::Instance().layout().GroupRanks(target, dp, tp, pp);
}

/**
* @brief Get all ranks that belong to the same group as the given rank
* along a specified parallel axis (e.g., all ranks in the same DP group).
*/
inline std::vector<int> GetGroupRanks(Axis target, int rank) {
int dp, tp, pp;
GetCoordOf(rank, dp, tp, pp);
return GlobalEnv::Instance().layout().GroupRanks(target, dp, tp, pp);
}

/**
* @brief Generate a human-readable overview of all parallel communication groups.
*
* The output is intended for debugging, logging, and runtime verification of
* distributed parallelism configuration.
*
* @param L The Layout describing DP / TP / PP sizes and axis ordering.
* @param skip_trivial_axes
* If true, axes whose size <= 1(i.e. parallel strategy that is not enabled)
* will be marked as "unenabled" and their detailed group listing will be skipped.
*
* @return A formatted string containing the full overview of process groups.
*
* Example:
* === Parallel Communication Groups ===
* world_size = 8, config: {DP=2, TP=4, PP=1}, order: {DP -> TP -> PP}
* [DP] size=2, num_groups=4
* - DP 0 (dp=-, tp=0, pp=0): [0, 4]
* - DP 1 (dp=-, tp=1, pp=0): [1, 5]
* - DP 2 (dp=-, tp=2, pp=0): [2, 6]
* - DP 3 (dp=-, tp=3, pp=0): [3, 7]
*
* [TP] size=4, num_groups=2
* - TP 0 (dp=0, tp=-, pp=0): [0, 1, 2, 3]
* - TP 1 (dp=1, tp=-, pp=0): [4, 5, 6, 7]
*
* [PP] size=1, unenabled
*/
std::string ProcessGroupOverview(const Layout &L = GlobalEnv::Instance().layout(), bool skip_trivial_axes = true);

} // namespace infini_train::nn::parallel::global
8 changes: 4 additions & 4 deletions infini_train/include/nn/parallel/pp/pipeline_parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@ namespace infini_train::nn::parallel {
class PipelineStage;
class PipelineSchedule;

extern thread_local int pp_rank;
extern thread_local int pp_rank_tls;

class PipelineParallel : public Module {
public:
PipelineParallel(const std::shared_ptr<nn::Module> module, int num_stages, int num_micro_batches,
const std::vector<std::vector<int64_t>> &recv_shape, int rank,
const std::shared_ptr<Optimizer> &optimizer);
const std::shared_ptr<Optimizer> &optimizer, int device_id);

float TrainStep(const std::vector<std::shared_ptr<Tensor>> &input,
const std::vector<std::shared_ptr<Tensor>> &target, const std::shared_ptr<nn::Module> &loss_fn);

static std::tuple<bool, bool, int, int> GetStageInfo(int total_layers, int pp_size);
static std::tuple<bool, bool, int, int> GetStageInfo(int total_layers, int pp_size, int pp_rank);

private:
int num_stages_ = -1;
Expand All @@ -36,7 +36,7 @@ class PipelineParallel : public Module {
std::shared_ptr<PipelineSchedule> schedule_ = nullptr;

void BuildPipelineStage(const std::shared_ptr<nn::Module> &model, const std::shared_ptr<Optimizer> &optimizer,
const std::vector<std::vector<int64_t>> &recv_shape);
const std::vector<std::vector<int64_t>> &recv_shape, int device_id);

void SetupSchedule(int num_micro_batches);
};
Expand Down
3 changes: 2 additions & 1 deletion infini_train/include/nn/parallel/pp/pipeline_stage.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ namespace infini_train::nn::parallel {
class PipelineStage {
public:
PipelineStage(const std::shared_ptr<nn::Module> &model, int stage_index, int num_stages,
const std::vector<std::vector<int64_t>> &recv_shape, std::shared_ptr<Optimizer> optimizer);
const std::vector<std::vector<int64_t>> &recv_shape, std::shared_ptr<Optimizer> optimizer,
int device_id);

std::vector<std::shared_ptr<Tensor>> ForwardOneChunk(const std::vector<std::shared_ptr<Tensor>> &inputs);

Expand Down
Loading