Skip to content

[Question] Memory keep increase in MetaAdam due to gradient link #218

@ycsos

Description

@ycsos

Required prerequisites

What version of TorchOpt are you using?

0.7.3

System information

>>> print(sys.version, sys.platform)
3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:36:13) [GCC 12.3.0] linux
>>> print(torchopt.__version__, torch.__version__, functorch.__version__)
0.7.3 2.3.0 2.3.0

Problem description

when use torchopt.MetaAdam and step some times, the memory use in gpu are continuously increase. It should not be, will you excute next step, the tensor create in the former step is no need should be release. I find the reason: metaOptimizer not detach the gradient link in optimizer. and former tensor was not release by torch due to dependency.

you can run the test code, the first one memory increase by step increase. and second one (I change the code to detach the grad link) the memory is stable when step increase:
before:
image
after:
image

Reproducible example code

The Python snippets:

import torch
import torch.nn
import torch.nn.functional as F
import time
import torchopt

class test_nn(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.a = torch.nn.Parameter(torch.randn(768, 512, device="cuda")) # 768 * 512 * 4 / 1024 = 1536 KB
        self.b = torch.nn.Parameter(torch.randn(768, 768, device="cuda")) # 768 * 768 * 4 / 1024 = 2304 KB
        self.test_time = 10
 
    def forward(self):
        from torch.profiler import profile, record_function, ProfilerActivity
        with profile(
            activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
            schedule=torch.profiler.schedule(wait=0, warmup=0, active=1, repeat=1),
            on_trace_ready=torch.profiler.tensorboard_trace_handler('./log_test/'),
            profile_memory=True,
            with_stack=True,
            record_shapes=True,
        ) as prof:
            def test_func1(a, b):
                with torch.enable_grad():
                    c = a * 2
                    d = torch.matmul(b, c)
                    return torch.sum(a + d)
                
            optimizer = torchopt.MetaAdam(self, lr=0.1)
            for _ in range(self.test_time):
                loss = test_func1(self.a, self.b)
                optimizer.step(loss)
                print(torch.cuda.max_memory_allocated())
                
                
def main():
    a = test_nn()
    a.forward()
    
if __name__ == "__main__":
    main()

Command lines:

python test.py

Traceback

current:
62526464
90054144
106309632
122827264
138558464
155600384
171331584
187587072
204104704
219835904

Expected behavior

should be :
57019392
60951552
61737984
63179776
63179776
63179776
63179776
63179776
63179776
63179776

Additional context

No response

Metadata

Metadata

Assignees

Labels

questionFurther information is requested

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions