|
| 1 | +# Prediction interface for Cog ⚙️ |
| 2 | +# https://github.com/replicate/cog/blob/main/docs/python.md |
| 3 | + |
| 4 | +import os |
| 5 | +import sys |
| 6 | + |
| 7 | +import time |
| 8 | +import subprocess |
| 9 | +from cog import BasePredictor, Input, Path |
| 10 | + |
| 11 | +import cv2 |
| 12 | +import torch |
| 13 | +import numpy as np |
| 14 | +from PIL import Image |
| 15 | + |
| 16 | +from diffusers.utils import load_image |
| 17 | +from diffusers.models import ControlNetModel |
| 18 | + |
| 19 | +from insightface.app import FaceAnalysis |
| 20 | + |
| 21 | +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) |
| 22 | +from pipeline_stable_diffusion_xl_instantid import ( |
| 23 | + StableDiffusionXLInstantIDPipeline, |
| 24 | + draw_kps, |
| 25 | +) |
| 26 | + |
| 27 | +# for `ip-adaper`, `ControlNetModel`, and `stable-diffusion-xl-base-1.0` |
| 28 | +CHECKPOINTS_CACHE = "./checkpoints" |
| 29 | +CHECKPOINTS_URL = ( |
| 30 | + "https://weights.replicate.delivery/default/InstantID/checkpoints.tar" |
| 31 | +) |
| 32 | + |
| 33 | +# for `models/antelopev2` |
| 34 | +MODELS_CACHE = "./models" |
| 35 | +MODELS_URL = "https://weights.replicate.delivery/default/InstantID/models.tar" |
| 36 | + |
| 37 | + |
| 38 | +def resize_img( |
| 39 | + input_image, |
| 40 | + max_side=1280, |
| 41 | + min_side=1024, |
| 42 | + size=None, |
| 43 | + pad_to_max_side=False, |
| 44 | + mode=Image.BILINEAR, |
| 45 | + base_pixel_number=64, |
| 46 | +): |
| 47 | + w, h = input_image.size |
| 48 | + if size is not None: |
| 49 | + w_resize_new, h_resize_new = size |
| 50 | + else: |
| 51 | + ratio = min_side / min(h, w) |
| 52 | + w, h = round(ratio * w), round(ratio * h) |
| 53 | + ratio = max_side / max(h, w) |
| 54 | + input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode) |
| 55 | + w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number |
| 56 | + h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number |
| 57 | + input_image = input_image.resize([w_resize_new, h_resize_new], mode) |
| 58 | + |
| 59 | + if pad_to_max_side: |
| 60 | + res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255 |
| 61 | + offset_x = (max_side - w_resize_new) // 2 |
| 62 | + offset_y = (max_side - h_resize_new) // 2 |
| 63 | + res[ |
| 64 | + offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new |
| 65 | + ] = np.array(input_image) |
| 66 | + input_image = Image.fromarray(res) |
| 67 | + return input_image |
| 68 | + |
| 69 | + |
| 70 | +def download_weights(url, dest): |
| 71 | + start = time.time() |
| 72 | + print("downloading url: ", url) |
| 73 | + print("downloading to: ", dest) |
| 74 | + subprocess.check_call(["pget", "-x", url, dest], close_fds=False) |
| 75 | + print("downloading took: ", time.time() - start) |
| 76 | + |
| 77 | + |
| 78 | +class Predictor(BasePredictor): |
| 79 | + def setup(self) -> None: |
| 80 | + """Load the model into memory to make running multiple predictions efficient""" |
| 81 | + if not os.path.exists(CHECKPOINTS_CACHE): |
| 82 | + download_weights(CHECKPOINTS_URL, CHECKPOINTS_CACHE) |
| 83 | + |
| 84 | + if not os.path.exists(MODELS_CACHE): |
| 85 | + download_weights(MODELS_URL, MODELS_CACHE) |
| 86 | + |
| 87 | + self.width, self.height = 640, 640 |
| 88 | + self.app = FaceAnalysis( |
| 89 | + name="antelopev2", |
| 90 | + root="./", |
| 91 | + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], |
| 92 | + ) |
| 93 | + self.app.prepare(ctx_id=0, det_size=(self.width, self.height)) |
| 94 | + |
| 95 | + # Path to InstantID models |
| 96 | + face_adapter = f"./checkpoints/ip-adapter.bin" |
| 97 | + controlnet_path = f"./checkpoints/ControlNetModel" |
| 98 | + |
| 99 | + # Load pipeline |
| 100 | + self.controlnet = ControlNetModel.from_pretrained( |
| 101 | + controlnet_path, |
| 102 | + torch_dtype=torch.float16, |
| 103 | + cache_dir=CHECKPOINTS_CACHE, |
| 104 | + local_files_only=True, |
| 105 | + ) |
| 106 | + |
| 107 | + base_model_path = "stabilityai/stable-diffusion-xl-base-1.0" |
| 108 | + self.pipe = StableDiffusionXLInstantIDPipeline.from_pretrained( |
| 109 | + base_model_path, |
| 110 | + controlnet=self.controlnet, |
| 111 | + torch_dtype=torch.float16, |
| 112 | + cache_dir=CHECKPOINTS_CACHE, |
| 113 | + local_files_only=True, |
| 114 | + ) |
| 115 | + self.pipe.cuda() |
| 116 | + self.pipe.load_ip_adapter_instantid(face_adapter) |
| 117 | + |
| 118 | + def predict( |
| 119 | + self, |
| 120 | + image: Path = Input(description="Input image"), |
| 121 | + prompt: str = Input( |
| 122 | + description="Input prompt", |
| 123 | + default="analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality", |
| 124 | + ), |
| 125 | + negative_prompt: str = Input( |
| 126 | + description="Input Negative Prompt", |
| 127 | + default="", |
| 128 | + ), |
| 129 | + width: int = Input( |
| 130 | + description="Width of output image", |
| 131 | + default=640, |
| 132 | + ge=512, |
| 133 | + le=2048, |
| 134 | + ), |
| 135 | + height: int = Input( |
| 136 | + description="Height of output image", |
| 137 | + default=640, |
| 138 | + ge=512, |
| 139 | + le=2048, |
| 140 | + ), |
| 141 | + ip_adapter_scale: float = Input( |
| 142 | + description="Scale for IP adapter", |
| 143 | + default=0.8, |
| 144 | + ge=0, |
| 145 | + le=1, |
| 146 | + ), |
| 147 | + controlnet_conditioning_scale: float = Input( |
| 148 | + description="Scale for ControlNet conditioning", |
| 149 | + default=0.8, |
| 150 | + ge=0, |
| 151 | + le=1, |
| 152 | + ), |
| 153 | + num_inference_steps: int = Input( |
| 154 | + description="Number of denoising steps", |
| 155 | + default=30, |
| 156 | + ge=1, |
| 157 | + le=500, |
| 158 | + ), |
| 159 | + guidance_scale: float = Input( |
| 160 | + description="Scale for classifier-free guidance", |
| 161 | + default=5, |
| 162 | + ge=1, |
| 163 | + le=50, |
| 164 | + ), |
| 165 | + ) -> Path: |
| 166 | + """Run a single prediction on the model""" |
| 167 | + if self.width != width or self.height != height: |
| 168 | + print(f"[!] Resizing output to {width}x{height}") |
| 169 | + self.width = width |
| 170 | + self.height = height |
| 171 | + self.app.prepare(ctx_id=0, det_size=(self.width, self.height)) |
| 172 | + |
| 173 | + face_image = load_image(str(image)) |
| 174 | + face_image = resize_img(face_image) |
| 175 | + |
| 176 | + face_info = self.app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR)) |
| 177 | + face_info = sorted( |
| 178 | + face_info, |
| 179 | + key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]), |
| 180 | + reverse=True, |
| 181 | + )[ |
| 182 | + 0 |
| 183 | + ] # only use the maximum face |
| 184 | + face_emb = face_info["embedding"] |
| 185 | + face_kps = draw_kps(face_image, face_info["kps"]) |
| 186 | + |
| 187 | + self.pipe.set_ip_adapter_scale(ip_adapter_scale) |
| 188 | + image = self.pipe( |
| 189 | + prompt=prompt, |
| 190 | + negative_prompt=negative_prompt, |
| 191 | + image_embeds=face_emb, |
| 192 | + image=face_kps, |
| 193 | + controlnet_conditioning_scale=controlnet_conditioning_scale, |
| 194 | + num_inference_steps=num_inference_steps, |
| 195 | + guidance_scale=guidance_scale, |
| 196 | + ).images[0] |
| 197 | + |
| 198 | + output_path = "result.jpg" |
| 199 | + image.save(output_path) |
| 200 | + return Path(output_path) |
0 commit comments