-
Notifications
You must be signed in to change notification settings - Fork 8
Extract latency from TPU trace for collective microbenchmarks #26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Extract latency from TPU trace for collective microbenchmarks #26
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll continue the review later today.
src/benchmark_utils.py
Outdated
import subprocess | ||
import shutil | ||
|
||
# The dictionary to map a CPU collective function to its corresponding operation on TPU |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: drop the term "CPU". E.g. "map a JAX (collective) operation to its main HLO"
src/benchmark_utils.py
Outdated
import shutil | ||
|
||
# The dictionary to map a CPU collective function to its corresponding operation on TPU | ||
# "psum_scatter_ici_op" has different implementation according to its `matrix_dim` and the number of TPUs, so it's not considered in this mapping dictionary. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove this comment
src/benchmark_utils.py
Outdated
"all_to_all_ici_op": r"all-to-all.[0-9]+", | ||
"all_gather_ici_op": r"all-gather.[0-9]+", | ||
"psum_ici_op": r"all-reduce.[0-9]+", | ||
"ppermute_ici_op": r"collective-permute-done", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LMK when you have the data with xla_enable_async_collective_permute=false
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've already have the data for this flag.
src/benchmark_utils.py
Outdated
|
||
# Check if the given task name is a collective with corresponding TPU opertion. | ||
# This is a workaround and should be reverted or refactored in future. | ||
if task in TARGET_TASK_NAME_COLLECTIVES_MAP.keys(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: if task in TARGET_TASK_NAME_COLLECTIVES_MAP
is sufficient.
src/benchmark_utils.py
Outdated
# This is a workaround and should be reverted or refactored in future. | ||
if task in TARGET_TASK_NAME_COLLECTIVES_MAP.keys(): | ||
task = TARGET_TASK_NAME_COLLECTIVES_MAP[task] | ||
return get_metrics_from_trace_tpu(trace, task) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you don't need to call the function again?
def get_metrics_from_trace(trace: dict[str, Any], task: str) -> list[float]: | ||
|
||
# Check if the given task name is a collective with corresponding TPU opertion. | ||
# This is a workaround and should be reverted or refactored in future. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add comment: if task
is not present in the map, fallback to the default behavior to measure the timing from the CPU end.
f2ddc4c
to
4e09b70
Compare
0edd623
to
6ef9e45
Compare
Change write to csv filename from csv to tsv.
8a477b4
to
acd30c9
Compare
2c55acd
to
afd28a0
Compare
This PR is a workaround to extract the latency from trace on TPU for collectives, and should be reverted or refactored in future.
The detailed changes include:
TARGET_TASK_NAME_COLLECTIVES_MAP
insrc/benchmark_utils.py
to map a collective to its corresponding operation on TPU devices.get_metrics_from_trace_tpu
to extract the execution time of collective operation on TPU.