2828
2929#include " example/common/tiny_shakespeare_dataset.h"
3030#include " example/common/tokenizer.h"
31- #include " example/common/utils.h"
3231#include " example/llama3/net.h"
3332
3433// I/O
@@ -124,8 +123,8 @@ void Train(const nn::parallel::Rank &rank) {
124123 }
125124
126125 if (pp_world_size > 1 ) {
127- pp_pg = ProcessGroupFactory::Instance ()->GetOrCreate (
128- GetPipelineParallelProcessGroupName (rank.thread_rank ()), GetPipelineParallelGroupRanks (pp_world_size ));
126+ pp_pg = ProcessGroupFactory::Instance ()->GetOrCreate (GetPipelineParallelProcessGroupName (rank. GlobalRank ()),
127+ GetPipelineParallelGroupRanks (rank.GlobalRank () ));
129128 pp_rank = pp_pg->GetGroupRank (rank.thread_rank ());
130129
131130 nn::parallel::pp_rank = pp_rank;
@@ -299,7 +298,7 @@ void Train(const nn::parallel::Rank &rank) {
299298 const double duration_us = std::chrono::duration<double , std::micro>(iter_end - iter_start).count ();
300299 const double tps = FLAGS_total_batch_size / (duration_us / 1e6 );
301300
302- if (rank.thread_rank () == pp_world_size - 1 ) {
301+ if (rank.GlobalRank () == pp_world_size - 1 ) {
303302 LOG (ERROR) << std::format (" step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, "
304303 " DP={}, TP={}, SP={}, PP={})" ,
305304 step + 1 , FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f,
@@ -318,6 +317,10 @@ void Train(const nn::parallel::Rank &rank) {
318317 Profiler::Instance ().Report (" llama3.report" , Profiler::SortBy::DeviceTimePercentage);
319318 Profiler::Instance ().PrintRecords (" llama3.records.log" );
320319#endif
320+
321+ if (pp_world_size > 1 && rank.IsMainRank ()) {
322+ pp_pg->Barrier ();
323+ }
321324}
322325
323326int main (int argc, char *argv[]) {
0 commit comments