Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add some new ops, fix some operators and add batch operations to certain operators. #747

Merged
merged 19 commits into from Mar 3, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 17 additions & 0 deletions include/ggml/ggml.h
Expand Up @@ -454,6 +454,8 @@ extern "C" {
GGML_OP_POOL_2D,
GGML_OP_UPSCALE, // nearest interpolate
GGML_OP_PAD,
GGML_OP_ARANGE,
GGML_OP_TIMESTEP_EMBEDDING,
GGML_OP_ARGSORT,
GGML_OP_LEAKY_RELU,

Expand Down Expand Up @@ -1661,6 +1663,15 @@ extern "C" {
int p2,
int p3);

// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
// timesteps: [N,]
// return: [N, dim]
GGML_API struct ggml_tensor * ggml_timestep_embedding(
struct ggml_context * ctx,
struct ggml_tensor * timesteps,
int dim,
int max_period);

// sort rows
enum ggml_sort_order {
GGML_SORT_ORDER_ASC,
Expand All @@ -1672,6 +1683,12 @@ extern "C" {
struct ggml_tensor * a,
enum ggml_sort_order order);

GGML_API struct ggml_tensor * ggml_arange(
struct ggml_context * ctx,
float start,
float stop,
float step);

leejet marked this conversation as resolved.
Show resolved Hide resolved
// top k elements per row
GGML_API struct ggml_tensor * ggml_top_k(
struct ggml_context * ctx,
Expand Down
220 changes: 181 additions & 39 deletions src/ggml-cuda.cu
Expand Up @@ -616,6 +616,8 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + Q
#define CUDA_UPSCALE_BLOCK_SIZE 256
#define CUDA_CONCAT_BLOCK_SIZE 256
#define CUDA_PAD_BLOCK_SIZE 256
#define CUDA_ARANGE_BLOCK_SIZE 256
#define CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
#define CUDA_ACC_BLOCK_SIZE 256
#define CUDA_IM2COL_BLOCK_SIZE 256
#define CUDA_POOL2D_BLOCK_SIZE 256
Expand Down Expand Up @@ -1000,7 +1002,11 @@ static __global__ void concat_f32(const float * x,const float * y, float * dst,
}
}

static __global__ void upscale_f32(const float * x, float * dst, const int ne00, const int nb02, const int scale_factor) {
static __global__ void upscale_f32(const float * x, float * dst, const int ne00, const int ne00xne01, const int scale_factor) {
// blockIdx.z: idx of ne02*ne03
// blockIdx.y: idx of ne01*scale_factor, aka ne1
// blockIDx.x: idx of ne00*scale_factor / BLOCK_SIZE
// ne00xne01: ne00 * ne01
int ne0 = ne00 * scale_factor;
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
Expand All @@ -1012,15 +1018,18 @@ static __global__ void upscale_f32(const float * x, float * dst, const int ne00,
int offset_src =
i00 +
i01 * ne00 +
blockIdx.z * nb02;
blockIdx.z * ne00xne01;
int offset_dst =
nidx +
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
dst[offset_dst] = x[offset_src];
}

static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02) {
static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
// blockIdx.z: idx of ne2*ne3, aka ne02*ne03
// blockIdx.y: idx of ne1
// blockIDx.x: idx of ne0 / BLOCK_SIZE
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
Expand All @@ -1031,7 +1040,7 @@ static __global__ void pad_f32(const float * x, float * dst, const int ne0, cons
nidx +
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02) {
if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
int offset_src =
nidx +
blockIdx.y * ne00 +
Expand All @@ -1042,8 +1051,42 @@ static __global__ void pad_f32(const float * x, float * dst, const int ne0, cons
}
}

static __global__ void arange_f32(float * dst, const int ne0, const float start, const float step) {
// blockIDx.x: idx of ne0 / BLOCK_SIZE
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
}
dst[nidx] = start + step * nidx;
}

static __global__ void timestep_embedding_f32(const float * timesteps, float * dst, const int nb1, const int dim, const int max_period) {
// blockIDx.y: idx of timesteps->ne[0]
// blockIDx.x: idx of ((dim + 1) / 2) / BLOCK_SIZE
int i = blockIdx.y;
int j = threadIdx.x + blockIdx.x * blockDim.x;
float * embed_data = (float *)((char *)dst + i*nb1);

if (dim % 2 != 0 && j == ((dim + 1) / 2)) {
embed_data[dim] = 0.f;
}

int half = dim / 2;
if (j >= half) {
return;
}

float timestep = timesteps[i];
float freq = (float)exp(-logf(max_period) * j / half);
float arg = timestep * freq;
embed_data[j] = cos(arg);
embed_data[j + half] = sin(arg);
leejet marked this conversation as resolved.
Show resolved Hide resolved
}

template <int block_size>
static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
// blockIdx.x: num_groups idx
// threadIdx.x: block_size idx
int start = blockIdx.x * group_size;
int end = start + group_size;

Expand Down Expand Up @@ -6449,25 +6492,25 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;

if (i >= ne) {
return;
}

// determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
// then combine those indices with the corresponding byte offsets to get the total offsets
const int i03 = i/(ne00 * ne01 * ne02);
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;

const int i13 = i/(ne10 * ne11 * ne12);
const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
const int64_t i03 = i/(ne00 * ne01 * ne02);
const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;

const int64_t i13 = i/(ne10 * ne11 * ne12);
const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;

cpy_1(cx + x_offset, cdst + dst_offset);
}
Expand Down Expand Up @@ -6956,23 +6999,23 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,

template <typename T>
static __global__ void im2col_kernel(
const float * x, T * dst, int batch_offset,
int offset_delta, int IC, int IW, int IH, int OH, int OW, int KW, int KH, int pelements, int CHW,
const float * x, T * dst, int64_t batch_offset,
int64_t offset_delta, int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW,
int s0, int s1, int p0, int p1, int d0, int d1) {
const int i = threadIdx.x + blockIdx.x * blockDim.x;
const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= pelements) {
return;
}

const int ksize = OW * (KH > 1 ? KW : 1);
const int kx = i / ksize;
const int kd = kx * ksize;
const int ky = (i - kd) / OW;
const int ix = i % OW;
const int64_t ksize = OW * (KH > 1 ? KW : 1);
const int64_t kx = i / ksize;
const int64_t kd = kx * ksize;
const int64_t ky = (i - kd) / OW;
const int64_t ix = i % OW;

const int oh = blockIdx.y;
const int batch = blockIdx.z / IC;
const int ic = blockIdx.z % IC;
const int64_t oh = blockIdx.y;
const int64_t batch = blockIdx.z / IC;
const int64_t ic = blockIdx.z % IC;

const int64_t iiw = ix * s0 + kx * d0 - p0;
const int64_t iih = oh * s1 + ky * d1 - p1;
Expand Down Expand Up @@ -7298,19 +7341,33 @@ static void concat_f32_cuda(const float * x, const float * y, float * dst, const
concat_f32<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
}

static void upscale_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int scale_factor, cudaStream_t stream) {
static void upscale_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int ne03,
const int scale_factor, cudaStream_t stream) {
int ne0 = (ne00 * scale_factor);
int num_blocks = (ne0 + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
dim3 gridDim(num_blocks, (ne01 * scale_factor), ne02);
dim3 gridDim(num_blocks, (ne01 * scale_factor), ne02*ne03);
upscale_f32<<<gridDim, CUDA_UPSCALE_BLOCK_SIZE, 0, stream>>>(x, dst, ne00, ne00 * ne01, scale_factor);
}

static void pad_f32_cuda(const float * x, float * dst,
const int ne00, const int ne01, const int ne02,
const int ne0, const int ne1, const int ne2, cudaStream_t stream) {
const int ne00, const int ne01, const int ne02, const int ne03,
const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
dim3 gridDim(num_blocks, ne1, ne2);
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02);
dim3 gridDim(num_blocks, ne1, ne2*ne3);
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
}

static void arange_f32_cuda(float * dst, const int ne0, const float start, const float step, cudaStream_t stream) {
int num_blocks = (ne0 + CUDA_ARANGE_BLOCK_SIZE - 1) / CUDA_ARANGE_BLOCK_SIZE;
arange_f32<<<num_blocks, CUDA_ARANGE_BLOCK_SIZE, 0, stream>>>(dst, ne0, start, step);
}

static void timestep_embedding_f32_cuda(const float * x, float * dst, const int ne00, const int nb1,
const int dim, const int max_period, cudaStream_t stream) {
int half_ceil = (dim + 1) / 2;
int num_blocks = (half_ceil + CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE;
dim3 gridDim(num_blocks, ne00, 1);
timestep_embedding_f32<<<gridDim, CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE, 0, stream>>>(x, dst, nb1, dim, max_period);
}

static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
Expand Down Expand Up @@ -8443,8 +8500,8 @@ static void soft_max_f32_cuda(const float * x, const float * mask, const float *

template <typename T>
static void im2col_cuda(const float* x, T* dst,
int IW, int IH, int OW, int OH, int KW, int KH, int IC,
int batch, int batch_offset, int offset_delta,
int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
int64_t batch, int64_t batch_offset, int64_t offset_delta,
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
const int parallel_elements = OW * KW * KH;
const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
Expand Down Expand Up @@ -9123,7 +9180,7 @@ static void ggml_cuda_op_group_norm(

int num_groups = dst->op_params[0];
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
group_norm_f32_cuda(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream);
group_norm_f32_cuda(src0_dd, dst_dd, num_groups * src0->ne[3], group_size, ggml_nelements(src0), main_stream);

(void) src1;
(void) dst;
Expand Down Expand Up @@ -9156,7 +9213,7 @@ static void ggml_cuda_op_upscale(

const int scale_factor = dst->op_params[0];

upscale_f32_cuda(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], scale_factor, main_stream);
upscale_f32_cuda(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], scale_factor, main_stream);

(void) src1;
(void) dst;
Expand All @@ -9172,8 +9229,46 @@ static void ggml_cuda_op_pad(
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors

pad_f32_cuda(src0_dd, dst_dd,
src0->ne[0], src0->ne[1], src0->ne[2],
dst->ne[0], dst->ne[1], dst->ne[2], main_stream);
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], main_stream);

(void) src1;
(void) dst;
(void) src1_dd;
}

static void ggml_cuda_op_arange(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(dst->type == GGML_TYPE_F32);

const float start = ((float*)dst->op_params)[0];
const float stop = ((float*)dst->op_params)[1];
const float step = ((float*)dst->op_params)[2];

int64_t steps = (int64_t)ceil((stop - start) / step);
GGML_ASSERT(ggml_nelements(dst) == steps);

arange_f32_cuda(dst_dd, dst->ne[0], start, step, main_stream);

(void) src0;
(void) src1;
(void) src0_dd;
(void) src1_dd;
}


static void ggml_cuda_op_timestep_embedding(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);

const int dim = dst->op_params[0];
const int max_period = dst->op_params[1];

timestep_embedding_f32_cuda(src0_dd, dst_dd, src0->ne[0], dst->nb[1], dim, max_period, main_stream);

(void) src1;
(void) dst;
Expand Down Expand Up @@ -10458,6 +10553,45 @@ static void ggml_cuda_pad(const ggml_tensor * src0, const ggml_tensor * src1, gg
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_pad);
}

static void ggml_cuda_arange(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;

const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU;
leejet marked this conversation as resolved.
Show resolved Hide resolved

// dd = data device
float * src0_ddf = nullptr;
float * src1_ddf = nullptr;
float * dst_ddf = nullptr;

cuda_pool_alloc<float> dst_f;

ggml_cuda_set_device(g_main_device);
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];

if (dst_on_device) {
dst_ddf = (float *) dst_extra->data_device[g_main_device];
} else {
dst_ddf = dst_f.alloc(ggml_nelements(dst));
}
slaren marked this conversation as resolved.
Show resolved Hide resolved

// do the computation
ggml_cuda_op_arange(src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
CUDA_CHECK(cudaGetLastError());

// copy dst to host if necessary
if (!dst_on_device) {
CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream));
}

if (dst->backend == GGML_BACKEND_TYPE_CPU) {
CUDA_CHECK(cudaDeviceSynchronize());
}
}

static void ggml_cuda_timestep_embedding(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_timestep_embedding);
}

static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm);
}
Expand Down Expand Up @@ -11358,6 +11492,12 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
case GGML_OP_PAD:
func = ggml_cuda_pad;
break;
case GGML_OP_ARANGE:
func = ggml_cuda_arange;
break;
case GGML_OP_TIMESTEP_EMBEDDING:
func = ggml_cuda_timestep_embedding;
break;
case GGML_OP_LEAKY_RELU:
func = ggml_cuda_leaky_relu;
break;
Expand Down Expand Up @@ -12253,6 +12393,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_GROUP_NORM:
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_LEAKY_RELU:
return true;
default:
Expand Down