Skip to content

Conversation

hylin2002
Copy link
Collaborator

This PR is a workaround to extract the latency from trace on TPU for collectives, and should be reverted or refactored in future.

The detailed changes include:

  • Add a dictionary TARGET_TASK_NAME_COLLECTIVES_MAP in src/benchmark_utils.py to map a collective to its corresponding operation on TPU devices.
  • Add a function get_metrics_from_trace_tpu to extract the execution time of collective operation on TPU.

@hylin2002 hylin2002 requested a review from chishuen August 29, 2025 09:31
@hylin2002 hylin2002 marked this pull request as draft August 29, 2025 18:06
Copy link
Collaborator

@chishuen chishuen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll continue the review later today.

import subprocess
import shutil

# The dictionary to map a CPU collective function to its corresponding operation on TPU
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: drop the term "CPU". E.g. "map a JAX (collective) operation to its main HLO"

import shutil

# The dictionary to map a CPU collective function to its corresponding operation on TPU
# "psum_scatter_ici_op" has different implementation according to its `matrix_dim` and the number of TPUs, so it's not considered in this mapping dictionary.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove this comment

"all_to_all_ici_op": r"all-to-all.[0-9]+",
"all_gather_ici_op": r"all-gather.[0-9]+",
"psum_ici_op": r"all-reduce.[0-9]+",
"ppermute_ici_op": r"collective-permute-done",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LMK when you have the data with xla_enable_async_collective_permute=false

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've already have the data for this flag.


# Check if the given task name is a collective with corresponding TPU opertion.
# This is a workaround and should be reverted or refactored in future.
if task in TARGET_TASK_NAME_COLLECTIVES_MAP.keys():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if task in TARGET_TASK_NAME_COLLECTIVES_MAP is sufficient.

# This is a workaround and should be reverted or refactored in future.
if task in TARGET_TASK_NAME_COLLECTIVES_MAP.keys():
task = TARGET_TASK_NAME_COLLECTIVES_MAP[task]
return get_metrics_from_trace_tpu(trace, task)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you don't need to call the function again?

def get_metrics_from_trace(trace: dict[str, Any], task: str) -> list[float]:

# Check if the given task name is a collective with corresponding TPU opertion.
# This is a workaround and should be reverted or refactored in future.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comment: if task is not present in the map, fallback to the default behavior to measure the timing from the CPU end.

@hylin2002 hylin2002 force-pushed the extract-time-from-tpu-trace branch from f2ddc4c to 4e09b70 Compare September 1, 2025 05:23
@hylin2002 hylin2002 force-pushed the extract-time-from-tpu-trace branch 2 times, most recently from 0edd623 to 6ef9e45 Compare September 1, 2025 08:06
Change write to csv filename from csv to tsv.
@hylin2002 hylin2002 force-pushed the extract-time-from-tpu-trace branch 2 times, most recently from 8a477b4 to acd30c9 Compare September 4, 2025 03:13
@hylin2002 hylin2002 marked this pull request as ready for review September 4, 2025 04:51
@hylin2002 hylin2002 force-pushed the extract-time-from-tpu-trace branch from 2c55acd to afd28a0 Compare September 4, 2025 05:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants