Skip to content

Commit

Permalink
chore:add task param to env var for access to pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
RUFFY-369 committed Dec 30, 2024
1 parent 441e48c commit cab7133
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions runner/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ async def lifespan(app: FastAPI):
pipeline = os.environ["PIPELINE"]
model_id = os.environ["MODEL_ID"]

app.pipeline = load_pipeline(pipeline, model_id)
task = os.environ["TASK"] if pipeline == "image-to-image-generic" else None

app.pipeline = load_pipeline(pipeline, model_id, task)
app.include_router(load_route(pipeline))

app.hardware_info_service.log_gpu_compute_info()
Expand All @@ -34,7 +36,7 @@ async def lifespan(app: FastAPI):
logger.info("Shutting down")


def load_pipeline(pipeline: str, model_id: str) -> any:
def load_pipeline(pipeline: str, model_id: str, task: str) -> any:
match pipeline:
case "text-to-image":
from app.pipelines.text_to_image import TextToImagePipeline
Expand Down Expand Up @@ -81,7 +83,7 @@ def load_pipeline(pipeline: str, model_id: str) -> any:
case "image-to-image-generic":
from app.pipelines.image_to_image_generic import ImageToImageGenericPipeline

return ImageToImageGenericPipeline(model_id)
return ImageToImageGenericPipeline(model_id, task)
case _:
raise EnvironmentError(
f"{pipeline} is not a valid pipeline for model {model_id}"
Expand Down

0 comments on commit cab7133

Please sign in to comment.