-
Notifications
You must be signed in to change notification settings - Fork 108
Labels
benchmarkingthunderfxfor things that could be applicable to the dynamo+thunder frontendfor things that could be applicable to the dynamo+thunder frontend
Description
🐛 Bug
It looks like that with the nvfuser version nvfuser-0.2.34+git2a7a1f9 and Thunder commit 13f7171784d6a953fd02879f325a7facfe124d0d the inference benchmark fails with:
AssertionError: expected size 2048==2048, stride 64==1 at dim=1; expected size 64==64, stride 1==2048 at dim=2To Reproduce
To reproduce run the following:
torchrun --local-ranks-filter 0 --nproc-per-node 2 thunder/benchmarks/benchmark_inference.py --input-length 2048 --output-length 512 --mode thunder --num-iterations 10
Stack trace
In the latest container the following error:
[rank0]: outputs = self.model(
[rank0]: ^^^^^^^^^^^
[rank0]: File "/opt/pytorch/lightning-thunder/thunder/dynamo/compiler.py", line 263, in __call__
[rank0]: return self._func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 418, in __call__
[rank0]: return super().__call__(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1788, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 886, in compile_wrapper
[rank0]: return fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1788, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py", line 969, in wrapper
[rank0]: output = func(self, *args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py", line 936, in forward
[rank0]: outputs = self.model(
[rank0]: ^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1788, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/transformers/utils/generic.py", line 969, in wrapper
[rank0]: output = func(self, *args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py", line 578, in forward
[rank0]: causal_mask, chunk_causal_mask = self._update_causal_mask(
[rank0]: File "/usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py", line 578, in torch_dynamo_resume_in_forward_at_578
[rank0]: causal_mask, chunk_causal_mask = self._update_causal_mask(
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 418, in __call__
[rank0]: return super().__call__(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1788, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1098, in _fn
[rank0]: return fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 838, in call_wrapped
[rank0]: return self._wrapped_call(self, *args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 413, in __call__
[rank0]: raise e
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/fx/graph_module.py", line 400, in __call__
[rank0]: return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1788, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "<eval_with_key>.135", line 8, in forward
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1777, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1788, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/pytorch/lightning-thunder/thunder/dynamo/splitter.py", line 235, in forward
[rank0]: return self.fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py", line 2782, in wrapper
[rank0]: return pytree.tree_unflatten(compiled_fn(*args, **kwargs), spec)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1098, in _fn
[rank0]: return fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py", line 1135, in forward
[rank0]: return compiled_fn(full_args)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1962, in __call__
[rank0]: return self.compiled_fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 357, in runtime_wrapper
[rank0]: all_outs = call_func_at_runtime_with_args(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 131, in call_func_at_runtime_with_args
[rank0]: out = normalize_as_list(f(args))
[rank0]: ^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 531, in wrapper
[rank0]: return compiled_fn(runtime_args)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/output_code.py", line 617, in __call__
[rank0]: return self.current_callable(inputs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/utils.py", line 3018, in run
[rank0]: out = model(new_inputs)
[rank0]: ^^^^^^^^^^^^^^^^^
[rank0]: File "/tmp/torchinductor_root/62/c627kc7k3bmo72ywcp6dkgglykj4huqrud6yei7k7dweawq3vsgg.py", line 49, in call
[rank0]: assert_size_stride(arg0_1, (1, 2048, 64), (131072, 1, 2048))
[rank0]: AssertionError: expected size 2048==2048, stride 64==1 at dim=1; expected size 64==64, stride 1==2048 at dim=2
[rank0]: This error most often comes from a incorrect fake (aka meta) kernel for a custom op.
[rank0]: Use torch.library.opcheck to test your custom op.
[rank0]: See https://pytorch.org/docs/stable/library.html#torch.library.opcheckcc. @shino16
cc @crcrpar
shino16
Metadata
Metadata
Assignees
Labels
benchmarkingthunderfxfor things that could be applicable to the dynamo+thunder frontendfor things that could be applicable to the dynamo+thunder frontend