Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About memory usage #13

Open
Arksyd96 opened this issue Mar 14, 2023 · 3 comments
Open

About memory usage #13

Arksyd96 opened this issue Mar 14, 2023 · 3 comments

Comments

@Arksyd96
Copy link

Hello, having issues with memory usage.
Is it normal that even with 48Go VRAM i cannot run the reverse process for generation with a small batch of 2 ?
What are you specs ?

@w86763777
Copy link
Owner

No, that is abnormal.
To train CIFAR-10, an 11G VRAM like the 2080 Ti is sufficient. However, if you use a larger model, the VRAM requirements may increase.

@Arksyd96
Copy link
Author

Arksyd96 commented Mar 15, 2023

Yeah problem fixed. Actually i'm training on 1x128x128 BraTS images and i forgot to put a torch.no_grad(): during reverse process.

However, i still have an issue with the reverse process. During training, the MSE is well optimized, but it only generates noise. Here's my sampling code if you want to give it a look and tell me if its ok :

    def q_mean_variance(self, x_0, x_t, t):
        posterior_mean = (
            self.posterior_mean_c1[t, None, None, None].to(device) * x_0 + 
            self.posterior_mean_c2[t, None, None, None].to(device) * x_t
        )
        posterior_log_var = self.posterior_log_var[t, None, None, None]
        return posterior_mean, posterior_log_var
    
    def p_mean_variance(self, x_t, t):
        model_logvar = torch.log(torch.cat([self.posterior_var[1: 2], self.betas[1:]])).to(device)
        model_logvar = model_logvar[t, None, None, None]

        eps = self.model(x_t, t.to(device))
        x_0 = self.predict_x_start_from_eps(x_t, t, eps)
        model_mean, _ = self.q_mean_variance(x_0, x_t, t)

        return model_mean, model_logvar
    
    def predict_x_start_from_eps(self, x_t, t, eps):
        return (
            torch.sqrt(1. - self.alpha_prods[t, None, None, None].to(device)) * x_t +
            torch.sqrt(1. / self.alpha_prods[t, None, None, None].to(device) - 1.) * eps
        )

    def forward(self, x_T):
        x_t = x_T
        for timestep in reversed(range(self.T)):
            t = torch.full((x_T.shape[0],), fill_value=timestep, dtype=torch.long)
            mean, logvar = self.p_mean_variance(x_t, t)
            if timestep > 0:
                noise = torch.randn_like(x_T)
            else:
                noise = 0
            x_t = mean + torch.exp(0.5 * logvar) * noise
        x_0 = x_t
        return torch.clip(x_0, -1, 1)

@w86763777
Copy link
Owner

w86763777 commented Apr 13, 2023

Apologies for the delayed response.

To the best of my recollection, you do not need to update the GaussianDiffusionTrainer and GaussianDiffusionSampler when training with images of different sizes. These components are capable of adapting to different image dimensions.

However, you will need to modify the model and data-related code, including the UNet, dataset, and dataloader, to accommodate the new image sizes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants