1- Inductor C++ Wrapper Tutorial
1+ TorchInductor C++ Wrapper Tutorial
22==============================================================
33
44**Author **: `Chunyuan Wu <https://github.com/chunyuan-w >`_, `Bin Bao <https://github.com/desertfire >`__, `Jiong Gong <https://github.com/jgong5 >`__
@@ -10,85 +10,119 @@ Prerequisites:
1010Introduction
1111------------
1212
13- Python, as the primary interface of PyTorch, is easy to use and efficient for development and debugging.
14- The Inductor's default wrapper generates Python code to invoke generated kernels and external kernels.
15- However, in deployments requiring high performance, Python, as an interpreted language, runs relatively slower compared to compiled languages.
13+ In ``torch.compile ``, the default backend **TorchInductor ** emits Python wrapper
14+ code that manages memory allocation and kernel invocation. This design provides
15+ flexibility and ease of debugging, but the interpreted nature of Python
16+ introduces runtime overhead in performance-sensitive environments.
1617
17- We implemented an Inductor C++ wrapper by leveraging the PyTorch C++ APIs
18- to generate pure C++ code that combines the generated and external kernels.
19- This allows for the execution of each captured Dynamo graph in pure C++,
20- thereby reducing the Python overhead within the graph.
18+ To address this limitation, TorchInductor includes a specialized mode that
19+ generates **C++ wrapper code ** in place of the Python wrapper, enabling faster
20+ execution with minimal Python involvement.
2121
2222
23- Enabling the API
23+ Enabling the C++ wrapper mode
2424----------------
25- This feature is still in prototype stage. To activate this feature , add the following to your code:
25+ To enable this C++ wrapper mode for TorchInductor , add the following config to your code:
2626
2727.. code :: python
2828
2929 import torch._inductor.config as config
3030 config.cpp_wrapper = True
3131
32- This will speed up your models by reducing the Python overhead of the Inductor wrapper.
33-
3432
3533 Example code
3634------------
3735
38- We will use the below frontend code as an example:
36+ We will use the following model code as an example:
3937
4038.. code :: python
41-
39+
4240 import torch
41+ import torch._inductor.config as config
4342
44- def fn (x ):
45- return torch.tensor(list (range (2 , 40 , 2 )), device = x.device) + x
43+ config.cpp_wrapper = True
4644
47- x = torch.randn(1 )
48- opt_fn = torch.compile()(fn)
49- y = opt_fn(x)
45+ def fn (x , y ):
46+ return (x + y).sum()
47+
48+ device = torch.device(" cuda" if torch.cuda.is_available() else " cpu" )
49+ x = torch.randn(128 , 128 , device = device)
50+ y = torch.randn(128 , 128 , device = device)
51+
52+ opt_fn = torch.compile(fn)
53+ result = opt_fn(x, y)
5054
5155
5256 **For CPU **
5357
54- The main part of Inductor -generated code with the default Python wrapper will look like this:
58+ The main part of TorchInductor -generated code with the default Python wrapper will look like this:
5559
5660.. code :: python
5761
58- def call (args ):
59- arg0_1, = args
60- args.clear()
61- assert_size_stride(arg0_1, (1 , ), (1 , ))
62- buf0 = empty_strided((19 , ), (1 , ), device = ' cpu' , dtype = torch.float32)
63- cpp_fused_add_lift_fresh_0(c_void_p(constant0.data_ptr()), c_void_p(arg0_1.data_ptr()), c_void_p(buf0.data_ptr()))
64- del arg0_1
65- return (buf0, )
62+ class Runner :
63+ def __init__ (self , partitions ):
64+ self .partitions = partitions
65+
66+ def call (self , args ):
67+ arg0_1, arg1_1 = args
68+ args.clear()
69+ assert_size_stride(arg0_1, (128 , 128 ), (128 , 1 ))
70+ assert_size_stride(arg1_1, (128 , 128 ), (128 , 1 ))
71+ buf0 = empty_strided_cpu((), (), torch.float32)
72+ cpp_fused_add_sum_0(arg0_1, arg1_1, buf0)
73+ del arg0_1
74+ del arg1_1
75+ return (buf0, )
6676
6777 By turning on the C++ wrapper, the generated code for the ``call `` function becomes a C++ function
68- ``inductor_entry_cpp `` of the C++ extension `` module ``:
78+ ``inductor_entry_impl ``:
6979
7080.. code :: python
71-
72- std::vector< at::Tensor> inductor_entry_cpp(const std::vector< at::Tensor> & args) {
73- at::Tensor arg0_1 = args[0 ];
74- at::Tensor constant0 = args[1 ];
75- auto buf0 = at::empty_strided({19L , }, {1L , }, at::device(at::kCPU).dtype(at::kFloat));
76- cpp_fused_add_lift_fresh_0((long * )(constant0.data_ptr()), (float * )(arg0_1.data_ptr()), (float * )(buf0.data_ptr()));
81+ cpp_wrapper_src = (
82+ r '''
83+ # include <torch/csrc/inductor/cpp_wrapper/cpu.h>
84+ extern "C" void cpp_fused_add_sum_0( const float* in_ptr0,
85+ const float* in_ptr1,
86+ float* out_ptr0) ;
87+ CACHE_TORCH_DTYPE( float32) ;
88+ CACHE_TORCH_DEVICE( cpu) ;
89+
90+ void inductor_entry_impl(
91+ AtenTensorHandle*
92+ input_handles, // array of input AtenTensorHandle; handles
93+ // are stolen; the array itself is borrowed
94+ AtenTensorHandle*
95+ output_handles // array for writing output AtenTensorHandle; handles
96+ // will be stolen by the caller; the array itself is
97+ // borrowed)
98+ ) {
99+ py::gil_scoped_release_simple release;
100+
101+ auto inputs = steal_from_raw_handles_to_raii_handles( input_handles, 2) ;
102+ auto arg0_1 = std::move( inputs[0 ]) ;
103+ auto arg1_1 = std::move( inputs[1 ]) ;
104+ static constexpr int64_t * int_array_0=nullptr;
105+ AtenTensorHandle buf0_handle;
106+ AOTI_TORCH_ERROR_CODE_CHECK( aoti_torch_empty_strided( 0, int_array_0, int_array_0, cached_torch_dtype_float32, cached_torch_device_type_cpu, 0, &buf0_handle)) ;
107+ RAIIAtenTensorHandle buf0( buf0_handle) ;
108+ cpp_fused_add_sum_0(( const float* ) ( arg0_1. data_ptr( )) , ( const float* ) ( arg1_1. data_ptr( )) , ( float* ) ( buf0. data_ptr( ))) ;
77109 arg0_1. reset( ) ;
78- return {buf0};
79- }
80-
81- module = CppWrapperCodeCache.load(cpp_wrapper_src, ' inductor_entry_cpp' , ' c2buojsvlqbywxe3itb43hldieh4jqulk72iswa2awalwev7hjn2' , False )
82-
83- def _wrap_func (f ):
84- def g (args ):
85- args_tensor = [arg if isinstance (arg, torch.Tensor) else torch.tensor(arg) for arg in args]
86- constants_tensor = [constant0]
87- args_tensor.extend(constants_tensor)
88-
89- return f(args_tensor)
90- return g
91- call = _wrap_func(module.inductor_entry_cpp)
110+ arg1_1. reset( ) ;
111+ output_handles[0 ] = buf0. release( ) ;
112+ } // inductor_entry_impl
113+ ...
114+ '''
115+ )
116+
117+ inductor_entry = CppWrapperCodeCache.load_pybinding(
118+ argtypes = [" std::vector<AtenTensorHandle>" ],
119+ main_code = cpp_wrapper_src,
120+ device_type = " cpu" ,
121+ num_outputs = 1 ,
122+ kernel_code = None ,
123+ )
124+
125+ call = _wrap_func(inductor_entry)
92126
93127 **For GPU **
94128
@@ -113,47 +147,36 @@ Based on the same example code, the generated code for GPU will look like this:
113147 With the C++ wrapper turned on, the below equivalent C++ code will be generated:
114148
115149.. code :: python
116-
117- std::vector< at::Tensor> inductor_entry_cpp(const std::vector< at::Tensor> & args) {
118- at::Tensor arg0_1 = args[0 ];
119- at::Tensor constant0 = args[1 ];
120-
121- at::cuda::CUDAGuard device_guard(0 );
122- auto buf0 = at::empty_strided({19L , }, {1L , }, at::TensorOptions(c10::Device(at::kCUDA, 0 )).dtype(at::kFloat));
123- // Source Nodes: [add, tensor], Original ATen: [aten.add, aten.lift_fresh]
124- if (triton_poi_fused_add_lift_fresh_0 == nullptr) {
125- triton_poi_fused_add_lift_fresh_0 = loadKernel(" /tmp/torchinductor_user/mm/cmm6xjgijjffxjku4akv55eyzibirvw6bti6uqmfnruujm5cvvmw.cubin" , " triton_poi_fused_add_lift_fresh_0_0d1d2d3" );
126- }
127- CUdeviceptr var_0 = reinterpret_cast< CUdeviceptr> (constant0.data_ptr());
128- CUdeviceptr var_1 = reinterpret_cast< CUdeviceptr> (arg0_1.data_ptr());
129- CUdeviceptr var_2 = reinterpret_cast< CUdeviceptr> (buf0.data_ptr());
130- auto var_3 = 19 ;
131- void* kernel_args_var_0[] = {& var_0, & var_1, & var_2, & var_3};
132- cudaStream_t stream0 = at::cuda::getCurrentCUDAStream(0 );
133- launchKernel(triton_poi_fused_add_lift_fresh_0, 1 , 1 , 1 , 1 , 0 , kernel_args_var_0, stream0);
134- arg0_1.reset();
135- return {buf0};
136- }
137-
138- module = CppWrapperCodeCache.load(cpp_wrapper_src, ' inductor_entry_cpp' , ' czbpeilh4qqmbyejdgsbpdfuk2ss5jigl2qjb7xs4gearrjvuwem' , True )
150+ inductor_entry = CppWrapperCodeCache.load_pybinding(
151+ argtypes = [" std::vector<AtenTensorHandle>" ],
152+ main_code = cpp_wrapper_src,
153+ device_type = " cuda" ,
154+ num_outputs = 1 ,
155+ kernel_code = None ,
156+ )
139157
140158 def _wrap_func (f ):
141159 def g (args ):
142- args_tensor = [arg if isinstance (arg, torch.Tensor) else torch.tensor(arg) for arg in args]
143- constants_tensor = [constant0]
144- args_tensor.extend(constants_tensor)
160+ input_tensors = [arg if isinstance (arg, torch.Tensor) else torch.tensor(arg, device = ' cpu' ) for arg in args]
161+ input_handles = torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(input_tensors)
162+
163+ args.clear()
164+ del input_tensors
165+
166+ output_handles = f(input_handles)
167+ output_tensors = torch._C._aoti.alloc_tensors_by_stealing_from_void_ptrs(output_handles)
168+ return output_tensors
145169
146- return f(args_tensor)
147170 return g
148- call = _wrap_func(module.inductor_entry_cpp)
171+
172+ call = _wrap_func(inductor_entry)
149173
150174
151175 Conclusion
152176------------
153177
154- In this tutorial, we introduced a new C++ wrapper in TorchInductor to speed up your models with just two lines of code changes.
155- We explained the motivation of this new feature and walked through the easy-to-use API to activate this experimental feature.
156- Furthermore, we demonstrated the Inductor-generated code using the default Python wrapper and the new C++ wrapper on both CPU and GPU
157- to visually showcase the difference between these two wrappers.
158-
159- This feature is still in prototype stage. If you have any feature requests or run into any issues, please file a bug report at `GitHub issues <https://github.com/pytorch/pytorch/issues >`_.
178+ This tutorial introduced the **C++ wrapper ** feature in TorchInductor, designed
179+ to improve model performance with minimal code modification. We described the
180+ motivation for this feature, detailed the experimental API used to enable it,
181+ and compared the generated outputs of the default Python wrapper and the new
182+ C++ wrapper on both CPU and GPU backends to illustrate their distinctions.
0 commit comments