Skip to content

Commit

Permalink
runner: Add HTTPError to API
Browse files Browse the repository at this point in the history
  • Loading branch information
yondonfu committed Feb 2, 2024
1 parent 75b7d29 commit 21b607e
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 32 deletions.
47 changes: 34 additions & 13 deletions runner/app/routes/image_to_image.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,28 @@
from fastapi import Depends, APIRouter, UploadFile, File, Form
from fastapi.responses import JSONResponse
from app.pipelines import ImageToImagePipeline
from app.dependencies import get_pipeline
from app.routes.util import image_to_data_url, ImageResponse
from app.routes.util import image_to_data_url, ImageResponse, HTTPError, http_error
import PIL
from typing import Annotated
import logging

router = APIRouter()

logger = logging.getLogger(__name__)

responses = {400: {"model": HTTPError}, 500: {"model": HTTPError}}


# TODO: Make model_id optional once Go codegen tool supports OAPI 3.1
# https://github.com/deepmap/oapi-codegen/issues/373
@router.post("/image-to-image", response_model=ImageResponse)
@router.post("/image-to-image/", response_model=ImageResponse, include_in_schema=False)
@router.post("/image-to-image", response_model=ImageResponse, responses=responses)
@router.post(
"/image-to-image/",
response_model=ImageResponse,
responses=responses,
include_in_schema=False,
)
async def image_to_image(
prompt: Annotated[str, Form()],
image: Annotated[UploadFile, File()],
Expand All @@ -23,18 +34,28 @@ async def image_to_image(
pipeline: ImageToImagePipeline = Depends(get_pipeline),
):
if model_id != "" and model_id != pipeline.model_id:
raise Exception(
f"pipeline configured with {pipeline.model_id} but called with {model_id}"
return JSONResponse(
status_code=400,
content=http_error(
f"pipeline configured with {pipeline.model_id} but called with {model_id}"
),
)

images = pipeline(
prompt,
PIL.Image.open(image.file).convert("RGB"),
strength=strength,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
seed=seed,
)
try:
images = pipeline(
prompt,
PIL.Image.open(image.file).convert("RGB"),
strength=strength,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
seed=seed,
)
except Exception as e:
logger.error(f"ImageToImagePipeline error: {e}")
logger.exception(e)
return JSONResponse(
status_code=500, content=http_error("ImageToImagePipeline error")
)

output_images = []
for img in images:
Expand Down
51 changes: 37 additions & 14 deletions runner/app/routes/image_to_video.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,28 @@
from fastapi import Depends, APIRouter, UploadFile, File, Form
from fastapi.responses import JSONResponse
from app.pipelines import ImageToVideoPipeline
from app.dependencies import get_pipeline
from app.routes.util import image_to_data_url, VideoResponse
from app.routes.util import image_to_data_url, VideoResponse, HTTPError
import PIL
from typing import Annotated
import logging

router = APIRouter()

logger = logging.getLogger(__name__)

responses = {400: {"model": HTTPError}, 500: {"model": HTTPError}}


# TODO: Make model_id optional once Go codegen tool supports OAPI 3.1
# https://github.com/deepmap/oapi-codegen/issues/373
@router.post("/image-to-video", response_model=VideoResponse)
@router.post("/image-to-video/", response_model=VideoResponse, include_in_schema=False)
@router.post("/image-to-video", response_model=VideoResponse, responses=responses)
@router.post(
"/image-to-video/",
response_model=VideoResponse,
responses=responses,
include_in_schema=False,
)
async def image_to_video(
image: Annotated[UploadFile, File()],
model_id: Annotated[str, Form()] = "",
Expand All @@ -24,19 +35,31 @@ async def image_to_video(
pipeline: ImageToVideoPipeline = Depends(get_pipeline),
):
if model_id != "" and model_id != pipeline.model_id:
raise Exception(
f"pipeline configured with {pipeline.model_id} but called with {model_id}"
return JSONResponse(
status_code=400,
content={
"detail": {
"msg": f"pipeline configured with {pipeline.model_id} but called with {model_id}"
}
},
)

batch_frames = pipeline(
PIL.Image.open(image.file).convert("RGB"),
height=height,
width=width,
fps=fps,
motion_bucket_id=motion_bucket_id,
noise_aug_strength=noise_aug_strength,
seed=seed,
)
try:
batch_frames = pipeline(
PIL.Image.open(image.file).convert("RGB"),
height=height,
width=width,
fps=fps,
motion_bucket_id=motion_bucket_id,
noise_aug_strength=noise_aug_strength,
seed=seed,
)
except Exception as e:
logger.error(f"ImageToVideoPipeline error: {e}")
logger.exception(e)
return JSONResponse(
status_code=500, content={"detail": {"msg": "ImageToVideoPipeline error"}}
)

output_frames = []
for frames in batch_frames:
Expand Down
27 changes: 22 additions & 5 deletions runner/app/routes/text_to_image.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from pydantic import BaseModel
from fastapi import Depends, APIRouter
from fastapi.responses import JSONResponse
from app.pipelines import TextToImagePipeline
from app.dependencies import get_pipeline
from app.routes.util import image_to_data_url, ImageResponse
from app.routes.util import image_to_data_url, ImageResponse, HTTPError, http_error
import logging

router = APIRouter()

logger = logging.getLogger(__name__)


class TextToImageParams(BaseModel):
# TODO: Make model_id optional once Go codegen tool supports OAPI 3.1
Expand All @@ -19,17 +23,30 @@ class TextToImageParams(BaseModel):
seed: int = None


@router.post("/text-to-image", response_model=ImageResponse)
responses = {400: {"model": HTTPError}, 500: {"model": HTTPError}}


@router.post("/text-to-image", response_model=ImageResponse, responses=responses)
@router.post("/text-to-image/", response_model=ImageResponse, include_in_schema=False)
async def text_to_image(
params: TextToImageParams, pipeline: TextToImagePipeline = Depends(get_pipeline)
):
if params.model_id != "" and params.model_id != pipeline.model_id:
raise Exception(
f"pipeline configured with {pipeline.model_id} but called with {params.model_id}"
return JSONResponse(
status_code=400,
content=http_error(
f"pipeline configured with {pipeline.model_id} but called with {params.model_id}"
),
)

images = pipeline(**params.model_dump())
try:
images = pipeline(**params.model_dump())
except Exception as e:
logger.error(f"TextToImagePipeline error: {e}")
logger.exception(e)
return JSONResponse(
status_code=500, content=http_error("TextToImagePipeline error")
)

output_images = []
for img in images:
Expand Down
12 changes: 12 additions & 0 deletions runner/app/routes/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,18 @@ class VideoResponse(BaseModel):
frames: List[List[Media]]


class APIError(BaseModel):
msg: str


class HTTPError(BaseModel):
detail: APIError


def http_error(msg: str) -> HTTPError:
return {"detail": {"msg": msg}}


def image_to_base64(img: PIL.Image, format: str = "png") -> str:
buffered = io.BytesIO()
img.save(buffered, format=format)
Expand Down
85 changes: 85 additions & 0 deletions runner/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,26 @@
}
}
},
"400": {
"description": "Bad Request",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPError"
}
}
}
},
"500": {
"description": "Internal Server Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPError"
}
}
}
},
"422": {
"description": "Validation Error",
"content": {
Expand Down Expand Up @@ -87,6 +107,26 @@
}
}
},
"400": {
"description": "Bad Request",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPError"
}
}
}
},
"500": {
"description": "Internal Server Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPError"
}
}
}
},
"422": {
"description": "Validation Error",
"content": {
Expand Down Expand Up @@ -125,6 +165,26 @@
}
}
},
"400": {
"description": "Bad Request",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPError"
}
}
}
},
"500": {
"description": "Internal Server Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPError"
}
}
}
},
"422": {
"description": "Validation Error",
"content": {
Expand All @@ -141,6 +201,19 @@
},
"components": {
"schemas": {
"APIError": {
"properties": {
"msg": {
"type": "string",
"title": "Msg"
}
},
"type": "object",
"required": [
"msg"
],
"title": "APIError"
},
"Body_image_to_image_image_to_image_post": {
"properties": {
"prompt": {
Expand Down Expand Up @@ -232,6 +305,18 @@
],
"title": "Body_image_to_video_image_to_video_post"
},
"HTTPError": {
"properties": {
"detail": {
"$ref": "#/components/schemas/APIError"
}
},
"type": "object",
"required": [
"detail"
],
"title": "HTTPError"
},
"HTTPValidationError": {
"properties": {
"detail": {
Expand Down

0 comments on commit 21b607e

Please sign in to comment.