Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ static PyMethodDef algorithms_PyMethodDef[] = {
METH_VARARGS | METH_KEYWORDS, ""},
{"bubble_sort_llvm", (PyCFunction)bubble_sort_llvm,
METH_VARARGS | METH_KEYWORDS, ""},
{"selection_sort_llvm", (PyCFunction)selection_sort_llvm,
METH_VARARGS | METH_KEYWORDS, ""},
{"selection_sort", (PyCFunction) selection_sort,
METH_VARARGS | METH_KEYWORDS, ""},
{"insertion_sort", (PyCFunction) insertion_sort,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,131 @@ def _materialize(dtype: str) -> int:

except Exception as e:
raise RuntimeError(f"Failed to materialize function for dtype {dtype}: {e}")

def get_selection_sort_ptr(dtype: str) -> int:
"""Get function pointer for selection sort with specified dtype."""
dtype = dtype.lower().strip()
if dtype not in _SUPPORTED:
raise ValueError(f"Unsupported dtype '{dtype}'. Supported: {list(_SUPPORTED)}")

return _materialize_selection(dtype)


def _build_selection_sort_ir(dtype: str) -> str:
if dtype not in _SUPPORTED:
raise ValueError(f"Unsupported dtype '{dtype}'. Supported: {list(_SUPPORTED)}")

T, _ = _SUPPORTED[dtype]
i32 = ir.IntType(32)
i64 = ir.IntType(64)

mod = ir.Module(name=f"selection_sort_{dtype}_module")
fn_name = f"selection_sort_{dtype}"

fn_ty = ir.FunctionType(ir.VoidType(), [T.as_pointer(), i32])
fn = ir.Function(mod, fn_ty, name=fn_name)

arr, n = fn.args
arr.name, n.name = "arr", "n"

# Basic blocks
b_entry = fn.append_basic_block("entry")
b_outer = fn.append_basic_block("outer")
b_inner = fn.append_basic_block("inner")
b_inner_latch = fn.append_basic_block("inner.latch")
b_swap = fn.append_basic_block("swap")
b_exit = fn.append_basic_block("exit")

b = ir.IRBuilder(b_entry)
cond_trivial = b.icmp_signed("<=", n, ir.Constant(i32, 1))
b.cbranch(cond_trivial, b_exit, b_outer)

# Outer loop
b.position_at_end(b_outer)
i_phi = b.phi(i32, name="i")
i_phi.add_incoming(ir.Constant(i32, 0), b_entry) # start at 0

cond_outer = b.icmp_signed("<", i_phi, n)
b.cbranch(cond_outer, b_inner, b_exit)

# Inner loop: find min index
b.position_at_end(b_inner)
min_idx = b_phi = b_phi_i = b.phi(i32, name="min_idx")
min_idx.add_incoming(i_phi, b_outer) # initial min_idx = i

j_phi = b.phi(i32, name="j")
j_phi.add_incoming(b.add(i_phi, ir.Constant(i32, 1)), b_outer)

cond_inner = b.icmp_signed("<", j_phi, n)
b.cbranch(cond_inner, b_inner_latch, b_swap)

# Compare and update min_idx
b.position_at_end(b_inner_latch)
j64 = b.sext(j_phi, i64)
min64 = b.sext(min_idx, i64)
arr_j_ptr = b.gep(arr, [j64], inbounds=True)
arr_min_ptr = b.gep(arr, [min64], inbounds=True)
arr_j_val = b.load(arr_j_ptr)
arr_min_val = b.load(arr_min_ptr)

if isinstance(T, ir.IntType):
cmp = b.icmp_signed("<", arr_j_val, arr_min_val)
else:
cmp = b.fcmp_ordered("<", arr_j_val, arr_min_val)

with b.if_then(cmp):
min_idx = j_phi # update min_idx

j_next = b.add(j_phi, ir.Constant(i32, 1))
j_phi.add_incoming(j_next, b_inner_latch)
min_idx.add_incoming(min_idx, b_inner_latch) # propagate current min_idx
b.branch(b_inner)

# Swap arr[i] and arr[min_idx]
b.position_at_end(b_swap)
i64 = b.sext(i_phi, i64)
min64 = b.sext(min_idx, i64)
ptr_i = b.gep(arr, [i64], inbounds=True)
ptr_min = b.gep(arr, [min64], inbounds=True)
val_i = b.load(ptr_i)
val_min = b.load(ptr_min)
b.store(val_min, ptr_i)
b.store(val_i, ptr_min)

# Increment i
i_next = b.add(i_phi, ir.Constant(i32, 1))
i_phi.add_incoming(i_next, b_swap)
b.branch(b_outer)

# Exit
b.position_at_end(b_exit)
b.ret_void()

return str(mod)


def _materialize_selection(dtype: str) -> int:
_ensure_target_machine()

name = f"selection_sort_{dtype}"
if dtype in _fn_ptr_cache:
return _fn_ptr_cache[dtype]

try:
llvm_ir = _build_selection_sort_ir(dtype)
mod = binding.parse_assembly(llvm_ir)
mod.verify()

engine = binding.create_mcjit_compiler(mod, _target_machine)
engine.finalize_object()
engine.run_static_constructors()

addr = engine.get_function_address(name)
if not addr:
raise RuntimeError(f"Failed to get address for {name}")

_fn_ptr_cache[dtype] = addr
_engines[dtype] = engine
return addr
except Exception as e:
raise RuntimeError(f"Failed to materialize function for dtype {dtype}: {e}")
Loading
Loading