diff --git a/backend/commonir/__init__.py b/backend/commonir/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/commonir/adapter.py b/backend/commonir/adapter.py new file mode 100644 index 00000000..4d99202a --- /dev/null +++ b/backend/commonir/adapter.py @@ -0,0 +1,218 @@ +import os +import re +from typing import Callable, List +from triton.backends.dicp_triton.commonir.compiler import ( + CommonIRCompiler, + CommonIRSource, + CompiledKernel, +) + + +class AdapterWrapper: + def __init__(self) -> None: + from tilelang import tvm as tvm + from tvm import tir + from tilelang.engine.param import KernelParam + from tilelang.jit.adapter import BaseKernelAdapter + + class Artifact: + def __init__(self) -> None: + self.kernel_source: str = None + self.params: List[KernelParam] = None + + def set_kernel_source(self, kernel_source) -> None: + self.kernel_source = str(kernel_source) + self.params = self._extrac_params(kernel_source) + + def _extrac_params(self, func: tir.PrimFunc) -> List[KernelParam]: + tensor_types = [] + for var in func.params: + if var in func.buffer_map: + tensor_types.append( + KernelParam.from_buffer(func.buffer_map[var]) + ) + else: + tensor_types.append(KernelParam.from_var(var)) + return tensor_types + + class Adapter(BaseKernelAdapter): + def __init__(self) -> None: + self.mod = None + self.func = None + self.libpath = None + self.kernel_source = None + + def set_info(self, mod, kernel_source, func: CompiledKernel) -> None: + self.mod = mod + self.func = func + self.libpath = func._run.so_launcher_path + self.kernel_source = str(kernel_source) + + def _convert_torch_func(self) -> Callable: + return self.func + + def get_kernel_source(self) -> str: + return self.kernel_source + + self.adapter = Adapter() + self.artifact = Artifact() + + @classmethod + def compile_and_create_adapter(cls, tilelang_module): + adapter_wrapper = AdapterWrapper() + adapter_wrapper.artifact.set_kernel_source(tilelang_module) + mlir_content = cls._tilelang_to_commonir(tilelang_module) + grid = cls._parse_grid(tilelang_module) + signature = cls._parse_signature(mlir_content) + + commonir_compiler = CommonIRCompiler() + func = commonir_compiler.compile(CommonIRSource(mlir_content, grid, signature)) + adapter_wrapper.adapter.set_info(mlir_content, tilelang_module, func) + + return adapter_wrapper + + @classmethod + def from_database( + cls, + params, + result_idx, + target, + func_or_mod, + kernel_global_source, + kernel_lib_path, + pass_configs, + ): + return cls.compile_and_create_adapter(func_or_mod) + + @classmethod + def _tilelang_to_commonir(cls, tilelang_module): + from tilelang.engine import lower + from tilelang import tvm as tvm + from tvm.ir.instrument import PrintAfterAll, PrintBeforeAll + + debug_enabled = os.environ.get("TILELANG_PRINT_COMMONIR", "0") in ( + "1", + "true", + "on", + ) + + instruments = [PrintAfterAll(), PrintBeforeAll()] if debug_enabled else [] + with tvm.transform.PassContext(instruments=instruments): + mlir_path = lower(tilelang_module) + if mlir_path.endswith(".mlir"): + mlir_content = cls._read_mlir_file(mlir_path) + else: + mlir_content = mlir_path + return mlir_content + + @classmethod + def _parse_grid(cls, tilelang_module): + patterns = { + "x": r'T\.launch_thread\("blockIdx\.x",\s*(\d+)\)', + "y": r'T\.launch_thread\("blockIdx\.y",\s*(\d+)\)', + "z": r'T\.launch_thread\("blockIdx\.z",\s*(\d+)\)', + } + block_indices = {"x": None, "y": None, "z": None} + for dim, pattern in patterns.items(): + match = re.search(pattern, str(str(tilelang_module))) + if match: + block_indices[dim] = int(match.group(1)) + return [ + block_indices["x"] if block_indices["x"] is not None else 1, + block_indices["y"] if block_indices["y"] is not None else 1, + block_indices["z"] if block_indices["z"] is not None else 1, + ] + + @classmethod + def _read_mlir_file(cls, file_path) -> str: + try: + with open(file_path, "r", encoding="utf-8") as file: + content = file.read() + return content + except FileNotFoundError: + print(f"Error: File '{file_path}' does not exist") + return None + except Exception as e: + print(f"Error occurred while reading the file: {e}") + return None + + @classmethod + def _parse_signature(cls, mlir_content) -> dict: + target_types = { + "i1", + "i8", + "i16", + "i32", + "i64", + "u32", + "u64", + "fp16", + "bf16", + "fp32", + "f32", + "fp64", + "f16", + } + + pattern = r"func\.func\s*@[^(]*\(([^)]*)\)" + match = re.search(pattern, mlir_content) + + if not match: + return {} + + params_str = match.group(1) + + params = [] + current_param = "" + brace_count = 0 + angle_count = 0 + + for char in params_str: + if char == "," and brace_count == 0 and angle_count == 0: + params.append(current_param.strip()) + current_param = "" + else: + current_param += char + if char == "{": + brace_count += 1 + elif char == "}": + brace_count -= 1 + elif char == "<": + angle_count += 1 + elif char == ">": + angle_count -= 1 + + if current_param: + params.append(current_param.strip()) + + result = {} + index = 0 + + for param in params: + if re.match(r"%args\d+", param.strip()): + continue + + found_type = None + for t_type in target_types: + x_pattern = r"\bx" + t_type + r"\b" + if re.search(x_pattern, param): + found_type = "*" + t_type + break + elif re.search(r"\b" + t_type + r"\b", param): + found_type = t_type + break + + if found_type: + if found_type == "f16": + found_type = "fp16" + elif found_type == "*f16": + found_type = "*fp16" + elif found_type == "f32": + found_type = "fp32" + elif found_type == "*f32": + found_type = "*fp32" + + result[index] = found_type + index += 1 + + return result diff --git a/backend/commonir/backend.py b/backend/commonir/backend.py new file mode 100644 index 00000000..7f49170c --- /dev/null +++ b/backend/commonir/backend.py @@ -0,0 +1,214 @@ +import functools +import os +from typing import Any +from ..compiler import DICPOptions +from ..driver import DICPDriver +from ..utils import get_current_backend + + +class CommonIRBackend: + binary_ext = "ttlinalgdir" + + def __init__(self) -> None: + target = get_current_backend() + self.driver = DICPDriver(target) + if self.driver.target == "dicp": + self.binary_ext = "ttlinalgdir" + elif self.driver.target == "mlu": + self.capability = target.arch + assert isinstance(self.capability, int) + self.binary_ext = "cnbin" + elif self.driver.target == "maca": + self.capability = 80 + self.binary_ext = "mcfatbin" + elif self.driver.target == "ascend": + self.binary_ext = "npubin" + else: + raise RuntimeError(f"Target '{self.target_type}' is not supported.") + + def get_attrs_descriptor(self, params, args): + if self.driver.target == "ascend": + from triton.backends.dicp_triton.npu import AscendAttrsDescriptor + + return AscendAttrsDescriptor(params, args) + else: + raise RuntimeError( + f"backend {self.driver.target} not supported for get_attrs_descriptor." + ) + + def add_stages(self, stages, options, language=None): + + if self.driver.target == "ascend": + from triton.backends.dicp_triton.npu import ( + commonir_to_linkedir, + linalg_to_bin_enable_npu_compile, + ) + + stages["linkedir"] = lambda src, metadata: commonir_to_linkedir( + src, metadata, options, named_ops=True + ) + stages["npubin"] = lambda src, metadata: linalg_to_bin_enable_npu_compile( + src, metadata, options + ) + else: + raise RuntimeError("backend not supported") + + def load_dialects(self, ctx): + if self.driver.target == "mlu": + from triton._C.libtriton import mlu + + mlu.load_dialects(ctx) + return + + def get_driver(self): + return self.driver + + # parse add_kernel[(16,)](x, y, output, n_elements, BLOCK_SIZE=1024) + def parse_options(self, options: dict) -> Any: + if self.driver.target == "ascend": + from triton.backends.dicp_triton.npu import NPUOptions + + args = { + k: options[k] + for k in NPUOptions.__dataclass_fields__.keys() + if k in options + } + options = NPUOptions(**args) + return options + elif self.driver.target == "mlu": + from triton.backends.dicp_triton.mlu import MLUOptions + + args = { + k: options[k] + for k in MLUOptions.__dataclass_fields__.keys() + if k in options + } + # When arch is less than mtp_5xx, tf32 is not supported, use fp32 for calculation. + if "allowed_dot_input_precisions" not in args: + if self.capability < 500: + args["allowed_dot_input_precisions"] = "ieee" + + if "supported_fp8_dtypes" not in args: + supported_fp8_dtypes = set(MLUOptions.supported_fp8_dtypes) + if self.capability >= 600: + supported_fp8_dtypes = supported_fp8_dtypes.union( + ("fp8e5", "fp8e4nv") + ) + args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes)) + + args["max_num_imprecise_acc_default"] = 0 + + if "enable_fp_fusion" not in args: + args["enable_fp_fusion"] = ( + os.getenv("TRITON_DEFAULT_FP_FUSION", "1") == "1" + ) + + if "enable_mlu_bound_check" not in args: + args["enable_mlu_bound_check"] = ( + os.getenv("TRITON_ENABLE_MLU_BOUND_CHECK", "0") == "1" + ) + return MLUOptions(**args) + elif self.driver.target == "maca": + from triton.backends.dicp_triton.maca import MACAOptions + + # args = {k: options[k] for k in MACAOptions.__dataclass_fields__.keys() if k in options} + # return MACAOptions(**args) + args = { + k: options[k] + for k in MACAOptions.__dataclass_fields__.keys() + if k in options + } + # USE_MACA: support allow_fp8e4nv(i.e. float8_e4m3fn) + args["allow_fp8e4nv"] = True + # args["allow_fp8e4nv"] = False + args["allow_fp8e4b15"] = False + args["max_num_imprecise_acc_default"] = ( + 2**30 if self.capability == 90 else 0 + ) + return MACAOptions(**args) + else: + args = {"arch": self.target} + args.update( + { + k: options[k] + for k in DICPOptions.__dataclass_fields__.keys() + if k in options + } + ) + return DICPOptions(**args) + + def get_codegen_implementation(self, options=None): + codegen_fns = dict() + if self.driver.target == "ascend": + from triton.backends.dicp_triton.npu import min_dot_size + + codegen_fns = {"min_dot_size": min_dot_size(self.target)} + elif self.driver.target == "mlu": + from triton.backends.dicp_triton.mlu import min_dot_size + + codegen_fns = { + "convert_custom_types": lambda arg, dst_ty: arg, + "min_dot_size": min_dot_size(self.target), + } + elif self.driver.target == "maca": + import triton.language.extra.cuda as cuda + + codegen_fns = { + "convert_custom_types": ( + cuda.convert_custom_float8_sm80 + if self.capability >= 80 + else cuda.convert_custom_float8_sm70 + ) + } + return codegen_fns + + def pack_metadata(self, metadata): + if self.driver.target == "ascend": + from triton.backends.dicp_triton.npu import TRITON_PROFILER_REGISTERED + + # collect necessary metadata to launch kernels + # TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 could set unique name. + # Get this name as the kernel_name to CANN runtime. + # kernel_name is unique to Ascend backend and should not be public. + # CANN runtime limits the length of kernel name <= 50. + # Considering '\n' is appended, thus the real kernel name <= 49. + KERNEL_NAME_MAX_LEN = 49 + kernel_name_orig, mix_mode = metadata.name.split() + if len(kernel_name_orig) > KERNEL_NAME_MAX_LEN: + kernel_name = kernel_name_orig[-KERNEL_NAME_MAX_LEN:] + # import warnings + # # red = "\x1b[31;20m" + # # reset = "\x1b[0m" + # warnings.warn(kernel_name_orig + " is truncated to " + kernel_name) + # warnings.warn("because '" + kernel_name_orig + "' exceeds torchnpu profiler's length limit < 50") + else: + kernel_name = kernel_name_orig + return { + "kernel_name": kernel_name, + "hash": metadata.hash, + "debug": metadata.debug, + "profiler_registered": TRITON_PROFILER_REGISTERED, + } + elif self.driver.target == "mlu": + return (metadata.num_warps,) + return ( + metadata.num_warps, + metadata.num_ctas, + metadata.shared, + metadata.cluster_dims[0], + metadata.cluster_dims[1], + metadata.cluster_dims[2], + ) + + @functools.lru_cache() + def hash(self): + if self.driver.target == "mlu": + from triton.backends.dicp_triton.mlu import get_cnas_version + + version = get_cnas_version() + return f"{version}-{self.capability}" + version_key = self.driver.target + return str(version_key) + + +commonir_backend = CommonIRBackend() diff --git a/backend/commonir/compiler.py b/backend/commonir/compiler.py new file mode 100644 index 00000000..10e22d2d --- /dev/null +++ b/backend/commonir/compiler.py @@ -0,0 +1,185 @@ +import functools +import hashlib +import json +from pathlib import Path +from typing import Any, List +from triton._C.libtriton import get_cache_invalidating_env_vars +from collections import namedtuple +from triton.runtime.cache import triton_key +from .backend import commonir_backend +from triton.backends.compiler import GPUTarget +from triton.compiler.compiler import AsmDict, _raise_error +from triton.compiler.compiler import LazyDict +from triton.runtime.cache import get_cache_manager + + +class CommonIRSource: + def __init__(self, src: str, grid: List[int], signature: dict): + self.src = src + self.grid = grid + self.signature = signature + + +class CompiledKernel: + def __init__(self, src: CommonIRSource, metadata_group, hash): + metadata_path = next( + (Path(p) for c, p in metadata_group.items() if c.endswith(".json")) + ) + metadata = json.loads(metadata_path.read_text()) + metadata["cluster_dims"] = tuple(metadata["cluster_dims"]) + # JSON serialization dumps the target as a dict. Restore it to a GPUTarget. + target = metadata["target"] + metadata["target"] = GPUTarget( + target["backend"], target["arch"], target["warp_size"] + ) + KernelMetadata = namedtuple("KernelMetadata", sorted(list(metadata.keys()))) + self.metadata = KernelMetadata(**metadata) + self.packed_metadata = commonir_backend.pack_metadata(self.metadata) + self.src = src + self.hash = hash + self.name = self.metadata.name + self.grid = src.grid + # stores the text of each level of IR that was generated during compilation + asm_files = [ + Path(p) for c, p in metadata_group.items() if not c.endswith(".json") + ] + binary_ext = commonir_backend.binary_ext + self.asm = AsmDict( + { + file.suffix[1:]: ( + file.read_bytes() + if file.suffix[1:] == binary_ext + else file.read_text() + ) + for file in asm_files + } + ) + self.metadata_group = metadata_group + self.kernel = self.asm[binary_ext] + self.module = None + self._init_handles() + + def _init_handles(self): + if self.module is not None: + return + + device = commonir_backend.get_driver().get_current_device() + # create launcher + self._run = commonir_backend.get_driver().launcher_cls(self.src, self.metadata) + ( + self.module, + self.function, + self.n_regs, + self.n_spills, + ) = commonir_backend.get_driver().utils.load_binary( + self.name, self.kernel, self.metadata.shared, device + ) + + @property + def run(self): + if self._run is None: + self._init_handles() + return self._run + + def launch_metadata(self, grid, stream, *args): + self._init_handles() + ret = LazyDict({"name": self.name, "function": self.function, "stream": stream}) + return ret + + def __call__(self, *args: Any) -> Any: + device = commonir_backend.get_driver().get_current_device() + stream = commonir_backend.get_driver().get_current_stream(device) + # launch kernel + + launch_metadata = self.launch_metadata(self.grid, stream, *args) + self.run( + self.grid[0], + self.grid[1], + self.grid[2], + stream, + self.function, + self.packed_metadata, + launch_metadata, + None, # knobs.runtime.launch_enter_hook, + None, # knobs.runtime.launch_exit_hook, + *args, + ) + + def __getitem__(self, grid): + self._init_handles() + + def runner(*args, stream=None): + if stream is None: + device = commonir_backend.get_driver().get_current_device() + stream = commonir_backend.get_driver().get_current_stream(device) + launch_metadata = self.launch_metadata(grid, stream, *args) + self.run( + grid[0], + grid[1], + grid[2], + stream, + self.function, + self.packed_metadata, + launch_metadata, + None, # knobs.runtime.launch_enter_hook, + None, # knobs.runtime.launch_exit_hook, + *args, + ) + + return runner + + +class CommonIRCompiler(object): + + def compile(self, commonir_src: CommonIRSource, options=None, _env_vars=None): + + target = commonir_backend.get_driver().get_current_target() + assert isinstance(target, GPUTarget), "target must be of GPUTarget type" + + extra_options = {} + options = commonir_backend.parse_options( + dict(options or dict(), **extra_options) + ) + # create cache manager + env_vars = get_cache_invalidating_env_vars() if _env_vars is None else _env_vars + + src_hash = hashlib.sha256(commonir_src.src.encode("utf-8")).hexdigest() + key = f"{triton_key()}-{src_hash}-{commonir_backend.hash()}-{options.hash()}-{str(sorted(env_vars.items()))}" + hash = hashlib.sha256(key.encode("utf-8")).hexdigest() + fn_cache_manager = get_cache_manager(hash) + store_only_binary = False + file_name = "tilelang-commonir" + metadata_filename = f"{file_name}.json" + metadata_group = fn_cache_manager.get_group(metadata_filename) or {} + # initialize metadata + metadata = { + "hash": hash, + "target": target, + **options.__dict__, + **env_vars, + } + # run compilation pipeline and populate metadata + stages = dict() + commonir_backend.add_stages(stages, options) + module = commonir_src.src + ir_filename = f"{file_name}.source" + metadata_group[ir_filename] = fn_cache_manager.put(module, ir_filename) + + for ext, compile_ir in list(stages.items()): + next_module = compile_ir(module, metadata) + ir_filename = f"{file_name}.{ext}" + if (not store_only_binary) or (ext in ("cubin", "hsaco", "json")): + metadata_group[ir_filename] = fn_cache_manager.put( + next_module, ir_filename + ) + module = next_module + # write-back metadata + metadata_group[metadata_filename] = fn_cache_manager.put( + json.dumps(metadata, default=vars), metadata_filename, binary=False + ) + fn_cache_manager.put_group(metadata_filename, metadata_group) + return CompiledKernel(commonir_src, metadata_group, hash) + + @functools.lru_cache() + def hash(self): + return "CommonIRCompiler" diff --git a/backend/device_utils.py b/backend/device_utils.py new file mode 100644 index 00000000..25e42826 --- /dev/null +++ b/backend/device_utils.py @@ -0,0 +1,138 @@ +# Copyright (c) 2025, DeepLink. +import functools +from typing import Optional + +import torch +import triton + +WARPS_PER_SM = { + (8, 0): 64, + (8, 6): 48, + (8, 7): 48, + (8, 9): 48, + (9, 0): 64, + (10, 0): 64, + (10, 1): 48, + (12, 0): 48, +} + + +@functools.lru_cache +def get_device_props(device=None): + if device is None: + device = torch.cuda.current_device() + + props = torch.cuda.get_device_properties(device) + + warps_per_sm = WARPS_PER_SM.get((props.major, props.minor), 32) + out = dict( + multi_processor_count=props.multi_processor_count, + warps_per_sm=warps_per_sm, + ) + return out + + +@functools.lru_cache +def get_number_cores(): + if is_npu(): + import triton.runtime.driver as driver + + device = torch.npu.current_device() + return driver.active.utils.get_device_properties(device)["num_aicore"] + elif is_cuda(): + return torch.cuda.get_device_properties("cuda").multi_processor_count + else: + raise RuntimeError("Please implement this function.") + + +def is_mlu_592(): + target = triton.runtime.driver.active.get_current_target() + return target.backend == "mlu" and target.arch == 592 + + +def is_muxi(): + target = triton.runtime.driver.active.get_current_target() + return target.backend == "maca" + + +@functools.lru_cache +def is_cuda(): + try: + return torch.cuda.is_available() + except Exception: + return False + + +@functools.lru_cache +def is_npu(): + try: + return torch.npu.is_available() + except Exception: + return False + + +@functools.lru_cache +def is_tesla(): + try: + return "Tesla" in torch.cuda.get_device_name(0) + except Exception: + return False + + +@functools.lru_cache +def is_nvidia_hopper(): + try: + return is_cuda() and ( + "NVIDIA H" in torch.cuda.get_device_name(0) + or torch.cuda.get_device_capability()[0] >= 9 + ) + except Exception: + return False + + +def set_allocator(device_: str): + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device=device_, dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + +@functools.lru_cache +def is_tma_supported(): + try: + is_tma_supported = ( + is_cuda() + and torch.cuda.get_device_capability(0)[0] >= 9 + and ( + hasattr(triton.language, "_experimental_make_tensor_descriptor") + or hasattr(triton.language, "make_tensor_descriptor") + ) + ) + if is_tma_supported: + set_allocator("cuda") + return is_tma_supported + except Exception: + return False + + +@functools.lru_cache +def infer_device(): + """ + Get current device name based on available devices + """ + if is_npu(): + return "npu" + elif is_mlu_592(): + return "mlu" + elif is_muxi(): + return "cuda" + elif is_nvidia_hopper(): + return "cuda" + elif is_cuda(): + return "cuda" + else: + return "cpu" + + +NUM_CORES = get_number_cores() +DEVICE = infer_device() diff --git a/backend/npu.py b/backend/npu.py index 74c18580..7838b641 100644 --- a/backend/npu.py +++ b/backend/npu.py @@ -17,7 +17,6 @@ import pybind11 import shutil - ###################### utils.py start ###################### TRITON_PROFILER_REGISTERED = False @@ -426,6 +425,63 @@ def ttir_to_ttsharedir_ascend(mod, metadata, opt, *, named_ops=False): return mod +def commonir_to_linkedir(commonir, metadata, opt, *, named_ops=False): + assert isinstance(commonir, str) + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "kernel.commonir.mlir") + dst_path = os.path.join(tmpdir, "kernel.linked.mlir") + Path(src_path).write_text(commonir) + cmd_list = [ + _get_dicp_opt_path(), + src_path, + "--lower-affine", + "--normalize-slice-ops", + "--linalg-if-to-select", + "--linalg-generic-to-scf", + "--scalar-to-1d-tensor", + f"--linalg-to-linked=global-kernel=false named-ops=true", + "--linked-to-hivm", + "-o", + dst_path, + ] + try: + ret = subprocess.run(cmd_list, capture_output=True, check=True) + except subprocess.CalledProcessError as e: + print(f"Error: code={e.returncode}, stdout:{e.stdout},stderr: {e.stderr}") + content = Path(dst_path).read_text() + + # TODO(zmz): 修改test_path 中内容,暂时在python中处理,bishengir-compile后续会支持,去掉这里逻辑。 + # 将"*xfxxx"替换成"?xfxxx" + content = content.replace("*xf", "?xf") + content = content.replace("*xi", "?xi") + content = content.replace("*xbf", "?xbf") + # 匹配形如 "memref<...> to tensor<...>" 的模式 + pattern = r"(memref\<.*?\>)\s+to\s+(tensor\<.*?\>)" + # 使用正则替换,保留memref和tensor类型,中间插入注释 + content = re.sub(pattern, r"\1 // to \2", content) + + if opt.debug or dump_ir: + cmd_list = [ + _get_dicp_opt_path(), + "kernel.ttshared.mlir", + "--lower-affine", + "--normalize-slice-ops", + "--linalg-if-to-select", + "--linalg-generic-to-scf", + "--scalar-to-1d-tensor", + f"--linalg-to-linked=global-kernel=false named-ops=true", + "--linked-to-hivm", + ] + dicp_utils._dump_stage_ir( + content, metadata["hash"], "kernel.linkedir.mlir", cmd_list + ) + + if replace_linked_ir is not None: + print(f"[DEBUG] Replace Linkedir with {replace_linked_ir}") + return Path(replace_linked_ir).read_text() + return content + + def ttsharedir_to_linkedir(mod, metadata, opt, *, named_ops=False): pm = ir.pass_manager(mod.context) dicp_triton.passes.linked_npu.add_lower_affine(pm) @@ -434,6 +490,9 @@ def ttsharedir_to_linkedir(mod, metadata, opt, *, named_ops=False): dicp_triton.passes.linked_npu.add_linalg_generic_to_scf(pm) dicp_triton.passes.linked_npu.add_scalar_to_1d_tensor(pm) dicp_triton.passes.linked_npu.add_linalg_to_linked(pm, False, True) + open_add_annotate_transpose = metadata["add_annotate_transpose"] + if open_add_annotate_transpose is not None and open_add_annotate_transpose is True: + dicp_triton.passes.linked_npu.add_annotate_transpose(pm) dicp_triton.passes.linked_npu.add_linked_to_hivm(pm) pm.run(mod) @@ -794,6 +853,7 @@ class NPUOptions: tile_mix_cube_loop: int = None limit_auto_multi_buffer_only_for_local_buffer: bool = None set_workspace_multibuffer: int = None + add_annotate_transpose: bool = None stream: int = None @@ -901,12 +961,12 @@ def __init__(self, src, metadata): wrapper_src = generate_npu_wrapper_src( constants, signature, workspace_size, mix_mode, lock_num, lock_init_value ) - so_launcher_path = make_npu_launcher_stub(wrapper_src, debug_mode) + self.so_launcher_path = make_npu_launcher_stub(wrapper_src, debug_mode) # initialize launcher import importlib.util spec = importlib.util.spec_from_file_location( - "__triton_launcher", so_launcher_path + "__triton_launcher", self.so_launcher_path ) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) diff --git a/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.h b/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.h index 7ae43b6c..0e18f26d 100644 --- a/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.h +++ b/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.h @@ -20,6 +20,9 @@ std::unique_ptr> createLinalgGenericToSCFPass(); std::unique_ptr> createScalarTo1DTensorPass(); +std::unique_ptr> +createAnnotateTransposePass(); + std::unique_ptr> createNormalizeSliceOpsPass(); diff --git a/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.td b/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.td index c486210a..528e6166 100644 --- a/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.td +++ b/compiler/include/dicp/Dialect/LinalgExt/Transforms/Passes.td @@ -68,4 +68,18 @@ def NormalizeSliceOps : Pass<"normalize-slice-ops", "func::FuncOp"> { let dependentDialects = ["mlir::tensor::TensorDialect"]; } +def AnnotateTranspose : Pass<"annotate-transpose", "func::FuncOp"> { + let summary = "Annotate operations with permuted memref type"; + let description = [{ + Adds MayImplicitTransposeWithLastAxis annotations to operations with permuted memref type. + }]; + let constructor = "mlir::dicp::LinalgExt::createAnnotateTransposePass()"; + let dependentDialects = [ + "mlir::func::FuncDialect", + "mlir::linalg::LinalgDialect", + "mlir::memref::MemRefDialect", + "mlir::bufferization::BufferizationDialect" + ]; +} + #endif diff --git a/compiler/include/dicp/Utils/Utils.h b/compiler/include/dicp/Utils/Utils.h index e50253dd..23e53351 100644 --- a/compiler/include/dicp/Utils/Utils.h +++ b/compiler/include/dicp/Utils/Utils.h @@ -1,19 +1,42 @@ #ifndef TRITON_UTILS_H #define TRITON_UTILS_H +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSwitch.h" - +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/LogicalResult.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" + +#include +#include +#include #include +#include #include // Dispatch conversion pattern handlers based on backend string. Executes @@ -43,7 +66,19 @@ llvm::StringRef getBackend(ModuleOp module); bool isAscendBackend(ModuleOp module); -bool isaPermutedMemRefType(MemRefType); +inline bool isaPermutedMemRefType(MemRefType memRefType) { + auto [ptrStrides, ptrOffsets] = memRefType.getStridesAndOffset(); + + switch (ptrStrides.size()) { + case 0: + return false; + case 1: + return false; + default: { + return ptrStrides[ptrStrides.size() - 1] != 1; + } + } +} // Retrieves the last (innermost) stride of a memref::ReinterpretCastOp if it is // a constant. diff --git a/compiler/lib/Dialect/LinalgExt/Transforms/AnnotateTransposePass.cpp b/compiler/lib/Dialect/LinalgExt/Transforms/AnnotateTransposePass.cpp new file mode 100644 index 00000000..c8820d06 --- /dev/null +++ b/compiler/lib/Dialect/LinalgExt/Transforms/AnnotateTransposePass.cpp @@ -0,0 +1,350 @@ +#include "dicp/Dialect/LinalgExt/Transforms/Passes.h" + +#include "bishengir/Dialect/Annotation/IR/Annotation.h" +#include "dicp/Utils/Utils.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" + +#include + +#define DEBUG_TYPE "annotate-transpose-pass" + +using namespace mlir; +using namespace mlir::dicp; + +namespace mlir { +namespace dicp { +namespace LinalgExt { +#define GEN_PASS_DEF_ANNOTATETRANSPOSE +#include "dicp/Dialect/LinalgExt/Transforms/Passes.h.inc" +} // namespace LinalgExt +} // namespace dicp +} // namespace mlir + +namespace { + +// ============================================================================== +// 辅助函数定义 +// ============================================================================== + +/// 检查一个MemRefType是否是置换类型的辅助函数 +/// 判定标准:使用 dicp::isaPermutedMemRefType 或 检查 stride 是否非标准 +bool isPermutedOrHasNonUnitLastStride(MemRefType memRefType) { + if (!memRefType) + return false; + + // 1. 使用现有的判定函数 + if (mlir::dicp::isaPermutedMemRefType(memRefType)) { + return true; + } + + // 2. 额外检查:最后维度的stride是否为1 + // 对于 Ascend 来说,如果最后维度 stride != + // 1,通常意味着不是连续内存,可能需要隐式转置 + auto [strides, offset] = memRefType.getStridesAndOffset(); + if (!strides.empty() && strides.back() != 1) { + return true; + } + + return false; +} + +/// 递归检查值的来源是否具有非标准stride +bool checkValueOriginHasNonStandardStride(Value value) { + if (auto memRefType = dyn_cast(value.getType())) { + if (isPermutedOrHasNonUnitLastStride(memRefType)) { + return true; + } + } + + // 检查定义操作 + if (Operation *defOp = value.getDefiningOp()) { + // 检查Subview操作 + if (auto subViewOp = dyn_cast(defOp)) { + return checkValueOriginHasNonStandardStride(subViewOp.getSource()); + } + // 检查ReinterpretCast操作 + if (auto castOp = dyn_cast(defOp)) { + return checkValueOriginHasNonStandardStride(castOp.getSource()); + } + } + + return false; +} + +struct AnnotateTransposePass + : public mlir::dicp::LinalgExt::impl::AnnotateTransposeBase< + AnnotateTransposePass> { + + void runOnOperation() override { + auto funcOp = getOperation(); + + LLVM_DEBUG(llvm::dbgs() + << "[INFO] Starting AnnotateTransposePass on function: " + << funcOp.getName() << "\n"); + + // 待处理列表 + SmallVector toTensorOpsToMark; + SmallVector opsToErase; // 用于存储被重写后需要删除的旧Op + + // ============================================================================== + // 1. 遍历 memref.copy 操作 + // 核心逻辑:检测 Dynamic Subview Copy -> 重写为 Static Full Copy + + // Annotation + // ============================================================================== + funcOp.walk([&](memref::CopyOp copyOp) { + auto source = copyOp.getSource(); + auto target = copyOp.getTarget(); + + LLVM_DEBUG(llvm::dbgs() << "[MEMREF_COPY_VISIT] " << copyOp << "\n"); + + // --- 尝试进行 IR 重写 (Rewrite) --- + // 目标:将 memref.copy(subview(A), subview(B)) 转换为 memref.copy(A, B) + // 条件:A 是静态 Permuted,B 是静态 Contiguous,且形状匹配 + + auto srcSubView = source.getDefiningOp(); + auto dstSubView = target.getDefiningOp(); + + if (srcSubView && dstSubView) { + Value baseSource = srcSubView.getSource(); + Value baseTarget = dstSubView.getSource(); + + auto baseSourceType = dyn_cast(baseSource.getType()); + auto baseTargetType = dyn_cast(baseTarget.getType()); + + if (baseSourceType && baseTargetType && + baseSourceType.hasStaticShape() && + baseTargetType.hasStaticShape()) { + + bool isBaseSourcePermuted = + isPermutedOrHasNonUnitLastStride(baseSourceType); + // 简化判定:如果不是 Permuted 且 stride 正常,视为 Contiguous + bool isBaseTargetContiguous = + !isPermutedOrHasNonUnitLastStride(baseTargetType); + + // 检查 Static Shape 是否一致 (例如都是 2x8xf32) + if (isBaseSourcePermuted && isBaseTargetContiguous && + baseSourceType.getShape() == baseTargetType.getShape()) { + + LLVM_DEBUG(llvm::dbgs() + << " [REWRITE_MATCH] Found Dynamic Subview Copy " + "candidate for Static Rewrite.\n"); + LLVM_DEBUG(llvm::dbgs() << " Base Source (Permuted): " + << baseSourceType << "\n"); + LLVM_DEBUG(llvm::dbgs() << " Base Target (Contiguous): " + << baseTargetType << "\n"); + + // 执行重写 + OpBuilder builder(copyOp->getContext()); + builder.setInsertionPoint(copyOp); + + // 1. 创建新的静态 Copy (Base -> Base) + auto newCopyOp = builder.create( + copyOp.getLoc(), baseSource, baseTarget); + LLVM_DEBUG(llvm::dbgs() << " -> Replaced with Static Copy: " + << newCopyOp << "\n"); + + // 2. 关键:在 Base Target (MemRef) 上添加 Annotation + // 这指导 Ascend 编译器生成隐式转置指令 + builder.setInsertionPointAfter(newCopyOp); + auto markOp = + builder.create(copyOp.getLoc(), baseTarget); + markOp->setAttr("MayImplicitTransposeWithLastAxis", + UnitAttr::get(builder.getContext())); + LLVM_DEBUG(llvm::dbgs() + << " -> Added Annotation to Base Target MemRef: " + << markOp << "\n"); + + // 3. 追踪 Base Target 的 Tensor 使用者 + // 我们需要标记 bufferization.to_tensor(BaseTarget),这样后续的 + // MatMul 才能识别到 Layout 变化 + for (auto user : baseTarget.getUsers()) { + if (auto toTensorOp = dyn_cast(user)) { + // 去重检查 + bool exists = false; + for (auto op : toTensorOpsToMark) + if (op == toTensorOp) + exists = true; + + if (!exists) { + toTensorOpsToMark.push_back(toTensorOp); + LLVM_DEBUG(llvm::dbgs() + << " -> Scheduled Base Target's ToTensorOp " + "for annotation: " + << toTensorOp << "\n"); + } + } + } + + // 4. 标记旧的 Copy Op 待删除 + opsToErase.push_back(copyOp); + + // 重写完成,跳过后续分析 + return; + } + } + } + + // --- 如果没有触发重写,执行常规的传播分析 --- + // (针对代码中已经是静态 Copy 的情况,或者仅仅进行标记传播) + + if (auto sourceType = dyn_cast(source.getType())) { + bool isSourcePermuted = isPermutedOrHasNonUnitLastStride(sourceType); + + if (auto targetType = dyn_cast(target.getType())) { + bool isTargetPermuted = isPermutedOrHasNonUnitLastStride(targetType); + + // 如果源是置换的,追踪目标的使用者 + if (isSourcePermuted) { + // 检查目标的使用者 + for (auto user : target.getUsers()) { + if (auto toTensorOp = dyn_cast(user)) { + bool exists = false; + for (auto op : toTensorOpsToMark) + if (op == toTensorOp) + exists = true; + if (!exists) { + toTensorOpsToMark.push_back(toTensorOp); + LLVM_DEBUG( + llvm::dbgs() + << " [PROPAGATE] Marked bufferization.to_tensor (Source " + "was permuted)\n"); + } + } + } + + // 如果目标是 Subview,追踪其父 MemRef + if (auto sourceDefOp = target.getDefiningOp()) { + if (auto subviewOp = dyn_cast(sourceDefOp)) { + Value parentMemRef = subviewOp.getSource(); + for (auto user : parentMemRef.getUsers()) { + if (auto toTensorOp = + dyn_cast(user)) { + bool exists = false; + for (auto op : toTensorOpsToMark) + if (op == toTensorOp) + exists = true; + if (!exists) { + toTensorOpsToMark.push_back(toTensorOp); + LLVM_DEBUG( + llvm::dbgs() + << " [PROPAGATE_PARENT] Marked " + "bufferization.to_tensor of Parent MemRef\n"); + } + } + } + } + } + } else if (isTargetPermuted) { + // 如果目标本身就是置换的 + for (auto user : target.getUsers()) { + if (auto toTensorOp = dyn_cast(user)) { + bool exists = false; + for (auto op : toTensorOpsToMark) + if (op == toTensorOp) + exists = true; + if (!exists) { + toTensorOpsToMark.push_back(toTensorOp); + LLVM_DEBUG( + llvm::dbgs() + << " [PROPAGATE_TARGET] Marked bufferization.to_tensor " + "(Target is permuted)\n"); + } + } + } + } + } + } + }); + + // 删除被重写的旧 Op + for (auto op : opsToErase) { + op->erase(); + } + + // ============================================================================== + // 2. 扫描所有 bufferization.to_tensor 操作 (查漏补缺) + // ============================================================================== + funcOp.walk([&](bufferization::ToTensorOp toTensorOp) { + // 如果已经在列表中,跳过 + for (auto existing : toTensorOpsToMark) { + if (existing == toTensorOp) + return; + } + + Value sourceMemRef = toTensorOp.getOperand(); + bool hasNonStandardStride = + checkValueOriginHasNonStandardStride(sourceMemRef); + + bool shouldMark = false; + if (auto memRefType = dyn_cast(sourceMemRef.getType())) { + if (isPermutedOrHasNonUnitLastStride(memRefType)) { + shouldMark = true; + } + } + + if (shouldMark || hasNonStandardStride) { + toTensorOpsToMark.push_back(toTensorOp); + LLVM_DEBUG(llvm::dbgs() + << "[TO_TENSOR_CHECK] Found permuted/strided origin: " + << toTensorOp << "\n"); + } + }); + + // ============================================================================== + // 3. 执行最终标记:为收集到的 Tensor 添加 Annotation + // ============================================================================== + for (auto toTensorOp : toTensorOpsToMark) { + // 双重检查:防止重复添加 MarkOp (虽然 OpBuilder 会创建新的 + // Op,但逻辑上我们不希望冗余) 简单检查该 Value 是否已经被 MarkOp 使用 + bool alreadyMarked = false; + // 注意:annotation::MarkOp 通常不直接作为 User 挂在 Value + // 上,而是作为一个独立的 Op 存在。 为了稳妥,这里我们假设 list + // 中可能有重复(如果 func.walk 逻辑有交集),去重已经在 push_back + // 时做了。 + + LLVM_DEBUG(llvm::dbgs() << " [ANNOTATE_ACTION] Adding annotation to: " + << toTensorOp << "\n"); + + OpBuilder builder(toTensorOp->getContext()); + builder.setInsertionPointAfter(toTensorOp); + + auto markOp = builder.create(toTensorOp->getLoc(), + toTensorOp.getResult()); + + markOp->setAttr("MayImplicitTransposeWithLastAxis", + UnitAttr::get(builder.getContext())); + + LLVM_DEBUG(llvm::dbgs() + << " -> Created annotation::MarkOp: " << markOp << "\n"); + } + + LLVM_DEBUG(llvm::dbgs() + << "[INFO] Finished AnnotateTransposePass on function: " + << funcOp.getName() << "\n"); + } +}; +} // namespace + +namespace mlir::dicp::LinalgExt { +std::unique_ptr> createAnnotateTransposePass() { + return std::make_unique(); +} +} // namespace mlir::dicp::LinalgExt diff --git a/compiler/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt b/compiler/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt index 0b28548a..272103de 100644 --- a/compiler/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt +++ b/compiler/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ add_triton_library(LinalgExtTransforms ScalarTo1DTensorPass.cpp RemoveSingleIterationLoop.cpp TensorTransform.cpp + AnnotateTransposePass.cpp DEPENDS LinalgExtTransformsIncGen @@ -26,4 +27,4 @@ add_triton_library(LinalgExtTransforms TritonArithToLinalg StructuredToMemref TritonToStructured -) +) \ No newline at end of file diff --git a/test/ascend/attention/test_lightning_attn.py b/test/ascend/attention/test_lightning_attn.py new file mode 100644 index 00000000..ff00b75e --- /dev/null +++ b/test/ascend/attention/test_lightning_attn.py @@ -0,0 +1,728 @@ +import math +import pytest +import torch + +from triton.backends.dicp_triton.device_utils import infer_device + +import torch +import enum +from typing import Tuple + +import triton +import triton.language as tl + + +class BackendType(enum.Enum): + """Backend type.""" + + TORCH = enum.auto() + TRITON = enum.auto() + + +def lightning_attention_prefill_forward_torch( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + past_key_value: torch.Tensor, + slope_rate: torch.Tensor, + BLOCK_SIZE=64, + in_place=True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform lightning attention prefill. + modify from: https://github.com/MiniMax-AI/MiniMax-M1/blob/main/modeling_minimax_m1.py + + Args: + q: Query tensor of shape [B, H, N, D] + k: Key tensor of shape [B, H, N, D] + v: Value tensor of shape [B, H, N, E] + kv_caches: Key-value cache tensor [B, H, D, E] + slope_rate: Decay rate tensor + BLOCK_SIZE: Size of blocks for processing + BLOCK_MODEL: Size of blocks for parallel processing + + Returns: + output: Attention output tensor [B, H, N, E] + kv_caches: Key-value cache tensor [B, H, D, E] + """ + b, h, n, d = q.shape + e = v.shape[-1] + assert q.ndim == 4 + assert past_key_value.shape == (b, h, d, e) + + s = slope_rate.to(torch.float32) + NUM_BLOCK = (n + BLOCK_SIZE - 1) // BLOCK_SIZE + + array = torch.arange(BLOCK_SIZE).to(q) + 1 + q_decay = torch.exp(-s * array.reshape(-1, 1)) + k_decay = torch.exp(-s * (BLOCK_SIZE - array.reshape(-1, 1))) + index = array[:, None] - array[None, :] + s_index = ( + s + * index[ + None, + None, + ] + ) + s_index = torch.where(index >= 0, -s_index, float("-inf")) + diag_decay = torch.exp(s_index) + + if past_key_value is not None: + kv = past_key_value + else: + kv = torch.zeros(b, h, d, e).to(torch.float32).to(q.device) + output = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device) + + for i in range(NUM_BLOCK): + si = i * BLOCK_SIZE + ei = min(si + BLOCK_SIZE, n) + m = ei - si + qi = q[:, :, si:ei].contiguous() + ki = k[:, :, si:ei].contiguous() + vi = v[:, :, si:ei].contiguous() + qkv_none_diag = torch.matmul(qi * q_decay[:, :m], kv).to(torch.float32) + qk = ( + torch.matmul(qi, ki.transpose(-1, -2)).to(torch.float32) + * diag_decay[:, :, :m, :m] + ) + qkv_diag = torch.matmul(qk, vi.to(torch.float32)) + output[:, :, si:ei] = qkv_none_diag + qkv_diag + block_decay = torch.exp(-s * m) + kv = block_decay * kv + torch.matmul( + (ki * k_decay[:, -m:]).transpose(-1, -2).to(vi.dtype), vi + ) + if in_place: + past_key_value.copy_(kv) + return output, past_key_value + else: + return output, kv + + +@triton.jit +def _fwd_loop_kernel( + q_ptr, + k_ptr, + v_ptr, + output_ptr, + slope_rate, + kv_cache_ptr, + b: tl.constexpr, + h: tl.constexpr, + n: tl.constexpr, + d: tl.constexpr, + e: tl.constexpr, + BLOCK: tl.constexpr, + NUM_BLOCK: tl.constexpr, + BLOCK_MODEL: tl.constexpr, +): + """ + Kernel for lightning attention prefill with KV cache. + """ + # get offset + off_bh = tl.program_id(0) + off_h = off_bh % h + off_e = tl.program_id(1) + qk_offset = off_bh * n * d + v_offset = off_bh * n * e + o_offset = off_bh * n * e + e_offset = off_e * BLOCK_MODEL + kv_offset = off_bh * d * e + + # get block ptr + Q_block_ptr = q_ptr + qk_offset + tl.arange(0, d)[None, :] + K_trans_block_ptr = k_ptr + qk_offset + tl.arange(0, d)[:, None] + V_block_ptr = v_ptr + v_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :] + O_block_ptr = output_ptr + o_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :] + S_block_ptr = slope_rate + off_h + + # init decay + s = tl.load(S_block_ptr).to(tl.float32) + off_block = tl.arange(0, BLOCK) + q_decay = tl.exp(-s.to(tl.float32) * (off_block[:, None] + 1)) + k_trans_decay = tl.exp(-s.to(tl.float32) * (BLOCK - 1 - off_block[None, :])) + block_decay = tl.exp(-s.to(tl.float32) * BLOCK) + + index = off_block[:, None] - off_block[None, :] + s_index = s * index + s_index = tl.where(index >= 0, -s_index, float("-inf")) + diag_decay = tl.exp(s_index) + + kv = tl.zeros([d, BLOCK_MODEL], dtype=tl.float32) + + # loop compute + for i in range(NUM_BLOCK): + if n < BLOCK * (i + 1): + block_decay = tl.exp(-s.to(tl.float32) * (n - BLOCK * i)) + # (BLOCK - 1 - off_block[None, :] + n - BLOCK) + k_trans_decay = tl.exp(-s.to(tl.float32) * (n - 1 - off_block[None, :])) + # load + q = tl.load( + Q_block_ptr + off_block[:, None] * d, mask=off_block[:, None] < n, other=0.0 + ).to(tl.float32) + k_trans = tl.load( + K_trans_block_ptr + off_block[None, :] * d, + mask=off_block[None, :] < n, + other=0.0, + ).to(tl.float32) + v = tl.load( + V_block_ptr + off_block[:, None] * e, mask=off_block[:, None] < n, other=0.0 + ).to(tl.float32) + + # compute + qk = tl.dot(q, k_trans) * diag_decay + o_intra = tl.dot(qk, v) + o_inter = tl.dot(q * q_decay, kv) + o = o_intra + o_inter + + # save and update + tl.store( + O_block_ptr + off_block[:, None] * e, + o.to(O_block_ptr.dtype.element_ty), + mask=off_block[:, None] < n, + ) + kv = block_decay * kv + tl.dot(k_trans * k_trans_decay, v) + off_block += BLOCK + + KV_block_ptr = ( + kv_cache_ptr + + kv_offset + + e_offset + + tl.arange(0, d)[:, None] * e + + tl.arange(0, BLOCK_MODEL)[None, :] + ) + tl.store( + KV_block_ptr, + kv.to(KV_block_ptr.dtype.element_ty), + ) + + +def lightning_attention_prefill_forward_triton_loop( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + past_key_value: torch.Tensor, + slope_rate: torch.Tensor, + BLOCK_SIZE=64, + BLOCK_MODEL=32, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform lightning attention prefill. + modify from: https://github.com/OpenNLPLab/lightning-attention/blob/main/lightning_attn/ops/triton/lightning_attn2.py + + Args: + q: Query tensor of shape [B, H, N, D] + k: Key tensor of shape [B, H, N, D] + v: Value tensor of shape [B, H, N, E] + kv_caches: Key-value cache tensor [B, H, D, E] + slope_rate: Decay rate tensor + BLOCK_SIZE: Size of blocks for processing + BLOCK_MODEL: Size of blocks for parallel processing + + Returns: + output: Attention output tensor [B, H, N, E] + kv_caches: Key-value cache tensor [B, H, D, E] + """ + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + s = slope_rate.contiguous() + + b, h, n, d = q.shape + e = v.shape[-1] + o = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device) + assert past_key_value.shape == (b, h, d, e) + assert o.shape == v.shape + assert o.dtype == v.dtype + + NUM_BLOCK = triton.cdiv(q.shape[2], BLOCK_SIZE) + # parallel over channel + BLOCK_M = min(triton.next_power_of_2(e), BLOCK_MODEL) + assert e % BLOCK_M == 0 + grid = (b * h, triton.cdiv(e, BLOCK_M)) + + if past_key_value is not None: + kv = past_key_value + assert kv.dtype == torch.float32 + else: + kv = torch.zeros(b, h, d, e).to(torch.float32).to(q.device) + + _fwd_loop_kernel[grid]( + q, + k, + v, + o, + s, + kv, + b, + h, + n, + d, + e, + BLOCK=BLOCK_SIZE, + NUM_BLOCK=NUM_BLOCK, + BLOCK_MODEL=BLOCK_M, + add_annotate_transpose=True, + ) + return o, kv + + +def lightning_attention_decode_forward_torch( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + past_key_value: torch.Tensor, + slope_rate: torch.Tensor, + in_place=True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform lightning attention decoding. + modify from: https://github.com/MiniMax-AI/MiniMax-M1/blob/main/modeling_minimax_m1.py + + Args: + q: Query tensor of shape [B, H, 1, D] + k: Key tensor of shape [B, H, 1, D] + v: Value tensor of shape [B, H, 1, E] + kv_caches: Key-value cache tensor [B, H, D, E] + slope_rate: Decay rate tensor + + Returns: + output: Attention output tensor [B, H, 1, E] + kv_caches: Key-value cache tensor [B, H, D, E] + """ + assert q.ndim == 4 + B, H, _, D = q.shape + E = v.shape[-1] + assert k.shape == (B, H, 1, D) + assert v.shape == (B, H, 1, E) + assert past_key_value.shape == (B, H, D, E) + kv = past_key_value + s = torch.exp(-slope_rate) + kv = ( + torch.einsum( + "... n d, ... n e -> ... d e", + k, + v, + ) + + s * kv + ) + qkv = torch.einsum("... n d, ... d e -> ... n e", q, kv.to(q.dtype)) + past_key_value.copy_(kv) + if in_place: + past_key_value.copy_(kv) + return qkv, past_key_value + else: + return qkv, kv + + +@triton.jit +def _lightningattn_attn_decode_kernel( + q_ptr, + k_ptr, + v_ptr, + kv_cache_ptr, + slope_rate, + output_ptr, + D: tl.constexpr, + qkv_b_stride, + qkv_h_stride, + cache_b_stride, + cache_h_stride, + cache_d_stride, + cache_e_stride, + BLOCK_SIZE: tl.constexpr, +): + """ + Kernel for lightning attention decoding with KV cache. + """ + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_d = tl.program_id(2) + + batch_id = pid_b + head_id = pid_h + + # Load decay rate for the current head + ratio = tl.load(slope_rate + pid_h) + + # Calculate offsets for dimensions + qk_d_offsets = tl.arange(0, D) + v_d_offsets = tl.arange(0, BLOCK_SIZE) + pid_d * BLOCK_SIZE + cache_d_offsets = ( + qk_d_offsets[:, None] * cache_d_stride + v_d_offsets[None, :] * cache_e_stride + ) + + # Calculate offsets for the current batch and head + q_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride + k_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride + v_offset = batch_id * qkv_b_stride + head_id * qkv_h_stride + + cache_offset = batch_id * cache_b_stride + head_id * cache_h_stride + + # Create masks for loading tensors + qk_mask = qk_d_offsets < D + v_mask = v_d_offsets < D + + # Load query, key, and value tensors + q = tl.load(q_ptr + q_offset + qk_d_offsets, mask=qk_mask, other=0.0) + k = tl.load(k_ptr + k_offset + qk_d_offsets, mask=qk_mask, other=0.0) + v = tl.load(v_ptr + v_offset + v_d_offsets, mask=v_mask, other=0.0) + + # Compute key-value outer product + kv_outer = k[:, None] * v[None, :] + kv_mask = qk_mask[:, None] & v_mask[None, :] + + # Apply decay to previous KV cache + ratio = tl.exp(-ratio) + kv_ptr = kv_cache_ptr + cache_offset + cache_d_offsets + kv_cache_old = tl.load(kv_ptr, mask=kv_mask, other=0.0) + kv_outer = kv_outer + ratio * kv_cache_old + + # Compute attention output + output = q[:, None].to(tl.float32) * kv_outer + output = tl.sum(output, axis=0) + tl.store(kv_ptr, kv_outer, mask=kv_mask) + tl.store(output_ptr + q_offset + v_d_offsets, output, mask=v_mask) + + +def lightning_attention_decode_forward_triton( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + past_key_value: torch.Tensor, + slope_rate: torch.Tensor, + BLOCK_SIZE: int = 128, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform lightning attention decoding using Triton kernels. + modify from: https://github.com/vllm-project/vllm/vllm/model_executor/layers/lightning_attn.py + + Args: + q: Query tensor of shape [B, H, 1, D] + k: Key tensor of shape [B, H, 1, D] + v: Value tensor of shape [B, H, 1, E] + kv_caches: Key-value cache tensor + slope_rate: Decay rate tensor + BLOCK_SIZE: Size of blocks for processing + + Returns: + output: Attention output tensor [B, H, 1, E] + kv_caches: Key-value cache tensor [B, H, D, E] + """ + assert q.ndim == 4 + B, H, _, D = q.shape + E = v.shape[-1] + assert k.shape == (B, H, 1, D) + assert v.shape == (B, H, 1, E) + assert past_key_value.shape == (B, H, D, E) + + # Initialize output tensor + o = torch.empty((B, H, 1, E), dtype=q.dtype, device=q.device) + + # Set grid dimensions for the kernel + grid = (B, H, D // BLOCK_SIZE) + + # Calculate strides for tensors + qkv_b_stride = q.stride(0) + qkv_h_stride = q.stride(1) + + cache_b_stride = past_key_value.stride(0) + cache_h_stride = past_key_value.stride(1) + cache_d_stride = past_key_value.stride(2) + cache_e_stride = past_key_value.stride(3) + + # Launch the kernel + _lightningattn_attn_decode_kernel[grid]( + q, + k, + v, + past_key_value, + slope_rate, + o, + D, + qkv_b_stride, + qkv_h_stride, + cache_b_stride, + cache_h_stride, + cache_d_stride, + cache_e_stride, + BLOCK_SIZE=BLOCK_SIZE, + ) + return o, past_key_value + + +def lightning_attention_prefill_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + past_key_value: torch.Tensor, + slope_rate: torch.Tensor, + BLOCK_SIZE=64, + BackendType: int = BackendType.TORCH, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + q: Query tensor of shape [B, H, N, D] + k: Key tensor of shape [B, H, N, D] + v: Value tensor of shape [B, H, N, E] + kv_caches: Key-value cache tensor [B, H, D, E] + slope_rate: Decay rate tensor + BLOCK_SIZE: Size of blocks for processing + BLOCK_MODEL: Size of blocks for parallel processing + + Returns: + output: Attention output tensor [B, H, N, E] + kv_caches: Key-value cache tensor [B, H, D, E] + """ + if BackendType == BackendType.TRITON: + return lightning_attention_prefill_forward_triton_loop( + q, k, v, past_key_value, slope_rate, BLOCK_SIZE + ) + else: + return lightning_attention_prefill_forward_torch( + q, + k, + v, + past_key_value, + slope_rate, + ) + + +def lightning_attention_decode_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + past_key_value: torch.Tensor, + slope_rate: torch.Tensor, + BLOCK_SIZE: int = 128, + BackendType: int = BackendType.TORCH, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + q: Query tensor of shape [B, H, 1, D] + k: Key tensor of shape [B, H, 1, D] + v: Value tensor of shape [B, H, 1, E] + kv_caches: Key-value cache tensor [B, H, D, E] + slope_rate: Decay rate tensor + BLOCK_SIZE: Size of blocks for processing in triton + BackendType: torch or triton + + Returns: + output: Attention output tensor [B, H, 1, E] + kv_caches: Key-value cache tensor [B, H, D, E] + """ + if BackendType == BackendType.TRITON: + return lightning_attention_decode_forward_triton( + q, k, v, past_key_value, slope_rate, BLOCK_SIZE + ) + else: + return lightning_attention_decode_forward_torch( + q, k, v, past_key_value, slope_rate + ) + + +class TestLightningAttn: + + @pytest.fixture + def B(self, request): + yield request.param + + @pytest.fixture + def H(self, request): + yield request.param + + @pytest.fixture + def N(self, request): + yield request.param + + @pytest.fixture + def D(self, request): + yield request.param + + @pytest.fixture + def E(self, request): + yield request.param + + @pytest.fixture + def dtype(self, request): + yield request.param + + @pytest.fixture + def BLOCK_SIZE(self, request): + yield request.param + + @pytest.fixture + def q_states(self, B, H, N, D, dtype): + yield torch.randn([B, H, N, D], dtype=dtype, device=infer_device()) + + @pytest.fixture + def k_states(self, B, H, N, D, dtype): + yield torch.randn([B, H, N, D], dtype=dtype, device=infer_device()) + + @pytest.fixture + def v_states(self, B, H, N, E, dtype): + yield torch.randn([B, H, N, E], dtype=dtype, device=infer_device()) + + @pytest.fixture + def past_key_value(self, B, H, D, E, dtype): + yield torch.randn([B, H, D, E], dtype=dtype, device=infer_device()) + + @pytest.fixture + def slope_rate(self, H, dtype): + def get_slopes(n): + + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + slope_rate = torch.tensor( + get_slopes(H), dtype=dtype, device=infer_device() + ).reshape(H, 1, 1) + yield slope_rate * (1 + 1e-5) + + # float32 only + @pytest.mark.parametrize( + ["B", "H", "N", "D", "E", "dtype", "BLOCK_SIZE"], + [ + (1, 64, 5, 64, 128, torch.float32, 8), + (1, 64, 72, 64, 64, torch.float32, 8), + # (1, 64, 72, 64, 64, torch.float32, 16), + ], + indirect=True, + ) + def test_lightning_attention_prefill( + self, + q_states, + k_states, + v_states, + slope_rate, + past_key_value, + BLOCK_SIZE, + dtype, + ): + past_key_value_torch = torch.zeros_like(past_key_value) + past_key_value_triton = torch.zeros_like(past_key_value) + out_torch, _ = lightning_attention_prefill_forward( + q_states, + k_states, + v_states, + past_key_value_torch, + slope_rate, + BLOCK_SIZE, + BackendType=BackendType.TORCH, + ) + out_triton, _ = lightning_attention_prefill_forward( + q_states, + k_states, + v_states, + past_key_value_triton, + slope_rate, + BLOCK_SIZE, + BackendType=BackendType.TRITON, + ) + + if dtype == torch.float32: + rtol = 1e-03 + atol = 1 + else: + rtol = 1e-03 + atol = 1 + + kv_check = torch.allclose( + past_key_value_torch, + past_key_value_triton, + rtol=rtol, + atol=atol, + ) + output_check = torch.allclose( + out_torch, + out_triton, + rtol=rtol, + atol=atol, + ) + + assert ( + kv_check + ), f"past_key_value torch:{past_key_value_torch}, past_key_value triton:{past_key_value_triton}" + assert output_check, f"output torch:{out_torch}, output triton:{out_triton}" + # print( + # f"debug torch kv_check:{past_key_value_torch}, triton :{past_key_value_triton}" + # ) + + # float32 only + @pytest.mark.parametrize( + ["B", "H", "N", "D", "E", "dtype", "BLOCK_SIZE"], + [ + (8, 64, 1, 128, 128, torch.float32, 64), + (16, 64, 1, 128, 128, torch.float32, 64), + ], + indirect=True, + ) + def test_lightning_attention_decode( + self, + q_states, + k_states, + v_states, + slope_rate, + past_key_value, + BLOCK_SIZE, + dtype, + ): + past_key_value_torch = torch.zeros_like(past_key_value) + past_key_value_triton = torch.zeros_like(past_key_value) + out_torch, _ = lightning_attention_decode_forward( + q_states, + k_states, + v_states, + past_key_value_torch, + slope_rate, + BLOCK_SIZE, + BackendType=BackendType.TORCH, + ) + out_triton, _ = lightning_attention_decode_forward( + q_states, + k_states, + v_states, + past_key_value_triton, + slope_rate, + BLOCK_SIZE, + BackendType=BackendType.TRITON, + ) + + if dtype == torch.float32: + rtol = 1e-03 + atol = 1e-02 + else: + rtol = 1e-03 + atol = 1e-02 + + kv_check = torch.allclose( + past_key_value_torch, + past_key_value_triton, + rtol=rtol, + atol=atol, + ) + output_check = torch.allclose( + out_torch, + out_triton, + rtol=rtol, + atol=atol, + ) + + assert ( + kv_check + ), f"past_key_value torch:{past_key_value_torch}, past_key_value triton:{past_key_value_triton}" + assert output_check, f"output torch:{out_torch}, output triton:{out_triton}" + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/ascend/run_tests.sh b/test/ascend/run_tests.sh index f4bb54cc..ab0dd697 100644 --- a/test/ascend/run_tests.sh +++ b/test/ascend/run_tests.sh @@ -26,4 +26,6 @@ for test_dir in "${pytestcase_dir[@]}"; do done - +export BISHENG_INSTALL_PATH=/mnt/data01/CI/DLCompiler/data/bishengir_20251215/bin/ +export PATH=$BISHENG_INSTALL_PATH:$PATH +run_pytestcases "attention" diff --git a/tools/dicp_triton_opt/dicp_triton_opt.cpp b/tools/dicp_triton_opt/dicp_triton_opt.cpp index dccd884e..a0482c95 100644 --- a/tools/dicp_triton_opt/dicp_triton_opt.cpp +++ b/tools/dicp_triton_opt/dicp_triton_opt.cpp @@ -105,6 +105,7 @@ inline void registerDICPDialects(mlir::DialectRegistry ®istry) { dicp::LinalgExt::registerLinalgGenericToSCFPass(); dicp::LinalgExt::registerScalarTo1DTensorPass(); dicp::LinalgExt::registerNormalizeSliceOpsPass(); + dicp::LinalgExt::registerAnnotateTransposePass(); registry.insert( dicp::LinalgExt::createScalarTo1DTensorPass()); }); + m.def("add_annotate_transpose", [](mlir::PassManager &pm) { + pm.addNestedPass( + dicp::LinalgExt::createAnnotateTransposePass()); + }); m.def("add_linalg_to_linked", [](mlir::PassManager &pm, bool globalKernel, bool namedOps) { pm.addPass(mlir::dicp::linked::createLinalgToLinkedPass(globalKernel, @@ -107,6 +111,7 @@ void init_triton_dicp_triton(py::module &&m) { dicp::LinalgExt::registerLinalgGenericToSCFPass(); dicp::LinalgExt::registerScalarTo1DTensorPass(); dicp::LinalgExt::registerNormalizeSliceOpsPass(); + dicp::LinalgExt::registerAnnotateTransposePass(); context.appendDialectRegistry(registry); context.loadAllAvailableDialects();