Skip to content

Commit 2012d55

Browse files
authored
[Relax] Add Python function support and BasePyModule for PyTorch integration (#18229)
### **Overview** This PR implements native Python function support in TVM Relax through the `@I.pyfunc` decorator and `BasePyModule`, which enable seamless integration between TVM's compilation pipeline and Python/PyTorch runtime environments. This enhancement allows users to write Python functions directly in TVMScript that can interoperate with Relax and TIR functions that provides enhanced debugging capabilities and leveraging existing PyTorch operator libraries. ### **Key Features** **TVMScript Parser Enhancement** - `@I.pyfunc` decorator: Marks Python functions for integration into IRModules - Dual storage format: Stores both raw string representation (for TVMScript printing) and captured PackedFunc (for runtime execution) - ExternFunc representation: Each Python function is represented as an ExternFunc node with attributes storing source code and runtime wrapper **Complete BasePyModule Implementation** - DLPack-based tensor conversion: Seamless conversion between PyTorch tensors and TVM NDArrays - Cross-function interoperability: Python functions can call Relax/TIR functions and vice versa - JIT compilation: Delays compilation until module instantiation for flexible late-stage modifications - Dynamic function registration: Supports runtime addition of Python functions ### Future Work - TVMScript printer for IRModules with Python functions: Print IRModules in proper format with high-level operator mapping from Relax ops to PyTorch ops, handling symbolic shapes - R.call_py_func primitive: Introduce Relax primitive to invoke corresponding PackedFunc of specified Python functions at runtime
1 parent 472b2fc commit 2012d55

File tree

13 files changed

+1826
-7
lines changed

13 files changed

+1826
-7
lines changed

python/tvm/ir/module.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __init__(self, functions=None, attrs=None, global_infos=None):
6767
attrs,
6868
global_infos,
6969
)
70+
self.pyfuncs = {}
7071

7172
def clone(self) -> "IRModule":
7273
return _ffi_api.Module_Clone(self)

python/tvm/relax/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@
9898
# utils
9999
from .utils import convert_to_expr
100100

101+
# BasePyModule
102+
from .base_py_module import BasePyModule
103+
101104
# Import submodules in the last to avoid dependency
102105
from . import exec_builder
103106
from . import expr

python/tvm/relax/base_py_module.py

Lines changed: 385 additions & 0 deletions
Large diffs are not rendered by default.

python/tvm/script/parser/core/entry.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import Any, Dict, Union
2020

2121
import tvm
22+
from tvm.relax import ExternFunc
2223
from ....ir.module import IRModule
2324
from ...ir_builder import IRBuilder
2425
from . import doc
@@ -86,12 +87,14 @@ def parse(
8687
extra_vars = _default_globals()
8788

8889
ann = {}
90+
all_pyfuncs = {}
8991
if inspect.isfunction(program):
9092
ann = {program.__name__: program.__annotations__}
9193
elif inspect.isclass(program):
9294
for name, func in program.__dict__.items():
9395
if inspect.isfunction(func):
9496
ann[name] = func.__annotations__
97+
all_pyfuncs[name] = func
9598

9699
source = Source(program)
97100
parser = Parser(source, ann)
@@ -101,6 +104,10 @@ def parse(
101104
except ParserError as err:
102105
parser.report_error(err.node, err.args[0])
103106
ret = builder.get()
107+
# Attach pyfuncs to the IRModule
108+
if inspect.isclass(program) and isinstance(ret, IRModule):
109+
_attach_pyfuncs_to_irmodule(ret, all_pyfuncs)
110+
104111
# check well-formedness in both Relax and TIR
105112
if check_well_formed:
106113
check_ret = ret
@@ -122,3 +129,65 @@ def parse(
122129
err=f"{WELL_FORMED_ERROR_MESSAGE}\n\nTraceback: {str(err)}",
123130
)
124131
return ret
132+
133+
134+
def _create_python_packed_func(pyfunc):
135+
"""Create a PackedFunc wrapper for a Python function.
136+
137+
This function creates a PackedFunc that can be called from TVM runtime
138+
and will execute the original Python function.
139+
140+
Parameters
141+
----------
142+
pyfunc : Callable
143+
The Python function to wrap.
144+
145+
Returns
146+
-------
147+
PackedFunc
148+
A PackedFunc that wraps the Python function.
149+
"""
150+
151+
def packed_func_wrapper(*args, **kwargs):
152+
"""Wrapper function that calls the original Python function."""
153+
try:
154+
result = pyfunc(*args, **kwargs)
155+
return result
156+
except Exception as error:
157+
print(f"Error calling Python function {pyfunc.__name__}: {error}")
158+
raise
159+
160+
return packed_func_wrapper
161+
162+
163+
def _attach_pyfuncs_to_irmodule(irmodule, all_pyfuncs):
164+
"""Attach Python functions to IRModule with reduced nesting."""
165+
if not all_pyfuncs:
166+
return
167+
168+
if not hasattr(irmodule, "pyfuncs"):
169+
irmodule.pyfuncs = {}
170+
171+
for global_var, func in irmodule.functions_items():
172+
if not isinstance(func, ExternFunc):
173+
continue
174+
if not func.attrs.get("is_pyfunc", False):
175+
continue
176+
177+
pyfunc_name = global_var.name_hint
178+
if pyfunc_name not in all_pyfuncs:
179+
continue
180+
181+
pyfunc = all_pyfuncs[pyfunc_name]
182+
irmodule.pyfuncs[pyfunc_name] = pyfunc
183+
184+
try:
185+
source_code = inspect.getsource(pyfunc)
186+
func = func.with_attr("python_source", source_code)
187+
except (OSError, TypeError):
188+
func = func.with_attr("python_source", f"# Source unavailable for {pyfunc_name}")
189+
190+
packed_func = _create_python_packed_func(pyfunc)
191+
func = func.with_attr("python_packed_func", packed_func)
192+
193+
irmodule[global_var] = func

python/tvm/script/parser/core/parser.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,8 @@ class Parser(doc.NodeVisitor):
343343
function_annotations: Optional[Dict[str, Dict[str, Any]]]
344344
var_table: VarTable
345345
inside_function: bool # whether we are within a function
346+
current_class: Optional[str] = None # current class being parsed
347+
base_py_module_context: bool = False # whether current class inherits from BasePyModule
346348

347349
def __init__(
348350
self,
@@ -414,6 +416,39 @@ def pop_token():
414416

415417
return _deferred(pop_token)
416418

419+
def set_class_context(self, class_name: str, is_base_py_module: bool = False):
420+
"""Set the current class context for parsing.
421+
422+
Parameters
423+
----------
424+
class_name : str
425+
The name of the current class being parsed.
426+
is_base_py_module : bool
427+
Whether the current class inherits from BasePyModule.
428+
"""
429+
self.current_class = class_name
430+
self.base_py_module_context = is_base_py_module
431+
432+
def _get_current_class_context(self) -> Optional[str]:
433+
"""Get the current class context.
434+
435+
Returns
436+
-------
437+
Optional[str]
438+
The name of the current class, or None if not in a class context.
439+
"""
440+
return self.current_class
441+
442+
def _is_base_py_module_context(self) -> bool:
443+
"""Check if the current class context allows Python functions.
444+
445+
Returns
446+
-------
447+
bool
448+
True if Python functions are allowed in the current context.
449+
"""
450+
return self.base_py_module_context
451+
417452
def with_diag_source(self, source: Source):
418453
"""Add a new source as with statement.
419454

python/tvm/script/parser/ir/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from tvm.ir import Range
1919
from ...ir_builder.ir import * # pylint: disable=redefined-builtin
2020
from . import parser as _parser
21-
from .entry import ir_module
21+
from .entry import ir_module, pyfunc
2222

2323

2424
__all__ = [
@@ -28,5 +28,6 @@
2828
"dummy_global_info",
2929
"Range",
3030
"lookup_vdevice",
31+
"pyfunc",
3132
"vdevice",
3233
]

python/tvm/script/parser/ir/entry.py

Lines changed: 91 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717
"""The entry point of TVM parser for ir module."""
1818

1919
import inspect
20-
from typing import Optional, Type
20+
from typing import Callable, Optional, Type
2121

22-
from tvm.ir import IRModule
22+
from tvm.ir import IRModule, GlobalVar
23+
from tvm.relax.expr import ExternFunc
24+
from tvm.relax.base_py_module import BasePyModule
25+
from tvm import cpu, ir
2326

2427
from .._core import parse, utils
2528

@@ -47,7 +50,86 @@ def ir_module(mod: Optional[Type] = None, check_well_formed: bool = True) -> IRM
4750
def decorator_wrapper(mod):
4851
if not inspect.isclass(mod):
4952
raise TypeError(f"Expect a class, but got: {mod}")
53+
54+
# Check BasePyModule inheritance
55+
base_py_module_inherited = any(base.__name__ == "BasePyModule" for base in mod.__bases__)
56+
5057
m = parse(mod, utils.inspect_class_capture(mod), check_well_formed=check_well_formed)
58+
59+
if base_py_module_inherited:
60+
# Collect pyfunc methods
61+
pyfunc_methods = [
62+
name
63+
for name, attr in mod.__dict__.items()
64+
if hasattr(attr, "dispatch_token") and attr.dispatch_token == "pyfunc"
65+
]
66+
67+
mod._pyfunc_methods = pyfunc_methods
68+
69+
# Create ExternFunc nodes
70+
71+
for method_name in pyfunc_methods:
72+
try:
73+
existing_gvars = [
74+
global_var
75+
for global_var in m.get_global_vars()
76+
if global_var.name_hint == method_name
77+
]
78+
79+
extern_func = ExternFunc(method_name)
80+
extern_func = extern_func.with_attr("is_pyfunc", True)
81+
extern_func = extern_func.with_attr("function_type", "python")
82+
extern_func = extern_func.with_attr("python_function_name", method_name)
83+
extern_func = extern_func.with_attr(
84+
"python_source", f"# Source for {method_name}"
85+
)
86+
extern_func = extern_func.with_attr("python_packed_func", None)
87+
88+
if existing_gvars:
89+
m[existing_gvars[0]] = extern_func
90+
else:
91+
m[GlobalVar(method_name)] = extern_func
92+
93+
except Exception: # pylint: disable=broad-exception-caught
94+
continue
95+
96+
class ModuleFactory:
97+
"""Factory class for creating BasePyModule instances with Python functions."""
98+
99+
def __init__(self, module, pyfunc_methods, original_class):
100+
self.ir_module = module
101+
self.pyfunc_methods = pyfunc_methods
102+
self.original_class = original_class
103+
104+
def __call__(self, device=None, target=None):
105+
106+
if device is None:
107+
device = cpu(0)
108+
109+
instance_ir_mod = ir.IRModule()
110+
for global_var, func in self.ir_module.functions_items():
111+
instance_ir_mod[global_var] = func
112+
113+
instance = BasePyModule(instance_ir_mod, device, target)
114+
115+
for method_name in self.pyfunc_methods:
116+
if hasattr(self.original_class, method_name):
117+
method = getattr(self.original_class, method_name)
118+
instance.add_python_function(method_name, method)
119+
120+
return instance
121+
122+
def __getattr__(self, name):
123+
if hasattr(self.ir_module, name):
124+
return getattr(self.ir_module, name)
125+
raise AttributeError(
126+
f"'{self.__class__.__name__}' object has no attribute '{name}'"
127+
)
128+
129+
factory = ModuleFactory(m, pyfunc_methods, mod)
130+
setattr(factory, "__name__", mod.__name__)
131+
return factory
132+
51133
setattr(m, "__name__", mod.__name__)
52134
return m
53135

@@ -61,4 +143,10 @@ def decorator_wrapper(mod):
61143
return decorator_wrapper
62144

63145

64-
setattr(ir_module, "dispatch_token", "ir")
146+
def pyfunc(func: Callable):
147+
# Set the dispatch_token on the decorated function
148+
setattr(func, "dispatch_token", "pyfunc")
149+
return func
150+
151+
152+
setattr(pyfunc, "dispatch_token", "pyfunc")

0 commit comments

Comments
 (0)