Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible PyTorch implementation of WL kernel #153

Closed
wants to merge 62 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
36fc3bd
Add a PyTorch implementation of WL kernel
vladislavalerievich Oct 28, 2024
b0d3842
Fix imports
vladislavalerievich Oct 29, 2024
f87abd6
Remove redundant copy
vladislavalerievich Oct 29, 2024
358fbb7
Increase precision for allclose
vladislavalerievich Oct 29, 2024
de140b6
Fix calculation for graphs with reordered edges
vladislavalerievich Oct 29, 2024
08c7aea
Increase test coverage
vladislavalerievich Oct 29, 2024
6f07858
Improve readability of TorchWLKernel
vladislavalerievich Oct 30, 2024
896f461
Add additional comments to TorchWLKernel
vladislavalerievich Oct 30, 2024
383e924
Add MixedSingleTaskGP to process graphs
vladislavalerievich Nov 8, 2024
65666a3
Refactor WLKernelWrapper into a standalone WLKernel class.
vladislavalerievich Nov 20, 2024
7fa9432
Update tests
vladislavalerievich Nov 20, 2024
4227f22
Add a check for empty inputs
vladislavalerievich Nov 20, 2024
f194bd2
Improve and combine tests
vladislavalerievich Nov 20, 2024
a104840
Update WLKernel
vladislavalerievich Nov 21, 2024
246f9f6
Add acquisition function with graph sampling
vladislavalerievich Nov 21, 2024
770c626
Add a custom __call__ method to pass graphs during optimization
vladislavalerievich Nov 21, 2024
8bf7ea7
Update MixedSingleTaskGP
vladislavalerievich Dec 7, 2024
84d0104
Remove not used argument
vladislavalerievich Dec 7, 2024
d63239a
Update sample_graphs
vladislavalerievich Dec 7, 2024
3db3f89
Handle different batch dimensions
vladislavalerievich Dec 7, 2024
f69ddbe
Set num_restarts=10
vladislavalerievich Dec 7, 2024
1c4cc83
Add acquisition function
vladislavalerievich Dec 7, 2024
dab9a8c
Update WLKernel
vladislavalerievich Dec 7, 2024
2999582
Make train_inputs private
vladislavalerievich Dec 7, 2024
ad55030
Update tests
vladislavalerievich Dec 7, 2024
8093d31
fix: Implement graph acquisition
eddiebergman Dec 16, 2024
9f978d6
fix: Implement graph acquisition (#164)
vladislavalerievich Dec 24, 2024
a1a29a8
Delete unused MixedSingleTaskGP
vladislavalerievich Dec 24, 2024
046ad66
Add seed_all and min_max_scale
vladislavalerievich Dec 24, 2024
0a609f7
Refactor optimize.py
vladislavalerievich Dec 24, 2024
5486dcc
Speed up WL kernel computations
vladislavalerievich Dec 24, 2024
f140c56
Process wl iterations in batches
vladislavalerievich Dec 24, 2024
371b530
Use CSR
vladislavalerievich Dec 25, 2024
1478fd9
Implement caching
vladislavalerievich Jan 16, 2025
a4ffaaf
Clean up __init__ methods
vladislavalerievich Jan 16, 2025
2ec7d5b
Split _compute_kernel logic into smaller methods
vladislavalerievich Jan 16, 2025
8d6b63b
Rename kernel to BoTorchWLKernel
vladislavalerievich Jan 16, 2025
f18642b
Move GraphDataset class into utils.py
vladislavalerievich Jan 16, 2025
bb92de4
Delete GraphDataset
vladislavalerievich Jan 19, 2025
e409798
Update tests
vladislavalerievich Jan 20, 2025
51e6ae4
Simplify TorchWLKernel
vladislavalerievich Jan 20, 2025
bdd32db
Remove torch_wl_usage_example.py
vladislavalerievich Jan 21, 2025
7747e49
Update grakel_wl_usage_example.py
vladislavalerievich Jan 23, 2025
21b32c8
Update TestTorchWLKernel
vladislavalerievich Jan 23, 2025
dabf4f0
Create graphs_to_tensors function
vladislavalerievich Jan 23, 2025
22cf6d5
Add docstring to BoTorchWLKernel
vladislavalerievich Jan 23, 2025
52b3b14
Add tests for the BoTorchWLKernel
vladislavalerievich Jan 23, 2025
fe79d63
Move redundant files to examples directory
vladislavalerievich Jan 23, 2025
7729d2c
Combine set_graph_lookup context managers into one
vladislavalerievich Jan 23, 2025
1cecacc
Update comments for the optimize_acqf_graph function
vladislavalerievich Jan 23, 2025
ab730d3
Move sample_graphs into utils.py
vladislavalerievich Jan 23, 2025
3eb793d
Rename mixed_single_task_gp_usage_example.py
vladislavalerievich Jan 23, 2025
0cfae28
Add comments
vladislavalerievich Jan 23, 2025
88ddfe1
Move set_graph_lookup into its own file
vladislavalerievich Jan 23, 2025
6d9ea56
Update imports
vladislavalerievich Jan 23, 2025
be04ad2
Print results
vladislavalerievich Jan 23, 2025
458d420
Provide better file names
vladislavalerievich Jan 23, 2025
f7922db
Organize imports
vladislavalerievich Jan 23, 2025
4cc0b29
Use lru_cache instead of simple dict cache
vladislavalerievich Jan 23, 2025
4e8bdad
Improve tests
vladislavalerievich Jan 23, 2025
ea77e44
Fix ruff and mypy complaints
vladislavalerievich Jan 23, 2025
5e2a33b
Improve kernels
vladislavalerievich Jan 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions grakel_replace/mixed_single_task_gp_usage_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from gpytorch import ExactMarginalLogLikelihood
from gpytorch.kernels import AdditiveKernel, MaternKernel
from grakel_replace.optimize import optimize_acqf_graph
from grakel_replace.torch_wl_kernel import TorchWLKernel
from grakel_replace.torch_wl_kernel import BoTorchWLKernel
from grakel_replace.utils import min_max_scale, seed_all

if TYPE_CHECKING:
Expand Down Expand Up @@ -63,7 +63,7 @@
ScaleKernel(CategoricalKernel(
ard_num_dims=N_CATEGORICAL,
active_dims=range(N_NUMERICAL, N_NUMERICAL + N_CATEGORICAL))),
ScaleKernel(TorchWLKernel(
ScaleKernel(BoTorchWLKernel(
graph_lookup=train_graphs, n_iter=5, normalize=True,
active_dims=(X.shape[1] - 1,)))
]
Expand All @@ -82,7 +82,7 @@
def set_graph_lookup(_gp: SingleTaskGP, new_graphs: list[nx.Graph]) -> Iterator[None]:
kernel_prev_graphs: list[tuple[Kernel, list[nx.Graph]]] = []
for kern in _gp.covar_module.sub_kernels():
if isinstance(kern, TorchWLKernel):
if isinstance(kern, BoTorchWLKernel):
kernel_prev_graphs.append((kern, kern.graph_lookup))
kern.set_graph_lookup(new_graphs)

Expand Down
6 changes: 3 additions & 3 deletions grakel_replace/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import networkx as nx
import torch
from botorch.optim import optimize_acqf_mixed
from grakel_replace.torch_wl_kernel import TorchWLKernel
from grakel_replace.torch_wl_kernel import BoTorchWLKernel

if TYPE_CHECKING:
from botorch.acquisition import AcquisitionFunction
Expand All @@ -33,13 +33,13 @@ def set_graph_lookup(
kernel_prev_graphs: list[tuple[Kernel, list[nx.Graph]]] = []

# Determine the modules to update based on the kernel type
if isinstance(kernel, TorchWLKernel):
if isinstance(kernel, BoTorchWLKernel):
modules = [kernel]
else:
assert hasattr(
kernel, "sub_kernels"
), "Kernel module must have sub_kernels method."
modules = [k for k in kernel.sub_kernels() if isinstance(k, TorchWLKernel)]
modules = [k for k in kernel.sub_kernels() if isinstance(k, BoTorchWLKernel)]

# Save the current graph lookup and set the new graph lookup
for kern in modules:
Expand Down
Loading
Loading