Skip to content

Commit 4cc0b29

Browse files
committed
Use lru_cache instead of simple dict cache
1 parent f7922db commit 4cc0b29

File tree

2 files changed

+16
-50
lines changed

2 files changed

+16
-50
lines changed

grakel_replace/graph_aware_gp_optimization_example.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from gpytorch import ExactMarginalLogLikelihood
1414
from gpytorch.kernels import AdditiveKernel, MaternKernel
1515
from grakel_replace.context_managers import set_graph_lookup
16-
from grakel_replace.kernels import BoTorchWLKernel
16+
from grakel_replace.kernels import BoTorchWLKernel, TorchWLKernel
1717
from grakel_replace.optimization import optimize_acqf_graph
1818
from grakel_replace.utils import min_max_scale, seed_all
1919

@@ -122,3 +122,8 @@
122122
print(f"Best candidate: {best_candidate}")
123123
print(f"Best score: {best_score}")
124124
print(f"Elapsed time: {time.time() - start_time} seconds")
125+
126+
# Clear caches after optimization to avoid memory leaks or unexpected behavior
127+
BoTorchWLKernel._compute_kernel.cache_clear()
128+
TorchWLKernel._get_node_neighbors.cache_clear()
129+
TorchWLKernel._wl_iteration.cache_clear()

grakel_replace/kernels.py

Lines changed: 10 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from functools import lru_cache
34
from typing import TYPE_CHECKING, Any
45

56
import torch
@@ -32,7 +33,6 @@ class BoTorchWLKernel(Kernel):
3233
graph_lookup (list[nx.Graph]): List of graphs used for kernel computation.
3334
n_iter (int): Number of WL iterations.
3435
normalize (bool): Whether to normalize the kernel matrix.
35-
cache (dict[tuple, Tensor]): Cache for storing precomputed kernel matrices.
3636
adjacency_cache (list[Tensor]): Cached adjacency matrices of the graphs.
3737
label_cache (list[Tensor]): Cached initial node labels of the graphs.
3838
"""
@@ -51,7 +51,6 @@ def __init__(
5151
self.graph_lookup = graph_lookup
5252
self.n_iter = n_iter
5353
self.normalize = normalize
54-
self.cache: dict[tuple, Tensor] = {}
5554
self._precompute_graph_data()
5655

5756
def _precompute_graph_data(self) -> None:
@@ -75,34 +74,15 @@ def forward(
7574
**params: Any,
7675
) -> Tensor:
7776
"""Compute kernel matrix containing pairwise similarities between graphs."""
78-
# last_dim_is_batch is for compatibility with base Kernel class.
7977
if last_dim_is_batch:
8078
raise NotImplementedError("Batch dimension handling is not implemented.")
8179

82-
x1_is_x2 = torch.equal(x1, x2)
83-
indices = tuple(x1.flatten().tolist()) if x1_is_x2 else (
84-
tuple(x1.flatten().tolist()), tuple(x2.flatten().tolist()))
85-
86-
if indices in self.cache:
87-
return self.cache[indices]
88-
89-
# Compute kernel matrix if not cached
90-
K = self._compute_kernel(x1, x2, diag=diag)
91-
self.cache[indices] = K
92-
return K
93-
94-
def _compute_kernel(self, x1: Tensor, x2: Tensor, diag: bool) -> Tensor:
95-
"""Compute the kernel matrix."""
9680
if x1.ndim == 3:
9781
return self._handle_batched_input(x1, x2, diag)
9882

9983
indices1, indices2 = self._prepare_indices(x1, x2)
10084

101-
# Check if we're computing self-similarity or cross-similarity
102-
if torch.equal(x1, x2):
103-
return self._compute_self_kernel(indices1, diag)
104-
else:
105-
return self._compute_cross_kernel(indices1, indices2, diag)
85+
return self._compute_kernel(tuple(indices1), tuple(indices2), diag)
10686

10787
def _handle_batched_input(self, x1: Tensor, x2: Tensor, diag: bool) -> Tensor:
10888
"""Handle computation for batched input tensors."""
@@ -111,7 +91,7 @@ def _handle_batched_input(self, x1: Tensor, x2: Tensor, diag: bool) -> Tensor:
11191

11292
out = torch.empty((q_dim_size, x1.shape[1], x2.shape[1]), device=x1.device)
11393
for q in range(q_dim_size):
114-
out[q] = self._compute_kernel(x1[q], x2[q], diag=diag)
94+
out[q] = self.forward(x1[q], x2[q], diag=diag)
11595
return out
11696

11797
def _prepare_indices(self, x1: Tensor, x2: Tensor) -> tuple[list[int], list[int]]:
@@ -127,34 +107,14 @@ def _prepare_indices(self, x1: Tensor, x2: Tensor) -> tuple[list[int], list[int]
127107

128108
return indices1, indices2
129109

130-
def _compute_self_kernel(self, indices: list[int], diag: bool) -> Tensor:
131-
"""Compute kernel matrix for self-similarity case."""
132-
indices_tuple = tuple(indices)
133-
if indices_tuple in self.cache:
134-
return self.cache[indices_tuple]
135-
136-
adj_matrices = [self.adjacency_cache[i] for i in indices]
137-
label_tensors = [self.label_cache[i] for i in indices]
138-
139-
# Compute kernel matrix
140-
K = self._compute_base_kernel(adj_matrices, label_tensors)
141-
if diag:
142-
K = torch.diag(K)
143-
144-
self.cache[indices_tuple] = K
145-
return K
146-
147-
def _compute_cross_kernel(
110+
@lru_cache(maxsize=128)
111+
def _compute_kernel(
148112
self,
149-
indices1: list[int],
150-
indices2: list[int],
113+
indices1: tuple[int],
114+
indices2: tuple[int],
151115
diag: bool,
152116
) -> Tensor:
153-
"""Compute kernel matrix for cross-similarity case."""
154-
cache_key = (tuple(indices1), tuple(indices2))
155-
if cache_key in self.cache:
156-
return self.cache[cache_key]
157-
117+
"""Compute the kernel matrix."""
158118
all_graphs = list(set(indices1 + indices2))
159119
adj_matrices = [self.adjacency_cache[i] for i in all_graphs]
160120
label_tensors = [self.label_cache[i] for i in all_graphs]
@@ -168,7 +128,6 @@ def _compute_cross_kernel(
168128
if diag:
169129
K = torch.diag(K)
170130

171-
self.cache[cache_key] = K
172131
return K
173132

174133
def _compute_base_kernel(
@@ -206,6 +165,7 @@ def __init__(self, n_iter: int = 5, *, normalize: bool = True) -> None:
206165
self.label_dict = {}
207166
self.label_counter = 0
208167

168+
@lru_cache(maxsize=128)
209169
def _get_node_neighbors(self, adj: Tensor) -> list[list[int]]:
210170
"""Extract neighborhood information from adjacency matrix."""
211171
if adj.layout == torch.sparse_csr:
@@ -221,6 +181,7 @@ def _get_node_neighbors(self, adj: Tensor) -> list[list[int]]:
221181

222182
return neighbors
223183

184+
@lru_cache(maxsize=128)
224185
def _wl_iteration(self, adj: Tensor, labels: Tensor) -> Tensor:
225186
"""Perform one WL iteration."""
226187
if not self.label_dict:

0 commit comments

Comments
 (0)