Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Memory leak with minimal example #2728

Open
1 task done
AdrianSosic opened this issue Feb 4, 2025 · 7 comments
Open
1 task done

[Bug]: Memory leak with minimal example #2728

AdrianSosic opened this issue Feb 4, 2025 · 7 comments
Assignees
Labels
bug Something isn't working

Comments

@AdrianSosic
Copy link
Contributor

What happened?

Hi @esantorella & @saitcakmak 👋🏼 After long time, I've finally had a moment to get back to #641 because I now have an actual minimal reproducing example.

For me, the process gets consistently killed after ~20 iterations. Until that point, it keeps allocating memory/swap and eventually crashes. Could you perhaps confirm if this is also the case for you?

Haven't yet checked if the gc.get_objects method suggest by @esantorella here to verify if it's an actual leak or just over-allocation. But in any case, a crash is unexpected since the code obviously should not allocate any long-term resources for the independent optimizations happening in the loop.

Please provide a minimal, reproducible example of the unexpected behavior.

Adapted from the landing page code:

import torch
from botorch.acquisition import qNegIntegratedPosteriorVariance
from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.models.transforms import Normalize, Standardize
from botorch.optim import optimize_acqf
from gpytorch.mlls import ExactMarginalLogLikelihood

d = 10
train_X = torch.rand(100, d, dtype=torch.double)
mc_points = torch.rand(100, d, dtype=torch.double)
train_Y = torch.rand(100, 1, dtype=torch.double)

gp = SingleTaskGP(
    train_X=train_X,
    train_Y=train_Y,
    input_transform=Normalize(d=d),
    outcome_transform=Standardize(m=1),
)
mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
fit_gpytorch_mll(mll)

acq = qNegIntegratedPosteriorVariance(model=gp, mc_points=mc_points)
bounds = torch.stack([torch.zeros(d), torch.ones(d)]).to(torch.double)

for i in range(1000):
    candidate, acq_value = optimize_acqf(
        acq, bounds=bounds, q=10, num_restarts=20, raw_samples=64, sequential=False
    )
    print(i)

Please paste any relevant traceback/logs produced by the example provided.

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
[1]    63481 killed

BoTorch Version

0.12.0

Python Version

3.10

Operating System

macOS

Code of Conduct

  • I agree to follow BoTorch's Code of Conduct
@AdrianSosic AdrianSosic added the bug Something isn't working label Feb 4, 2025
@saitcakmak
Copy link
Contributor

Hi @AdrianSosic. Thanks for sharing the simple repro. This does reproduce for me on both 0.12.0 and 0.13.0. The memory usage climbs up a few GP per replication until it is killed. I'll investigate

@saitcakmak saitcakmak self-assigned this Feb 4, 2025
@saitcakmak
Copy link
Contributor

I think I've identified the part that causes the memory leak but I don't yet know why. Reduced the repro all the way to evaluations of qNIPV, and added a simplified implementation of the acqf to make it modifiable.

from botorch.models import SingleTaskGP
import torch
from botorch import settings
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.models.model import Model
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.transforms import concatenate_pending_points, t_batch_mode_transform
from torch import Tensor


class qNegIntegratedPosteriorVariance(AcquisitionFunction):
    def __init__(
        self,
        model: Model,
        mc_points: Tensor,
    ) -> None:
        super().__init__(model=model)
        self.sampler = SobolQMCNormalSampler(sample_shape=torch.Size([1]))
        self.register_buffer("mc_points", mc_points)

    @t_batch_mode_transform()
    def forward(self, X: Tensor) -> Tensor:
        # Construct the fantasy model (we actually do not use the full model,
        # this is just a convenient way of computing fast posterior covariances
        fantasy_model = self.model.fantasize(
            X=X,
            sampler=self.sampler,
        )

        bdims = tuple(1 for _ in X.shape[:-2])
        mc_points = self.mc_points.view(*bdims, -1, X.size(-1))

        # evaluate the posterior at the grid points
        with settings.propagate_grads(True):  # Changing this to FALSE prevents the leak.
            posterior = fantasy_model.posterior(
                mc_points
            )
        neg_variance = posterior.variance.mul(-1.0)
        return neg_variance.mean(dim=-2).squeeze(-1).squeeze(0)

d = 10
mc_points = torch.rand(100, d, dtype=torch.double)

gp = SingleTaskGP(
    train_X=torch.rand(100, d, dtype=torch.double),
    train_Y=torch.rand(100, 1, dtype=torch.double),
).eval()


acq = qNegIntegratedPosteriorVariance(model=gp, mc_points=mc_points)

for i in range(10000):
    acq(torch.rand(128, 5, d, dtype=torch.double, requires_grad=False))
    print(f"{i=}")

If we replace settings.propagate_grads(True) with settings.propagate_grads(False), the leak goes away.

@saitcakmak
Copy link
Contributor

saitcakmak commented Feb 4, 2025

That can be simplified further. No need for the acqf.

from contextlib import nullcontext
from botorch.models import SingleTaskGP
import torch
from botorch import settings
from botorch.sampling.normal import SobolQMCNormalSampler
from gpytorch import settings as gpt_settings
from torch import Tensor


d = 10
mc_points = torch.rand(100, d, dtype=torch.double)

gp = SingleTaskGP(
    train_X=torch.rand(100, d, dtype=torch.double),
    train_Y=torch.rand(100, 1, dtype=torch.double),
).eval()


for i in range(10000):
    X = torch.rand(128, 5, d, dtype=torch.double, requires_grad=False)
    fantasy_model = gp.fantasize(
        X=X,
        sampler=SobolQMCNormalSampler(sample_shape=torch.Size([1])),
    )
    with settings.propagate_grads(False):  # Set this to TRUE to reproduce the leak.
        posterior = fantasy_model.posterior(
            mc_points
        )
    print(f"{i=}")

@saitcakmak
Copy link
Contributor

And we can reproduce directly with a single gpytorch context manager:

from botorch.models import SingleTaskGP
import torch
from botorch.sampling.normal import SobolQMCNormalSampler
from gpytorch import settings as gpt_settings
from torch import Tensor

d = 10
mc_points = torch.rand(100, d, dtype=torch.double)

gp = SingleTaskGP(
    train_X=torch.rand(100, d, dtype=torch.double),
    train_Y=torch.rand(100, 1, dtype=torch.double),
).eval()


for i in range(10000):
    X = torch.rand(128, 5, d, dtype=torch.double, requires_grad=False)
    fantasy_model = gp.fantasize(
        X=X,
        sampler=SobolQMCNormalSampler(sample_shape=torch.Size([1])),
    ).eval()
    with gpt_settings.detach_test_caches(False):  # Set to TRUE and the leak goes away.
         fantasy_model(mc_points)
    print(f"{i=}")

@AdrianSosic
Copy link
Contributor Author

Interesting, thanks for sharing. So it seems that some of the backprop graphs keep lying around, right? I mean – if I understand correctly what detach_test_caches is supposed to do – this is sort of the intended effect, right? But probably not to the extent we are currently observing ...

@saitcakmak
Copy link
Contributor

That's my guess as well. The context manager seems to control detach calls for some mean and covar caches within the GPyTorch prediction strategy. I'll look into it more today

@saitcakmak
Copy link
Contributor

I've isolated the issue to a specific detach call in mean_cache and created cornellius-gp/gpytorch#2631 to track this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants