Skip to content

Commit 5515495

Browse files
committed
[OpenCL] Optimized Single-Precision GEMM Kernel
This pull request adds a highly optimized single-precision General Matrix Multiplication (GEMM) kernel developed for OpenCL. The enhancements in this kernel aim to improve computational efficiency and performance for matrix operation, which reduces execution time and enhances throughput. **Self-evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Donghyeon Jeong <[email protected]>
1 parent e0995a5 commit 5515495

File tree

2 files changed

+198
-71
lines changed

2 files changed

+198
-71
lines changed

nntrainer/tensor/cl_operations/blas_kernel_strings.cpp

Lines changed: 190 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -59,81 +59,211 @@ const std::string &getDotClKernel() {
5959

6060
const std::string &getSgemmClNoTransKernel() {
6161
static const std::string sgemm_cl_noTrans_kernel_ =
62-
R"(__kernel void sgemm_cl_noTrans(const __global float* A, const __global float* B,
63-
__global float* C, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc) {
64-
65-
unsigned int m = get_global_id(0);
66-
unsigned int n = get_global_id(1);
67-
float c = 0.0f;
68-
for (unsigned int k = 0; k < K; ++k) {
69-
float a, b;
70-
a = A[m * lda + k];
71-
b = B[k * ldb + n];
72-
c += a * b;
73-
}
74-
C[m * ldc + n] = c;
75-
})";
62+
R"(
63+
#define TS 16
64+
__kernel void sgemm_cl_noTrans(__global const float *A, __global const float *B,
65+
__global float *C, const int M, const int N,
66+
const int K) {
67+
const int globalRow = get_global_id(1); // M dimension
68+
const int globalCol = get_global_id(0); // N dimension
69+
70+
__local float Asub[TS][TS];
71+
__local float Bsub[TS][TS];
72+
73+
float sum = 0.0f;
74+
75+
const int localRow = get_local_id(1);
76+
const int localCol = get_local_id(0);
77+
const int groupRow = TS * get_group_id(1);
78+
const int groupCol = TS * get_group_id(0);
79+
80+
for (int t = 0; t < (K + TS - 1) / TS; ++t) {
81+
const int tiledRowA = groupRow + localRow;
82+
const int tiledColA = t * TS + localCol;
83+
84+
const int tiledRowB = t * TS + localRow;
85+
const int tiledColB = groupCol + localCol;
86+
87+
// Load A
88+
if (tiledRowA < M && tiledColA < K)
89+
Asub[localRow][localCol] = A[tiledRowA * K + tiledColA];
90+
else
91+
Asub[localRow][localCol] = 0.0f;
92+
93+
// Load B
94+
if (tiledRowB < K && tiledColB < N)
95+
Bsub[localRow][localCol] = B[tiledRowB * N + tiledColB];
96+
else
97+
Bsub[localRow][localCol] = 0.0f;
98+
99+
barrier(CLK_LOCAL_MEM_FENCE);
100+
101+
for (int k = 0; k < TS; ++k)
102+
sum += Asub[localRow][k] * Bsub[k][localCol];
103+
104+
barrier(CLK_LOCAL_MEM_FENCE);
105+
}
106+
107+
if (globalRow < M && globalCol < N)
108+
C[globalRow * N + globalCol] = sum;
109+
}
110+
)";
76111
return sgemm_cl_noTrans_kernel_;
77112
}
78113

79114
const std::string &getSgemmClTransAKernel() {
80115
static const std::string sgemm_cl_transA_kernel_ =
81-
R"(__kernel void sgemm_cl_transA(const __global float* A, const __global float* B,
82-
__global float* C, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc) {
83-
84-
unsigned int m = get_global_id(0);
85-
unsigned int n = get_global_id(1);
86-
float c = 0.0f;
87-
for (unsigned int k = 0; k < K; ++k) {
88-
float a, b;
89-
a = A[k * lda + m];
90-
b = B[k * ldb + n];
91-
c += a * b;
92-
}
93-
C[m * ldc + n] = c;
94-
})";
116+
R"(
117+
#define TS 16
118+
__kernel void sgemm_cl_transA(__global const float *A, __global const float *B,
119+
__global float *C, const int M, const int N,
120+
const int K) {
121+
const int globalRow = get_global_id(1); // M
122+
const int globalCol = get_global_id(0); // N
123+
124+
__local float Asub[TS][TS];
125+
__local float Bsub[TS][TS];
126+
127+
float sum = 0.0f;
128+
129+
const int localRow = get_local_id(1);
130+
const int localCol = get_local_id(0);
131+
const int groupRow = TS * get_group_id(1);
132+
const int groupCol = TS * get_group_id(0);
133+
134+
for (int t = 0; t < (K + TS - 1) / TS; ++t) {
135+
const int tiledRowA = t * TS + localCol;
136+
const int tiledColA = groupRow + localRow;
137+
138+
if (tiledRowA < K && tiledColA < M)
139+
Asub[localRow][localCol] = A[tiledRowA * M + tiledColA];
140+
else
141+
Asub[localRow][localCol] = 0.0f;
142+
143+
const int tiledRowB = t * TS + localRow;
144+
const int tiledColB = groupCol + localCol;
145+
146+
if (tiledRowB < K && tiledColB < N)
147+
Bsub[localRow][localCol] = B[tiledRowB * N + tiledColB];
148+
else
149+
Bsub[localRow][localCol] = 0.0f;
150+
151+
barrier(CLK_LOCAL_MEM_FENCE);
152+
153+
for (int k = 0; k < TS; ++k)
154+
sum += Asub[localRow][k] * Bsub[k][localCol];
155+
156+
barrier(CLK_LOCAL_MEM_FENCE);
157+
}
158+
159+
if (globalRow < M && globalCol < N)
160+
C[globalRow * N + globalCol] = sum;
161+
}
162+
)";
95163
return sgemm_cl_transA_kernel_;
96164
}
97165

98166
const std::string &getSgemmClTransBKernel() {
99167
static const std::string sgemm_cl_transB_kernel_ =
100-
R"(__kernel void sgemm_cl_transB(const __global float *A, const __global float *B,
101-
__global float *C, unsigned int K,
102-
unsigned int lda, unsigned int ldb,
103-
unsigned int ldc) {
104-
105-
unsigned int m = get_global_id(0);
106-
unsigned int n = get_global_id(1);
107-
float c = 0.0f;
108-
for (unsigned int k = 0; k < K; ++k) {
109-
float a, b;
110-
a = A[m * lda + k];
111-
b = B[n * ldb + k];
112-
c += a * b;
113-
}
114-
C[m * ldc + n] = c;
115-
})";
168+
R"(
169+
#define TS 16
170+
__kernel void sgemm_cl_transB(__global const float *A, __global const float *B,
171+
__global float *C, const int M, const int N,
172+
const int K) {
173+
const int globalRow = get_global_id(1);
174+
const int globalCol = get_global_id(0);
175+
176+
__local float Asub[TS][TS];
177+
__local float Bsub[TS][TS];
178+
179+
float sum = 0.0f;
180+
181+
const int localRow = get_local_id(1);
182+
const int localCol = get_local_id(0);
183+
const int groupRow = TS * get_group_id(1);
184+
const int groupCol = TS * get_group_id(0);
185+
186+
for (int t = 0; t < (K + TS - 1) / TS; ++t) {
187+
const int tiledRowA = groupRow + localRow;
188+
const int tiledColA = t * TS + localCol;
189+
190+
if (tiledRowA < M && tiledColA < K)
191+
Asub[localRow][localCol] = A[tiledRowA * K + tiledColA];
192+
else
193+
Asub[localRow][localCol] = 0.0f;
194+
195+
const int tiledRowB = groupCol + localCol;
196+
const int tiledColB = t * TS + localRow;
197+
198+
if (tiledRowB < N && tiledColB < K)
199+
Bsub[localRow][localCol] = B[tiledRowB * K + tiledColB];
200+
else
201+
Bsub[localRow][localCol] = 0.0f;
202+
203+
barrier(CLK_LOCAL_MEM_FENCE);
204+
205+
for (int k = 0; k < TS; ++k)
206+
sum += Asub[localRow][k] * Bsub[k][localCol];
207+
208+
barrier(CLK_LOCAL_MEM_FENCE);
209+
}
210+
211+
if (globalRow < M && globalCol < N)
212+
C[globalRow * N + globalCol] = sum;
213+
}
214+
)";
116215
return sgemm_cl_transB_kernel_;
117216
}
118217

119218
const std::string &getSgemmClTransABKernel() {
120219
static const std::string sgemm_cl_transAB_kernel_ =
121-
R"(__kernel void sgemm_cl_transAB(const __global float *A, const __global float *B,
122-
__global float *C, unsigned int K,
123-
unsigned int lda, unsigned int ldb,
124-
unsigned int ldc) {
125-
126-
unsigned int m = get_global_id(0);
127-
unsigned int n = get_global_id(1);
128-
float c = 0.0f;
129-
for (unsigned int k = 0; k < K; ++k) {
130-
float a, b;
131-
a = A[k * lda + m];
132-
b = B[n * ldb + k];
133-
c += a * b;
134-
}
135-
C[m * ldc + n] = c;
136-
})";
220+
R"(
221+
#define TS 16
222+
__kernel void sgemm_cl_transAB(__global const float *A, __global const float *B,
223+
__global float *C, const int M, const int N,
224+
const int K) {
225+
const int globalRow = get_global_id(1);
226+
const int globalCol = get_global_id(0);
227+
228+
__local float Asub[TS][TS];
229+
__local float Bsub[TS][TS];
230+
231+
float sum = 0.0f;
232+
233+
const int localRow = get_local_id(1);
234+
const int localCol = get_local_id(0);
235+
const int groupRow = TS * get_group_id(1);
236+
const int groupCol = TS * get_group_id(0);
237+
238+
for (int t = 0; t < (K + TS - 1) / TS; ++t) {
239+
const int tiledRowA = t * TS + localCol;
240+
const int tiledColA = groupRow + localRow;
241+
242+
if (tiledRowA < K && tiledColA < M)
243+
Asub[localRow][localCol] = A[tiledRowA * M + tiledColA];
244+
else
245+
Asub[localRow][localCol] = 0.0f;
246+
247+
const int tiledRowB = groupCol + localCol;
248+
const int tiledColB = t * TS + localRow;
249+
250+
if (tiledRowB < N && tiledColB < K)
251+
Bsub[localRow][localCol] = B[tiledRowB * K + tiledColB];
252+
else
253+
Bsub[localRow][localCol] = 0.0f;
254+
255+
barrier(CLK_LOCAL_MEM_FENCE);
256+
257+
for (int k = 0; k < TS; ++k)
258+
sum += Asub[localRow][k] * Bsub[k][localCol];
259+
260+
barrier(CLK_LOCAL_MEM_FENCE);
261+
}
262+
263+
if (globalRow < M && globalCol < N)
264+
C[globalRow * N + globalCol] = sum;
265+
}
266+
)";
137267
return sgemm_cl_transAB_kernel_;
138268
}
139269

nntrainer/tensor/cl_operations/blas_kernels.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -247,29 +247,26 @@ void sgemm_cl(bool TransA, bool TransB, const float *A, const float *B,
247247
break;
248248
}
249249

250-
result = kernel_sgemm_ptr->SetKernelArguments(3, &K, sizeof(int));
250+
result = kernel_sgemm_ptr->SetKernelArguments(3, &M, sizeof(int));
251251
if (!result) {
252252
break;
253253
}
254254

255-
result = kernel_sgemm_ptr->SetKernelArguments(4, &lda, sizeof(int));
255+
result = kernel_sgemm_ptr->SetKernelArguments(4, &N, sizeof(int));
256256
if (!result) {
257257
break;
258258
}
259259

260-
result = kernel_sgemm_ptr->SetKernelArguments(5, &ldb, sizeof(int));
260+
result = kernel_sgemm_ptr->SetKernelArguments(5, &K, sizeof(int));
261261
if (!result) {
262262
break;
263263
}
264+
const int tiled_size = 16;
265+
const int work_groups_count[3] = {
266+
(int)((N + tiled_size - 1) / tiled_size) * tiled_size,
267+
(int)((M + tiled_size - 1) / tiled_size) * tiled_size, 1}; // test-value
264268

265-
result = kernel_sgemm_ptr->SetKernelArguments(6, &ldc, sizeof(int));
266-
if (!result) {
267-
break;
268-
}
269-
270-
const int work_groups_count[3] = {(int)M, (int)N, 1};
271-
/// @todo: create a group size by device & input
272-
const int work_group_size[3] = {1, 1, 1}; // test-value
269+
const int work_group_size[3] = {tiled_size, tiled_size, 1}; // test-value
273270

274271
result = blas_cc->command_queue_inst_.DispatchCommand(
275272
kernel_sgemm_ptr, work_groups_count, work_group_size);

0 commit comments

Comments
 (0)