11from __future__ import annotations
22
3+ from functools import lru_cache
34from typing import TYPE_CHECKING , Any
45
56import torch
@@ -32,7 +33,6 @@ class BoTorchWLKernel(Kernel):
3233 graph_lookup (list[nx.Graph]): List of graphs used for kernel computation.
3334 n_iter (int): Number of WL iterations.
3435 normalize (bool): Whether to normalize the kernel matrix.
35- cache (dict[tuple, Tensor]): Cache for storing precomputed kernel matrices.
3636 adjacency_cache (list[Tensor]): Cached adjacency matrices of the graphs.
3737 label_cache (list[Tensor]): Cached initial node labels of the graphs.
3838 """
@@ -51,7 +51,6 @@ def __init__(
5151 self .graph_lookup = graph_lookup
5252 self .n_iter = n_iter
5353 self .normalize = normalize
54- self .cache : dict [tuple , Tensor ] = {}
5554 self ._precompute_graph_data ()
5655
5756 def _precompute_graph_data (self ) -> None :
@@ -75,34 +74,15 @@ def forward(
7574 ** params : Any ,
7675 ) -> Tensor :
7776 """Compute kernel matrix containing pairwise similarities between graphs."""
78- # last_dim_is_batch is for compatibility with base Kernel class.
7977 if last_dim_is_batch :
8078 raise NotImplementedError ("Batch dimension handling is not implemented." )
8179
82- x1_is_x2 = torch .equal (x1 , x2 )
83- indices = tuple (x1 .flatten ().tolist ()) if x1_is_x2 else (
84- tuple (x1 .flatten ().tolist ()), tuple (x2 .flatten ().tolist ()))
85-
86- if indices in self .cache :
87- return self .cache [indices ]
88-
89- # Compute kernel matrix if not cached
90- K = self ._compute_kernel (x1 , x2 , diag = diag )
91- self .cache [indices ] = K
92- return K
93-
94- def _compute_kernel (self , x1 : Tensor , x2 : Tensor , diag : bool ) -> Tensor :
95- """Compute the kernel matrix."""
9680 if x1 .ndim == 3 :
9781 return self ._handle_batched_input (x1 , x2 , diag )
9882
9983 indices1 , indices2 = self ._prepare_indices (x1 , x2 )
10084
101- # Check if we're computing self-similarity or cross-similarity
102- if torch .equal (x1 , x2 ):
103- return self ._compute_self_kernel (indices1 , diag )
104- else :
105- return self ._compute_cross_kernel (indices1 , indices2 , diag )
85+ return self ._compute_kernel (tuple (indices1 ), tuple (indices2 ), diag )
10686
10787 def _handle_batched_input (self , x1 : Tensor , x2 : Tensor , diag : bool ) -> Tensor :
10888 """Handle computation for batched input tensors."""
@@ -111,7 +91,7 @@ def _handle_batched_input(self, x1: Tensor, x2: Tensor, diag: bool) -> Tensor:
11191
11292 out = torch .empty ((q_dim_size , x1 .shape [1 ], x2 .shape [1 ]), device = x1 .device )
11393 for q in range (q_dim_size ):
114- out [q ] = self ._compute_kernel (x1 [q ], x2 [q ], diag = diag )
94+ out [q ] = self .forward (x1 [q ], x2 [q ], diag = diag )
11595 return out
11696
11797 def _prepare_indices (self , x1 : Tensor , x2 : Tensor ) -> tuple [list [int ], list [int ]]:
@@ -127,34 +107,14 @@ def _prepare_indices(self, x1: Tensor, x2: Tensor) -> tuple[list[int], list[int]
127107
128108 return indices1 , indices2
129109
130- def _compute_self_kernel (self , indices : list [int ], diag : bool ) -> Tensor :
131- """Compute kernel matrix for self-similarity case."""
132- indices_tuple = tuple (indices )
133- if indices_tuple in self .cache :
134- return self .cache [indices_tuple ]
135-
136- adj_matrices = [self .adjacency_cache [i ] for i in indices ]
137- label_tensors = [self .label_cache [i ] for i in indices ]
138-
139- # Compute kernel matrix
140- K = self ._compute_base_kernel (adj_matrices , label_tensors )
141- if diag :
142- K = torch .diag (K )
143-
144- self .cache [indices_tuple ] = K
145- return K
146-
147- def _compute_cross_kernel (
110+ @lru_cache (maxsize = 128 )
111+ def _compute_kernel (
148112 self ,
149- indices1 : list [int ],
150- indices2 : list [int ],
113+ indices1 : tuple [int ],
114+ indices2 : tuple [int ],
151115 diag : bool ,
152116 ) -> Tensor :
153- """Compute kernel matrix for cross-similarity case."""
154- cache_key = (tuple (indices1 ), tuple (indices2 ))
155- if cache_key in self .cache :
156- return self .cache [cache_key ]
157-
117+ """Compute the kernel matrix."""
158118 all_graphs = list (set (indices1 + indices2 ))
159119 adj_matrices = [self .adjacency_cache [i ] for i in all_graphs ]
160120 label_tensors = [self .label_cache [i ] for i in all_graphs ]
@@ -168,7 +128,6 @@ def _compute_cross_kernel(
168128 if diag :
169129 K = torch .diag (K )
170130
171- self .cache [cache_key ] = K
172131 return K
173132
174133 def _compute_base_kernel (
@@ -206,6 +165,7 @@ def __init__(self, n_iter: int = 5, *, normalize: bool = True) -> None:
206165 self .label_dict = {}
207166 self .label_counter = 0
208167
168+ @lru_cache (maxsize = 128 )
209169 def _get_node_neighbors (self , adj : Tensor ) -> list [list [int ]]:
210170 """Extract neighborhood information from adjacency matrix."""
211171 if adj .layout == torch .sparse_csr :
@@ -221,6 +181,7 @@ def _get_node_neighbors(self, adj: Tensor) -> list[list[int]]:
221181
222182 return neighbors
223183
184+ @lru_cache (maxsize = 128 )
224185 def _wl_iteration (self , adj : Tensor , labels : Tensor ) -> Tensor :
225186 """Perform one WL iteration."""
226187 if not self .label_dict :
0 commit comments