Skip to content

Commit

Permalink
custom clip fid features
Browse files Browse the repository at this point in the history
  • Loading branch information
GaParmar committed Sep 5, 2022
1 parent fb49969 commit 55ec168
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 18 deletions.
1 change: 0 additions & 1 deletion cleanfid/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def get_reference_statistics(name, res, mode="clean", model_name="inception_v3",
model_modifier = ""
else:
model_modifier = "_"+model_name

if metric == "FID":
rel_path = (f"{name}_{mode}{model_modifier}_{split}_{res}.npz").lower()
url = f"{base_url}/{rel_path}"
Expand Down
58 changes: 41 additions & 17 deletions cleanfid/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,32 +288,39 @@ def compare_folders(fdir1, fdir2, feat_model, mode, num_workers=0,
"""
Test if a custom statistic exists
"""
def test_stats_exists(name, mode, metric="FID"):
def test_stats_exists(name, mode, model_name="inception_v3", metric="FID"):
stats_folder = os.path.join(os.path.dirname(cleanfid.__file__), "stats")
split, res = "custom", "na"
if model_name=="inception_v3":
model_modifier = ""
else:
model_modifier = "_"+model_name
if metric == "FID":
fname = f"{name}_{mode}_{split}_{res}.npz"
fname = f"{name}_{mode}{model_modifier}_{split}_{res}.npz"
elif metric == "KID":
fname = f"{name}_{mode}_{split}_{res}_kid.npz"
fname = f"{name}_{mode}{model_modifier}_{split}_{res}_kid.npz"
fpath = os.path.join(stats_folder, fname)
return os.path.exists(fpath)


"""
Remove the custom FID features from the stats folder
"""
def remove_custom_stats(name, mode="clean"):
def remove_custom_stats(name, mode="clean", model_name="inception_v3"):
stats_folder = os.path.join(os.path.dirname(cleanfid.__file__), "stats")
# remove the FID stats
split, res = "custom", "na"
outname = f"{name}_{mode}_{split}_{res}.npz"
outf = os.path.join(stats_folder, outname)
if model_name=="inception_v3":
model_modifier = ""
else:
model_modifier = "_"+model_name
outf = os.path.join(stats_folder, f"{name}_{mode}{model_modifier}_{split}_{res}.npz".lower())
if not os.path.exists(outf):
msg = f"The stats file {name} does not exist."
raise Exception(msg)
os.remove(outf)
# remove the KID stats
outf = os.path.join(stats_folder, f"{name}_{mode}_{split}_{res}_kid.npz")
outf = os.path.join(stats_folder, f"{name}_{mode}{model_modifier}_{split}_{res}_kid.npz")
if not os.path.exists(outf):
msg = f"The stats file {name} does not exist."
raise Exception(msg)
Expand All @@ -323,31 +330,48 @@ def remove_custom_stats(name, mode="clean"):
"""
Cache a custom dataset statistics file
"""
def make_custom_stats(name, fdir, num=None, mode="clean",
num_workers=0, batch_size=64, device=torch.device("cuda")):
def make_custom_stats(name, fdir, num=None, mode="clean", model_name="inception_v3",
num_workers=0, batch_size=64, device=torch.device("cuda"), verbose=True):
stats_folder = os.path.join(os.path.dirname(cleanfid.__file__), "stats")
os.makedirs(stats_folder, exist_ok=True)
split, res = "custom", "na"
outname = f"{name}_{mode}_{split}_{res}.npz".lower()
outf = os.path.join(stats_folder, outname)
if model_name=="inception_v3":
model_modifier = ""
else:
model_modifier = "_"+model_name
outf = os.path.join(stats_folder, f"{name}_{mode}{model_modifier}_{split}_{res}.npz".lower())
# if the custom stat file already exists
if os.path.exists(outf):
msg = f"The statistics file {name} already exists. "
msg += "Use remove_custom_stats function to delete it first."
raise Exception(msg)
if model_name=="inception_v3":
feat_model = build_feature_extractor(mode, device)
custom_fn_resize = None
custom_image_tranform = None
elif model_name=="clip_vit_b_32":
from cleanfid.clip_features import CLIP_fx, img_preprocess_clip
clip_fx = CLIP_fx("ViT-B/32")
feat_model = clip_fx
custom_fn_resize = img_preprocess_clip
custom_image_tranform = None
else:
raise ValueError(f"The entered model name - {model_name} was not recognized.")

feat_model = build_feature_extractor(mode, device)
fbname = os.path.basename(fdir)
# get all inception features for folder images
np_feats = get_folder_features(fdir, feat_model, num_workers=num_workers, num=num,
batch_size=batch_size, device=device,
mode=mode, description=f"custom stats: {fbname} : ")
batch_size=batch_size, device=device, verbose=verbose,
mode=mode, description=f"custom stats: {os.path.basename(fdir)} : ",
custom_image_tranform=custom_image_tranform,
custom_fn_resize=custom_fn_resize)

mu = np.mean(np_feats, axis=0)
sigma = np.cov(np_feats, rowvar=False)
print(f"saving custom FID stats to {outf}")
np.savez_compressed(outf, mu=mu, sigma=sigma)

# KID stats
outf = os.path.join(stats_folder, f"{name}_{mode}_{split}_{res}_kid.npz".lower())
outf = os.path.join(stats_folder, f"{name}_{mode}{model_modifier}_{split}_{res}_kid.npz".lower())
print(f"saving custom KID stats to {outf}")
np.savez_compressed(outf, feats=np_feats)

Expand Down Expand Up @@ -435,7 +459,7 @@ def compute_fid(fdir1=None, fdir2=None, gen=None,
dataset_res=1024, dataset_split="train", num_gen=50_000, z_dim=512,
custom_feat_extractor=None, verbose=True,
custom_image_tranform=None, custom_fn_resize=None):
# build the feature extractor based on the mode
# build the feature extractor based on the mode and the model to be used
if custom_feat_extractor is None and model_name=="inception_v3":
feat_model = build_feature_extractor(mode, device)
elif custom_feat_extractor is None and model_name=="clip_vit_b_32":
Expand Down

0 comments on commit 55ec168

Please sign in to comment.