Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions thunder/benchmarks/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ class InferenceBenchmarkConfig:
mode: str
disable_moe_replacement: bool
attn_implementation: str | None
profile: bool
thunder_cache: str | None
enable_thunder_cudagraph: bool

Expand Down Expand Up @@ -493,10 +492,17 @@ def run_benchmark(self) -> InferenceMetrics:
for _ in tqdm(range(self.config.num_iterations), disable=LOCAL_RANK != 0):
past_key_values.reset()

if self.config.profile:
is_under_nsys = bool(os.environ.get("NSYS_PROFILING_SESSION_ID"))
# Wrap each non-warmup iteration with cudaProfilerStart() and
# cudaProfilerStop(). This allows the user to run
# ```shell
# nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat:<N> ...
# ```
# to record only the non-warmup iterations.
if is_under_nsys:
torch.cuda.cudart().cudaProfilerStart()
iter_metrics = self.measure_inference_step(input_ids, past_key_values, self.config.output_length)
if self.config.profile:
if is_under_nsys:
torch.cuda.cudart().cudaProfilerStop()

all_metrics.append(iter_metrics)
Expand Down Expand Up @@ -680,11 +686,6 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="let nvfuser take care of linear and matmul, note that this might fail with distributed run. See: https://github.com/NVIDIA/Fuser/issues/4507",
)
parser.add_argument(
"--profile",
action="store_true",
help="Wrap each non-warmup iteration with cudaProfilerStart() and cudaProfilerStop(). This allows us to run `nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat:<N> ... --profile` to record only the non-warmup iterations.",
)

parser.add_argument("--save-results", action="store_true", help="Save results to JSON file")
parser.add_argument("--output-dir", type=str, default="./results", help="Directory to save results")
Expand Down Expand Up @@ -728,7 +729,6 @@ def main():
enable_nv_linear=args.enable_nv_linear,
disable_moe_replacement=args.disable_moe_replacement,
attn_implementation=args.attn_implementation,
profile=args.profile,
thunder_cache=args.thunder_cache,
enable_thunder_cudagraph=args.enable_thunder_cudagraph,
)
Expand Down
Loading