Skip to content

Commit

Permalink
lazy loading works
Browse files Browse the repository at this point in the history
  • Loading branch information
zsxkib committed Aug 28, 2024
1 parent fbccf68 commit 43be9d7
Showing 1 changed file with 103 additions and 59 deletions.
162 changes: 103 additions & 59 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import List, Tuple
from cog import BasePredictor, Input, Path
import numpy as np
import warnings
from diffusers import FluxPipeline, FluxInpaintPipeline
from transformers import CLIPImageProcessor
from diffusers.pipelines.stable_diffusion.safety_checker import (
Expand Down Expand Up @@ -43,25 +44,21 @@ class LoadedLoRAs:
main: str | None
extra: str | None


class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
"""Load dev model into CUDA memory and initialize other models"""
start = time.time()
# Dont pull weights
os.environ["TRANSFORMERS_OFFLINE"] = "1"

self.weights_cache = WeightsDownloadCache()

print("Loading safety checker...")
if not os.path.exists(SAFETY_CACHE):
download_base_weights(SAFETY_URL, SAFETY_CACHE)

# TODO: implement safety checker w/ lazy loading
# self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
# SAFETY_CACHE, torch_dtype=torch.float16
# ).to("cuda")
# self.feature_extractor = CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR)
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
SAFETY_CACHE, torch_dtype=torch.float16
).to("cuda")
self.feature_extractor = CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR)

print("Loading Flux dev pipeline")
if not os.path.exists("FLUX.1-dev"):
Expand All @@ -81,7 +78,7 @@ def setup(self) -> None:
tokenizer=dev_pipe.tokenizer,
tokenizer_2=dev_pipe.tokenizer_2,
torch_dtype=torch.bfloat16,
).to("cuda")
).to("cpu") # Keep schnell on CPU initially

self.pipes = {
"dev": dev_pipe,
Expand All @@ -91,23 +88,76 @@ def setup(self) -> None:
"dev": LoadedLoRAs(main=None, extra=None),
"schnell": LoadedLoRAs(main=None, extra=None),
}
self.inpaint_pipes = {}
self.inpaint_pipes = {
"dev": None,
"schnell": None,
}
self.current_model = "dev"
self.current_inpaint = None

self.loaded_models = ["safety_checker", "dev"]
print(f"[!] Loaded models: {self.loaded_models}")

print("setup took: ", time.time() - start)

def setup_inpaint_pipeline(self, model: str):
if model not in self.inpaint_pipes:
print(f"Creating inpaint pipeline for {model}")
base_pipe = self.pipes[model]
inpaint_pipe = FluxInpaintPipeline.from_pretrained(
f"FLUX.1-{model}",
text_encoder=base_pipe.text_encoder,
text_encoder_2=base_pipe.text_encoder_2,
tokenizer=base_pipe.tokenizer,
tokenizer_2=base_pipe.tokenizer_2,
torch_dtype=torch.bfloat16,
).to("cuda")
self.inpaint_pipes[model] = inpaint_pipe
def configure_active_model(self, model: str, inpaint: bool = False):
initial_models = set(self.loaded_models)

# Unload current model if it's different
if self.current_model != model:
if self.current_model:
self.pipes[self.current_model].to("cpu")
self.loaded_models.remove(self.current_model)

self.pipes[model].to("cuda")
self.current_model = model
self.loaded_models.append(model)

# Ensure the model and all its components are on CUDA
pipe = self.pipes[model]
if pipe.device.type != "cuda":
print(f"Moving {model} model to CUDA.")
pipe.to("cuda")

# Explicitly move specific model components to CUDA
for component_name in ['unet', 'text_encoder', 'text_encoder_2', 'vae']:
if hasattr(pipe, component_name):
component = getattr(pipe, component_name)
if isinstance(component, torch.nn.Module):
component.to("cuda")

# Handle inpainting models
if inpaint:
if self.current_inpaint != model:
if self.current_inpaint:
self.inpaint_pipes[self.current_inpaint].to("cpu")
self.loaded_models.remove(f"{self.current_inpaint}_inpaint")

if self.inpaint_pipes[model] is None:
base_pipe = self.pipes[model]
self.inpaint_pipes[model] = FluxInpaintPipeline.from_pretrained(
f"FLUX.1-{model}",
text_encoder=base_pipe.text_encoder,
text_encoder_2=base_pipe.text_encoder_2,
tokenizer=base_pipe.tokenizer,
tokenizer_2=base_pipe.tokenizer_2,
torch_dtype=torch.bfloat16,
).to("cuda")
else:
self.inpaint_pipes[model].to("cuda")

self.current_inpaint = model
self.loaded_models.append(f"{model}_inpaint")
else:
if self.current_inpaint:
self.inpaint_pipes[self.current_inpaint].to("cpu")
self.loaded_models.remove(f"{self.current_inpaint}_inpaint")
self.current_inpaint = None

torch.cuda.empty_cache()

if set(self.loaded_models) != initial_models:
print(f"[!] Loaded models: {self.loaded_models}")

@torch.inference_mode()
def predict(
Expand Down Expand Up @@ -224,30 +274,26 @@ def predict(
flux_kwargs = {}
print(f"Prompt: {prompt}")

if image and mask:
inpaint_mode = image is not None and mask is not None
self.configure_active_model(model, inpaint_mode)

if inpaint_mode:
print("inpaint mode")
self.setup_inpaint_pipeline(model)
input_image = self.load_image(image)
mask_image = self.load_image(mask)
width, height = self.resize_image_dimensions(input_image.size)
flux_kwargs["image"] = input_image.resize((width, height), Image.LANCZOS)
flux_kwargs["mask_image"] = mask_image.resize(
(width, height), Image.LANCZOS
)
flux_kwargs["prompt_strength"] = prompt_strength
flux_kwargs["strength"] = prompt_strength
print(f"Using {model} model for inpainting")
pipe = self.inpaint_pipes[model]
else:
# Unload inpainting pipelines if they exist
if self.inpaint_pipes:
print("Unloading inpaint pipelines to free CUDA memory")
self.inpaint_pipes.clear()
torch.cuda.empty_cache()

if image:
print("img2img mode")
flux_kwargs["image"] = self.load_image(image)
flux_kwargs["prompt_strength"] = prompt_strength
flux_kwargs["strength"] = prompt_strength
else:
print("txt2img mode")

Expand All @@ -270,8 +316,6 @@ def predict(
max_sequence_length = 256
guidance_scale = 0

print("Available pipelines:", list(self.pipes.keys()))

if replicate_weights:
start_time = time.time()
if extra_lora:
Expand All @@ -290,7 +334,15 @@ def predict(
pipe.unload_lora_weights()
self.loaded_lora_urls[model] = LoadedLoRAs(main=None, extra=None)

generator = torch.Generator("cuda").manual_seed(seed)
# Ensure all model components are on the correct device
device = pipe.device
for component_name in ['unet', 'text_encoder', 'text_encoder_2', 'vae']:
if hasattr(pipe, component_name):
component = getattr(pipe, component_name)
if isinstance(component, torch.nn.Module):
component.to(device)

generator = torch.Generator(device=device).manual_seed(seed)

common_args = {
"prompt": [prompt] * num_outputs,
Expand All @@ -303,17 +355,10 @@ def predict(

output = pipe(**common_args, **flux_kwargs)

# TODO: implement safety checker w/ lazy loading
# if not disable_safety_checker:
# _, has_nsfw_content = self.run_safety_checker(output.images)
# else:
# has_nsfw_content = [False] * len(output.images)
disable_safety_checker = (
True # TODO: remove this when we have a safety checker back
)
has_nsfw_content = [False] * len(
output.images
) # TODO: remove this when we have a safety checker back
if not disable_safety_checker:
_, has_nsfw_content = self.run_safety_checker(output.images)
else:
has_nsfw_content = [False] * len(output.images)

output_paths = []
for i, image in enumerate(output.images):
Expand Down Expand Up @@ -376,17 +421,16 @@ def load_multiple_loras(self, main_lora_url: str, extra_lora_url: str, model: st

@torch.amp.autocast("cuda")
def run_safety_checker(self, image):
# TODO: implement safety checker w/ lazy loading
# safety_checker_input = self.feature_extractor(image, return_tensors="pt").to(
# "cuda"
# )
# np_image = [np.array(val) for val in image]
# image, has_nsfw_concept = self.safety_checker(
# images=np_image,
# clip_input=safety_checker_input.pixel_values.to(torch.float16),
# )
# return image, has_nsfw_concept
return image, False # TODO: remove this when we have a safety checker back
safety_checker_input = self.feature_extractor(image, return_tensors="pt").to(
"cuda"
)
np_image = [np.array(val) for val in image]
image, has_nsfw_concept = self.safety_checker(
images=np_image,
clip_input=safety_checker_input.pixel_values.to(torch.float16),
)
return image, has_nsfw_concept


def aspect_ratio_to_width_height(self, aspect_ratio: str) -> Tuple[int, int]:
return ASPECT_RATIOS[aspect_ratio]
Expand Down

0 comments on commit 43be9d7

Please sign in to comment.