Skip to content

Commit 70b03b5

Browse files
committed
Numba Cholesky: Allow inplace on C_contiguous inputs
1 parent c199d6d commit 70b03b5

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,27 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
2929

3030
numba_potrf = _LAPACK().numba_xpotrf(dtype)
3131

32-
def impl(A, lower=0, overwrite_a=False, check_finite=True):
32+
def impl(A, lower=False, overwrite_a=False, check_finite=True):
3333
_N = np.int32(A.shape[-1])
3434
if A.shape[-2] != _N:
3535
raise linalg.LinAlgError("Last 2 dimensions of A must be square")
3636

37-
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
38-
N = val_to_int_ptr(_N)
39-
LDA = val_to_int_ptr(_N)
40-
INFO = val_to_int_ptr(0)
41-
37+
transposed = False
4238
if overwrite_a and A.flags.f_contiguous:
4339
A_copy = A
40+
elif overwrite_a and A.flags.c_contiguous:
41+
# We can work on the transpose of A directly
42+
A_copy = A.T
43+
transposed = True
44+
lower = not lower
4445
else:
4546
A_copy = _copy_to_fortran_order(A)
4647

48+
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
49+
N = val_to_int_ptr(_N)
50+
LDA = val_to_int_ptr(_N)
51+
INFO = val_to_int_ptr(0)
52+
4753
numba_potrf(
4854
UPLO,
4955
N,
@@ -61,6 +67,10 @@ def impl(A, lower=0, overwrite_a=False, check_finite=True):
6167
for i in range(j + 1, _N):
6268
A_copy[i, j] = 0.0
6369

64-
return A_copy, int_ptr_to_val(INFO)
70+
info_int = int_ptr_to_val(INFO)
71+
72+
if transposed:
73+
return A_copy.T, info_int
74+
return A_copy, info_int
6575

6676
return impl

tests/link/numba/test_slinalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -551,8 +551,8 @@ def test_cholesky(self, lower: bool, overwrite_a: bool):
551551
val_c_contig = np.copy(val, order="C")
552552
res_c_contig = fn(val_c_contig)
553553
np.testing.assert_allclose(res_c_contig, res)
554-
# Cannot destroy C-contiguous input
555-
np.testing.assert_allclose(val_c_contig, val)
554+
# Should always be destroyable
555+
assert (val == val_c_contig).all() == (not overwrite_a)
556556

557557
# Test non-contiguous input
558558
val_not_contig = np.repeat(val, 2, axis=0)[::2]

0 commit comments

Comments
 (0)