diff --git a/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py b/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py index 37c975e2..f9f95976 100644 --- a/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py +++ b/dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py @@ -248,25 +248,31 @@ def capture(self, **kwargs): self.model.update_context_cudagraph(self.meta, context) current_stream = torch.cuda.current_stream() + # warmup + warmup_output = self.model(**padded_kwargs) + warmup_buffers = self.model.make_output_buffers(warmup_output) + aclgraph = torch.npu.NPUGraph() with ExitStack() as stack: + AscendGraphRunner.capturing = True with torch.npu.graph( aclgraph, auto_dispatch_capture=True, pool=self.pool, stream=current_stream, ): - output = self.model(**padded_kwargs) + graph_output = self.model(**padded_kwargs) + AscendGraphRunner.capturing = False - output_buffers = dict(logits=output) + output_buffers = self.model.make_output_buffers(graph_output) self.meta.output_buffers = output_buffers self._graph = aclgraph - return output + final_output = self.model.get_outputs_cudagraph(warmup_buffers, **kwargs) + return final_output @record_function("forward_cudagraph") def forward(self, **kwargs): """forward.""" - num_tokens = kwargs["input_ids"].size(-1) assert self._graph is not None self.model.fill_buffers_cudagraph(self.meta, **kwargs) context = self.ctx_mgr.current_context() @@ -281,7 +287,8 @@ def forward(self, **kwargs): else: update_attn_params(self.update_stream, self.meta, self.max_tokens) self._graph.replay() - output = self.meta.output_buffers["logits"][:, :num_tokens] + output_buffers = self.meta.output_buffers + output = self.model.get_outputs_cudagraph(output_buffers, **kwargs) return output def reset(self): @@ -368,7 +375,7 @@ def __call__(self, **kwargs): if not enable_graph: with record_function("forward_eager"): ret = self.model(**kwargs) - return ret + return self.model.make_output_buffers(ret) graph_key = self.get_graph_key(**kwargs) max_tokens = graph_key[0] @@ -386,9 +393,7 @@ def __call__(self, **kwargs): device=self.device, update_stream=self.update_stream, ) - AscendGraphRunner.capturing = True runner.capture(**kwargs) - AscendGraphRunner.capturing = False self._runner_map[graph_key] = runner else: runner = self._runner_map[graph_key]