Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b6b1dd2
Modernise extract-target pass. Note how the location of the bounds wa…
gabrielrodcanal Jul 31, 2025
a62cba6
Add the signature of the offloaded function to the module calling the…
gabrielrodcanal Aug 1, 2025
cf99ae9
Fix location index. Now all the examples build
gabrielrodcanal Aug 1, 2025
b1c2ac1
Add upper bound to the offload function arguments. Otherwise the gene…
gabrielrodcanal Aug 1, 2025
c6c360f
Add pass that converts target function to HLS compatible format. Func…
gabrielrodcanal Aug 1, 2025
a7cf606
More modular design to remove the target operation
gabrielrodcanal Aug 1, 2025
a56dd6a
Keep only the top-level module to contain the HLS function - necessar…
gabrielrodcanal Aug 1, 2025
865e2b4
Merge branch 'main' into gabriel/target_to_hls
mesham Aug 3, 2025
58a649c
Dereference in the HLS pass not necessary anymore, as it is processed…
gabrielrodcanal Aug 4, 2025
c13fc06
Add scripts to generate LLVM IR compatible with Vitis HLS
gabrielrodcanal Aug 4, 2025
10b1893
Add option to trigger the generation of LLVM IR for Vitis HLS. This o…
gabrielrodcanal Aug 4, 2025
fa1bfb9
Add support for omp.parallel as a pipelined loop. Tested on ex4.F90
gabrielrodcanal Aug 5, 2025
fea7ad2
Add support for SIMD directive as an unrolled loop
gabrielrodcanal Aug 5, 2025
69c46d9
Merge branch 'main' into gabriel/target_to_hls
gabrielrodcanal Aug 6, 2025
0962050
Add option to print to target fpga-host
gabrielrodcanal Aug 6, 2025
ebc67ff
Register unregistered passes
gabrielrodcanal Aug 6, 2025
73291b2
extract-target must be now launched after lower-omp-target-data, sinc…
gabrielrodcanal Aug 6, 2025
e164950
extract-target now extracts the body of the kernel_create operation i…
gabrielrodcanal Aug 7, 2025
5271ff9
Simplify loop
gabrielrodcanal Aug 7, 2025
9efc806
Remove the removal of the omp.target operation in the target-to-hls p…
gabrielrodcanal Aug 7, 2025
ce0e724
Remove everything FPGA-specific: xftn options and bash files
gabrielrodcanal Aug 7, 2025
fdf5b70
Merge branch 'main' into gabriel/new_extract_target
gabrielrodcanal Aug 7, 2025
5889bf1
Remove function call in favour of device.kernel_create operation with…
gabrielrodcanal Aug 7, 2025
10b2aaf
Merge main into gabriel/new_extract_target
gabrielrodcanal Aug 8, 2025
1568041
Merge branch 'main' into gabriel/new_extract_target
gabrielrodcanal Aug 8, 2025
7fb953f
A device.create_kernel operation pointing to a device function lacks …
gabrielrodcanal Aug 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions ftn/tools/ftn_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@

from xdsl.dialects.builtin import ModuleOp

from typing import IO

from ftn.transforms.rewrite_fir_to_core import RewriteFIRToCore
from ftn.transforms.merge_memref_deref import MergeMemRefDeref
from ftn.transforms.extract_target import ExtractTarget
from ftn.transforms.fpga.target_to_hls import TargetToHLSPass
from ftn.transforms.lower_omp_target_data import LowerOmpTargetDataPass
# from ftn.transforms.extract_target import ExtractTarget
from ftn.transforms.apply_target_config import ApplyTargetConfig
from ftn.transforms.omp_target_to_kernel import OmpTargetToKernelPass
# from ftn.transforms.isolate_target import IsolateTarget
# from psy.extract_stencil import ExtractStencil
# from ftn.transforms.tenstorrent.convert_to_tt import ConvertToTT
Expand All @@ -18,6 +23,7 @@

from xdsl.xdsl_opt_main import xDSLOptMain
from ftn.dialects import ftn_relative_cf
from ftn.dialects import device

import traceback

Expand All @@ -27,13 +33,23 @@ def register_all_passes(self):
super().register_all_passes()
self.register_pass("rewrite-fir-to-core", lambda: RewriteFIRToCore)
self.register_pass("merge-memref-deref", lambda: MergeMemRefDeref)
self.register_pass("extract-target", lambda: ExtractTarget)
self.register_pass("target-to-hls", lambda: TargetToHLSPass)
self.register_pass("lower-omp-target-data", lambda: LowerOmpTargetDataPass)
# self.register_pass("extract-target", lambda: ExtractTarget)
self.register_pass("apply-target", lambda: ApplyTargetConfig)
self.register_pass("omp-target-to-kernel", lambda: OmpTargetToKernelPass)
# self.register_pass("isolate-target", lambda: IsolateTarget)
# self.register_pass("convert-to-tt", lambda: ConvertToTT)

def register_all_targets(self):
def _output_fpga_host(prog: ModuleOp, output: IO[str]):
from ftn.dialects.fpga.host_printer import HostPrinter

printer = HostPrinter(stream=output)
printer.print(prog)

super().register_all_targets()
self.available_targets["fpga-host"] = _output_fpga_host

def setup_pipeline(self):
super().setup_pipeline()
Expand All @@ -50,6 +66,7 @@ def register_all_arguments(self, arg_parser: argparse.ArgumentParser):
def register_all_dialects(self):
super().register_all_dialects()
self.ctx.load_dialect(ftn_relative_cf.Ftn_relative_cf)
self.ctx.load_dialect(device.Device)

@staticmethod
def get_passes_as_dict() -> Dict[str, Callable[[ModuleOp], None]]:
Expand Down
158 changes: 49 additions & 109 deletions ftn/transforms/extract_target.py
Original file line number Diff line number Diff line change
@@ -1,119 +1,57 @@
from abc import ABC
from typing import TypeVar, cast
from dataclasses import dataclass
import itertools
from xdsl.utils.hints import isa
from xdsl.dialects import memref, scf, omp
from xdsl.ir import Operation, SSAValue, OpResult, Attribute, MLContext, Block, Region
from dataclasses import dataclass, field
from xdsl.context import Context
from xdsl.ir import Operation, Block, Region

from xdsl.pattern_rewriter import (RewritePattern, PatternRewriter,
op_type_rewrite_pattern,
PatternRewriteWalker,
GreedyRewritePatternApplier)
from xdsl.passes import ModulePass
from xdsl.dialects import builtin, func, llvm, arith
from ftn.util.visitor import Visitor
from xdsl.dialects import builtin, func
from xdsl.rewriter import InsertPoint
from ftn.dialects import device

@dataclass
class RewriteTarget(RewritePattern):
def __init__(self):
self.target_ops=[]
module : builtin.ModuleOp
target_ops: list[Operation] = field(default_factory=list)

@op_type_rewrite_pattern
def match_and_rewrite(self, op: omp.TargetOp, rewriter: PatternRewriter, /):
arg_types=[]
arg_ssa=[]

loc_idx=0

locations={}

memref_dim_ops=[]
# Grab bounds and info, then at end the terminator
for var in op.map_vars:
var_op=var.owner
var_op.parent.detach_op(var_op)
arg_types.append(var_op.var_ptr[0].type)
arg_ssa.append(var_op.var_ptr[0])
if isa(var_op.var_ptr[0].type, builtin.MemRefType):
memref_type=var_op.var_ptr[0].type
src_memref=var_op.var_ptr[0]
if isa(memref_type.element_type, builtin.MemRefType):
assert len(memref_type.shape) == 0
memref_type=var_op.var_ptr[0].type.element_type
memref_loadop=memref.Load.get(src_memref, [])
src_memref=memref_loadop.results[0]
memref_dim_ops.append(memref_loadop)
for idx, s in enumerate(memref_type.shape):
assert isa(s, builtin.IntAttr)
if (s.data == -1):
# Need to pass the dimension shape size in explicitly as it is deferred
const_op=arith.Constant.from_int_and_width(idx, builtin.IndexType())
dim_size=memref.Dim.from_source_and_index(src_memref, const_op)
memref_dim_ops+=[const_op, dim_size]
arg_ssa.append(dim_size.results[0])
arg_types.append(dim_size.results[0].type)

locations[var_op]=loc_idx
loc_idx+=1
if len(var_op.bounds) > 0:
bound_op=var_op.bounds[0].owner
bound_op.parent.detach_op(bound_op)
#self.target_ops+=[bound_op, var_op]
arg_types.append(bound_op.lower[0].type)
arg_ssa.append(bound_op.lower[0])
locations[bound_op]=loc_idx
# Add two, as second is the size
loc_idx+=2
else:
pass#self.target_ops+=[var_op]

new_block = Block(arg_types=arg_types)

new_mapinfo_ssa=[]
for var in op.map_vars:
var_op=var.owner

map_bounds=[]
if len(var_op.bounds) > 0:
bound_op=var_op.bounds[0].owner
res_types=[]
for res in bound_op.results: res_types.append(res.type)
new_bounds_op=omp.BoundsOp.build(operands=[[new_block.args[locations[bound_op]]], [], [], [], []],
properties={"stride_in_bytes": bound_op.stride_in_bytes},
result_types=res_types)

new_block.add_op(new_bounds_op)
map_bounds=[new_bounds_op.results[0]]

res_types=[]
for res in var_op.results: res_types.append(res.type)
mapinfo_op=omp.MapInfoOp.build(operands=[[new_block.args[locations[var_op]]], [], map_bounds],
properties={"map_type": var_op.map_type, "var_name": var_op.var_name, "var_type": var_op.var_type},
result_types=res_types)
new_mapinfo_ssa.append(mapinfo_op.results[0])

new_block.add_op(mapinfo_op)

reg=op.region
op.detach_region(reg)

new_omp_target_op=omp.TargetOp.build(operands=[[],[],[], new_mapinfo_ssa], regions=[reg])
new_block.add_op(new_omp_target_op)
new_block.add_op(func.Return())

new_fn_type=builtin.FunctionType.from_lists(arg_types, [])

body = Region()
body.add_block(new_block)

new_func=func.FuncOp("tt_device", new_fn_type, body)

self.target_ops=[new_func]

call_fn=func.Call.create(properties={"callee": builtin.SymbolRefAttr("tt_device")}, operands=arg_ssa, result_types=[])
op.parent.insert_ops_before(memref_dim_ops+[call_fn], op)

op.parent.detach_op(op)
def match_and_rewrite(self, op: device.KernelCreate, rewriter: PatternRewriter, /):
arg_types = []
for var in op.mapped_data:
assert isinstance(var.type, builtin.MemRefType)
var_type = var.type
arg_types.append(var_type)

assert op.body
dev_func_body = rewriter.move_region_contents_to_new_regions(op.body)
dev_func = func.FuncOp.from_region(
"tt_device",
arg_types,
[],
dev_func_body
)

## Fix type of the block arguments (there is a mismatch between mapped_data and block args)
n_args = len(dev_func_body.block.args)
for block_arg, arg_type in zip(dev_func_body.block.args, arg_types):
new_block_arg = dev_func_body.block.insert_arg(arg_type, len(dev_func_body.block.args))
block_arg.replace_by(new_block_arg)

for arg in dev_func_body.block.args[:n_args]:
rewriter.erase_block_argument(arg)

# kernel_create cannot have both a pointer to a device_function and a body.
op.device_function = builtin.SymbolRefAttr(dev_func.sym_name)

assert dev_func_body.block.last_op is not None, "The last operation in the device function block must not be None"
rewriter.erase_op(dev_func_body.block.last_op)

rewriter.insert_op(func.ReturnOp(), InsertPoint.at_end(dev_func_body.block))

self.target_ops = [dev_func]


@dataclass(frozen=True)
class ExtractTarget(ModulePass):
Expand All @@ -122,15 +60,17 @@ class ExtractTarget(ModulePass):
"""
name = 'extract-target'

def apply(self, ctx: MLContext, module: builtin.ModuleOp):
rw_target= RewriteTarget()
def apply(self, ctx: Context, module: builtin.ModuleOp):
rw_target= RewriteTarget(module)
walker = PatternRewriteWalker(GreedyRewritePatternApplier([
rw_target,
]), apply_recursively=False, walk_reverse=True)

walker.rewrite_module(module)

containing_mod=builtin.ModuleOp([])
# NOTE: The region recieving the block must be empty. Otherwise, the single block region rule of
# the module will not be satisfied.
containing_mod=builtin.ModuleOp(Region())
module.regions[0].move_blocks(containing_mod.regions[0])

new_module=builtin.ModuleOp(rw_target.target_ops, {"target": builtin.StringAttr("tt_device")})
Expand Down
Loading