diff --git a/xprof/convert/streaming_trace_viewer_processor.cc b/xprof/convert/streaming_trace_viewer_processor.cc index 852a64dcb..ff422d291 100644 --- a/xprof/convert/streaming_trace_viewer_processor.cc +++ b/xprof/convert/streaming_trace_viewer_processor.cc @@ -1,5 +1,6 @@ #include "xprof/convert/streaming_trace_viewer_processor.h" +#include #include #include #include @@ -46,6 +47,10 @@ struct TraceViewOption { uint64_t resolution = 0; double start_time_ms = 0.0; double end_time_ms = 0.0; + std::string event_name = ""; + std::string search_prefix = ""; + double duration_ms = 0.0; + uint64_t unique_id = 0; }; absl::StatusOr GetTraceViewOption(const ToolOptions& options) { @@ -56,10 +61,21 @@ absl::StatusOr GetTraceViewOption(const ToolOptions& options) { GetParamWithDefault(options, "end_time_ms", "0.0"); auto resolution_opt = GetParamWithDefault(options, "resolution", "0"); + trace_options.event_name = + GetParamWithDefault(options, "event_name", ""); + trace_options.search_prefix = + GetParamWithDefault(options, "search_prefix", ""); + auto duration_ms_opt = + GetParamWithDefault(options, "duration_ms", "0.0"); + auto unique_id_opt = + GetParamWithDefault(options, "unique_id", "0"); + if (!absl::SimpleAtoi(resolution_opt, &trace_options.resolution) || !absl::SimpleAtod(start_time_ms_opt, &trace_options.start_time_ms) || - !absl::SimpleAtod(end_time_ms_opt, &trace_options.end_time_ms)) { + !absl::SimpleAtod(end_time_ms_opt, &trace_options.end_time_ms) || + !absl::SimpleAtoi(unique_id_opt, &trace_options.unique_id) || + !absl::SimpleAtod(duration_ms_opt, &trace_options.duration_ms)) { return tsl::errors::InvalidArgument("wrong arguments"); } return trace_options; @@ -84,36 +100,81 @@ absl::Status StreamingTraceViewerProcessor::ProcessSession( /*derived_timeline=*/true); std::string host_name = session_snapshot.GetHostname(i); - auto sstable_path = session_snapshot.GetFilePath(tool_name, host_name); - if (!sstable_path) { + auto trace_events_sstable_path = session_snapshot.MakeHostDataFilePath( + tensorflow::profiler::StoredDataType::TRACE_LEVELDB, host_name); + auto trace_events_metadata_sstable_path = + session_snapshot.MakeHostDataFilePath( + tensorflow::profiler::StoredDataType::TRACE_EVENTS_METADATA_LEVELDB, + host_name); + auto trace_events_prefix_trie_sstable_path = + session_snapshot.MakeHostDataFilePath( + tensorflow::profiler::StoredDataType:: + TRACE_EVENTS_PREFIX_TRIE_LEVELDB, + host_name); + if (!trace_events_sstable_path || !trace_events_metadata_sstable_path || + !trace_events_prefix_trie_sstable_path) { return tsl::errors::Unimplemented( "streaming trace viewer hasn't been supported in Cloud AI"); } - if (!tsl::Env::Default()->FileExists(*sstable_path).ok()) { + if (!tsl::Env::Default()->FileExists(*trace_events_sstable_path).ok()) { ProcessMegascaleDcn(xspace); TraceEventsContainer trace_container; ConvertXSpaceToTraceEventsContainer(host_name, *xspace, &trace_container); - std::unique_ptr file; - TF_RETURN_IF_ERROR( - tsl::Env::Default()->NewWritableFile(*sstable_path, &file)); - TF_RETURN_IF_ERROR(trace_container.StoreAsLevelDbTable(std::move(file))); + std::unique_ptr trace_events_file; + TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile( + *trace_events_sstable_path, &trace_events_file)); + std::unique_ptr trace_events_metadata_file; + TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile( + *trace_events_metadata_sstable_path, &trace_events_metadata_file)); + std::unique_ptr trace_events_prefix_trie_file; + TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile( + *trace_events_prefix_trie_sstable_path, + &trace_events_prefix_trie_file)); + TF_RETURN_IF_ERROR(trace_container.StoreAsLevelDbTables( + std::move(trace_events_file), + std::move(trace_events_metadata_file), + std::move(trace_events_prefix_trie_file) + )); } - auto visibility_filter = std::make_unique( - tsl::profiler::MilliSpan(trace_option.start_time_ms, - trace_option.end_time_ms), - trace_option.resolution, profiler_trace_options); - TraceEventsContainer trace_container; - // Trace smaller than threshold will be disabled from streaming. - constexpr int64_t kDisableStreamingThreshold = 500000; - auto trace_events_filter = - CreateTraceEventsFilterFromTraceOptions(profiler_trace_options); TraceEventsLevelDbFilePaths file_paths; - file_paths.trace_events_file_path = *sstable_path; - TF_RETURN_IF_ERROR(trace_container.LoadFromLevelDbTable( - file_paths, std::move(trace_events_filter), - std::move(visibility_filter), kDisableStreamingThreshold)); + file_paths.trace_events_file_path = *trace_events_sstable_path; + file_paths.trace_events_metadata_file_path = + *trace_events_metadata_sstable_path; + file_paths.trace_events_prefix_trie_file_path = + *trace_events_prefix_trie_sstable_path; + + TraceEventsContainer trace_container; + if (!trace_option.event_name.empty()) { + TF_RETURN_IF_ERROR(trace_container.ReadFullEventFromLevelDbTable( + *trace_events_metadata_sstable_path, *trace_events_sstable_path, + trace_option.event_name, + static_cast(std::round(trace_option.start_time_ms * 1E9)), + static_cast(std::round(trace_option.duration_ms * 1E9)), + trace_option.unique_id)); + } else if (!trace_option.search_prefix.empty()) { // Search Events Request + if (tsl::Env::Default() + ->FileExists(*trace_events_prefix_trie_sstable_path).ok()) { + auto trace_events_filter = + CreateTraceEventsFilterFromTraceOptions(profiler_trace_options); + TF_RETURN_IF_ERROR(trace_container.SearchInLevelDbTable( + file_paths, + trace_option.search_prefix, std::move(trace_events_filter))); + } + } else { + auto visibility_filter = std::make_unique( + tsl::profiler::MilliSpan(trace_option.start_time_ms, + trace_option.end_time_ms), + trace_option.resolution, profiler_trace_options); + // Trace smaller than threshold will be disabled from streaming. + constexpr int64_t kDisableStreamingThreshold = 500000; + auto trace_events_filter = + CreateTraceEventsFilterFromTraceOptions(profiler_trace_options); + TF_RETURN_IF_ERROR(trace_container.LoadFromLevelDbTable( + file_paths, std::move(trace_events_filter), + std::move(visibility_filter), kDisableStreamingThreshold)); + } merged_trace_container.Merge(std::move(trace_container), host_id); }