@@ -23,6 +23,7 @@ limitations under the License.
2323#include  < ostream> 
2424#include  < string> 
2525#include  < vector> 
26+ #include  < queue> 
2627
2728#include  " absl/cleanup/cleanup.h" 
2829#include  " absl/container/flat_hash_set.h" 
@@ -66,6 +67,7 @@ limitations under the License.
6667#include  " xprof/utils/kernel_stats_utils.h" 
6768#include  " xprof/utils/op_utils.h" 
6869#include  " xprof/utils/xprof_gpu_cost_analysis_types.h" 
70+ #include  " absl/container/flat_hash_map.h" 
6971
7072namespace  tensorflow  {
7173namespace  profiler  {
@@ -77,6 +79,7 @@ using ::tsl::profiler::kGpuPlanePrefix;
7779using  ::tsl::profiler::kTpuPlanePrefix ;
7880using  tsl::profiler::Timespan;
7981using  ::tsl::profiler::XPlaneBuilder;
82+ using  tsl::profiler::kXlaOpLineName ;
8083
8184std::string Hostname (const  XSpace& space) {
8285  if  (space.hostnames ().empty ()) return  " localhost" 
@@ -380,6 +383,113 @@ OpStats ConvertXSpaceToOpStats(const XSpace& space,
380383                     .has_value ()) {
381384              op_metrics_db =
382385                  ConvertTpuDeviceTraceXPlaneToOpMetricsDb (*device_plane);
386+               XPlaneVisitor visitorSecond =
387+                   tsl::profiler::CreateTfXPlaneVisitor (device_plane);
388+               std::queue<XEventVisitor> custom_call_blocks;
389+                   visitorSecond.ForEachLine ([&](const  XLineVisitor& line) {
390+                 if  (line.Name () == " XLA TraceMe" 
391+                   line.ForEachEvent ([&](const  XEventVisitor& event) {
392+                     if  (absl::StartsWith (event.Name (), " __block_" 
393+                       custom_call_blocks.push (event);
394+                     }
395+                   });
396+                 }
397+               });
398+               absl::flat_hash_map<std::string,
399+                 absl::flat_hash_map<std::string, uint>>
400+                   custom_call_to_block_count;
401+ 
402+               XPlaneVisitor xlaEvents =
403+                   tsl::profiler::CreateTfXPlaneVisitor (device_plane);
404+               xlaEvents.ForEachLine ([&](const  XLineVisitor& line) {
405+                 if  (line.Name () == kXlaOpLineName ) {
406+                   line.ForEachEvent ([&](const  XEventVisitor& event) {
407+                     tsl::profiler::Timespan custom_call_timespan =
408+                     GetDeviceEventTimespan (event);
409+                     bool  custom_call = false ;
410+                     event.Metadata ().ForEachStat ([&]
411+                       (const  XStatVisitor& stat) {
412+                       if  (stat.Type ().has_value ()) {
413+                         switch  (static_cast <StatType>(*stat.Type ())) {
414+                           case  StatType::kHloCategory :
415+                             custom_call =
416+                             (stat.StrOrRefValue () == " custom-call" 
417+                             break ;
418+                           default :
419+                             break ;
420+                         }
421+                       }
422+                     });
423+                     if  (custom_call){
424+                       while  (!custom_call_blocks.empty ()){
425+                         tsl::profiler::Timespan ccall_blck_timespan =
426+                           GetDeviceEventTimespan (custom_call_blocks.front ());
427+                         if  ((custom_call_timespan.begin_ps () <=
428+                           ccall_blck_timespan.begin_ps ()) &&
429+                           (ccall_blck_timespan.end_ps () <=
430+                             custom_call_timespan.end_ps ())
431+                           ){
432+                             custom_call_to_block_count[event.DisplayName ()]
433+                               [std::string (custom_call_blocks.front ().
434+                                 Name ())] += 1 ;
435+                             custom_call_blocks.pop ();
436+                         }else {
437+                           break ;
438+                         }
439+                       }
440+                     }
441+                   });
442+                 }
443+               });
444+               for  (OpMetrics& op_metrics :
445+                     *op_metrics_db.mutable_metrics_db ()) {
446+                   const  HloInstructionWrapper* instr_wrapper =
447+                   GetHloInstruction (hlo_module_map,
448+                       op_metrics.hlo_module_id (), op_metrics.name ());
449+                 if  (instr_wrapper != nullptr ) {
450+                   if  (instr_wrapper->Category () == " custom-call" 
451+                     uint64 total_flops = 0 ;
452+                     uint64 total_bytes_accessed = 0 ;
453+                     bool  has_block_costs =
454+                       custom_call_to_block_count.contains (op_metrics.name ());
455+                     if  (has_block_costs){
456+                       for  (auto &[block_name, occurrence] :
457+                         custom_call_to_block_count[op_metrics.name ()]){
458+                         auto  block_cost_pair = instr_wrapper->
459+                           GetCustomCallBlockCosts (block_name);
460+                         if   (block_cost_pair.has_value ()){
461+                           OpMetrics* child_metric =
462+                             op_metrics.mutable_children ()->add_metrics_db ();
463+                           child_metric->set_name (block_name);
464+                           child_metric->set_occurrences (occurrence);
465+                           child_metric->set_flops (
466+                             block_cost_pair.value ().first );
467+                           child_metric->set_model_flops (
468+                             block_cost_pair.value ().first );
469+                           child_metric->set_bytes_accessed (
470+                             block_cost_pair.value ().second );
471+                           total_flops +=
472+                             (occurrence*block_cost_pair.value ().first );
473+                           total_bytes_accessed +=
474+                             (occurrence*block_cost_pair.value ().second );
475+                         }else {
476+                           LOG (WARNING) << " No Costs Found for : " 
477+                         }
478+                       }
479+                       if   (instr_wrapper->FusedChildren ().empty ()){
480+                         LOG (INFO) << " Custom - Call Name: " 
481+                           << op_metrics.name () << "  Total Flops: " 
482+                           << total_flops
483+                           << "  Total Bytes Accessed: " 
484+                         total_bytes_accessed;
485+                         op_metrics.set_flops (total_flops);
486+                         op_metrics.set_bytes_accessed (total_bytes_accessed);
487+                         op_metrics.set_model_flops (total_flops);
488+                       }
489+                     }
490+                   }
491+                 }
492+               }
383493              UpdateOpMetricsDbFromHloModuleMap (op_metrics_db, hlo_module_map);
384494            }
385495          }
0 commit comments