-
Notifications
You must be signed in to change notification settings - Fork 393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
训练代码复现 #187
Comments
上述问题解决了,想咨询论文中说的8卡A100显存多大的,8卡跑的batch_size是4吗?目前在进行vae.decode单卡A100-80G总报内存溢出!! |
您好,想请问这个项目的训练代码需要重构吗? |
需要,参考提到的几个开源,重写整个训练框架,提供的只是推理代码! |
训练框架参考animateanyone,数据处理参考Hallo2,是这个意思吗?复现出来的效果如何? |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
感谢分享~
1、参考hallo2和[Moore-AnimateAnyone进行第2阶段代码复现,denoise_unet部分的第一个参数latent是跟2个开源项目一致吗?
(1)self.denoising_unet第1个参数:noisy_latents = train_noise_scheduler.add_noise(latents, noise, timesteps)
(2)权重冻结:
vae.requires_grad_(False)
denoising_unet.requires_grad_(False)
reference_unet.requires_grad_(False)
face_locator.requires_grad_(False)
2、使用跟hallo2相同的原图加噪与denoise_unet预测的值算mse loss,第2批数据后就出现loss为nan的情况;
mse_loss: tensor(0.2483, device='cuda:0', dtype=torch.float16, grad_fn=)
{'global_step:1, train_loss: 0.248291015625'}
mse_loss: tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=)
{'global_step:2, train_loss: nan'}
3、请问能参考哪个开源进行实现,再次感谢!
The text was updated successfully, but these errors were encountered: