Skip to content

ltorch.cumsum should return an int64 TensorProxy when input tensor is int or bool and dtype=None #1952

@crcrpar

Description

@crcrpar

Note: If you have a model or program that is not supported yet but should be, please use the program coverage template.

🐛 Bug

As per title, when the arguments of torch.cumsum are an int|bool tensor and dtype=None, thunder should return torch.int64 tensor but not as in

def cumsum(a: TensorLike, dim: int, *, dtype: None | dtypeLike = None) -> TensorLike:
# check the input dimension
utils.canonicalize_dim(a.ndim, dim)
if dtype is None:
return TensorProxy(like=a)
.

PyTorch seems to use torch.int64 according to https://github.com/pytorch/pytorch/blob/78fe079c97fd35a2072717cbdcf98bf3a88e61be/torch/_refs/__init__.py#L2301-L2315. This method is used by cumsum there: https://github.com/pytorch/pytorch/blob/78fe079c97fd35a2072717cbdcf98bf3a88e61be/torch/_refs/__init__.py#L4655.

To Reproduce

Code sample

import torch

import thunder


def f(mask):
    t103 = torch.cumsum(mask, dim=1)
    add = t103 + 0.0
    incremental_indices = add * mask
    long = incremental_indices.to(torch.int64)
    position_ids = long + 1
    return t103, position_ids


def main():
    dev = torch.device("cuda")
    with dev:
        mask = torch.randint(0, 128, (1, 128), dtype=torch.int32)

    r_t103, r_position_ids = f(mask)
    print(f"### {r_t103.dtype=}, {r_t103.shape=}, {r_position_ids.dtype=}, {r_position_ids.shape=}")

    jitted = thunder.jit(f)
    t103, position_ids = jitted(mask)

    torch.testing.assert_close(t103, r_t103)
    torch.testing.assert_close(position_ids, r_position_ids)


if __name__ == "__main__":
    main()

This program fails with the following message:

[ERROR    | nvfuser            ]: An error occurred while executing nvFuser FusionDefinition 0.
If you believe this is a bug or need assistance, please file an issue at https://github.com/NVIDIA/Fuser/issues/new
...
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/__init__.py", line 317, in execute
    out_tensors: list[DistributedTensor] = self._execute(
                                           ^^^^^^^^^^^^^^
RuntimeError:  INTERNAL ASSERT FAILED at "/opt/pytorch/fuser/csrc/runtime/executor_utils.cpp":595, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. When trying to run the provided host program, there was an error with the provided input 0. Provided input was:
  Tensor: shape: [1, 128], dtype: long int, device: cuda:0, pointer: 140382379705856
Fusion input was:
  T0_g_int[bS0{1}, iS1{128}]
Expr eval provided the error:
"""Expected input 0, T0_g_int[bS0{1}, iS1{128}], to be bound to a tensor of dtype int, but got a tensor of dtype int64_t"""

Exception raised from bindInputs at /opt/pytorch/fuser/csrc/runtime/executor_utils.cpp:595 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x103 (0x7fae6d3cc771 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x62 (0x7fae6d81b582 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #2: <unknown function> + 0x1c8277 (0x7fae6d416277 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0xa2d627 (0x7fae6dc7b627 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #4: <unknown function> + 0xa2e4ee (0x7fae6dc7c4ee in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #5: <unknown function> + 0x9791eb (0x7fae6dbc71eb in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #6: <unknown function> + 0x96fcf8 (0x7fae6dbbdcf8 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #7: nvfuser::FusionExecutorCache::runFusionWithInputs(nvfuser::KernelArgumentHolder, std::optional<nvfuser::PrimDataType>, std::optional<signed char>) + 0xbd (0x7fae6dbbe74d in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #8: nvfuser::python_frontend::FusionDefinition::execute(nvfuser::KernelArgumentHolder, std::optional<signed char>, bool, bool, bool, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >) const + 0xe4f (0x7fae6dd9adcf in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #9: <unknown function> + 0x21fd37 (0x7fae6d46dd37 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #10: <unknown function> + 0x2ee0a3 (0x7fae6d53c0a3 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #11: <unknown function> + 0x20c223 (0x7fae6d45a223 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #12: python() [0x58208f]
<omitting python frames>
frame #15: python() [0x54cd94]
frame #19: python() [0x5a3628]
frame #23: python() [0x608b42]
frame #24: python() [0x6b4e93]
frame #29: <unknown function> + 0x2a1ca (0x7fc27852b1ca in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #30: __libc_start_main + 0x8b (0x7fc27852b28b in /usr/lib/x86_64-linux-gnu/libc.so.6)

### r_t103.dtype=torch.int64, r_t103.shape=torch.Size([1, 128]), r_position_ids.dtype=torch.int64, r_position_ids.shape=torch.Size([1, 128])
# Constructed by Unwrap the actual return value
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(mask):
  # mask: "cuda:0 i32[1, 128]"

  # /workspace/nemo-bench/a.py:7: 	    t103 = torch.cumsum(mask, dim=1)
  t8 = torch.cumsum(mask, 1)  # t8: "cuda:0 i32[1, 128]"
    # t8 = ltorch.cumsum(mask, 1, dtype=None)  # t8: "cuda:0 i32[1, 128]"
  [position_ids] = nvFusion0(t8, mask)
    # t1 = prims.convert_element_type(t8, dtypes.float32_)  # t1: "cuda:0 f32[1, 128]"
    # add = prims.add(t1, 0.0)  # add: "cuda:0 f32[1, 128]"
    # t3 = prims.convert_element_type(mask, dtypes.float32_)  # t3: "cuda:0 f32[1, 128]"
    # incremental_indices = prims.mul(add, t3)  # incremental_indices: "cuda:0 f32[1, 128]"
    # long = prims.convert_element_type(incremental_indices, dtypes.int64)  # long: "cuda:0 i64[1, 128]"
    # position_ids = prims.add(long, 1)  # position_ids: "cuda:0 i64[1, 128]"
  return (t8, position_ids)
Traceback (most recent call last):
  File "/workspace/nemo-bench/a.py", line 31, in <module>
    main()
  File "/workspace/nemo-bench/a.py", line 24, in main
    t103, position_ids = jitted(mask)
                         ^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 745, in wrapped
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 787, in fn_
    result = cache_entry.computation_fn(*inps)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 709, in wrapped
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/executors/torchex.py", line 170, in no_autocast_fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "thunder.computation_2", line 12, in computation
  File "/opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 472, in __call__
    return fd.execute(args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/__init__.py", line 317, in execute
    out_tensors: list[DistributedTensor] = self._execute(
                                           ^^^^^^^^^^^^^^
RuntimeError:  INTERNAL ASSERT FAILED at "/opt/pytorch/fuser/csrc/runtime/executor_utils.cpp":595, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. When trying to run the provided host program, there was an error with the provided input 0. Provided input was:
  Tensor: shape: [1, 128], dtype: long int, device: cuda:0, pointer: 140382379705856
Fusion input was:
  T0_g_int[bS0{1}, iS1{128}]
Expr eval provided the error:
"""Expected input 0, T0_g_int[bS0{1}, iS1{128}], to be bound to a tensor of dtype int, but got a tensor of dtype int64_t"""

Exception raised from bindInputs at /opt/pytorch/fuser/csrc/runtime/executor_utils.cpp:595 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x103 (0x7fae6d3cc771 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x62 (0x7fae6d81b582 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #2: <unknown function> + 0x1c8277 (0x7fae6d416277 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0xa2d627 (0x7fae6dc7b627 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #4: <unknown function> + 0xa2e4ee (0x7fae6dc7c4ee in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #5: <unknown function> + 0x9791eb (0x7fae6dbc71eb in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #6: <unknown function> + 0x96fcf8 (0x7fae6dbbdcf8 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #7: nvfuser::FusionExecutorCache::runFusionWithInputs(nvfuser::KernelArgumentHolder, std::optional<nvfuser::PrimDataType>, std::optional<signed char>) + 0xbd (0x7fae6dbbe74d in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #8: nvfuser::python_frontend::FusionDefinition::execute(nvfuser::KernelArgumentHolder, std::optional<signed char>, bool, bool, bool, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >) const + 0xe4f (0x7fae6dd9adcf in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #9: <unknown function> + 0x21fd37 (0x7fae6d46dd37 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #10: <unknown function> + 0x2ee0a3 (0x7fae6d53c0a3 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #11: <unknown function> + 0x20c223 (0x7fae6d45a223 in /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.27a0+5111d3b-py3.12-linux-x86_64.egg/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #12: python() [0x58208f]
<omitting python frames>
frame #15: python() [0x54cd94]
frame #19: python() [0x5a3628]
frame #23: python() [0x608b42]
frame #24: python() [0x6b4e93]
frame #29: <unknown function> + 0x2a1ca (0x7fc27852b1ca in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #30: __libc_start_main + 0x8b (0x7fc27852b28b in /usr/lib/x86_64-linux-gnu/libc.so.6)

Expected behavior

The script should be functioning w/o errors.

Environment

  • PyTorch Version (e.g., 1.0):
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, source):
  • Build command you used (if compiling from source):
  • Python version:
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

nemo peft - sangjeedondrub/tibetan-roberta-causal-base

Metadata

Metadata

Assignees

Labels

hf-transformersoperatorsthunderfxfor things that could be applicable to the dynamo+thunder frontend

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions