Skip to content

Commit e187a6e

Browse files
committed
Numba Alloc: Patch so it works inside a Blockwise
1 parent 0433078 commit e187a6e

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

pytensor/link/numba/dispatch/tensor_basic.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,24 +74,26 @@ def numba_funcify_Alloc(op, node, **kwargs):
7474
f'if val.shape[{-i - 1}] == 1 and scalar_shape[{-i - 1}] != 1: raise ValueError("{Alloc._runtime_broadcast_error_msg}")'
7575
)
7676
check_runtime_broadcast_src = indent("\n".join(check_runtime_broadcast), " " * 4)
77-
77+
dtype = node.inputs[0].type.dtype
7878
alloc_def_src = f"""
7979
def alloc(val, {", ".join(shape_var_names)}):
8080
{shapes_to_items_src}
8181
scalar_shape = {create_tuple_string(shape_var_item_names)}
8282
{check_runtime_broadcast_src}
83-
res = np.empty(scalar_shape, dtype=val.dtype)
83+
res = np.empty(scalar_shape, dtype=np.{dtype})
8484
res[...] = val
8585
return res
8686
"""
8787
alloc_fn = compile_numba_function_src(
8888
alloc_def_src,
8989
"alloc",
9090
globals() | {"np": np},
91+
write_to_disk=True,
9192
)
9293

94+
cache_version = -1
9395
cache_key = sha256(
94-
str((type(op), node.inputs[0].type.broadcastable)).encode()
96+
str((type(op), node.inputs[0].type.broadcastable, cache_version)).encode()
9597
).hexdigest()
9698
return numba_basic.numba_njit(alloc_fn), cache_key
9799

tests/link/numba/test_blockwise.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import pytest
33

44
from pytensor import function
5-
from pytensor.tensor import tensor, tensor3
6-
from pytensor.tensor.basic import ARange
5+
from pytensor.tensor import lvector, tensor, tensor3
6+
from pytensor.tensor.basic import Alloc, ARange, constant
77
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
88
from pytensor.tensor.nlinalg import SVD, Det
99
from pytensor.tensor.slinalg import Cholesky, cholesky
@@ -70,3 +70,13 @@ def test_repeated_args():
7070
final_node = fn.maker.fgraph.outputs[0].owner
7171
assert isinstance(final_node.op, BlockwiseWithCoreShape)
7272
assert final_node.inputs[0] is final_node.inputs[1]
73+
74+
75+
def test_blockwise_alloc():
76+
val = lvector("val")
77+
out = Blockwise(Alloc(), signature="(),(),()->(2,3)")(
78+
val, constant(2, dtype="int64"), constant(3, dtype="int64")
79+
)
80+
assert out.type.ndim == 3
81+
82+
compare_numba_and_py([val], [out], [np.arange(5)], eval_obj_mode=False)

0 commit comments

Comments
 (0)