Skip to content

Commit 4e8bdad

Browse files
Improve tests
1 parent 4cc0b29 commit 4e8bdad

File tree

2 files changed

+105
-65
lines changed

2 files changed

+105
-65
lines changed

tests/test_optimization_over_graphs.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from grakel_replace.utils import min_max_scale
2020

2121

22-
class TestPipeline:
22+
class TestGraphOptimizationPipeline:
2323
@pytest.fixture
2424
def setup_data(self):
2525
"""Fixture to set up common data for tests."""
@@ -79,11 +79,13 @@ def test_gp_fit_and_predict(self, setup_data):
7979
kernels = [
8080
ScaleKernel(MaternKernel(nu=2.5, ard_num_dims=setup_data["N_NUMERICAL"],
8181
active_dims=range(setup_data["N_NUMERICAL"]))),
82-
ScaleKernel(CategoricalKernel(ard_num_dims=setup_data["N_CATEGORICAL"],
83-
active_dims=range(setup_data["N_NUMERICAL"],
84-
setup_data["N_NUMERICAL"] +
85-
setup_data[
86-
"N_CATEGORICAL"]))),
82+
ScaleKernel(
83+
CategoricalKernel(ard_num_dims=setup_data["N_CATEGORICAL"],
84+
active_dims=range(setup_data["N_NUMERICAL"],
85+
setup_data["N_NUMERICAL"] +
86+
setup_data["N_CATEGORICAL"])
87+
)
88+
),
8789
ScaleKernel(
8890
BoTorchWLKernel(graph_lookup=train_graphs, n_iter=5, normalize=True,
8991
active_dims=(train_x.shape[1] - 1,)))

tests/test_torch_wl_kernel.py

Lines changed: 97 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class TestTorchWLKernel:
1111
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1212

1313
@pytest.fixture
14-
def example_graphs(self):
14+
def example_graphs_set(self):
1515
# Create example graphs for testing
1616
G1 = nx.Graph()
1717
G1.add_edges_from([(0, 1), (1, 2), (1, 3), (2, 3), (3, 4)])
@@ -30,35 +30,67 @@ def example_graphs(self):
3030

3131
return [G1, G2, G3]
3232

33-
@pytest.mark.parametrize("n_iter", [1, 2, 3, 5, 10])
34-
@pytest.mark.parametrize("normalize", [True, False])
35-
def test_wl_kernel_against_grakel(self, n_iter, normalize, example_graphs):
36-
adjacency_matrices, label_tensors = graphs_to_tensors(
37-
example_graphs, device=self.device)
33+
@pytest.fixture
34+
def random_graphs_sets(self):
35+
# Set a seed for reproducibility
36+
seed = 100
37+
np.random.seed(seed)
38+
torch.manual_seed(seed)
39+
random_graph_sets = []
3840

39-
# Initialize Torch WL Kernel
40-
torch_kernel = TorchWLKernel(n_iter=n_iter, normalize=normalize)
41-
torch_kernel_matrix = torch_kernel(adjacency_matrices,
42-
label_tensors).cpu().numpy()
41+
# Generate 10 random sets of graphs
42+
for _ in range(10):
43+
# Number of graphs in the set (2 to 10)
44+
num_graphs = np.random.randint(2, 11)
45+
graph_set = []
4346

44-
# Initialize GraKel WL Kernel
45-
grakel_graphs = list(
46-
graph_from_networkx(example_graphs, node_labels_tag="label", as_Graph=True))
47-
grakel_kernel = WeisfeilerLehman(n_iter=n_iter, normalize=normalize)
48-
grakel_kernel_matrix = grakel_kernel.fit_transform(grakel_graphs)
47+
for _ in range(num_graphs):
48+
# Number of nodes in the graph (3 to 50)
49+
num_nodes = np.random.randint(3, 51)
50+
G = nx.Graph()
4951

50-
# Define tolerances based on normalization
51-
rtol = 1e-5 if normalize else 1e-4
52-
atol = 1e-8 if normalize else 1e-7
52+
# Add nodes with labels
53+
for node in range(num_nodes):
54+
G.add_node(node, label=str(node))
5355

54-
# Compare the kernel matrices
55-
np.testing.assert_allclose(
56-
torch_kernel_matrix,
57-
grakel_kernel_matrix,
58-
rtol=rtol,
59-
atol=atol,
60-
err_msg=f"Kernel matrices differ for n_iter={n_iter}, normalize={normalize}"
61-
)
56+
# Add random edges
57+
for u in range(num_nodes):
58+
for v in range(u + 1, num_nodes):
59+
if np.random.rand() > 0.5: # 50% chance to add an edge
60+
G.add_edge(u, v)
61+
62+
graph_set.append(G)
63+
64+
random_graph_sets.append(graph_set)
65+
66+
return random_graph_sets
67+
68+
@pytest.mark.parametrize("n_iter", [1, 2, 3, 5, 10])
69+
@pytest.mark.parametrize("normalize", [True, False])
70+
def test_wl_kernel_against_grakel(self, n_iter, normalize, random_graphs_sets):
71+
for graph_set in random_graphs_sets:
72+
adjacency_matrices, label_tensors = graphs_to_tensors(
73+
graph_set, device=self.device)
74+
75+
# Initialize Torch WL Kernel
76+
torch_kernel = TorchWLKernel(n_iter=n_iter, normalize=normalize)
77+
torch_kernel_matrix = torch_kernel(adjacency_matrices,
78+
label_tensors).cpu().numpy()
79+
80+
# Initialize GraKel WL Kernel
81+
grakel_graphs = list(
82+
graph_from_networkx(graph_set, node_labels_tag="label", as_Graph=True))
83+
grakel_kernel = WeisfeilerLehman(n_iter=n_iter, normalize=normalize)
84+
grakel_kernel_matrix = grakel_kernel.fit_transform(grakel_graphs)
85+
86+
# Compare the kernel matrices
87+
np.testing.assert_allclose(
88+
torch_kernel_matrix,
89+
grakel_kernel_matrix,
90+
rtol=1e-5,
91+
atol=1e-8,
92+
err_msg=f"Kernel matrices differ for graph={graph_set}, n_iter={n_iter}"
93+
)
6294

6395
def test_empty_graph(self):
6496
G_empty = nx.Graph()
@@ -97,36 +129,47 @@ def test_kernel_on_single_node_graph(self):
97129
expected = torch.ones(1, 1, device=self.device)
98130
torch.testing.assert_close(K, expected)
99131

100-
def test_wl_kernel_with_empty_graph_and_reordered_edges(self, example_graphs):
132+
def test_wl_kernel_with_empty_graph_and_reordered_edges(self, random_graphs_sets):
101133
"""Test the TorchWLKernel with an empty graph and a graph with reordered edges."""
102-
# Create example graphs for testing
103-
G_empty = nx.Graph()
104-
G_empty.add_node(0)
105-
G_empty.nodes[0]["label"] = "0"
106-
107-
G = example_graphs[0]
108-
G_reordered = nx.Graph()
109-
G_reordered.add_edges_from([(1, 4), (2, 3), (1, 2), (0, 1), (1, 3)])
110-
for node in G_reordered.nodes():
111-
G_reordered.nodes[node]["label"] = str(node)
112-
113-
graphs = [G_empty, G, G_reordered]
114-
adjacency_matrices, label_tensors = graphs_to_tensors(graphs,
115-
device=self.device)
116-
117-
wl_kernel = TorchWLKernel(n_iter=3, normalize=True)
118-
K = wl_kernel(adjacency_matrices, label_tensors)
119-
120-
assert K.shape == (3, 3), "Kernel matrix shape is incorrect"
121-
assert K[1, 1] == K[
122-
2, 2], "Kernel value for original and reordered graphs should be the same"
134+
for graph_set in random_graphs_sets:
135+
# Create an empty graph
136+
G_empty = nx.Graph()
137+
G_empty.add_node(0)
138+
G_empty.nodes[0]["label"] = "0"
139+
140+
# Select the first graph from the set to reorder its edges
141+
G = graph_set[0]
142+
G_reordered = nx.Graph()
143+
144+
# Add all nodes from the original graph to G_reordered
145+
for node in G.nodes():
146+
G_reordered.add_node(node, label=G.nodes[node]["label"])
147+
148+
# Reorder edges randomly
149+
edges = list(G.edges())
150+
np.random.shuffle(edges) # Randomly shuffle the edges
151+
G_reordered.add_edges_from(edges)
152+
153+
# Combine the empty graph, original graph, and reordered graph
154+
graphs = [G_empty, G, G_reordered]
155+
adjacency_matrices, label_tensors = graphs_to_tensors(
156+
graphs, device=self.device
157+
)
158+
159+
# Initialize and compute the kernel
160+
wl_kernel = TorchWLKernel(n_iter=3, normalize=True)
161+
K = wl_kernel(adjacency_matrices, label_tensors)
162+
163+
assert K.shape == (3, 3), "Kernel matrix shape is incorrect"
164+
assert torch.allclose(K[1, 1], K[2, 2]), \
165+
"Kernel value for original and reordered graphs should be the same"
123166

124167
@pytest.mark.parametrize("n_iter", [1, 2, 3, 4, 5, 6, 7])
125168
@pytest.mark.parametrize("normalize", [True, False])
126169
def test_wl_kernel_with_different_node_labels(self, n_iter, normalize,
127-
example_graphs):
170+
example_graphs_set):
128171
graphs = []
129-
for i, G in enumerate(example_graphs):
172+
for i, G in enumerate(example_graphs_set):
130173
G_copy = G.copy()
131174
prefix = ["node_", "vertex_", "n"][i]
132175
for node in G_copy.nodes():
@@ -143,20 +186,15 @@ def test_wl_kernel_with_different_node_labels(self, n_iter, normalize,
143186
grakel_wl = WeisfeilerLehman(n_iter=n_iter, normalize=normalize)
144187
grakel_kernel_matrix = grakel_wl.fit_transform(grakel_graphs)
145188

146-
# Define tolerances based on normalization, matching the main test
147-
rtol = 1e-5 if normalize else 1e-4
148-
atol = 1e-8 if normalize else 1e-7
149-
150-
# Updated assertion with both rtol and atol
151189
np.testing.assert_allclose(
152190
torch_kernel_matrix,
153191
grakel_kernel_matrix,
154-
rtol=rtol,
155-
atol=atol,
192+
rtol=1e-5,
193+
atol=1e-8,
156194
err_msg=f"Kernel matrices differ for n_iter={n_iter}, normalize={normalize}"
157195
)
158196

159-
def test_wl_kernel_with_same_node_labels(self, example_graphs):
197+
def test_wl_kernel_with_same_node_labels(self, example_graphs_set):
160198
"""Test WL kernel behavior with same node labels but different structures.
161199
162200
Even when all nodes have the same label, the WL kernel should:
@@ -166,7 +204,7 @@ def test_wl_kernel_with_same_node_labels(self, example_graphs):
166204
4. Maintain non-negative values (it's a valid kernel)
167205
"""
168206
graphs = []
169-
for G in example_graphs:
207+
for G in example_graphs_set:
170208
G_copy = G.copy()
171209
for node in G_copy.nodes():
172210
G_copy.nodes[node]["label"] = "A"

0 commit comments

Comments
 (0)