@@ -11,7 +11,7 @@ class TestTorchWLKernel:
11
11
device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
12
12
13
13
@pytest .fixture
14
- def example_graphs (self ):
14
+ def example_graphs_set (self ):
15
15
# Create example graphs for testing
16
16
G1 = nx .Graph ()
17
17
G1 .add_edges_from ([(0 , 1 ), (1 , 2 ), (1 , 3 ), (2 , 3 ), (3 , 4 )])
@@ -30,35 +30,67 @@ def example_graphs(self):
30
30
31
31
return [G1 , G2 , G3 ]
32
32
33
- @pytest .mark .parametrize ("n_iter" , [1 , 2 , 3 , 5 , 10 ])
34
- @pytest .mark .parametrize ("normalize" , [True , False ])
35
- def test_wl_kernel_against_grakel (self , n_iter , normalize , example_graphs ):
36
- adjacency_matrices , label_tensors = graphs_to_tensors (
37
- example_graphs , device = self .device )
33
+ @pytest .fixture
34
+ def random_graphs_sets (self ):
35
+ # Set a seed for reproducibility
36
+ seed = 100
37
+ np .random .seed (seed )
38
+ torch .manual_seed (seed )
39
+ random_graph_sets = []
38
40
39
- # Initialize Torch WL Kernel
40
- torch_kernel = TorchWLKernel (n_iter = n_iter , normalize = normalize )
41
- torch_kernel_matrix = torch_kernel (adjacency_matrices ,
42
- label_tensors ).cpu ().numpy ()
41
+ # Generate 10 random sets of graphs
42
+ for _ in range (10 ):
43
+ # Number of graphs in the set (2 to 10)
44
+ num_graphs = np .random .randint (2 , 11 )
45
+ graph_set = []
43
46
44
- # Initialize GraKel WL Kernel
45
- grakel_graphs = list (
46
- graph_from_networkx (example_graphs , node_labels_tag = "label" , as_Graph = True ))
47
- grakel_kernel = WeisfeilerLehman (n_iter = n_iter , normalize = normalize )
48
- grakel_kernel_matrix = grakel_kernel .fit_transform (grakel_graphs )
47
+ for _ in range (num_graphs ):
48
+ # Number of nodes in the graph (3 to 50)
49
+ num_nodes = np .random .randint (3 , 51 )
50
+ G = nx .Graph ()
49
51
50
- # Define tolerances based on normalization
51
- rtol = 1e-5 if normalize else 1e-4
52
- atol = 1e-8 if normalize else 1e-7
52
+ # Add nodes with labels
53
+ for node in range ( num_nodes ):
54
+ G . add_node ( node , label = str ( node ))
53
55
54
- # Compare the kernel matrices
55
- np .testing .assert_allclose (
56
- torch_kernel_matrix ,
57
- grakel_kernel_matrix ,
58
- rtol = rtol ,
59
- atol = atol ,
60
- err_msg = f"Kernel matrices differ for n_iter={ n_iter } , normalize={ normalize } "
61
- )
56
+ # Add random edges
57
+ for u in range (num_nodes ):
58
+ for v in range (u + 1 , num_nodes ):
59
+ if np .random .rand () > 0.5 : # 50% chance to add an edge
60
+ G .add_edge (u , v )
61
+
62
+ graph_set .append (G )
63
+
64
+ random_graph_sets .append (graph_set )
65
+
66
+ return random_graph_sets
67
+
68
+ @pytest .mark .parametrize ("n_iter" , [1 , 2 , 3 , 5 , 10 ])
69
+ @pytest .mark .parametrize ("normalize" , [True , False ])
70
+ def test_wl_kernel_against_grakel (self , n_iter , normalize , random_graphs_sets ):
71
+ for graph_set in random_graphs_sets :
72
+ adjacency_matrices , label_tensors = graphs_to_tensors (
73
+ graph_set , device = self .device )
74
+
75
+ # Initialize Torch WL Kernel
76
+ torch_kernel = TorchWLKernel (n_iter = n_iter , normalize = normalize )
77
+ torch_kernel_matrix = torch_kernel (adjacency_matrices ,
78
+ label_tensors ).cpu ().numpy ()
79
+
80
+ # Initialize GraKel WL Kernel
81
+ grakel_graphs = list (
82
+ graph_from_networkx (graph_set , node_labels_tag = "label" , as_Graph = True ))
83
+ grakel_kernel = WeisfeilerLehman (n_iter = n_iter , normalize = normalize )
84
+ grakel_kernel_matrix = grakel_kernel .fit_transform (grakel_graphs )
85
+
86
+ # Compare the kernel matrices
87
+ np .testing .assert_allclose (
88
+ torch_kernel_matrix ,
89
+ grakel_kernel_matrix ,
90
+ rtol = 1e-5 ,
91
+ atol = 1e-8 ,
92
+ err_msg = f"Kernel matrices differ for graph={ graph_set } , n_iter={ n_iter } "
93
+ )
62
94
63
95
def test_empty_graph (self ):
64
96
G_empty = nx .Graph ()
@@ -97,36 +129,47 @@ def test_kernel_on_single_node_graph(self):
97
129
expected = torch .ones (1 , 1 , device = self .device )
98
130
torch .testing .assert_close (K , expected )
99
131
100
- def test_wl_kernel_with_empty_graph_and_reordered_edges (self , example_graphs ):
132
+ def test_wl_kernel_with_empty_graph_and_reordered_edges (self , random_graphs_sets ):
101
133
"""Test the TorchWLKernel with an empty graph and a graph with reordered edges."""
102
- # Create example graphs for testing
103
- G_empty = nx .Graph ()
104
- G_empty .add_node (0 )
105
- G_empty .nodes [0 ]["label" ] = "0"
106
-
107
- G = example_graphs [0 ]
108
- G_reordered = nx .Graph ()
109
- G_reordered .add_edges_from ([(1 , 4 ), (2 , 3 ), (1 , 2 ), (0 , 1 ), (1 , 3 )])
110
- for node in G_reordered .nodes ():
111
- G_reordered .nodes [node ]["label" ] = str (node )
112
-
113
- graphs = [G_empty , G , G_reordered ]
114
- adjacency_matrices , label_tensors = graphs_to_tensors (graphs ,
115
- device = self .device )
116
-
117
- wl_kernel = TorchWLKernel (n_iter = 3 , normalize = True )
118
- K = wl_kernel (adjacency_matrices , label_tensors )
119
-
120
- assert K .shape == (3 , 3 ), "Kernel matrix shape is incorrect"
121
- assert K [1 , 1 ] == K [
122
- 2 , 2 ], "Kernel value for original and reordered graphs should be the same"
134
+ for graph_set in random_graphs_sets :
135
+ # Create an empty graph
136
+ G_empty = nx .Graph ()
137
+ G_empty .add_node (0 )
138
+ G_empty .nodes [0 ]["label" ] = "0"
139
+
140
+ # Select the first graph from the set to reorder its edges
141
+ G = graph_set [0 ]
142
+ G_reordered = nx .Graph ()
143
+
144
+ # Add all nodes from the original graph to G_reordered
145
+ for node in G .nodes ():
146
+ G_reordered .add_node (node , label = G .nodes [node ]["label" ])
147
+
148
+ # Reorder edges randomly
149
+ edges = list (G .edges ())
150
+ np .random .shuffle (edges ) # Randomly shuffle the edges
151
+ G_reordered .add_edges_from (edges )
152
+
153
+ # Combine the empty graph, original graph, and reordered graph
154
+ graphs = [G_empty , G , G_reordered ]
155
+ adjacency_matrices , label_tensors = graphs_to_tensors (
156
+ graphs , device = self .device
157
+ )
158
+
159
+ # Initialize and compute the kernel
160
+ wl_kernel = TorchWLKernel (n_iter = 3 , normalize = True )
161
+ K = wl_kernel (adjacency_matrices , label_tensors )
162
+
163
+ assert K .shape == (3 , 3 ), "Kernel matrix shape is incorrect"
164
+ assert torch .allclose (K [1 , 1 ], K [2 , 2 ]), \
165
+ "Kernel value for original and reordered graphs should be the same"
123
166
124
167
@pytest .mark .parametrize ("n_iter" , [1 , 2 , 3 , 4 , 5 , 6 , 7 ])
125
168
@pytest .mark .parametrize ("normalize" , [True , False ])
126
169
def test_wl_kernel_with_different_node_labels (self , n_iter , normalize ,
127
- example_graphs ):
170
+ example_graphs_set ):
128
171
graphs = []
129
- for i , G in enumerate (example_graphs ):
172
+ for i , G in enumerate (example_graphs_set ):
130
173
G_copy = G .copy ()
131
174
prefix = ["node_" , "vertex_" , "n" ][i ]
132
175
for node in G_copy .nodes ():
@@ -143,20 +186,15 @@ def test_wl_kernel_with_different_node_labels(self, n_iter, normalize,
143
186
grakel_wl = WeisfeilerLehman (n_iter = n_iter , normalize = normalize )
144
187
grakel_kernel_matrix = grakel_wl .fit_transform (grakel_graphs )
145
188
146
- # Define tolerances based on normalization, matching the main test
147
- rtol = 1e-5 if normalize else 1e-4
148
- atol = 1e-8 if normalize else 1e-7
149
-
150
- # Updated assertion with both rtol and atol
151
189
np .testing .assert_allclose (
152
190
torch_kernel_matrix ,
153
191
grakel_kernel_matrix ,
154
- rtol = rtol ,
155
- atol = atol ,
192
+ rtol = 1e-5 ,
193
+ atol = 1e-8 ,
156
194
err_msg = f"Kernel matrices differ for n_iter={ n_iter } , normalize={ normalize } "
157
195
)
158
196
159
- def test_wl_kernel_with_same_node_labels (self , example_graphs ):
197
+ def test_wl_kernel_with_same_node_labels (self , example_graphs_set ):
160
198
"""Test WL kernel behavior with same node labels but different structures.
161
199
162
200
Even when all nodes have the same label, the WL kernel should:
@@ -166,7 +204,7 @@ def test_wl_kernel_with_same_node_labels(self, example_graphs):
166
204
4. Maintain non-negative values (it's a valid kernel)
167
205
"""
168
206
graphs = []
169
- for G in example_graphs :
207
+ for G in example_graphs_set :
170
208
G_copy = G .copy ()
171
209
for node in G_copy .nodes ():
172
210
G_copy .nodes [node ]["label" ] = "A"
0 commit comments