From 21b607e2e9737689ab27f5c320a05d12f907e39b Mon Sep 17 00:00:00 2001
From: Yondon Fu <yondon.fu@gmail.com>
Date: Fri, 2 Feb 2024 21:15:44 +0000
Subject: [PATCH] runner: Add HTTPError to API

---
 runner/app/routes/image_to_image.py | 47 +++++++++++-----
 runner/app/routes/image_to_video.py | 51 ++++++++++++-----
 runner/app/routes/text_to_image.py  | 27 +++++++--
 runner/app/routes/util.py           | 12 ++++
 runner/openapi.json                 | 85 +++++++++++++++++++++++++++++
 5 files changed, 190 insertions(+), 32 deletions(-)

diff --git a/runner/app/routes/image_to_image.py b/runner/app/routes/image_to_image.py
index 6b00e64a..52a3684b 100644
--- a/runner/app/routes/image_to_image.py
+++ b/runner/app/routes/image_to_image.py
@@ -1,17 +1,28 @@
 from fastapi import Depends, APIRouter, UploadFile, File, Form
+from fastapi.responses import JSONResponse
 from app.pipelines import ImageToImagePipeline
 from app.dependencies import get_pipeline
-from app.routes.util import image_to_data_url, ImageResponse
+from app.routes.util import image_to_data_url, ImageResponse, HTTPError, http_error
 import PIL
 from typing import Annotated
+import logging
 
 router = APIRouter()
 
+logger = logging.getLogger(__name__)
+
+responses = {400: {"model": HTTPError}, 500: {"model": HTTPError}}
+
 
 # TODO: Make model_id optional once Go codegen tool supports OAPI 3.1
 # https://github.com/deepmap/oapi-codegen/issues/373
-@router.post("/image-to-image", response_model=ImageResponse)
-@router.post("/image-to-image/", response_model=ImageResponse, include_in_schema=False)
+@router.post("/image-to-image", response_model=ImageResponse, responses=responses)
+@router.post(
+    "/image-to-image/",
+    response_model=ImageResponse,
+    responses=responses,
+    include_in_schema=False,
+)
 async def image_to_image(
     prompt: Annotated[str, Form()],
     image: Annotated[UploadFile, File()],
@@ -23,18 +34,28 @@ async def image_to_image(
     pipeline: ImageToImagePipeline = Depends(get_pipeline),
 ):
     if model_id != "" and model_id != pipeline.model_id:
-        raise Exception(
-            f"pipeline configured with {pipeline.model_id} but called with {model_id}"
+        return JSONResponse(
+            status_code=400,
+            content=http_error(
+                f"pipeline configured with {pipeline.model_id} but called with {model_id}"
+            ),
         )
 
-    images = pipeline(
-        prompt,
-        PIL.Image.open(image.file).convert("RGB"),
-        strength=strength,
-        guidance_scale=guidance_scale,
-        negative_prompt=negative_prompt,
-        seed=seed,
-    )
+    try:
+        images = pipeline(
+            prompt,
+            PIL.Image.open(image.file).convert("RGB"),
+            strength=strength,
+            guidance_scale=guidance_scale,
+            negative_prompt=negative_prompt,
+            seed=seed,
+        )
+    except Exception as e:
+        logger.error(f"ImageToImagePipeline error: {e}")
+        logger.exception(e)
+        return JSONResponse(
+            status_code=500, content=http_error("ImageToImagePipeline error")
+        )
 
     output_images = []
     for img in images:
diff --git a/runner/app/routes/image_to_video.py b/runner/app/routes/image_to_video.py
index ab1174cd..0977a3c7 100644
--- a/runner/app/routes/image_to_video.py
+++ b/runner/app/routes/image_to_video.py
@@ -1,17 +1,28 @@
 from fastapi import Depends, APIRouter, UploadFile, File, Form
+from fastapi.responses import JSONResponse
 from app.pipelines import ImageToVideoPipeline
 from app.dependencies import get_pipeline
-from app.routes.util import image_to_data_url, VideoResponse
+from app.routes.util import image_to_data_url, VideoResponse, HTTPError
 import PIL
 from typing import Annotated
+import logging
 
 router = APIRouter()
 
+logger = logging.getLogger(__name__)
+
+responses = {400: {"model": HTTPError}, 500: {"model": HTTPError}}
+
 
 # TODO: Make model_id optional once Go codegen tool supports OAPI 3.1
 # https://github.com/deepmap/oapi-codegen/issues/373
-@router.post("/image-to-video", response_model=VideoResponse)
-@router.post("/image-to-video/", response_model=VideoResponse, include_in_schema=False)
+@router.post("/image-to-video", response_model=VideoResponse, responses=responses)
+@router.post(
+    "/image-to-video/",
+    response_model=VideoResponse,
+    responses=responses,
+    include_in_schema=False,
+)
 async def image_to_video(
     image: Annotated[UploadFile, File()],
     model_id: Annotated[str, Form()] = "",
@@ -24,19 +35,31 @@ async def image_to_video(
     pipeline: ImageToVideoPipeline = Depends(get_pipeline),
 ):
     if model_id != "" and model_id != pipeline.model_id:
-        raise Exception(
-            f"pipeline configured with {pipeline.model_id} but called with {model_id}"
+        return JSONResponse(
+            status_code=400,
+            content={
+                "detail": {
+                    "msg": f"pipeline configured with {pipeline.model_id} but called with {model_id}"
+                }
+            },
         )
 
-    batch_frames = pipeline(
-        PIL.Image.open(image.file).convert("RGB"),
-        height=height,
-        width=width,
-        fps=fps,
-        motion_bucket_id=motion_bucket_id,
-        noise_aug_strength=noise_aug_strength,
-        seed=seed,
-    )
+    try:
+        batch_frames = pipeline(
+            PIL.Image.open(image.file).convert("RGB"),
+            height=height,
+            width=width,
+            fps=fps,
+            motion_bucket_id=motion_bucket_id,
+            noise_aug_strength=noise_aug_strength,
+            seed=seed,
+        )
+    except Exception as e:
+        logger.error(f"ImageToVideoPipeline error: {e}")
+        logger.exception(e)
+        return JSONResponse(
+            status_code=500, content={"detail": {"msg": "ImageToVideoPipeline error"}}
+        )
 
     output_frames = []
     for frames in batch_frames:
diff --git a/runner/app/routes/text_to_image.py b/runner/app/routes/text_to_image.py
index 0ba5d793..fa49ecc0 100644
--- a/runner/app/routes/text_to_image.py
+++ b/runner/app/routes/text_to_image.py
@@ -1,11 +1,15 @@
 from pydantic import BaseModel
 from fastapi import Depends, APIRouter
+from fastapi.responses import JSONResponse
 from app.pipelines import TextToImagePipeline
 from app.dependencies import get_pipeline
-from app.routes.util import image_to_data_url, ImageResponse
+from app.routes.util import image_to_data_url, ImageResponse, HTTPError, http_error
+import logging
 
 router = APIRouter()
 
+logger = logging.getLogger(__name__)
+
 
 class TextToImageParams(BaseModel):
     # TODO: Make model_id optional once Go codegen tool supports OAPI 3.1
@@ -19,17 +23,30 @@ class TextToImageParams(BaseModel):
     seed: int = None
 
 
-@router.post("/text-to-image", response_model=ImageResponse)
+responses = {400: {"model": HTTPError}, 500: {"model": HTTPError}}
+
+
+@router.post("/text-to-image", response_model=ImageResponse, responses=responses)
 @router.post("/text-to-image/", response_model=ImageResponse, include_in_schema=False)
 async def text_to_image(
     params: TextToImageParams, pipeline: TextToImagePipeline = Depends(get_pipeline)
 ):
     if params.model_id != "" and params.model_id != pipeline.model_id:
-        raise Exception(
-            f"pipeline configured with {pipeline.model_id} but called with {params.model_id}"
+        return JSONResponse(
+            status_code=400,
+            content=http_error(
+                f"pipeline configured with {pipeline.model_id} but called with {params.model_id}"
+            ),
         )
 
-    images = pipeline(**params.model_dump())
+    try:
+        images = pipeline(**params.model_dump())
+    except Exception as e:
+        logger.error(f"TextToImagePipeline error: {e}")
+        logger.exception(e)
+        return JSONResponse(
+            status_code=500, content=http_error("TextToImagePipeline error")
+        )
 
     output_images = []
     for img in images:
diff --git a/runner/app/routes/util.py b/runner/app/routes/util.py
index f63685d2..9cb24f9e 100644
--- a/runner/app/routes/util.py
+++ b/runner/app/routes/util.py
@@ -17,6 +17,18 @@ class VideoResponse(BaseModel):
     frames: List[List[Media]]
 
 
+class APIError(BaseModel):
+    msg: str
+
+
+class HTTPError(BaseModel):
+    detail: APIError
+
+
+def http_error(msg: str) -> HTTPError:
+    return {"detail": {"msg": msg}}
+
+
 def image_to_base64(img: PIL.Image, format: str = "png") -> str:
     buffered = io.BytesIO()
     img.save(buffered, format=format)
diff --git a/runner/openapi.json b/runner/openapi.json
index e4d87db1..07fdde23 100644
--- a/runner/openapi.json
+++ b/runner/openapi.json
@@ -49,6 +49,26 @@
               }
             }
           },
+          "400": {
+            "description": "Bad Request",
+            "content": {
+              "application/json": {
+                "schema": {
+                  "$ref": "#/components/schemas/HTTPError"
+                }
+              }
+            }
+          },
+          "500": {
+            "description": "Internal Server Error",
+            "content": {
+              "application/json": {
+                "schema": {
+                  "$ref": "#/components/schemas/HTTPError"
+                }
+              }
+            }
+          },
           "422": {
             "description": "Validation Error",
             "content": {
@@ -87,6 +107,26 @@
               }
             }
           },
+          "400": {
+            "description": "Bad Request",
+            "content": {
+              "application/json": {
+                "schema": {
+                  "$ref": "#/components/schemas/HTTPError"
+                }
+              }
+            }
+          },
+          "500": {
+            "description": "Internal Server Error",
+            "content": {
+              "application/json": {
+                "schema": {
+                  "$ref": "#/components/schemas/HTTPError"
+                }
+              }
+            }
+          },
           "422": {
             "description": "Validation Error",
             "content": {
@@ -125,6 +165,26 @@
               }
             }
           },
+          "400": {
+            "description": "Bad Request",
+            "content": {
+              "application/json": {
+                "schema": {
+                  "$ref": "#/components/schemas/HTTPError"
+                }
+              }
+            }
+          },
+          "500": {
+            "description": "Internal Server Error",
+            "content": {
+              "application/json": {
+                "schema": {
+                  "$ref": "#/components/schemas/HTTPError"
+                }
+              }
+            }
+          },
           "422": {
             "description": "Validation Error",
             "content": {
@@ -141,6 +201,19 @@
   },
   "components": {
     "schemas": {
+      "APIError": {
+        "properties": {
+          "msg": {
+            "type": "string",
+            "title": "Msg"
+          }
+        },
+        "type": "object",
+        "required": [
+          "msg"
+        ],
+        "title": "APIError"
+      },
       "Body_image_to_image_image_to_image_post": {
         "properties": {
           "prompt": {
@@ -232,6 +305,18 @@
         ],
         "title": "Body_image_to_video_image_to_video_post"
       },
+      "HTTPError": {
+        "properties": {
+          "detail": {
+            "$ref": "#/components/schemas/APIError"
+          }
+        },
+        "type": "object",
+        "required": [
+          "detail"
+        ],
+        "title": "HTTPError"
+      },
       "HTTPValidationError": {
         "properties": {
           "detail": {