diff --git a/runner/app/main.py b/runner/app/main.py index 57acb6f85..626474871 100644 --- a/runner/app/main.py +++ b/runner/app/main.py @@ -23,7 +23,9 @@ async def lifespan(app: FastAPI): pipeline = os.environ["PIPELINE"] model_id = os.environ["MODEL_ID"] - app.pipeline = load_pipeline(pipeline, model_id) + task = os.environ["TASK"] if pipeline == "image-to-image-generic" else None + + app.pipeline = load_pipeline(pipeline, model_id, task) app.include_router(load_route(pipeline)) app.hardware_info_service.log_gpu_compute_info() @@ -34,7 +36,7 @@ async def lifespan(app: FastAPI): logger.info("Shutting down") -def load_pipeline(pipeline: str, model_id: str) -> any: +def load_pipeline(pipeline: str, model_id: str, task: str) -> any: match pipeline: case "text-to-image": from app.pipelines.text_to_image import TextToImagePipeline @@ -81,7 +83,7 @@ def load_pipeline(pipeline: str, model_id: str) -> any: case "image-to-image-generic": from app.pipelines.image_to_image_generic import ImageToImageGenericPipeline - return ImageToImageGenericPipeline(model_id) + return ImageToImageGenericPipeline(model_id, task) case _: raise EnvironmentError( f"{pipeline} is not a valid pipeline for model {model_id}"