Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions ftn/tools/ftn_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from ftn.transforms.rewrite_fir_to_core import RewriteFIRToCore
from ftn.transforms.merge_memref_deref import MergeMemRefDeref
from ftn.transforms.lower_omp_target_data import LowerOmpTargetDataPass
# from ftn.transforms.extract_target import ExtractTarget
from ftn.transforms.apply_target_config import ApplyTargetConfig
from ftn.transforms.omp_target_to_kernel import OmpTargetToKernelPass
from ftn.transforms.lift_omp_to_tensor import LiftOmpToTensorPass
# from ftn.transforms.isolate_target import IsolateTarget
# from psy.extract_stencil import ExtractStencil
# from ftn.transforms.tenstorrent.convert_to_tt import ConvertToTT
Expand All @@ -28,7 +30,9 @@ def register_all_passes(self):
self.register_pass("rewrite-fir-to-core", lambda: RewriteFIRToCore)
self.register_pass("merge-memref-deref", lambda: MergeMemRefDeref)
self.register_pass("lower-omp-target-data", lambda: LowerOmpTargetDataPass)
# self.register_pass("extract-target", lambda: ExtractTarget)
self.register_pass("apply-target", lambda: ApplyTargetConfig)
self.register_pass("omp-target-to-kernel", lambda: OmpTargetToKernelPass)
self.register_pass("lift-omp-to-tensor", lambda: LiftOmpToTensorPass)
# self.register_pass("isolate-target", lambda: IsolateTarget)
# self.register_pass("convert-to-tt", lambda: ConvertToTT)

Expand Down
59 changes: 44 additions & 15 deletions ftn/transforms/apply_target_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from abc import ABC, abstractmethod

from xdsl.context import Context
from xdsl.dialects import builtin, dlti
Expand All @@ -7,16 +8,32 @@
from ftn.dialects import device


class TenstorrentConfiguration:
def get():
class TargetConfiguration(ABC):
@classmethod
@abstractmethod
def get(cls) -> dlti.TargetDeviceSpecAttr: ...

@classmethod
@abstractmethod
def _memory_subsystem(cls) -> dlti.MapAttr: ...

@classmethod
@abstractmethod
def _compute_subsystem(cls) -> dlti.MapAttr: ...


class TenstorrentConfiguration(TargetConfiguration):
@classmethod
def get(cls):
return dlti.TargetDeviceSpecAttr(
{
"memory": TenstorrentConfiguration._memory_subsystem(),
"compute": TenstorrentConfiguration._compute_subsystem(),
"memory": cls._memory_subsystem(),
"compute": cls._compute_subsystem(),
}
)

def _memory_subsystem():
@classmethod
def _memory_subsystem(cls):
config = {
"DRAM": {
"kind": device.MemoryKindAttr(device.MemoryKind.DDR),
Expand All @@ -25,7 +42,8 @@ def _memory_subsystem():
}
return dlti.MapAttr(config)

def _compute_subsystem():
@classmethod
def _compute_subsystem(cls):
config = {
"architecture_type": device.ArchitectureKindAttr(
device.ArchitectureKind.MANYCORE
Expand Down Expand Up @@ -53,16 +71,18 @@ def _compute_subsystem():
return dlti.MapAttr(config)


class U280Configuration:
def get():
class U280Configuration(TargetConfiguration):
@classmethod
def get(cls):
return dlti.TargetDeviceSpecAttr(
{
"memory": U280Configuration._memory_subsystem(),
"compute": U280Configuration._compute_subsystem(),
"memory": cls._memory_subsystem(),
"compute": cls._compute_subsystem(),
}
)

def _memory_subsystem():
@classmethod
def _memory_subsystem(cls):
config = {
"DRAM": {
"kind": device.MemoryKindAttr(device.MemoryKind.DDR),
Expand All @@ -76,7 +96,8 @@ def _memory_subsystem():
}
return dlti.MapAttr(config)

def _compute_subsystem():
@classmethod
def _compute_subsystem(cls):
config = {
"architecture_type": device.ArchitectureKindAttr(
device.ArchitectureKind.FPGA
Expand Down Expand Up @@ -153,14 +174,22 @@ def generate_system_config(self, accelerator_name, accelerator_config):
}
)

def _get_config(self) -> dlti.TargetDeviceSpecAttr:
"""
Get the device spec for the current `self.taregt`

If overriding this function, make sure to *not* specify `name` field again
"""
if config := SYSTEM_CONFIGURATIONS.get(self.target):
return config.get()
raise ValueError(f"No such target configuration {self.target}")

def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
op.attributes["omp.target_triples"] = builtin.ArrayAttr(
[builtin.StringAttr(self.target)]
)

assert self.target in SYSTEM_CONFIGURATIONS.keys()

config = SYSTEM_CONFIGURATIONS[self.target].get()
config = self._get_config()

op.attributes["dlti.target_system_spec"] = self.generate_system_config(
self.target, config
Expand Down
2 changes: 1 addition & 1 deletion ftn/transforms/lift_omp_to_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def lift_op(
tensor_sizes.append(int((upper_const - (lower_const - 1)) / step_const))
else:
# Otherwise this dimension size is dynamic
tensor_sizes.append(-1)
tensor_sizes.append(builtin.DYNAMIC_INDEX)

# Create dependency tree walker, this is passed the private (intermediate)
# memrefs, and SSA of the device mapped data
Expand Down
28 changes: 21 additions & 7 deletions ftn/transforms/lower_omp_target_data.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum

from xdsl.builder import Builder
from xdsl.context import Context
from xdsl.dialects import arith, builtin, memref, omp, scf
from xdsl.ir import BlockArgument, Block, Region
from xdsl.ir import Block, BlockArgument, Operation, Region, SSAValue
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
Expand All @@ -25,6 +26,8 @@ class DataEnvironmentDirection(Enum):
BOTH = 3


_DO_NOT_WAIT = 1<<64

class DataMovementGenerator:
def collect_mapped_vars_by_stack_and_heap(mapped_vars, use_mapped_vars):
sorted_mapped_vars = []
Expand Down Expand Up @@ -305,7 +308,7 @@ def generate_conditional_on_data_exists(
"""
data_exists_op = device.DataCheckExists(var_name, memory_space)

ops = [data_exists_op]
ops: list[Operation] = [data_exists_op]
if is_not_conditional:
const_op = arith.ConstantOp.from_int_and_width(1, 1)
ex_io_op = arith.XOrIOp(const_op, data_exists_op, builtin.i1)
Expand All @@ -314,6 +317,12 @@ def generate_conditional_on_data_exists(
else:
condition_ssa = data_exists_op

if false_region is None:
@Builder.implicit_region([])
def false_region(args: tuple[BlockArgument, ...]) -> None:
dummy_tag = memref.AllocaOp.get(builtin.i32, shape=[])
do_not_wait = arith.ConstantOp(builtin.IntegerAttr.from_index_int_value(_DO_NOT_WAIT))
scf.YieldOp(dummy_tag, do_not_wait)
cond = scf.IfOp(
condition_ssa, conditional_return_type, true_region, false_region
)
Expand All @@ -334,7 +343,7 @@ def generate_allocate_on_device(var_type, var_name, memory_space, size_ssas):
"""
dynamic_ssas = []
for idx, shape in enumerate(var_type.shape):
if shape.data == -1:
if shape.data == builtin.DYNAMIC_INDEX:
dynamic_ssas.append(size_ssas[idx])

return device.AllocOp(
Expand Down Expand Up @@ -441,15 +450,20 @@ def generate_copy_from_device(var_name, memory_space, var_type, dest):
)
return tag_ssa, [device_memref] + ops_list

def generate_dma_waits_for_tags(wait_ssas_list):
@staticmethod
def generate_dma_waits_for_tags(wait_ssas_list: Sequence[tuple[SSAValue|Operation, SSAValue|Operation]]):
"""
Generates the DMA wait operations based upon the provided wait list, each
entry in the wait list is a tuple (wait tag, number elements).
"""
ops_list = []
ops_list: list[Operation] = []
for tag, num_els in wait_ssas_list:
wait_op = memref.DmaWaitOp.get(tag, [], num_els)
ops_list.append(wait_op)
do_not_wait = arith.ConstantOp(builtin.IntegerAttr.from_index_int_value(_DO_NOT_WAIT))
should_wait = arith.CmpiOp(num_els, do_not_wait, "ne")
if_op = scf.IfOp(should_wait, [], Region(Block(
[memref.DmaWaitOp.get(tag, [], num_els), scf.YieldOp()]
)))
ops_list.extend([do_not_wait, should_wait, if_op])
return ops_list


Expand Down
2 changes: 1 addition & 1 deletion ftn/transforms/tenstorrent/convert_to_tt.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def generate_data_in(self, module, memory_type, references, cb_idxs, new_block,
mem_size_bytes_op=arith.Muli(dt_width_conversion_op, new_block.args[(len(memory_type)*2)+idx])
read_op=data_movement.DMNocAsyncRead(dm_op.results[0], cb_op.results[0], mem_size_bytes_op)

target_memref=builtin.MemRefType(element_type, [-1])
target_memref=builtin.MemRefType(element_type, [builtin.DYNAMIC_INDEX])
conversion_op=builtin.UnrealizedConversionCastOp.get([cb_op.results[0]], [target_memref])
conversion_op.results[0].name_hint = f"src{idx}_data"

Expand Down
14 changes: 7 additions & 7 deletions ftn/transforms/to_core/components/ftn_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def compare_memrefs(memref_a, memref_b):
return MemrefComparison.INCOMPATIBLE
for dim_size_a, dim_size_b in zip(memref_a.shape, memref_b.shape):
if dim_size_a.data != dim_size_b.data and (
dim_size_a.data == -1 or dim_size_b.data == -1
dim_size_a.data == builtin.DYNAMIC_INDEX or dim_size_b.data == builtin.DYNAMIC_INDEX
):
return MemrefComparison.CONVERTABLE
return MemrefComparison.SAME
Expand Down Expand Up @@ -77,7 +77,7 @@ def does_type_represent_ftn_pointer(type_chain):

def convert_fir_type_to_standard_if_needed(fir_type):
if isa(fir_type, fir.ReferenceType) and fir_type.type == builtin.i8:
return llvm.LLVMPointerType.opaque()
return llvm.LLVMPointerType()
else:
return convert_fir_type_to_standard(fir_type)

Expand All @@ -93,7 +93,7 @@ def convert_fir_type_to_standard(fir_type, ref_as_mem_ref=True):
base_t, [], builtin.NoneAttr(), builtin.NoneAttr()
)
else:
return llvm.LLVMPointerType.opaque()
return llvm.LLVMPointerType()
elif isa(fir_type, fir.BoxType):
return convert_fir_type_to_standard(fir_type.type, ref_as_mem_ref)
elif isa(fir_type, fir.HeapType):
Expand All @@ -111,7 +111,7 @@ def convert_fir_type_to_standard(fir_type, ref_as_mem_ref=True):
if isa(shape_el, builtin.IntegerAttr):
dim_sizes.append(shape_el.value.data)
else:
dim_sizes.append(-1)
dim_sizes.append(builtin.DYNAMIC_INDEX)
# Reverse the sizes to go from Fortran to C allocation semantics
dim_sizes.reverse()
return builtin.MemRefType(
Expand All @@ -121,7 +121,7 @@ def convert_fir_type_to_standard(fir_type, ref_as_mem_ref=True):
return builtin.i1
elif isa(fir_type, fir.BoxCharType):
return llvm.LLVMStructType.from_type_list(
[llvm.LLVMPointerType.opaque(), builtin.i64]
[llvm.LLVMPointerType(), builtin.i64]
)
elif isa(fir_type, builtin.TupleType):
new_types = []
Expand Down Expand Up @@ -214,7 +214,7 @@ def translate_convert(program_state: ProgramState, ctx: SSAValueCtx, op: fir.Con
get_element_ptr = llvm.GEPOp(
ctx[op.value],
[0, 0],
result_type=llvm.LLVMPointerType.opaque(),
result_type=llvm.LLVMPointerType(),
pointee_type=llvm.LLVMArrayType.from_size_and_type(
1, builtin.IntegerType(8)
),
Expand All @@ -229,7 +229,7 @@ def translate_convert(program_state: ProgramState, ctx: SSAValueCtx, op: fir.Con
shape_size = []
for s in out_type.type.shape.data:
if isa(s, fir.DeferredAttr):
shape_size.append(-1)
shape_size.append(builtin.DYNAMIC_INDEX)
else:
shape_size.append(s.value.data)
# Reverse shape_size to get it from Fortran allocation to C/MLIR allocation
Expand Down
2 changes: 1 addition & 1 deletion ftn/transforms/to_core/components/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def translate_function(program_state: ProgramState, ctx: SSAValueCtx, fn: func.F
if ftn_types.does_type_represent_ftn_pointer(fir_type):
# If we are passing a Fortran pointer then we need to handle this differently, actually pass
# the LLVM pointer of this and reconstruct, to access the same underlying memref
converted_type = llvm.LLVMPointerType.opaque()
converted_type = llvm.LLVMPointerType()
ptr_unpack_args.append((idx, fir_type))
else:
converted_type = ftn_types.convert_fir_type_to_standard(fir_type)
Expand Down
8 changes: 4 additions & 4 deletions ftn/transforms/to_core/components/intrinsics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ def handle_create_temporary_linalg_output_memref(
result_type, element_type, input_ssas, input_dims_to_read
):
output_shape = [
-1 if isa(s, fir.DeferredAttr) else s.value for s in result_type.shape
builtin.DYNAMIC_INDEX if isa(s, fir.DeferredAttr) else s.value for s in result_type.shape
]

ops_list = []
dynamic_sizes = []
if -1 in output_shape:
if builtin.DYNAMIC_INDEX in output_shape:
# If we have deferred sizes then grab the output sizes from the input array sizes
# Ensure all elements are -1
assert len(set(output_shape)) == 1
Expand Down Expand Up @@ -177,9 +177,9 @@ def handle_reduction_operation(
memref_shape = []
memref_dynamic_sizes = []
elif len(reduction_dimensions) == 1:
if -1 in input_array_shape:
if builtin.DYNAMIC_INDEX in input_array_shape:
assert len(set(input_array_shape)) == 1
memref_shape = [-1] * (len(array_load_ssa.type.shape) - 1)
memref_shape = [builtin.DYNAMIC_INDEX] * (len(array_load_ssa.type.shape) - 1)
memref_dynamic_sizes = []
if len(array_load_ssa.type.shape) > 1:
for dim in list(range(len(array_load_ssa.type.shape))):
Expand Down
8 changes: 4 additions & 4 deletions ftn/transforms/to_core/components/load_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,13 +410,13 @@ def generate_allocatable_array_allocate(

assert len(dim_sizes) == len(dim_starts) == len(dim_ends)

# Now create memref, passing -1 as shape will make this deferred size
# Now create memref, passing DYNAMIC_INDEX as shape will make this deferred size
# Reverse the indicies as Fortran and C/MLIR are opposite in terms of
# the order of the contiguous dimension (F is least, whereas C/MLIR is highest)
dim_ssa_reversed = dim_ssas.copy()
dim_ssa_reversed.reverse()
memref_allocation_op = memref_alloca_op = memref.AllocOp.get(
base_type, shape=[-1] * len(dim_ssas), dynamic_sizes=dim_ssa_reversed
base_type, shape=[builtin.DYNAMIC_INDEX] * len(dim_ssas), dynamic_sizes=dim_ssa_reversed
)
ops_list.append(memref_allocation_op)

Expand Down Expand Up @@ -462,13 +462,13 @@ def handle_pointer_assignment(
source_ssa = ctx[source_op]
ops = []

if any(i.data != -1 for i in ctx[source_op].type.shape.data):
if any(i.data != builtin.DYNAMIC_INDEX for i in ctx[source_op].type.shape.data):
# The source type has explicit dimension sizes, by definition a pointer must be unknown
# dimension sizes so we need to convert
num_dims = len(ctx[source_op].type.shape.data)
cast_op = memref.CastOp.get(
source_ssa,
builtin.MemRefType(source_ssa.type.element_type, shape=num_dims * [-1]),
builtin.MemRefType(source_ssa.type.element_type, shape=num_dims * [builtin.DYNAMIC_INDEX]),
)
source_ssa = cast_op.results[0]
ops.append(cast_op)
Expand Down
Loading