From 0c7e21ebbbfe504734a08949daab827f556fa19f Mon Sep 17 00:00:00 2001 From: snsinfu Date: Wed, 21 Oct 2020 02:20:35 +0900 Subject: [PATCH] Add matrix --- API.md | 31 ++++++ include/point.hpp | 220 +++++++++++++++++++++++++++++++++++++++ tests/Makefile | 4 +- tests/matrix_test.cc | 241 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 495 insertions(+), 1 deletion(-) create mode 100644 tests/matrix_test.cc diff --git a/API.md b/API.md index fc3d6d8..5a1e63b 100644 --- a/API.md +++ b/API.md @@ -61,3 +61,34 @@ For `cxx::point p, q` and `double x, y, z`, `std::istream in` and | p.squared\_distance(q), squared\_distance(p, q) | squared distance between p and q | | in >> q | read "x y z" from in | | out << p | write "x y z" to out | + +## cxx::matrix + +`matrix` is a three-dimensional square matrix. It supports linear operations, +matrix products, vector left-products and iostream. + +For `cxx::matrix A, B`, `cxx::vector v`, `double a, b, c, r, s, t, x, y, z`, +`size_t i, j`, `std::istream in` and `std::ostream out`: + +| Expression | Meaning | +| ----------------------------------- | --------------------------------- | +| matrix{} | zero matrix | +| matrix{a, b, c} | diagonal matrix | +| matrix{a, b, c, r, s, t, x, y, z} | full 3-by-3 matrix | +| A(i, j) | i,j-component of A | +| A += B | add v to u | +| A -= B | subtract v from u | +| A \*= a | multiply all elements of A by a | +| A /= a | divide all elements of A by a | +| +A | copy of A | +| -A | negated copy of A | +| A + B | sum of A and B | +| A - B | difference of A and B | +| A \* a | copy of A scaled by a | +| a \* A | copy of A scaled by a | +| A / a | copy of A scaled by 1/a | +| A.dot(B), dot(A, B) | matrix product of A and B | +| A.dot(v), dot(A, v) | matrix product of A and v | +| A.transpose(), transpose(A) | transposed copy of A | +| in >> A | read "a b c r s t x y z" from in | +| out << A | write "a b c r s t x y z" to out | diff --git a/include/point.hpp b/include/point.hpp index 438f320..43889a6 100644 --- a/include/point.hpp +++ b/include/point.hpp @@ -371,6 +371,226 @@ namespace cxx { return pa.squared_distance(pb); } + + // matrix is a three-dimensional square matrix. + struct matrix + { + double elements[3][3] = {}; + + // Default constructor initializes all elements to zero. + matrix() = default; + + // Initializes matrix with diagonal elements. + inline matrix(double a, double b, double c) noexcept + : elements{{a, 0, 0}, {0, b, 0}, {0, 0, c}} + { + } + + // Initializes matrix with full elements in row-major order. + inline matrix( + double a11, + double a12, + double a13, + double a21, + double a22, + double a23, + double a31, + double a32, + double a33 + ) noexcept + : elements{{a11, a12, a13}, {a21, a22, a23}, {a31, a32, a33}} + { + } + + // Returns a reference to the i,j element. + inline double& operator()(std::size_t i, std::size_t j) + { + return elements[i][j]; + } + + inline double const& operator()(std::size_t i, std::size_t j) const + { + return elements[i][j]; + } + + // Element-wise addition. + inline matrix& operator+=(matrix const& other) noexcept + { + for (std::size_t i = 0; i < 9; i++) { + (&elements[0][0])[i] += (&other.elements[0][0])[i]; + } + return *this; + } + + // Element-wise subtraction. + inline matrix& operator-=(matrix const& other) noexcept + { + for (std::size_t i = 0; i < 9; i++) { + (&elements[0][0])[i] -= (&other.elements[0][0])[i]; + } + return *this; + } + + // Element-wise multiplication by a scalar. + inline matrix& operator*=(double mult) noexcept + { + for (std::size_t i = 0; i < 9; i++) { + (&elements[0][0])[i] *= mult; + } + return *this; + } + + // Element-wise division by a scalar. + inline matrix& operator/=(double divisor) + { + return *this *= 1 / divisor; + } + + inline vector row(std::size_t i) const noexcept + { + return {elements[i][0], elements[i][1], elements[i][2]}; + } + + inline vector column(std::size_t i) const noexcept + { + return {elements[0][i], elements[1][i], elements[2][i]}; + } + + // Matrix product. + inline matrix dot(matrix const& other) const noexcept + { + return { + row(0).dot(other.column(0)), + row(0).dot(other.column(1)), + row(0).dot(other.column(2)), + row(1).dot(other.column(0)), + row(1).dot(other.column(1)), + row(1).dot(other.column(2)), + row(2).dot(other.column(0)), + row(2).dot(other.column(1)), + row(2).dot(other.column(2)) + }; + } + + // Matrix product with vector. + inline vector dot(vector const& vec) const noexcept + { + return {row(0).dot(vec), row(1).dot(vec), row(2).dot(vec)}; + } + + // Returns a transposed copy. + inline matrix transpose() const noexcept + { + return { + elements[0][0], elements[1][0], elements[2][0], + elements[0][1], elements[1][1], elements[2][1], + elements[0][2], elements[1][2], elements[2][2] + }; + } + }; + + // Returns a copy of mat. + inline matrix operator+(matrix const& mat) noexcept + { + return mat; + } + + // Returns a negated copy of mat. + inline matrix operator-(matrix const& mat) noexcept + { + return { + -mat(0, 0), -mat(0, 1), -mat(0, 2), + -mat(1, 0), -mat(1, 1), -mat(1, 2), + -mat(2, 0), -mat(2, 1), -mat(2, 2) + }; + } + + // Returns the sum of two matrices. + inline matrix operator+(matrix const& lhs, matrix const& rhs) noexcept + { + return matrix{lhs} += rhs; + } + + // Returns the difference of two matrices. + inline matrix operator-(matrix const& lhs, matrix const& rhs) noexcept + { + return matrix{lhs} -= rhs; + } + + // Returns scalar multiplication of a matrix. + inline matrix operator*(matrix const& lhs, double rhs) noexcept + { + return matrix{lhs} *= rhs; + } + + // Returns scalar multiplication of a matrix. + inline matrix operator*(double lhs, matrix const& rhs) noexcept + { + return matrix{rhs} *= lhs; + } + + // Returns scalar quotient of a matrix divided by scalar. + inline matrix operator/(matrix const& lhs, double rhs) noexcept + { + return matrix{lhs} *= 1 / rhs; + } + + // Returns matrix product of two matrices. + inline matrix dot(matrix const& lhs, matrix const& rhs) noexcept + { + return lhs.dot(rhs); + } + + // Returns matrix transformation of a vector. + inline vector dot(matrix const& lhs, vector const& rhs) noexcept + { + return lhs.dot(rhs); + } + + // Returns a transposed copy of mat. + inline matrix transpose(matrix const& mat) noexcept + { + return mat.transpose(); + } + + template + std::basic_ostream& operator<<( + std::basic_ostream& os, + matrix const& mat + ) + { + using sentry_type = typename std::basic_ostream::sentry; + + if (sentry_type sentry{os}) { + Char const delim = os.widen(' '); + Char const newline = os.widen('\n'); + + os << mat(0, 0) << delim << mat(0, 1) << delim << mat(0, 2); + os << newline; + os << mat(1, 0) << delim << mat(1, 1) << delim << mat(1, 2); + os << newline; + os << mat(2, 0) << delim << mat(2, 1) << delim << mat(2, 2); + } + + return os; + } + + template + std::basic_istream& operator>>( + std::basic_istream& is, + matrix& mat + ) + { + using sentry_type = typename std::basic_istream::sentry; + + if (sentry_type sentry{is}) { + is >> mat(0, 0) >> mat(0, 1) >> mat(0, 2); + is >> mat(1, 0) >> mat(1, 1) >> mat(1, 2); + is >> mat(2, 0) >> mat(2, 1) >> mat(2, 2); + } + + return is; + } } #endif diff --git a/tests/Makefile b/tests/Makefile index 1f52632..3886c0f 100644 --- a/tests/Makefile +++ b/tests/Makefile @@ -15,7 +15,8 @@ OBJECTS = \ main.o \ coordinates_test.o \ vector_test.o \ - point_test.o + point_test.o \ + matrix_test.o ARTIFACTS = \ $(OBJECTS) \ @@ -40,3 +41,4 @@ main.o: catch.hpp coordinates_test.o: catch.hpp ../include/point.hpp vector_test.o: catch.hpp ../include/point.hpp point_test.o: catch.hpp ../include/point.hpp +matrix_test.o: catch.hpp ../include/point.hpp diff --git a/tests/matrix_test.cc b/tests/matrix_test.cc new file mode 100644 index 0000000..780110c --- /dev/null +++ b/tests/matrix_test.cc @@ -0,0 +1,241 @@ +// Copyright snsinfu 2020. +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE.txt or copy at http://www.boost.org/LICENSE_1_0.txt) + +#include + +#include "catch.hpp" +#include "../include/point.hpp" + + +TEST_CASE("matrix - is default constructible as the null matrix") +{ + cxx::matrix mat; + CHECK(mat(0, 0) == 0); + CHECK(mat(0, 1) == 0); + CHECK(mat(0, 2) == 0); + CHECK(mat(1, 0) == 0); + CHECK(mat(1, 1) == 0); + CHECK(mat(1, 2) == 0); + CHECK(mat(2, 0) == 0); + CHECK(mat(2, 1) == 0); + CHECK(mat(2, 2) == 0); +} + +TEST_CASE("matrix - is constructible from diagonal elements") +{ + cxx::matrix mat = {1, 2, 3}; + CHECK(mat(0, 0) == 1); + CHECK(mat(0, 1) == 0); + CHECK(mat(0, 2) == 0); + CHECK(mat(1, 0) == 0); + CHECK(mat(1, 1) == 2); + CHECK(mat(1, 2) == 0); + CHECK(mat(2, 0) == 0); + CHECK(mat(2, 1) == 0); + CHECK(mat(2, 2) == 3); +} + +TEST_CASE("matrix - is constructible from full elements") +{ + cxx::matrix mat = { + 1, 2, 3, + 4, 5, 6, + 7, 8, 9 + }; + CHECK(mat(0, 0) == 1); + CHECK(mat(0, 1) == 2); + CHECK(mat(0, 2) == 3); + CHECK(mat(1, 0) == 4); + CHECK(mat(1, 1) == 5); + CHECK(mat(1, 2) == 6); + CHECK(mat(2, 0) == 7); + CHECK(mat(2, 1) == 8); + CHECK(mat(2, 2) == 9); +} + +TEST_CASE("matrix - provides mutable element references") +{ + cxx::matrix mat; + mat(0, 1) = 4; + mat(2, 0) = 5; + mat(1, 2) = 6; + CHECK(mat(0, 1) == 4); + CHECK(mat(2, 0) == 5); + CHECK(mat(1, 2) == 6); +} + +TEST_CASE("matrix - supports element-wise in-place addition") +{ + cxx::matrix mat = { + 1, 2, 3, + 4, 5, 6, + 7, 8, 9 + }; + mat += { + 3, 5, 7, + 6, 8, 1, + 9, 2, 4 + }; + CHECK(mat(0, 0) == 4); + CHECK(mat(0, 1) == 7); + CHECK(mat(0, 2) == 10); + CHECK(mat(1, 0) == 10); + CHECK(mat(1, 1) == 13); + CHECK(mat(1, 2) == 7); + CHECK(mat(2, 0) == 16); + CHECK(mat(2, 1) == 10); + CHECK(mat(2, 2) == 13); +} + +TEST_CASE("matrix - supports element-wise in-place subtraction") +{ + cxx::matrix mat = { + 1, 2, 3, + 4, 5, 6, + 7, 8, 9 + }; + mat -= { + 3, 5, 7, + 6, 8, 1, + 9, 2, 4 + }; + CHECK(mat(0, 0) == -2); + CHECK(mat(0, 1) == -3); + CHECK(mat(0, 2) == -4); + CHECK(mat(1, 0) == -2); + CHECK(mat(1, 1) == -3); + CHECK(mat(1, 2) == 5); + CHECK(mat(2, 0) == -2); + CHECK(mat(2, 1) == 6); + CHECK(mat(2, 2) == 5); +} + +TEST_CASE("matrix - supports element-wise in-place scaling") +{ + cxx::matrix mat = { + 1, 2, 3, + 4, 5, 6, + 7, 8, 9 + }; + mat *= 2; + CHECK(mat(0, 0) == 2); + CHECK(mat(0, 1) == 4); + CHECK(mat(0, 2) == 6); + CHECK(mat(1, 0) == 8); + CHECK(mat(1, 1) == 10); + CHECK(mat(1, 2) == 12); + CHECK(mat(2, 0) == 14); + CHECK(mat(2, 1) == 16); + CHECK(mat(2, 2) == 18); +} + +TEST_CASE("matrix - supports element-wise in-place division") +{ + cxx::matrix mat = { + 1, 2, 3, + 4, 5, 6, + 7, 8, 9 + }; + mat /= 2; + CHECK(mat(0, 0) == Approx(0.5)); + CHECK(mat(0, 1) == Approx(1.0)); + CHECK(mat(0, 2) == Approx(1.5)); + CHECK(mat(1, 0) == Approx(2.0)); + CHECK(mat(1, 1) == Approx(2.5)); + CHECK(mat(1, 2) == Approx(3.0)); + CHECK(mat(2, 0) == Approx(3.5)); + CHECK(mat(2, 1) == Approx(4.0)); + CHECK(mat(2, 2) == Approx(4.5)); +} + +TEST_CASE("matrix::dot - computes matrix-matrix product") +{ + cxx::matrix const mat1 = { + 1, 2, 3, + 4, 5, 6, + 7, 8, 9 + }; + cxx::matrix const mat2 = { + 3, 5, 7, + 6, 8, 1, + 9, 2, 4 + }; + cxx::matrix const expected = { + 42, 27, 21, + 96, 72, 57, + 150, 117, 93 + }; + cxx::matrix const actual = mat1.dot(mat2); + CHECK(actual(0, 0) == Approx(expected(0, 0))); + CHECK(actual(0, 1) == Approx(expected(0, 1))); + CHECK(actual(0, 2) == Approx(expected(0, 2))); + CHECK(actual(1, 0) == Approx(expected(1, 0))); + CHECK(actual(1, 1) == Approx(expected(1, 1))); + CHECK(actual(1, 2) == Approx(expected(1, 2))); + CHECK(actual(2, 0) == Approx(expected(2, 0))); + CHECK(actual(2, 1) == Approx(expected(2, 1))); + CHECK(actual(2, 2) == Approx(expected(2, 2))); +} + +TEST_CASE("matrix::dot - computes matrix-vector product") +{ + cxx::matrix const mat = { + 1, 2, 3, + 4, 5, 6, + 7, 8, 9 + }; + cxx::vector const vec = {10, 11, 12}; + cxx::vector const expected = {68, 167, 266}; + cxx::vector const actual = mat.dot(vec); + CHECK(actual.x == Approx(expected.x)); + CHECK(actual.y == Approx(expected.y)); + CHECK(actual.z == Approx(expected.z)); +} + +TEST_CASE("matrix::transpose - returns transposed copy") +{ + cxx::matrix const mat = { + 1, 2, 3, + 4, 5, 6, + 7, 8, 9 + }; + cxx::matrix const tr = mat.transpose(); + CHECK(tr(0, 0) == mat(0, 0)); + CHECK(tr(0, 1) == mat(1, 0)); + CHECK(tr(0, 2) == mat(2, 0)); + CHECK(tr(1, 0) == mat(0, 1)); + CHECK(tr(1, 1) == mat(1, 1)); + CHECK(tr(1, 2) == mat(2, 1)); + CHECK(tr(2, 0) == mat(0, 2)); + CHECK(tr(2, 1) == mat(1, 2)); + CHECK(tr(2, 2) == mat(2, 2)); +} + +TEST_CASE("matrix - is formattable") +{ + cxx::matrix const mat = { + 1.2, 3.4, 5.6, + 7.8, 9.1, 2.3, + 4.5, 6.7, 8.9 + }; + std::ostringstream str; + str << mat; + CHECK(str.str() == "1.2 3.4 5.6\n7.8 9.1 2.3\n4.5 6.7 8.9"); +} + +TEST_CASE("matrix - is parsable") +{ + cxx::matrix mat; + std::istringstream str("1.2 3.4 5.6\n7.8 9.1 2.3\n4.5 6.7 8.9"); + str >> mat; + CHECK(mat(0, 0) == Approx(1.2)); + CHECK(mat(0, 1) == Approx(3.4)); + CHECK(mat(0, 2) == Approx(5.6)); + CHECK(mat(1, 0) == Approx(7.8)); + CHECK(mat(1, 1) == Approx(9.1)); + CHECK(mat(1, 2) == Approx(2.3)); + CHECK(mat(2, 0) == Approx(4.5)); + CHECK(mat(2, 1) == Approx(6.7)); + CHECK(mat(2, 2) == Approx(8.9)); +}