Skip to content

Commit 8f9ae95

Browse files
Profiler Teamcopybara-github
authored andcommitted
Support For Roofline Analysis Of Pallas Kernels
PiperOrigin-RevId: 814542103
1 parent 4d291b3 commit 8f9ae95

File tree

5 files changed

+481
-0
lines changed

5 files changed

+481
-0
lines changed

xprof/convert/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,7 @@ cc_library(
884884
":xplane_to_tf_functions",
885885
":xprof_thread_pool_executor",
886886
"@com_google_absl//absl/cleanup",
887+
"@com_google_absl//absl/container:flat_hash_map",
887888
"@com_google_absl//absl/container:flat_hash_set",
888889
"@com_google_absl//absl/log",
889890
"@com_google_absl//absl/log:check",

xprof/convert/xplane_to_op_stats.cc

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ limitations under the License.
2323
#include <ostream>
2424
#include <string>
2525
#include <vector>
26+
#include <queue>
2627

2728
#include "absl/cleanup/cleanup.h"
2829
#include "absl/container/flat_hash_set.h"
@@ -66,6 +67,7 @@ limitations under the License.
6667
#include "xprof/utils/kernel_stats_utils.h"
6768
#include "xprof/utils/op_utils.h"
6869
#include "xprof/utils/xprof_gpu_cost_analysis_types.h"
70+
#include "absl/container/flat_hash_map.h"
6971

7072
namespace tensorflow {
7173
namespace profiler {
@@ -77,6 +79,7 @@ using ::tsl::profiler::kGpuPlanePrefix;
7779
using ::tsl::profiler::kTpuPlanePrefix;
7880
using tsl::profiler::Timespan;
7981
using ::tsl::profiler::XPlaneBuilder;
82+
using tsl::profiler::kXlaOpLineName;
8083

8184
std::string Hostname(const XSpace& space) {
8285
if (space.hostnames().empty()) return "localhost";
@@ -380,6 +383,113 @@ OpStats ConvertXSpaceToOpStats(const XSpace& space,
380383
.has_value()) {
381384
op_metrics_db =
382385
ConvertTpuDeviceTraceXPlaneToOpMetricsDb(*device_plane);
386+
XPlaneVisitor visitorSecond =
387+
tsl::profiler::CreateTfXPlaneVisitor(device_plane);
388+
std::queue<XEventVisitor> custom_call_blocks;
389+
visitorSecond.ForEachLine([&](const XLineVisitor& line) {
390+
if (line.Name() == "XLA TraceMe") {
391+
line.ForEachEvent([&](const XEventVisitor& event) {
392+
if (absl::StartsWith(event.Name(), "__block_")){
393+
custom_call_blocks.push(event);
394+
}
395+
});
396+
}
397+
});
398+
absl::flat_hash_map<std::string,
399+
absl::flat_hash_map<std::string, uint>>
400+
custom_call_to_block_count;
401+
402+
XPlaneVisitor xlaEvents =
403+
tsl::profiler::CreateTfXPlaneVisitor(device_plane);
404+
xlaEvents.ForEachLine([&](const XLineVisitor& line) {
405+
if (line.Name() == kXlaOpLineName) {
406+
line.ForEachEvent([&](const XEventVisitor& event) {
407+
tsl::profiler::Timespan custom_call_timespan =
408+
GetDeviceEventTimespan(event);
409+
bool custom_call = false;
410+
event.Metadata().ForEachStat([&]
411+
(const XStatVisitor& stat) {
412+
if (stat.Type().has_value()) {
413+
switch (static_cast<StatType>(*stat.Type())) {
414+
case StatType::kHloCategory:
415+
custom_call =
416+
(stat.StrOrRefValue() == "custom-call");
417+
break;
418+
default:
419+
break;
420+
}
421+
}
422+
});
423+
if (custom_call){
424+
while (!custom_call_blocks.empty()){
425+
tsl::profiler::Timespan ccall_blck_timespan =
426+
GetDeviceEventTimespan(custom_call_blocks.front());
427+
if ((custom_call_timespan.begin_ps() <=
428+
ccall_blck_timespan.begin_ps()) &&
429+
(ccall_blck_timespan.end_ps() <=
430+
custom_call_timespan.end_ps())
431+
){
432+
custom_call_to_block_count[event.DisplayName()]
433+
[std::string(custom_call_blocks.front().
434+
Name())] += 1;
435+
custom_call_blocks.pop();
436+
}else{
437+
break;
438+
}
439+
}
440+
}
441+
});
442+
}
443+
});
444+
for (OpMetrics& op_metrics :
445+
*op_metrics_db.mutable_metrics_db()) {
446+
const HloInstructionWrapper* instr_wrapper =
447+
GetHloInstruction(hlo_module_map,
448+
op_metrics.hlo_module_id(), op_metrics.name());
449+
if (instr_wrapper != nullptr) {
450+
if (instr_wrapper->Category() == "custom-call"){
451+
uint64 total_flops = 0;
452+
uint64 total_bytes_accessed = 0;
453+
bool has_block_costs =
454+
custom_call_to_block_count.contains(op_metrics.name());
455+
if (has_block_costs){
456+
for (auto&[block_name, occurrence] :
457+
custom_call_to_block_count[op_metrics.name()]){
458+
auto block_cost_pair = instr_wrapper->
459+
GetCustomCallBlockCosts(block_name);
460+
if (block_cost_pair.has_value()){
461+
OpMetrics* child_metric =
462+
op_metrics.mutable_children()->add_metrics_db();
463+
child_metric->set_name(block_name);
464+
child_metric->set_occurrences(occurrence);
465+
child_metric->set_flops(
466+
block_cost_pair.value().first);
467+
child_metric->set_model_flops(
468+
block_cost_pair.value().first);
469+
child_metric->set_bytes_accessed(
470+
block_cost_pair.value().second);
471+
total_flops +=
472+
(occurrence*block_cost_pair.value().first);
473+
total_bytes_accessed +=
474+
(occurrence*block_cost_pair.value().second);
475+
}else{
476+
LOG(WARNING) << "No Costs Found for : " << block_name;
477+
}
478+
}
479+
if (instr_wrapper->FusedChildren().empty()){
480+
LOG(INFO) << "Custom - Call Name: "
481+
<< op_metrics.name() << " Total Flops: "
482+
<< total_flops
483+
<< " Total Bytes Accessed: " <<
484+
total_bytes_accessed;
485+
op_metrics.set_flops(total_flops);
486+
op_metrics.set_bytes_accessed(total_bytes_accessed);
487+
op_metrics.set_model_flops(total_flops);
488+
}
489+
}
490+
}
491+
}
492+
}
383493
UpdateOpMetricsDbFromHloModuleMap(op_metrics_db, hlo_module_map);
384494
}
385495
}

xprof/utils/BUILD

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,11 +398,27 @@ cc_library(
398398
srcs = ["hlo_module_map.cc"],
399399
hdrs = ["hlo_module_map.h"],
400400
deps = [
401+
":backend_configs_proto_cc",
401402
":hlo_cost_analysis_wrapper",
402403
":hlo_module_utils",
403404
":hlo_proto_map",
404405
":hlo_proto_to_module",
405406
":performance_info_wrapper",
407+
"//third_party/llvm/llvm-project/mlir:ArithDialect",
408+
"//third_party/llvm/llvm-project/mlir:DataLayoutInterfaces",
409+
"//third_party/llvm/llvm-project/mlir:FuncDialect",
410+
"//third_party/llvm/llvm-project/mlir:IR",
411+
"//third_party/llvm/llvm-project/mlir:LLVMDialect",
412+
"//third_party/llvm/llvm-project/mlir:MathDialect",
413+
"//third_party/llvm/llvm-project/mlir:MemRefDialect",
414+
"//third_party/llvm/llvm-project/mlir:Pass",
415+
"//third_party/llvm/llvm-project/mlir:SCFDialect",
416+
"//third_party/llvm/llvm-project/mlir:Support",
417+
"//third_party/llvm/llvm-project/mlir:VectorDialect",
418+
"//third_party/protobuf/json",
419+
"//third_party/protobuf/util:json_util",
420+
"//third_party/py/jax/jaxlib/mosaic:tpu_dialect",
421+
"@com_google_absl//absl/base:no_destructor",
406422
"@com_google_absl//absl/container:flat_hash_map",
407423
"@com_google_absl//absl/log",
408424
"@com_google_absl//absl/log:check",
@@ -412,6 +428,8 @@ cc_library(
412428
"@tsl//tsl/profiler/lib:traceme_encode",
413429
"@tsl//tsl/profiler/protobuf:xplane_proto_cc",
414430
"@xla//xla/hlo/ir:hlo",
431+
"@xla//xla/mlir_hlo",
432+
"@xla//xla/pjrt:mlir_to_hlo",
415433
"@xla//xla/service:hlo_cost_analysis",
416434
"@xla//xla/service:hlo_proto_cc",
417435
"@xla//xla/tsl/profiler/convert:xla_op_utils",

0 commit comments

Comments
 (0)