diff --git a/ftn/tools/ftn_opt.py b/ftn/tools/ftn_opt.py index 931f854..38b968b 100755 --- a/ftn/tools/ftn_opt.py +++ b/ftn/tools/ftn_opt.py @@ -3,10 +3,16 @@ from xdsl.dialects.builtin import ModuleOp +from typing import IO + from ftn.transforms.rewrite_fir_to_core import RewriteFIRToCore from ftn.transforms.merge_memref_deref import MergeMemRefDeref +from ftn.transforms.extract_target import ExtractTarget +from ftn.transforms.fpga.target_to_hls import TargetToHLSPass from ftn.transforms.lower_omp_target_data import LowerOmpTargetDataPass -# from ftn.transforms.extract_target import ExtractTarget +from ftn.transforms.apply_target_config import ApplyTargetConfig +from ftn.transforms.omp_target_to_kernel import OmpTargetToKernelPass +from ftn.transforms.lift_omp_to_tensor import LiftOmpToTensorPass # from ftn.transforms.isolate_target import IsolateTarget # from psy.extract_stencil import ExtractStencil # from ftn.transforms.tenstorrent.convert_to_tt import ConvertToTT @@ -18,6 +24,7 @@ from xdsl.xdsl_opt_main import xDSLOptMain from ftn.dialects import ftn_relative_cf +from ftn.dialects import device import traceback @@ -27,13 +34,24 @@ def register_all_passes(self): super().register_all_passes() self.register_pass("rewrite-fir-to-core", lambda: RewriteFIRToCore) self.register_pass("merge-memref-deref", lambda: MergeMemRefDeref) + self.register_pass("extract-target", lambda: ExtractTarget) + self.register_pass("target-to-hls", lambda: TargetToHLSPass) self.register_pass("lower-omp-target-data", lambda: LowerOmpTargetDataPass) - # self.register_pass("extract-target", lambda: ExtractTarget) + self.register_pass("apply-target", lambda: ApplyTargetConfig) + self.register_pass("omp-target-to-kernel", lambda: OmpTargetToKernelPass) + self.register_pass("lift-omp-to-tensor", lambda: LiftOmpToTensorPass) # self.register_pass("isolate-target", lambda: IsolateTarget) # self.register_pass("convert-to-tt", lambda: ConvertToTT) def register_all_targets(self): + def _output_fpga_host(prog: ModuleOp, output: IO[str]): + from ftn.dialects.fpga.host_printer import HostPrinter + + printer = HostPrinter(stream=output) + printer.print(prog) + super().register_all_targets() + self.available_targets["fpga-host"] = _output_fpga_host def setup_pipeline(self): super().setup_pipeline() @@ -50,6 +68,7 @@ def register_all_arguments(self, arg_parser: argparse.ArgumentParser): def register_all_dialects(self): super().register_all_dialects() self.ctx.load_dialect(ftn_relative_cf.Ftn_relative_cf) + self.ctx.load_dialect(device.Device) @staticmethod def get_passes_as_dict() -> Dict[str, Callable[[ModuleOp], None]]: diff --git a/ftn/transforms/apply_target_config.py b/ftn/transforms/apply_target_config.py index d0cf79a..9319c4a 100644 --- a/ftn/transforms/apply_target_config.py +++ b/ftn/transforms/apply_target_config.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from abc import ABC, abstractmethod from xdsl.context import Context from xdsl.dialects import builtin, dlti @@ -7,16 +8,32 @@ from ftn.dialects import device -class TenstorrentConfiguration: - def get(): +class TargetConfiguration(ABC): + @classmethod + @abstractmethod + def get(cls) -> dlti.TargetDeviceSpecAttr: ... + + @classmethod + @abstractmethod + def _memory_subsystem(cls) -> dlti.MapAttr: ... + + @classmethod + @abstractmethod + def _compute_subsystem(cls) -> dlti.MapAttr: ... + + +class TenstorrentConfiguration(TargetConfiguration): + @classmethod + def get(cls): return dlti.TargetDeviceSpecAttr( { - "memory": TenstorrentConfiguration._memory_subsystem(), - "compute": TenstorrentConfiguration._compute_subsystem(), + "memory": cls._memory_subsystem(), + "compute": cls._compute_subsystem(), } ) - def _memory_subsystem(): + @classmethod + def _memory_subsystem(cls): config = { "DRAM": { "kind": device.MemoryKindAttr(device.MemoryKind.DDR), @@ -25,7 +42,8 @@ def _memory_subsystem(): } return dlti.MapAttr(config) - def _compute_subsystem(): + @classmethod + def _compute_subsystem(cls): config = { "architecture_type": device.ArchitectureKindAttr( device.ArchitectureKind.MANYCORE @@ -43,16 +61,18 @@ def _compute_subsystem(): return dlti.MapAttr(config) -class U280Configuration: - def get(): +class U280Configuration(TargetConfiguration): + @classmethod + def get(cls): return dlti.TargetDeviceSpecAttr( { - "memory": U280Configuration._memory_subsystem(), - "compute": U280Configuration._compute_subsystem(), + "memory": cls._memory_subsystem(), + "compute": cls._compute_subsystem(), } ) - def _memory_subsystem(): + @classmethod + def _memory_subsystem(cls): config = { "DRAM": { "kind": device.MemoryKindAttr(device.MemoryKind.DDR), @@ -66,7 +86,8 @@ def _memory_subsystem(): } return dlti.MapAttr(config) - def _compute_subsystem(): + @classmethod + def _compute_subsystem(cls): config = { "architecture_type": device.ArchitectureKindAttr( device.ArchitectureKind.FPGA @@ -108,14 +129,22 @@ def generate_system_config(self, accelerator_name, accelerator_config): } ) + def _get_config(self) -> dlti.TargetDeviceSpecAttr: + """ + Get the device spec for the current `self.taregt` + + If overriding this function, make sure to *not* specify `name` field again + """ + if config := SYSTEM_CONFIGURATIONS.get(self.target): + return config.get() + raise ValueError(f"No such target configuration {self.target}") + def apply(self, ctx: Context, op: builtin.ModuleOp) -> None: op.attributes["omp.target_triples"] = builtin.ArrayAttr( [builtin.StringAttr(self.target)] ) - assert self.target in SYSTEM_CONFIGURATIONS.keys() - - config = SYSTEM_CONFIGURATIONS[self.target].get() + config = self._get_config() op.attributes["dlti.target_system_spec"] = self.generate_system_config( self.target, config diff --git a/ftn/transforms/extract_target.py b/ftn/transforms/extract_target.py index a84eb8c..8324c88 100644 --- a/ftn/transforms/extract_target.py +++ b/ftn/transforms/extract_target.py @@ -1,119 +1,57 @@ -from abc import ABC -from typing import TypeVar, cast -from dataclasses import dataclass -import itertools -from xdsl.utils.hints import isa -from xdsl.dialects import memref, scf, omp -from xdsl.ir import Operation, SSAValue, OpResult, Attribute, MLContext, Block, Region +from dataclasses import dataclass, field +from xdsl.context import Context +from xdsl.ir import Operation, Block, Region from xdsl.pattern_rewriter import (RewritePattern, PatternRewriter, op_type_rewrite_pattern, PatternRewriteWalker, GreedyRewritePatternApplier) from xdsl.passes import ModulePass -from xdsl.dialects import builtin, func, llvm, arith -from ftn.util.visitor import Visitor +from xdsl.dialects import builtin, func +from xdsl.rewriter import InsertPoint +from ftn.dialects import device +@dataclass class RewriteTarget(RewritePattern): - def __init__(self): - self.target_ops=[] + module : builtin.ModuleOp + target_ops: list[Operation] = field(default_factory=list) @op_type_rewrite_pattern - def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /): - arg_types=[] - arg_ssa=[] - - loc_idx=0 - - locations={} - - memref_dim_ops=[] - # Grab bounds and info, then at end the terminator - for var in op.map_vars: - var_op=var.owner - var_op.parent.detach_op(var_op) - arg_types.append(var_op.var_ptr[0].type) - arg_ssa.append(var_op.var_ptr[0]) - if isa(var_op.var_ptr[0].type, builtin.MemRefType): - memref_type=var_op.var_ptr[0].type - src_memref=var_op.var_ptr[0] - if isa(memref_type.element_type, builtin.MemRefType): - assert len(memref_type.shape) == 0 - memref_type=var_op.var_ptr[0].type.element_type - memref_loadop=memref.Load.get(src_memref, []) - src_memref=memref_loadop.results[0] - memref_dim_ops.append(memref_loadop) - for idx, s in enumerate(memref_type.shape): - assert isa(s, builtin.IntAttr) - if (s.data == -1): - # Need to pass the dimension shape size in explicitly as it is deferred - const_op=arith.Constant.from_int_and_width(idx, builtin.IndexType()) - dim_size=memref.Dim.from_source_and_index(src_memref, const_op) - memref_dim_ops+=[const_op, dim_size] - arg_ssa.append(dim_size.results[0]) - arg_types.append(dim_size.results[0].type) - - locations[var_op]=loc_idx - loc_idx+=1 - if len(var_op.bounds) > 0: - bound_op=var_op.bounds[0].owner - bound_op.parent.detach_op(bound_op) - #self.target_ops+=[bound_op, var_op] - arg_types.append(bound_op.lower[0].type) - arg_ssa.append(bound_op.lower[0]) - locations[bound_op]=loc_idx - # Add two, as second is the size - loc_idx+=2 - else: - pass#self.target_ops+=[var_op] - - new_block = Block(arg_types=arg_types) - - new_mapinfo_ssa=[] - for var in op.map_vars: - var_op=var.owner - - map_bounds=[] - if len(var_op.bounds) > 0: - bound_op=var_op.bounds[0].owner - res_types=[] - for res in bound_op.results: res_types.append(res.type) - new_bounds_op=omp.BoundsOp.build(operands=[[new_block.args[locations[bound_op]]], [], [], [], []], - properties={"stride_in_bytes": bound_op.stride_in_bytes}, - result_types=res_types) - - new_block.add_op(new_bounds_op) - map_bounds=[new_bounds_op.results[0]] - - res_types=[] - for res in var_op.results: res_types.append(res.type) - mapinfo_op=omp.MapInfoOp.build(operands=[[new_block.args[locations[var_op]]], [], map_bounds], - properties={"map_type": var_op.map_type, "var_name": var_op.var_name, "var_type": var_op.var_type}, - result_types=res_types) - new_mapinfo_ssa.append(mapinfo_op.results[0]) - - new_block.add_op(mapinfo_op) - - reg=op.region - op.detach_region(reg) - - new_omp_target_op=omp.TargetOp.build(operands=[[],[],[], new_mapinfo_ssa], regions=[reg]) - new_block.add_op(new_omp_target_op) - new_block.add_op(func.Return()) - - new_fn_type=builtin.FunctionType.from_lists(arg_types, []) - - body = Region() - body.add_block(new_block) - - new_func=func.FuncOp("tt_device", new_fn_type, body) - - self.target_ops=[new_func] - - call_fn=func.Call.create(properties={"callee": builtin.SymbolRefAttr("tt_device")}, operands=arg_ssa, result_types=[]) - op.parent.insert_ops_before(memref_dim_ops+[call_fn], op) - - op.parent.detach_op(op) + def match_and_rewrite(self, op: device.KernelCreate, rewriter: PatternRewriter, /): + arg_types = [] + for var in op.mapped_data: + assert isinstance(var.type, builtin.MemRefType) + var_type = var.type + arg_types.append(var_type) + + assert op.body + dev_func_body = rewriter.move_region_contents_to_new_regions(op.body) + dev_func = func.FuncOp.from_region( + "tt_device", + arg_types, + [], + dev_func_body + ) + + ## Fix type of the block arguments (there is a mismatch between mapped_data and block args) + n_args = len(dev_func_body.block.args) + for block_arg, arg_type in zip(dev_func_body.block.args, arg_types): + new_block_arg = dev_func_body.block.insert_arg(arg_type, len(dev_func_body.block.args)) + block_arg.replace_by(new_block_arg) + + for arg in dev_func_body.block.args[:n_args]: + rewriter.erase_block_argument(arg) + + # kernel_create cannot have both a pointer to a device_function and a body. + op.device_function = builtin.SymbolRefAttr(dev_func.sym_name) + + assert dev_func_body.block.last_op is not None, "The last operation in the device function block must not be None" + rewriter.erase_op(dev_func_body.block.last_op) + + rewriter.insert_op(func.ReturnOp(), InsertPoint.at_end(dev_func_body.block)) + + self.target_ops = [dev_func] + @dataclass(frozen=True) class ExtractTarget(ModulePass): @@ -122,15 +60,17 @@ class ExtractTarget(ModulePass): """ name = 'extract-target' - def apply(self, ctx: MLContext, module: builtin.ModuleOp): - rw_target= RewriteTarget() + def apply(self, ctx: Context, module: builtin.ModuleOp): + rw_target= RewriteTarget(module) walker = PatternRewriteWalker(GreedyRewritePatternApplier([ rw_target, ]), apply_recursively=False, walk_reverse=True) walker.rewrite_module(module) - containing_mod=builtin.ModuleOp([]) + # NOTE: The region recieving the block must be empty. Otherwise, the single block region rule of + # the module will not be satisfied. + containing_mod=builtin.ModuleOp(Region()) module.regions[0].move_blocks(containing_mod.regions[0]) new_module=builtin.ModuleOp(rw_target.target_ops, {"target": builtin.StringAttr("tt_device")}) diff --git a/ftn/transforms/fpga/target_to_hls.py b/ftn/transforms/fpga/target_to_hls.py new file mode 100644 index 0000000..81f1a2c --- /dev/null +++ b/ftn/transforms/fpga/target_to_hls.py @@ -0,0 +1,252 @@ +from dataclasses import dataclass, field +from xdsl.utils.hints import isa +from xdsl.dialects import memref, scf, omp +from xdsl.context import Context +from xdsl.ir import SSAValue, Block + +from xdsl.pattern_rewriter import (RewritePattern, PatternRewriter, + op_type_rewrite_pattern, + PatternRewriteWalker, + GreedyRewritePatternApplier) +from xdsl.passes import ModulePass +from xdsl.dialects import builtin, func, arith +from xdsl.rewriter import InsertPoint +from xdsl.dialects.experimental.hls import PragmaPipelineOp, PragmaUnrollOp + +class DerefMemrefs: + @staticmethod + def deref_scalar_memops(scalar_ssa: SSAValue, rewriter: PatternRewriter): + for use in scalar_ssa.uses: + if isinstance(use.operation, memref.LoadOp): + load_op = use.operation + # NOTE: the operand of the load operation is not a memref anymore, we have dereferenced it, + # so we forward it. + load_op.res.replace_by(load_op.memref) + rewriter.erase_op(use.operation) + elif isinstance(use.operation, memref.StoreOp): + rewriter.erase_op(use.operation) + + @staticmethod + def deref_memref_memops(memref_ssa: SSAValue, rewriter: PatternRewriter): + for use in memref_ssa.uses: + if isinstance(use.operation, memref.LoadOp): + # The first load was used to load the pointer to the array. The index to retrieve an element from the array is applied + # in the next load. Since we have dereferenced the first pointer, we need to end up with a single load that accesses + # the array directly. + ptr_load_op = use.operation + + # FIXME: this is assuming each dereferencing load only has one use + for ptr_use in ptr_load_op.res.uses: + if isinstance(ptr_use.operation, memref.LoadOp): + array_load_op = ptr_use.operation + array_idx = array_load_op.indices + + new_load_op = memref.LoadOp.get(memref_ssa, array_idx) + array_load_op.res.replace_by(ptr_load_op.res) + rewriter.replace_op(ptr_load_op, new_load_op) + rewriter.erase_op(array_load_op) + + elif isinstance(ptr_use.operation, memref.StoreOp): + array_store_op = ptr_use.operation + array_idx = array_store_op.indices + new_store_op = memref.StoreOp.get(array_store_op.value, memref_ssa, array_idx) + rewriter.insert_op(new_store_op, InsertPoint.before(ptr_load_op)) + rewriter.erase_op(array_store_op) + rewriter.erase_op(ptr_load_op) + + @staticmethod + def deref_args(func_op: func.FuncOp, rewriter : PatternRewriter): + """Dereference the arguments of a function operation.""" + new_input_types = [] + + for arg in func_op.body.block.args: + if isa(arg.type, builtin.MemRefType): + deref_type = arg.type.element_type + func_op.replace_argument_type(arg, deref_type, rewriter) + new_input_types.append(deref_type) + +class RemoveOps: + @staticmethod + def transform_omp_loop_nest_into_scf_for(loop_nest_op : omp.LoopNestOp, rewriter: PatternRewriter): + flat_loop_body = rewriter.move_region_contents_to_new_regions(loop_nest_op.body) + rewriter.replace_op(flat_loop_body.block.last_op, scf.YieldOp()) + flat_lb = arith.IndexCastOp(loop_nest_op.lowerBound[0], builtin.IndexType()) + flat_ub = arith.IndexCastOp(loop_nest_op.upperBound[0], builtin.IndexType()) + flat_step = arith.IndexCastOp(loop_nest_op.step[0], builtin.IndexType()) + rewriter.insert_op(flat_lb, InsertPoint.after(loop_nest_op.lowerBound[0].owner)) + rewriter.insert_op(flat_ub, InsertPoint.after(loop_nest_op.upperBound[0].owner)) + rewriter.insert_op(flat_step, InsertPoint.after(loop_nest_op.step[0].owner)) + rewriter.replace_value_with_new_type(flat_loop_body.block.args[0], builtin.IndexType()) + + # TODO: convert between i32 and index where appropriate, since omp.loop_nest operates with i32 and + # scf.for with index. + for arg_use in flat_loop_body.block.args[0].uses: + if isinstance(arg_use.operation, memref.StoreOp): + store_op = arg_use.operation + + ## Replace the store ops first + if isinstance(store_op.memref.type.element_type, builtin.IntegerType): + index_to_i32 = arith.IndexCastOp(store_op.value, builtin.i32) + store_op.value.replace_by_if(index_to_i32.result, lambda use: isinstance(use.operation, memref.StoreOp)) + rewriter.insert_op(index_to_i32, InsertPoint.before(store_op)) + + else: + alloca_op = store_op.memref.owner + assert isinstance(alloca_op, memref.AllocaOp) + index_alloca = memref.AllocaOp.get(builtin.IndexType(), shape=alloca_op.memref.type.shape) + rewriter.replace_op(alloca_op, index_alloca) + + idx_memref = store_op.memref + for idx_memref_use in idx_memref.uses: + if isinstance(idx_memref_use.operation, memref.LoadOp): + load_op = idx_memref_use.operation + index_load = memref.LoadOp.get(load_op.memref, load_op.indices) + rewriter.replace_op(load_op, index_load) + + # Original type of the block arg of the loop nest op + cast_ind_var = arith.IndexCastOp(index_load.res, builtin.i32) + rewriter.insert_op(cast_ind_var, InsertPoint.after(index_load)) + index_load.res.replace_by_if(cast_ind_var.result, lambda use: use.operation != cast_ind_var) + + flat_loop = scf.ForOp(flat_lb, flat_ub, flat_step, (), flat_loop_body) + #flat_loop = scf.ForOp(loop_nest_op.lowerBound[0], omp_loop_op.upperBound[0], omp_loop_op.step[0], (), flat_loop_body) + rewriter.replace_op(loop_nest_op, flat_loop) + + return flat_loop + + + @staticmethod + def remove_omp_parallel(parallel_op: omp.ParallelOp, rewriter: PatternRewriter): + ws_loop = None + for op in parallel_op.walk(): + if isinstance(op, omp.WsLoopOp): + ws_loop = op + break + + assert ws_loop + print(ws_loop.body.block) + omp_loop_op = ws_loop.body.block.first_op + ws_loop_block = ws_loop.body.block + ws_loop.body.detach_block(ws_loop_block) + rewriter.inline_block(ws_loop_block, InsertPoint.before(ws_loop)) + rewriter.erase_op(ws_loop) + + if isinstance(omp_loop_op, omp.LoopNestOp): + flat_loop = RemoveOps.transform_omp_loop_nest_into_scf_for(omp_loop_op, rewriter) + one = arith.ConstantOp.from_int_and_width(1, 32) + pragma_pipeline = PragmaPipelineOp(one) + rewriter.insert_op([one, pragma_pipeline], InsertPoint.at_start(flat_loop.body.block)) + + parallel_block = parallel_op.region.block + parallel_op.region.detach_block(parallel_block) + rewriter.erase_op(parallel_block.last_op) + rewriter.inline_block(parallel_block, InsertPoint.before(parallel_op)) + rewriter.erase_op(parallel_op) + + @staticmethod + def remove_omp_simd(simd_op : omp.SimdOp, rewriter : PatternRewriter): + for priv_var in simd_op.private_vars: + arg_idx = simd_op.operands.index(priv_var) + simd_op.body.block.args[arg_idx].replace_by(priv_var) + + omp_loop_op = simd_op.body.block.first_op + + if isinstance(omp_loop_op, omp.LoopNestOp): + flat_loop = RemoveOps.transform_omp_loop_nest_into_scf_for(omp_loop_op, rewriter) + simd_factor = simd_op.simdlen.value.data + ssa_simd_factor = arith.ConstantOp.from_int_and_width(simd_factor, 32) + pragma_unroll = PragmaUnrollOp(ssa_simd_factor) + rewriter.insert_op([ssa_simd_factor, pragma_unroll], InsertPoint.at_start(flat_loop.body.block)) + else: + flat_loop = simd_op.body.block.first_op + + assert isinstance(flat_loop, scf.ForOp) + flat_loop.detach() + rewriter.insert_op(flat_loop, InsertPoint.before(simd_op)) + #rewriter.erase_matched_op() #FIXME: this does not work + rewriter.erase_op(simd_op) + + + @staticmethod + def remove_remaining_omp_ops(target_func: func.FuncOp, rewriter: PatternRewriter): + """Remove any remaining OpenMP operations in the target function.""" + for op in target_func.walk(): + if isinstance(op, omp.MapInfoOp): + rewriter.erase_op(op) + + for op in target_func.walk(): + if isinstance(op, omp.MapBoundsOp): + rewriter.erase_op(op) + + @staticmethod + def forward_map_info(map_info: omp.MapInfoOp, rewriter: PatternRewriter): + map_info.omp_ptr.replace_by(map_info.var_ptr) + + +@dataclass +class TargetFuncToHLS(RewritePattern): + target_funcs : list[func.FuncOp] = field(default_factory=list) + + @op_type_rewrite_pattern + def match_and_rewrite(self, module: builtin.ModuleOp, rewriter: PatternRewriter, /): + if "target" not in module.attributes: + return + + target_name = module.attributes["target"].data + target_func = [op for op in module.walk() if isinstance(op, func.FuncOp) and op.sym_name.data == target_name][0] + self.target_funcs.append(target_func) + + for map_info in target_func.walk(): + if not isinstance(map_info, omp.MapInfoOp): + continue + + RemoveOps.forward_map_info(map_info, rewriter) + + omp_parallel = None + for op in target_func.walk(): + if isinstance(op, omp.ParallelOp): + omp_parallel = op + break + + if omp_parallel: + RemoveOps.remove_omp_parallel(omp_parallel, rewriter) + + omp_simd = None + for op in target_func.walk(): + if isinstance(op, omp.SimdOp): + omp_simd = op + break + + if omp_simd: + RemoveOps.remove_omp_simd(omp_simd, rewriter) + + RemoveOps.remove_remaining_omp_ops(target_func, rewriter) + + +@dataclass(frozen=True) +class TargetToHLSPass(ModulePass): + """ + This is the entry point for the transformation pass which will then apply the rewriter + """ + name = 'target-to-hls' + + generate : str = "hls" + + def apply(self, ctx: Context, module: builtin.ModuleOp): + target_funcs : list[func.FuncOp] = [] + walker = PatternRewriteWalker(GreedyRewritePatternApplier([ + TargetFuncToHLS(target_funcs), + ]), apply_recursively=False, walk_reverse=True) + + walker.rewrite_module(module) + + if self.generate == "hls": + # Keep only the top level module to contain the function + for target_func in target_funcs: + target_func.detach() + + module_block = module.body.block + module.body.detach_block(module.body.block) + module_block.erase() + module.body.add_block(Block(target_funcs)) + module.attributes = {} \ No newline at end of file diff --git a/ftn/transforms/lift_omp_to_tensor.py b/ftn/transforms/lift_omp_to_tensor.py index 46e0092..7a2e073 100644 --- a/ftn/transforms/lift_omp_to_tensor.py +++ b/ftn/transforms/lift_omp_to_tensor.py @@ -497,7 +497,7 @@ def lift_op( tensor_sizes.append(int((upper_const - (lower_const - 1)) / step_const)) else: # Otherwise this dimension size is dynamic - tensor_sizes.append(-1) + tensor_sizes.append(builtin.DYNAMIC_INDEX) # Create dependency tree walker, this is passed the private (intermediate) # memrefs, and SSA of the device mapped data diff --git a/ftn/transforms/lower_omp_target_data.py b/ftn/transforms/lower_omp_target_data.py index 130c95c..6c1453e 100644 --- a/ftn/transforms/lower_omp_target_data.py +++ b/ftn/transforms/lower_omp_target_data.py @@ -1,10 +1,11 @@ +from collections.abc import Sequence from dataclasses import dataclass from enum import Enum from xdsl.builder import Builder from xdsl.context import Context from xdsl.dialects import arith, builtin, memref, omp, scf -from xdsl.ir import BlockArgument +from xdsl.ir import Block, BlockArgument, Operation, Region, SSAValue from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, @@ -24,6 +25,8 @@ class DataEnvironmentDirection(Enum): BOTH = 3 +_DO_NOT_WAIT = 1<<64 + class DataMovementGenerator: def collect_mapped_vars_by_stack_and_heap(mapped_vars): sorted_mapped_vars = [] @@ -291,7 +294,7 @@ def generate_conditional_on_data_exists( """ data_exists_op = device.DataCheckExists(var_name, memory_space) - ops = [data_exists_op] + ops: list[Operation] = [data_exists_op] if is_not_conditional: const_op = arith.ConstantOp.from_int_and_width(1, 1) ex_io_op = arith.XOrIOp(const_op, data_exists_op, builtin.i1) @@ -300,6 +303,12 @@ def generate_conditional_on_data_exists( else: condition_ssa = data_exists_op + if false_region is None: + @Builder.implicit_region([]) + def false_region(args: tuple[BlockArgument, ...]) -> None: + dummy_tag = memref.AllocaOp.get(builtin.i32, shape=[]) + do_not_wait = arith.ConstantOp(builtin.IntegerAttr.from_index_int_value(_DO_NOT_WAIT)) + scf.YieldOp(dummy_tag, do_not_wait) cond = scf.IfOp( condition_ssa, conditional_return_type, true_region, false_region ) @@ -320,7 +329,7 @@ def generate_allocate_on_device(var_type, var_name, memory_space, size_ssas): """ dynamic_ssas = [] for idx, shape in enumerate(var_type.shape): - if shape.data == -1: + if shape.data == builtin.DYNAMIC_INDEX: dynamic_ssas.append(size_ssas[idx]) return device.AllocOp( @@ -427,15 +436,20 @@ def generate_copy_from_device(var_name, memory_space, var_type, dest): ) return tag_ssa, [device_memref] + ops_list - def generate_dma_waits_for_tags(wait_ssas_list): + @staticmethod + def generate_dma_waits_for_tags(wait_ssas_list: Sequence[tuple[SSAValue|Operation, SSAValue|Operation]]): """ Generates the DMA wait operations based upon the provided wait list, each entry in the wait list is a tuple (wait tag, number elements). """ - ops_list = [] + ops_list: list[Operation] = [] for tag, num_els in wait_ssas_list: - wait_op = memref.DmaWaitOp.get(tag, [], num_els) - ops_list.append(wait_op) + do_not_wait = arith.ConstantOp(builtin.IntegerAttr.from_index_int_value(_DO_NOT_WAIT)) + should_wait = arith.CmpiOp(num_els, do_not_wait, "ne") + if_op = scf.IfOp(should_wait, [], Region(Block( + [memref.DmaWaitOp.get(tag, [], num_els), scf.YieldOp()] + ))) + ops_list.extend([do_not_wait, should_wait, if_op]) return ops_list diff --git a/ftn/transforms/tenstorrent/convert_to_tt.py b/ftn/transforms/tenstorrent/convert_to_tt.py index d94fbbd..0bdb595 100644 --- a/ftn/transforms/tenstorrent/convert_to_tt.py +++ b/ftn/transforms/tenstorrent/convert_to_tt.py @@ -497,7 +497,7 @@ def generate_data_in(self, module, memory_type, references, cb_idxs, new_block, mem_size_bytes_op=arith.Muli(dt_width_conversion_op, new_block.args[(len(memory_type)*2)+idx]) read_op=data_movement.DMNocAsyncRead(dm_op.results[0], cb_op.results[0], mem_size_bytes_op) - target_memref=builtin.MemRefType(element_type, [-1]) + target_memref=builtin.MemRefType(element_type, [builtin.DYNAMIC_INDEX]) conversion_op=builtin.UnrealizedConversionCastOp.get([cb_op.results[0]], [target_memref]) conversion_op.results[0].name_hint = f"src{idx}_data" diff --git a/ftn/transforms/to_core/components/ftn_types.py b/ftn/transforms/to_core/components/ftn_types.py index 876667a..c211ddd 100644 --- a/ftn/transforms/to_core/components/ftn_types.py +++ b/ftn/transforms/to_core/components/ftn_types.py @@ -25,7 +25,7 @@ def compare_memrefs(memref_a, memref_b): return MemrefComparison.INCOMPATIBLE for dim_size_a, dim_size_b in zip(memref_a.shape, memref_b.shape): if dim_size_a.data != dim_size_b.data and ( - dim_size_a.data == -1 or dim_size_b.data == -1 + dim_size_a.data == builtin.DYNAMIC_INDEX or dim_size_b.data == builtin.DYNAMIC_INDEX ): return MemrefComparison.CONVERTABLE return MemrefComparison.SAME @@ -111,7 +111,7 @@ def convert_fir_type_to_standard(fir_type, ref_as_mem_ref=True): if isa(shape_el, builtin.IntegerAttr): dim_sizes.append(shape_el.value.data) else: - dim_sizes.append(-1) + dim_sizes.append(builtin.DYNAMIC_INDEX) # Reverse the sizes to go from Fortran to C allocation semantics dim_sizes.reverse() return builtin.MemRefType( @@ -229,7 +229,7 @@ def translate_convert(program_state: ProgramState, ctx: SSAValueCtx, op: fir.Con shape_size = [] for s in out_type.type.shape.data: if isa(s, fir.DeferredAttr): - shape_size.append(-1) + shape_size.append(builtin.DYNAMIC_INDEX) else: shape_size.append(s.value.data) # Reverse shape_size to get it from Fortran allocation to C/MLIR allocation diff --git a/ftn/transforms/to_core/components/intrinsics.py b/ftn/transforms/to_core/components/intrinsics.py index 5725281..8449863 100644 --- a/ftn/transforms/to_core/components/intrinsics.py +++ b/ftn/transforms/to_core/components/intrinsics.py @@ -19,12 +19,12 @@ def handle_create_temporary_linalg_output_memref( result_type, element_type, input_ssas, input_dims_to_read ): output_shape = [ - -1 if isa(s, fir.DeferredAttr) else s.value for s in result_type.shape + builtin.DYNAMIC_INDEX if isa(s, fir.DeferredAttr) else s.value for s in result_type.shape ] ops_list = [] dynamic_sizes = [] - if -1 in output_shape: + if builtin.DYNAMIC_INDEX in output_shape: # If we have deferred sizes then grab the output sizes from the input array sizes # Ensure all elements are -1 assert len(set(output_shape)) == 1 @@ -177,9 +177,9 @@ def handle_reduction_operation( memref_shape = [] memref_dynamic_sizes = [] elif len(reduction_dimensions) == 1: - if -1 in input_array_shape: + if builtin.DYNAMIC_INDEX in input_array_shape: assert len(set(input_array_shape)) == 1 - memref_shape = [-1] * (len(array_load_ssa.type.shape) - 1) + memref_shape = [builtin.DYNAMIC_INDEX] * (len(array_load_ssa.type.shape) - 1) memref_dynamic_sizes = [] if len(array_load_ssa.type.shape) > 1: for dim in list(range(len(array_load_ssa.type.shape))): diff --git a/ftn/transforms/to_core/components/load_store.py b/ftn/transforms/to_core/components/load_store.py index c05768e..8149148 100644 --- a/ftn/transforms/to_core/components/load_store.py +++ b/ftn/transforms/to_core/components/load_store.py @@ -410,13 +410,13 @@ def generate_allocatable_array_allocate( assert len(dim_sizes) == len(dim_starts) == len(dim_ends) - # Now create memref, passing -1 as shape will make this deferred size + # Now create memref, passing DYNAMIC_INDEX as shape will make this deferred size # Reverse the indicies as Fortran and C/MLIR are opposite in terms of # the order of the contiguous dimension (F is least, whereas C/MLIR is highest) dim_ssa_reversed = dim_ssas.copy() dim_ssa_reversed.reverse() memref_allocation_op = memref_alloca_op = memref.AllocOp.get( - base_type, shape=[-1] * len(dim_ssas), dynamic_sizes=dim_ssa_reversed + base_type, shape=[builtin.DYNAMIC_INDEX] * len(dim_ssas), dynamic_sizes=dim_ssa_reversed ) ops_list.append(memref_allocation_op) @@ -462,13 +462,13 @@ def handle_pointer_assignment( source_ssa = ctx[source_op] ops = [] - if any(i.data != -1 for i in ctx[source_op].type.shape.data): + if any(i.data != builtin.DYNAMIC_INDEX for i in ctx[source_op].type.shape.data): # The source type has explicit dimension sizes, by definition a pointer must be unknown # dimension sizes so we need to convert num_dims = len(ctx[source_op].type.shape.data) cast_op = memref.CastOp.get( source_ssa, - builtin.MemRefType(source_ssa.type.element_type, shape=num_dims * [-1]), + builtin.MemRefType(source_ssa.type.element_type, shape=num_dims * [builtin.DYNAMIC_INDEX]), ) source_ssa = cast_op.results[0] ops.append(cast_op) diff --git a/ftn/transforms/to_core/components/memory.py b/ftn/transforms/to_core/components/memory.py index d439ae6..7c8d000 100644 --- a/ftn/transforms/to_core/components/memory.py +++ b/ftn/transforms/to_core/components/memory.py @@ -423,7 +423,7 @@ def translate_declare( else: alloc_memref_container = memref.AllocaOp.get( builtin.MemRefType( - op.results[0].type.type.type.type.type, shape=num_dims * [-1] + op.results[0].type.type.type.type.type, shape=num_dims * [builtin.DYNAMIC_INDEX] ), shape=[], ) @@ -720,14 +720,14 @@ def translate_elemental(program_state, ctx, op: hlfir.ElementalOp): sizes.reverse() memref_shape = [ - -1 if isa(f, fir.DeferredAttr) else f.value.data + builtin.DYNAMIC_INDEX if isa(f, fir.DeferredAttr) else f.value.data for f in op.results[0].type.shape ] dynamic_sizes = [] for idx, s in enumerate(memref_shape): - if s == -1: + if s == builtin.DYNAMIC_INDEX: dynamic_sizes.append(ctx[sizes[idx]]) memref_alloca_op = memref.AllocOp(