Skip to content

Commit

Permalink
Allow negative weights from networkx input
Browse files Browse the repository at this point in the history
  • Loading branch information
oscarhiggott committed Dec 24, 2021
1 parent cd9af28 commit 41d8d43
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 16 deletions.
2 changes: 0 additions & 2 deletions src/pymatching/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,6 @@ def load_from_networkx(self, graph: nx.Graph) -> None:
" (or convertible to a set), not {}".format(fault_ids))
all_fault_ids = all_fault_ids | fault_ids
weight = attr.get("weight", 1) # Default weight is 1 if not provided
if weight < 0:
raise ValueError("Weights cannot be negative.")
e_prob = attr.get("error_probability", -1)
g.add_edge(u, v, fault_ids, weight, e_prob, 0 <= e_prob <= 1)
self.matching_graph = g
Expand Down
9 changes: 0 additions & 9 deletions tests/test_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,6 @@ def test_too_many_checks_per_qubit_raises_value_error():
Matching(H)


def test_negative_weight_raises_value_error():
g = nx.Graph()
g.add_edge(0,1,weight=-1)
with pytest.raises(ValueError):
Matching(g)
with pytest.raises(ValueError):
Matching(csr_matrix([[1,1,0],[0,1,1]]), spacelike_weights=np.array([1,1,-1]))


def test_wrong_check_matrix_type_raises_type_error():
with pytest.raises(TypeError):
Matching("test")
Expand Down
12 changes: 7 additions & 5 deletions tests/test_negative_weghts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import numpy as np
import networkx as nx

from pymatching import Matching

Expand Down Expand Up @@ -32,11 +33,12 @@ def test_isolated_negative_weight(nn):

@pytest.mark.parametrize("nn", (None, 30))
def test_negative_and_positive_in_matching(nn):
m = Matching()
m.add_edge(0, 1, 0, 1)
m.add_edge(1, 2, 1, -10)
m.add_edge(2, 3, 2, 1)
m.add_edge(3, 0, 3, 1)
g = nx.Graph()
g.add_edge(0, 1, fault_ids=0, weight=1)
g.add_edge(1, 2, fault_ids=1, weight=-10)
g.add_edge(2, 3, fault_ids=2, weight=1)
g.add_edge(3, 0, fault_ids=3, weight=1)
m = Matching(g)
c, w = m.decode([0, 1, 0, 1], return_weight=True, num_neighbours=nn)
assert np.array_equal(c, np.array([0, 1, 1, 0]))
assert w == -9
Expand Down

0 comments on commit 41d8d43

Please sign in to comment.