From 1a4dc4cef6d542dc9560dd24c4a5909e99ac42c4 Mon Sep 17 00:00:00 2001 From: Mudit Gokhale Date: Thu, 16 Oct 2025 11:09:07 -0700 Subject: [PATCH] Add select / de-select all option and add a submit button to the hosts sidenav for trace_viewer PiperOrigin-RevId: 820313088 --- .../app/common/interfaces/navigation_event.ts | 2 + .../app/components/sidenav/sidenav.ng.html | 26 ++- frontend/app/components/sidenav/sidenav.scss | 48 ++++++ frontend/app/components/sidenav/sidenav.ts | 150 +++++++++++++++--- .../components/trace_viewer/trace_viewer.ts | 19 ++- plugin/xprof/convert/raw_to_tool_data.py | 3 +- plugin/xprof/profile_plugin.py | 130 +++++++++++---- plugin/xprof/profile_plugin_test.py | 14 +- xprof/convert/BUILD | 2 + .../streaming_trace_viewer_processor.cc | 90 ++++++----- xprof/convert/trace_viewer/BUILD | 1 + xprof/convert/trace_viewer/trace_events.cc | 83 ++++++++++ xprof/convert/trace_viewer/trace_events.h | 5 + xprof/convert/trace_viewer/trace_utils.h | 2 + xprof/convert/xplane_to_tools_data.cc | 2 + xprof/convert/xplane_to_trace_container.cc | 10 +- 16 files changed, 472 insertions(+), 115 deletions(-) diff --git a/frontend/app/common/interfaces/navigation_event.ts b/frontend/app/common/interfaces/navigation_event.ts index e7df483ac..2cd2b79e1 100644 --- a/frontend/app/common/interfaces/navigation_event.ts +++ b/frontend/app/common/interfaces/navigation_event.ts @@ -5,6 +5,8 @@ export declare interface NavigationEvent { run?: string; tag?: string; host?: string; + // Added to support multi-host functionality for trace_viewer. + hosts?: string[]; // Graph Viewer crosslink params opName?: string; moduleName?: string; diff --git a/frontend/app/components/sidenav/sidenav.ng.html b/frontend/app/components/sidenav/sidenav.ng.html index 53761e457..54a960ab6 100644 --- a/frontend/app/components/sidenav/sidenav.ng.html +++ b/frontend/app/components/sidenav/sidenav.ng.html @@ -32,10 +32,30 @@
- Hosts ({{hosts.length}}) + Hosts ({{ isMultiHostsEnabled ? selectedHostsInternal.length : hosts.length }})
- + + +
+
+ +
+
+ + {{host}} + +
+ +
+
+
+ {{host}} @@ -43,6 +63,7 @@
+
@@ -66,4 +87,3 @@

- diff --git a/frontend/app/components/sidenav/sidenav.scss b/frontend/app/components/sidenav/sidenav.scss index fcdb8792b..e21f6aef6 100644 --- a/frontend/app/components/sidenav/sidenav.scss +++ b/frontend/app/components/sidenav/sidenav.scss @@ -13,3 +13,51 @@ .mat-subheading-2 { margin-bottom: 0; } + +// Target the custom panel class defined in sidenav.ng.html +.multi-host-select-panel { + .select-panel-content { + display: flex; + flex-direction: column; + } + + // Container for the Select All/Deselect All button at the top + .select-all-button-container { + position: sticky; + top: 0; + z-index: 10; + background: #fff; + padding: 0 8px; + box-shadow: 0 2px 5px rgba(0, 0, 0, 0.05); + + .full-width { + width: 100%; + text-align: left; + padding-left: 0; + } + + .custom-divider { + border-bottom: 1px solid rgba(0, 0, 0, 0.12); + margin: 0 -8px; + } + } + + // Container for the submit button at the bottom of the dropdown + .submit-button-container { + position: sticky; + bottom: 0; + padding: 8px; + + background: #fff; + box-shadow: 0 -2px 5px rgba(0, 0, 0, 0.1); + + .full-width { + width: 100%; + } + } + + // Ensure options remain clickable and are not covered by the sticky button's padding. + mat-option { + flex-shrink: 0; + } +} diff --git a/frontend/app/components/sidenav/sidenav.ts b/frontend/app/components/sidenav/sidenav.ts index e4e759e54..78c9fe8e4 100644 --- a/frontend/app/components/sidenav/sidenav.ts +++ b/frontend/app/components/sidenav/sidenav.ts @@ -32,8 +32,11 @@ export class SideNav implements OnInit, OnDestroy { selectedRunInternal = ''; selectedTagInternal = ''; selectedHostInternal = ''; + selectedHostsInternal: string[] = []; + selectedHostsPending: string[] = []; selectedModuleInternal = ''; navigationParams: {[key: string]: string|boolean} = {}; + multiHostEnabledTools: string[] = ['trace_viewer', 'trace_viewer@']; hideCaptureProfileButton = false; @@ -65,6 +68,11 @@ export class SideNav implements OnInit, OnDestroy { return HLO_TOOLS.includes(this.selectedTag); } + get isMultiHostsEnabled() { + const tag = this.selectedTag || ''; + return this.multiHostEnabledTools.includes(tag); + } + // Getter for valid run given url router or user selection. get selectedRun() { return this.runs.find(validRun => validRun === this.selectedRunInternal) || @@ -90,6 +98,10 @@ export class SideNav implements OnInit, OnDestroy { this.moduleList[0] || ''; } + get selectedHosts() { + return this.selectedHostsInternal; + } + // https://github.com/angular/angular/issues/11023#issuecomment-752228784 mergeRouteParams(): Map { const params = new Map(); @@ -119,20 +131,25 @@ export class SideNav implements OnInit, OnDestroy { const run = params.get('run') || ''; const tag = params.get('tool') || params.get('tag') || ''; const host = params.get('host') || ''; + const hostsParam = params.get('hosts'); const opName = params.get('node_name') || params.get('opName') || ''; const moduleName = params.get('module_name') || ''; this.navigationParams['firstLoad'] = true; if (opName) { this.navigationParams['opName'] = opName; } - if (this.selectedRunInternal === run && this.selectedTagInternal === tag && - this.selectedHostInternal === host) { - return; - } this.selectedRunInternal = run; this.selectedTagInternal = tag; - this.selectedHostInternal = host; this.selectedModuleInternal = moduleName; + + if (this.isMultiHostsEnabled) { + if (hostsParam) { + this.selectedHostsInternal = hostsParam.split(','); + } + this.selectedHostsPending = [...this.selectedHostsInternal]; + } else { + this.selectedHostInternal = host; + } this.update(); } @@ -153,9 +170,13 @@ export class SideNav implements OnInit, OnDestroy { const navigationEvent: NavigationEvent = { run: this.selectedRun, tag: this.selectedTag, - host: this.selectedHost, ...this.navigationParams, }; + if (this.isMultiHostsEnabled) { + navigationEvent.hosts = this.selectedHosts; + } else { + navigationEvent.host = this.selectedHost; + } if (this.is_hlo_tool) { navigationEvent.moduleName = this.selectedModule; } @@ -242,8 +263,21 @@ export class SideNav implements OnInit, OnDestroy { this.afterUpdateTag(); } - onTagSelectionChange(tag: string) { + async onTagSelectionChange(tag: string) { this.selectedTagInternal = tag; + this.selectedHostsInternal = []; + this.selectedHostsPending = []; + this.selectedHostInternal = ''; + + if (this.isMultiHostsEnabled) { + this.hosts = await this.getHostsForSelectedTag(); + if (this.hosts.length > 0) { + this.selectedHostsInternal = [this.hosts[0]]; + } else { + this.selectedHostsInternal = []; + } + this.selectedHostsPending = [...this.selectedHostsInternal]; + } this.afterUpdateTag(); } @@ -255,6 +289,16 @@ export class SideNav implements OnInit, OnDestroy { // Keep them under the same update function as initial step of the separation. async updateHosts() { this.hosts = await this.getHostsForSelectedTag(); + if (this.isMultiHostsEnabled) { + if (this.selectedHostsInternal.length === 0 && this.hosts.length > 0) { + this.selectedHostsInternal = [this.hosts[0]]; + } + this.selectedHostsPending = [...this.selectedHostsInternal]; + } else { + if (!this.selectedHostInternal && this.hosts.length > 0) { + this.selectedHostInternal = this.hosts[0]; + } + } if (this.is_hlo_tool) { this.moduleList = await this.getModuleListForSelectedTag(); } @@ -262,11 +306,34 @@ export class SideNav implements OnInit, OnDestroy { this.afterUpdateHost(); } - onHostSelectionChange(host: string) { - this.selectedHostInternal = host; + onHostSelectionChange(selection: string) { + this.selectedHostInternal = selection; + this.navigateTools(); + } + + onHostsSelectionChange(selection: string[]) { + this.selectedHostsPending = + Array.isArray(selection) ? selection : [selection]; + } + + onSubmitHosts() { + this.selectedHostsInternal = [...this.selectedHostsPending]; this.navigateTools(); } + toggleAllHosts() { + const allAvailableHosts = this.hosts; + + const areAllSelected = allAvailableHosts.length > 0 && + allAvailableHosts.length === this.selectedHostsPending.length; + + if (areAllSelected) { + this.selectedHostsPending = []; + } else { + this.selectedHostsPending = [...allAvailableHosts]; + } + } + onModuleSelectionChange(module: string) { this.selectedModuleInternal = module; this.navigateTools(); @@ -276,26 +343,65 @@ export class SideNav implements OnInit, OnDestroy { this.navigateTools(); } + // Helper function to serialize query parameters + private serializeQueryParams( + params: {[key: string]: string|string[]|boolean|undefined}): string { + const searchParams = new URLSearchParams(); + for (const key in params) { + if (params.hasOwnProperty(key)) { + const value = params[key]; + // Only include non-null/non-undefined values + if (value !== undefined && value !== null) { + if (Array.isArray(value)) { + // Arrays are handled as comma-separated strings (like 'hosts') + searchParams.set(key, value.join(',')); + } else if (typeof value === 'boolean') { + // Only set boolean flags if they are explicitly true + if (value === true) { + searchParams.set(key, 'true'); + } + } else { + searchParams.set(key, String(value)); + } + } + } + } + const queryString = searchParams.toString(); + return queryString ? `?${queryString}` : ''; + } + updateUrlHistory() { - // TODO(xprof): change to camel case when constructing url - const toolQueryParams = Object.keys(this.navigationParams) - .map(key => { - return `${key}=${this.navigationParams[key]}`; - }) - .join('&'); - const toolQueryParamsString = - toolQueryParams.length ? `&${toolQueryParams}` : ''; - const moduleNameQuery = - this.is_hlo_tool ? `&module_name=${this.selectedModule}` : ''; - const url = `${window.parent.location.origin}?tool=${ - this.selectedTag}&host=${this.selectedHost}&run=${this.selectedRun}${ - toolQueryParamsString}${moduleNameQuery}#profile`; + const navigationEvent = this.getNavigationEvent(); + const queryParams: {[key: string]: string|string[]|boolean| + undefined} = {...navigationEvent}; + + if (this.isMultiHostsEnabled) { + // For multi-host enabled tools, ensure 'hosts' is a comma-separated string in the URL + if (queryParams['hosts'] && Array.isArray(queryParams['hosts'])) { + queryParams['hosts'] = (queryParams['hosts'] as string[]).join(','); + } + delete queryParams['host']; // Remove single host param + } else { + // For other tools, ensure 'host' is used + delete queryParams['hosts']; // Remove multi-host param + } + + // Get current path to avoid changing the base URL + const pathname = window.parent.location.pathname; + + // Use the custom serialization helper + const queryString = this.serializeQueryParams(queryParams); + const url = pathname + queryString; + window.parent.history.pushState({}, '', url); } navigateTools() { const navigationEvent = this.getNavigationEvent(); this.communicationService.onNavigateReady(navigationEvent); + + // This router.navigate call remains, as it's responsible for Angular + // routing this.router.navigate( [ this.selectedTag || 'empty', diff --git a/frontend/app/components/trace_viewer/trace_viewer.ts b/frontend/app/components/trace_viewer/trace_viewer.ts index 9eb5f0ebf..4ade3c9c4 100644 --- a/frontend/app/components/trace_viewer/trace_viewer.ts +++ b/frontend/app/components/trace_viewer/trace_viewer.ts @@ -1,5 +1,4 @@ import {PlatformLocation} from '@angular/common'; -import {HttpParams} from '@angular/common/http'; import {Component, inject, Injector, OnDestroy} from '@angular/core'; import {ActivatedRoute} from '@angular/router'; import {API_PREFIX, DATA_API, PLUGIN_NAME} from 'org_xprof/frontend/app/common/constants/constants'; @@ -38,11 +37,19 @@ export class TraceViewer implements OnDestroy { update(event: NavigationEvent) { const isStreaming = (event.tag === 'trace_viewer@'); - const params = new HttpParams() - .set('run', event.run!) - .set('tag', event.tag!) - .set('host', event.host!); - const traceDataUrl = this.pathPrefix + DATA_API + '?' + params.toString(); + const run = event.run || ''; + const tag = event.tag || ''; + + let queryString = `run=${run}&tag=${tag}`; + + if (event.hosts && typeof event.hosts === 'string') { + // Since event.hosts is a comma-separated string, we can use it directly. + queryString += `&hosts=${event.hosts}`; + } else if (event.host) { + queryString += `&host=${event.host}`; + } + + const traceDataUrl = `${this.pathPrefix}${DATA_API}?${queryString}`; this.url = this.pathPrefix + API_PREFIX + PLUGIN_NAME + '/trace_viewer_index.html' + '?is_streaming=' + isStreaming.toString() + '&is_oss=true' + diff --git a/plugin/xprof/convert/raw_to_tool_data.py b/plugin/xprof/convert/raw_to_tool_data.py index 04b4f5b19..e3d99c937 100644 --- a/plugin/xprof/convert/raw_to_tool_data.py +++ b/plugin/xprof/convert/raw_to_tool_data.py @@ -116,10 +116,9 @@ def xspace_to_tool_data( if success: data = process_raw_trace(raw_data) elif tool == 'trace_viewer@': - # Streaming trace viewer handles one host at a time. - assert len(xspace_paths) == 1 options = params.get('trace_viewer_options', {}) options['use_saved_result'] = params.get('use_saved_result', True) + options['hosts'] = params.get('hosts', []) raw_data, success = xspace_wrapper_func(xspace_paths, tool, options) if success: data = raw_data diff --git a/plugin/xprof/profile_plugin.py b/plugin/xprof/profile_plugin.py index b949ea5a8..2edcb1c36 100644 --- a/plugin/xprof/profile_plugin.py +++ b/plugin/xprof/profile_plugin.py @@ -380,23 +380,6 @@ def filenames_to_hosts(filenames: list[str], tool: str) -> list[str]: return sorted(hosts) -def validate_xplane_asset_paths(asset_paths: List[str]) -> None: - """Validates that all xplane asset paths that are provided are valid files. - - Args: - asset_paths: A list of asset paths. - - Raises: - FileNotFoundError: If any of the xplane asset paths do not exist. - """ - for asset_path in asset_paths: - if ( - str(asset_path).endswith(TOOLS['xplane']) - and not epath.Path(asset_path).exists() - ): - raise FileNotFoundError(f'Invalid asset path: {asset_path}') - - def _get_bool_arg( args: Mapping[str, Any], arg_name: str, default: bool ) -> bool: @@ -511,6 +494,10 @@ def is_active(self) -> bool: self._is_active = any(self.generate_runs()) return self._is_active + def _does_tool_support_multi_hosts_processing(self, tool: str) -> bool: + """Returns true if the tool supports multi-hosts processing.""" + return tool == 'trace_viewer@' or tool == 'trace_viewer' + def get_plugin_apps( self, ) -> dict[str, Callable[[wrappers.Request], wrappers.Response]]: @@ -718,6 +705,85 @@ def hlo_module_list_route( module_names_str = self.hlo_module_list_impl(request) return respond(module_names_str, 'text/plain') + def _get_valid_hosts( + self, run_dir: str, run: str, tool: str, hosts_param: str, host: str + ) -> tuple[List[str], List[epath.Path]]: + """Retrieves and validates the hosts and asset paths for a run and tool. + + Args: + run_dir: The run directory. + run: The frontend run name. + tool: The requested tool. + hosts_param: Comma-separated list of selected hosts. + host: The single host parameter. + + Returns: + A tuple containing (selected_hosts, asset_paths). + + Raises: + FileNotFoundError: If a required xplane file for the specified host(s) + is not found. + IOError: If there is an error reading asset directories. + """ + asset_paths = [] + selected_hosts = [] + all_xplane_files = {} # Map host to path + + # Find all available xplane files for the run and map them by host. + file_pattern = make_filename('*', 'xplane') + try: + path = epath.Path(run_dir) + for xplane_path in path.glob(file_pattern): + host_name, _ = _parse_filename(xplane_path.name) + if host_name: + print('host_name: %s', host_name) + all_xplane_files[host_name] = xplane_path + except OSError as e: + print('Error') + logger.warning('Cannot read asset directory: %s, OpError %s', run_dir, e) + raise IOError( + 'Cannot read asset directory: %s, OpError %s' % (run_dir, e) + ) from e + + if hosts_param and self._does_tool_support_multi_hosts_processing(tool): + selected_hosts = hosts_param.split(',') + for selected_host in selected_hosts: + if selected_host in all_xplane_files: + asset_paths.append(all_xplane_files[selected_host]) + else: + raise FileNotFoundError( + 'No xplane file found for host: %s in run: %s' + % (selected_host, run) + ) + logger.info('Inside trace_viewer@, asset_paths: %s') + elif host == ALL_HOSTS: + asset_paths = list(all_xplane_files.values()) + selected_hosts = list(all_xplane_files.keys()) + elif host and host in all_xplane_files: + selected_hosts = [host] + asset_paths = [all_xplane_files[host]] + elif host: + logger.warning('No xplane file found for host: %s in run: %s', host, run) + if host not in XPLANE_TOOLS_ALL_HOSTS_ONLY: + raise FileNotFoundError( + 'No xplane file found for host: %s in run: %s' % (host, run) + ) + + if not asset_paths: + logger.warning( + 'No matching asset paths found for run %s, tool %s, host(s) %s / %s', + run, + tool, + hosts_param, + host, + ) + if not host and tool not in XPLANE_TOOLS_ALL_HOSTS_ONLY: + raise FileNotFoundError( + 'Host must be specified for tool %s in run %s' % (tool, run) + ) + + return selected_hosts, asset_paths + def data_impl( self, request: wrappers.Request ) -> tuple[Optional[str], str, Optional[str]]: @@ -729,9 +795,17 @@ def data_impl( Returns: A string that can be served to the frontend tool or None if tool, run or host is invalid. + + Raises: + FileNotFoundError: If a required xplane file for the specified host(s) + is not found. + IOError: If there is an error reading asset directories. + AttributeError: If there is an error during xplane to tool data conversion + ValueError: If xplane conversion fails due to invalid data. """ run = request.args.get('run') tool = request.args.get('tag') + hosts_param = request.args.get('hosts') host = request.args.get('host') module_name = request.args.get('module_name') tqx = request.args.get('tqx') @@ -795,26 +869,16 @@ def data_impl( options['search_prefix'] = request.args.get('search_prefix') params['trace_viewer_options'] = options - asset_path = os.path.join(run_dir, make_filename(host, tool)) - _, content_encoding = None, None if use_xplane(tool): - if host == ALL_HOSTS: - file_pattern = make_filename('*', 'xplane') - try: - path = epath.Path(run_dir) - asset_paths = list(path.glob(file_pattern)) - except OSError as e: - logger.warning('Cannot read asset directory: %s, OpError %s', run_dir, - e) - raise IOError( - 'Cannot read asset directory: %s, OpError %s' % (run_dir, e) - ) from e - else: - asset_paths = [asset_path] + selected_hosts, asset_paths = self._get_valid_hosts( + run_dir, run, tool, hosts_param, host + ) + if not asset_paths: + return None, content_type, None + params['hosts'] = selected_hosts try: - validate_xplane_asset_paths(asset_paths) data, content_type = convert.xspace_to_tool_data( asset_paths, tool, params) except AttributeError as e: diff --git a/plugin/xprof/profile_plugin_test.py b/plugin/xprof/profile_plugin_test.py index 0f44db974..613e96309 100644 --- a/plugin/xprof/profile_plugin_test.py +++ b/plugin/xprof/profile_plugin_test.py @@ -330,7 +330,9 @@ def testData(self): with self.assertRaises(FileNotFoundError): self.plugin.data_impl( utils.make_data_request( - utils.DataRequestOptions(run='a', tool='trace_viewer', host='') + utils.DataRequestOptions( + run='a/foo', tool='trace_viewer', host='' + ) ) ) @@ -445,6 +447,7 @@ def testDataImplTraceViewerOptions(self, mock_xspace_to_tool_data): 'start_time_ms': '100', 'end_time_ms': '200', }, + 'hosts': ['host1'], } _, _, _ = self.plugin.data_impl( @@ -462,8 +465,11 @@ def testDataImplTraceViewerOptions(self, mock_xspace_to_tool_data): ) mock_xspace_to_tool_data.assert_called_once_with( - [expected_asset_path], 'trace_viewer@', expected_params + [mock.ANY], 'trace_viewer@', expected_params ) + actual_path_list = mock_xspace_to_tool_data.call_args[0][0] + self.assertLen(actual_path_list, 1) + self.assertEqual(str(actual_path_list[0]), expected_asset_path) def testActive(self): @@ -535,8 +541,10 @@ def test_generate_runs_from_path_params_with_run_path(self): # run3 is a file, not a directory, and should be ignored. with open(os.path.join(run_path, 'run3'), 'w') as f: f.write('dummy file') + with open(os.path.join(run2_path, 'host2.xplane.pb'), 'w') as f: + f.write('dummy xplane data for run2') runs = list(self.plugin._generate_runs_from_path_params(run_path=run_path)) - self.assertListEqual(['run1'], runs) + self.assertListEqual(['run1', 'run2'], sorted(runs)) self.assertEqual(run_path, self.plugin.logdir) def test_runs_impl_with_session(self): diff --git a/xprof/convert/BUILD b/xprof/convert/BUILD index 87b0426a1..8f1e8238c 100644 --- a/xprof/convert/BUILD +++ b/xprof/convert/BUILD @@ -1592,12 +1592,14 @@ cc_library( srcs = ["xplane_to_trace_container.cc"], hdrs = ["xplane_to_trace_container.h"], deps = [ + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", "@org_xprof//plugin/xprof/protobuf:trace_events_proto_cc", "@org_xprof//plugin/xprof/protobuf:trace_events_raw_proto_cc", "@org_xprof//xprof/convert/trace_viewer:trace_event_arguments_builder", "@org_xprof//xprof/convert/trace_viewer:trace_events", "@org_xprof//xprof/convert/trace_viewer:trace_events_util", + "@org_xprof//xprof/convert/trace_viewer:trace_utils", "@tsl//tsl/profiler/protobuf:xplane_proto_cc", "@xla//xla/tsl/profiler/utils:tf_xplane_visitor", "@xla//xla/tsl/profiler/utils:timespan", diff --git a/xprof/convert/streaming_trace_viewer_processor.cc b/xprof/convert/streaming_trace_viewer_processor.cc index 954885cf7..e90222d74 100644 --- a/xprof/convert/streaming_trace_viewer_processor.cc +++ b/xprof/convert/streaming_trace_viewer_processor.cc @@ -67,65 +67,69 @@ absl::StatusOr GetTraceViewOption(const ToolOptions& options) { absl::Status StreamingTraceViewerProcessor::ProcessSession( const SessionSnapshot& session_snapshot, const ToolOptions& options) { - if (session_snapshot.XSpaceSize() != 1) { - return tsl::errors::InvalidArgument( - "Trace events tool expects only 1 XSpace path but gets ", - session_snapshot.XSpaceSize()); - } - - google::protobuf::Arena arena; - TF_ASSIGN_OR_RETURN(XSpace * xspace, session_snapshot.GetXSpace(0, &arena)); - PreprocessSingleHostXSpace(xspace, /*step_grouping=*/true, - /*derived_timeline=*/true); - + TraceEventsContainer merged_trace_container; std::string tool_name = "trace_viewer@"; - std::string trace_viewer_json; - std::string host_name = session_snapshot.GetHostname(0); - auto sstable_path = session_snapshot.GetFilePath(tool_name, host_name); - if (!sstable_path) { - return tsl::errors::Unimplemented( - "streaming trace viewer hasn't been supported in Cloud AI"); - } - if (!tsl::Env::Default()->FileExists(*sstable_path).ok()) { - ProcessMegascaleDcn(xspace); - TraceEventsContainer trace_container; - ConvertXSpaceToTraceEventsContainer(host_name, *xspace, &trace_container); - std::unique_ptr file; - TF_RETURN_IF_ERROR( - tsl::Env::Default()->NewWritableFile(*sstable_path, &file)); - TF_RETURN_IF_ERROR(trace_container.StoreAsLevelDbTable(std::move(file))); - } TF_ASSIGN_OR_RETURN(TraceViewOption trace_option, GetTraceViewOption(options)); tensorflow::profiler::TraceOptions profiler_trace_options = TraceOptionsFromToolOptions(options); - auto visibility_filter = std::make_unique( - tsl::profiler::MilliSpan(trace_option.start_time_ms, - trace_option.end_time_ms), - trace_option.resolution, profiler_trace_options); - TraceEventsContainer trace_container; - // Trace smaller than threshold will be disabled from streaming. - constexpr int64_t kDisableStreamingThreshold = 500000; - auto trace_events_filter = - CreateTraceEventsFilterFromTraceOptions(profiler_trace_options); - TraceEventsLevelDbFilePaths file_paths; - file_paths.trace_events_file_path = *sstable_path; - TF_RETURN_IF_ERROR(trace_container.LoadFromLevelDbTable( - file_paths, std::move(trace_events_filter), std::move(visibility_filter), - kDisableStreamingThreshold)); + + // TODO(b/452217676) : Optimize this to process hosts in parallel. + for (int i = 0; i < session_snapshot.XSpaceSize(); ++i) { + int host_id = i+1; + google::protobuf::Arena arena; + TF_ASSIGN_OR_RETURN(XSpace * xspace, session_snapshot.GetXSpace(i, &arena)); + PreprocessSingleHostXSpace(xspace, /*step_grouping=*/true, + /*derived_timeline=*/true); + + std::string host_name = session_snapshot.GetHostname(i); + auto sstable_path = session_snapshot.GetFilePath(tool_name, host_name); + if (!sstable_path) { + return tsl::errors::Unimplemented( + "streaming trace viewer hasn't been supported in Cloud AI"); + } + if (!tsl::Env::Default()->FileExists(*sstable_path).ok()) { + ProcessMegascaleDcn(xspace); + TraceEventsContainer trace_container; + ConvertXSpaceToTraceEventsContainer(host_name, *xspace, + &trace_container); + std::unique_ptr file; + TF_RETURN_IF_ERROR( + tsl::Env::Default()->NewWritableFile(*sstable_path, &file)); + TF_RETURN_IF_ERROR(trace_container.StoreAsLevelDbTable(std::move(file))); + } + + auto visibility_filter = std::make_unique( + tsl::profiler::MilliSpan(trace_option.start_time_ms, + trace_option.end_time_ms), + trace_option.resolution, profiler_trace_options); + TraceEventsContainer trace_container; + // Trace smaller than threshold will be disabled from streaming. + constexpr int64_t kDisableStreamingThreshold = 500000; + auto trace_events_filter = + CreateTraceEventsFilterFromTraceOptions(profiler_trace_options); + TraceEventsLevelDbFilePaths file_paths; + file_paths.trace_events_file_path = *sstable_path; + TF_RETURN_IF_ERROR(trace_container.LoadFromLevelDbTable( + file_paths, std::move(trace_events_filter), + std::move(visibility_filter), kDisableStreamingThreshold)); + merged_trace_container.Merge(std::move(trace_container), host_id); + } + + std::string trace_viewer_json; JsonTraceOptions json_trace_options; tensorflow::profiler::TraceDeviceType device_type = tensorflow::profiler::TraceDeviceType::kUnknownDevice; - if (IsTpuTrace(trace_container.trace())) { + if (IsTpuTrace(merged_trace_container.trace())) { device_type = TraceDeviceType::kTpu; } json_trace_options.details = TraceOptionsToDetails(device_type, profiler_trace_options); IOBufferAdapter adapter(&trace_viewer_json); TraceEventsToJson( - json_trace_options, trace_container, &adapter); + json_trace_options, merged_trace_container, &adapter); SetOutput(trace_viewer_json, "application/json"); return absl::OkStatus(); diff --git a/xprof/convert/trace_viewer/BUILD b/xprof/convert/trace_viewer/BUILD index d0020296e..54ca46f03 100644 --- a/xprof/convert/trace_viewer/BUILD +++ b/xprof/convert/trace_viewer/BUILD @@ -133,6 +133,7 @@ cc_library( ":trace_events_filter_interface", ":trace_events_util", ":trace_viewer_visibility", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:endian", "@com_google_absl//absl/container:flat_hash_map", diff --git a/xprof/convert/trace_viewer/trace_events.cc b/xprof/convert/trace_viewer/trace_events.cc index fa5bce32d..cdcb79f4b 100644 --- a/xprof/convert/trace_viewer/trace_events.cc +++ b/xprof/convert/trace_viewer/trace_events.cc @@ -19,13 +19,17 @@ limitations under the License. #include #include #include +#include #include #include #include #include +#include #include +#include "absl/algorithm/container.h" #include "absl/base/internal/endian.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" @@ -77,6 +81,15 @@ void MaybeAddEventUniqueId(std::vector& events) { } } +// Appends all events from src into dst. +inline void AppendEvents(TraceEventTrack&& src, TraceEventTrack* dst) { + if (dst->empty()) { + *dst = std::move(src); + } else { + absl::c_move(src, std::back_inserter(*dst)); + } +} + } // namespace TraceEvent::EventType GetTraceEventType(const TraceEvent& event) { @@ -293,5 +306,75 @@ void PurgeIrrelevantEntriesInTraceNameTable( trace.mutable_name_table()->swap(new_name_table); } +template +void TraceEventsContainerBase::MergeTrace( + const Trace& other_trace) { + trace_.mutable_tasks()->insert(other_trace.tasks().begin(), + other_trace.tasks().end()); + trace_.mutable_name_table()->insert(other_trace.name_table().begin(), + other_trace.name_table().end()); + if (other_trace.has_min_timestamp_ps() && + other_trace.has_max_timestamp_ps()) { + ExpandTraceSpan(TraceSpan(other_trace), &trace_); + } + trace_.set_num_events(trace_.num_events() + other_trace.num_events()); +} + +template +void TraceEventsContainerBase::Merge( + TraceEventsContainerBase&& other, int host_id) { + if (this == &other) return; + if (other.NumEvents() == 0 && other.trace().devices().empty()) return; + + const int kMaxDevicesPerHost = 1000; + absl::flat_hash_map other_to_this_device_id_map; + auto& this_device_map = *trace_.mutable_devices(); + + // Handle device id collisions. + // TODO(muditgokhale) : Check if this logic can be moved to + // xplane_to_trace_container. + for (const auto& [other_id, other_device] : other.trace().devices()) { + LOG(WARNING) << "Remapping device id " << other_id << "for host " << host_id + << " to " << other_id + host_id * kMaxDevicesPerHost; + uint32_t target_id = other_id + host_id * kMaxDevicesPerHost; + other_to_this_device_id_map[other_id] = target_id; + + Device device_copy = other_device; + device_copy.set_device_id(target_id); + + this_device_map.insert({target_id, device_copy}); + } + + other.ForAllMutableTracks([this, &other_to_this_device_id_map]( + uint32_t other_device_id, + ResourceValue resource_id_or_counter_name, + TraceEventTrack* track) { + uint32_t this_device_id = other_to_this_device_id_map.at(other_device_id); + for (TraceEvent* event : *track) { + event->set_device_id(this_device_id); + } + DeviceEvents& device = this->events_by_device_[this_device_id]; + if (const uint64_t* resource_id = + std::get_if(&resource_id_or_counter_name)) { + AppendEvents(std::move(*track), &device.events_by_resource[*resource_id]); + } else if (const absl::string_view* counter_name = + std::get_if( + &resource_id_or_counter_name)) { + AppendEvents(std::move(*track), + &device.counter_events_by_name[*counter_name]); + } + }); + + MergeTrace(other.trace()); + arenas_.insert(std::make_move_iterator(other.arenas_.begin()), + std::make_move_iterator(other.arenas_.end())); + other.arenas_.clear(); + other.events_by_device_.clear(); + other.trace_.Clear(); +} + +// Explicit instantiations for the common case. +template class TraceEventsContainerBase; + } // namespace profiler } // namespace tensorflow diff --git a/xprof/convert/trace_viewer/trace_events.h b/xprof/convert/trace_viewer/trace_events.h index 99a68d3b9..5737e18d0 100644 --- a/xprof/convert/trace_viewer/trace_events.h +++ b/xprof/convert/trace_viewer/trace_events.h @@ -729,6 +729,8 @@ class TraceEventsContainerBase { TraceEventsContainerBase(const TraceEventsContainerBase&) = delete; TraceEventsContainerBase& operator=(const TraceEventsContainerBase&) = delete; + void Merge(TraceEventsContainerBase&& other, int host_id); + // Creates a TraceEvent prefilled with the given values. void AddCompleteEvent(absl::string_view name, uint64_t resource_id, uint32_t device_id, tsl::profiler::Timespan timespan, @@ -1075,6 +1077,9 @@ class TraceEventsContainerBase { return copy; } + // Helper function to merge top-level trace metadata. + void MergeTrace(const Trace& other_trace); + // Adds an event from arenas_ to events_by_device_. void AddArenaEvent(TraceEvent* event) { ExpandTraceSpan(EventSpan(*event), &trace_); diff --git a/xprof/convert/trace_viewer/trace_utils.h b/xprof/convert/trace_viewer/trace_utils.h index 783fd8420..eaa9b0d3d 100644 --- a/xprof/convert/trace_viewer/trace_utils.h +++ b/xprof/convert/trace_viewer/trace_utils.h @@ -37,6 +37,8 @@ inline bool MaybeTpuNonCoreDeviceName(absl::string_view device_name) { IsTpuIciRouterDeviceName(device_name)); } +static constexpr int kMaxDevicesPerHost = 1000; + } // namespace profiler } // namespace tensorflow diff --git a/xprof/convert/xplane_to_tools_data.cc b/xprof/convert/xplane_to_tools_data.cc index 4ecce1108..f8bc5e712 100644 --- a/xprof/convert/xplane_to_tools_data.cc +++ b/xprof/convert/xplane_to_tools_data.cc @@ -177,6 +177,8 @@ absl::StatusOr ConvertXSpaceToTraceEvents( if (!tsl::Env::Default()->FileExists(*trace_events_sstable_path).ok()) { ProcessMegascaleDcn(xspace); TraceEventsContainer trace_container; + // No-op method which will be deprecated in the future, thus added + // /*host_id=*/1 as a placeholder for now. ConvertXSpaceToTraceEventsContainer(host_name, *xspace, &trace_container); std::unique_ptr trace_events_file; TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile( diff --git a/xprof/convert/xplane_to_trace_container.cc b/xprof/convert/xplane_to_trace_container.cc index aba9c186c..95866dafe 100644 --- a/xprof/convert/xplane_to_trace_container.cc +++ b/xprof/convert/xplane_to_trace_container.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/base/optimization.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "xla/tsl/profiler/utils/tf_xplane_visitor.h" @@ -236,9 +237,12 @@ void ConvertXSpaceToTraceEventsContainer(absl::string_view hostname, } for (const XPlane* device_plane : device_planes) { - ConvertXPlaneToTraceEventsContainer( - tsl::profiler::kFirstDeviceId + device_plane->id(), hostname, - *device_plane, container); + uint32_t device_pid = tsl::profiler::kFirstDeviceId + device_plane->id(); + if (ABSL_PREDICT_FALSE(device_pid > tsl::profiler::kLastDeviceId)) { + device_pid = tsl::profiler::kFirstDeviceId; + } + ConvertXPlaneToTraceEventsContainer(device_pid, hostname, *device_plane, + container); } for (const XPlane* custom_plane : FindPlanesWithPrefix(space, tsl::profiler::kCustomPlanePrefix)) {