Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion thunder/clang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,18 @@ def full_like(
device = devices.to_device(device) if device is not None else a.device
dtype = dtype if dtype is not None else a.true_dtype

return full(a.shape, fill_value, device=device, dtype=dtype)
is_stride_decreasing = all(x > y for x, y in zip(a.stride(), a.stride()[1:]))
if is_stride_decreasing:
return full(a.shape, fill_value, device=device, dtype=dtype)

permutation = [None] * len(a.stride())
permuted_shape = [None] * len(a.stride())
for i, s in enumerate(sorted(a.stride(), reverse=True)):
permutation[a.stride().index(s)] = i
permuted_shape[i] = a.shape[a.stride().index(s)]

# TODO: Return a non-view tensor. ATen implements this as empty_strided(...).fill_(...).
return transpose(full(permuted_shape, fill_value, device=device, dtype=dtype), permutation)


@clangop()
Expand Down
16 changes: 15 additions & 1 deletion thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ class PrimIDs(Enum):
PACK_SETITEM = auto()
DATACLASS_NEW = auto()
SHAPE = auto()
STRIDE = auto()
# TODO: UNPACK_SET
# Utility prims
COMMENT = auto()
Expand Down Expand Up @@ -1393,6 +1394,17 @@ def shape_meta(t: TensorProxy) -> Sequence[int | NumberProxy]:
)


def stride_meta(t: TensorProxy) -> Sequence[int]:
return t._stride


stride = make_prim(
PrimIDs.STRIDE,
"stride",
meta=stride_meta,
)


# NOTE UNPACK_GETITEM is intended only to be bound to directly, and not called
def unpack_getitem_meta(o: Any, key: Any) -> Any:
raise NotImplementedError
Expand Down Expand Up @@ -3802,10 +3814,12 @@ def transpose_meta(a: TensorProxy, /, permutation: tuple[int, ...]) -> TensorPro
utils.check_valid_permutation(a.ndim, permutation)

new_shape = [0] * a.ndim
new_stride = [0] * a.ndim
for idx, dim in enumerate(permutation):
new_shape[idx] = a.shape[dim]
new_stride[idx] = a.stride()[dim]

return TensorProxy(like=a, shape=new_shape)
return TensorProxy(like=a, shape=new_shape, stride=new_stride)


transpose = make_prim(PrimIDs.TRANSPOSE, "transpose", meta=transpose_meta, tags=(OpTags.SHAPE_OP,))
Expand Down
28 changes: 28 additions & 0 deletions thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import copy
from enum import auto, Enum
import itertools
from numbers import Number
from typing import Any
from collections.abc import Callable
Expand Down Expand Up @@ -1224,6 +1225,7 @@ class DistParallelType(Enum):
def _infer_tensor_properties(
like: TensorProxy | FutureTensorProxy | None = None,
shape: ShapeLike | None = None,
stride: Sequence[int] | None = None,
device: devices.Device | None = None,
dtype: dtypes.dtype | None = None,
requires_grad: bool | None = None,
Expand All @@ -1232,6 +1234,7 @@ def _infer_tensor_properties(
thunder_fsdp_padding_size: int | None = None,
):
_shape = None
_stride = None
_device = None
_dtype = None
_requires_grad: None | bool = None
Expand All @@ -1242,6 +1245,7 @@ def _infer_tensor_properties(
if like is not None:
baseutils.check_type(like, (TensorProxy, FutureTensorProxy))
_shape = tuple(like._shape)
_stride = tuple(like._stride)
_device = like.device
_dtype = like.true_dtype
_requires_grad = like.requires_grad
Expand All @@ -1252,6 +1256,7 @@ def _infer_tensor_properties(
baseutils.check_valid_shape(shape)

_shape = tuple(shape) if shape is not None else _shape
_stride = tuple(stride) if stride is not None else _stride
_device = device if device is not None else _device
_dtype = dtype if dtype is not None else _dtype
_dtype = dtypes.numbertype_to_dtype(_dtype) if dtypes.is_numbertype(_dtype) else _dtype
Expand All @@ -1269,6 +1274,8 @@ def _infer_tensor_properties(
_shape = tuple(pyval(x) for x in _shape)
# Computes derived properties
_numel = reduce(operator.mul, _shape, 1)
if _stride is None:
_stride = tuple(itertools.accumulate(reversed(_shape), operator.mul, initial=1))[-2::-1]
else:
# deferred computation of numel
# TODO: similar to how `shape` is handled, this should be CSE or lifted for efficiency
Expand Down Expand Up @@ -1299,6 +1306,7 @@ def _infer_tensor_properties(

return (
_shape,
_stride,
_device,
_dtype,
_true_dtype,
Expand All @@ -1321,6 +1329,7 @@ def __init__(
*,
like: TensorProxy | FutureTensorProxy | None = None,
shape: ShapeLike | None = None,
stride: Sequence[int] | None = None,
device: devices.Device | None = None,
dtype: dtypes.dtype | None = None,
prefix: None | str = None,
Expand All @@ -1332,6 +1341,7 @@ def __init__(
# NOTE FutureTensorProxies never require grad
(
self._shape,
self._stride,
self._device,
self._dtype,
self._true_dtype,
Expand All @@ -1344,6 +1354,7 @@ def __init__(
) = _infer_tensor_properties(
like,
shape,
stride,
device,
dtype,
False,
Expand All @@ -1357,6 +1368,9 @@ def __init__(
def shape(self):
return self._shape

def stride(self):
return self._stride

@property
def numel(self):
return self._numel
Expand Down Expand Up @@ -1445,6 +1459,7 @@ def __init__(
*,
like: TensorProxy | FutureTensorProxy | None = None,
shape: ShapeLike | None = None,
stride: Sequence[int] | None = None,
device: devices.Device | None = None,
dtype: dtypes.dtype | None = None,
requires_grad: bool = False,
Expand All @@ -1459,6 +1474,7 @@ def __init__(

(
self._shape,
self._stride,
self._device,
self._dtype,
self._true_dtype,
Expand All @@ -1471,6 +1487,7 @@ def __init__(
) = _infer_tensor_properties(
like,
shape,
stride,
device,
dtype,
requires_grad,
Expand All @@ -1491,6 +1508,14 @@ def shape(self):

return shape(self)

def stride(self):
if not using_symbolic_values() or not is_tracing():
return self._stride
else:
from thunder.core.prims import stride

return stride(self)

@property
def ndim(self):
return self._ndim
Expand Down Expand Up @@ -1543,6 +1568,7 @@ def replace(self, **changes):
like = changes.get("like")
(
shape,
stride,
device,
dtype,
true_dtype,
Expand All @@ -1555,6 +1581,7 @@ def replace(self, **changes):
) = _infer_tensor_properties(
like,
changes.get("shape", self._shape if like is None else None),
changes.get("stride", self._stride if like is None else None),
changes.get("device", self._device if like is None else None),
changes.get("dtype", self._dtype if like is None else None),
changes.get("requires_grad", self._requires_grad if like is None else None),
Expand Down Expand Up @@ -2026,6 +2053,7 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple =
return TensorProxy(
name,
shape=tuple(shape),
stride=tuple(t.stride()),
device=device,
dtype=dtype,
requires_grad=t.requires_grad,
Expand Down
8 changes: 8 additions & 0 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2376,6 +2376,14 @@ def _shape_impl(t):
_register_implementation(prims.shape, shape, checker=_always_executable)


def _stride_impl(t):
return t.stride()


stride = ex.register_operator("stride", meta=prims.stride_meta, fn=_stride_impl)
_register_implementation(prims.stride, stride, checker=_always_executable)


def _bitcast_impl(src, dtype):
return src.view(dtypes.to_torch_dtype(dtype))

Expand Down
43 changes: 43 additions & 0 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -6653,6 +6653,49 @@ def torch_empty_and_zero(*args, **kwargs):
tensor_creation_ops.append(empty_opinfo)


def empty_strided_sample_generator(op, device, dtype, requires_grad, **kwargs):
# (shape, stride) pairs
cases = (
# Contiguous strides
((2, 3), (3, 1)),
((4, 4), (4, 1)),
((2, 3, 4), (12, 4, 1)),
# Non-contiguous strides (column-major)
((3, 4), (1, 3)),
((2, 3, 4), (1, 2, 6)),
# 1D tensors
((10,), (1,)),
((5,), (2,)),
# Scalar
((), ()),
)

for shape, stride in cases:
yield SampleInput(shape, stride, device=device, dtype=dtype)


# Helper function for `empty_strided` opinfo.
# It always returns zero tensors, so that the consistency tests pass.
def torch_empty_strided_and_zero(shape, stride, **kwargs):
result = ltorch.empty_strided(shape, stride, **kwargs)
# Use full_like to fill with zeros, which preserves the stride
return ltorch.full_like(result, 0)


def torch_empty_strided_reference(shape, stride, **kwargs):
result = torch.empty_strided(shape, stride, **kwargs)
return result.fill_(0)


empty_strided_opinfo = OpInfo(
name="empty_strided",
op=torch_empty_strided_and_zero,
sample_input_generator=empty_strided_sample_generator,
torch_reference=torch_empty_strided_reference,
)
tensor_creation_ops.append(empty_strided_opinfo)


def fixed_value_tensor_creation_op_sample_generator_with_bounds(op, device, dtype, requires_grad, **kwargs):
# shape
cases = (
Expand Down
42 changes: 42 additions & 0 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3183,6 +3183,48 @@ def test_proxy_same_name():
TensorProxy(name="test", shape=(1,), device=cpu, dtype=float32)


def test_tensor_proxy_with_stride():
from thunder.core.proxies import TensorProxy
from thunder.core.trace import detached_trace
from thunder.core.dtypes import float32
from thunder.core.devices import cpu

# Test that TensorProxy can be created with custom strides
with detached_trace():
t1 = TensorProxy(shape=(2, 3, 4), device=cpu, dtype=float32)
assert t1.shape == (2, 3, 4)
assert t1.stride() == (12, 4, 1)

t2 = TensorProxy(shape=(10,), device=cpu, dtype=float32)
assert t2.shape == (10,)
assert t2.stride() == (1,)

t3 = TensorProxy(shape=(), device=cpu, dtype=float32)
assert t3.shape == ()
assert t3.stride() == ()

t4 = TensorProxy(shape=(2, 3, 4), stride=(2, 8, 1), device=cpu, dtype=float32)
assert t4.shape == (2, 3, 4)
assert t4.stride() == (2, 8, 1)


def test_stride_tracking():
def fn(x):
result = [x.stride()]
x = x.transpose(0, 1)
result.append(x.stride())
x = x.permute(2, 0, 1)
result.append(x.stride())
x = x.reshape(8, 3)
result.append(x.stride())
return result

jfn = thunder.jit(fn)
x = torch.randn(2, 3, 4)

assert jfn(x) == fn(x)


def test_save_trace():
def fn(x):
return x + 1
Expand Down
Loading