Skip to content

Commit 36c281f

Browse files
committed
sync : adapt to CUDA changes (#0)
ggml-ci
1 parent 6937369 commit 36c281f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+9084
-7
lines changed

CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ option(GGML_NO_ACCELERATE "ggml: disable Accelerate framework" OFF)
4444
option(GGML_OPENBLAS "ggml: use OpenBLAS" OFF)
4545
option(GGML_CLBLAST "ggml: use clBLAST" OFF)
4646
option(GGML_HIPBLAS "ggml: use hipBLAS" OFF)
47-
option(GGML_CUBLAS "ggml: use cuBLAS" OFF)
47+
option(GGML_CUDA "ggml: use CUDA" OFF)
48+
option(GGML_CUBLAS "ggml: use CUDA (deprecated)" OFF)
4849
option(GGML_METAL "ggml: use Metal" OFF)
4950

5051
option(GGML_CUDA_FORCE_DMMV "ggml: use dmmv instead of mmvq CUDA kernels" OFF)

examples/common-ggml.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ bool ggml_common_quantize_0(
7070
case GGML_FTYPE_MOSTLY_IQ1_S:
7171
case GGML_FTYPE_MOSTLY_IQ4_NL:
7272
case GGML_FTYPE_MOSTLY_IQ4_XS:
73+
case GGML_FTYPE_MOSTLY_IQ1_M:
7374
{
7475
fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype);
7576
return false;
@@ -193,6 +194,8 @@ bool ggml_common_quantize_0(
193194
case GGML_TYPE_I8:
194195
case GGML_TYPE_I16:
195196
case GGML_TYPE_I32:
197+
case GGML_TYPE_I64:
198+
case GGML_TYPE_F64:
196199
case GGML_TYPE_Q8_1:
197200
case GGML_TYPE_Q8_K:
198201
case GGML_TYPE_IQ2_XXS:
@@ -203,6 +206,7 @@ bool ggml_common_quantize_0(
203206
case GGML_TYPE_IQ1_S:
204207
case GGML_TYPE_IQ4_NL:
205208
case GGML_TYPE_IQ4_XS:
209+
case GGML_TYPE_IQ1_M:
206210
case GGML_TYPE_COUNT:
207211
{
208212
fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype));

examples/whisper/whisper.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#include "ggml-metal.h"
99
#endif
1010

11-
#ifdef GGML_USE_CUBLAS
11+
#ifdef GGML_USE_CUDA
1212
#include "ggml-cuda.h"
1313
#endif
1414

@@ -1031,7 +1031,7 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
10311031
ggml_backend_t backend_gpu = NULL;
10321032

10331033
// initialize the backends
1034-
#ifdef GGML_USE_CUBLAS
1034+
#ifdef GGML_USE_CUDA
10351035
if (params.use_gpu && ggml_cublas_loaded()) {
10361036
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
10371037
backend_gpu = ggml_backend_cuda_init(params.gpu_device);
@@ -3852,7 +3852,7 @@ const char * whisper_print_system_info(void) {
38523852
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
38533853
s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | ";
38543854
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
3855-
s += "CUDA = " + std::to_string(ggml_cpu_has_cublas()) + " | ";
3855+
s += "CUDA = " + std::to_string(ggml_cpu_has_cuda()) + " | ";
38563856
s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
38573857
s += "OPENVINO = " + std::to_string(whisper_has_openvino()) ;
38583858

scripts/sync-llama-am.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ if [ -f $SRC_GGML/llama-src.patch ]; then
9494
# ggml-backend-impl.h -> src/ggml-backend-impl.h
9595
# ggml-backend.c -> src/ggml-backend.c
9696
# ggml-common.h -> src/ggml-common.h
97+
# ggml-cuda/* -> src/ggml-cuda/*
9798
# ggml-cuda.cu -> src/ggml-cuda.cu
9899
# ggml-cuda.h -> src/ggml-cuda.h
99100
# ggml-impl.h -> src/ggml-impl.h
@@ -127,6 +128,7 @@ if [ -f $SRC_GGML/llama-src.patch ]; then
127128
-e 's/\/ggml-backend-impl\.h/\/src\/ggml-backend-impl.h/g' \
128129
-e 's/\/ggml-backend\.c/\/src\/ggml-backend.c/g' \
129130
-e 's/\/ggml-common\.h/\/src\/ggml-common.h/g' \
131+
-e 's/\/ggml-cuda\//\/src\/ggml-cuda\//g' \
130132
-e 's/\/ggml-cuda\.cu/\/src\/ggml-cuda.cu/g' \
131133
-e 's/\/ggml-cuda\.h/\/src\/ggml-cuda.h/g' \
132134
-e 's/\/ggml-impl\.h/\/src\/ggml-impl.h/g' \

scripts/sync-llama.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ cp -rpv ../llama.cpp/ggml-alloc.c src/ggml-alloc.c
55
cp -rpv ../llama.cpp/ggml-backend-impl.h src/ggml-backend-impl.h
66
cp -rpv ../llama.cpp/ggml-backend.c src/ggml-backend.c
77
cp -rpv ../llama.cpp/ggml-common.h src/ggml-common.h
8+
cp -rpv ../llama.cpp/ggml-cuda/* src/ggml-cuda/
89
cp -rpv ../llama.cpp/ggml-cuda.cu src/ggml-cuda.cu
910
cp -rpv ../llama.cpp/ggml-cuda.h src/ggml-cuda.h
1011
cp -rpv ../llama.cpp/ggml-impl.h src/ggml-impl.h

src/CMakeLists.txt

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,11 @@ if (GGML_CLBLAST)
206206
endif()
207207

208208
if (GGML_CUBLAS)
209+
message(WARNING "GGML_CUBLAS is deprecated and will be removed in the future.\nUse GGML_CUDA instead")
210+
set(GGML_CUDA ON)
211+
endif()
212+
213+
if (GGML_CUDA)
209214
cmake_minimum_required(VERSION 3.17)
210215

211216
find_package(CUDAToolkit)
@@ -214,9 +219,11 @@ if (GGML_CUBLAS)
214219

215220
enable_language(CUDA)
216221

217-
set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)
222+
file(GLOB GGML_CUDA_SOURCES "ggml-cuda/*.cu")
223+
list(APPEND GGML_CUDA_SOURCES ggml-cuda.h)
224+
list(APPEND GGML_CUDA_SOURCES ggml-cuda.cu)
218225

219-
set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_CUBLAS)
226+
set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_CUDA)
220227

221228
if (GGML_CUDA_FORCE_DMMV)
222229
add_compile_definitions(GGML_CUDA_FORCE_DMMV)
@@ -266,9 +273,10 @@ if (GGML_HIPBLAS)
266273
if (${hipblas_FOUND} AND ${hip_FOUND})
267274
message(STATUS "HIP and hipBLAS found")
268275

269-
add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
276+
add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUDA)
270277

271278
add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h)
279+
272280
if (BUILD_SHARED_LIBS)
273281
set_target_properties(ggml-rocm PROPERTIES POSITION_INDEPENDENT_CODE ON)
274282
endif()

src/ggml-cuda/acc.cu

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#include "acc.cuh"
2+
3+
static __global__ void acc_f32(const float * x, const float * y, float * dst, const int ne,
4+
const int ne10, const int ne11, const int ne12,
5+
const int nb1, const int nb2, int offset) {
6+
const int i = blockDim.x * blockIdx.x + threadIdx.x;
7+
if (i >= ne) {
8+
return;
9+
}
10+
int src1_idx = i - offset;
11+
int oz = src1_idx / nb2;
12+
int oy = (src1_idx - (oz * nb2)) / nb1;
13+
int ox = src1_idx % nb1;
14+
if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
15+
dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
16+
} else {
17+
dst[i] = x[i];
18+
}
19+
}
20+
21+
static void acc_f32_cuda(const float * x, const float * y, float * dst, const int n_elements,
22+
const int ne10, const int ne11, const int ne12,
23+
const int nb1, const int nb2, const int offset, cudaStream_t stream) {
24+
int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE;
25+
acc_f32<<<num_blocks, CUDA_ACC_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset);
26+
}
27+
28+
void ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
29+
const ggml_tensor * src0 = dst->src[0];
30+
const ggml_tensor * src1 = dst->src[1];
31+
const float * src0_d = (const float *)src0->data;
32+
const float * src1_d = (const float *)src1->data;
33+
float * dst_d = (float *)dst->data;
34+
cudaStream_t stream = ctx.stream();
35+
36+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
37+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
38+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
39+
GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
40+
41+
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
42+
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
43+
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
44+
int offset = dst->op_params[3] / 4; // offset in bytes
45+
46+
acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, stream);
47+
}

src/ggml-cuda/acc.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include "common.cuh"
2+
3+
#define CUDA_ACC_BLOCK_SIZE 256
4+
5+
void ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

src/ggml-cuda/alibi.cu

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#include "alibi.cuh"
2+
3+
static __global__ void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows,
4+
const int n_heads_log2_floor, const float m0, const float m1) {
5+
const int col = blockDim.x*blockIdx.x + threadIdx.x;
6+
7+
if (col >= ncols) {
8+
return;
9+
}
10+
11+
const int row = blockDim.y*blockIdx.y + threadIdx.y;
12+
const int i = row*ncols + col;
13+
14+
const int k = row/k_rows;
15+
16+
float m_k;
17+
if (k < n_heads_log2_floor) {
18+
m_k = powf(m0, k + 1);
19+
} else {
20+
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
21+
}
22+
23+
dst[i] = col * m_k + x[i];
24+
}
25+
26+
static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
27+
const int k_rows, const int n_heads_log2_floor, const float m0,
28+
const float m1, cudaStream_t stream) {
29+
const dim3 block_dims(CUDA_ALIBI_BLOCK_SIZE, 1, 1);
30+
const int num_blocks_x = (ncols + CUDA_ALIBI_BLOCK_SIZE - 1) / (CUDA_ALIBI_BLOCK_SIZE);
31+
const dim3 block_nums(num_blocks_x, nrows, 1);
32+
alibi_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1);
33+
}
34+
35+
void ggml_cuda_op_alibi(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
36+
const ggml_tensor * src0 = dst->src[0];
37+
const float * src0_d = (const float *)src0->data;
38+
float * dst_d = (float *)dst->data;
39+
cudaStream_t stream = ctx.stream();
40+
41+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
42+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
43+
44+
const int64_t ne00 = src0->ne[0];
45+
const int64_t ne01 = src0->ne[1];
46+
const int64_t ne02 = src0->ne[2];
47+
const int64_t nrows = ggml_nrows(src0);
48+
49+
//const int n_past = ((int32_t *) dst->op_params)[0];
50+
const int n_head = ((int32_t *) dst->op_params)[1];
51+
float max_bias;
52+
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
53+
54+
//GGML_ASSERT(ne01 + n_past == ne00);
55+
GGML_ASSERT(n_head == ne02);
56+
57+
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
58+
59+
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
60+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
61+
62+
alibi_f32_cuda(src0_d, dst_d, ne00, nrows, ne01, n_heads_log2_floor, m0, m1, stream);
63+
}

src/ggml-cuda/alibi.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include "common.cuh"
2+
3+
#define CUDA_ALIBI_BLOCK_SIZE 32
4+
5+
void ggml_cuda_op_alibi(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

src/ggml-cuda/arange.cu

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#include "arange.cuh"
2+
3+
static __global__ void arange_f32(float * dst, const int ne0, const float start, const float step) {
4+
// blockIDx.x: idx of ne0 / BLOCK_SIZE
5+
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
6+
if (nidx >= ne0) {
7+
return;
8+
}
9+
dst[nidx] = start + step * nidx;
10+
}
11+
12+
static void arange_f32_cuda(float * dst, const int ne0, const float start, const float step, cudaStream_t stream) {
13+
int num_blocks = (ne0 + CUDA_ARANGE_BLOCK_SIZE - 1) / CUDA_ARANGE_BLOCK_SIZE;
14+
arange_f32<<<num_blocks, CUDA_ARANGE_BLOCK_SIZE, 0, stream>>>(dst, ne0, start, step);
15+
}
16+
17+
void ggml_cuda_op_arange(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
18+
float * dst_d = (float *)dst->data;
19+
cudaStream_t stream = ctx.stream();
20+
21+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
22+
23+
float start;
24+
float stop;
25+
float step;
26+
memcpy(&start, (float *)dst->op_params + 0, sizeof(float));
27+
memcpy(&stop, (float *)dst->op_params + 1, sizeof(float));
28+
memcpy(&step, (float *)dst->op_params + 2, sizeof(float));
29+
30+
int64_t steps = (int64_t)ceil((stop - start) / step);
31+
GGML_ASSERT(ggml_nelements(dst) == steps);
32+
33+
arange_f32_cuda(dst_d, dst->ne[0], start, step, stream);
34+
}

src/ggml-cuda/arange.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include "common.cuh"
2+
3+
#define CUDA_ARANGE_BLOCK_SIZE 256
4+
5+
void ggml_cuda_op_arange(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

src/ggml-cuda/argsort.cu

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#include "argsort.cuh"
2+
3+
template<typename T>
4+
static inline __device__ void ggml_cuda_swap(T & a, T & b) {
5+
T tmp = a;
6+
a = b;
7+
b = tmp;
8+
}
9+
10+
template<ggml_sort_order order>
11+
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols) {
12+
// bitonic sort
13+
int col = threadIdx.x;
14+
int row = blockIdx.y;
15+
16+
if (col >= ncols) return;
17+
18+
const float * x_row = x + row * ncols;
19+
int * dst_row = dst + row * ncols;
20+
21+
// initialize indices
22+
if (col < ncols) {
23+
dst_row[col] = col;
24+
}
25+
__syncthreads();
26+
27+
for (int k = 2; k <= ncols; k *= 2) {
28+
for (int j = k / 2; j > 0; j /= 2) {
29+
int ixj = col ^ j;
30+
if (ixj > col) {
31+
if ((col & k) == 0) {
32+
if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
33+
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
34+
}
35+
} else {
36+
if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
37+
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
38+
}
39+
}
40+
}
41+
__syncthreads();
42+
}
43+
}
44+
}
45+
46+
static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
47+
// bitonic sort requires ncols to be power of 2
48+
GGML_ASSERT((ncols & (ncols - 1)) == 0);
49+
50+
const dim3 block_dims(ncols, 1, 1);
51+
const dim3 block_nums(1, nrows, 1);
52+
if (order == GGML_SORT_ORDER_ASC) {
53+
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
54+
} else if (order == GGML_SORT_ORDER_DESC) {
55+
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
56+
} else {
57+
GGML_ASSERT(false);
58+
}
59+
}
60+
61+
void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
62+
const ggml_tensor * src0 = dst->src[0];
63+
const float * src0_d = (const float *)src0->data;
64+
float * dst_d = (float *)dst->data;
65+
cudaStream_t stream = ctx.stream();
66+
67+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
68+
GGML_ASSERT( dst->type == GGML_TYPE_I32);
69+
GGML_ASSERT(ggml_is_contiguous(src0));
70+
71+
const int64_t ncols = src0->ne[0];
72+
const int64_t nrows = ggml_nrows(src0);
73+
74+
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
75+
76+
argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
77+
}

src/ggml-cuda/argsort.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#include "common.cuh"
2+
3+
void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)