Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add some new ops, fix some operators and add batch operations to certain operators. #747

Merged
merged 19 commits into from Mar 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 9 additions & 2 deletions examples/simple/simple-backend.cpp
Expand Up @@ -19,6 +19,13 @@
#include <string>
#include <vector>

static void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
(void) level;
(void) user_data;
fputs(text, stderr);
fflush(stderr);
}

// This is a simple model with two tensors a and b
struct simple_model {
struct ggml_tensor * a;
Expand Down Expand Up @@ -47,7 +54,7 @@ void load_model(simple_model & model, float * a, float * b, int rows_A, int cols

#ifdef GGML_USE_METAL
fprintf(stderr, "%s: using Metal backend\n", __func__);
ggml_metal_log_set_callback(ggml_log_callback_default, nullptr);
ggml_backend_metal_log_set_callback(ggml_log_callback_default, nullptr);
model.backend = ggml_backend_metal_init();
if (!model.backend) {
fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__);
Expand Down Expand Up @@ -191,7 +198,7 @@ int main(void) {
// 55.00 90.00 126.00 28.00
// 50.00 54.00 42.00 64.00 ]

printf("mul mat (%d x %d) (transposed result):\n[", result->ne[0], result->ne[1]);
printf("mul mat (%d x %d) (transposed result):\n[", (int) result->ne[0], (int) result->ne[1]);
for (int j = 0; j < result->ne[1] /* rows */; j++) {
if (j > 0) {
printf("\n");
Expand Down
2 changes: 1 addition & 1 deletion examples/simple/simple-ctx.cpp
Expand Up @@ -108,7 +108,7 @@ int main(void) {
// 55.00 90.00 126.00 28.00
// 50.00 54.00 42.00 64.00 ]

printf("mul mat (%d x %d) (transposed result):\n[", result->ne[0], result->ne[1]);
printf("mul mat (%d x %d) (transposed result):\n[", (int) result->ne[0], (int) result->ne[1]);
for (int j = 0; j < result->ne[1] /* rows */; j++) {
if (j > 0) {
printf("\n");
Expand Down
17 changes: 17 additions & 0 deletions include/ggml/ggml.h
Expand Up @@ -454,6 +454,8 @@ extern "C" {
GGML_OP_POOL_2D,
GGML_OP_UPSCALE, // nearest interpolate
GGML_OP_PAD,
GGML_OP_ARANGE,
GGML_OP_TIMESTEP_EMBEDDING,
GGML_OP_ARGSORT,
GGML_OP_LEAKY_RELU,

Expand Down Expand Up @@ -1661,6 +1663,15 @@ extern "C" {
int p2,
int p3);

// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
// timesteps: [N,]
// return: [N, dim]
GGML_API struct ggml_tensor * ggml_timestep_embedding(
struct ggml_context * ctx,
struct ggml_tensor * timesteps,
int dim,
int max_period);

// sort rows
enum ggml_sort_order {
GGML_SORT_ORDER_ASC,
Expand All @@ -1672,6 +1683,12 @@ extern "C" {
struct ggml_tensor * a,
enum ggml_sort_order order);

GGML_API struct ggml_tensor * ggml_arange(
struct ggml_context * ctx,
float start,
float stop,
float step);

// top k elements per row
GGML_API struct ggml_tensor * ggml_top_k(
struct ggml_context * ctx,
Expand Down