Skip to content
30 changes: 0 additions & 30 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,36 +224,6 @@ def codegen(context, builder, signature, args):
return sig, codegen


def int_to_float_fn(inputs, out_dtype):
"""Create a Numba function that converts integer and boolean ``ndarray``s to floats."""

if (
all(inp.type.dtype == out_dtype for inp in inputs)
and np.dtype(out_dtype).kind == "f"
):

@numba_njit(inline="always")
def inputs_cast(x):
return x

elif any(i.type.numpy_dtype.kind in "uib" for i in inputs):
args_dtype = np.dtype(f"f{out_dtype.itemsize}")

@numba_njit(inline="always")
def inputs_cast(x):
return x.astype(args_dtype)

else:
args_dtype_sz = max(_arg.type.numpy_dtype.itemsize for _arg in inputs)
args_dtype = np.dtype(f"f{args_dtype_sz}")

@numba_njit(inline="always")
def inputs_cast(x):
return x.astype(args_dtype)

return inputs_cast


@singledispatch
def numba_typify(data, dtype=None, **kwargs):
return data
Expand Down
34 changes: 22 additions & 12 deletions pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import numpy as np
from numba.core.extending import overload
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from numba.types import Float
from scipy import linalg

from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix
from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix


def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
Expand All @@ -24,30 +24,36 @@ def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
@overload(_cholesky)
def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
ensure_lapack()
_check_scipy_linalg_matrix(A, "cholesky")
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="cholesky")
dtype = A.dtype
w_type = _get_underlying_float(dtype)

numba_potrf = _LAPACK().numba_xpotrf(dtype)

def impl(A, lower=0, overwrite_a=False, check_finite=True):
def impl(A, lower=False, overwrite_a=False, check_finite=True):
_N = np.int32(A.shape[-1])
if A.shape[-2] != _N:
raise linalg.LinAlgError("Last 2 dimensions of A must be square")

UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)

transposed = False
if overwrite_a and A.flags.f_contiguous:
A_copy = A
elif overwrite_a and A.flags.c_contiguous:
# We can work on the transpose of A directly
A_copy = A.T
transposed = True
lower = not lower
else:
A_copy = _copy_to_fortran_order(A)

UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
N = val_to_int_ptr(_N)
LDA = val_to_int_ptr(_N)
INFO = val_to_int_ptr(0)

numba_potrf(
UPLO,
N,
A_copy.view(w_type).ctypes,
A_copy.ctypes,
LDA,
INFO,
)
Expand All @@ -61,6 +67,10 @@ def impl(A, lower=0, overwrite_a=False, check_finite=True):
for i in range(j + 1, _N):
A_copy[i, j] = 0.0

return A_copy, int_ptr_to_val(INFO)
info_int = int_ptr_to_val(INFO)

if transposed:
return A_copy.T, info_int
return A_copy, info_int

return impl
15 changes: 8 additions & 7 deletions pytensor/link/numba/dispatch/linalg/decomposition/lu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@

import numpy as np
from numba.core.extending import overload
from numba.core.types import Float
from numba.np.linalg import ensure_lapack
from scipy import linalg

from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _getrf
from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix
from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix


@numba_basic.numba_njit
def _pivot_to_permutation(p, dtype):
p_inv = np.arange(len(p)).astype(dtype)
def _pivot_to_permutation(p):
p_inv = np.arange(len(p))
for i in range(len(p)):
p_inv[i], p_inv[p[i]] = p_inv[p[i]], p_inv[i]
return p_inv
Expand All @@ -29,7 +30,7 @@ def _lu_factor_to_lu(a, dtype, overwrite_a):

# Fortran is 1 indexed, so we need to subtract 1 from the IPIV array
IPIV = IPIV - 1
p_inv = _pivot_to_permutation(IPIV, dtype=dtype)
p_inv = _pivot_to_permutation(IPIV)
perm = np.argsort(p_inv).astype("int32")

return perm, L, U
Expand Down Expand Up @@ -116,7 +117,7 @@ def lu_impl_1(
False. Returns a tuple of (perm, L, U), where perm an integer array of row swaps, such that L[perm] @ U = A.
"""
ensure_lapack()
_check_scipy_linalg_matrix(a, "lu")
_check_linalg_matrix(a, ndim=2, dtype=Float, func_name="lu")
dtype = a.dtype

def impl(
Expand Down Expand Up @@ -146,7 +147,7 @@ def lu_impl_2(
"""

ensure_lapack()
_check_scipy_linalg_matrix(a, "lu")
_check_linalg_matrix(a, ndim=2, dtype=Float, func_name="lu")
dtype = a.dtype

def impl(
Expand Down Expand Up @@ -179,7 +180,7 @@ def lu_impl_3(
False. Returns a tuple of (P, L, U), such that P @ L @ U = A.
"""
ensure_lapack()
_check_scipy_linalg_matrix(a, "lu")
_check_linalg_matrix(a, ndim=2, dtype=Float, func_name="lu")
dtype = a.dtype

def impl(
Expand Down
13 changes: 5 additions & 8 deletions pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,16 @@

import numpy as np
from numba.core.extending import overload
from numba.core.types import Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg

from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix,
)
from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix


def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
Expand All @@ -38,9 +36,8 @@ def getrf_impl(
A: np.ndarray, overwrite_a: bool = False
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray, int]]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "getrf")
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="getrf")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_getrf = _LAPACK().numba_xgetrf(dtype)

def impl(
Expand All @@ -59,7 +56,7 @@ def impl(
IPIV = np.empty(_N, dtype=np.int32) # type: ignore
INFO = val_to_int_ptr(0)

numba_getrf(M, N, A_copy.view(w_type).ctypes, LDA, IPIV.ctypes, INFO)
numba_getrf(M, N, A_copy.ctypes, LDA, IPIV.ctypes, INFO)

return A_copy, IPIV, int_ptr_to_val(INFO)

Expand All @@ -79,7 +76,7 @@ def lu_factor_impl(
A: np.ndarray, overwrite_a: bool = False
) -> Callable[[np.ndarray, bool], tuple[np.ndarray, np.ndarray]]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "lu_factor")
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="lu_factor")

def impl(A: np.ndarray, overwrite_a: bool = False) -> tuple[np.ndarray, np.ndarray]:
A_copy, IPIV, INFO = _getrf(A, overwrite_a=overwrite_a)
Expand Down
15 changes: 8 additions & 7 deletions pytensor/link/numba/dispatch/linalg/solve/cholesky.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import numpy as np
from numba.core.extending import overload
from numba.core.types import Float
from numba.np.linalg import ensure_lapack
from scipy import linalg

from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix,
_check_dtypes_match,
_check_linalg_matrix,
_copy_to_fortran_order_even_if_1d,
_solve_check,
)
Expand All @@ -31,10 +32,10 @@ def _cho_solve(
@overload(_cho_solve)
def cho_solve_impl(C, B, lower=False, overwrite_b=False, check_finite=True):
ensure_lapack()
_check_scipy_linalg_matrix(C, "cho_solve")
_check_scipy_linalg_matrix(B, "cho_solve")
_check_linalg_matrix(C, ndim=2, dtype=Float, func_name="cho_solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="cho_solve")
_check_dtypes_match((C, B), func_name="cho_solve")
dtype = C.dtype
w_type = _get_underlying_float(dtype)
numba_potrs = _LAPACK().numba_xpotrs(dtype)

def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
Expand Down Expand Up @@ -71,9 +72,9 @@ def impl(C, B, lower=False, overwrite_b=False, check_finite=True):
UPLO,
N,
NRHS,
C_f.view(w_type).ctypes,
C_f.ctypes,
LDA,
B_copy.view(w_type).ctypes,
B_copy.ctypes,
LDB,
INFO,
)
Expand Down
21 changes: 11 additions & 10 deletions pytensor/link/numba/dispatch/linalg/solve/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

import numpy as np
from numba.core.extending import overload
from numba.core.types import Float
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg

from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
Expand All @@ -16,7 +16,8 @@
from pytensor.link.numba.dispatch.linalg.solve.norm import _xlange
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix,
_check_dtypes_match,
_check_linalg_matrix,
_solve_check,
)

Expand All @@ -37,9 +38,8 @@ def xgecon_impl(
Compute the condition number of a matrix A.
"""
ensure_lapack()
_check_scipy_linalg_matrix(A, "gecon")
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="gecon")
dtype = A.dtype
w_type = _get_underlying_float(dtype)
numba_gecon = _LAPACK().numba_xgecon(dtype)

def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]:
Expand All @@ -58,11 +58,11 @@ def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]:
numba_gecon(
NORM,
N,
A_copy.view(w_type).ctypes,
A_copy.ctypes,
LDA,
A_NORM.view(w_type).ctypes,
RCOND.view(w_type).ctypes,
WORK.view(w_type).ctypes,
A_NORM.ctypes,
RCOND.ctypes,
WORK.ctypes,
IWORK.ctypes,
INFO,
)
Expand Down Expand Up @@ -106,8 +106,9 @@ def solve_gen_impl(
transposed: bool,
) -> Callable[[np.ndarray, np.ndarray, bool, bool, bool, bool, bool], np.ndarray]:
ensure_lapack()
_check_scipy_linalg_matrix(A, "solve")
_check_scipy_linalg_matrix(B, "solve")
_check_linalg_matrix(A, ndim=2, dtype=Float, func_name="solve")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="solve")
_check_dtypes_match((A, B), "solve")

def impl(
A: np.ndarray,
Expand Down
22 changes: 13 additions & 9 deletions pytensor/link/numba/dispatch/linalg/solve/lu_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@

import numpy as np
from numba.core.extending import overload
from numba.core.types import Float, int32
from numba.np.linalg import _copy_to_fortran_order, ensure_lapack
from scipy import linalg

from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
_get_underlying_float,
int_ptr_to_val,
val_to_int_ptr,
)
from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes
from pytensor.link.numba.dispatch.linalg.utils import (
_check_scipy_linalg_matrix,
_check_dtypes_match,
_check_linalg_matrix,
_copy_to_fortran_order_even_if_1d,
_solve_check,
_trans_char_to_int,
Expand Down Expand Up @@ -44,10 +45,11 @@ def getrs_impl(
[np.ndarray, np.ndarray, np.ndarray, _Trans, bool], tuple[np.ndarray, int]
]:
ensure_lapack()
_check_scipy_linalg_matrix(LU, "getrs")
_check_scipy_linalg_matrix(B, "getrs")
_check_linalg_matrix(LU, ndim=2, dtype=Float, func_name="getrs")
_check_linalg_matrix(B, ndim=(1, 2), dtype=Float, func_name="getrs")
_check_dtypes_match((LU, B), func_name="getrs")
_check_linalg_matrix(IPIV, ndim=1, dtype=int32, func_name="getrs")
dtype = LU.dtype
w_type = _get_underlying_float(dtype)
numba_getrs = _LAPACK().numba_xgetrs(dtype)

def impl(
Expand Down Expand Up @@ -84,10 +86,10 @@ def impl(
TRANS,
N,
NRHS,
LU.view(w_type).ctypes,
LU.ctypes,
LDA,
IPIV.ctypes,
B_copy.view(w_type).ctypes,
B_copy.ctypes,
LDB,
INFO,
)
Expand Down Expand Up @@ -124,8 +126,10 @@ def lu_solve_impl(
check_finite: bool,
) -> Callable[[np.ndarray, np.ndarray, np.ndarray, _Trans, bool, bool], np.ndarray]:
ensure_lapack()
_check_scipy_linalg_matrix(lu_and_piv[0], "lu_solve")
_check_scipy_linalg_matrix(b, "lu_solve")
lu, _piv = lu_and_piv
_check_linalg_matrix(lu, ndim=2, dtype=Float, func_name="lu_solve")
_check_linalg_matrix(b, ndim=(1, 2), dtype=Float, func_name="lu_solve")
_check_dtypes_match((lu, b), func_name="lu_solve")

def impl(
lu: np.ndarray,
Expand Down
Loading
Loading