Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/dawn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ if(NOT DAWN_BUILD_FOUND)

# Ensure source present on required commit (idempotent remote setup)
if(NOT DEFINED DAWN_COMMIT OR DAWN_COMMIT STREQUAL "")
set(DAWN_COMMIT "e1d6e12337080cf9f6d8726209e86df449bc6e9a" CACHE STRING "Dawn commit to checkout" FORCE)
set(DAWN_COMMIT "3f79f3aefe0b0a498002564fcfb13eb21ab6c047" CACHE STRING "Dawn commit to checkout" FORCE)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

google/dawn@d7d27a6
Required to set subgroupsize to 32 on MacOS.

endif()
file(MAKE_DIRECTORY ${DAWN_DIR})
execute_process(COMMAND git init WORKING_DIRECTORY "${DAWN_DIR}")
Expand Down
225 changes: 182 additions & 43 deletions examples/matmul/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,93 @@ inline KernelCode createMatmulWithTranspose(const char *shaderTemplate, const si
return {unrolledCode, workgroupSize, precision};
}

inline KernelCode createMatmul12(const char *shaderTemplate, const size_t M,
const size_t K, const size_t N,
const size_t TM, const size_t TN,
const size_t LID,
const Shape &workgroupSize = {256, 1, 1},
NumType precision = kf32) {
std::string codeString(shaderTemplate);
replaceAll(codeString, {{"{{precision}}", toString(precision)},
{"{{M}}", toString(M)},
{"{{K}}", toString(K)},
{"{{N}}", toString(N)},
{"{{TM}}", toString(TM)},
{"{{TN}}", toString(TN)},
{"{{LID}}", toString(LID)}
});
return {loopUnrolling(codeString), workgroupSize, precision};
}

// ─────────────────────────────────────────────────────────────────────────────
// Optimised WGSL matrix‑multiply kernel using subgroupMatrixLoad/Store
// and subgroupMatrixMultiplyAccumulate
// ─────────────────────────────────────────────────────────────────────────────
const char* kShaderSubgroupMatrixMultiply = R"(
enable chromium_experimental_subgroup_matrix;
diagnostic (off, chromium.subgroup_matrix_uniformity);

@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
@group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;

@compute @workgroup_size({{workgroupSize}})
fn main(@builtin(workgroup_id) wg: vec3<u32>,
@builtin(local_invocation_id) localID : vec3<u32>) {

let rowStart: u32 = wg.x * 8u * {{TM}};
let colStart: u32 = (wg.y * {{LID}} + localID.y) * 8u * {{TN}};

let baseA: u32 = rowStart * {{K}};
let baseB: u32 = colStart;
let cBase: u32 = rowStart * {{N}} + colStart;

var Ax: array<subgroup_matrix_left<{{precision}}, 8, 8>, {{TM}}>;
var Bx: array<subgroup_matrix_right<{{precision}}, 8, 8>, {{TN}}>;

// 4x4 accumulators (8x8 each)
var accxx: array<subgroup_matrix_result<{{precision}}, 8, 8>, {{TM}} * {{TN}}>;

for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
Ax[idx_i] = subgroup_matrix_left<{{precision}}, 8, 8>(0);
}

for (var idx_i: u32 = 0; idx_i < {{TN}}; idx_i++) {
Bx[idx_i] = subgroup_matrix_right<{{precision}}, 8, 8>(0);
}

for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
for (var idx_j: u32 = 0; idx_j < {{TN}}; idx_j++) {
accxx[idx_i+idx_j*{{TM}}] = subgroup_matrix_result<{{precision}}, 8, 8>(0);
}
}

for (var k: u32 = 0u; k < {{K}}; k = k + 8u) {
workgroupBarrier();
for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
Ax[idx_i] = subgroupMatrixLoad<subgroup_matrix_left<{{precision}},8,8>>(&A, baseA + k + 8u * {{K}} * idx_i, false, {{K}});
}

for (var idx_i: u32 = 0; idx_i < {{TN}}; idx_i++) {
Bx[idx_i] = subgroupMatrixLoad<subgroup_matrix_right<{{precision}},8,8>>(&B, baseB + k * {{N}} + 8u * idx_i, false, {{N}});
}

for (var idx_j: u32 = 0; idx_j < {{TN}}; idx_j++) {
for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
accxx[idx_j*{{TM}} + idx_i] = subgroupMatrixMultiplyAccumulate(Ax[idx_i], Bx[idx_j], accxx[idx_j*{{TM}} + idx_i]);
}
}
}

workgroupBarrier();
for (var idx_i: u32 = 0; idx_i < {{TM}}; idx_i++) {
for (var idx_j: u32 = 0; idx_j < {{TN}}; idx_j++) {
subgroupMatrixStore(&C, cBase + idx_i * 8u * {{N}} + 8u * idx_j, accxx[idx_j*{{TM}} + idx_i], false, {{N}});
}
}
}
)";

/**
* @brief No-Op shader with matmul bindings for performance testing
*/
Expand Down Expand Up @@ -683,26 +770,30 @@ Kernel selectMatmul(Context &ctx, int version,
const Bindings</* input, weights, output */ 3> &bindings,
size_t M, size_t K, size_t N, NumType numtype) {
Kernel kernel;
CompilationInfo info;
if (version == 1) {
Shape wgSize = {256, 1, 1};
Shape nWorkgroups = cdiv({M, N, 1}, {16, 16, 1});
KernelCode matmul = createNoOp(kShaderNoOp, /*wgsize*/ wgSize);
kernel = createKernel(ctx, matmul, bindings,
/*nWorkgroups*/ nWorkgroups);
/*nWorkgroups*/ nWorkgroups,
NoParam{}, &info);
} else if (version == 2) {
Shape wgSize = {16, 16, 1};
LOG(kDefLog, kInfo, "wgSize: %s", toString(wgSize).c_str());
KernelCode matmul =
createMatmul1(kShaderMatmul1, M, K, N, /*wgsize*/ wgSize, numtype);
kernel = createKernel(ctx, matmul, bindings,
/*nWorkgroups*/ cdiv({M, N, 1}, wgSize));
/*nWorkgroups*/ cdiv({M, N, 1}, wgSize),
NoParam{}, &info);
} else if (version == 3) {
static constexpr size_t tileSize = 16;
KernelCode matmul = createMatmul2(kShaderMatmul2, M, K, N,
/*wgSize*/ {tileSize * tileSize, 1, 1}, numtype);
kernel =
createKernel(ctx, matmul, bindings,
/* nWorkgroups*/ cdiv({M, N, 1}, {tileSize, tileSize, 1}));
/* nWorkgroups*/ cdiv({M, N, 1}, {tileSize, tileSize, 1}),
NoParam{}, &info);
} else if (version == 4 || version == 6) {
static constexpr size_t BM = 64;
static constexpr size_t BK = 4;
Expand All @@ -721,7 +812,8 @@ Kernel selectMatmul(Context &ctx, int version,
numtype,
/*Loop unrolling*/ version == 6 ? true: false);
kernel = createKernel(ctx, matmul, bindings,
/*nWorkgroups*/ nWorkgroups);
/*nWorkgroups*/ nWorkgroups,
NoParam{}, &info);
} else if (version == 5 || version == 7) {
static constexpr size_t BM = 64;
static constexpr size_t BK = 8;
Expand All @@ -739,7 +831,8 @@ Kernel selectMatmul(Context &ctx, int version,
numtype,
/*Loop unrolling*/ version == 7 ? true: false);
kernel = createKernel(ctx, matmul, bindings,
/*nWorkgroups*/ nWorkgroups);
/*nWorkgroups*/ nWorkgroups,
NoParam{}, &info);
} else if (version == 8 || version == 10) {
static constexpr size_t BM = 64;
static constexpr size_t BK = 8;
Expand All @@ -757,7 +850,8 @@ Kernel selectMatmul(Context &ctx, int version,
numtype,
/*Loop unrolling*/ true);
kernel = createKernel(ctx, matmul, bindings,
/*nWorkgroups*/ nWorkgroups);
/*nWorkgroups*/ nWorkgroups,
NoParam{}, &info);
} else if (version == 9 || version == 11) {
static constexpr size_t BM = 64;
static constexpr size_t BK = 8;
Expand All @@ -774,8 +868,38 @@ Kernel selectMatmul(Context &ctx, int version,
/*wgSize*/ wgSize,
numtype);
kernel = createKernel(ctx, matmul, bindings,
/*nWorkgroups*/ nWorkgroups);
/*nWorkgroups*/ nWorkgroups,
NoParam{}, &info);
} else if (version == 12 || version == 13) {
// f16: Subgroup matrix multiply
static constexpr size_t TM = 4;
static constexpr size_t TN = 8;
static constexpr size_t LID = 2;
Shape wgSize = {32, LID, 1}; // One subgroup per workgroup
Shape nWorkgroups = {cdiv(M, 8 * TM), cdiv(N, 8 * TN * LID), 1};
LOG(kDefLog, kInfo, "M: %zu, K: %zu, N: %zu", M, K, N);
LOG(kDefLog, kInfo, "wgSize: ( %s )", toString(wgSize).c_str());
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
KernelCode matmul = createMatmul12(kShaderSubgroupMatrixMultiply, M, K, N, TM, TN, LID, wgSize, numtype);
kernel = createKernel(ctx, matmul, bindings, nWorkgroups,
NoParam{}, &info);
}

if (info.status != WGPUCompilationInfoRequestStatus_Success) {
LOG(kDefLog, kError, "Failed to compile shader");
for (size_t i = 0; i < info.messages.size(); i++) {
LOG(kDefLog, kError, "Line %llu, Pos %llu: %s", info.lineNums[i],
info.linePos[i], info.messages[i].c_str());
}
exit(1);
} else {
LOG(kDefLog, kInfo, "Shader compiled successfully");
for (size_t i = 0; i < info.messages.size(); i++) {
LOG(kDefLog, kInfo, "Line %llu, Pos %llu: %s", info.lineNums[i],
info.linePos[i], info.messages[i].c_str());
}
}

return kernel;
}

Expand All @@ -791,41 +915,51 @@ void runTest(int version, size_t M, size_t K, size_t N,
assert(numtype == kf16);
}

// Allocate GPU buffers and copy data
WGPUDeviceDescriptor devDescriptor = {};
devDescriptor.requiredFeatureCount = 1;
devDescriptor.requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data();

Context ctx;
if (numtype == kf16) {
ctx = createContext(
{}, {},
/*device descriptor, enabling f16 in WGSL*/
{
.requiredFeatureCount = 1,
.requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data()
});
if (ctx.adapterStatus != WGPURequestAdapterStatus_Success) {
LOG(kDefLog, kError, "Failed to create adapter with f16 support, try running an f32 test instead (`export MATMUL_VERSION=9).");
exit(1);
static WGPUDawnTogglesDescriptor toggles = {};
toggles.chain.sType = WGPUSType_DawnTogglesDescriptor;
const char* enableList[] = {"allow_unsafe_apis"};
toggles.enabledToggles = enableList;
toggles.enabledToggleCount = 1;

static WGPUDeviceDescriptor devDesc = {};
devDesc.nextInChain = &toggles.chain;
devDesc.requiredFeatureCount = 3,
devDesc.requiredFeatures = std::array{
WGPUFeatureName_ShaderF16,
WGPUFeatureName_Subgroups,
WGPUFeatureName_ChromiumExperimentalSubgroupMatrix
}.data();
devDesc.uncapturedErrorCallbackInfo = WGPUUncapturedErrorCallbackInfo {
.callback = [](WGPUDevice const * device, WGPUErrorType type, WGPUStringView msg, void*, void*) {
LOG(kDefLog, kError, "[Uncaptured %d] %.*s\n", (int)type, (int)msg.length, msg.data);
}
if (ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
LOG(kDefLog, kError, "Failed to create device with f16 support, try running an f32 test instead. (`export MATMUL_VERSION=9)");
exit(1);
};
devDesc.deviceLostCallbackInfo = WGPUDeviceLostCallbackInfo {
.mode = WGPUCallbackMode_AllowSpontaneous,
.callback = [](WGPUDevice const * device, WGPUDeviceLostReason reason, WGPUStringView msg, void*, void*) {
LOG(kDefLog, kError, "[DeviceLost %d] %.*s\n", (int)reason, (int)msg.length, msg.data);
}
}
};

if (numtype == kf32) {
ctx = createContext({}, {}, {});
if (ctx.adapterStatus != WGPURequestAdapterStatus_Success ||
ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
LOG(kDefLog, kError, "Failed to create adapter or device");
// stop execution
exit(1);
} else {
LOG(kDefLog, kInfo, "Successfully created adapter and device");
static WGPULimits requiredLimits = WGPU_LIMITS_INIT;
devDesc.requiredLimits = &requiredLimits;
Context ctx = createContext({}, {}, devDesc);

WGPULoggingCallbackInfo logCb{
.callback = [](WGPULoggingType type, WGPUStringView msg, void*, void*) {
LOG(kDefLog, kError, "[WGPU %d] %.*s\n", (int)type, (int)msg.length, msg.data);
}
}
};
wgpuDeviceSetLoggingCallback(ctx.device, logCb);

if (ctx.adapterStatus != WGPURequestAdapterStatus_Success ||
ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
LOG(kDefLog, kError, "Failed to create adapter or device");
// stop execution
exit(1);
} else {
LOG(kDefLog, kInfo, "Successfully created adapter and device");
}

Tensor input = createTensor(ctx, Shape{M, K}, numtype, inputPtr.get());
Tensor weights = createTensor(ctx, Shape{N, K}, numtype, weightsPtr.get()); // column-major
Expand Down Expand Up @@ -859,7 +993,7 @@ void runTest(int version, size_t M, size_t K, size_t N,
// Use microsecond for more accurate time measurement
auto duration =
std::chrono::duration_cast<std::chrono::microseconds>(end - start);
float gflops = 2 * M * N *
float gflops = 2.0f * M * N *
K / // factor of 2 for multiplication & accumulation
(static_cast<double>(duration.count()) / 1000000.0) /
1000000000.0 * static_cast<float>(nIter);
Expand All @@ -870,7 +1004,7 @@ void runTest(int version, size_t M, size_t K, size_t N,
show<precision>(outputPtr.get(), M, N, "Output[0]").c_str());

LOG(kDefLog, kInfo, "\n\n===================================================================="
"============\nExecution Time: (M = %d, K = %d, N = %d) x %d iterations "
"============\nExecution Time: (M = %zu, K = %zu, N = %zu) x %zu iterations "
":\n%.1f "
"milliseconds / dispatch ~ %.2f "
"GFLOPS\n================================================================"
Expand Down Expand Up @@ -913,13 +1047,16 @@ const std::string versionToStr(int version){
case 9: return "f32: 2D blocktiling with loop unrolling, vectorization and transpose";
case 10: return "f16: 2D blocktiling with loop unrolling and vectorization";
case 11: return "f16: 2D blocktiling with loop unrolling, vectorization and transpose";
case 12: return "f16: Subgroup matrix multiply with transpose (default)";
case 13: return "f32: Subgroup matrix multiply with transpose";
default: return "Not specified";
}
}

int main() {
std::cout << "Starting matmul test..." << std::endl;
char* version_str = getenv("MATMUL_VERSION");
int version = version_str == NULL ? 10 : atoi(version_str);
int version = version_str == NULL ? 12 : atoi(version_str);
// 1 == f32: No-Op
// 2 == f32: naive matmul
// 3 == f32: tiling
Expand All @@ -931,8 +1068,10 @@ int main() {
// 9 == f32: 2D blocktiling with loop unrolling, vectorization and transpose
// 10 == f16: 2D blocktiling with loop unrolling and vectorization (default)
// 11 == f16: 2D blocktiling with loop unrolling, vectorization and transpose
bool enableF16 = version == 10 || version ==11;
bool transposedInput = version == 9 || version == 11;
// 12 == f16: Subgroup matrix multiply with transpose (default)
// 13 == f32: Subgroup matrix multiply with transpose
bool enableF16 = version == 10 || version ==11 || version == 12;
bool transposedInput = version == 9 || version == 11 || version == 12 || version == 13;
NumType numtype = enableF16 ? kf16 : kf32;

size_t M, K, N; // Matrix dimensions
Expand Down
6 changes: 2 additions & 4 deletions gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1580,8 +1580,7 @@ inline void bufferMapCallback(WGPUMapAsyncStatus status, WGPUStringView message,
* and a promise to signal completion.
* @param userdata2 Unused.
*/
inline void queueWorkDoneCallback(WGPUQueueWorkDoneStatus status,
WGPUStringView message,
inline void queueWorkDoneCallback(WGPUQueueWorkDoneStatus status, WGPUStringView message,
void *userdata1, void * /*userdata2*/) {
const CallbackData *cbData = static_cast<CallbackData *>(userdata1);
// Ensure the queue work finished successfully.
Expand Down Expand Up @@ -2824,8 +2823,7 @@ Kernel createKernel(Context &ctx, const KernelCode &code,
* when the work is done.
* @param userdata2 Unused.
*/
inline void dispatchKernelCallback(WGPUQueueWorkDoneStatus status,
WGPUStringView message,
inline void dispatchKernelCallback(WGPUQueueWorkDoneStatus status, WGPUStringView message,
void *userdata1, void * /*userdata2*/) {
// Cast the userdata pointer back to our heap‑allocated promise.
auto *p = reinterpret_cast<std::promise<void> *>(userdata1);
Expand Down
Loading
Loading