Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#!/usr/bin/env bash
# WSE-3 support contributed by Integrated Reasoning, Inc.
# https://www.integrated-reasoning.com
# david@integrated-reasoning.com

set -e

cslc ./src/layout.csl --arch wse2 --fabric-dims=11,6 --fabric-offsets=4,1 \
cslc ./src/layout.csl --arch wse3 --fabric-dims=11,6 --fabric-offsets=4,1 \
--params=ncols:16,nrows:16,pcols:4,prows:4,max_local_nnz:8 \
--params=max_local_nnz_cols:4,max_local_nnz_rows:4,local_vec_sz:1 \
--params=local_out_vec_sz:1,y_pad_start_row_idx:4 -o=out \
Expand Down
33 changes: 17 additions & 16 deletions benchmarks/spmv-hypersparse/src/hypersparse_spmv/pe.csl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
param f_callback : fn ()void;

param input_queues:[4]u16;
param output_queues:[2]u16;
param output_queues:[4]u16;

// explicit DSR allocation
param dest_dsr_ids:[6]u16;
Expand Down Expand Up @@ -176,16 +176,16 @@ var tsc_reduce_end_buffer = @zeros([timestamp.tsc_size_words]u16);
// var TSC_VALUE_TO_WAIT_UNTIL = [3]u16 { 0x9c40, 0x0, 0x0 }; // 40K cycles
var TSC_VALUE_TO_WAIT_UNTIL = [3]u16 { 0x3e8, 0x0, 0x0 }; // 1K cycles

// WARNING: reserve input/output queue 0 for memcpy module
// WARNING: input/output queues must avoid reserved queues.
// uthreads for fabric data movement
const RX_NORTH_Q: u16 = input_queues[0];
const RX_SOUTH_Q: u16 = input_queues[1];
const TX_NORTH_Q: u16 = output_queues[0];
const TX_SOUTH_Q: u16 = output_queues[1];
// reduction trains, corresponding rx and tx are not active simultaneously
// NOTE: the two phases are exclusive, so uthreads can actually be reused from north-south
const TX_WEST_Q: u16 = output_queues[0];
const TX_EAST_Q: u16 = output_queues[1];
const TX_WEST_Q: u16 = output_queues[2];
const TX_EAST_Q: u16 = output_queues[3];
const RX_WEST_Q: u16 = input_queues[2];
const RX_EAST_Q: u16 = input_queues[3];

Expand Down Expand Up @@ -292,49 +292,41 @@ const rx_south_dsd = @get_dsd(fabin_dsd, .{
});
const tx_north_dsd = @get_dsd(fabout_dsd, .{
.extent = local_vec_sz, // fp32 => 1 per wavelet
.fabric_color = north_train,
.output_queue = @get_output_queue(TX_NORTH_Q),
});
const tx_south_dsd = @get_dsd(fabout_dsd, .{
.extent = local_vec_sz,
.fabric_color = south_train,
.output_queue = @get_output_queue(TX_SOUTH_Q),
});
const tx_north_ctrl_adv_dsd = @get_dsd(fabout_dsd, .{
.extent = 2, // two switch wavelets
.control = true,
.fabric_color = north_train,
.output_queue = @get_output_queue(TX_NORTH_Q),
});
const tx_south_ctrl_adv_dsd = @get_dsd(fabout_dsd, .{
.extent = 2, // two switch wavelets
.control = true,
.fabric_color = south_train,
.output_queue = @get_output_queue(TX_SOUTH_Q),
});
const tx_north_ctrl_rst_dsd = @get_dsd(fabout_dsd, .{
.extent = 1, // two switch wavelets
.control = true,
.fabric_color = north_train,
.output_queue = @get_output_queue(TX_NORTH_Q),
});
const tx_south_ctrl_rst_dsd = @get_dsd(fabout_dsd, .{
.extent = 1, // two switch wavelets
.control = true,
.fabric_color = south_train,
.output_queue = @get_output_queue(TX_SOUTH_Q),
});

// 2. reduce phase: west and east trains for partial output vectors (sparse: vals + rows)
const tx_west_dsd = @get_dsd(fabout_dsd, .{
.extent = 1,
.fabric_color = tx_west_train,
.output_queue = @get_output_queue(TX_WEST_Q),
});

const tx_east_dsd = @get_dsd(fabout_dsd, .{
.extent = 1,
.fabric_color = tx_east_train,
.output_queue = @get_output_queue(TX_EAST_Q),
});

Expand Down Expand Up @@ -1888,16 +1880,25 @@ comptime {
// the compiler no longer can generate the instruction to set up the
// config register of input queue.
comptime {
// color south_train maps to RX_NORTH_Q: u16 = 4;
// color north_train maps to RX_SOUTH_Q: u16 = 1;
// color rx_east_train maps to RX_WEST_Q: u16 = 6;
// color rx_west_train maps to RX_EAST_Q: u16 = 7;
// color south_train maps to RX_NORTH_Q: u16 = 2;
// color north_train maps to RX_SOUTH_Q: u16 = 3;
// color rx_east_train maps to RX_WEST_Q: u16 = 4;
// color rx_west_train maps to RX_EAST_Q: u16 = 5;
@initialize_queue(@get_input_queue(RX_NORTH_Q), .{.color = south_train});
@initialize_queue(@get_input_queue(RX_SOUTH_Q), .{.color = north_train});
@initialize_queue(@get_input_queue(RX_WEST_Q), .{.color = rx_east_train});
@initialize_queue(@get_input_queue(RX_EAST_Q), .{.color = rx_west_train});
}

comptime {
if (@is_arch("wse3")) {
@initialize_queue(@get_output_queue(TX_NORTH_Q), .{.color = north_train});
@initialize_queue(@get_output_queue(TX_SOUTH_Q), .{.color = south_train});
@initialize_queue(@get_output_queue(TX_WEST_Q), .{.color = tx_west_train});
@initialize_queue(@get_output_queue(TX_EAST_Q), .{.color = tx_east_train});
}
}

comptime {

const north_train_route = .{
Expand Down
64 changes: 37 additions & 27 deletions benchmarks/spmv-hypersparse/src/kernel.csl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ param memcpyParams;

param spmvParams;

param reduceParams;

// parameters
param nrows: u32; // total number of matrix rows
param ncols: u32; // total number of matrix cols (= nrows)
Expand Down Expand Up @@ -51,14 +49,11 @@ var local_nnz_rows = @zeros([1]u16); // actual local number of nnz rows
// final reduced local output vector (dense)
var y_local_buf = @zeros([local_out_vec_sz]f32);

// temporary buffer for allreduce
var dot = @zeros([1]f32);

const timestamp = @import_module("<time>");

const sys_mod = @import_module( "<memcpy/memcpy>", memcpyParams);

// input_queues cannot overlap with output_queues
// input_queues/output_queues must avoid memcpy-reserved queues.
const spmv_mod = @import_module( "hypersparse_spmv/pe.csl", .{
.spmv_params = spmvParams,
.f_callback = sys_mod.unblock_cmd_stream,
Expand All @@ -82,33 +77,22 @@ const spmv_mod = @import_module( "hypersparse_spmv/pe.csl", .{
.local_nnz_cols = &local_nnz_cols,
.local_nnz_rows = &local_nnz_rows,

.input_queues=[4]u16{4, 1, 6, 7},
.output_queues=[2]u16{2,3},
.input_queues=[4]u16{2, 3, 4, 5},
.output_queues=[4]u16{4, 5, 2, 3},
.dest_dsr_ids = [6]u16{1, 4, 5, 6, 2, 3},
.src1_dsr_ids = [6]u16{4, 1, 6, 7, 2, 3},
});

// allreduce uses input queue/output queue 5
// dest_dsr and src0_dsr must be a valid pair, for example (7,1) is invalid
const reduce_mod = @import_module( "allreduce2R1E/pe.csl", .{
.reduce_params = reduceParams,
.f_callback = sys_mod.unblock_cmd_stream,
.MAX_ZDIM = 1,
.queues = [1]u16{5},
.dest_dsr_ids = [1]u16{7},
.src0_dsr_ids = [1]u16{7},
.src1_dsr_ids = [1]u16{5}
});

// tsc library
var tsc_start_buffer = @zeros([timestamp.tsc_size_words]u16);
var tsc_end_buffer = @zeros([timestamp.tsc_size_words]u16);
var tsc_wait_until = [timestamp.tsc_size_words]u16{ 0x3e8, 0x0, 0x0 };

// time_buf_u16[0:5] = {tsc_start_buffer, tsc_end_buffer}
var time_buf_u16 = @zeros([timestamp.tsc_size_words*2]u16);
var ptr_time_buf_u16: [*]u16 = &time_buf_u16;

// reference clock inside allreduce module
// reference clock for host timing alignment
var time_ref_u16 = @zeros([timestamp.tsc_size_words]u16);
var ptr_time_ref_u16: [*]u16 = &time_ref_u16;

Expand Down Expand Up @@ -171,15 +155,41 @@ fn f_memcpy_timestamps() void {
sys_mod.unblock_cmd_stream();
}

fn f_sync( n: i16 ) void {
reduce_mod.allreduce(n, &dot);
fn is_less_than(aval: *[timestamp.tsc_size_words]u16, bval: *[timestamp.tsc_size_words]u16) bool {
if ((aval.*)[2] < (bval.*)[2]) {
return true;
} else if ((aval.*)[2] == (bval.*)[2]) {
if ((aval.*)[1] < (bval.*)[1]) {
return true;
} else if ((aval.*)[1] == (bval.*)[1]) {
if ((aval.*)[0] < (bval.*)[0]) {
return true;
}
}
}
return false;
}

fn f_reference_timestamps() void {
fn f_sync(n: i16) void {
if (n == -1) {
time_ref_u16[0] = time_ref_u16[0];
}
var curr = @zeros([timestamp.tsc_size_words]u16);
while (true) {
timestamp.get_timestamp(&curr);
if (!is_less_than(&curr, &tsc_wait_until)) {
break;
}
}
time_ref_u16[0] = curr[0];
time_ref_u16[1] = curr[1];
time_ref_u16[2] = curr[2];

time_ref_u16[0] = reduce_mod.tscRefBuffer[0];
time_ref_u16[1] = reduce_mod.tscRefBuffer[1];
time_ref_u16[2] = reduce_mod.tscRefBuffer[2];
sys_mod.unblock_cmd_stream();
}

fn f_reference_timestamps() void {
timestamp.get_timestamp(&time_ref_u16);

// the user must unblock cmd color for every PE
sys_mod.unblock_cmd_stream();
Expand Down
22 changes: 3 additions & 19 deletions benchmarks/spmv-hypersparse/src/layout.csl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
// 4 c3 14 tx_south 24 compute_local 34 reserved (memcpy)
// 5 c4 15 rx_north 25 curr_rx_north_done 35 reserved (memcpy)
// 6 c5 16 rx_south 26 curr_rx_south_done 36 reserved (memcpy)
// 7 allreduce_c0 17 rx_east 27 reserved (memcpy) 37 reserved (memcpy)
// 8 allreduce_c1 18 rx_west 28 reserved (memcpy)
// 9 allreduce_EN1 19 tx_west 29 reserved (memcpy)
// 7 17 rx_east 27 reserved (memcpy) 37 reserved (memcpy)
// 8 18 rx_west 28 reserved (memcpy)
// 9 19 tx_west 29 reserved (memcpy)

// routable colors for spmv
param c0 = @get_color(1);
Expand All @@ -33,12 +33,6 @@ param c3 = @get_color(4);
param c4 = @get_color(5);
param c5 = @get_color(6);

// routable colors for allreduce
param allreduce_c0 = @get_color(7);
param allreduce_c1 = @get_color(8);
// entrypoint for allreduce
param allreduce_EN1: local_task_id = @get_local_task_id(9);

// entrypoints for spmv
param EN1: local_task_id = @get_local_task_id(10);
param EN2: local_task_id = @get_local_task_id(11);
Expand Down Expand Up @@ -82,14 +76,6 @@ const spmv = @import_module( "hypersparse_spmv/layout.csl", .{
.width = pcols,
.height = prows
});

const reduce = @import_module( "allreduce2R1E/layout.csl", .{
.colors = [2]color{allreduce_c0, allreduce_c1},
.entrypoints = [1]local_task_id{allreduce_EN1},
.width = pcols,
.height = prows
});

const memcpy = @import_module( "<memcpy/get_params>", .{
.width = pcols,
.height = prows,
Expand All @@ -115,11 +101,9 @@ layout {

const memcpyParams = memcpy.get_params(pcol_id);
const spmvParams = spmv.get_params(pcol_id, prow_id);
const reduceParams = reduce.get_params(pcol_id, prow_id);
var params = .{
.memcpyParams = memcpyParams,
.spmvParams = spmvParams,
.reduceParams = reduceParams,
.nrows = nrows,
.ncols = ncols,
.local_vec_sz = local_vec_sz,
Expand Down