Skip to content

Commit

Permalink
examples : fix simple (#770)
Browse files Browse the repository at this point in the history
* Update README.md

Correcting matrix multiplication expected result.

* Update simple-ctx.cpp

Fix incorrect striding through output.

* simple : update readme

---------

Co-authored-by: Georgi Gerganov <[email protected]>
  • Loading branch information
wilderfield and ggerganov committed Mar 22, 2024
1 parent 1bf5e23 commit 30626ea
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 7 deletions.
42 changes: 39 additions & 3 deletions examples/simple/README.md
Expand Up @@ -2,6 +2,12 @@

This example simply performs a matrix multiplication, solely for the purpose of demonstrating a basic usage of ggml and backend handling. The code is commented to help understand what each part does.

Traditional matrix multiplication goes like this (multiply row-by-column):

$$
A \times B = C
$$

$$
\begin{bmatrix}
2 & 8 \\
Expand All @@ -16,9 +22,39 @@ $$
\end{bmatrix}
\=
\begin{bmatrix}
60 & 110 & 54 & 29 \\
55 & 90 & 126 & 28 \\
50 & 54 & 42 & 64 \\
60 & 90 & 42 \\
55 & 54 & 29 \\
50 & 54 & 28 \\
110 & 126 & 64 \\
\end{bmatrix}
$$

In `ggml`, we pass the matrix $B$ in transposed form and multiply row-by-row. The result $C$ is also transposed:

$$
ggml\\_mul\\_mat(A, B^T) = C^T
$$

$$
ggml\\_mul\\_mat(
\begin{bmatrix}
2 & 8 \\
5 & 1 \\
4 & 2 \\
8 & 6 \\
\end{bmatrix}
,
\begin{bmatrix}
10 & 5 \\
9 & 9 \\
5 & 4 \\
\end{bmatrix}
)
\=
\begin{bmatrix}
60 & 55 & 50 & 110 \\
90 & 54 & 54 & 126 \\
42 & 29 & 28 & 64 \\
\end{bmatrix}
$$

Expand Down
8 changes: 4 additions & 4 deletions examples/simple/simple-ctx.cpp
Expand Up @@ -104,9 +104,9 @@ int main(void) {
memcpy(out_data.data(), result->data, ggml_nbytes(result));

// expected result:
// [ 60.00 110.00 54.00 29.00
// 55.00 90.00 126.00 28.00
// 50.00 54.00 42.00 64.00 ]
// [ 60.00 55.00 50.00 110.00
// 90.00 54.00 54.00 126.00
// 42.00 29.00 28.00 64.00 ]

printf("mul mat (%d x %d) (transposed result):\n[", (int) result->ne[0], (int) result->ne[1]);
for (int j = 0; j < result->ne[1] /* rows */; j++) {
Expand All @@ -115,7 +115,7 @@ int main(void) {
}

for (int i = 0; i < result->ne[0] /* cols */; i++) {
printf(" %.2f", out_data[i * result->ne[1] + j]);
printf(" %.2f", out_data[j * result->ne[0] + i]);
}
}
printf(" ]\n");
Expand Down

0 comments on commit 30626ea

Please sign in to comment.