Skip to content

Commit

Permalink
(tests pass) Substantial refactor to clean up main file
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam Sepiol committed Oct 1, 2021
1 parent ef51825 commit 9e40100
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 247 deletions.
4 changes: 1 addition & 3 deletions cgd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
# __init__.py for cgd package
import sys
sys.path.append('./cgd')
from . import *
151 changes: 33 additions & 118 deletions cgd/cgd.py
Original file line number Diff line number Diff line change
@@ -1,108 +1,21 @@
import argparse
import os
import sys
import time
from pathlib import Path

import clip
import lpips
import torch as th
import wandb
from PIL import Image
from torch.nn.functional import normalize
from torchvision import transforms as T
from torchvision.transforms import functional as tvf
from torchvision.transforms.transforms import ToTensor
from tqdm.auto import tqdm

from cgd.clip_util import (CLIP_MODEL_NAMES, CLIP_NORMALIZE, MakeCutouts,
load_clip)
from cgd.loss_util import spherical_dist_loss, tv_loss
from cgd.util import (CACHE_PATH, create_gif, download_guided_diffusion, fetch,
load_guided_diffusion, log_image)

sys.path.append(os.path.join(os.getcwd(), "guided-diffusion"))

TIMESTEP_RESPACINGS = ("25", "50", "100", "250", "500", "1000",
"ddim25", "ddim50", "ddim100", "ddim250", "ddim500", "ddim1000")
DIFFUSION_SCHEDULES = (25, 50, 100, 250, 500, 1000)
IMAGE_SIZES = (64, 128, 256, 512)


def check_parameters(
prompts: list,
image_prompts: list,
image_size: int,
timestep_respacing: str,
diffusion_steps: int,
clip_model_name: str,
save_frequency: int,
noise_schedule: str,
):
if not (len(prompts) > 0 or len(image_prompts) > 0):
raise ValueError("Must provide at least one prompt, text or image.")
if not (noise_schedule in ['linear', 'cosine']):
raise ValueError('Noise schedule should be one of: linear, cosine')
if not (image_size in IMAGE_SIZES):
raise ValueError(f"--image size should be one of {IMAGE_SIZES}")
if not (0 < save_frequency <= int(timestep_respacing.replace('ddim', ''))):
raise ValueError(
"--save_frequency must be greater than 0 and less than `timestep_respacing`")
if not (diffusion_steps in DIFFUSION_SCHEDULES):
print('(warning) Diffusion steps should be one of:', DIFFUSION_SCHEDULES)
if not (timestep_respacing in TIMESTEP_RESPACINGS):
print(
f"Pausing run. `timestep_respacing` should be one of {TIMESTEP_RESPACINGS}. CTRL-C if this was a mistake.")
time.sleep(5)
print("Resuming run.")
if clip_model_name.endswith('.pt') or clip_model_name.endswith('.pth'):
assert os.path.isfile(
clip_model_name), f"{clip_model_name} does not exist"
print(f"Loading custom model from {clip_model_name}")
elif not (clip_model_name in CLIP_MODEL_NAMES):
print(
f"--clip model name should be one of: {CLIP_MODEL_NAMES} unless you are trying to use your own checkpoint.")
print(f"Loading OpenAI CLIP - {clip_model_name}")
from cgd import losses
from cgd import clip_util
from cgd import script_util


# Define necessary functions

def parse_prompt(prompt): # parse a single prompt in the form "<text||img_url>:<weight>"
if prompt.startswith('http://') or prompt.startswith('https://'):
vals = prompt.rsplit(':', 2) # theres two colons, so we grab the 2nd
vals = [vals[0] + ':' + vals[1], *vals[2:]]
else:
vals = prompt.rsplit(':', 1) # grab weight after colon
vals = vals + ['', '1'][len(vals):] # if no weight, use 1
return vals[0], float(vals[1]) # return text, weight


def encode_text_prompt(txt, weight, clip_model_name="ViT-B/32", device="cpu"):
clip_model, _ = load_clip(clip_model_name, device)
txt_tokens = clip.tokenize(txt).to(device)
txt_encoded = clip_model.encode_text(txt_tokens).float()
return txt_encoded, weight


def encode_image_prompt(image: str, weight: float, diffusion_size: int, num_cutouts, clip_model_name: str = "ViT-B/32", device: str = "cpu"):
clip_model, clip_size = load_clip(clip_model_name, device)
make_cutouts = MakeCutouts(cut_size=clip_size, num_cutouts=num_cutouts)
pil_img = Image.open(fetch(image)).convert('RGB')
smallest_side = min(diffusion_size, *pil_img.size)
# You can ignore the type warning caused by pytorch resize having
# an incorrect type hint for their resize signature. which does indeed support PIL.Image
pil_img = tvf.resize(pil_img, [smallest_side],
tvf.InterpolationMode.LANCZOS)
batch = make_cutouts(tvf.to_tensor(pil_img).unsqueeze(0).to(device))
batch_embed = clip_model.encode_image(normalize(batch)).float()
batch_weight = [weight / make_cutouts.cutn] * make_cutouts.cutn
return batch_embed, batch_weight


def range_loss(input):
return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])


def clip_guided_diffusion(
image_size: int = 128,
num_cutouts: int = 16,
Expand All @@ -121,7 +34,7 @@ def clip_guided_diffusion(
seed: int = 0,
diffusion_steps: int = 1000,
skip_timesteps: int = 0,
checkpoints_dir: str = CACHE_PATH,
checkpoints_dir: str = script_util.CACHE_PATH,
clip_model_name: str = "ViT-B/32",
randomize_class: bool = True,
prefix_path: str = 'outputs',
Expand All @@ -133,7 +46,6 @@ def clip_guided_diffusion(
wandb_entity: str = None,
progress: bool = True,
):

if len(device) == 0:
device = 'cuda' if th.cuda.is_available() else 'cpu'
print(
Expand All @@ -142,39 +54,43 @@ def clip_guided_diffusion(
print(f"Using device {device}")
fp32_diffusion = (device == 'cpu')

wandb_run = None
if wandb_project is not None:
# just use local vars for config
wandb_run = wandb.init(project=wandb_project, entity=wandb_entity, config=locals())
wandb_run = wandb.init(project=wandb_project,
entity=wandb_entity, config=locals())
else:
print(f"--wandb_project not specified. Skipping W&B integration.")

if seed:
th.manual_seed(seed)
th.manual_seed(seed)

# only use magnitude for low timestep_respacing
use_magnitude = (int(timestep_respacing.replace("ddim", "")) <= 25 or image_size == 64)
use_magnitude = (int(timestep_respacing.replace(
"ddim", "")) <= 25 or image_size == 64)
# only use saturation loss on ddim
use_saturation = ("ddim" in timestep_respacing or image_size == 64)

Path(prefix_path).mkdir(parents=True, exist_ok=True)
Path(checkpoints_dir).mkdir(parents=True, exist_ok=True)

diffusion_path = download_guided_diffusion(
diffusion_path = script_util.download_guided_diffusion(
image_size=image_size, checkpoints_dir=checkpoints_dir, class_cond=class_cond)

# Load CLIP model/Encode text/Create `MakeCutouts`
embeds_list = []
weights_list = []
clip_model, clip_size = load_clip(clip_model_name, device)
clip_model, clip_size = clip_util.load_clip(clip_model_name, device)

for prompt in prompts:
text, weight = parse_prompt(prompt)
text, weight = encode_text_prompt(
text, weight = script_util.parse_prompt(prompt)
text, weight = clip_util.encode_text_prompt(
text, weight, clip_model_name, device)
embeds_list.append(text)
weights_list.append(weight)

for image_prompt in image_prompts:
img, weight = parse_prompt(image_prompt)
image_prompt, batched_weight = encode_image_prompt(
img, weight = script_util.parse_prompt(image_prompt)
image_prompt, batched_weight = clip_util.encode_image_prompt(
img, weight, image_size, num_cutouts=num_cutouts, clip_model_name=clip_model_name, device=device)
embeds_list.append(image_prompt)
weights_list.extend(batched_weight)
Expand All @@ -191,13 +107,13 @@ def clip_guided_diffusion(
if use_augs:
tqdm.write(
f"Using augmentations to improve performance for lower timestep_respacing of {timestep_respacing}")
make_cutouts = MakeCutouts(cut_size=clip_size, num_cutouts=num_cutouts,
cutout_size_power=cutout_power, use_augs=use_augs)
make_cutouts = clip_util.MakeCutouts(cut_size=clip_size, num_cutouts=num_cutouts,
cutout_size_power=cutout_power, use_augs=use_augs)

# Load initial image (if provided)
init_tensor = None
if len(init_image) > 0:
pil_image = Image.open(fetch(init_image)).convert(
pil_image = Image.open(script_util.fetch(init_image)).convert(
"RGB").resize((image_size, image_size), Image.LANCZOS)
init_tensor = ToTensor()(pil_image).to(device).unsqueeze(0).mul(2).sub(1)

Expand All @@ -208,7 +124,7 @@ def clip_guided_diffusion(
[batch_size], device=device, dtype=th.long)

# Load guided diffusion
gd_model, diffusion = load_guided_diffusion(
gd_model, diffusion = script_util.load_guided_diffusion(
checkpoint_path=diffusion_path,
image_size=image_size, class_cond=class_cond,
diffusion_steps=diffusion_steps,
Expand All @@ -229,7 +145,6 @@ def cond_fn(x, t, out, y=None):
fac = diffusion.sqrt_one_minus_alphas_cumprod[current_timestep]
sigmas = 1 - fac
x_in = out["pred_xstart"] * fac + x * sigmas
wandb_run = None
if wandb_project is not None:
log['Generations'] = [
wandb.Image(x, caption=f"Noisy Sample"),
Expand All @@ -238,16 +153,16 @@ def cond_fn(x, t, out, y=None):
wandb.Image(x_in, caption=f"Blended (what CLIP sees)"),
]

clip_in = CLIP_NORMALIZE(make_cutouts(x_in.add(1).div(2)))
clip_in = clip_util.CLIP_NORMALIZE(make_cutouts(x_in.add(1).div(2)))
cutout_embeds = clip_model.encode_image(
clip_in).float().view([num_cutouts, n, -1])
dists = spherical_dist_loss(
dists = losses.spherical_dist_loss(
cutout_embeds.unsqueeze(0), target_embeds.unsqueeze(0))
dists = dists.view([num_cutouts, n, -1])

clip_losses = dists.mul(weights).sum(2).mean(0)
range_losses = range_loss(out["pred_xstart"])
tv_losses = tv_loss(x_in)
range_losses = losses.range_loss(out["pred_xstart"])
tv_losses = losses.tv_loss(x_in)

clip_losses = clip_losses.sum() * clip_guidance_scale
range_losses = range_losses.sum() * range_scale
Expand Down Expand Up @@ -282,7 +197,7 @@ def cond_fn(x, t, out, y=None):
if progress:
tqdm.write(
"\t".join([f"{k}: {v:.3f}" for k, v in log.items() if "loss" in k.lower()]))
if wandb_run is not None:
if wandb_project is not None:
wandb_run.log(log)
return final_loss

Expand Down Expand Up @@ -312,11 +227,11 @@ def cond_fn(x, t, out, y=None):
current_timestep -= 1
if step % save_frequency == 0 or current_timestep == -1:
for batch_idx, image_tensor in enumerate(sample["pred_xstart"]):
yield batch_idx, log_image(image_tensor, prefix_path, prompts, step, batch_idx)
yield batch_idx, script_util.log_image(image_tensor, prefix_path, prompts, step, batch_idx)
# if wandb_project is not None: wandb.log({"image": wandb.Image(image_tensor, caption="|".join(prompts))})

for batch_idx in range(batch_size):
create_gif(prefix_path, prompts, batch_idx)
script_util.create_gif(prefix_path, prompts, batch_idx)

except (RuntimeError, KeyboardInterrupt) as runtime_ex:
if "CUDA out of memory" in str(runtime_ex):
Expand Down Expand Up @@ -348,7 +263,7 @@ def main():
help="Number of timesteps to blend image for. CLIP guidance occurs after this.")
p.add_argument("--prefix", "-dir", default="outputs",
type=Path, help="output directory")
p.add_argument("--checkpoints_dir", "-ckpts", default=CACHE_PATH,
p.add_argument("--checkpoints_dir", "-ckpts", default=script_util.CACHE_PATH,
type=Path, help="Path subdirectory containing checkpoints.")
p.add_argument("--batch_size", "-bs", type=int,
default=1, help="the batch size")
Expand All @@ -373,7 +288,7 @@ def main():
p.add_argument("--cutout_power", "-cutpow", type=float,
default=1.0, help="Cutout size power")
p.add_argument("--clip_model", "-clip", type=str, default="ViT-B/32",
help=f"clip model name. Should be one of: {CLIP_MODEL_NAMES} or a checkpoint filename ending in `.pt`")
help=f"clip model name. Should be one of: {clip_util.CLIP_MODEL_NAMES} or a checkpoint filename ending in `.pt`")
p.add_argument("--uncond", "-uncond", action="store_true",
help='Use finetuned unconditional checkpoints from OpenAI (256px) and Katherine Crowson (512px)')
p.add_argument("--noise_schedule", "-sched", default='linear', type=str,
Expand All @@ -383,7 +298,7 @@ def main():
p.add_argument("--device", "-dev", default='', type=str,
help="Device to use. Either cpu or cuda.")
p.add_argument('--wandb_project', '-proj', default=None,
help='Name W&B will use when saving results.\ne.g. `--wandb_name "my_project"`')
help='Name W&B will use when saving results.\ne.g. `--wandb_project "my_project"`')
p.add_argument('--wandb_entity', '-ent', default=None,
help='(optional) Name of W&B team/entity to log to.')
p.add_argument('--quiet', '-q', action='store_true',
Expand Down Expand Up @@ -439,4 +354,4 @@ def main():


if __name__ == "__main__":
main()
main()
Loading

0 comments on commit 9e40100

Please sign in to comment.