Skip to content

Incorrect metadata for custom_op in benchmark_inference.py #2680

@riccardofelluga

Description

@riccardofelluga

🐛 Bug

It looks like that with the nvfuser version nvfuser-0.2.34+git2a7a1f9 and Thunder commit 13f7171784d6a953fd02879f325a7facfe124d0d the inference benchmark fails with:

AssertionError: expected size 2048==2048, stride 64==1 at dim=1; expected size 64==64, stride 1==2048 at dim=2

To Reproduce

To reproduce run the following:

torchrun --local-ranks-filter 0 --nproc-per-node 2 thunder/benchmarks/benchmark_inference.py --input-length 2048 --output-length 512 --mode thunder --num-iterations 10

Stack trace

In the latest container the following error:

[rank0]:     outputs = self.model(
[rank0]:               ^^^^^^^^^^^
[rank0]:   File "/opt/pytorch/lightning-thunder/thunder/dynamo/compiler.py", line 263, in __call__
[rank0]:     return self._func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 418, in __call__
[rank0]:     return super().__call__(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1788, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 886, in compile_wrapper
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1788, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py", line 969, in wrapper
[rank0]:     output = func(self, *args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py", line 936, in forward
[rank0]:     outputs = self.model(
[rank0]:               ^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1788, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py", line 969, in wrapper
[rank0]:     output = func(self, *args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py", line 578, in forward
[rank0]:     causal_mask, chunk_causal_mask = self._update_causal_mask(
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py", line 578, in torch_dynamo_resume_in_forward_at_578
[rank0]:     causal_mask, chunk_causal_mask = self._update_causal_mask(
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 418, in __call__
[rank0]:     return super().__call__(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1788, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1098, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 838, in call_wrapped
[rank0]:     return self._wrapped_call(self, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 413, in __call__
[rank0]:     raise e
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 400, in __call__
[rank0]:     return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1788, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "<eval_with_key>.135", line 8, in forward
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1788, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/opt/pytorch/lightning-thunder/thunder/dynamo/splitter.py", line 235, in forward
[rank0]:     return self.fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py", line 2782, in wrapper
[rank0]:     return pytree.tree_unflatten(compiled_fn(*args, **kwargs), spec)
[rank0]:                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1098, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py", line 1135, in forward
[rank0]:     return compiled_fn(full_args)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1962, in __call__
[rank0]:     return self.compiled_fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 357, in runtime_wrapper
[rank0]:     all_outs = call_func_at_runtime_with_args(
[rank0]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 131, in call_func_at_runtime_with_args
[rank0]:     out = normalize_as_list(f(args))
[rank0]:                             ^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 531, in wrapper
[rank0]:     return compiled_fn(runtime_args)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/output_code.py", line 617, in __call__
[rank0]:     return self.current_callable(inputs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/utils.py", line 3018, in run
[rank0]:     out = model(new_inputs)
[rank0]:           ^^^^^^^^^^^^^^^^^
[rank0]:   File "/tmp/torchinductor_root/62/c627kc7k3bmo72ywcp6dkgglykj4huqrud6yei7k7dweawq3vsgg.py", line 49, in call
[rank0]:     assert_size_stride(arg0_1, (1, 2048, 64), (131072, 1, 2048))
[rank0]: AssertionError: expected size 2048==2048, stride 64==1 at dim=1; expected size 64==64, stride 1==2048 at dim=2
[rank0]: This error most often comes from a incorrect fake (aka meta) kernel for a custom op.
[rank0]: Use torch.library.opcheck to test your custom op.
[rank0]: See https://pytorch.org/docs/stable/library.html#torch.library.opcheck

cc. @shino16

cc @crcrpar

Metadata

Metadata

Assignees

Labels

benchmarkingthunderfxfor things that could be applicable to the dynamo+thunder frontend

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions