Skip to content

Commit 30626ea

Browse files
examples : fix simple (#770)
* Update README.md Correcting matrix multiplication expected result. * Update simple-ctx.cpp Fix incorrect striding through output. * simple : update readme --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 1bf5e23 commit 30626ea

File tree

2 files changed

+43
-7
lines changed

2 files changed

+43
-7
lines changed

examples/simple/README.md

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
This example simply performs a matrix multiplication, solely for the purpose of demonstrating a basic usage of ggml and backend handling. The code is commented to help understand what each part does.
44

5+
Traditional matrix multiplication goes like this (multiply row-by-column):
6+
7+
$$
8+
A \times B = C
9+
$$
10+
511
$$
612
\begin{bmatrix}
713
2 & 8 \\
@@ -16,9 +22,39 @@ $$
1622
\end{bmatrix}
1723
\=
1824
\begin{bmatrix}
19-
60 & 110 & 54 & 29 \\
20-
55 & 90 & 126 & 28 \\
21-
50 & 54 & 42 & 64 \\
25+
60 & 90 & 42 \\
26+
55 & 54 & 29 \\
27+
50 & 54 & 28 \\
28+
110 & 126 & 64 \\
29+
\end{bmatrix}
30+
$$
31+
32+
In `ggml`, we pass the matrix $B$ in transposed form and multiply row-by-row. The result $C$ is also transposed:
33+
34+
$$
35+
ggml\\_mul\\_mat(A, B^T) = C^T
36+
$$
37+
38+
$$
39+
ggml\\_mul\\_mat(
40+
\begin{bmatrix}
41+
2 & 8 \\
42+
5 & 1 \\
43+
4 & 2 \\
44+
8 & 6 \\
45+
\end{bmatrix}
46+
,
47+
\begin{bmatrix}
48+
10 & 5 \\
49+
9 & 9 \\
50+
5 & 4 \\
51+
\end{bmatrix}
52+
)
53+
\=
54+
\begin{bmatrix}
55+
60 & 55 & 50 & 110 \\
56+
90 & 54 & 54 & 126 \\
57+
42 & 29 & 28 & 64 \\
2258
\end{bmatrix}
2359
$$
2460

examples/simple/simple-ctx.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ int main(void) {
104104
memcpy(out_data.data(), result->data, ggml_nbytes(result));
105105

106106
// expected result:
107-
// [ 60.00 110.00 54.00 29.00
108-
// 55.00 90.00 126.00 28.00
109-
// 50.00 54.00 42.00 64.00 ]
107+
// [ 60.00 55.00 50.00 110.00
108+
// 90.00 54.00 54.00 126.00
109+
// 42.00 29.00 28.00 64.00 ]
110110

111111
printf("mul mat (%d x %d) (transposed result):\n[", (int) result->ne[0], (int) result->ne[1]);
112112
for (int j = 0; j < result->ne[1] /* rows */; j++) {
@@ -115,7 +115,7 @@ int main(void) {
115115
}
116116

117117
for (int i = 0; i < result->ne[0] /* cols */; i++) {
118-
printf(" %.2f", out_data[i * result->ne[1] + j]);
118+
printf(" %.2f", out_data[j * result->ne[0] + i]);
119119
}
120120
}
121121
printf(" ]\n");

0 commit comments

Comments
 (0)