You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I attempted to use the following script for batch image generation:
importSimpleITKassitkimporttorchimportosimportnumpyasnpfromutilsimporttrim_state_dict_namefrommatplotlibimportpyplotaspltimportnibabelasniblatent_dim=1024save_step=80000batch_size=1img_size=256num_class=0exp_name="HA_GAN_run1"num_images=3039# Number of images to generate in a loopifimg_size==256:
frommodels.Model_HA_GAN_256importGenerator, Encoder, Sub_Encoderelifimg_size==128:
frommodels.Model_HA_GAN_128importGenerator, Encoder, Sub_EncoderG=Generator(mode='eval', latent_dim=latent_dim, num_class=num_class).cuda()
E=Encoder().cuda()
Sub_E=Sub_Encoder(latent_dim=latent_dim).cuda()
# ----------------------# Load Generator weightsckpt_path="/HA-GAN/GSP_HA_GAN_pretrained/G_iter80000.pth"ckpt=torch.load(ckpt_path)['model']
ckpt=trim_state_dict_name(ckpt)
G.load_state_dict(ckpt)
# Load Encoder weightsckpt_path="/HA-GAN/GSP_HA_GAN_pretrained/E_iter80000.pth"ckpt=torch.load(ckpt_path)['model']
ckpt=trim_state_dict_name(ckpt)
E.load_state_dict(ckpt)
# Load Sub_Encoder weightsckpt_path="/HA-GAN/GSP_HA_GAN_pretrained/Sub_E_iter80000.pth"ckpt=torch.load(ckpt_path)['model']
ckpt=trim_state_dict_name(ckpt)
Sub_E.load_state_dict(ckpt)
print(exp_name, save_step, "step weights loaded.")
outpath="/HA-GAN/GSP_HA_GAN_images"os.makedirs(outpath, exist_ok=True)
G.eval()
E.eval()
Sub_E.eval()
torch.cuda.empty_cache()
# ----------------------# Loop to generate multiple images and save themlow_threshold=-1024high_threshold=600withtorch.no_grad():
foriinrange(num_images):
z_rand=torch.randn((batch_size, latent_dim)).cuda()
x_rand=G(z_rand, 0)
x_rand=x_rand.detach().cpu().numpy()
# Map the generated output from [-1,1] to [0,1]; adjust this step based on your model's outputx_rand=0.5*x_rand+0.5# For batch_size=1, take the first image and channel 0x_rand=x_rand[0, 0, :, :, :]
# Map to the typical CT intensity range [low_threshold, high_threshold]x_rand_nifti=x_rand* (high_threshold-low_threshold) +low_thresholdx_rand_nifti=x_rand_nifti.astype(np.int16)
# Transpose and encapsulate in NIfTI formatx_rand_nifti=nib.Nifti1Image(x_rand_nifti.transpose((2, 1, 0)), affine=np.eye(4))
# Construct output filename with loop index iout_filename=os.path.join(outpath, f"x_rand_nifti_{i}.nii.gz")
nib.save(x_rand_nifti, out_filename)
print(f"Saved {out_filename}")
However, when loading the GSP_HA_GAN pretrained weights, I encountered the following error:
Traceback (most recent call last):
File "inference_gsp.py", line 31, in <module>
ckpt = torch.load(ckpt_path)['model']
File ".../torch/serialization.py", line 593, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File ".../torch/serialization.py", line 747, in _legacy_load
return legacy_load(f)
File ".../torch/serialization.py", line 672, in legacy_load
tar.extract('storages', path=tmpdir)
File ".../tarfile.py", line 2060, in extract
tarinfo = self.getmember(member)
File ".../tarfile.py", line 1782, in getmember
raise KeyError("filename %r not found" % name)
KeyError: "filename 'storages' not found"
After searching online, it appears that this error may indicate that the weight file is corrupted.
Could you please update the repository with a new, valid version of the pretrained weight files? Your assistance would be greatly appreciated.
Thank you very much!
The text was updated successfully, but these errors were encountered:
I attempted to use the following script for batch image generation:
However, when loading the GSP_HA_GAN pretrained weights, I encountered the following error:
After searching online, it appears that this error may indicate that the weight file is corrupted.
Could you please update the repository with a new, valid version of the pretrained weight files? Your assistance would be greatly appreciated.
Thank you very much!
The text was updated successfully, but these errors were encountered: