Skip to content

Commit

Permalink
fix-kul-cluster-sw-lib (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoling-yi authored Aug 25, 2024
1 parent 17e5546 commit 118182c
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,75 +35,79 @@ int main() {

// Transfer data from L3 to L1
// Using DMA only
if (snrt_is_dm_core()) {
load_conv_input_data(Nbatch, H + 2 * pad_h, W + 2 * pad_w, Cin, local_a,
A);
load_weight_data(Cout, Kh, Kw, Cin, local_b, B);
}
if(snrt_cluster_idx() == 0){

// Wait for DMA to finish
snrt_cluster_hw_barrier();
if (snrt_is_dm_core()) {
load_conv_input_data(Nbatch, H + 2 * pad_h, W + 2 * pad_w, Cin, local_a,
A);
load_weight_data(Cout, Kh, Kw, Cin, local_b, B);
}

if (snrt_is_dm_core()) {
snrt_dma_start_1d(local_c, C,
M * N * meshRow * meshCol * sizeof(int32_t));
}
// Wait for DMA to finish
snrt_cluster_hw_barrier();

snrt_cluster_hw_barrier();
if (snrt_is_dm_core()) {
snrt_dma_start_1d(local_c, C,
M * N * meshRow * meshCol * sizeof(int32_t));
}

if (snrt_global_core_idx() == 0) {
// Set Streamer configuration CSR for conv2d
set_gemmx_streamer_csr(
Aslstride0, Aslstride1, Atlbound0, Atlstride0, Atlbound1,
Atlstride1, Atlbound2, Atlstride2, Atlbound3, Atlstride3, Atlbound4,
Atlstride4, Atlbound5, Atlstride5,
snrt_cluster_hw_barrier();

Bslstride0, Bslstride1, Btlbound0, Btlstride0, Btlbound1,
Btlstride1, Btlbound2, Btlstride2,
if (snrt_global_core_idx() == 0) {
// Set Streamer configuration CSR for conv2d
set_gemmx_streamer_csr(
Aslstride0, Aslstride1, Atlbound0, Atlstride0, Atlbound1,
Atlstride1, Atlbound2, Atlstride2, Atlbound3, Atlstride3, Atlbound4,
Atlstride4, Atlbound5, Atlstride5,

D8slstride0, D8slstride1, D8tlbound0, D8tlstride0, D8tlbound1,
D8tlstride1, D8tlbound2, D8tlstride2,
Bslstride0, Bslstride1, Btlbound0, Btlstride0, Btlbound1,
Btlstride1, Btlbound2, Btlstride2,

Cslstride0, Cslstride1, Ctlbound0, Ctlstride0, Ctlbound1,
Ctlstride1, Ctlbound2, Ctlstride2,
D8slstride0, D8slstride1, D8tlbound0, D8tlstride0, D8tlbound1,
D8tlstride1, D8tlbound2, D8tlstride2,

D32slstride0, D32slstride1, D32tlbound0, D32tlstride0, D32tlbound1,
D32tlstride1, D32tlbound2, D32tlstride2,
Cslstride0, Cslstride1, Ctlbound0, Ctlstride0, Ctlbound1,
Ctlstride1, Ctlbound2, Ctlstride2,

delta_local_a, delta_local_b, delta_local_d8, delta_local_c,
delta_local_d32, bypassSIMD);
D32slstride0, D32slstride1, D32tlbound0, D32tlstride0, D32tlbound1,
D32tlstride1, D32tlbound2, D32tlstride2,

// Set CSR to start Streamer for conv2d
set_gemmx_streamer_start();
delta_local_a, delta_local_b, delta_local_d8, delta_local_c,
delta_local_d32, bypassSIMD);

// Set GEMM configuration CSR
uint32_t subtraction_setting =
gen_subtraction_config(subtraction_a, subtraction_b);
// Set CSR to start Streamer for conv2d
set_gemmx_streamer_start();

uint32_t csr0 =
gen_csr0_config(input_zp_i, output_zp_i, shift_i, max_int_i);
uint32_t csr1 = gen_csr1_config(min_int_i, double_round_i);
uint32_t csr2 = gen_csr2_config(multiplier_i);
// Set GEMM configuration CSR
uint32_t subtraction_setting =
gen_subtraction_config(subtraction_a, subtraction_b);

set_gemmx_csr(K, N, M, subtraction_setting, csr0, csr1, csr2, M * N,
bypassSIMD);
uint32_t csr0 =
gen_csr0_config(input_zp_i, output_zp_i, shift_i, max_int_i);
uint32_t csr1 = gen_csr1_config(min_int_i, double_round_i);
uint32_t csr2 = gen_csr2_config(multiplier_i);

// Set CSR to start GEMM
set_gemmx_start();
set_gemmx_csr(K, N, M, subtraction_setting, csr0, csr1, csr2, M * N,
bypassSIMD);

// Poll until Streamer and GEMM accelerator finish
wait_gemmx_and_streamer();
// Set CSR to start GEMM
set_gemmx_start();

// Poll until Streamer and GEMM accelerator finish
wait_gemmx_and_streamer();

// check the result of the implicit im2col convolution
if (!bypassSIMD) {
err +=
check_gemmx_result_D8(local_d8, D8_direct_conv2d, Batch, M, N);
} else {
err += check_gemmx_result_D32(local_d32, D32_direct_conv2d, Batch,
M, N);
}
printf("SNAX GEMM Conv2d: %s, err = %d . bypassSIMD = %d .\n",
err ? "FAIL" : "PASS", err, bypassSIMD);
};

// check the result of the implicit im2col convolution
if (!bypassSIMD) {
err +=
check_gemmx_result_D8(local_d8, D8_direct_conv2d, Batch, M, N);
} else {
err += check_gemmx_result_D32(local_d32, D32_direct_conv2d, Batch,
M, N);
}
printf("SNAX GEMM Conv2d: %s, err = %d . bypassSIMD = %d .\n",
err ? "FAIL" : "PASS", err, bypassSIMD);
};

return err;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,50 +148,50 @@ void set_gemmx_streamer_csr(
}

// Set CSR to start STREAMER
void set_gemmx_streamer_start() { write_csr(1005 + 6, 1); }
void set_gemmx_streamer_start() { write_csr(1011, 1); }

// Set GEMM configuration CSR
void set_gemmx_csr(int tempLoop0, int tempLoop1, int tempLoop2,
int subtractions, uint32_t csr0, uint32_t csr1,
uint32_t csr2, uint32_t temporal_loop_bound,
uint32_t bypassSIMD) {
// set loop bounds, from innermost to outermost, aka from K to N to M
write_csr(1007 + 6, tempLoop0);
write_csr(1008 + 6, tempLoop1);
write_csr(1009 + 6, tempLoop2);
write_csr(1014, tempLoop0);
write_csr(1015, tempLoop1);
write_csr(1016, tempLoop2);

// set subtraction a and b
write_csr(1010 + 6, subtractions);
write_csr(1017, subtractions);

// set the constants for the SIMD unit
write_csr(1011 + 6, csr0);
write_csr(1012 + 6, csr1);
write_csr(1013 + 6, csr2);
write_csr(1018, csr0);
write_csr(1019, csr1);
write_csr(1020, csr2);

// set the temporal loop bound
write_csr(1014 + 6, temporal_loop_bound);
write_csr(1015 + 6, bypassSIMD);
write_csr(1021, temporal_loop_bound);
write_csr(1022, bypassSIMD);
}

// Set CSR to start GEMM
void set_gemmx_start() { write_csr(1016 + 6, 1); }
void set_gemmx_start() { write_csr(1023, 1); }

// Stall until Streamer and GEMM accelerator finish
void wait_gemmx_and_streamer() {
write_csr(1005 + 6, 0);
write_csr(1005 + 6, 0);
write_csr(1016 + 6, 0);
write_csr(1011, 0);
write_csr(1011, 0);
write_csr(1023, 0);
}

// Read performance counter of the Streamer, a read-only CSR
uint32_t read_gemmx_streamer_perf_counter() {
uint32_t perf_counter = read_csr(1006);
uint32_t perf_counter = read_csr(1013);
return perf_counter;
}

// Read performance counter of GEMM, a read-only CSR
uint32_t read_gemmx_perf_counter() {
uint32_t perf_counter = read_csr(1017);
uint32_t perf_counter = read_csr(1025);
return perf_counter;
}

Expand Down

0 comments on commit 118182c

Please sign in to comment.