Skip to content

Commit

Permalink
sync : adapt to CUDA changes (#0)
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Mar 27, 2024
1 parent 6937369 commit 36c281f
Show file tree
Hide file tree
Showing 61 changed files with 9,084 additions and 7 deletions.
3 changes: 2 additions & 1 deletion CMakeLists.txt
Expand Up @@ -44,7 +44,8 @@ option(GGML_NO_ACCELERATE "ggml: disable Accelerate framework" OFF)
option(GGML_OPENBLAS "ggml: use OpenBLAS" OFF)
option(GGML_CLBLAST "ggml: use clBLAST" OFF)
option(GGML_HIPBLAS "ggml: use hipBLAS" OFF)
option(GGML_CUBLAS "ggml: use cuBLAS" OFF)
option(GGML_CUDA "ggml: use CUDA" OFF)
option(GGML_CUBLAS "ggml: use CUDA (deprecated)" OFF)
option(GGML_METAL "ggml: use Metal" OFF)

option(GGML_CUDA_FORCE_DMMV "ggml: use dmmv instead of mmvq CUDA kernels" OFF)
Expand Down
4 changes: 4 additions & 0 deletions examples/common-ggml.cpp
Expand Up @@ -70,6 +70,7 @@ bool ggml_common_quantize_0(
case GGML_FTYPE_MOSTLY_IQ1_S:
case GGML_FTYPE_MOSTLY_IQ4_NL:
case GGML_FTYPE_MOSTLY_IQ4_XS:
case GGML_FTYPE_MOSTLY_IQ1_M:
{
fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype);
return false;
Expand Down Expand Up @@ -193,6 +194,8 @@ bool ggml_common_quantize_0(
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
case GGML_TYPE_I64:
case GGML_TYPE_F64:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q8_K:
case GGML_TYPE_IQ2_XXS:
Expand All @@ -203,6 +206,7 @@ bool ggml_common_quantize_0(
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_COUNT:
{
fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype));
Expand Down
6 changes: 3 additions & 3 deletions examples/whisper/whisper.cpp
Expand Up @@ -8,7 +8,7 @@
#include "ggml-metal.h"
#endif

#ifdef GGML_USE_CUBLAS
#ifdef GGML_USE_CUDA
#include "ggml-cuda.h"
#endif

Expand Down Expand Up @@ -1031,7 +1031,7 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
ggml_backend_t backend_gpu = NULL;

// initialize the backends
#ifdef GGML_USE_CUBLAS
#ifdef GGML_USE_CUDA
if (params.use_gpu && ggml_cublas_loaded()) {
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
backend_gpu = ggml_backend_cuda_init(params.gpu_device);
Expand Down Expand Up @@ -3852,7 +3852,7 @@ const char * whisper_print_system_info(void) {
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | ";
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
s += "CUDA = " + std::to_string(ggml_cpu_has_cublas()) + " | ";
s += "CUDA = " + std::to_string(ggml_cpu_has_cuda()) + " | ";
s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
s += "OPENVINO = " + std::to_string(whisper_has_openvino()) ;

Expand Down
2 changes: 2 additions & 0 deletions scripts/sync-llama-am.sh
Expand Up @@ -94,6 +94,7 @@ if [ -f $SRC_GGML/llama-src.patch ]; then
# ggml-backend-impl.h -> src/ggml-backend-impl.h
# ggml-backend.c -> src/ggml-backend.c
# ggml-common.h -> src/ggml-common.h
# ggml-cuda/* -> src/ggml-cuda/*
# ggml-cuda.cu -> src/ggml-cuda.cu
# ggml-cuda.h -> src/ggml-cuda.h
# ggml-impl.h -> src/ggml-impl.h
Expand Down Expand Up @@ -127,6 +128,7 @@ if [ -f $SRC_GGML/llama-src.patch ]; then
-e 's/\/ggml-backend-impl\.h/\/src\/ggml-backend-impl.h/g' \
-e 's/\/ggml-backend\.c/\/src\/ggml-backend.c/g' \
-e 's/\/ggml-common\.h/\/src\/ggml-common.h/g' \
-e 's/\/ggml-cuda\//\/src\/ggml-cuda\//g' \
-e 's/\/ggml-cuda\.cu/\/src\/ggml-cuda.cu/g' \
-e 's/\/ggml-cuda\.h/\/src\/ggml-cuda.h/g' \
-e 's/\/ggml-impl\.h/\/src\/ggml-impl.h/g' \
Expand Down
1 change: 1 addition & 0 deletions scripts/sync-llama.sh
Expand Up @@ -5,6 +5,7 @@ cp -rpv ../llama.cpp/ggml-alloc.c src/ggml-alloc.c
cp -rpv ../llama.cpp/ggml-backend-impl.h src/ggml-backend-impl.h
cp -rpv ../llama.cpp/ggml-backend.c src/ggml-backend.c
cp -rpv ../llama.cpp/ggml-common.h src/ggml-common.h
cp -rpv ../llama.cpp/ggml-cuda/* src/ggml-cuda/
cp -rpv ../llama.cpp/ggml-cuda.cu src/ggml-cuda.cu
cp -rpv ../llama.cpp/ggml-cuda.h src/ggml-cuda.h
cp -rpv ../llama.cpp/ggml-impl.h src/ggml-impl.h
Expand Down
14 changes: 11 additions & 3 deletions src/CMakeLists.txt
Expand Up @@ -206,6 +206,11 @@ if (GGML_CLBLAST)
endif()

if (GGML_CUBLAS)
message(WARNING "GGML_CUBLAS is deprecated and will be removed in the future.\nUse GGML_CUDA instead")
set(GGML_CUDA ON)
endif()

if (GGML_CUDA)
cmake_minimum_required(VERSION 3.17)

find_package(CUDAToolkit)
Expand All @@ -214,9 +219,11 @@ if (GGML_CUBLAS)

enable_language(CUDA)

set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)
file(GLOB GGML_CUDA_SOURCES "ggml-cuda/*.cu")
list(APPEND GGML_CUDA_SOURCES ggml-cuda.h)
list(APPEND GGML_CUDA_SOURCES ggml-cuda.cu)

set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_CUBLAS)
set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_CUDA)

if (GGML_CUDA_FORCE_DMMV)
add_compile_definitions(GGML_CUDA_FORCE_DMMV)
Expand Down Expand Up @@ -266,9 +273,10 @@ if (GGML_HIPBLAS)
if (${hipblas_FOUND} AND ${hip_FOUND})
message(STATUS "HIP and hipBLAS found")

add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUDA)

add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h)

if (BUILD_SHARED_LIBS)
set_target_properties(ggml-rocm PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif()
Expand Down
47 changes: 47 additions & 0 deletions src/ggml-cuda/acc.cu
@@ -0,0 +1,47 @@
#include "acc.cuh"

static __global__ void acc_f32(const float * x, const float * y, float * dst, const int ne,
const int ne10, const int ne11, const int ne12,
const int nb1, const int nb2, int offset) {
const int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i >= ne) {
return;
}
int src1_idx = i - offset;
int oz = src1_idx / nb2;
int oy = (src1_idx - (oz * nb2)) / nb1;
int ox = src1_idx % nb1;
if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
} else {
dst[i] = x[i];
}
}

static void acc_f32_cuda(const float * x, const float * y, float * dst, const int n_elements,
const int ne10, const int ne11, const int ne12,
const int nb1, const int nb2, const int offset, cudaStream_t stream) {
int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE;
acc_f32<<<num_blocks, CUDA_ACC_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset);
}

void ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported

int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
int offset = dst->op_params[3] / 4; // offset in bytes

acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, stream);
}
5 changes: 5 additions & 0 deletions src/ggml-cuda/acc.cuh
@@ -0,0 +1,5 @@
#include "common.cuh"

#define CUDA_ACC_BLOCK_SIZE 256

void ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
63 changes: 63 additions & 0 deletions src/ggml-cuda/alibi.cu
@@ -0,0 +1,63 @@
#include "alibi.cuh"

static __global__ void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows,
const int n_heads_log2_floor, const float m0, const float m1) {
const int col = blockDim.x*blockIdx.x + threadIdx.x;

if (col >= ncols) {
return;
}

const int row = blockDim.y*blockIdx.y + threadIdx.y;
const int i = row*ncols + col;

const int k = row/k_rows;

float m_k;
if (k < n_heads_log2_floor) {
m_k = powf(m0, k + 1);
} else {
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
}

dst[i] = col * m_k + x[i];
}

static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
const int k_rows, const int n_heads_log2_floor, const float m0,
const float m1, cudaStream_t stream) {
const dim3 block_dims(CUDA_ALIBI_BLOCK_SIZE, 1, 1);
const int num_blocks_x = (ncols + CUDA_ALIBI_BLOCK_SIZE - 1) / (CUDA_ALIBI_BLOCK_SIZE);
const dim3 block_nums(num_blocks_x, nrows, 1);
alibi_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1);
}

void ggml_cuda_op_alibi(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t nrows = ggml_nrows(src0);

//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_head = ((int32_t *) dst->op_params)[1];
float max_bias;
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));

//GGML_ASSERT(ne01 + n_past == ne00);
GGML_ASSERT(n_head == ne02);

const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));

const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);

alibi_f32_cuda(src0_d, dst_d, ne00, nrows, ne01, n_heads_log2_floor, m0, m1, stream);
}
5 changes: 5 additions & 0 deletions src/ggml-cuda/alibi.cuh
@@ -0,0 +1,5 @@
#include "common.cuh"

#define CUDA_ALIBI_BLOCK_SIZE 32

void ggml_cuda_op_alibi(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
34 changes: 34 additions & 0 deletions src/ggml-cuda/arange.cu
@@ -0,0 +1,34 @@
#include "arange.cuh"

static __global__ void arange_f32(float * dst, const int ne0, const float start, const float step) {
// blockIDx.x: idx of ne0 / BLOCK_SIZE
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
}
dst[nidx] = start + step * nidx;
}

static void arange_f32_cuda(float * dst, const int ne0, const float start, const float step, cudaStream_t stream) {
int num_blocks = (ne0 + CUDA_ARANGE_BLOCK_SIZE - 1) / CUDA_ARANGE_BLOCK_SIZE;
arange_f32<<<num_blocks, CUDA_ARANGE_BLOCK_SIZE, 0, stream>>>(dst, ne0, start, step);
}

void ggml_cuda_op_arange(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(dst->type == GGML_TYPE_F32);

float start;
float stop;
float step;
memcpy(&start, (float *)dst->op_params + 0, sizeof(float));
memcpy(&stop, (float *)dst->op_params + 1, sizeof(float));
memcpy(&step, (float *)dst->op_params + 2, sizeof(float));

int64_t steps = (int64_t)ceil((stop - start) / step);
GGML_ASSERT(ggml_nelements(dst) == steps);

arange_f32_cuda(dst_d, dst->ne[0], start, step, stream);
}
5 changes: 5 additions & 0 deletions src/ggml-cuda/arange.cuh
@@ -0,0 +1,5 @@
#include "common.cuh"

#define CUDA_ARANGE_BLOCK_SIZE 256

void ggml_cuda_op_arange(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
77 changes: 77 additions & 0 deletions src/ggml-cuda/argsort.cu
@@ -0,0 +1,77 @@
#include "argsort.cuh"

template<typename T>
static inline __device__ void ggml_cuda_swap(T & a, T & b) {
T tmp = a;
a = b;
b = tmp;
}

template<ggml_sort_order order>
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols) {
// bitonic sort
int col = threadIdx.x;
int row = blockIdx.y;

if (col >= ncols) return;

const float * x_row = x + row * ncols;
int * dst_row = dst + row * ncols;

// initialize indices
if (col < ncols) {
dst_row[col] = col;
}
__syncthreads();

for (int k = 2; k <= ncols; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
}
} else {
if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
}
}
}
__syncthreads();
}
}
}

static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
// bitonic sort requires ncols to be power of 2
GGML_ASSERT((ncols & (ncols - 1)) == 0);

const dim3 block_dims(ncols, 1, 1);
const dim3 block_nums(1, nrows, 1);
if (order == GGML_SORT_ORDER_ASC) {
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
} else if (order == GGML_SORT_ORDER_DESC) {
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
} else {
GGML_ASSERT(false);
}
}

void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_is_contiguous(src0));

const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);

enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];

argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
}
3 changes: 3 additions & 0 deletions src/ggml-cuda/argsort.cuh
@@ -0,0 +1,3 @@
#include "common.cuh"

void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

0 comments on commit 36c281f

Please sign in to comment.