diff --git a/ftn/tools/ftn_opt.py b/ftn/tools/ftn_opt.py index 931f854..01962d8 100755 --- a/ftn/tools/ftn_opt.py +++ b/ftn/tools/ftn_opt.py @@ -3,10 +3,15 @@ from xdsl.dialects.builtin import ModuleOp +from typing import IO + from ftn.transforms.rewrite_fir_to_core import RewriteFIRToCore from ftn.transforms.merge_memref_deref import MergeMemRefDeref +from ftn.transforms.extract_target import ExtractTarget +from ftn.transforms.fpga.target_to_hls import TargetToHLSPass from ftn.transforms.lower_omp_target_data import LowerOmpTargetDataPass -# from ftn.transforms.extract_target import ExtractTarget +from ftn.transforms.apply_target_config import ApplyTargetConfig +from ftn.transforms.omp_target_to_kernel import OmpTargetToKernelPass # from ftn.transforms.isolate_target import IsolateTarget # from psy.extract_stencil import ExtractStencil # from ftn.transforms.tenstorrent.convert_to_tt import ConvertToTT @@ -18,6 +23,7 @@ from xdsl.xdsl_opt_main import xDSLOptMain from ftn.dialects import ftn_relative_cf +from ftn.dialects import device import traceback @@ -27,13 +33,23 @@ def register_all_passes(self): super().register_all_passes() self.register_pass("rewrite-fir-to-core", lambda: RewriteFIRToCore) self.register_pass("merge-memref-deref", lambda: MergeMemRefDeref) + self.register_pass("extract-target", lambda: ExtractTarget) + self.register_pass("target-to-hls", lambda: TargetToHLSPass) self.register_pass("lower-omp-target-data", lambda: LowerOmpTargetDataPass) - # self.register_pass("extract-target", lambda: ExtractTarget) + self.register_pass("apply-target", lambda: ApplyTargetConfig) + self.register_pass("omp-target-to-kernel", lambda: OmpTargetToKernelPass) # self.register_pass("isolate-target", lambda: IsolateTarget) # self.register_pass("convert-to-tt", lambda: ConvertToTT) def register_all_targets(self): + def _output_fpga_host(prog: ModuleOp, output: IO[str]): + from ftn.dialects.fpga.host_printer import HostPrinter + + printer = HostPrinter(stream=output) + printer.print(prog) + super().register_all_targets() + self.available_targets["fpga-host"] = _output_fpga_host def setup_pipeline(self): super().setup_pipeline() @@ -50,6 +66,7 @@ def register_all_arguments(self, arg_parser: argparse.ArgumentParser): def register_all_dialects(self): super().register_all_dialects() self.ctx.load_dialect(ftn_relative_cf.Ftn_relative_cf) + self.ctx.load_dialect(device.Device) @staticmethod def get_passes_as_dict() -> Dict[str, Callable[[ModuleOp], None]]: diff --git a/ftn/transforms/extract_target.py b/ftn/transforms/extract_target.py index a84eb8c..8324c88 100644 --- a/ftn/transforms/extract_target.py +++ b/ftn/transforms/extract_target.py @@ -1,119 +1,57 @@ -from abc import ABC -from typing import TypeVar, cast -from dataclasses import dataclass -import itertools -from xdsl.utils.hints import isa -from xdsl.dialects import memref, scf, omp -from xdsl.ir import Operation, SSAValue, OpResult, Attribute, MLContext, Block, Region +from dataclasses import dataclass, field +from xdsl.context import Context +from xdsl.ir import Operation, Block, Region from xdsl.pattern_rewriter import (RewritePattern, PatternRewriter, op_type_rewrite_pattern, PatternRewriteWalker, GreedyRewritePatternApplier) from xdsl.passes import ModulePass -from xdsl.dialects import builtin, func, llvm, arith -from ftn.util.visitor import Visitor +from xdsl.dialects import builtin, func +from xdsl.rewriter import InsertPoint +from ftn.dialects import device +@dataclass class RewriteTarget(RewritePattern): - def __init__(self): - self.target_ops=[] + module : builtin.ModuleOp + target_ops: list[Operation] = field(default_factory=list) @op_type_rewrite_pattern - def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /): - arg_types=[] - arg_ssa=[] - - loc_idx=0 - - locations={} - - memref_dim_ops=[] - # Grab bounds and info, then at end the terminator - for var in op.map_vars: - var_op=var.owner - var_op.parent.detach_op(var_op) - arg_types.append(var_op.var_ptr[0].type) - arg_ssa.append(var_op.var_ptr[0]) - if isa(var_op.var_ptr[0].type, builtin.MemRefType): - memref_type=var_op.var_ptr[0].type - src_memref=var_op.var_ptr[0] - if isa(memref_type.element_type, builtin.MemRefType): - assert len(memref_type.shape) == 0 - memref_type=var_op.var_ptr[0].type.element_type - memref_loadop=memref.Load.get(src_memref, []) - src_memref=memref_loadop.results[0] - memref_dim_ops.append(memref_loadop) - for idx, s in enumerate(memref_type.shape): - assert isa(s, builtin.IntAttr) - if (s.data == -1): - # Need to pass the dimension shape size in explicitly as it is deferred - const_op=arith.Constant.from_int_and_width(idx, builtin.IndexType()) - dim_size=memref.Dim.from_source_and_index(src_memref, const_op) - memref_dim_ops+=[const_op, dim_size] - arg_ssa.append(dim_size.results[0]) - arg_types.append(dim_size.results[0].type) - - locations[var_op]=loc_idx - loc_idx+=1 - if len(var_op.bounds) > 0: - bound_op=var_op.bounds[0].owner - bound_op.parent.detach_op(bound_op) - #self.target_ops+=[bound_op, var_op] - arg_types.append(bound_op.lower[0].type) - arg_ssa.append(bound_op.lower[0]) - locations[bound_op]=loc_idx - # Add two, as second is the size - loc_idx+=2 - else: - pass#self.target_ops+=[var_op] - - new_block = Block(arg_types=arg_types) - - new_mapinfo_ssa=[] - for var in op.map_vars: - var_op=var.owner - - map_bounds=[] - if len(var_op.bounds) > 0: - bound_op=var_op.bounds[0].owner - res_types=[] - for res in bound_op.results: res_types.append(res.type) - new_bounds_op=omp.BoundsOp.build(operands=[[new_block.args[locations[bound_op]]], [], [], [], []], - properties={"stride_in_bytes": bound_op.stride_in_bytes}, - result_types=res_types) - - new_block.add_op(new_bounds_op) - map_bounds=[new_bounds_op.results[0]] - - res_types=[] - for res in var_op.results: res_types.append(res.type) - mapinfo_op=omp.MapInfoOp.build(operands=[[new_block.args[locations[var_op]]], [], map_bounds], - properties={"map_type": var_op.map_type, "var_name": var_op.var_name, "var_type": var_op.var_type}, - result_types=res_types) - new_mapinfo_ssa.append(mapinfo_op.results[0]) - - new_block.add_op(mapinfo_op) - - reg=op.region - op.detach_region(reg) - - new_omp_target_op=omp.TargetOp.build(operands=[[],[],[], new_mapinfo_ssa], regions=[reg]) - new_block.add_op(new_omp_target_op) - new_block.add_op(func.Return()) - - new_fn_type=builtin.FunctionType.from_lists(arg_types, []) - - body = Region() - body.add_block(new_block) - - new_func=func.FuncOp("tt_device", new_fn_type, body) - - self.target_ops=[new_func] - - call_fn=func.Call.create(properties={"callee": builtin.SymbolRefAttr("tt_device")}, operands=arg_ssa, result_types=[]) - op.parent.insert_ops_before(memref_dim_ops+[call_fn], op) - - op.parent.detach_op(op) + def match_and_rewrite(self, op: device.KernelCreate, rewriter: PatternRewriter, /): + arg_types = [] + for var in op.mapped_data: + assert isinstance(var.type, builtin.MemRefType) + var_type = var.type + arg_types.append(var_type) + + assert op.body + dev_func_body = rewriter.move_region_contents_to_new_regions(op.body) + dev_func = func.FuncOp.from_region( + "tt_device", + arg_types, + [], + dev_func_body + ) + + ## Fix type of the block arguments (there is a mismatch between mapped_data and block args) + n_args = len(dev_func_body.block.args) + for block_arg, arg_type in zip(dev_func_body.block.args, arg_types): + new_block_arg = dev_func_body.block.insert_arg(arg_type, len(dev_func_body.block.args)) + block_arg.replace_by(new_block_arg) + + for arg in dev_func_body.block.args[:n_args]: + rewriter.erase_block_argument(arg) + + # kernel_create cannot have both a pointer to a device_function and a body. + op.device_function = builtin.SymbolRefAttr(dev_func.sym_name) + + assert dev_func_body.block.last_op is not None, "The last operation in the device function block must not be None" + rewriter.erase_op(dev_func_body.block.last_op) + + rewriter.insert_op(func.ReturnOp(), InsertPoint.at_end(dev_func_body.block)) + + self.target_ops = [dev_func] + @dataclass(frozen=True) class ExtractTarget(ModulePass): @@ -122,15 +60,17 @@ class ExtractTarget(ModulePass): """ name = 'extract-target' - def apply(self, ctx: MLContext, module: builtin.ModuleOp): - rw_target= RewriteTarget() + def apply(self, ctx: Context, module: builtin.ModuleOp): + rw_target= RewriteTarget(module) walker = PatternRewriteWalker(GreedyRewritePatternApplier([ rw_target, ]), apply_recursively=False, walk_reverse=True) walker.rewrite_module(module) - containing_mod=builtin.ModuleOp([]) + # NOTE: The region recieving the block must be empty. Otherwise, the single block region rule of + # the module will not be satisfied. + containing_mod=builtin.ModuleOp(Region()) module.regions[0].move_blocks(containing_mod.regions[0]) new_module=builtin.ModuleOp(rw_target.target_ops, {"target": builtin.StringAttr("tt_device")}) diff --git a/ftn/transforms/fpga/target_to_hls.py b/ftn/transforms/fpga/target_to_hls.py new file mode 100644 index 0000000..81f1a2c --- /dev/null +++ b/ftn/transforms/fpga/target_to_hls.py @@ -0,0 +1,252 @@ +from dataclasses import dataclass, field +from xdsl.utils.hints import isa +from xdsl.dialects import memref, scf, omp +from xdsl.context import Context +from xdsl.ir import SSAValue, Block + +from xdsl.pattern_rewriter import (RewritePattern, PatternRewriter, + op_type_rewrite_pattern, + PatternRewriteWalker, + GreedyRewritePatternApplier) +from xdsl.passes import ModulePass +from xdsl.dialects import builtin, func, arith +from xdsl.rewriter import InsertPoint +from xdsl.dialects.experimental.hls import PragmaPipelineOp, PragmaUnrollOp + +class DerefMemrefs: + @staticmethod + def deref_scalar_memops(scalar_ssa: SSAValue, rewriter: PatternRewriter): + for use in scalar_ssa.uses: + if isinstance(use.operation, memref.LoadOp): + load_op = use.operation + # NOTE: the operand of the load operation is not a memref anymore, we have dereferenced it, + # so we forward it. + load_op.res.replace_by(load_op.memref) + rewriter.erase_op(use.operation) + elif isinstance(use.operation, memref.StoreOp): + rewriter.erase_op(use.operation) + + @staticmethod + def deref_memref_memops(memref_ssa: SSAValue, rewriter: PatternRewriter): + for use in memref_ssa.uses: + if isinstance(use.operation, memref.LoadOp): + # The first load was used to load the pointer to the array. The index to retrieve an element from the array is applied + # in the next load. Since we have dereferenced the first pointer, we need to end up with a single load that accesses + # the array directly. + ptr_load_op = use.operation + + # FIXME: this is assuming each dereferencing load only has one use + for ptr_use in ptr_load_op.res.uses: + if isinstance(ptr_use.operation, memref.LoadOp): + array_load_op = ptr_use.operation + array_idx = array_load_op.indices + + new_load_op = memref.LoadOp.get(memref_ssa, array_idx) + array_load_op.res.replace_by(ptr_load_op.res) + rewriter.replace_op(ptr_load_op, new_load_op) + rewriter.erase_op(array_load_op) + + elif isinstance(ptr_use.operation, memref.StoreOp): + array_store_op = ptr_use.operation + array_idx = array_store_op.indices + new_store_op = memref.StoreOp.get(array_store_op.value, memref_ssa, array_idx) + rewriter.insert_op(new_store_op, InsertPoint.before(ptr_load_op)) + rewriter.erase_op(array_store_op) + rewriter.erase_op(ptr_load_op) + + @staticmethod + def deref_args(func_op: func.FuncOp, rewriter : PatternRewriter): + """Dereference the arguments of a function operation.""" + new_input_types = [] + + for arg in func_op.body.block.args: + if isa(arg.type, builtin.MemRefType): + deref_type = arg.type.element_type + func_op.replace_argument_type(arg, deref_type, rewriter) + new_input_types.append(deref_type) + +class RemoveOps: + @staticmethod + def transform_omp_loop_nest_into_scf_for(loop_nest_op : omp.LoopNestOp, rewriter: PatternRewriter): + flat_loop_body = rewriter.move_region_contents_to_new_regions(loop_nest_op.body) + rewriter.replace_op(flat_loop_body.block.last_op, scf.YieldOp()) + flat_lb = arith.IndexCastOp(loop_nest_op.lowerBound[0], builtin.IndexType()) + flat_ub = arith.IndexCastOp(loop_nest_op.upperBound[0], builtin.IndexType()) + flat_step = arith.IndexCastOp(loop_nest_op.step[0], builtin.IndexType()) + rewriter.insert_op(flat_lb, InsertPoint.after(loop_nest_op.lowerBound[0].owner)) + rewriter.insert_op(flat_ub, InsertPoint.after(loop_nest_op.upperBound[0].owner)) + rewriter.insert_op(flat_step, InsertPoint.after(loop_nest_op.step[0].owner)) + rewriter.replace_value_with_new_type(flat_loop_body.block.args[0], builtin.IndexType()) + + # TODO: convert between i32 and index where appropriate, since omp.loop_nest operates with i32 and + # scf.for with index. + for arg_use in flat_loop_body.block.args[0].uses: + if isinstance(arg_use.operation, memref.StoreOp): + store_op = arg_use.operation + + ## Replace the store ops first + if isinstance(store_op.memref.type.element_type, builtin.IntegerType): + index_to_i32 = arith.IndexCastOp(store_op.value, builtin.i32) + store_op.value.replace_by_if(index_to_i32.result, lambda use: isinstance(use.operation, memref.StoreOp)) + rewriter.insert_op(index_to_i32, InsertPoint.before(store_op)) + + else: + alloca_op = store_op.memref.owner + assert isinstance(alloca_op, memref.AllocaOp) + index_alloca = memref.AllocaOp.get(builtin.IndexType(), shape=alloca_op.memref.type.shape) + rewriter.replace_op(alloca_op, index_alloca) + + idx_memref = store_op.memref + for idx_memref_use in idx_memref.uses: + if isinstance(idx_memref_use.operation, memref.LoadOp): + load_op = idx_memref_use.operation + index_load = memref.LoadOp.get(load_op.memref, load_op.indices) + rewriter.replace_op(load_op, index_load) + + # Original type of the block arg of the loop nest op + cast_ind_var = arith.IndexCastOp(index_load.res, builtin.i32) + rewriter.insert_op(cast_ind_var, InsertPoint.after(index_load)) + index_load.res.replace_by_if(cast_ind_var.result, lambda use: use.operation != cast_ind_var) + + flat_loop = scf.ForOp(flat_lb, flat_ub, flat_step, (), flat_loop_body) + #flat_loop = scf.ForOp(loop_nest_op.lowerBound[0], omp_loop_op.upperBound[0], omp_loop_op.step[0], (), flat_loop_body) + rewriter.replace_op(loop_nest_op, flat_loop) + + return flat_loop + + + @staticmethod + def remove_omp_parallel(parallel_op: omp.ParallelOp, rewriter: PatternRewriter): + ws_loop = None + for op in parallel_op.walk(): + if isinstance(op, omp.WsLoopOp): + ws_loop = op + break + + assert ws_loop + print(ws_loop.body.block) + omp_loop_op = ws_loop.body.block.first_op + ws_loop_block = ws_loop.body.block + ws_loop.body.detach_block(ws_loop_block) + rewriter.inline_block(ws_loop_block, InsertPoint.before(ws_loop)) + rewriter.erase_op(ws_loop) + + if isinstance(omp_loop_op, omp.LoopNestOp): + flat_loop = RemoveOps.transform_omp_loop_nest_into_scf_for(omp_loop_op, rewriter) + one = arith.ConstantOp.from_int_and_width(1, 32) + pragma_pipeline = PragmaPipelineOp(one) + rewriter.insert_op([one, pragma_pipeline], InsertPoint.at_start(flat_loop.body.block)) + + parallel_block = parallel_op.region.block + parallel_op.region.detach_block(parallel_block) + rewriter.erase_op(parallel_block.last_op) + rewriter.inline_block(parallel_block, InsertPoint.before(parallel_op)) + rewriter.erase_op(parallel_op) + + @staticmethod + def remove_omp_simd(simd_op : omp.SimdOp, rewriter : PatternRewriter): + for priv_var in simd_op.private_vars: + arg_idx = simd_op.operands.index(priv_var) + simd_op.body.block.args[arg_idx].replace_by(priv_var) + + omp_loop_op = simd_op.body.block.first_op + + if isinstance(omp_loop_op, omp.LoopNestOp): + flat_loop = RemoveOps.transform_omp_loop_nest_into_scf_for(omp_loop_op, rewriter) + simd_factor = simd_op.simdlen.value.data + ssa_simd_factor = arith.ConstantOp.from_int_and_width(simd_factor, 32) + pragma_unroll = PragmaUnrollOp(ssa_simd_factor) + rewriter.insert_op([ssa_simd_factor, pragma_unroll], InsertPoint.at_start(flat_loop.body.block)) + else: + flat_loop = simd_op.body.block.first_op + + assert isinstance(flat_loop, scf.ForOp) + flat_loop.detach() + rewriter.insert_op(flat_loop, InsertPoint.before(simd_op)) + #rewriter.erase_matched_op() #FIXME: this does not work + rewriter.erase_op(simd_op) + + + @staticmethod + def remove_remaining_omp_ops(target_func: func.FuncOp, rewriter: PatternRewriter): + """Remove any remaining OpenMP operations in the target function.""" + for op in target_func.walk(): + if isinstance(op, omp.MapInfoOp): + rewriter.erase_op(op) + + for op in target_func.walk(): + if isinstance(op, omp.MapBoundsOp): + rewriter.erase_op(op) + + @staticmethod + def forward_map_info(map_info: omp.MapInfoOp, rewriter: PatternRewriter): + map_info.omp_ptr.replace_by(map_info.var_ptr) + + +@dataclass +class TargetFuncToHLS(RewritePattern): + target_funcs : list[func.FuncOp] = field(default_factory=list) + + @op_type_rewrite_pattern + def match_and_rewrite(self, module: builtin.ModuleOp, rewriter: PatternRewriter, /): + if "target" not in module.attributes: + return + + target_name = module.attributes["target"].data + target_func = [op for op in module.walk() if isinstance(op, func.FuncOp) and op.sym_name.data == target_name][0] + self.target_funcs.append(target_func) + + for map_info in target_func.walk(): + if not isinstance(map_info, omp.MapInfoOp): + continue + + RemoveOps.forward_map_info(map_info, rewriter) + + omp_parallel = None + for op in target_func.walk(): + if isinstance(op, omp.ParallelOp): + omp_parallel = op + break + + if omp_parallel: + RemoveOps.remove_omp_parallel(omp_parallel, rewriter) + + omp_simd = None + for op in target_func.walk(): + if isinstance(op, omp.SimdOp): + omp_simd = op + break + + if omp_simd: + RemoveOps.remove_omp_simd(omp_simd, rewriter) + + RemoveOps.remove_remaining_omp_ops(target_func, rewriter) + + +@dataclass(frozen=True) +class TargetToHLSPass(ModulePass): + """ + This is the entry point for the transformation pass which will then apply the rewriter + """ + name = 'target-to-hls' + + generate : str = "hls" + + def apply(self, ctx: Context, module: builtin.ModuleOp): + target_funcs : list[func.FuncOp] = [] + walker = PatternRewriteWalker(GreedyRewritePatternApplier([ + TargetFuncToHLS(target_funcs), + ]), apply_recursively=False, walk_reverse=True) + + walker.rewrite_module(module) + + if self.generate == "hls": + # Keep only the top level module to contain the function + for target_func in target_funcs: + target_func.detach() + + module_block = module.body.block + module.body.detach_block(module.body.block) + module_block.erase() + module.body.add_block(Block(target_funcs)) + module.attributes = {} \ No newline at end of file