Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
static PyMethodDef algorithms_PyMethodDef[] = {
{"quick_sort", (PyCFunction) quick_sort,
METH_VARARGS | METH_KEYWORDS, ""},
{"quick_sort_llvm", (PyCFunction)quick_sort_llvm,
METH_VARARGS | METH_KEYWORDS, ""},
{"bubble_sort", (PyCFunction) bubble_sort,
METH_VARARGS | METH_KEYWORDS, ""},
{"bubble_sort_llvm", (PyCFunction)bubble_sort_llvm,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ def get_bubble_sort_ptr(dtype: str) -> int:

return _materialize(dtype)


def get_quick_sort_ptr(dtype: str) -> int:
dtype = dtype.lower().strip()
if dtype not in _SUPPORTED:
raise ValueError(f"Unsupported dtype '{dtype}'. Supported: {list(_SUPPORTED)}")

return _materialize_quick(dtype)

def _build_bubble_sort_ir(dtype: str) -> str:
if dtype not in _SUPPORTED:
raise ValueError(f"Unsupported dtype '{dtype}'. Supported: {list(_SUPPORTED)}")
Expand Down Expand Up @@ -131,6 +139,134 @@ def _build_bubble_sort_ir(dtype: str) -> str:

return str(mod)


def _build_quick_sort_ir(dtype: str) -> str:
if dtype not in _SUPPORTED:
raise ValueError(f"Unsupported dtype '{dtype}'. Supported: {list(_SUPPORTED)}")

T, _ = _SUPPORTED[dtype]
i32 = ir.IntType(32)
i64 = ir.IntType(64)

mod = ir.Module(name=f"quick_sort_{dtype}_module")
fn_name = f"quick_sort_{dtype}"

# void quick_sort(T* arr, int32 low, int32 high)
fn_ty = ir.FunctionType(ir.VoidType(), [T.as_pointer(), i32, i32])
fn = ir.Function(mod, fn_ty, name=fn_name)
arr, low, high = fn.args
arr.name, low.name, high.name = "arr", "low", "high"

entry = fn.append_basic_block("entry")
part = fn.append_basic_block("partition")
exit = fn.append_basic_block("exit")

b = ir.IRBuilder(entry)

# if (low < high)
cond = b.icmp_signed("<", low, high)
b.cbranch(cond, part, exit)

# --- Partition block
b.position_at_end(part)

# pivot = arr[high]
high_64 = b.sext(high, i64)
pivot_ptr = b.gep(arr, [high_64])
pivot = b.load(pivot_ptr, name="pivot")

# i = low - 1
i = b.alloca(i32, name="i")
i_init = b.sub(low, ir.Constant(i32, 1))
b.store(i_init, i)

# j = low
j = b.alloca(i32, name="j")
b.store(low, j)

loop = fn.append_basic_block("loop")
after_loop = fn.append_basic_block("after_loop")
body = fn.append_basic_block("body")
swap = fn.append_basic_block("swap")
skip_swap = fn.append_basic_block("skip_swap")

b.branch(loop)

# --- Loop: while (j < high)
b.position_at_end(loop)
j_val = b.load(j)
cond = b.icmp_signed("<", j_val, high)
b.cbranch(cond, body, after_loop)

# --- Body
b.position_at_end(body)
j64 = b.sext(j_val, i64)
elem_ptr = b.gep(arr, [j64])
elem = b.load(elem_ptr, name="elem")

if isinstance(T, ir.IntType):
cmp = b.icmp_signed("<=", elem, pivot)
else:
cmp = b.fcmp_ordered("<=", elem, pivot, fastmath=True)

b.cbranch(cmp, swap, skip_swap)

# --- Swap block
b.position_at_end(swap)
i_val = b.load(i)
i_next = b.add(i_val, ir.Constant(i32, 1))
b.store(i_next, i)

i64v = b.sext(i_next, i64)
iptr = b.gep(arr, [i64v])
ival = b.load(iptr)
# swap arr[i] and arr[j]
b.store(elem, iptr)
b.store(ival, elem_ptr)

b.branch(skip_swap)

# --- Skip swap
b.position_at_end(skip_swap)
j_next = b.add(j_val, ir.Constant(i32, 1))
b.store(j_next, j)
b.branch(loop)

# --- After loop
b.position_at_end(after_loop)
i_val = b.load(i)
i_plus1 = b.add(i_val, ir.Constant(i32, 1))

i64v = b.sext(i_plus1, i64)
iptr = b.gep(arr, [i64v])
ival = b.load(iptr)

# swap arr[i+1] and arr[high]
b.store(pivot, iptr)
b.store(ival, pivot_ptr)

# Now i+1 is the partition index
pi = i_plus1

# Recursive calls:
# quick_sort(arr, low, pi - 1)
low_call = low
high_call1 = b.sub(pi, ir.Constant(i32, 1))
b.call(fn, [arr, low_call, high_call1])

# quick_sort(arr, pi + 1, high)
low_call2 = b.add(pi, ir.Constant(i32, 1))
high_call2 = high
b.call(fn, [arr, low_call2, high_call2])

b.branch(exit)

# --- Exit
b.position_at_end(exit)
b.ret_void()

return str(mod)

def _materialize(dtype: str) -> int:
_ensure_target_machine()

Expand Down Expand Up @@ -167,3 +303,42 @@ def _materialize(dtype: str) -> int:

except Exception as e:
raise RuntimeError(f"Failed to materialize function for dtype {dtype}: {e}")


def _materialize_quick(dtype: str) -> int:
_ensure_target_machine()

key = f"quick_{dtype}"
if key in _fn_ptr_cache:
return _fn_ptr_cache[key]

try:
llvm_ir = _build_quick_sort_ir(dtype)
mod = binding.parse_assembly(llvm_ir)
mod.verify()

try:
pm = binding.ModulePassManager()
pm.add_instruction_combining_pass()
pm.add_reassociate_pass()
pm.add_gvn_pass()
pm.add_cfg_simplification_pass()
pm.run(mod)
except AttributeError:
pass

engine = binding.create_mcjit_compiler(mod, _target_machine)
engine.finalize_object()
engine.run_static_constructors()

addr = engine.get_function_address(f"quick_sort_{dtype}")
if not addr:
raise RuntimeError(f"Failed to get address for quick_sort_{dtype}")

_fn_ptr_cache[key] = addr
_engines[key] = engine

return addr

except Exception as e:
raise RuntimeError(f"Failed to materialize quick sort function for dtype {dtype}: {e}")
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,7 @@ static PyObject* bubble_sort_llvm(PyObject* self, PyObject* args, PyObject* kwds
Py_INCREF(arr_obj);
return arr_obj;
}

// Selection Sort
static PyObject* selection_sort_impl(PyObject* array, size_t lower, size_t upper,
PyObject* comp) {
Expand Down
Loading
Loading