diff --git a/examples/simple/README.md b/examples/simple/README.md index ba3a4fd1a..28c549fc9 100644 --- a/examples/simple/README.md +++ b/examples/simple/README.md @@ -2,6 +2,12 @@ This example simply performs a matrix multiplication, solely for the purpose of demonstrating a basic usage of ggml and backend handling. The code is commented to help understand what each part does. +Traditional matrix multiplication goes like this (multiply row-by-column): + +$$ +A \times B = C +$$ + $$ \begin{bmatrix} 2 & 8 \\ @@ -16,9 +22,39 @@ $$ \end{bmatrix} \= \begin{bmatrix} -60 & 110 & 54 & 29 \\ -55 & 90 & 126 & 28 \\ -50 & 54 & 42 & 64 \\ +60 & 90 & 42 \\ +55 & 54 & 29 \\ +50 & 54 & 28 \\ +110 & 126 & 64 \\ +\end{bmatrix} +$$ + +In `ggml`, we pass the matrix $B$ in transposed form and multiply row-by-row. The result $C$ is also transposed: + +$$ +ggml\\_mul\\_mat(A, B^T) = C^T +$$ + +$$ +ggml\\_mul\\_mat( +\begin{bmatrix} +2 & 8 \\ +5 & 1 \\ +4 & 2 \\ +8 & 6 \\ +\end{bmatrix} +, +\begin{bmatrix} +10 & 5 \\ +9 & 9 \\ +5 & 4 \\ +\end{bmatrix} +) +\= +\begin{bmatrix} +60 & 55 & 50 & 110 \\ +90 & 54 & 54 & 126 \\ +42 & 29 & 28 & 64 \\ \end{bmatrix} $$ diff --git a/examples/simple/simple-ctx.cpp b/examples/simple/simple-ctx.cpp index d331a4c1f..b2d4e4ba5 100644 --- a/examples/simple/simple-ctx.cpp +++ b/examples/simple/simple-ctx.cpp @@ -104,9 +104,9 @@ int main(void) { memcpy(out_data.data(), result->data, ggml_nbytes(result)); // expected result: - // [ 60.00 110.00 54.00 29.00 - // 55.00 90.00 126.00 28.00 - // 50.00 54.00 42.00 64.00 ] + // [ 60.00 55.00 50.00 110.00 + // 90.00 54.00 54.00 126.00 + // 42.00 29.00 28.00 64.00 ] printf("mul mat (%d x %d) (transposed result):\n[", (int) result->ne[0], (int) result->ne[1]); for (int j = 0; j < result->ne[1] /* rows */; j++) { @@ -115,7 +115,7 @@ int main(void) { } for (int i = 0; i < result->ne[0] /* cols */; i++) { - printf(" %.2f", out_data[i * result->ne[1] + j]); + printf(" %.2f", out_data[j * result->ne[0] + i]); } } printf(" ]\n");