Skip to content

Commit

Permalink
runner: Support models w/o fp16 variant
Browse files Browse the repository at this point in the history
  • Loading branch information
yondonfu committed Jan 26, 2024
1 parent 92b0d6b commit f5ba038
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 4 deletions.
12 changes: 11 additions & 1 deletion runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,27 @@
from app.pipelines.util import get_torch_device, get_model_dir

from diffusers import AutoPipelineForImage2Image
from huggingface_hub import model_info
import torch
import PIL
from typing import List
import logging

logger = logging.getLogger(__name__)


class ImageToImagePipeline(Pipeline):
def __init__(self, model_id: str):
kwargs = {"cache_dir": get_model_dir()}

torch_device = get_torch_device()
if torch_device != "cpu":
model_data = model_info(model_id)
has_fp16_variant = any(
".fp16.safetensors" in file.rfilename for file in model_data.siblings
)
if torch_device != "cpu" and has_fp16_variant:
logger.info("ImageToImagePipeline loading fp16 variant for %s", model_id)

kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"

Expand Down
12 changes: 11 additions & 1 deletion runner/app/pipelines/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,27 @@
from app.pipelines.util import get_torch_device, get_model_dir

from diffusers import StableVideoDiffusionPipeline
from huggingface_hub import model_info
import torch
import PIL
from typing import List
import logging

logger = logging.getLogger(__name__)


class ImageToVideoPipeline(Pipeline):
def __init__(self, model_id: str):
kwargs = {"cache_dir": get_model_dir()}

torch_device = get_torch_device()
if torch_device != "cpu":
model_data = model_info(model_id)
has_fp16_variant = any(
".fp16.safetensors" in file.rfilename for file in model_data.siblings
)
if torch_device != "cpu" and has_fp16_variant:
logger.info("ImageToVideoPipeline loading fp16 variant for %s", model_id)

kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"

Expand Down
12 changes: 11 additions & 1 deletion runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,27 @@
from app.pipelines.util import get_torch_device, get_model_dir

from diffusers import AutoPipelineForText2Image
from huggingface_hub import model_info
import torch
import PIL
from typing import List
import logging

logger = logging.getLogger(__name__)


class TextToImagePipeline(Pipeline):
def __init__(self, model_id: str):
kwargs = {"cache_dir": get_model_dir()}

torch_device = get_torch_device()
if torch_device != "cpu":
model_data = model_info(model_id)
has_fp16_variant = any(
".fp16.safetensors" in file.rfilename for file in model_data.siblings
)
if torch_device != "cpu" and has_fp16_variant:
logger.info("TextToImagePipeline loading fp16 variant for %s", model_id)

kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"

Expand Down
3 changes: 3 additions & 0 deletions runner/dl_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,8 @@ mkdir -p models

# text-to-image, image-to-image
huggingface-cli download stabilityai/sd-turbo --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models
huggingface-cli download runwayml/stable-diffusion-v1-5 --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models
huggingface-cli download stabilityai/stable-diffusion-xl-base-1.0 --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models
huggingface-cli download prompthero/openjourney-v4 --include "*.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models
# image-to-video
huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models
3 changes: 2 additions & 1 deletion runner/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ fastapi
pydantic
Pillow
python-multipart
uvicorn
uvicorn
huggingface_hub

0 comments on commit f5ba038

Please sign in to comment.