Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llvm_passes/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ add_library(LLVMHipPasses MODULE HipPasses.cpp
HipPrintf.cpp HipGlobalVariables.cpp HipCleanup.cpp HipTextureLowering.cpp HipAbort.cpp
HipEmitLoweredNames.cpp HipWarps.cpp HipKernelArgSpiller.cpp
HipLowerZeroLengthArrays.cpp HipSanityChecks.cpp HipLowerSwitch.cpp
HipLowerMemset.cpp HipIGBADetector.cpp HipPromoteInts.cpp
HipLowerMemset.cpp HipIGBADetector.cpp HipPromoteInts.cpp
HipSpirvFunctionReorderPass.cpp
HipVerify.cpp
${EXTRA_OBJS})
Expand Down
69 changes: 39 additions & 30 deletions llvm_passes/HipPromoteInts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1238,15 +1238,18 @@ static bool needsKernelArgPromotion(Type *T) {
}

/// Promote narrow (< i32) integer kernel arguments to i32 for SPIR-V/OpenCL
/// conformance. Creates a wrapper kernel with i32 args that truncates them
/// back to the original type before calling the original function body.
/// conformance. Modifies the kernel function in-place: replaces narrow
/// params with i32 and inserts trunc instructions at the entry block.
///
/// Previous implementation created a wrapper+unpromoted-function pattern
/// which broke subgroup shuffle operations on Intel GPUs (IGC bug: shuffles
/// don't work in called non-kernel functions).
static bool promoteKernelArgs(Module &M) {
bool Changed = false;

SmallVector<Function *, 8> WorkList;
for (auto &F : M)
if (F.getCallingConv() == CallingConv::SPIR_KERNEL && !F.isDeclaration()) {
// Check if any argument needs promotion
for (const auto &Arg : F.args()) {
if (needsKernelArgPromotion(Arg.getType())) {
WorkList.push_back(&F);
Expand All @@ -1261,14 +1264,11 @@ static bool promoteKernelArgs(Module &M) {

// Build new argument type list, promoting narrow ints to i32
SmallVector<Type *, 8> NewArgTys;
SmallVector<unsigned, 4> PromotedArgIndices;
for (auto &Arg : F->args()) {
if (needsKernelArgPromotion(Arg.getType())) {
if (needsKernelArgPromotion(Arg.getType()))
NewArgTys.push_back(Type::getInt32Ty(M.getContext()));
PromotedArgIndices.push_back(Arg.getArgNo());
} else {
else
NewArgTys.push_back(Arg.getType());
}
}

// Create new kernel function with promoted signature
Expand All @@ -1278,32 +1278,41 @@ static bool promoteKernelArgs(Module &M) {
F->getAddressSpace(), "", F->getParent());
NewF->copyAttributesFrom(F);
NewF->takeName(F);

// Demote the original kernel to a regular internal function
F->setName(NewF->getName() + ".unpromoted");
F->setCallingConv(CallingConv::SPIR_FUNC);
F->setLinkage(GlobalValue::InternalLinkage);

// Build the wrapper body: truncate promoted args back, then call original
IRBuilder<> B(BasicBlock::Create(F->getContext(), "entry", NewF));
auto *RI = B.CreateRetVoid();
B.SetInsertPoint(RI);

SmallVector<Value *, 8> CallArgs;
for (auto &OrigArg : F->args()) {
auto *NewArg = NewF->getArg(OrigArg.getArgNo());
if (needsKernelArgPromotion(OrigArg.getType())) {
// Truncate i32 back to the original narrow type
auto *Trunc = B.CreateTrunc(NewArg, OrigArg.getType());
CallArgs.push_back(Trunc);
LLVM_DEBUG(dbgs() << " Arg " << OrigArg.getArgNo() << ": "
<< *OrigArg.getType() << " -> i32 (with trunc)\n");
NewF->setCallingConv(F->getCallingConv());
NewF->copyMetadata(F, 0);

// Move the function body from the old function to the new one.
NewF->splice(NewF->begin(), F);

// Replace old args with new args, inserting truncations where needed.
auto NewArgIt = NewF->arg_begin();
for (auto &OldArg : F->args()) {
NewArgIt->setName(OldArg.getName());
if (needsKernelArgPromotion(OldArg.getType())) {
// Insert trunc at the start of the entry block.
IRBuilder<> B(&*NewF->getEntryBlock().getFirstInsertionPt());
Value *Trunc = B.CreateTrunc(&*NewArgIt, OldArg.getType(),
OldArg.getName() + ".trunc");
OldArg.replaceAllUsesWith(Trunc);
LLVM_DEBUG(dbgs() << " Arg " << OldArg.getArgNo() << ": "
<< *OldArg.getType() << " -> i32 (with trunc)\n");
} else {
CallArgs.push_back(NewArg);
OldArg.replaceAllUsesWith(&*NewArgIt);
}
++NewArgIt;
}

// Remove parameter attributes specific to narrow types.
AttributeList Attrs = NewF->getAttributes();
for (unsigned I = 0; I < F->arg_size(); I++) {
if (needsKernelArgPromotion(F->getArg(I)->getType())) {
Attrs = Attrs.removeParamAttribute(M.getContext(), I, Attribute::SExt);
Attrs = Attrs.removeParamAttribute(M.getContext(), I, Attribute::ZExt);
}
}
NewF->setAttributes(Attrs);

B.CreateCall(F, CallArgs);
F->eraseFromParent();
Changed = true;
}

Expand Down
1 change: 1 addition & 0 deletions tests/runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ add_hip_runtime_test(TestTypeCastIntrinsics.hip)
add_hip_runtime_test(TestHipLaunchHostFunc.cpp)

add_hip_runtime_test(TestBoolKernelParam.hip)
add_hip_runtime_test(TestBoolParamShuffle.cpp)
find_program(SPIRV_DIS spirv-dis)
if(SPIRV_DIS)
add_shell_test(TestBoolKernelParamSPIRV.bash)
Expand Down
117 changes: 117 additions & 0 deletions tests/runtime/TestBoolParamShuffle.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/**
* Regression test for IGC bug: subgroup shuffle returns wrong results
* when any kernel function parameter has OpTypeBool in SPIR-V.
*
* Two kernels perform an inclusive prefix sum (warp scan) via __shfl_up_sync:
* - scan_int_param: 4th parameter is int (should PASS)
* - scan_bool_param: 4th parameter is bool (FAIL without workaround)
*
* The bool parameter is unused by the scan logic; its mere presence in the
* function signature triggers the IGC miscompilation.
*/
#include <hip/hip_runtime.h>
#include <cstdio>
#include <cstdlib>
#include <cstring>

static constexpr int WARP_SIZE = 32;
static constexpr int N = 64; // two warps

/// Inclusive prefix sum using __shfl_up_sync -- 4th param is int.
__global__ void scan_int_param(const int *in, int *out, int n,
int /*unused_flag*/) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= n)
return;

int val = in[tid];
for (int offset = 1; offset < WARP_SIZE; offset <<= 1) {
int up = __shfl_up(val, offset, WARP_SIZE);
if ((tid % WARP_SIZE) >= offset)
val += up;
}
out[tid] = val;
}

/// Inclusive prefix sum using __shfl_up_sync -- 4th param is bool.
__global__ void scan_bool_param(const int *in, int *out, int n,
bool /*unused_flag*/) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= n)
return;

int val = in[tid];
for (int offset = 1; offset < WARP_SIZE; offset <<= 1) {
int up = __shfl_up(val, offset, WARP_SIZE);
if ((tid % WARP_SIZE) >= offset)
val += up;
}
out[tid] = val;
}

static bool verify(const char *label, const int *out, int n) {
bool pass = true;
int first_fail = -1;
for (int i = 0; i < n; i++) {
// Expected: inclusive prefix sum within each warp.
// Input is all-ones so expected[i] = (i % WARP_SIZE) + 1.
int expected = (i % WARP_SIZE) + 1;
if (out[i] != expected) {
if (first_fail < 0)
first_fail = i;
pass = false;
}
}
if (pass) {
printf("%s: PASS\n", label);
} else {
printf("%s: FAIL (first mismatch at [%d]: got %d, expected %d)\n", label,
first_fail, out[first_fail], (first_fail % WARP_SIZE) + 1);
}
return pass;
}

int main() {
// Skip on devices that don't support the Intel SPIR-V subgroup ops chipStar
// emits for __shfl_up. e.g. Mali-G52 has cl_khr_subgroup_shuffle but not
// cl_intel_subgroups, so the kernel SPIR-V is rejected at clBuildProgram.
// Detect by device name since chipStar reports warpSize=32 unconditionally.
hipDeviceProp_t prop;
hipGetDeviceProperties(&prop, 0);
if (std::strstr(prop.name, "Mali") != nullptr) {
printf("HIP_SKIP_THIS_TEST: device '%s' lacks SPV_INTEL_subgroups "
"support\n",
prop.name);
return CHIP_SKIP_TEST;
}

int h_in[N], h_out[N];
for (int i = 0; i < N; i++)
h_in[i] = 1;

int *d_in, *d_out;
hipMalloc(&d_in, N * sizeof(int));
hipMalloc(&d_out, N * sizeof(int));
hipMemcpy(d_in, h_in, N * sizeof(int), hipMemcpyHostToDevice);

int blocks = (N + 63) / 64;

// Test 1: int parameter (baseline, should always pass)
hipMemset(d_out, 0, N * sizeof(int));
scan_int_param<<<blocks, 64>>>(d_in, d_out, N, 0);
hipDeviceSynchronize();
hipMemcpy(h_out, d_out, N * sizeof(int), hipMemcpyDeviceToHost);
bool pass1 = verify("scan_int_param", h_out, N);

// Test 2: bool parameter (triggers IGC bug without workaround)
hipMemset(d_out, 0, N * sizeof(int));
scan_bool_param<<<blocks, 64>>>(d_in, d_out, N, false);
hipDeviceSynchronize();
hipMemcpy(h_out, d_out, N * sizeof(int), hipMemcpyDeviceToHost);
bool pass2 = verify("scan_bool_param", h_out, N);

hipFree(d_in);
hipFree(d_out);

return (pass1 && pass2) ? 0 : 1;
}
Loading