From e4fa28b151eeab5ea36fad8e8acc36b1efc5824c Mon Sep 17 00:00:00 2001 From: Oscar Higgott Date: Sun, 5 Sep 2021 23:03:45 +0100 Subject: [PATCH] Add docstrings to pybind bindings and add doctest examples to docstrings. Fixes #13 --- README.md | 2 +- src/pymatching/__init__.py | 2 +- src/pymatching/bindings.cpp | 612 ++++++++++++++++++++++++++++++++-- src/pymatching/lemon_mwpm.cpp | 70 ++-- src/pymatching/lemon_mwpm.h | 27 +- src/pymatching/matching.py | 67 +++- 6 files changed, 715 insertions(+), 65 deletions(-) diff --git a/README.md b/README.md index 253c83a0..e5087af0 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ pip install -e ./PyMatching ``` The installation may take a few minutes since the C++ extension has to be compiled. If you'd also like to run the tests, first install [pytest](https://docs.pytest.org/en/stable/), and then run: ``` -pytest ./PyMatching/tests +pytest ./PyMatching/tests ./PyMatching/src ``` ## Usage diff --git a/src/pymatching/__init__.py b/src/pymatching/__init__.py index 576034d2..28d36452 100644 --- a/src/pymatching/__init__.py +++ b/src/pymatching/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pymatching._cpp_mwpm import (randomize, set_seed, rand_float) +from pymatching._cpp_mwpm import (randomize, set_seed, rand_float, BlossomFailureException) from pymatching.matching import * randomize() # Set random seed using std::random_device diff --git a/src/pymatching/bindings.cpp b/src/pymatching/bindings.cpp index b8c8bdf7..900039a2 100644 --- a/src/pymatching/bindings.cpp +++ b/src/pymatching/bindings.cpp @@ -23,34 +23,488 @@ using namespace pybind11::literals; PYBIND11_MODULE(_cpp_mwpm, m) { py::class_(m, "WeightedEdgeData") + .def(py::init<>(), u8R"( + Initialises a WeightedEdgeData object + )") + .def(py::init, double, double, bool>(), + "qubit_ids"_a, "weight"_a, "error_probability"_a, "has_error_probability"_a, u8R"( + Initialises a WeightedEdgeData object + + Parameters + ---------- + qubit_ids: set[int] + A set of qubit ids + weight: float + The edge weight + error_probability: float + The probability that the edge flips. If no error_probability is associated + with the edge, set to -1. + has_error_probability: bool + Whether the edge has an error_probability + )") + .def("__repr__", &WeightedEdgeData::repr) .def_readwrite("qubit_ids", &WeightedEdgeData::qubit_ids) .def_readwrite("weight", &WeightedEdgeData::weight) .def_readwrite("error_probability", &WeightedEdgeData::error_probability) .def_readwrite("has_error_probability", &WeightedEdgeData::has_error_probability); - py::class_(m, "MatchingGraph") - .def(py::init<>()) - .def(py::init&>(), "num_detectors"_a, "boundary"_a) + py::class_(m, "MatchingGraph", u8R"( + A matching graph to be decoded with minimum-weight perfect matching + + Examples + -------- + >>> import math + >>> from pymatching._cpp_mwpm import MatchingGraph, set_seed + >>> set_seed(0) + >>> graph = MatchingGraph() + >>> graph.add_edge(0, 1, qubit_ids={0}, weight=math.log((1-0.1)/0.1), error_probability=0.1, has_error_probability=True) + >>> graph.add_edge(1, 2, qubit_ids={1}, weight=math.log((1-0.2)/0.2), error_probability=0.2, has_error_probability=True) + >>> graph.add_edge(2, 3, qubit_ids={2}, weight=math.log((1-0.4)/0.4), error_probability=0.4, has_error_probability=True) + >>> graph.add_edge(3, 4, qubit_ids={3}, weight=math.log((1-0.05)/0.05), error_probability=0.05, has_error_probability=True) + >>> graph.set_boundary({0, 4}) + >>> graph + + >>> graph.get_path(1, 4) + [1, 2, 3, 4] + >>> graph.get_nearest_neighbours(source=2, num_neighbours=2, defect_id=[-1, 0, 1, 2, -1]) + [(3, 0.4054651081081642), (1, 1.3862943611198906)] + )") + .def(py::init<>(), u8R"( + Initialises a `pymatching._cpp_mwpm.MatchingGraph` + )") + .def(py::init&>(), "num_detectors"_a, "boundary"_a, u8R"( + Initialises a `pymatching._cpp_mwpm.MatchingGraph` + + Parameters + ---------- + num_detectors: int + The number of detectors in the matching graph. A detector is a node that is not + a boundary node, and has the same meaning as in Stim. + boundary: set[int] + The ids of the boundary nodes in the matching graph. + + )") .def("all_edges_have_error_probabilities", - &MatchingGraph::AllEdgesHaveErrorProbabilities) + &MatchingGraph::AllEdgesHaveErrorProbabilities, u8R"( + Outputs whether or not all edges have been assigned error probabilities + + Returns + ------- + bool + True if all edges have been assigned error probabilities, else False + + Examples + -------- + >>> import math + >>> from pymatching._cpp_mwpm import MatchingGraph + >>> graph = MatchingGraph() + >>> graph.add_edge(0, 1, {0}, math.log((1-0.1)/0.1), 0.1, True) + >>> graph.all_edges_have_error_probabilities() + True + >>> graph.add_edge(1, 2, {1}, 0, -1, False) + >>> graph.all_edges_have_error_probabilities() + False + )") .def("add_edge", &MatchingGraph::AddEdge, "node1"_a, "node2"_a, "qubit_ids"_a, - "weight"_a, "error_probability"_a=-1.0, "has_error_probability"_a=false) - .def("add_noise", &MatchingGraph::AddNoise) - .def("get_boundary", &MatchingGraph::GetBoundary) - .def("set_boundary", &MatchingGraph::SetBoundary, "boundary"_a) - .def("get_edges", &MatchingGraph::GetEdges) + "weight"_a, "error_probability"_a=-1.0, "has_error_probability"_a=false, u8R"( + Adds an edge to the matching graph + + Parameters + ---------- + node1: int + The id of the first node in the edge to be added + node2: int + The id of the second node in the edge to be added + qubit_ids: set[int] + The ids of the qubits associated with the edge + weight: float + The weight of the edge + error_probability: float + The probability that the edge is flipped. This parameter is optional + and should be set to -1 (the default value) if no error probability + needs to be set for the edge + has_error_probability: bool + Whether or not the edge has been given an error probability + + Examples + -------- + >>> import math + >>> from pymatching._cpp_mwpm import MatchingGraph + >>> graph = MatchingGraph() + >>> graph.add_edge(0, 1, {0, 1}, math.log((1-0.1)/0.1), 0.1, True) + >>> graph.add_edge(1, 2, {2}, math.log((1-0.1)/0.1), 0.1, True) + >>> graph.add_edge(2, 3, {3}, math.log((1-0.2)/0.2), 0.2, True) + >>> graph.add_edge(3, 0, set(), 0, -1, False) + >>> graph.set_boundary({0, 3}) + >>> graph.get_num_qubits() + 4 + >>> graph.get_num_edges() + 4 + )") + .def("add_noise", &MatchingGraph::AddNoise, u8R"( + Flips each edge independently with the associated error probability, + returning the noise vector and syndrome. + + Returns + ------- + numpy.ndarray + A binary array (of dtype numpy.uint8) specifying whether each + qubit has been flipped. Element i is one if the qubit with + qubit_id==i has been flipped, and is zero otherwise. + numpy.ndarray + A binary array (of dtype numpy.uint8) with length equal to the + number of nodes in the matching graph, specifying the syndrome. + Element i is one if node i is a defect, and zero otherwise. Note + that boundary nodes are never defects (their syndrome is zero). + + Examples + -------- + >>> import math + >>> from pymatching._cpp_mwpm import MatchingGraph, set_seed + >>> set_seed(0) + >>> graph = MatchingGraph() + >>> graph.add_edge(0, 1, {0}, math.log((1-0.4)/0.4), 0.4, True) + >>> graph.add_edge(1, 2, {1}, math.log((1-0.0001)/0.0001), 0.0001, True) + >>> graph.add_edge(2, 3, {2}, math.log((1-0.45)/0.45), 0.45, True) + >>> graph.add_noise() + (array([0, 0, 0], dtype=uint8), array([0, 0, 0, 0], dtype=uint8)) + >>> graph.add_noise() + (array([0, 0, 1], dtype=uint8), array([0, 0, 1, 1], dtype=uint8)) + >>> graph.add_noise() + (array([1, 0, 1], dtype=uint8), array([1, 1, 1, 1], dtype=uint8)) + >>> graph.add_noise() + (array([0, 0, 0], dtype=uint8), array([0, 0, 0, 0], dtype=uint8)) + )") + .def("get_boundary", &MatchingGraph::GetBoundary, u8R"( + Get the ids of the boundary nodes + + Returns + ------- + set[int] + The ids of the boundary nodes + + Examples + -------- + >>> from pymatching._cpp_mwpm import MatchingGraph + >>> graph = MatchingGraph() + >>> graph.set_boundary({1,3,5}) + >>> graph.get_boundary() + {1, 3, 5} + )") + .def("set_boundary", &MatchingGraph::SetBoundary, "boundary"_a, u8R"( + Set the ids of the boundary nodes + + Parameters + ------- + boundary: set[int] + The ids of the boundary nodes + + Examples + -------- + >>> from pymatching._cpp_mwpm import MatchingGraph + >>> graph = MatchingGraph() + >>> graph.set_boundary({2,3,4}) + >>> graph.get_boundary() + {2, 3, 4} + )") + .def("get_edges", &MatchingGraph::GetEdges, u8R"( + Get the edges and edge data in the MatchingGraph + + Returns + ------- + list[tuple[int, int, WeightedEdgeData]] + A list of edges. Each edges is a tuple (n1, n2, edge_data) where n1 and n2 + are the ids of the first and second nodes in the edge, respectively, and + edge_data is the WeightedEdgeData associated with the edge. + + Examples + -------- + >>> import math + >>> from pymatching._cpp_mwpm import MatchingGraph + >>> graph = MatchingGraph() + >>> graph.add_edge(0, 1, {0}, math.log((1-0.4)/0.4), 0.4, True) + >>> graph.add_edge(1, 2, {1}, math.log((1-0.3)/0.3), 0.3, True) + >>> graph.get_edges() + [(0, 1, pymatching._cpp_mwpm.WeightedEdgeData({0}, 0.405465, 0.4, 1)), (1, 2, pymatching._cpp_mwpm.WeightedEdgeData({1}, 0.847298, 0.3, 1))] + )") .def("get_nearest_neighbours", &MatchingGraph::GetNearestNeighbours, - "source"_a, "num_neighbours"_a, "defect_id"_a) - .def("get_path", &MatchingGraph::GetPath, "source"_a, "target"_a) - .def("distance", &MatchingGraph::Distance, "node1"_a, "node2"_a) - .def("shortest_path", &MatchingGraph::ShortestPath, "node1"_a, "node2"_a) - .def("qubit_ids", &MatchingGraph::QubitIDs, "node1"_a, "node2"_a) - .def("get_num_qubits", &MatchingGraph::GetNumQubits) - .def("get_num_nodes", &MatchingGraph::GetNumNodes) - .def("get_num_edges", &MatchingGraph::GetNumEdges) - .def("compute_all_pairs_shortest_paths", &MatchingGraph::ComputeAllPairsShortestPaths) - .def("has_computed_all_pairs_shortest_paths", &MatchingGraph::HasComputedAllPairsShortestPaths) - .def("get_num_connected_components", &MatchingGraph::GetNumConnectedComponents); + "source"_a, "num_neighbours"_a, "defect_id"_a, u8R"( + Find the nearest `num_neighbours` defects from a source node in the matching graph, + using a modified Dijkstra algorithm (local Dijkstra). + + Parameters + ---------- + source: int + The index of the source node + num_neighbours: int + The maximum number of defects to find in the matching graph (excluding the source node) + defect_id: list[int] + A list of length `MatchingGraph.get_num_nodes()` specifying the defect id of each node + in the matching graph. Node `i` satisfies `defect_id[i]>=0` if it is a defect and + `defect_id[i]==-1` otherwise. The value of `defect_id[i]` (if not -1) is the index + of the corresponding defect in the syndrome vector. + + Returns + ------- + list[tuple[int, float]] + A list of tuples of the form `(i, d)` where `i` is a node id of + a defect, and `d` is its distance from the source node (the sum of the + weights of the edges along the shortest path from the source to node `i`). + The list has `num_neighbours` elements/tuples, and the nodes in the list + are the `num_neighbours` closest defects from the source node. + + Examples + -------- + >>> from pymatching._cpp_mwpm import MatchingGraph + >>> graph = MatchingGraph() + >>> graph.add_edge(0, 1, {0}, 0.1) + >>> graph.add_edge(1, 2, {1}, 0.2) + >>> graph.add_edge(2, 3, {2}, 0.1) + >>> graph.add_edge(3, 4, {3}, 0.1) + >>> graph.add_edge(4, 5, {4}, 0.1) + >>> graph.get_nearest_neighbours(2, 0, [-1, 0, 1, 2, -1, 3]) + [] + >>> graph.get_nearest_neighbours(2, 1, [-1, 0, 1, 2, -1, 3]) + [(3, 0.1)] + >>> graph.get_nearest_neighbours(2, 2, [-1, 0, 1, 2, -1, 3]) + [(3, 0.1), (1, 0.2)] + >>> graph.get_nearest_neighbours(2, 3, [-1, 0, 1, 2, -1, 3]) + [(3, 0.1), (1, 0.2), (5, 0.30000000000000004)] + )") + .def("get_path", &MatchingGraph::GetPath, "source"_a, "target"_a, u8R"( + Find the nodes along the shortest path from a source to a target + node using Dijkstra's algorithm. + + Parameters + ---------- + source: int + The id of the source vertex + target: int + The id of the target vertex + + Returns + ------- + list[int] + A list of ids of the nodes along the shortest path from source to + target (including the source and target nodes). Elements `i` and `i+1` + of the list are nodes in the `i`th edge along the shortest path from + source to target. + + Examples + -------- + >>> from pymatching._cpp_mwpm import MatchingGraph + >>> graph = MatchingGraph() + >>> graph.add_edge(0, 1, {0}, 1) + >>> graph.add_edge(1, 2, {1}, 1) + >>> graph.add_edge(2, 3, {2}, 1) + >>> graph.add_edge(3, 4, {3}, 1) + >>> graph.add_edge(4, 5, {4}, 1) + >>> graph.get_path(2, 5) + [2, 3, 4, 5] + >>> graph.get_path(4, 1) + [4, 3, 2, 1] + >>> graph.get_path(5, 0) + [5, 4, 3, 2, 1, 0] + )") + .def("distance", &MatchingGraph::Distance, "node1"_a, "node2"_a, u8R"( + Get the distance between node1 and node2 using the precomputed all-pairs shortest paths + computed using Dijkstra. If the all-pairs-shortest-paths have not yet been computed, this + function will also compute these. + + Parameters + ---------- + node1: int + The id of the first node + node2: int + The id of the second node + + Returns + ------- + float + The sum of the weights of the edges along the shortest path from node1 to node2 + + Examples + -------- + >>> from pymatching._cpp_mwpm import MatchingGraph + >>> graph = MatchingGraph() + >>> graph.add_edge(0, 1, {0}, 0.12) + >>> graph.add_edge(1, 2, {1}, 0.11) + >>> graph.add_edge(2, 3, {2}, 0.3) + >>> graph.distance(0, 3) + 0.53 + >>> graph.distance(1, 2) + 0.11 + >>> graph.distance(3, 2) + 0.3 + )") + .def("shortest_path", &MatchingGraph::ShortestPath, "node1"_a, "node2"_a, u8R"( + Find the shortest path between node1 and node2, using the precomputed all-pairs shortest paths + computed using Dijkstra's algorithm. If the all-pairs-shortest-paths have not yet been computed, this + function will also compute these. + + Parameters + ---------- + node1: int + The id of the first node + node2: int + The id of the second node + + Returns + ------- + list[int] + A list of ids of the nodes along the shortest path from source to + target (including the source and target nodes). Elements `i` and `i+1` + of the list are nodes in the `i`th edge along the shortest path from + source to target. + + Examples + -------- + >>> from pymatching._cpp_mwpm import MatchingGraph + >>> graph = MatchingGraph() + >>> graph.add_edge(0, 1, {0}, 1) + >>> graph.add_edge(1, 2, {1}, 1) + >>> graph.add_edge(2, 3, {2}, 1) + >>> graph.add_edge(3, 4, {3}, 1) + >>> graph.shortest_path(0, 4) + [0, 1, 2, 3, 4] + >>> graph.shortest_path(3, 2) + [3, 2] + >>> graph.shortest_path(1, 2) + [1, 2] + )") + .def("qubit_ids", &MatchingGraph::QubitIDs, "node1"_a, "node2"_a, u8R"( + Returns the qubit_ids associated with the edge (node1, node2) + + Parameters + ---------- + node1: int + The id of the first node + node2: int + The id of the second node + + Returns + ------- + set[int] + The qubit_ids associated with the edge (node1, node2) + + Examples + -------- + >>> from pymatching._cpp_mwpm import MatchingGraph + >>> graph = MatchingGraph() + >>> graph.add_edge(0, 1, {0}, 1) + >>> graph.add_edge(1, 2, {1, 2, 3}, 1) + >>> graph.qubit_ids(0, 1) + {0} + >>> graph.qubit_ids(1, 2) + {1, 2, 3} + >>> graph.qubit_ids(1, 0) + {0} + )") + .def("get_num_qubits", &MatchingGraph::GetNumQubits, u8R"( + Returns the number of qubits associated with edges in the matching graph. + + Returns + ------- + int + The number of qubits in the matching graph + + Examples + -------- + >>> from pymatching._cpp_mwpm import MatchingGraph + >>> graph = MatchingGraph() + >>> graph.add_edge(0, 1, {0}, 1) + >>> graph.add_edge(1, 2, {1, 2, 3}, 1) + >>> graph.add_edge(2, 3, {4, 5}, 1) + >>> graph.get_num_qubits() + 6 + )") + .def("get_num_nodes", &MatchingGraph::GetNumNodes, u8R"( + Returns the number of nodes in the matching graph + + Returns + ------- + int + The number of nodes in the matching graph + + Examples + -------- + >>> from pymatching._cpp_mwpm import MatchingGraph + >>> graph = MatchingGraph() + >>> graph.add_edge(0, 1, {0}, 1) + >>> graph.add_edge(1, 2, {1}, 1) + >>> graph.add_edge(2, 3, {2}, 1) + >>> graph.get_num_nodes() + 4 + )") + .def("get_num_edges", &MatchingGraph::GetNumEdges, u8R"( + Get the number of edges in the matching graph + + Returns + ------- + int + The number of edges in the matching graph + + Examples + -------- + >>> from pymatching._cpp_mwpm import MatchingGraph + >>> graph = MatchingGraph() + >>> graph.add_edge(0, 1, {0}, 1) + >>> graph.add_edge(1, 2, {1}, 1) + >>> graph.add_edge(2, 3, {2}, 1) + >>> graph.get_num_edges() + 3 + )") + .def("compute_all_pairs_shortest_paths", &MatchingGraph::ComputeAllPairsShortestPaths, u8R"( + Computes the shortest paths between all pairs of nodes in the matching graph using Dijkstra's + algorithm. Note that this method is very memory intensive and is not used for local matching, + only for exact matching. + )") + .def("has_computed_all_pairs_shortest_paths", &MatchingGraph::HasComputedAllPairsShortestPaths, u8R"( + Returns whether or not the all-pairs shortest paths have already been computed. + + Examples + -------- + >>> from pymatching._cpp_mwpm import MatchingGraph + >>> graph = MatchingGraph() + >>> graph.add_edge(0, 1, {0}, 0.5) + >>> graph.add_edge(1, 2, {1}, 1) + >>> graph.has_computed_all_pairs_shortest_paths() + False + >>> graph.compute_all_pairs_shortest_paths() + >>> graph.has_computed_all_pairs_shortest_paths() + True + >>> graph.add_edge(2, 3, {2}, 0.8) + >>> graph.has_computed_all_pairs_shortest_paths() + False + >>> graph.compute_all_pairs_shortest_paths() + >>> graph.has_computed_all_pairs_shortest_paths() + True + )") + .def("get_num_connected_components", &MatchingGraph::GetNumConnectedComponents, u8R"( + Get the number of connected components in the matching graph + + Returns + ------- + int + The number of connected components in the matching graph + + Examples + -------- + >>> from pymatching._cpp_mwpm import MatchingGraph + >>> graph = MatchingGraph() + >>> graph.add_edge(0, 1, {0}, 1) + >>> graph.add_edge(1, 2, {1}, 1) + >>> graph.get_num_connected_components() + 1 + >>> graph.add_edge(3, 4, {2}, 1) + >>> graph.get_num_connected_components() + 2 + >>> graph.add_edge(5, 6, {3}, 1) + >>> graph.add_edge(6, 7, {4}, 1) + >>> graph.get_num_connected_components() + 3 + )") + .def("__repr__", &MatchingGraph::repr); m.def("randomize", &randomize, u8R"( Choose a random seed using std::random_device @@ -85,17 +539,125 @@ PYBIND11_MODULE(_cpp_mwpm, m) { to: float Largest float that can be drawn from the distribution + + Returns + ------- + float + The random float )"); py::class_(m, "MatchingResult") + .def(py::init<>()) + .def(py::init, double>(), "correction"_a, "weight"_a) .def_readwrite("correction", &MatchingResult::correction) - .def_readwrite("weight", &MatchingResult::weight); + .def_readwrite("weight", &MatchingResult::weight) + .def("__repr__", &MatchingResult::repr); py::register_exception(m, "BlossomFailureException", PyExc_RuntimeError); m.def("local_matching", &LocalMatching, - "sg"_a, "defects"_a, "num_neighbours"_a=30, - "return_weight"_a=false, "max_attempts"_a=10); - m.def("exact_matching", &LemonDecode, "sg"_a, "defects"_a, - "return_weight"_a=false); + "graph"_a, "defects"_a, "num_neighbours"_a=30, + "return_weight"_a=false, "max_attempts"_a=10, u8R"( + Decode using local matching. + + Parameters + ---------- + graph: MatchingGraph + The matching graph to be used to decode the syndrome + defects: np.ndarray[int] + A numpy array of integers giving the indices of the -1 measurements + (defects) in the syndrome. i.e. This is an array of IDs of nodes with + a non-trivial syndrome (detectors that have fired). + num_neighbours: int + Number of closest neighbours (with non-trivial syndrome) of each matching + graph node to consider when decoding. `num_neighbours` corresponds to + the parameter `m` in the local matching algorithm in the paper: + https://arxiv.org/abs/2105.13082 is used, and `num_neighbours` + It is recommended to set `num_neighbours` to at least 20 for + decoding performance to closely match that of exact matching. + return_weight: bool + If True, also return the weight of the matching found. By default False + max_attempts: int + The blossom algorithm can very occasionally fail to find a solution if a + perfect matching does not exist in the graph derived from nodes only with + non-trivial syndromes (called a syndrome graph in https://arxiv.org/abs/2105.13082), + since the syndrome graph is not a complete graph in local matching. + If this happens, `num_neighbours` is doubled `max_attempts` times until + a solution is found. It is highly unlikely that more than one attempt is + required, but if a solution is not found after `max_attempts` tries + (doubling `num_neighbours` each time), then a `pymatching.BlossomFailureException` + is raised. By default 10 + + Returns + ------- + MatchingResult + The recovery operator (and, optionally, weight of the matching). + `MatchingResult.correction[i]` is 1 if qubit_id 1 + should be flipped when applying the minimum weight correction, and 0 + otherwise. `MatchingResult.weight` gives the sum of the weights of the + edges included in the minimum-weight perfect matching correction if + `return_weight=True`, and is -1 otherwise. + + Examples + -------- + >>> from pymatching._cpp_mwpm import local_matching, MatchingGraph + >>> graph = MatchingGraph() + >>> graph.add_edge(0, 1, {0}, 1.0) + >>> graph.add_edge(1, 2, {1}, 1.0) + >>> graph.add_edge(2, 0, {2}, 1.0) + >>> res = local_matching(graph, [0, 1]) + >>> res.correction + array([1, 0, 0], dtype=uint8) + + By setting `return_weight=True`, the weight of the matching is also + returned: + >>> from pymatching._cpp_mwpm import local_matching, MatchingGraph + >>> graph = MatchingGraph() + >>> graph.add_edge(0, 1, {0}, 0.23) + >>> graph.add_edge(1, 2, {1}, 0.34) + >>> graph.add_edge(2, 3, {2}, 0.91) + >>> graph.add_edge(3, 0, {3}, 0.86) + >>> res = local_matching(graph, [0, 2], return_weight=True) + >>> res.correction + array([1, 1, 0, 0], dtype=uint8) + >>> res.weight + 0.5700000000000001 + )"); + m.def("exact_matching", &ExactMatching, "graph"_a, "defects"_a, + "return_weight"_a=false, u8R"( + Decode using exact matching + + Parameters + ---------- + graph: MatchingGraph + The matching graph to be used to decode the syndrome + defects: np.ndarray[int] + A numpy array of integers giving the indices of the -1 measurements + (defects) in the syndrome. i.e. This is an array of IDs of nodes with + a non-trivial syndrome (detectors that have fired). + return_weight: bool + If True, also return the weight of the matching found. By default False + + Returns + ------- + MatchingResult + The recovery operator (and, optionally, weight of the matching). + `MatchingResult.correction[i]` is 1 if qubit_id 1 + should be flipped when applying the minimum weight correction, and 0 + otherwise. `MatchingResult.weight` gives the sum of the weights of the + edges included in the minimum-weight perfect matching correction if + `return_weight=True`, and is -1 otherwise. + + Examples + -------- + >>> from pymatching._cpp_mwpm import exact_matching, MatchingGraph + >>> graph = MatchingGraph() + >>> graph.add_edge(0, 1, {0}, 0.8) + >>> graph.add_edge(1, 2, {1}, 0.5) + >>> graph.add_edge(2, 3, {2, 3}, 0.9) + >>> graph.add_edge(3, 0, {4}, 1.1) + >>> res = exact_matching(graph, [1, 3]) + >>> res.correction + array([0, 1, 1, 1, 0], dtype=uint8) + )"); } diff --git a/src/pymatching/lemon_mwpm.cpp b/src/pymatching/lemon_mwpm.cpp index 85f2aede..826a7a52 100644 --- a/src/pymatching/lemon_mwpm.cpp +++ b/src/pymatching/lemon_mwpm.cpp @@ -38,6 +38,36 @@ const char * BlossomFailureException::what() const throw() { "perfect matching problem."; } +MatchingResult::MatchingResult() {} + +MatchingResult::MatchingResult(py::array_t correction, double weight) + : correction(correction), weight(weight) {} + + +std::string arr_repr(py::array_t arr) { + std::stringstream ss; + ss << "array(["; + bool first = true; + for (auto i : arr) { + if (first){ + first = false; + } else { + ss << ", "; + } + ss << i; + } + ss << "], dtype=uint8)"; + return ss.str(); +} + +std::string MatchingResult::repr() const { + std::stringstream ss; + ss << "pymatching._cpp_mwpm.MatchingResult(correction="; + ss << arr_repr(correction); + ss << ", weight=" << weight << ")"; + return ss.str(); +} + class DefectGraph { public: @@ -63,16 +93,16 @@ void DefectGraph::AddEdge(int i, int j, double weight){ } -MatchingResult LemonDecode( - MatchingGraph& sg, +MatchingResult ExactMatching( + MatchingGraph& graph, const py::array_t& defects, bool return_weight ){ MatchingResult matching_result; - if (!sg.HasComputedAllPairsShortestPaths()){ - sg.ComputeAllPairsShortestPaths(); + if (!graph.HasComputedAllPairsShortestPaths()){ + graph.ComputeAllPairsShortestPaths(); } - int num_nodes = sg.GetNumNodes(); + int num_nodes = graph.GetNumNodes(); auto d = defects.unchecked<1>(); std::set defects_set; @@ -84,7 +114,7 @@ MatchingResult LemonDecode( } defects_set.insert(d(i)); } - sg.FlipBoundaryNodesIfNeeded(defects_set); + graph.FlipBoundaryNodesIfNeeded(defects_set); std::vector defects_vec(defects_set.begin(), defects_set.end()); @@ -94,7 +124,7 @@ MatchingResult LemonDecode( for (py::size_t i = 0; i(N, 0); std::set qids; for (py::size_t i = 0; i path = sg.ShortestPath( + std::vector path = graph.ShortestPath( defects_vec[i], defects_vec[j] ); for (std::vector::size_type k=0; k= 0) && (qid < N)){ (*correction)[qid] = ((*correction)[qid] + 1) % 2; @@ -141,7 +171,7 @@ MatchingResult LemonDecode( MatchingResult LocalMatching( - MatchingGraph& sg, + MatchingGraph& graph, const py::array_t& defects, int num_neighbours, bool return_weight, @@ -159,7 +189,7 @@ MatchingResult LocalMatching( while (true) { try{ return LemonDecodeMatchNeighbourhood( - sg, + graph, defects_set, num_neighbours, return_weight @@ -177,14 +207,14 @@ MatchingResult LocalMatching( MatchingResult LemonDecodeMatchNeighbourhood( - MatchingGraph& sg, + MatchingGraph& graph, std::set& defects_set, int num_neighbours, bool return_weight ){ MatchingResult matching_result; - int num_nodes = sg.GetNumNodes(); + int num_nodes = graph.GetNumNodes(); for (auto d : defects_set){ if (d >= num_nodes){ @@ -194,7 +224,7 @@ MatchingResult LemonDecodeMatchNeighbourhood( } } - sg.FlipBoundaryNodesIfNeeded(defects_set); + graph.FlipBoundaryNodesIfNeeded(defects_set); std::vector defects_vec(defects_set.begin(), defects_set.end()); int num_defects = defects_vec.size(); @@ -202,7 +232,7 @@ MatchingResult LemonDecodeMatchNeighbourhood( for (int i=0; i(N, 0); std::set remaining_defects; @@ -246,9 +276,9 @@ MatchingResult LemonDecodeMatchNeighbourhood( remaining_defects.erase(remaining_defects.begin()); j = defect_graph.g.id(pm.mate(defect_graph.g.nodeFromId(i))); remaining_defects.erase(j); - path = sg.GetPath(defects_vec[i], defects_vec[j]); + path = graph.GetPath(defects_vec[i], defects_vec[j]); for (std::vector::size_type k=0; k= 0) && (qid < N)){ (*correction)[qid] = ((*correction)[qid] + 1) % 2; diff --git a/src/pymatching/lemon_mwpm.h b/src/pymatching/lemon_mwpm.h index d45e5600..d0d52bca 100644 --- a/src/pymatching/lemon_mwpm.h +++ b/src/pymatching/lemon_mwpm.h @@ -16,6 +16,8 @@ #include #include #include +#include +#include struct BlossomFailureException : public std::exception { @@ -30,6 +32,8 @@ struct BlossomFailureException : public std::exception { * */ struct MatchingResult { + MatchingResult(); + MatchingResult(py::array_t correction, double weight); /** * @brief The correction operator corresponding to the minimum-weight perfect matching. * correction[i] is 1 if the ith qubit is flipped and correction[i] is 0 otherwise. @@ -43,43 +47,44 @@ struct MatchingResult { * */ double weight; + std::string repr() const; }; /** - * @brief Given a matching graph sg and a vector `defects` of indices of nodes that have a -1 syndrome, + * @brief Given a matching graph `graph` and a vector `defects` of indices of nodes that have a -1 syndrome, * find the find the minimum weight perfect matching in the complete graph with nodes in the defects - * list, and where the edge between node i and j is given by the distance between i and j in sg. The - * distances and shortest paths between nodes in the matching graph sg are all precomputed and this + * list, and where the edge between node i and j is given by the distance between i and j in `graph`. The + * distances and shortest paths between nodes in the matching graph `graph` are all precomputed and this * method returns the exact minimum-weight perfect matching. As a result it is suitable for matching graphs * with a few thousand nodes or less, but will be very memory and compute intensive for larger matching graphs. * Returns a noise vector N for which N[i]=1 if qubit_id appeared an odd number of times in the minimum weight * perfect matching and N[i]=0 otherwise. * - * @param sg A matching graph + * @param graph A matching graph * @param defects The indices of nodes that are associated with a -1 syndrome * @return MatchingResult A struct containing the correction vector for the minimum-weight perfect matching and the matching weight. * The matching weight is set to -1 if it is not requested. */ -MatchingResult LemonDecode(MatchingGraph& sg, const py::array_t& defects, bool return_weight=false); +MatchingResult ExactMatching(MatchingGraph& graph, const py::array_t& defects, bool return_weight=false); /** - * @brief Given a matching graph `sg`, a vector `defects` of indices of nodes that have a -1 syndrome and + * @brief Given a matching graph `graph`, a vector `defects` of indices of nodes that have a -1 syndrome and * a chosen `num_neighbours`, find the minimum weight perfect matching in the graph V where each defect node - * is connected by an edge to each of the `num_neighbours` nearest other defect nodes in sg, and where the - * weight of each edge is the distance between the two defect nodes in `sg`. + * is connected by an edge to each of the `num_neighbours` nearest other defect nodes in graph, and where the + * weight of each edge is the distance between the two defect nodes in `graph`. * Returns a noise vector N for which N[i]=1 if qubit_id appeared an odd number of times in the minimum weight * perfect matching and N[i]=0 otherwise. * - * @param sg A matching graph + * @param graph A matching graph * @param defects The indices of nodes that are associated with a -1 syndrome * @param num_neighbours The number of closest defects to connect each defect to in the matching graph * @return MatchingResult A struct containing the correction vector for the minimum-weight perfect matching and the matching weight. * The matching weight is set to -1 if it is not requested. */ -MatchingResult LemonDecodeMatchNeighbourhood(MatchingGraph& sg, std::set& defects, +MatchingResult LemonDecodeMatchNeighbourhood(MatchingGraph& graph, std::set& defects, int num_neighbours=30, bool return_weight=false); MatchingResult LocalMatching( - MatchingGraph& sg, + MatchingGraph& graph, const py::array_t& defects, int num_neighbours=30, bool return_weight=false, diff --git a/src/pymatching/matching.py b/src/pymatching/matching.py index 7420a082..297cfae5 100644 --- a/src/pymatching/matching.py +++ b/src/pymatching/matching.py @@ -385,7 +385,7 @@ def load_from_check_matrix(self, v1 = H.indices[s] + H.shape[0] * t v2 = H.indices[e - 1] + H.shape[0] * t if e - s == 2 else next(iter(boundary)) self.matching_graph.add_edge(v1, v2, {i}, weights[i], - error_probabilities[i], error_probabilities[i] >= 0) + error_probabilities[i], error_probabilities[i] >= 0) for t in range(repetitions - 1): for i in range(H.shape[0]): self.matching_graph.add_edge(i + t * H.shape[0], i + (t + 1) * H.shape[0], @@ -479,7 +479,7 @@ def num_detectors(self) -> int: return self.num_nodes - len(self.boundary) def decode(self, - z: np.ndarray, + z: Union[np.ndarray, List[int]], num_neighbours: int=30, return_weight: bool=False ) -> Union[np.ndarray, Tuple[np.ndarray, int]]: @@ -499,10 +499,10 @@ def decode(self, (modulo 2) between the (noisy) measurement of stabiliser `i` in time step `j+1` and time step `j` (for the case where `repetitions>1`). num_neighbours : int, optional - Number of closest neighbours of each matching graph node to consider - when decoding. If `num_neighbours` is set (as it is by default), - then the local matching decoder in the Appendix of - https://arxiv.org/abs/2010.09626 is used, and `num_neighbours` + Number of closest neighbours (with non-trivial syndrome) of each matching + graph node to consider when decoding. If `num_neighbours` is set + (as it is by default), then the local matching decoder in + https://arxiv.org/abs/2105.13082 is used, and `num_neighbours` corresponds to the parameter `m` in the paper. It is recommended to leave `num_neighbours` set to at least 20. If `num_neighbours is None`, then instead full matching is @@ -518,7 +518,7 @@ def decode(self, Returns ------- - numpy.ndarray + numpy.ndarray or list[int] A 1D numpy array of ints giving the minimum-weight correction operator. The number of elements equals the number of qubits, and an element is 1 if the corresponding qubit should be flipped, @@ -529,6 +529,59 @@ def decode(self, The sum of the weights of the edges in the minimum-weight perfect matching. + Examples + -------- + >>> import pymatching + >>> import numpy as np + >>> H = np.array([[1, 1, 0, 0], + ... [0, 1, 1, 0], + ... [0, 0, 1, 1]]) + >>> m = pymatching.Matching(H) + >>> z = np.array([0, 1, 0]) + >>> m.decode(z) + array([1, 1, 0, 0], dtype=uint8) + + Each bit in the correction provided by Matching.decode corresponds to a + qubit_id. The index of a bit in a correction corresponds to its qubit_id. + For example, here an error on edge (0, 1) flips qubit_id 2 and 3, as + inferred by the minimum-weight correction: + >>> import pymatching + >>> m = pymatching.Matching() + >>> m.add_edge(0, 1, qubit_id={2, 3}) + >>> m.add_edge(1, 2, qubit_id=1) + >>> m.add_edge(2, 0, qubit_id=0) + >>> m.decode([1, 1, 0]) + array([0, 0, 1, 1], dtype=uint8) + + To decode with a phenomenological noise model (qubits and measurements both suffering + bit-flip errors), you can provide a check matrix and number of syndrome repetitions to + construct a matching graph with a time dimension (where nodes in consecutive time steps + are connected by an edge), and then decode with a 2D syndrome + (dimension 0 is space/qubits, dimension 1 is time): + >>> import pymatching + >>> import numpy as np + >>> np.random.seed(0) + >>> H = np.array([[1, 1, 0, 0], + ... [0, 1, 1, 0], + ... [0, 0, 1, 1]]) + >>> m = pymatching.Matching(H, repetitions=5) + >>> data_qubit_noise = (np.random.rand(4, 5) < 0.1).astype(np.uint8) + >>> print(data_qubit_noise) + [[0 0 0 0 0] + [0 0 0 0 0] + [0 0 0 0 1] + [1 1 0 0 0]] + >>> cumulative_noise = (np.cumsum(data_qubit_noise, 1) % 2).astype(np.uint8) + >>> syndrome = H@cumulative_noise % 2 + >>> print(syndrome) + [[0 0 0 0 0] + [0 0 0 0 1] + [1 0 0 0 1]] + >>> syndrome[:,:-1] ^= (np.random.rand(3, 4) < 0.1).astype(np.uint8) + >>> # Take the parity of consecutive timesteps to construct a difference syndrome: + >>> syndrome[:,1:] = syndrome[:,:-1] ^ syndrome[:,1:] + >>> m.decode(syndrome) + array([0, 0, 1, 0], dtype=uint8) """ try: z = np.array(z, dtype=np.uint8)