Skip to content

Commit 1b0ce21

Browse files
committed
Numba linalg: Fallback to objmode with complex inputs
1 parent f77444d commit 1b0ce21

File tree

1 file changed

+8
-13
lines changed

1 file changed

+8
-13
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pytensor import config
66
from pytensor.link.numba.dispatch import basic as numba_basic
77
from pytensor.link.numba.dispatch.basic import (
8+
generate_fallback_impl,
89
numba_funcify,
910
register_funcify_default_op_cache_key,
1011
)
@@ -44,12 +45,6 @@
4445
from pytensor.tensor.type import complex_dtypes, integer_dtypes
4546

4647

47-
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG = (
48-
"Complex dtype for {op} not supported in numba mode. "
49-
"If you need this functionality, please open an issue at: https://github.com/pymc-devs/pytensor"
50-
)
51-
52-
5348
@numba_funcify.register(Cholesky)
5449
def numba_funcify_Cholesky(op, node, **kwargs):
5550
"""
@@ -65,7 +60,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
6560

6661
inp_dtype = node.inputs[0].type.numpy_dtype
6762
if inp_dtype.kind == "c":
68-
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
63+
return generate_fallback_impl(op, node=node, **kwargs)
6964
discrete_inp = inp_dtype.kind in "ibu"
7065
if discrete_inp and config.optimizer_verbose:
7166
print("Cholesky requires casting discrete input to float") # noqa: T201
@@ -125,7 +120,7 @@ def numba_pivot_to_permutation(piv):
125120
def numba_funcify_LU(op, node, **kwargs):
126121
inp_dtype = node.inputs[0].type.numpy_dtype
127122
if inp_dtype.kind == "c":
128-
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
123+
return generate_fallback_impl(op, node=node, **kwargs)
129124
discrete_inp = inp_dtype.kind in "ibu"
130125
if discrete_inp and config.optimizer_verbose:
131126
print("LU requires casting discrete input to float") # noqa: T201
@@ -192,7 +187,7 @@ def lu(a):
192187
def numba_funcify_LUFactor(op, node, **kwargs):
193188
inp_dtype = node.inputs[0].type.numpy_dtype
194189
if inp_dtype.kind == "c":
195-
NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
190+
return generate_fallback_impl(op, node=node, **kwargs)
196191
discrete_inp = inp_dtype.kind in "ibu"
197192
if discrete_inp and config.optimizer_verbose:
198193
print("LUFactor requires casting discrete input to float") # noqa: T201
@@ -252,7 +247,7 @@ def numba_funcify_Solve(op, node, **kwargs):
252247
out_dtype = node.outputs[0].type.numpy_dtype
253248

254249
if A_dtype.kind == "c" or b_dtype.kind == "c":
255-
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
250+
raise generate_fallback_impl(op, node=node, **kwargs)
256251
must_cast_A = A_dtype != out_dtype
257252
if must_cast_A and config.optimizer_verbose:
258253
print("Solve requires casting first input `A`") # noqa: T201
@@ -326,7 +321,7 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
326321
out_dtype = node.outputs[0].type.numpy_dtype
327322

328323
if A_dtype.kind == "c" or b_dtype.kind == "c":
329-
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
324+
raise generate_fallback_impl(op, node=node, **kwargs)
330325
must_cast_A = A_dtype != out_dtype
331326
if must_cast_A and config.optimizer_verbose:
332327
print("SolveTriangular requires casting first input `A`") # noqa: T201
@@ -377,7 +372,7 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
377372
out_dtype = node.outputs[0].type.numpy_dtype
378373

379374
if c_dtype.kind == "c" or b_dtype.kind == "c":
380-
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
375+
raise generate_fallback_impl(op, node=node, **kwargs)
381376
must_cast_c = c_dtype != out_dtype
382377
if must_cast_c and config.optimizer_verbose:
383378
print("CholeskySolve requires casting first input `c`") # noqa: T201
@@ -425,7 +420,7 @@ def numba_funcify_QR(op, node, **kwargs):
425420

426421
dtype = node.inputs[0].dtype
427422
if dtype in complex_dtypes:
428-
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
423+
return generate_fallback_impl(op, node=node, **kwargs)
429424

430425
integer_input = dtype in integer_dtypes
431426
in_dtype = config.floatX if integer_input else dtype

0 commit comments

Comments
 (0)