Skip to content

Commit

Permalink
update for video input instead of directory.
Browse files Browse the repository at this point in the history
  • Loading branch information
JJassonn69 committed Aug 8, 2024
1 parent b84dffa commit d91717e
Show file tree
Hide file tree
Showing 9 changed files with 189 additions and 107 deletions.
2 changes: 1 addition & 1 deletion runner/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Based on https://github.com/huggingface/api-inference-community/blob/main/docker_images/diffusers/Dockerfile

FROM nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu20.04
FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04
LABEL maintainer="Yondon Fu <[email protected]>"

# Add any system dependency here
Expand Down
4 changes: 2 additions & 2 deletions runner/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def load_pipeline(pipeline: str, model_id: str) -> any:
from app.pipelines.audio_to_text import AudioToTextPipeline

return AudioToTextPipeline(model_id)
case "FILMPipeline":
case "frame-interpolation":
from app.pipelines.frame_interpolation import FILMPipeline

return FILMPipeline(model_id)
Expand Down Expand Up @@ -80,7 +80,7 @@ def load_route(pipeline: str) -> any:
from app.routes import audio_to_text

return audio_to_text.router
case "FILMPipeline":
case "frame-interpolation":
from app.routes import frame_interpolation

return frame_interpolation.router
Expand Down
3 changes: 2 additions & 1 deletion runner/app/pipelines/frame_interpolation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import os
from torchvision.transforms import v2
from tqdm import tqdm
import bisect
Expand All @@ -10,7 +11,7 @@ class FILMPipeline:
model: torch.jit.ScriptModule

def __init__(self, model_id: str):
self.model_id = model_id
model_id = os.environ.get("MODEL_ID", "")
model_dir = get_model_dir() # Get the directory where models are stored
model_path = f"{model_dir}/{model_id}" # Construct the full path to the model file

Expand Down
26 changes: 25 additions & 1 deletion runner/app/pipelines/upscale.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import os
import time
from compel import Compel, ReturnedEmbeddingsType
from typing import List, Optional, Tuple

import PIL
Expand Down Expand Up @@ -114,7 +116,29 @@ def __call__(
if num_inference_steps is None or num_inference_steps < 1:
kwargs.pop("num_inference_steps", None)

output = self.ldm(prompt, image=image, **kwargs)
# trying differnt configs of promp_embed for different models
try:
compel_proc=Compel(tokenizer=self.ldm.tokenizer, text_encoder=self.ldm.text_encoder)
prompt=embeds = compel_proc(prompt)
output = self.ldm(prompt_embeds=prompt_embeds, image=image, **kwargs)
except Exception as e:
logging.info(f"Failed to generate prompt embeddings: {e}. Using prompt and pooled embeddings.")

try:
compel_proc = Compel(tokenizer=[self.ldm.tokenizer, self.ldm.tokenizer_2],
text_encoder=[self.ldm.text_encoder, self.ldm.text_encoder_2],
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
requires_pooled=[False, True])
prompt_embeds, pooled_prompt_embeds = compel_proc(prompt)
output = self.ldm(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
image=image,
**kwargs
)
except Exception as e:
logging.info(f"Failed to generate prompt and pooled embeddings: {e}. Trying normal prompt.")
output = self.ldm(prompt, image=image, **kwargs)

if safety_check:
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images)
Expand Down
2 changes: 1 addition & 1 deletion runner/app/pipelines/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def video_shredder(video_data, is_file_path=True) -> np.ndarray:
# Handle in-memory video input
# Create a temporary file to store in-memory video data
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_file:
temp_file.write(video_data.getvalue())
temp_file.write(video_data)
temp_file_path = temp_file.name

# Open the temporary video file
Expand Down
60 changes: 23 additions & 37 deletions runner/app/routes/frame_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import torch
import glob
import cv2
from typing import Annotated, Optional
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
from fastapi.responses import JSONResponse
Expand All @@ -12,8 +13,8 @@

from app.dependencies import get_pipeline
from app.pipelines.frame_interpolation import FILMPipeline
from app.pipelines.utils.utils import DirectoryReader, DirectoryWriter, get_torch_device, get_model_dir
from app.routes.util import HTTPError, ImageResponse, http_error, image_to_data_url
from app.pipelines.utils.utils import DirectoryReader, DirectoryWriter, get_torch_device, video_shredder
from app.routes.util import HTTPError, VideoResponse, http_error, image_to_data_url

ImageFile.LOAD_TRUNCATED_IMAGES = True

Expand All @@ -27,18 +28,16 @@
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError},
}

@router.post("/frame_interpolation", response_model=ImageResponse, responses=RESPONSES)
@router.post("/frame-interpolation", response_model=VideoResponse, responses=RESPONSES)
@router.post(
"/frame_interpolation/",
response_model=ImageResponse,
"/frame-interpolation/",
response_model=VideoResponse,
responses=RESPONSES,
include_in_schema=False,
)
async def frame_interpolation(
model_id: Annotated[str, Form()],
image1: Annotated[UploadFile, File()]=None,
image2: Annotated[UploadFile, File()]=None,
image_dir: Annotated[str, Form()]="",
model_id: Annotated[str, Form()] = "",
video: Annotated[UploadFile, File()]=None,
inter_frames: Annotated[int, Form()] = 2,
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
):
Expand All @@ -51,10 +50,9 @@ async def frame_interpolation(
content=http_error("Invalid bearer token"),
)


# Initialize FILMPipeline
film_pipeline = FILMPipeline(model_id)
film_pipeline.to(device=get_torch_device(),dtype=torch.float16)
film_pipeline.to(device=get_torch_device(), dtype=torch.float16)

# Prepare directories for input and output
temp_input_dir = "temp_input"
Expand All @@ -63,31 +61,21 @@ async def frame_interpolation(
os.makedirs(temp_output_dir, exist_ok=True)

try:
if os.path.isdir(image_dir):
if image1 and image2:
logger.info("Both directory and individual images provided. Directory will be used, and images will be ignored.")
reader = DirectoryReader(image_dir)
else:
if not (image1 and image2):
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error("Either a directory or two images must be provided."),
)

image1_path = os.path.join(temp_input_dir, "0.png")
image2_path = os.path.join(temp_input_dir, "1.png")

with open(image1_path, "wb") as f:
f.write(await image1.read())
with open(image2_path, "wb") as f:
f.write(await image2.read())

reader = DirectoryReader(temp_input_dir)
# Extract frames from video
video_data = await video.read()
frames = video_shredder(video_data, is_file_path=False)

# Save frames to temporary directory
for i, frame in enumerate(frames):
frame_path = os.path.join(temp_input_dir, f"{i}.png")
cv2.imwrite(frame_path, frame)

# Create DirectoryReader and DirectoryWriter
reader = DirectoryReader(temp_input_dir)
writer = DirectoryWriter(temp_output_dir)

# Perform interpolation
film_pipeline(reader, writer, inter_frames=inter_frames)

writer.close()
reader.reset()

Expand All @@ -96,8 +84,8 @@ async def frame_interpolation(
for frame_path in sorted(glob.glob(os.path.join(temp_output_dir, "*.png"))):
frame = Image.open(frame_path)
output_frames.append(frame)

output_images = [{"url": image_to_data_url(frame),"seed":0, "nsfw":False} for frame in output_frames]
# Wrap output frames in a list of batches (with a single batch in this case)
output_images = [[{"url": image_to_data_url(frame), "seed": 0, "nsfw": False} for frame in output_frames]]

except Exception as e:
logger.error(f"FILMPipeline error: {e}")
Expand All @@ -106,15 +94,13 @@ async def frame_interpolation(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=http_error("FILMPipeline error"),
)

finally:
# Clean up temporary directories
for file_path in glob.glob(os.path.join(temp_input_dir, "*")):
os.remove(file_path)
os.rmdir(temp_input_dir)

for file_path in glob.glob(os.path.join(temp_output_dir, "*")):
os.remove(file_path)
os.rmdir(temp_output_dir)

return {"images": output_images}
return {"frames": output_images}
107 changes: 86 additions & 21 deletions runner/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -332,27 +332,104 @@
]
}
},
"/frame_interpolation": {
"/llm-generate": {
"post": {
"summary": "Llm Generate",
"operationId": "llm_generate",
"requestBody": {
"content": {
"application/x-www-form-urlencoded": {
"schema": {
"$ref": "#/components/schemas/Body_llm_generate_llm_generate_post"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/LlmResponse"
}
}
}
},
"400": {
"description": "Bad Request",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPError"
}
}
}
},
"401": {
"description": "Unauthorized",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPError"
}
}
}
},
"500": {
"description": "Internal Server Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPError"
}
}
}
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
}
}
},
"security": [
{
"HTTPBearer": []
}
]
}
},
"/frame-interpolation": {
"post": {
"summary": "Frame Interpolation",
"operationId": "frame_interpolation",
"requestBody": {
"content": {
"multipart/form-data": {
"schema": {
"$ref": "#/components/schemas/Body_frame_interpolation_frame_interpolation_post"
"allOf": [
{
"$ref": "#/components/schemas/Body_frame_interpolation_frame_interpolation_post"
}
],
"title": "Body"
}
}
},
"required": true
}
},
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ImageResponse"
"$ref": "#/components/schemas/VideoResponse"
}
}
}
Expand Down Expand Up @@ -517,22 +594,13 @@
"properties": {
"model_id": {
"type": "string",
"title": "Model Id"
},
"image1": {
"type": "string",
"format": "binary",
"title": "Image1"
"title": "Model Id",
"default": ""
},
"image2": {
"video": {
"type": "string",
"format": "binary",
"title": "Image2"
},
"image_dir": {
"type": "string",
"title": "Image Dir",
"default": ""
"title": "Video"
},
"inter_frames": {
"type": "integer",
Expand All @@ -541,9 +609,6 @@
}
},
"type": "object",
"required": [
"model_id"
],
"title": "Body_frame_interpolation_frame_interpolation_post"
},
"Body_image_to_image_image_to_image_post": {
Expand Down
12 changes: 7 additions & 5 deletions worker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@ const containerCreator = "ai-worker"
// This only works right now on a single GPU because if there is another container
// using the GPU we stop it so we don't have to worry about having enough ports
var containerHostPorts = map[string]string{
"text-to-image": "8000",
"image-to-image": "8001",
"image-to-video": "8002",
"upscale": "8003",
"audio-to-text": "8004",
"text-to-image": "8000",
"image-to-image": "8001",
"image-to-video": "8002",
"upscale": "8003",
"audio-to-text": "8004",
"llm": "8005",
"frame-interpolation": "8006",
}

type DockerManager struct {
Expand Down
Loading

0 comments on commit d91717e

Please sign in to comment.