[Optimization][Operator] Implement and enable Conv2d-Reshape-Add-ReLU fusion #18171
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR introduces an operator fusion for the common
conv2dfollowed byreshape,add, andrelusequence, commonly found in deep learning models (e.g., convolution + bias + activation pattern in PyTorch). This optimization significantly improves performance and efficiency by reducing overhead and optimizing memory usage.Performance Improvement:
conv2d,reshape,add, andrelueach required separate kernel calls. By fusing these four operations into a single, unified DNNL kernel (e.g.,dnnl_fused_conv2d_bias_relu), the overhead from multiple kernel launches is significantly reduced. This is evident fromsrc/runtime/contrib/dnnl/dnnl.cc:154-158, where all operations are handled by a singleexecutecall.conv_out,bias_add) traditionally required costly memory write-backs and reads. Fusion allows these intermediate values to be processed directly in registers or cache, reducing unnecessary memory accesses, and thus decreasing memory bandwidth usage and overall execution time.Increased Efficiency:
FuseOpsByPatternandMergeCompositeFunctionspasses, this change generates a composite operation optimized for specific backends (like DNNL). This ensures that common patterns from frontends like PyTorch are automatically recognized within the TVM graph and mapped to high-performance fused kernels provided by libraries like DNNL.This fusion is achieved through a two-stage transformation within the TVM Relax framework:
Pattern Recognition and Composite Function Creation (
FuseConv2dReshapeAddReluPass):FuseConv2dReshapeAddReluclass, registered as atvm.transform.module_pass, transforms theIRModule._conv2d_reshape_add_relu_pattern()helper function defines the specific sequence:conv2d->reshape(applied to bias) ->add->reluusing TVM's Declarative Pattern Language (DPL). This includes matching input tensors (data,weight,bias,shape) usingwildcard()and identifying operation sequence withis_op().relax.transform.FuseOpsByPatternpass identifies this pattern in the inputIRModule. Upon detection, the operation sequence is encapsulated into a new Relax function with{"Composite": "dnnl.conv2d_reshape_add_relu", "Primitive": True}attributes, marking it as a logical "composite" unit.Composite Function Merging and Codegen Attribute Assignment (
MergeCompositeFunctionsPass):FuseConv2dReshapeAddRelupass, theMergeCompositeFunctionspass is applied viatvm.ir.transform.Sequential.Compositeattribute and transforms them into external functions bearing the{"Codegen": "dnnl"}attribute. ThisCodegenattribute indicates that the composite operation should be offloaded to a specific TVM backend, such as DNNL.Codegenattribute will be mapped and executed by an optimized, single DNNL kernel, for instance,dnnl_fused_conv2d_bias_relu(defined insrc/runtime/contrib/dnnl/dnnl.cc:199-207).This implementation successfully enables the fusion of the
conv2d + reshape + add + relupattern. This ensures that common convolution + bias + activation patterns originating from frontends like PyTorch are now fully optimized and executed as a single, highly efficient DNNL kernel within TVM.To verify this fusion, you can directly run the specific test case:
python tests/python/relax/test_conv2d_reshape_add_relu.py