From ee505ae5c58500e648025999b6191dd6f08a0986 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Wed, 22 Oct 2025 13:17:56 -0700 Subject: [PATCH 1/5] Add Tensor.stride() --- thunder/core/prims.py | 12 ++++++++++++ thunder/core/proxies.py | 28 ++++++++++++++++++++++++++++ thunder/executors/torchex.py | 8 ++++++++ 3 files changed, 48 insertions(+) diff --git a/thunder/core/prims.py b/thunder/core/prims.py index 90e1f3689b..67acf457c8 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -131,6 +131,7 @@ class PrimIDs(Enum): PACK_SETITEM = auto() DATACLASS_NEW = auto() SHAPE = auto() + STRIDE = auto() # TODO: UNPACK_SET # Utility prims COMMENT = auto() @@ -1393,6 +1394,17 @@ def shape_meta(t: TensorProxy) -> Sequence[int | NumberProxy]: ) +def stride_meta(t: TensorProxy) -> Sequence[int]: + return t._stride + + +stride = make_prim( + PrimIDs.STRIDE, + "stride", + meta=stride_meta, +) + + # NOTE UNPACK_GETITEM is intended only to be bound to directly, and not called def unpack_getitem_meta(o: Any, key: Any) -> Any: raise NotImplementedError diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index b19e0ea1fc..9742a41c1b 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -2,6 +2,7 @@ import copy from enum import auto, Enum +import itertools from numbers import Number from typing import Any from collections.abc import Callable @@ -1224,6 +1225,7 @@ class DistParallelType(Enum): def _infer_tensor_properties( like: TensorProxy | FutureTensorProxy | None = None, shape: ShapeLike | None = None, + stride: Sequence[int] | None = None, device: devices.Device | None = None, dtype: dtypes.dtype | None = None, requires_grad: bool | None = None, @@ -1232,6 +1234,7 @@ def _infer_tensor_properties( thunder_fsdp_padding_size: int | None = None, ): _shape = None + _stride = None _device = None _dtype = None _requires_grad: None | bool = None @@ -1242,6 +1245,7 @@ def _infer_tensor_properties( if like is not None: baseutils.check_type(like, (TensorProxy, FutureTensorProxy)) _shape = tuple(like._shape) + _stride = tuple(like._stride) _device = like.device _dtype = like.true_dtype _requires_grad = like.requires_grad @@ -1252,6 +1256,7 @@ def _infer_tensor_properties( baseutils.check_valid_shape(shape) _shape = tuple(shape) if shape is not None else _shape + _stride = tuple(stride) if stride is not None else _stride _device = device if device is not None else _device _dtype = dtype if dtype is not None else _dtype _dtype = dtypes.numbertype_to_dtype(_dtype) if dtypes.is_numbertype(_dtype) else _dtype @@ -1269,6 +1274,8 @@ def _infer_tensor_properties( _shape = tuple(pyval(x) for x in _shape) # Computes derived properties _numel = reduce(operator.mul, _shape, 1) + if _stride is None: + _stride = tuple(itertools.accumulate(reversed(_shape), operator.mul, initial=1))[-2::-1] else: # deferred computation of numel # TODO: similar to how `shape` is handled, this should be CSE or lifted for efficiency @@ -1299,6 +1306,7 @@ def _infer_tensor_properties( return ( _shape, + _stride, _device, _dtype, _true_dtype, @@ -1321,6 +1329,7 @@ def __init__( *, like: TensorProxy | FutureTensorProxy | None = None, shape: ShapeLike | None = None, + stride: Sequence[int] | None = None, device: devices.Device | None = None, dtype: dtypes.dtype | None = None, prefix: None | str = None, @@ -1332,6 +1341,7 @@ def __init__( # NOTE FutureTensorProxies never require grad ( self._shape, + self._stride, self._device, self._dtype, self._true_dtype, @@ -1344,6 +1354,7 @@ def __init__( ) = _infer_tensor_properties( like, shape, + stride, device, dtype, False, @@ -1357,6 +1368,9 @@ def __init__( def shape(self): return self._shape + def stride(self): + return self._stride + @property def numel(self): return self._numel @@ -1445,6 +1459,7 @@ def __init__( *, like: TensorProxy | FutureTensorProxy | None = None, shape: ShapeLike | None = None, + stride: Sequence[int] | None = None, device: devices.Device | None = None, dtype: dtypes.dtype | None = None, requires_grad: bool = False, @@ -1459,6 +1474,7 @@ def __init__( ( self._shape, + self._stride, self._device, self._dtype, self._true_dtype, @@ -1471,6 +1487,7 @@ def __init__( ) = _infer_tensor_properties( like, shape, + stride, device, dtype, requires_grad, @@ -1491,6 +1508,14 @@ def shape(self): return shape(self) + def stride(self): + if not using_symbolic_values() or not is_tracing(): + return self._stride + else: + from thunder.core.prims import stride + + return stride(self) + @property def ndim(self): return self._ndim @@ -1543,6 +1568,7 @@ def replace(self, **changes): like = changes.get("like") ( shape, + stride, device, dtype, true_dtype, @@ -1555,6 +1581,7 @@ def replace(self, **changes): ) = _infer_tensor_properties( like, changes.get("shape", self._shape if like is None else None), + changes.get("stride", self._stride if like is None else None), changes.get("device", self._device if like is None else None), changes.get("dtype", self._dtype if like is None else None), changes.get("requires_grad", self._requires_grad if like is None else None), @@ -2026,6 +2053,7 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple = return TensorProxy( name, shape=tuple(shape), + stride=tuple(t.stride()), device=device, dtype=dtype, requires_grad=t.requires_grad, diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 84a8ea9f93..3ef1ab2c34 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -2376,6 +2376,14 @@ def _shape_impl(t): _register_implementation(prims.shape, shape, checker=_always_executable) +def _stride_impl(t): + return t.stride() + + +stride = ex.register_operator("stride", meta=prims.stride_meta, fn=_stride_impl) +_register_implementation(prims.stride, stride, checker=_always_executable) + + def _bitcast_impl(src, dtype): return src.view(dtypes.to_torch_dtype(dtype)) From 6e0af6994146c3cf171400d94196accfad72e317 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Wed, 22 Oct 2025 13:19:53 -0700 Subject: [PATCH 2/5] Add tests --- thunder/tests/opinfos.py | 43 ++++++++++++++++++++++++++++++++++++++ thunder/tests/test_core.py | 42 +++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 33df903a10..d25d63ae60 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -6653,6 +6653,49 @@ def torch_empty_and_zero(*args, **kwargs): tensor_creation_ops.append(empty_opinfo) +def empty_strided_sample_generator(op, device, dtype, requires_grad, **kwargs): + # (shape, stride) pairs + cases = ( + # Contiguous strides + ((2, 3), (3, 1)), + ((4, 4), (4, 1)), + ((2, 3, 4), (12, 4, 1)), + # Non-contiguous strides (column-major) + ((3, 4), (1, 3)), + ((2, 3, 4), (1, 2, 6)), + # 1D tensors + ((10,), (1,)), + ((5,), (2,)), + # Scalar + ((), ()), + ) + + for shape, stride in cases: + yield SampleInput(shape, stride, device=device, dtype=dtype) + + +# Helper function for `empty_strided` opinfo. +# It always returns zero tensors, so that the consistency tests pass. +def torch_empty_strided_and_zero(shape, stride, **kwargs): + result = ltorch.empty_strided(shape, stride, **kwargs) + # Use full_like to fill with zeros, which preserves the stride + return ltorch.full_like(result, 0) + + +def torch_empty_strided_reference(shape, stride, **kwargs): + result = torch.empty_strided(shape, stride, **kwargs) + return result.fill_(0) + + +empty_strided_opinfo = OpInfo( + name="empty_strided", + op=torch_empty_strided_and_zero, + sample_input_generator=empty_strided_sample_generator, + torch_reference=torch_empty_strided_reference, +) +tensor_creation_ops.append(empty_strided_opinfo) + + def fixed_value_tensor_creation_op_sample_generator_with_bounds(op, device, dtype, requires_grad, **kwargs): # shape cases = ( diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index 4210fdca67..57f47ad816 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -3183,6 +3183,48 @@ def test_proxy_same_name(): TensorProxy(name="test", shape=(1,), device=cpu, dtype=float32) +def test_tensor_proxy_with_stride(): + from thunder.core.proxies import TensorProxy + from thunder.core.trace import detached_trace + from thunder.core.dtypes import float32 + from thunder.core.devices import cpu + + # Test that TensorProxy can be created with custom strides + with detached_trace(): + t1 = TensorProxy(shape=(2, 3, 4), device=cpu, dtype=float32) + assert t1.shape == (2, 3, 4) + assert t1.stride() == (12, 4, 1) + + t2 = TensorProxy(shape=(10,), device=cpu, dtype=float32) + assert t2.shape == (10,) + assert t2.stride() == (1,) + + t3 = TensorProxy(shape=(), device=cpu, dtype=float32) + assert t3.shape == () + assert t3.stride() == () + + t4 = TensorProxy(shape=(2, 3, 4), stride=(2, 8, 1), device=cpu, dtype=float32) + assert t4.shape == (2, 3, 4) + assert t4.stride() == (2, 8, 1) + + +def test_stride_tracking(): + def fn(x): + result = [x.stride()] + x = x.transpose(0, 1) + result.append(x.stride()) + x = x.permute(2, 0, 1) + result.append(x.stride()) + x = x.reshape(8, 3) + result.append(x.stride()) + return result + + jfn = thunder.jit(fn) + x = torch.randn(2, 3, 4) + + assert jfn(x) == fn(x) + + def test_save_trace(): def fn(x): return x + 1 From 71390fabc0d85fb5c52013569ac03b86e3252a96 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Wed, 22 Oct 2025 16:26:37 -0700 Subject: [PATCH 3/5] Make full_like(a) transpose the created tensor if needed --- thunder/clang/__init__.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/thunder/clang/__init__.py b/thunder/clang/__init__.py index 333086b906..0605fa9265 100644 --- a/thunder/clang/__init__.py +++ b/thunder/clang/__init__.py @@ -265,7 +265,17 @@ def full_like( device = devices.to_device(device) if device is not None else a.device dtype = dtype if dtype is not None else a.true_dtype - return full(a.shape, fill_value, device=device, dtype=dtype) + is_stride_decreasing = all(x > y for x, y in zip(a.stride(), a.stride()[1:])) + if is_stride_decreasing: + return full(a.shape, fill_value, device=device, dtype=dtype) + + permutation = [None] * len(a.stride()) + permuted_shape = [None] * len(a.stride()) + for i, s in enumerate(sorted(a.stride(), reverse=True)): + permutation[a.stride().index(s)] = i + permuted_shape[i] = a.shape[a.stride().index(s)] + + return transpose(full(permuted_shape, fill_value, device=device, dtype=dtype), permutation) @clangop() From d73d6e9e84b2e777a73ff439500e3cb2b4be08f7 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Wed, 22 Oct 2025 17:03:37 -0700 Subject: [PATCH 4/5] Add comment --- thunder/clang/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/thunder/clang/__init__.py b/thunder/clang/__init__.py index 0605fa9265..3222698c1d 100644 --- a/thunder/clang/__init__.py +++ b/thunder/clang/__init__.py @@ -275,6 +275,7 @@ def full_like( permutation[a.stride().index(s)] = i permuted_shape[i] = a.shape[a.stride().index(s)] + # TODO: Return a non-view tensor. ATen implements this as empty_strided(...).fill_(...). return transpose(full(permuted_shape, fill_value, device=device, dtype=dtype), permutation) From 2ab2f19f3d22af157384cebf438d1c4d81878ac4 Mon Sep 17 00:00:00 2001 From: Masato Shinokawa Date: Wed, 22 Oct 2025 20:56:01 -0700 Subject: [PATCH 5/5] Track strides over Tensor.transpose --- thunder/core/prims.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/thunder/core/prims.py b/thunder/core/prims.py index 67acf457c8..6186ef406e 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -3814,10 +3814,12 @@ def transpose_meta(a: TensorProxy, /, permutation: tuple[int, ...]) -> TensorPro utils.check_valid_permutation(a.ndim, permutation) new_shape = [0] * a.ndim + new_stride = [0] * a.ndim for idx, dim in enumerate(permutation): new_shape[idx] = a.shape[dim] + new_stride[idx] = a.stride()[dim] - return TensorProxy(like=a, shape=new_shape) + return TensorProxy(like=a, shape=new_shape, stride=new_stride) transpose = make_prim(PrimIDs.TRANSPOSE, "transpose", meta=transpose_meta, tags=(OpTags.SHAPE_OP,))