@@ -59,81 +59,211 @@ const std::string &getDotClKernel() {
59
59
60
60
const std::string &getSgemmClNoTransKernel () {
61
61
static const std::string sgemm_cl_noTrans_kernel_ =
62
- R"( __kernel void sgemm_cl_noTrans(const __global float* A, const __global float* B,
63
- __global float* C, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc) {
64
-
65
- unsigned int m = get_global_id(0);
66
- unsigned int n = get_global_id(1);
67
- float c = 0.0f;
68
- for (unsigned int k = 0; k < K; ++k) {
69
- float a, b;
70
- a = A[m * lda + k];
71
- b = B[k * ldb + n];
72
- c += a * b;
73
- }
74
- C[m * ldc + n] = c;
75
- })" ;
62
+ R"(
63
+ #define TS 16
64
+ __kernel void sgemm_cl_noTrans(__global const float *A, __global const float *B,
65
+ __global float *C, const int M, const int N,
66
+ const int K) {
67
+ const int globalRow = get_global_id(1); // M dimension
68
+ const int globalCol = get_global_id(0); // N dimension
69
+
70
+ __local float Asub[TS][TS];
71
+ __local float Bsub[TS][TS];
72
+
73
+ float sum = 0.0f;
74
+
75
+ const int localRow = get_local_id(1);
76
+ const int localCol = get_local_id(0);
77
+ const int groupRow = TS * get_group_id(1);
78
+ const int groupCol = TS * get_group_id(0);
79
+
80
+ for (int t = 0; t < (K + TS - 1) / TS; ++t) {
81
+ const int tiledRowA = groupRow + localRow;
82
+ const int tiledColA = t * TS + localCol;
83
+
84
+ const int tiledRowB = t * TS + localRow;
85
+ const int tiledColB = groupCol + localCol;
86
+
87
+ // Load A
88
+ if (tiledRowA < M && tiledColA < K)
89
+ Asub[localRow][localCol] = A[tiledRowA * K + tiledColA];
90
+ else
91
+ Asub[localRow][localCol] = 0.0f;
92
+
93
+ // Load B
94
+ if (tiledRowB < K && tiledColB < N)
95
+ Bsub[localRow][localCol] = B[tiledRowB * N + tiledColB];
96
+ else
97
+ Bsub[localRow][localCol] = 0.0f;
98
+
99
+ barrier(CLK_LOCAL_MEM_FENCE);
100
+
101
+ for (int k = 0; k < TS; ++k)
102
+ sum += Asub[localRow][k] * Bsub[k][localCol];
103
+
104
+ barrier(CLK_LOCAL_MEM_FENCE);
105
+ }
106
+
107
+ if (globalRow < M && globalCol < N)
108
+ C[globalRow * N + globalCol] = sum;
109
+ }
110
+ )" ;
76
111
return sgemm_cl_noTrans_kernel_;
77
112
}
78
113
79
114
const std::string &getSgemmClTransAKernel () {
80
115
static const std::string sgemm_cl_transA_kernel_ =
81
- R"( __kernel void sgemm_cl_transA(const __global float* A, const __global float* B,
82
- __global float* C, unsigned int K, unsigned int lda, unsigned int ldb, unsigned int ldc) {
83
-
84
- unsigned int m = get_global_id(0);
85
- unsigned int n = get_global_id(1);
86
- float c = 0.0f;
87
- for (unsigned int k = 0; k < K; ++k) {
88
- float a, b;
89
- a = A[k * lda + m];
90
- b = B[k * ldb + n];
91
- c += a * b;
92
- }
93
- C[m * ldc + n] = c;
94
- })" ;
116
+ R"(
117
+ #define TS 16
118
+ __kernel void sgemm_cl_transA(__global const float *A, __global const float *B,
119
+ __global float *C, const int M, const int N,
120
+ const int K) {
121
+ const int globalRow = get_global_id(1); // M
122
+ const int globalCol = get_global_id(0); // N
123
+
124
+ __local float Asub[TS][TS];
125
+ __local float Bsub[TS][TS];
126
+
127
+ float sum = 0.0f;
128
+
129
+ const int localRow = get_local_id(1);
130
+ const int localCol = get_local_id(0);
131
+ const int groupRow = TS * get_group_id(1);
132
+ const int groupCol = TS * get_group_id(0);
133
+
134
+ for (int t = 0; t < (K + TS - 1) / TS; ++t) {
135
+ const int tiledRowA = t * TS + localCol;
136
+ const int tiledColA = groupRow + localRow;
137
+
138
+ if (tiledRowA < K && tiledColA < M)
139
+ Asub[localRow][localCol] = A[tiledRowA * M + tiledColA];
140
+ else
141
+ Asub[localRow][localCol] = 0.0f;
142
+
143
+ const int tiledRowB = t * TS + localRow;
144
+ const int tiledColB = groupCol + localCol;
145
+
146
+ if (tiledRowB < K && tiledColB < N)
147
+ Bsub[localRow][localCol] = B[tiledRowB * N + tiledColB];
148
+ else
149
+ Bsub[localRow][localCol] = 0.0f;
150
+
151
+ barrier(CLK_LOCAL_MEM_FENCE);
152
+
153
+ for (int k = 0; k < TS; ++k)
154
+ sum += Asub[localRow][k] * Bsub[k][localCol];
155
+
156
+ barrier(CLK_LOCAL_MEM_FENCE);
157
+ }
158
+
159
+ if (globalRow < M && globalCol < N)
160
+ C[globalRow * N + globalCol] = sum;
161
+ }
162
+ )" ;
95
163
return sgemm_cl_transA_kernel_;
96
164
}
97
165
98
166
const std::string &getSgemmClTransBKernel () {
99
167
static const std::string sgemm_cl_transB_kernel_ =
100
- R"( __kernel void sgemm_cl_transB(const __global float *A, const __global float *B,
101
- __global float *C, unsigned int K,
102
- unsigned int lda, unsigned int ldb,
103
- unsigned int ldc) {
104
-
105
- unsigned int m = get_global_id(0);
106
- unsigned int n = get_global_id(1);
107
- float c = 0.0f;
108
- for (unsigned int k = 0; k < K; ++k) {
109
- float a, b;
110
- a = A[m * lda + k];
111
- b = B[n * ldb + k];
112
- c += a * b;
113
- }
114
- C[m * ldc + n] = c;
115
- })" ;
168
+ R"(
169
+ #define TS 16
170
+ __kernel void sgemm_cl_transB(__global const float *A, __global const float *B,
171
+ __global float *C, const int M, const int N,
172
+ const int K) {
173
+ const int globalRow = get_global_id(1);
174
+ const int globalCol = get_global_id(0);
175
+
176
+ __local float Asub[TS][TS];
177
+ __local float Bsub[TS][TS];
178
+
179
+ float sum = 0.0f;
180
+
181
+ const int localRow = get_local_id(1);
182
+ const int localCol = get_local_id(0);
183
+ const int groupRow = TS * get_group_id(1);
184
+ const int groupCol = TS * get_group_id(0);
185
+
186
+ for (int t = 0; t < (K + TS - 1) / TS; ++t) {
187
+ const int tiledRowA = groupRow + localRow;
188
+ const int tiledColA = t * TS + localCol;
189
+
190
+ if (tiledRowA < M && tiledColA < K)
191
+ Asub[localRow][localCol] = A[tiledRowA * K + tiledColA];
192
+ else
193
+ Asub[localRow][localCol] = 0.0f;
194
+
195
+ const int tiledRowB = groupCol + localCol;
196
+ const int tiledColB = t * TS + localRow;
197
+
198
+ if (tiledRowB < N && tiledColB < K)
199
+ Bsub[localRow][localCol] = B[tiledRowB * K + tiledColB];
200
+ else
201
+ Bsub[localRow][localCol] = 0.0f;
202
+
203
+ barrier(CLK_LOCAL_MEM_FENCE);
204
+
205
+ for (int k = 0; k < TS; ++k)
206
+ sum += Asub[localRow][k] * Bsub[k][localCol];
207
+
208
+ barrier(CLK_LOCAL_MEM_FENCE);
209
+ }
210
+
211
+ if (globalRow < M && globalCol < N)
212
+ C[globalRow * N + globalCol] = sum;
213
+ }
214
+ )" ;
116
215
return sgemm_cl_transB_kernel_;
117
216
}
118
217
119
218
const std::string &getSgemmClTransABKernel () {
120
219
static const std::string sgemm_cl_transAB_kernel_ =
121
- R"( __kernel void sgemm_cl_transAB(const __global float *A, const __global float *B,
122
- __global float *C, unsigned int K,
123
- unsigned int lda, unsigned int ldb,
124
- unsigned int ldc) {
125
-
126
- unsigned int m = get_global_id(0);
127
- unsigned int n = get_global_id(1);
128
- float c = 0.0f;
129
- for (unsigned int k = 0; k < K; ++k) {
130
- float a, b;
131
- a = A[k * lda + m];
132
- b = B[n * ldb + k];
133
- c += a * b;
134
- }
135
- C[m * ldc + n] = c;
136
- })" ;
220
+ R"(
221
+ #define TS 16
222
+ __kernel void sgemm_cl_transAB(__global const float *A, __global const float *B,
223
+ __global float *C, const int M, const int N,
224
+ const int K) {
225
+ const int globalRow = get_global_id(1);
226
+ const int globalCol = get_global_id(0);
227
+
228
+ __local float Asub[TS][TS];
229
+ __local float Bsub[TS][TS];
230
+
231
+ float sum = 0.0f;
232
+
233
+ const int localRow = get_local_id(1);
234
+ const int localCol = get_local_id(0);
235
+ const int groupRow = TS * get_group_id(1);
236
+ const int groupCol = TS * get_group_id(0);
237
+
238
+ for (int t = 0; t < (K + TS - 1) / TS; ++t) {
239
+ const int tiledRowA = t * TS + localCol;
240
+ const int tiledColA = groupRow + localRow;
241
+
242
+ if (tiledRowA < K && tiledColA < M)
243
+ Asub[localRow][localCol] = A[tiledRowA * M + tiledColA];
244
+ else
245
+ Asub[localRow][localCol] = 0.0f;
246
+
247
+ const int tiledRowB = groupCol + localCol;
248
+ const int tiledColB = t * TS + localRow;
249
+
250
+ if (tiledRowB < N && tiledColB < K)
251
+ Bsub[localRow][localCol] = B[tiledRowB * K + tiledColB];
252
+ else
253
+ Bsub[localRow][localCol] = 0.0f;
254
+
255
+ barrier(CLK_LOCAL_MEM_FENCE);
256
+
257
+ for (int k = 0; k < TS; ++k)
258
+ sum += Asub[localRow][k] * Bsub[k][localCol];
259
+
260
+ barrier(CLK_LOCAL_MEM_FENCE);
261
+ }
262
+
263
+ if (globalRow < M && globalCol < N)
264
+ C[globalRow * N + globalCol] = sum;
265
+ }
266
+ )" ;
137
267
return sgemm_cl_transAB_kernel_;
138
268
}
139
269
0 commit comments