diff --git a/cgd/__init__.py b/cgd/__init__.py index 2e65126..b974282 100644 --- a/cgd/__init__.py +++ b/cgd/__init__.py @@ -1,3 +1 @@ -# __init__.py for cgd package -import sys -sys.path.append('./cgd') \ No newline at end of file +from . import * \ No newline at end of file diff --git a/cgd/cgd.py b/cgd/cgd.py index fbdf615..059acf9 100644 --- a/cgd/cgd.py +++ b/cgd/cgd.py @@ -1,108 +1,21 @@ import argparse import os -import sys import time from pathlib import Path -import clip import lpips import torch as th import wandb from PIL import Image -from torch.nn.functional import normalize -from torchvision import transforms as T -from torchvision.transforms import functional as tvf from torchvision.transforms.transforms import ToTensor from tqdm.auto import tqdm - -from cgd.clip_util import (CLIP_MODEL_NAMES, CLIP_NORMALIZE, MakeCutouts, - load_clip) -from cgd.loss_util import spherical_dist_loss, tv_loss -from cgd.util import (CACHE_PATH, create_gif, download_guided_diffusion, fetch, - load_guided_diffusion, log_image) - -sys.path.append(os.path.join(os.getcwd(), "guided-diffusion")) - -TIMESTEP_RESPACINGS = ("25", "50", "100", "250", "500", "1000", - "ddim25", "ddim50", "ddim100", "ddim250", "ddim500", "ddim1000") -DIFFUSION_SCHEDULES = (25, 50, 100, 250, 500, 1000) -IMAGE_SIZES = (64, 128, 256, 512) - - -def check_parameters( - prompts: list, - image_prompts: list, - image_size: int, - timestep_respacing: str, - diffusion_steps: int, - clip_model_name: str, - save_frequency: int, - noise_schedule: str, -): - if not (len(prompts) > 0 or len(image_prompts) > 0): - raise ValueError("Must provide at least one prompt, text or image.") - if not (noise_schedule in ['linear', 'cosine']): - raise ValueError('Noise schedule should be one of: linear, cosine') - if not (image_size in IMAGE_SIZES): - raise ValueError(f"--image size should be one of {IMAGE_SIZES}") - if not (0 < save_frequency <= int(timestep_respacing.replace('ddim', ''))): - raise ValueError( - "--save_frequency must be greater than 0 and less than `timestep_respacing`") - if not (diffusion_steps in DIFFUSION_SCHEDULES): - print('(warning) Diffusion steps should be one of:', DIFFUSION_SCHEDULES) - if not (timestep_respacing in TIMESTEP_RESPACINGS): - print( - f"Pausing run. `timestep_respacing` should be one of {TIMESTEP_RESPACINGS}. CTRL-C if this was a mistake.") - time.sleep(5) - print("Resuming run.") - if clip_model_name.endswith('.pt') or clip_model_name.endswith('.pth'): - assert os.path.isfile( - clip_model_name), f"{clip_model_name} does not exist" - print(f"Loading custom model from {clip_model_name}") - elif not (clip_model_name in CLIP_MODEL_NAMES): - print( - f"--clip model name should be one of: {CLIP_MODEL_NAMES} unless you are trying to use your own checkpoint.") - print(f"Loading OpenAI CLIP - {clip_model_name}") +from cgd import losses +from cgd import clip_util +from cgd import script_util # Define necessary functions -def parse_prompt(prompt): # parse a single prompt in the form ":" - if prompt.startswith('http://') or prompt.startswith('https://'): - vals = prompt.rsplit(':', 2) # theres two colons, so we grab the 2nd - vals = [vals[0] + ':' + vals[1], *vals[2:]] - else: - vals = prompt.rsplit(':', 1) # grab weight after colon - vals = vals + ['', '1'][len(vals):] # if no weight, use 1 - return vals[0], float(vals[1]) # return text, weight - - -def encode_text_prompt(txt, weight, clip_model_name="ViT-B/32", device="cpu"): - clip_model, _ = load_clip(clip_model_name, device) - txt_tokens = clip.tokenize(txt).to(device) - txt_encoded = clip_model.encode_text(txt_tokens).float() - return txt_encoded, weight - - -def encode_image_prompt(image: str, weight: float, diffusion_size: int, num_cutouts, clip_model_name: str = "ViT-B/32", device: str = "cpu"): - clip_model, clip_size = load_clip(clip_model_name, device) - make_cutouts = MakeCutouts(cut_size=clip_size, num_cutouts=num_cutouts) - pil_img = Image.open(fetch(image)).convert('RGB') - smallest_side = min(diffusion_size, *pil_img.size) - # You can ignore the type warning caused by pytorch resize having - # an incorrect type hint for their resize signature. which does indeed support PIL.Image - pil_img = tvf.resize(pil_img, [smallest_side], - tvf.InterpolationMode.LANCZOS) - batch = make_cutouts(tvf.to_tensor(pil_img).unsqueeze(0).to(device)) - batch_embed = clip_model.encode_image(normalize(batch)).float() - batch_weight = [weight / make_cutouts.cutn] * make_cutouts.cutn - return batch_embed, batch_weight - - -def range_loss(input): - return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3]) - - def clip_guided_diffusion( image_size: int = 128, num_cutouts: int = 16, @@ -121,7 +34,7 @@ def clip_guided_diffusion( seed: int = 0, diffusion_steps: int = 1000, skip_timesteps: int = 0, - checkpoints_dir: str = CACHE_PATH, + checkpoints_dir: str = script_util.CACHE_PATH, clip_model_name: str = "ViT-B/32", randomize_class: bool = True, prefix_path: str = 'outputs', @@ -133,7 +46,6 @@ def clip_guided_diffusion( wandb_entity: str = None, progress: bool = True, ): - if len(device) == 0: device = 'cuda' if th.cuda.is_available() else 'cpu' print( @@ -142,39 +54,43 @@ def clip_guided_diffusion( print(f"Using device {device}") fp32_diffusion = (device == 'cpu') + wandb_run = None if wandb_project is not None: # just use local vars for config - wandb_run = wandb.init(project=wandb_project, entity=wandb_entity, config=locals()) + wandb_run = wandb.init(project=wandb_project, + entity=wandb_entity, config=locals()) + else: + print(f"--wandb_project not specified. Skipping W&B integration.") - if seed: - th.manual_seed(seed) + th.manual_seed(seed) # only use magnitude for low timestep_respacing - use_magnitude = (int(timestep_respacing.replace("ddim", "")) <= 25 or image_size == 64) + use_magnitude = (int(timestep_respacing.replace( + "ddim", "")) <= 25 or image_size == 64) # only use saturation loss on ddim use_saturation = ("ddim" in timestep_respacing or image_size == 64) Path(prefix_path).mkdir(parents=True, exist_ok=True) Path(checkpoints_dir).mkdir(parents=True, exist_ok=True) - diffusion_path = download_guided_diffusion( + diffusion_path = script_util.download_guided_diffusion( image_size=image_size, checkpoints_dir=checkpoints_dir, class_cond=class_cond) # Load CLIP model/Encode text/Create `MakeCutouts` embeds_list = [] weights_list = [] - clip_model, clip_size = load_clip(clip_model_name, device) + clip_model, clip_size = clip_util.load_clip(clip_model_name, device) for prompt in prompts: - text, weight = parse_prompt(prompt) - text, weight = encode_text_prompt( + text, weight = script_util.parse_prompt(prompt) + text, weight = clip_util.encode_text_prompt( text, weight, clip_model_name, device) embeds_list.append(text) weights_list.append(weight) for image_prompt in image_prompts: - img, weight = parse_prompt(image_prompt) - image_prompt, batched_weight = encode_image_prompt( + img, weight = script_util.parse_prompt(image_prompt) + image_prompt, batched_weight = clip_util.encode_image_prompt( img, weight, image_size, num_cutouts=num_cutouts, clip_model_name=clip_model_name, device=device) embeds_list.append(image_prompt) weights_list.extend(batched_weight) @@ -191,13 +107,13 @@ def clip_guided_diffusion( if use_augs: tqdm.write( f"Using augmentations to improve performance for lower timestep_respacing of {timestep_respacing}") - make_cutouts = MakeCutouts(cut_size=clip_size, num_cutouts=num_cutouts, - cutout_size_power=cutout_power, use_augs=use_augs) + make_cutouts = clip_util.MakeCutouts(cut_size=clip_size, num_cutouts=num_cutouts, + cutout_size_power=cutout_power, use_augs=use_augs) # Load initial image (if provided) init_tensor = None if len(init_image) > 0: - pil_image = Image.open(fetch(init_image)).convert( + pil_image = Image.open(script_util.fetch(init_image)).convert( "RGB").resize((image_size, image_size), Image.LANCZOS) init_tensor = ToTensor()(pil_image).to(device).unsqueeze(0).mul(2).sub(1) @@ -208,7 +124,7 @@ def clip_guided_diffusion( [batch_size], device=device, dtype=th.long) # Load guided diffusion - gd_model, diffusion = load_guided_diffusion( + gd_model, diffusion = script_util.load_guided_diffusion( checkpoint_path=diffusion_path, image_size=image_size, class_cond=class_cond, diffusion_steps=diffusion_steps, @@ -229,7 +145,6 @@ def cond_fn(x, t, out, y=None): fac = diffusion.sqrt_one_minus_alphas_cumprod[current_timestep] sigmas = 1 - fac x_in = out["pred_xstart"] * fac + x * sigmas - wandb_run = None if wandb_project is not None: log['Generations'] = [ wandb.Image(x, caption=f"Noisy Sample"), @@ -238,16 +153,16 @@ def cond_fn(x, t, out, y=None): wandb.Image(x_in, caption=f"Blended (what CLIP sees)"), ] - clip_in = CLIP_NORMALIZE(make_cutouts(x_in.add(1).div(2))) + clip_in = clip_util.CLIP_NORMALIZE(make_cutouts(x_in.add(1).div(2))) cutout_embeds = clip_model.encode_image( clip_in).float().view([num_cutouts, n, -1]) - dists = spherical_dist_loss( + dists = losses.spherical_dist_loss( cutout_embeds.unsqueeze(0), target_embeds.unsqueeze(0)) dists = dists.view([num_cutouts, n, -1]) clip_losses = dists.mul(weights).sum(2).mean(0) - range_losses = range_loss(out["pred_xstart"]) - tv_losses = tv_loss(x_in) + range_losses = losses.range_loss(out["pred_xstart"]) + tv_losses = losses.tv_loss(x_in) clip_losses = clip_losses.sum() * clip_guidance_scale range_losses = range_losses.sum() * range_scale @@ -282,7 +197,7 @@ def cond_fn(x, t, out, y=None): if progress: tqdm.write( "\t".join([f"{k}: {v:.3f}" for k, v in log.items() if "loss" in k.lower()])) - if wandb_run is not None: + if wandb_project is not None: wandb_run.log(log) return final_loss @@ -312,11 +227,11 @@ def cond_fn(x, t, out, y=None): current_timestep -= 1 if step % save_frequency == 0 or current_timestep == -1: for batch_idx, image_tensor in enumerate(sample["pred_xstart"]): - yield batch_idx, log_image(image_tensor, prefix_path, prompts, step, batch_idx) + yield batch_idx, script_util.log_image(image_tensor, prefix_path, prompts, step, batch_idx) # if wandb_project is not None: wandb.log({"image": wandb.Image(image_tensor, caption="|".join(prompts))}) for batch_idx in range(batch_size): - create_gif(prefix_path, prompts, batch_idx) + script_util.create_gif(prefix_path, prompts, batch_idx) except (RuntimeError, KeyboardInterrupt) as runtime_ex: if "CUDA out of memory" in str(runtime_ex): @@ -348,7 +263,7 @@ def main(): help="Number of timesteps to blend image for. CLIP guidance occurs after this.") p.add_argument("--prefix", "-dir", default="outputs", type=Path, help="output directory") - p.add_argument("--checkpoints_dir", "-ckpts", default=CACHE_PATH, + p.add_argument("--checkpoints_dir", "-ckpts", default=script_util.CACHE_PATH, type=Path, help="Path subdirectory containing checkpoints.") p.add_argument("--batch_size", "-bs", type=int, default=1, help="the batch size") @@ -373,7 +288,7 @@ def main(): p.add_argument("--cutout_power", "-cutpow", type=float, default=1.0, help="Cutout size power") p.add_argument("--clip_model", "-clip", type=str, default="ViT-B/32", - help=f"clip model name. Should be one of: {CLIP_MODEL_NAMES} or a checkpoint filename ending in `.pt`") + help=f"clip model name. Should be one of: {clip_util.CLIP_MODEL_NAMES} or a checkpoint filename ending in `.pt`") p.add_argument("--uncond", "-uncond", action="store_true", help='Use finetuned unconditional checkpoints from OpenAI (256px) and Katherine Crowson (512px)') p.add_argument("--noise_schedule", "-sched", default='linear', type=str, @@ -383,7 +298,7 @@ def main(): p.add_argument("--device", "-dev", default='', type=str, help="Device to use. Either cpu or cuda.") p.add_argument('--wandb_project', '-proj', default=None, - help='Name W&B will use when saving results.\ne.g. `--wandb_name "my_project"`') + help='Name W&B will use when saving results.\ne.g. `--wandb_project "my_project"`') p.add_argument('--wandb_entity', '-ent', default=None, help='(optional) Name of W&B team/entity to log to.') p.add_argument('--quiet', '-q', action='store_true', @@ -439,4 +354,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/cgd/clip_util.py b/cgd/clip_util.py index 8b22f9f..74bc4e4 100644 --- a/cgd/clip_util.py +++ b/cgd/clip_util.py @@ -1,66 +1,40 @@ -from PIL import Image -from torchvision.transforms.functional import to_tensor -from cgd.util import resize_image, download -import clip from functools import lru_cache + +import clip import torch as th import torch.nn.functional as tf import torchvision.transforms as tvt +import torchvision.transforms.functional as tvf from data.imagenet1000_clsidx_to_labels import IMAGENET_CLASSES +from PIL import Image + +from cgd import script_util +from cgd.modules import MakeCutouts +from cgd.ResizeRight import resize_right +from cgd.ResizeRight.interp_methods import lanczos3 +CLIP_MODEL_NAMES = ("ViT-B/16", "ViT-B/32", "RN50", + "RN101", "RN50x4", "RN50x16") +CLIP_NORMALIZE = tvt.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[ + 0.26862954, 0.26130258, 0.27577711]) -CLIP_MODEL_NAMES = ("ViT-B/16", "ViT-B/32", "RN50", "RN101", "RN50x4", "RN50x16") -CLIP_NORMALIZE = tvt.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) @lru_cache(maxsize=1) def load_clip(model_name='ViT-B/32', device="cpu"): print(f"Loading clip model\t{model_name}\ton device\t{device}.") if device == "cpu": - clip_model = clip.load(model_name, jit=False)[0].eval().to(device=device).float() + clip_model = clip.load(model_name, jit=False)[ + 0].eval().to(device=device).float() clip_size = clip_model.visual.input_resolution return clip_model, clip_size elif "cuda" in device: - clip_model = clip.load(model_name, jit=False)[0].eval().requires_grad_(False).to(device) + clip_model = clip.load(model_name, jit=False)[ + 0].eval().requires_grad_(False).to(device) clip_size = clip_model.visual.input_resolution return clip_model, clip_size else: raise ValueError("Invalid or unspecified device: {}".format(device)) -class MakeCutouts(th.nn.Module): - def __init__(self, cut_size: int, num_cutouts: int, cutout_size_power: float = 1.0, use_augs: bool = True): - super().__init__() - self.cut_size = cut_size - self.cutn = num_cutouts - self.cut_pow = cutout_size_power - custom_augs = [] - if use_augs: - custom_augs = [ - tvt.RandomHorizontalFlip(p=0.5), - tvt.Lambda(lambda x: x + th.randn_like(x) * 0.01), - tvt.RandomAffine(degrees=15, translate=(0.1, 0.1)), - tvt.Lambda(lambda x: x + th.randn_like(x) * 0.01), - tvt.RandomPerspective(distortion_scale=0.4, p=0.7), - tvt.Lambda(lambda x: x + th.randn_like(x) * 0.01), - tvt.RandomGrayscale(p=0.15), - tvt.Lambda(lambda x: x + th.randn_like(x) * 0.01), - tvt.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), - tvt.Lambda(lambda x: x + th.randn_like(x) * 0.01), - ] # TODO: test color jitter specifically - self.augs = tvt.Compose(custom_augs) - - def forward(self, input: th.Tensor): - side_x, side_y = input.shape[2:4] - max_size = min(side_y, side_x) - min_size = min(side_y, side_x, self.cut_size) - cutouts = [] - for _ in range(self.cutn): - size = int(th.rand([])**self.cut_pow * (max_size - min_size) + min_size) - offsetx = th.randint(0, side_x - size + 1, ()) - offsety = th.randint(0, side_y - size + 1, ()) - cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] - cutouts.append(tf.adaptive_avg_pool2d(cutout, self.cut_size)) - return th.cat(cutouts) - def imagenet_top_n(text_encodes, device: str = 'cuda', n: int = len(IMAGENET_CLASSES), clip_model_name: str = "ViT-B/32"): """ @@ -68,10 +42,36 @@ def imagenet_top_n(text_encodes, device: str = 'cuda', n: int = len(IMAGENET_CLA """ clip_model, _ = load_clip(model_name=clip_model_name, device=device) with th.no_grad(): - engineered_pronmpts = [f"an image of a {img_cls}" for img_cls in IMAGENET_CLASSES] + engineered_pronmpts = [ + f"an image of a {img_cls}" for img_cls in IMAGENET_CLASSES] imagenet_lbl_tokens = clip.tokenize(engineered_pronmpts).to(device) imagenet_features = clip_model.encode_text(imagenet_lbl_tokens).float() imagenet_features /= imagenet_features.norm(dim=-1, keepdim=True) - prompt_features = text_encodes / text_encodes.norm(dim=-1, keepdim=True) - text_probs = (100.0 * prompt_features @ imagenet_features.T).softmax(dim=-1) - return text_probs.topk(n, dim=-1, sorted=True).indices[0].to(device) \ No newline at end of file + prompt_features = text_encodes / \ + text_encodes.norm(dim=-1, keepdim=True) + text_probs = (100.0 * prompt_features @ + imagenet_features.T).softmax(dim=-1) + return text_probs.topk(n, dim=-1, sorted=True).indices[0].to(device) + + +def encode_image_prompt(image: str, weight: float, diffusion_size: int, num_cutouts, clip_model_name: str = "ViT-B/32", device: str = "cpu"): + clip_model, clip_size = load_clip(clip_model_name, device) + make_cutouts = MakeCutouts(cut_size=clip_size, num_cutouts=num_cutouts) + pil_img = Image.open(script_util.fetch(image)).convert('RGB') + smallest_side = min(diffusion_size, *pil_img.size) + pil_img = resize_right.resize(input, scale_factors=None, out_shape=[smallest_side], + interp_method=lanczos3, support_sz=None, + antialiasing=True, by_convs=False, scale_tolerance=None, + max_denominator=10, pad_mode='constant') + + batch = make_cutouts(tvf.to_tensor(pil_img).unsqueeze(0).to(device)) + batch_embed = clip_model.encode_image(tf.normalize(batch)).float() + batch_weight = [weight / make_cutouts.cutn] * make_cutouts.cutn + return batch_embed, batch_weight + + +def encode_text_prompt(txt, weight, clip_model_name="ViT-B/32", device="cpu"): + clip_model, _ = load_clip(clip_model_name, device) + txt_tokens = clip.tokenize(txt).to(device) + txt_encoded = clip_model.encode_text(txt_tokens).float() + return txt_encoded, weight diff --git a/cgd/loss_util.py b/cgd/loss_util.py deleted file mode 100644 index c51d7ea..0000000 --- a/cgd/loss_util.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch as th -from torch.nn import functional as tf - - -def range_loss(input): - return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3]) - - -def spherical_dist_loss(x: th.Tensor, y: th.Tensor): - """(Katherine Crowson) - Spherical distance loss""" - x = tf.normalize(x, dim=-1) - y = tf.normalize(y, dim=-1) - return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) - - -def tv_loss(input: th.Tensor): - """(Katherine Crowson) - L2 total variation loss, as in Mahendran et al.""" - input = tf.pad(input, (0, 1, 0, 1), "replicate") - x_diff = input[..., :-1, 1:] - input[..., :-1, :-1] - y_diff = input[..., 1:, :-1] - input[..., :-1, :-1] - return (x_diff ** 2 + y_diff ** 2).mean([1, 2, 3]) diff --git a/cgd/modules.py b/cgd/modules.py new file mode 100644 index 0000000..933993d --- /dev/null +++ b/cgd/modules.py @@ -0,0 +1,42 @@ +import torch as th +import torch.nn.functional as tf +import torchvision.transforms as tvt + + +class MakeCutouts(th.nn.Module): + def __init__(self, cut_size: int, num_cutouts: int, cutout_size_power: float = 1.0, use_augs: bool = False): + super().__init__() + self.cut_size = cut_size + self.cutn = num_cutouts + self.cut_pow = cutout_size_power + custom_augs = [] + if use_augs: + custom_augs = [ + tvt.RandomHorizontalFlip(p=0.5), + tvt.Lambda(lambda x: x + th.randn_like(x) * 0.01), + tvt.RandomAffine(degrees=15, translate=(0.1, 0.1)), + tvt.Lambda(lambda x: x + th.randn_like(x) * 0.01), + tvt.RandomPerspective(distortion_scale=0.4, p=0.7), + tvt.Lambda(lambda x: x + th.randn_like(x) * 0.01), + tvt.RandomGrayscale(p=0.15), + tvt.Lambda(lambda x: x + th.randn_like(x) * 0.01), + tvt.ColorJitter(brightness=0.1, contrast=0.1, + saturation=0.1, hue=0.1), + tvt.Lambda(lambda x: x + th.randn_like(x) * 0.01), + ] # TODO: test color jitter specifically + self.augs = tvt.Compose(custom_augs) + + def forward(self, input: th.Tensor): + side_x, side_y = input.shape[2:4] + max_size = min(side_y, side_x) + min_size = min(side_y, side_x, self.cut_size) + cutouts = [] + for _ in range(self.cutn): + size = int(th.rand([])**self.cut_pow * + (max_size - min_size) + min_size) + offsetx = th.randint(0, side_x - size + 1, ()) + offsety = th.randint(0, side_y - size + 1, ()) + cutout = input[:, :, offsety:offsety + + size, offsetx:offsetx + size] + cutouts.append(tf.adaptive_avg_pool2d(cutout, self.cut_size)) + return th.cat(cutouts) diff --git a/cgd/util.py b/cgd/script_util.py similarity index 60% rename from cgd/util.py rename to cgd/script_util.py index 2a2d68c..163a703 100644 --- a/cgd/util.py +++ b/cgd/script_util.py @@ -1,25 +1,117 @@ -from functools import lru_cache import glob import io import os import re +import time +from functools import lru_cache from pathlib import Path -from typing import Tuple, Union from urllib import request -from PIL import Image import requests import torch as th -from guided_diffusion.script_util import (create_model_and_diffusion, - model_and_diffusion_defaults) -from torch.nn import functional as tf -from torchvision.transforms import functional as tvf -from tqdm.autonotebook import tqdm - +import torchvision.transforms.functional as tvf +from tqdm.auto import tqdm from data.diffusion_model_flags import DIFFUSION_LOOKUP +from PIL import Image + +from cgd import clip_util +from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults CACHE_PATH = os.path.expanduser("~/.cache/clip-guided-diffusion") -ALPHANUMERIC_REGEX = r"[^\w\s]" +TIMESTEP_RESPACINGS = ("25", "50", "100", "250", "500", "1000", + "ddim25", "ddim50", "ddim100", "ddim250", "ddim500", "ddim1000") +DIFFUSION_SCHEDULES = (25, 50, 100, 250, 500, 1000) +IMAGE_SIZES = (64, 128, 256, 512) + +def check_parameters( + prompts: list, + image_prompts: list, + image_size: int, + timestep_respacing: str, + diffusion_steps: int, + clip_model_name: str, + save_frequency: int, + noise_schedule: str, +): + if not (len(prompts) > 0 or len(image_prompts) > 0): + raise ValueError("Must provide at least one prompt, text or image.") + if not (noise_schedule in ['linear', 'cosine']): + raise ValueError('Noise schedule should be one of: linear, cosine') + if not (image_size in IMAGE_SIZES): + raise ValueError(f"--image size should be one of {IMAGE_SIZES}") + if not (0 < save_frequency <= int(timestep_respacing.replace('ddim', ''))): + raise ValueError( + "--save_frequency must be greater than 0 and less than `timestep_respacing`") + if not (diffusion_steps in DIFFUSION_SCHEDULES): + print('(warning) Diffusion steps should be one of:', DIFFUSION_SCHEDULES) + if not (timestep_respacing in TIMESTEP_RESPACINGS): + print( + f"Pausing run. `timestep_respacing` should be one of {TIMESTEP_RESPACINGS}. CTRL-C if this was a mistake.") + time.sleep(5) + print("Resuming run.") + if clip_model_name.endswith('.pt') or clip_model_name.endswith('.pth'): + assert os.path.isfile( + clip_model_name), f"{clip_model_name} does not exist" + print(f"Loading custom model from {clip_model_name}") + elif not (clip_model_name in clip_util.CLIP_MODEL_NAMES): + print( + f"--clip model name should be one of: {clip_util.CLIP_MODEL_NAMES} unless you are trying to use your own checkpoint.") + print(f"Loading OpenAI CLIP - {clip_model_name}") + + +def parse_prompt(prompt): # parse a single prompt in the form ":" + if prompt.startswith('http://') or prompt.startswith('https://'): + vals = prompt.rsplit(':', 2) # theres two colons, so we grab the 2nd + vals = [vals[0] + ':' + vals[1], *vals[2:]] + else: + vals = prompt.rsplit(':', 1) # grab weight after colon + vals = vals + ['', '1'][len(vals):] # if no weight, use 1 + return vals[0], float(vals[1]) # return text, weight + + +def fetch(url_or_path): + if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'): + r = requests.get(url_or_path) + r.raise_for_status() + fd = io.BytesIO() + fd.write(r.content) + fd.seek(0) + return fd + return open(url_or_path, 'rb') + + +def alphanumeric_filter(s: str) -> str: + # regex to remove non-alphanumeric characters + ALPHANUMERIC_REGEX = r"[^\w\s]" + return re.sub(ALPHANUMERIC_REGEX, "", s).replace(" ", "_") + + +def clean_and_combine_prompts(base_path, txts, batch_idx, max_length=255) -> str: + clean_txt = "_".join([alphanumeric_filter(txt) + for txt in txts])[:max_length] + return os.path.join(base_path, clean_txt, f"{batch_idx:02}") + + +def log_image(image: th.Tensor, base_path: str, txts: list, current_step: int, batch_idx: int) -> str: + dirname = clean_and_combine_prompts(base_path, txts, batch_idx) + os.makedirs(dirname, exist_ok=True) + stem = f"{current_step:04}" + filename = os.path.join(dirname, f'{stem}.png') + pil_image = tvf.to_pil_image(image.add(1).div(2).clamp(0, 1)) + pil_image.save(os.path.join(os.getcwd(), f'current.png')) + pil_image.save(filename) + return str(filename) + + +def create_gif(base, prompts, batch_idx): + io_safe_prompts = clean_and_combine_prompts(base, prompts, batch_idx) + images_glob = os.path.join(io_safe_prompts, "*.png") + imgs = [Image.open(f) for f in sorted(glob.glob(images_glob))] + gif_filename = f"{io_safe_prompts}_{batch_idx:02}.gif" + imgs[0].save(fp=gif_filename, format='GIF', append_images=imgs, + save_all=True, duration=200, loop=0) + return gif_filename + # modified from https://github.com/lucidrains/DALLE-pytorch/blob/d355100061911b13e1f1c22de8c2b5deb44e65f8/dalle_pytorch/vae.py def download(url: str, filename: str, root: str = CACHE_PATH) -> str: @@ -47,11 +139,13 @@ def download_guided_diffusion(image_size: int, class_cond: bool, checkpoints_dir cond_key = 'cond' if class_cond else 'uncond' diffusion_model_info = DIFFUSION_LOOKUP[cond_key][image_size] if not overwrite: - target_path = Path(checkpoints_dir).joinpath(diffusion_model_info["filename"]) + target_path = Path(checkpoints_dir).joinpath( + diffusion_model_info["filename"]) if target_path.exists(): return str(target_path) return download(diffusion_model_info["url"], diffusion_model_info["filename"], checkpoints_dir) + @lru_cache(maxsize=1) def load_guided_diffusion( checkpoint_path: str, @@ -71,8 +165,10 @@ def load_guided_diffusion( diffusion_steps: number of diffusion steps timestep_respacing: whether to use timestep-respacing or not ''' - if not (len(device) > 0): raise ValueError("device must be set") - if not (noise_schedule in [ "linear", "cosine"]): raise ValueError("linear_or_cosine must be set") + if not (len(device) > 0): + raise ValueError("device must be set") + if not (noise_schedule in ["linear", "cosine"]): + raise ValueError("linear_or_cosine must be set") cond_key = 'cond' if class_cond else 'uncond' diffusion_model_info = DIFFUSION_LOOKUP[cond_key][image_size] @@ -94,48 +190,3 @@ def load_guided_diffusion( if model_config["use_fp16"]: model.convert_to_fp16() return model.to(device), diffusion - -def fetch(url_or_path): - if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'): - r = requests.get(url_or_path) - r.raise_for_status() - fd = io.BytesIO() - fd.write(r.content) - fd.seek(0) - return fd - return open(url_or_path, 'rb') - - -def alphanumeric_filter(s: str) -> str: - return re.sub(ALPHANUMERIC_REGEX, "", s).replace(" ", "_") - -def clean_and_combine_prompts(base_path, txts, batch_idx, max_length=255) -> str: - clean_txt = "_".join([alphanumeric_filter(txt) for txt in txts])[:max_length] - return os.path.join(base_path, clean_txt, f"{batch_idx:02}") - -def log_image(image: th.Tensor, base_path: str, txts: list, current_step: int, batch_idx: int) -> str: - dirname = clean_and_combine_prompts(base_path, txts, batch_idx) - os.makedirs(dirname, exist_ok=True) - stem = f"{current_step:04}" - filename = os.path.join(dirname, f'{stem}.png') - pil_image = tvf.to_pil_image(image.add(1).div(2).clamp(0, 1)) - pil_image.save(os.path.join(os.getcwd(), f'current.png')) - pil_image.save(filename) - return str(filename) - -def create_gif(base, prompts, batch_idx): - io_safe_prompts = clean_and_combine_prompts(base, prompts, batch_idx) - images_glob = os.path.join(io_safe_prompts, "*.png") - imgs = [Image.open(f) for f in sorted(glob.glob(images_glob))] - gif_filename = f"{io_safe_prompts}_{batch_idx:02}.gif" - imgs[0].save(fp=gif_filename, format='GIF', append_images=imgs, save_all=True, duration=200, loop=0) - return gif_filename - -def resize_image(image: th.Tensor, out_size: Union[int, Tuple[int, int]]) -> th.Tensor: - """(Katherine Crowson) - Resize image""" - outsize_x = out_size if isinstance(out_size, int) else out_size[0] - outsize_y = out_size if isinstance(out_size, int) else out_size[1] - ratio = image.size(1) / image.size(1) - area = min(image.size(0) * image.size(1), outsize_x * outsize_y) - size = round((area * ratio)**0.5), round((area / ratio)**0.5) - return image.reshape(size) \ No newline at end of file