Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions benchmarks/communication/all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,21 @@ def timed_all_gather(input, output, start_event, end_event, args):
sync_all()

# time the actual comm op trials times and average it
start_event.record()
for i in range(args.trials):
if args.dist == 'torch':
if hasattr(torch.distributed, "_all_gather_base"):
dist._all_gather_base(output, input, group=None, async_op=args.async_op)
else:
output_tensors = list(
torch.chunk(output,
dist.get_world_size()))
dist.all_gather(output_tensors, input, group=None, async_op=True)
elif args.dist == 'deepspeed':
dist.allgather_fn(output, input, group=None, async_op=args.async_op)
end_event.record()
with prof(args) as profiler:
start_event.record()
for i in range(args.trials):
if args.dist == 'torch':
if hasattr(torch.distributed, "_all_gather_base"):
dist._all_gather_base(output, input, group=None, async_op=args.async_op)
else:
output_tensors = list(
torch.chunk(output,
dist.get_world_size()))
dist.all_gather(output_tensors, input, group=None, async_op=True)
elif args.dist == 'deepspeed':
dist.allgather_fn(output, input, group=None, async_op=args.async_op)
profiler.step()
end_event.record()
sync_all()
duration = start_event.elapsed_time(end_event) / 1000

Expand Down
10 changes: 6 additions & 4 deletions benchmarks/communication/all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ def timed_all_reduce(input, start_event, end_event, args):
sync_all()

# time the actual comm op trials times and average it
start_event.record()
for i in range(args.trials):
dist.all_reduce(input, async_op=args.async_op)
end_event.record()
with prof(args) as profiler:
start_event.record()
for i in range(args.trials):
dist.all_reduce(input, async_op=args.async_op)
profiler.step()
end_event.record()
sync_all()
duration = start_event.elapsed_time(end_event) / 1000

Expand Down
16 changes: 9 additions & 7 deletions benchmarks/communication/all_to_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@ def timed_all_to_all(input, output, start_event, end_event, args):
sync_all()

# time the actual comm op trials times and average it
start_event.record()
for i in range(args.trials):
if args.all_to_all_v:
dist.all_to_all(output_list, input_list, async_op=args.async_op)
else:
dist.all_to_all_single(output, input, async_op=args.async_op)
end_event.record()
with prof(args) as profiler:
start_event.record()
for i in range(args.trials):
if args.all_to_all_v:
dist.all_to_all(output_list, input_list, async_op=args.async_op)
else:
dist.all_to_all_single(output, input, async_op=args.async_op)
profiler.step()
end_event.record()
sync_all()
duration = start_event.elapsed_time(end_event) / 1000

Expand Down
10 changes: 6 additions & 4 deletions benchmarks/communication/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ def timed_broadcast(input, start_event, end_event, args):
sync_all()

# time the actual comm op trials times and average it
start_event.record()
for i in range(args.trials):
dist.broadcast(input, 0, async_op=args.async_op)
end_event.record()
with prof(args) as profiler:
start_event.record()
for i in range(args.trials):
dist.broadcast(input, 0, async_op=args.async_op)
profiler.step()
end_event.record()
sync_all()
duration = start_event.elapsed_time(end_event) / 1000

Expand Down
29 changes: 15 additions & 14 deletions benchmarks/communication/pt2pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,21 @@ def timed_pt2pt(input, start_event, end_event, args):
sync_all()

# time the actual comm op trials times and average it
start_event.record()
for i in range(args.trials):
if dist.get_rank() == 0:
if args.async_op:
dist.isend(input, 1)
else:
dist.send(input, 1)
if dist.get_rank() == 1:
if args.async_op:
dist.irecv(input, src=0)
else:
dist.recv(input, src=0)

end_event.record()
with prof(args) as profiler:
start_event.record()
for i in range(args.trials):
if dist.get_rank() == 0:
if args.async_op:
dist.isend(input, 1)
else:
dist.send(input, 1)
if dist.get_rank() == 1:
if args.async_op:
dist.irecv(input, src=0)
else:
dist.recv(input, src=0)
profiler.step()
end_event.record()
sync_all()
duration = start_event.elapsed_time(end_event) / 1000

Expand Down
26 changes: 14 additions & 12 deletions benchmarks/communication/reduce_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,20 @@ def timed_reduce_scatter(input, start_event, end_event, args):
sync_all()

# time the actual comm op trials times and average it
start_event.record()
for i in range(args.trials):
if hasattr(torch.distributed, "reduce_scatter_tensor"):
dist.reduce_scatter_tensor(output, input, async_op=args.async_op)
elif hasattr(torch.distributed, "_reduce_scatter_base"):
dist._reduce_scatter_base(output, input, async_op=args.async_op)
else:
input_tensors = list(
torch.chunk(input,
dist.get_world_size()))
dist.reduce_scatter(output, input_tensors, async_op=args.async_op)
end_event.record()
with prof(args) as profiler:
start_event.record()
for i in range(args.trials):
if hasattr(torch.distributed, "reduce_scatter_tensor"):
dist.reduce_scatter_tensor(output, input, async_op=args.async_op)
elif hasattr(torch.distributed, "_reduce_scatter_base"):
dist._reduce_scatter_base(output, input, async_op=args.async_op)
else:
input_tensors = list(
torch.chunk(input,
dist.get_world_size()))
dist.reduce_scatter(output, input_tensors, async_op=args.async_op)
profiler.step()
end_event.record()
sync_all()
duration = start_event.elapsed_time(end_event) / 1000

Expand Down
45 changes: 45 additions & 0 deletions benchmarks/communication/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os, sys
import math
import argparse
from contextlib import nullcontext

COMMS_BENCH_DIR = os.path.join(os.path.dirname(__file__), "../")
sys.path.append(COMMS_BENCH_DIR)
Expand Down Expand Up @@ -235,4 +236,48 @@ def benchmark_parser():
parser.add_argument("--debug", action="store_true", help='Enables all_to_all debug prints')
parser.add_argument('--all-to-all-v', action='store_true',
help='Use alltoallv instead of alltoall. This will run the all_to_all benchmark with vector variant. Use with --all-to-all or alone to run just this benchmark.')
parser.add_argument("--profile", action="store_true", help='Enable PyTorch profiler during timed iterations')
return parser

class PassProfile:
"""
Even when profiling is disabled, the code can still walk through step.
"""
def step(self):
pass

def prof(args):
"""
Returns a context manager that enables PyTorch profiler when args.profile is True.
"""
if not getattr(args, 'profile', False):
return nullcontext(PassProfile())

try:
from torch.profiler import profile, ProfilerActivity, schedule, tensorboard_trace_handler
except Exception:
return nullcontext(PassProfile())

activities = [ProfilerActivity.CPU]
if torch.cuda.is_available():
activities.append(ProfilerActivity.CUDA)

prof_schedule = schedule(wait=1, warmup=1, active=5, repeat=1)

# assume saving logs under communication folder
comm_dir = os.path.abspath(os.path.dirname(__file__))
log_dir = os.path.join(comm_dir, 'profiles')
os.makedirs(log_dir, exist_ok=True)

rank = 0
if 'dist' in globals(): rank = dist.get_rank()
handler = tensorboard_trace_handler(os.path.join(log_dir, f'rank_{rank}'))

return profile(
activities=activities,
schedule=prof_schedule,
on_trace_ready=handler,
record_shapes=True,
profile_memory=True,
with_stack=True,
)