@@ -11,7 +11,7 @@ class TestTorchWLKernel:
1111 device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
1212
1313 @pytest .fixture
14- def example_graphs (self ):
14+ def example_graphs_set (self ):
1515 # Create example graphs for testing
1616 G1 = nx .Graph ()
1717 G1 .add_edges_from ([(0 , 1 ), (1 , 2 ), (1 , 3 ), (2 , 3 ), (3 , 4 )])
@@ -30,35 +30,67 @@ def example_graphs(self):
3030
3131 return [G1 , G2 , G3 ]
3232
33- @pytest .mark .parametrize ("n_iter" , [1 , 2 , 3 , 5 , 10 ])
34- @pytest .mark .parametrize ("normalize" , [True , False ])
35- def test_wl_kernel_against_grakel (self , n_iter , normalize , example_graphs ):
36- adjacency_matrices , label_tensors = graphs_to_tensors (
37- example_graphs , device = self .device )
33+ @pytest .fixture
34+ def random_graphs_sets (self ):
35+ # Set a seed for reproducibility
36+ seed = 100
37+ np .random .seed (seed )
38+ torch .manual_seed (seed )
39+ random_graph_sets = []
3840
39- # Initialize Torch WL Kernel
40- torch_kernel = TorchWLKernel (n_iter = n_iter , normalize = normalize )
41- torch_kernel_matrix = torch_kernel (adjacency_matrices ,
42- label_tensors ).cpu ().numpy ()
41+ # Generate 10 random sets of graphs
42+ for _ in range (10 ):
43+ # Number of graphs in the set (2 to 10)
44+ num_graphs = np .random .randint (2 , 11 )
45+ graph_set = []
4346
44- # Initialize GraKel WL Kernel
45- grakel_graphs = list (
46- graph_from_networkx (example_graphs , node_labels_tag = "label" , as_Graph = True ))
47- grakel_kernel = WeisfeilerLehman (n_iter = n_iter , normalize = normalize )
48- grakel_kernel_matrix = grakel_kernel .fit_transform (grakel_graphs )
47+ for _ in range (num_graphs ):
48+ # Number of nodes in the graph (3 to 50)
49+ num_nodes = np .random .randint (3 , 51 )
50+ G = nx .Graph ()
4951
50- # Define tolerances based on normalization
51- rtol = 1e-5 if normalize else 1e-4
52- atol = 1e-8 if normalize else 1e-7
52+ # Add nodes with labels
53+ for node in range ( num_nodes ):
54+ G . add_node ( node , label = str ( node ))
5355
54- # Compare the kernel matrices
55- np .testing .assert_allclose (
56- torch_kernel_matrix ,
57- grakel_kernel_matrix ,
58- rtol = rtol ,
59- atol = atol ,
60- err_msg = f"Kernel matrices differ for n_iter={ n_iter } , normalize={ normalize } "
61- )
56+ # Add random edges
57+ for u in range (num_nodes ):
58+ for v in range (u + 1 , num_nodes ):
59+ if np .random .rand () > 0.5 : # 50% chance to add an edge
60+ G .add_edge (u , v )
61+
62+ graph_set .append (G )
63+
64+ random_graph_sets .append (graph_set )
65+
66+ return random_graph_sets
67+
68+ @pytest .mark .parametrize ("n_iter" , [1 , 2 , 3 , 5 , 10 ])
69+ @pytest .mark .parametrize ("normalize" , [True , False ])
70+ def test_wl_kernel_against_grakel (self , n_iter , normalize , random_graphs_sets ):
71+ for graph_set in random_graphs_sets :
72+ adjacency_matrices , label_tensors = graphs_to_tensors (
73+ graph_set , device = self .device )
74+
75+ # Initialize Torch WL Kernel
76+ torch_kernel = TorchWLKernel (n_iter = n_iter , normalize = normalize )
77+ torch_kernel_matrix = torch_kernel (adjacency_matrices ,
78+ label_tensors ).cpu ().numpy ()
79+
80+ # Initialize GraKel WL Kernel
81+ grakel_graphs = list (
82+ graph_from_networkx (graph_set , node_labels_tag = "label" , as_Graph = True ))
83+ grakel_kernel = WeisfeilerLehman (n_iter = n_iter , normalize = normalize )
84+ grakel_kernel_matrix = grakel_kernel .fit_transform (grakel_graphs )
85+
86+ # Compare the kernel matrices
87+ np .testing .assert_allclose (
88+ torch_kernel_matrix ,
89+ grakel_kernel_matrix ,
90+ rtol = 1e-5 ,
91+ atol = 1e-8 ,
92+ err_msg = f"Kernel matrices differ for graph={ graph_set } , n_iter={ n_iter } "
93+ )
6294
6395 def test_empty_graph (self ):
6496 G_empty = nx .Graph ()
@@ -97,36 +129,47 @@ def test_kernel_on_single_node_graph(self):
97129 expected = torch .ones (1 , 1 , device = self .device )
98130 torch .testing .assert_close (K , expected )
99131
100- def test_wl_kernel_with_empty_graph_and_reordered_edges (self , example_graphs ):
132+ def test_wl_kernel_with_empty_graph_and_reordered_edges (self , random_graphs_sets ):
101133 """Test the TorchWLKernel with an empty graph and a graph with reordered edges."""
102- # Create example graphs for testing
103- G_empty = nx .Graph ()
104- G_empty .add_node (0 )
105- G_empty .nodes [0 ]["label" ] = "0"
106-
107- G = example_graphs [0 ]
108- G_reordered = nx .Graph ()
109- G_reordered .add_edges_from ([(1 , 4 ), (2 , 3 ), (1 , 2 ), (0 , 1 ), (1 , 3 )])
110- for node in G_reordered .nodes ():
111- G_reordered .nodes [node ]["label" ] = str (node )
112-
113- graphs = [G_empty , G , G_reordered ]
114- adjacency_matrices , label_tensors = graphs_to_tensors (graphs ,
115- device = self .device )
116-
117- wl_kernel = TorchWLKernel (n_iter = 3 , normalize = True )
118- K = wl_kernel (adjacency_matrices , label_tensors )
119-
120- assert K .shape == (3 , 3 ), "Kernel matrix shape is incorrect"
121- assert K [1 , 1 ] == K [
122- 2 , 2 ], "Kernel value for original and reordered graphs should be the same"
134+ for graph_set in random_graphs_sets :
135+ # Create an empty graph
136+ G_empty = nx .Graph ()
137+ G_empty .add_node (0 )
138+ G_empty .nodes [0 ]["label" ] = "0"
139+
140+ # Select the first graph from the set to reorder its edges
141+ G = graph_set [0 ]
142+ G_reordered = nx .Graph ()
143+
144+ # Add all nodes from the original graph to G_reordered
145+ for node in G .nodes ():
146+ G_reordered .add_node (node , label = G .nodes [node ]["label" ])
147+
148+ # Reorder edges randomly
149+ edges = list (G .edges ())
150+ np .random .shuffle (edges ) # Randomly shuffle the edges
151+ G_reordered .add_edges_from (edges )
152+
153+ # Combine the empty graph, original graph, and reordered graph
154+ graphs = [G_empty , G , G_reordered ]
155+ adjacency_matrices , label_tensors = graphs_to_tensors (
156+ graphs , device = self .device
157+ )
158+
159+ # Initialize and compute the kernel
160+ wl_kernel = TorchWLKernel (n_iter = 3 , normalize = True )
161+ K = wl_kernel (adjacency_matrices , label_tensors )
162+
163+ assert K .shape == (3 , 3 ), "Kernel matrix shape is incorrect"
164+ assert torch .allclose (K [1 , 1 ], K [2 , 2 ]), \
165+ "Kernel value for original and reordered graphs should be the same"
123166
124167 @pytest .mark .parametrize ("n_iter" , [1 , 2 , 3 , 4 , 5 , 6 , 7 ])
125168 @pytest .mark .parametrize ("normalize" , [True , False ])
126169 def test_wl_kernel_with_different_node_labels (self , n_iter , normalize ,
127- example_graphs ):
170+ example_graphs_set ):
128171 graphs = []
129- for i , G in enumerate (example_graphs ):
172+ for i , G in enumerate (example_graphs_set ):
130173 G_copy = G .copy ()
131174 prefix = ["node_" , "vertex_" , "n" ][i ]
132175 for node in G_copy .nodes ():
@@ -143,20 +186,15 @@ def test_wl_kernel_with_different_node_labels(self, n_iter, normalize,
143186 grakel_wl = WeisfeilerLehman (n_iter = n_iter , normalize = normalize )
144187 grakel_kernel_matrix = grakel_wl .fit_transform (grakel_graphs )
145188
146- # Define tolerances based on normalization, matching the main test
147- rtol = 1e-5 if normalize else 1e-4
148- atol = 1e-8 if normalize else 1e-7
149-
150- # Updated assertion with both rtol and atol
151189 np .testing .assert_allclose (
152190 torch_kernel_matrix ,
153191 grakel_kernel_matrix ,
154- rtol = rtol ,
155- atol = atol ,
192+ rtol = 1e-5 ,
193+ atol = 1e-8 ,
156194 err_msg = f"Kernel matrices differ for n_iter={ n_iter } , normalize={ normalize } "
157195 )
158196
159- def test_wl_kernel_with_same_node_labels (self , example_graphs ):
197+ def test_wl_kernel_with_same_node_labels (self , example_graphs_set ):
160198 """Test WL kernel behavior with same node labels but different structures.
161199
162200 Even when all nodes have the same label, the WL kernel should:
@@ -166,7 +204,7 @@ def test_wl_kernel_with_same_node_labels(self, example_graphs):
166204 4. Maintain non-negative values (it's a valid kernel)
167205 """
168206 graphs = []
169- for G in example_graphs :
207+ for G in example_graphs_set :
170208 G_copy = G .copy ()
171209 for node in G_copy .nodes ():
172210 G_copy .nodes [node ]["label" ] = "A"
0 commit comments