-
Notifications
You must be signed in to change notification settings - Fork 108
Description
Note: If you have a model or program that is not supported yet but should be, please use the program coverage template.
🐛 Bug
As per title, when the arguments of torch.cumsum are an int|bool tensor and dtype=None, thunder should return torch.int64 tensor but not as in
lightning-thunder/thunder/torch/__init__.py
Lines 3015 to 3019 in 7a8d7e6
| def cumsum(a: TensorLike, dim: int, *, dtype: None | dtypeLike = None) -> TensorLike: | |
| # check the input dimension | |
| utils.canonicalize_dim(a.ndim, dim) | |
| if dtype is None: | |
| return TensorProxy(like=a) |
PyTorch seems to use torch.int64 according to https://github.com/pytorch/pytorch/blob/78fe079c97fd35a2072717cbdcf98bf3a88e61be/torch/_refs/__init__.py#L2301-L2315. This method is used by cumsum there: https://github.com/pytorch/pytorch/blob/78fe079c97fd35a2072717cbdcf98bf3a88e61be/torch/_refs/__init__.py#L4655.
To Reproduce
Code sample
import torch
import thunder
def f(mask):
t103 = torch.cumsum(mask, dim=1)
add = t103 + 0.0
incremental_indices = add * mask
long = incremental_indices.to(torch.int64)
position_ids = long + 1
return t103, position_ids
def main():
dev = torch.device("cuda")
with dev:
mask = torch.randint(0, 128, (1, 128), dtype=torch.int32)
r_t103, r_position_ids = f(mask)
print(f"### {r_t103.dtype=}, {r_t103.shape=}, {r_position_ids.dtype=}, {r_position_ids.shape=}")
jitted = thunder.jit(f)
t103, position_ids = jitted(mask)
torch.testing.assert_close(t103, r_t103)
torch.testing.assert_close(position_ids, r_position_ids)
if __name__ == "__main__":
main()This program fails with the following message:
[ERROR | nvfuser ]: An error occurred while executing nvFuser FusionDefinition 0.
If you believe this is a bug or need assistance, please file an issue at https://github.com/NVIDIA/Fuser/issues/new
...
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/__init__.py", line 317, in execute
out_tensors: list[DistributedTensor] = self._execute(
^^^^^^^^^^^^^^
RuntimeError: INTERNAL ASSERT FAILED at "/opt/pytorch/fuser/csrc/runtime/executor_utils.cpp":595, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. When trying to run the provided host program, there was an error with the provided input 0. Provided input was:
Tensor: shape: [1, 128], dtype: long int, device: cuda:0, pointer: 140382379705856
Fusion input was:
T0_g_int[bS0{1}, iS1{128}]
Expr eval provided the error:
"""Expected input 0, T0_g_int[bS0{1}, iS1{128}], to be bound to a tensor of dtype int, but got a tensor of dtype int64_t"""
Exception raised from bindInputs at /opt/pytorch/fuser/csrc/runtime/executor_utils.cpp:595 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x103 (0x7fae6d3cc771 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x62 (0x7fae6d81b582 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #2: <unknown function> + 0x1c8277 (0x7fae6d416277 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0xa2d627 (0x7fae6dc7b627 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #4: <unknown function> + 0xa2e4ee (0x7fae6dc7c4ee in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #5: <unknown function> + 0x9791eb (0x7fae6dbc71eb in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #6: <unknown function> + 0x96fcf8 (0x7fae6dbbdcf8 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #7: nvfuser::FusionExecutorCache::runFusionWithInputs(nvfuser::KernelArgumentHolder, std::optional<nvfuser::PrimDataType>, std::optional<signed char>) + 0xbd (0x7fae6dbbe74d in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #8: nvfuser::python_frontend::FusionDefinition::execute(nvfuser::KernelArgumentHolder, std::optional<signed char>, bool, bool, bool, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >) const + 0xe4f (0x7fae6dd9adcf in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #9: <unknown function> + 0x21fd37 (0x7fae6d46dd37 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #10: <unknown function> + 0x2ee0a3 (0x7fae6d53c0a3 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #11: <unknown function> + 0x20c223 (0x7fae6d45a223 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #12: python() [0x58208f]
<omitting python frames>
frame #15: python() [0x54cd94]
frame #19: python() [0x5a3628]
frame #23: python() [0x608b42]
frame #24: python() [0x6b4e93]
frame #29: <unknown function> + 0x2a1ca (0x7fc27852b1ca in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #30: __libc_start_main + 0x8b (0x7fc27852b28b in /usr/lib/x86_64-linux-gnu/libc.so.6)
### r_t103.dtype=torch.int64, r_t103.shape=torch.Size([1, 128]), r_position_ids.dtype=torch.int64, r_position_ids.shape=torch.Size([1, 128])
# Constructed by Unwrap the actual return value
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(mask):
# mask: "cuda:0 i32[1, 128]"
# /workspace/nemo-bench/a.py:7: t103 = torch.cumsum(mask, dim=1)
t8 = torch.cumsum(mask, 1) # t8: "cuda:0 i32[1, 128]"
# t8 = ltorch.cumsum(mask, 1, dtype=None) # t8: "cuda:0 i32[1, 128]"
[position_ids] = nvFusion0(t8, mask)
# t1 = prims.convert_element_type(t8, dtypes.float32_) # t1: "cuda:0 f32[1, 128]"
# add = prims.add(t1, 0.0) # add: "cuda:0 f32[1, 128]"
# t3 = prims.convert_element_type(mask, dtypes.float32_) # t3: "cuda:0 f32[1, 128]"
# incremental_indices = prims.mul(add, t3) # incremental_indices: "cuda:0 f32[1, 128]"
# long = prims.convert_element_type(incremental_indices, dtypes.int64) # long: "cuda:0 i64[1, 128]"
# position_ids = prims.add(long, 1) # position_ids: "cuda:0 i64[1, 128]"
return (t8, position_ids)
Traceback (most recent call last):
File "/workspace/nemo-bench/a.py", line 31, in <module>
main()
File "/workspace/nemo-bench/a.py", line 24, in main
t103, position_ids = jitted(mask)
^^^^^^^^^^^^
File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 745, in wrapped
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 787, in fn_
result = cache_entry.computation_fn(*inps)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 709, in wrapped
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/pytorch/lightning-thunder/thunder/executors/torchex.py", line 170, in no_autocast_fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "thunder.computation_2", line 12, in computation
File "/opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 472, in __call__
return fd.execute(args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/__init__.py", line 317, in execute
out_tensors: list[DistributedTensor] = self._execute(
^^^^^^^^^^^^^^
RuntimeError: INTERNAL ASSERT FAILED at "/opt/pytorch/fuser/csrc/runtime/executor_utils.cpp":595, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. When trying to run the provided host program, there was an error with the provided input 0. Provided input was:
Tensor: shape: [1, 128], dtype: long int, device: cuda:0, pointer: 140382379705856
Fusion input was:
T0_g_int[bS0{1}, iS1{128}]
Expr eval provided the error:
"""Expected input 0, T0_g_int[bS0{1}, iS1{128}], to be bound to a tensor of dtype int, but got a tensor of dtype int64_t"""
Exception raised from bindInputs at /opt/pytorch/fuser/csrc/runtime/executor_utils.cpp:595 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x103 (0x7fae6d3cc771 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x62 (0x7fae6d81b582 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #2: <unknown function> + 0x1c8277 (0x7fae6d416277 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0xa2d627 (0x7fae6dc7b627 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #4: <unknown function> + 0xa2e4ee (0x7fae6dc7c4ee in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #5: <unknown function> + 0x9791eb (0x7fae6dbc71eb in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #6: <unknown function> + 0x96fcf8 (0x7fae6dbbdcf8 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #7: nvfuser::FusionExecutorCache::runFusionWithInputs(nvfuser::KernelArgumentHolder, std::optional<nvfuser::PrimDataType>, std::optional<signed char>) + 0xbd (0x7fae6dbbe74d in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #8: nvfuser::python_frontend::FusionDefinition::execute(nvfuser::KernelArgumentHolder, std::optional<signed char>, bool, bool, bool, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >) const + 0xe4f (0x7fae6dd9adcf in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #9: <unknown function> + 0x21fd37 (0x7fae6d46dd37 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #10: <unknown function> + 0x2ee0a3 (0x7fae6d53c0a3 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #11: <unknown function> + 0x20c223 (0x7fae6d45a223 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #12: python() [0x58208f]
<omitting python frames>
frame #15: python() [0x54cd94]
frame #19: python() [0x5a3628]
frame #23: python() [0x608b42]
frame #24: python() [0x6b4e93]
frame #29: <unknown function> + 0x2a1ca (0x7fc27852b1ca in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #30: __libc_start_main + 0x8b (0x7fc27852b28b in /usr/lib/x86_64-linux-gnu/libc.so.6)
Expected behavior
The script should be functioning w/o errors.
Environment
- PyTorch Version (e.g., 1.0):
- OS (e.g., Linux):
- How you installed PyTorch (
conda,pip, source): - Build command you used (if compiling from source):
- Python version:
- CUDA/cuDNN version:
- GPU models and configuration:
- Any other relevant information:
Additional context
nemo peft - sangjeedondrub/tibetan-roberta-causal-base