Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ if(NOT TILELANG_BACKEND_USER_SELECTED)
set(USE_METAL ON)
elseif($ENV{USE_ROCM})
set(USE_ROCM ON)
elseif($ENV{TILELANG_USE_DLCOMPILER})
set(USE_DLCOMPILER ON)
else()
if($ENV{USE_CUDA})
set(USE_CUDA ON)
Expand All @@ -198,6 +200,14 @@ if(NOT TILELANG_BACKEND_USER_SELECTED)
endif()
endif()

if(USE_DLCOMPILER)
if(EXISTS $ENV{DLCOMPILER_SOURCE}/cmake/commonir.cmake)
include($ENV{DLCOMPILER_SOURCE}/cmake/commonir.cmake)
else()
message(FATAL_ERROR "The env USE_DLCOMPILER is set, but commonir.cmake is not found.")
endif()
endif()

if(USE_METAL)
file(GLOB TILE_LANG_METAL_SRCS
src/target/rt_mod_metal.cc
Expand Down
6 changes: 6 additions & 0 deletions tilelang/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,9 @@ class Environment:
CUTLASS_INCLUDE_DIR = EnvVar("TL_CUTLASS_PATH", None)
COMPOSABLE_KERNEL_INCLUDE_DIR = EnvVar("TL_COMPOSABLE_KERNEL_PATH", None)

# For DLCompiler
TILELANG_USE_DLCOMPILER = EnvVar("TILELANG_USE_DLCOMPILER", "0")

# TVM integration
TVM_PYTHON_PATH = EnvVar("TVM_IMPORT_PYTHON_PATH", None)
TVM_LIBRARY_PATH = EnvVar("TVM_LIBRARY_PATH", None)
Expand Down Expand Up @@ -328,6 +331,9 @@ def use_gemm_v1(self) -> bool:
{"1", "true", "yes", "on"} (case-insensitive).
"""
return str(self.TILELANG_USE_GEMM_V1).lower() in ("1", "true", "yes", "on")

def use_dlcompiler(self) -> bool:
return str(self.TILELANG_USE_DLCOMPILER).lower() in ("1", "true", "yes", "on")

def get_default_target(self) -> str:
"""Get default compilation target from environment."""
Expand Down
20 changes: 20 additions & 0 deletions tilelang/jit/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,13 @@ def _compile_and_create_adapter(self, tilelang_func: PrimFunc, out_idx: list[int
compile_flags_cfg + compile_flags if compile_flags_cfg is not None else compile_flags
)

if env.use_dlcompiler():
from triton.backends.dicp_triton.commonir.adapter import AdapterWrapper

adapter_wrapper = AdapterWrapper.compile_and_create_adapter(tilelang_func)
self.artifact = adapter_wrapper.artifact
return adapter_wrapper.adapter

# Compile the function with TVM, optimizing with shared memory lowering.
enable_host_codegen = execution_backend == "tvm_ffi"
enable_device_compile = execution_backend == "tvm_ffi"
Expand Down Expand Up @@ -343,6 +350,19 @@ def _create_adapter_from_database(
target = self.target
execution_backend = self.execution_backend

if env.use_dlcompiler():
from triton.backends.dicp_triton.commonir.adapter import AdapterWrapper
adapter_wrapper = AdapterWrapper.from_database(
params=params,
result_idx=result_idx,
target=target,
func_or_mod=func_or_mod,
host_kernel_source=host_kernel_source,
kernel_lib_path=kernel_lib_path,
pass_configs=pass_configs,
)
return adapter_wrapper.adapter

# Create an adapter based on the specified execution backend.
if execution_backend == "tvm_ffi":
adapter = TVMFFIKernelAdapter.from_database(
Expand Down
6 changes: 5 additions & 1 deletion tilelang/utils/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from platform import mac_ver
from typing import Literal
from tilelang import env
from tilelang import tvm as tvm
from tilelang import _ffi_api
from tvm.target import Target
Expand All @@ -20,6 +21,7 @@
"webgpu": "WebGPU target for browser/WebGPU runtimes.",
"c": "C source backend.",
"cutedsl": "CuTe DSL GPU target.",
"dlcompiler": "DLCompiler for various device target.",
}


Expand Down Expand Up @@ -145,7 +147,9 @@ def determine_target(target: str | Target | Literal["auto"] = "auto", return_obj
is_hip_available = check_hip_availability()

# Determine the target based on availability
if is_cuda_available:
if env.use_dlcompiler():
return_var = "dlcompiler"
elif is_cuda_available:
if torch.cuda.is_available() and (cap := torch.cuda.get_device_capability(0)):
return_var = Target({"kind": "cuda", "arch": f"sm_{nvcc.get_target_arch(cap)}"})
else:
Expand Down