Skip to content

Commit 06ec197

Browse files
committed
merg
1 parent 02780ea commit 06ec197

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

kernels/test/matmul.cu

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,22 +265,24 @@ int run_benchmark(size_t B, size_t M, size_t N, size_t K) {
265265
std::cout << "Converted result back to float" << std::endl;
266266

267267
// Check result
268-
float max_error = 0.0f;
268+
float max_abs_error = 0.0f;
269+
float max_rel_error = 0.0f;
269270
int error_count = 0;
270271
for (int i = 0; i < B * M * N; ++i) {
271272
float abs_error = std::abs(h_C[i] - h_C_ref[i]);
272273
float rel_error = std::abs(abs_error / h_C_ref[i]);
273-
if(error > .01 && abs_error > 1.0) { // large because of bf16 vs fp32 numerics
274+
if(rel_error > .01 && abs_error > 0.1) { // large because of bf16 vs fp32 numerics
274275
int b = i / (M * N), row = i % (M * N) / N, col = i % N;
275276
if(error_count < 20) std::cout << "Error at batch " << b << " row " << row << " col " << col << ": " << h_C[i] << " != " << h_C_ref[i] << " (ref)" << std::endl;
276277
else if(error_count == 21) std::cout << "Too many errors to show them all.\n";
277278
error_count++;
278279
}
279-
max_error = std::max(max_error, error);
280+
max_abs_error = std::max(max_abs_error, abs_error);
281+
max_rel_error = std::max(max_rel_error, rel_error);
280282
}
281283

282284
std::cout << "Total elements: " << M*N << std::endl;
283-
std::cout << "Max error: " << max_error << std::endl;
285+
std::cout << "Max relative error: " << max_rel_error * 100 << "%, max absolute error: " << max_abs_error <<std::endl;
284286
std::cout << "Error count: " << error_count << std::endl;
285287

286288
// Clean up

0 commit comments

Comments
 (0)