55from pytensor import config
66from pytensor .link .numba .dispatch import basic as numba_basic
77from pytensor .link .numba .dispatch .basic import (
8+ generate_fallback_impl ,
89 numba_funcify ,
910 register_funcify_default_op_cache_key ,
1011)
4445from pytensor .tensor .type import complex_dtypes , integer_dtypes
4546
4647
47- _COMPLEX_DTYPE_NOT_SUPPORTED_MSG = (
48- "Complex dtype for {op} not supported in numba mode. "
49- "If you need this functionality, please open an issue at: https://github.com/pymc-devs/pytensor"
50- )
51-
52-
5348@numba_funcify .register (Cholesky )
5449def numba_funcify_Cholesky (op , node , ** kwargs ):
5550 """
@@ -65,7 +60,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
6560
6661 inp_dtype = node .inputs [0 ].type .numpy_dtype
6762 if inp_dtype .kind == "c" :
68- raise NotImplementedError ( _COMPLEX_DTYPE_NOT_SUPPORTED_MSG . format ( op = op ) )
63+ return generate_fallback_impl ( op , node = node , ** kwargs )
6964 discrete_inp = inp_dtype .kind in "ibu"
7065 if discrete_inp and config .optimizer_verbose :
7166 print ("Cholesky requires casting discrete input to float" ) # noqa: T201
@@ -125,7 +120,7 @@ def numba_pivot_to_permutation(piv):
125120def numba_funcify_LU (op , node , ** kwargs ):
126121 inp_dtype = node .inputs [0 ].type .numpy_dtype
127122 if inp_dtype .kind == "c" :
128- raise NotImplementedError ( _COMPLEX_DTYPE_NOT_SUPPORTED_MSG . format ( op = op ) )
123+ return generate_fallback_impl ( op , node = node , ** kwargs )
129124 discrete_inp = inp_dtype .kind in "ibu"
130125 if discrete_inp and config .optimizer_verbose :
131126 print ("LU requires casting discrete input to float" ) # noqa: T201
@@ -192,7 +187,7 @@ def lu(a):
192187def numba_funcify_LUFactor (op , node , ** kwargs ):
193188 inp_dtype = node .inputs [0 ].type .numpy_dtype
194189 if inp_dtype .kind == "c" :
195- NotImplementedError ( _COMPLEX_DTYPE_NOT_SUPPORTED_MSG . format ( op = op ) )
190+ return generate_fallback_impl ( op , node = node , ** kwargs )
196191 discrete_inp = inp_dtype .kind in "ibu"
197192 if discrete_inp and config .optimizer_verbose :
198193 print ("LUFactor requires casting discrete input to float" ) # noqa: T201
@@ -252,7 +247,7 @@ def numba_funcify_Solve(op, node, **kwargs):
252247 out_dtype = node .outputs [0 ].type .numpy_dtype
253248
254249 if A_dtype .kind == "c" or b_dtype .kind == "c" :
255- raise NotImplementedError ( _COMPLEX_DTYPE_NOT_SUPPORTED_MSG . format ( op = op ) )
250+ raise generate_fallback_impl ( op , node = node , ** kwargs )
256251 must_cast_A = A_dtype != out_dtype
257252 if must_cast_A and config .optimizer_verbose :
258253 print ("Solve requires casting first input `A`" ) # noqa: T201
@@ -326,7 +321,7 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
326321 out_dtype = node .outputs [0 ].type .numpy_dtype
327322
328323 if A_dtype .kind == "c" or b_dtype .kind == "c" :
329- raise NotImplementedError ( _COMPLEX_DTYPE_NOT_SUPPORTED_MSG . format ( op = op ) )
324+ raise generate_fallback_impl ( op , node = node , ** kwargs )
330325 must_cast_A = A_dtype != out_dtype
331326 if must_cast_A and config .optimizer_verbose :
332327 print ("SolveTriangular requires casting first input `A`" ) # noqa: T201
@@ -377,7 +372,7 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
377372 out_dtype = node .outputs [0 ].type .numpy_dtype
378373
379374 if c_dtype .kind == "c" or b_dtype .kind == "c" :
380- raise NotImplementedError ( _COMPLEX_DTYPE_NOT_SUPPORTED_MSG . format ( op = op ) )
375+ raise generate_fallback_impl ( op , node = node , ** kwargs )
381376 must_cast_c = c_dtype != out_dtype
382377 if must_cast_c and config .optimizer_verbose :
383378 print ("CholeskySolve requires casting first input `c`" ) # noqa: T201
@@ -425,7 +420,7 @@ def numba_funcify_QR(op, node, **kwargs):
425420
426421 dtype = node .inputs [0 ].dtype
427422 if dtype in complex_dtypes :
428- raise NotImplementedError ( _COMPLEX_DTYPE_NOT_SUPPORTED_MSG . format ( op = op ) )
423+ return generate_fallback_impl ( op , node = node , ** kwargs )
429424
430425 integer_input = dtype in integer_dtypes
431426 in_dtype = config .floatX if integer_input else dtype
0 commit comments