Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix examples/simple/simple-ctx #770

Merged
merged 3 commits into from Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
42 changes: 39 additions & 3 deletions examples/simple/README.md
Expand Up @@ -2,6 +2,12 @@

This example simply performs a matrix multiplication, solely for the purpose of demonstrating a basic usage of ggml and backend handling. The code is commented to help understand what each part does.

Traditional matrix multiplication goes like this (multiply row-by-column):

$$
A \times B = C
$$

$$
\begin{bmatrix}
2 & 8 \\
Expand All @@ -16,9 +22,39 @@ $$
\end{bmatrix}
\=
\begin{bmatrix}
60 & 110 & 54 & 29 \\
55 & 90 & 126 & 28 \\
50 & 54 & 42 & 64 \\
60 & 90 & 42 \\
55 & 54 & 29 \\
50 & 54 & 28 \\
110 & 126 & 64 \\
\end{bmatrix}
$$

In `ggml`, we pass the matrix $B$ in transposed form and multiply row-by-row. The result $C$ is also transposed:

$$
ggml\\_mul\\_mat(A, B^T) = C^T
$$

$$
ggml\\_mul\\_mat(
\begin{bmatrix}
2 & 8 \\
5 & 1 \\
4 & 2 \\
8 & 6 \\
\end{bmatrix}
,
\begin{bmatrix}
10 & 5 \\
9 & 9 \\
5 & 4 \\
\end{bmatrix}
)
\=
\begin{bmatrix}
60 & 55 & 50 & 110 \\
90 & 54 & 54 & 126 \\
42 & 29 & 28 & 64 \\
\end{bmatrix}
$$

Expand Down
8 changes: 4 additions & 4 deletions examples/simple/simple-ctx.cpp
Expand Up @@ -104,9 +104,9 @@ int main(void) {
memcpy(out_data.data(), result->data, ggml_nbytes(result));

// expected result:
// [ 60.00 110.00 54.00 29.00
// 55.00 90.00 126.00 28.00
// 50.00 54.00 42.00 64.00 ]
// [ 60.00 55.00 50.00 110.00
// 90.00 54.00 54.00 126.00
// 42.00 29.00 28.00 64.00 ]

printf("mul mat (%d x %d) (transposed result):\n[", (int) result->ne[0], (int) result->ne[1]);
for (int j = 0; j < result->ne[1] /* rows */; j++) {
Expand All @@ -115,7 +115,7 @@ int main(void) {
}

for (int i = 0; i < result->ne[0] /* cols */; i++) {
printf(" %.2f", out_data[i * result->ne[1] + j]);
printf(" %.2f", out_data[j * result->ne[0] + i]);
}
}
printf(" ]\n");
Expand Down