Skip to content

Commit

Permalink
Add Control API (#279)
Browse files Browse the repository at this point in the history
  • Loading branch information
leszko authored Nov 19, 2024
1 parent 27c969f commit 04aaa59
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 76 deletions.
25 changes: 23 additions & 2 deletions runner/app/live/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
import os
import traceback
from typing import List
import logging

from streamer import PipelineStreamer
from trickle import TrickleSubscriber

# loads neighbouring modules with absolute paths
infer_root = os.path.abspath(os.path.dirname(__file__))
Expand All @@ -17,7 +21,7 @@
from streamer.zeromq import ZeroMQStreamer


async def main(http_port: int, stream_protocol: str, subscribe_url: str, publish_url: str, pipeline: str, params: dict):
async def main(http_port: int, stream_protocol: str, subscribe_url: str, publish_url: str, control_url: str, pipeline: str, params: dict):
if stream_protocol == "trickle":
handler = TrickleStreamer(subscribe_url, publish_url, pipeline, **(params or {}))
elif stream_protocol == "zeromq":
Expand All @@ -29,6 +33,7 @@ async def main(http_port: int, stream_protocol: str, subscribe_url: str, publish
try:
handler.start()
runner = await start_http_server(handler, http_port)
asyncio.create_task(start_control_subscriber(handler, control_url))
except Exception as e:
logging.error(f"Error starting socket handler or HTTP server: {e}")
logging.error(f"Stack trace:\n{traceback.format_exc()}")
Expand Down Expand Up @@ -56,6 +61,19 @@ def signal_handler(sig, _):
signal.signal(sig, signal_handler)
return await future

async def start_control_subscriber(handler: PipelineStreamer, control_url: str):
if control_url is None or control_url.strip() == "":
logging.warning("No control-url provided, inference won't get updates from the control trickle subscription")
return
logging.info("Starting Control subscriber at %s", control_url)
subscriber = TrickleSubscriber(url=control_url)
while True:
segment = await subscriber.next()
if segment.eos():
return
params = await segment.read()
logging.info("Received control message, updating model with params: %s", params)
handler.update_params(**json.loads(params))

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Infer process to run the AI pipeline")
Expand All @@ -81,6 +99,9 @@ def signal_handler(sig, _):
parser.add_argument(
"--publish-url", type=str, required=True, help="URL to publish output frames (trickle). For zeromq this is the output socket address"
)
parser.add_argument(
"--control-url", type=str, help="URL to subscribe for Control API JSON messages"
)
parser.add_argument(
"-v", "--verbose",
action="store_true",
Expand All @@ -103,7 +124,7 @@ def signal_handler(sig, _):

try:
asyncio.run(
main(args.http_port, args.stream_protocol, args.subscribe_url, args.publish_url, args.pipeline, params)
main(args.http_port, args.stream_protocol, args.subscribe_url, args.publish_url, args.control_url, args.pipeline, params)
)
except Exception as e:
logging.error(f"Fatal error in main: {e}")
Expand Down
3 changes: 2 additions & 1 deletion runner/app/live/trickle/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .media import run_subscribe, run_publish
from .trickle_subscriber import TrickleSubscriber

__all__ = ["run_subscribe", "run_publish"]
__all__ = ["run_subscribe", "run_publish", "TrickleSubscriber"]
4 changes: 3 additions & 1 deletion runner/app/pipelines/live_video_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,17 @@ def __call__(
):
try:
if not self.process:
logger.info(f"Starting stream, subscribe={kwargs['subscribe_url']} publish={kwargs['publish_url']}, control={kwargs['control_url']}")
self.start_process(
pipeline=self.model_id, # we use the model_id as the pipeline name for now
http_port=8888,
subscribe_url=kwargs["subscribe_url"],
publish_url=kwargs["publish_url"],
control_url=kwargs["control_url"],
initial_params=json.dumps(kwargs["params"]),
# TODO: set torch device from self.torch_device
)
logger.info(f"Starting stream, subscribe={kwargs['subscribe_url']} publish={kwargs['publish_url']}")
logger.info(f"Starting stream, subscribe={kwargs['subscribe_url']} publish={kwargs['publish_url']}, control={kwargs['control_url']}")
return
except Exception as e:
raise InferenceError(original_exception=e)
Expand Down
8 changes: 7 additions & 1 deletion runner/app/routes/live_video_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ class LiveVideoToVideoParams(BaseModel):
description="Destination URL of the outgoing stream to publish.",
),
]
control_url: Annotated[
str,
Field(
default="",description="URL for subscribing via Trickle protocol for updates in the live video-to-video generation params.",
),
]
model_id: Annotated[
str,
Field(
Expand Down Expand Up @@ -129,5 +135,5 @@ async def live_video_to_video(
)

# outputs unused for now; the orchestrator is setting these
return {'publish_url':"", 'subscribe_url': ""}
return {'publish_url':"", 'subscribe_url': "", 'control_url': ""}

4 changes: 3 additions & 1 deletion runner/app/routes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ class LiveVideoToVideoResponse(BaseModel):
publish_url: str = Field(
..., description="Destination URL of the outgoing stream to publish to"
)

control_url: str = Field(
..., description="URL for updating the live video-to-video generation"
)

class APIError(BaseModel):
"""API error response model."""
Expand Down
11 changes: 11 additions & 0 deletions runner/gateway.openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -924,6 +924,12 @@ components:
type: string
title: Publish Url
description: Destination URL of the outgoing stream to publish.
control_url:
type: string
title: Control Url
description: URL for subscribing via Trickle protocol for updates in the
live video-to-video generation params.
default: ''
model_id:
type: string
title: Model Id
Expand All @@ -950,10 +956,15 @@ components:
type: string
title: Publish Url
description: Destination URL of the outgoing stream to publish to
control_url:
type: string
title: Control Url
description: URL for updating the live video-to-video generation
type: object
required:
- subscribe_url
- publish_url
- control_url
title: LiveVideoToVideoResponse
description: Response model for live video-to-video generation.
MasksResponse:
Expand Down
11 changes: 11 additions & 0 deletions runner/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,12 @@ components:
type: string
title: Publish Url
description: Destination URL of the outgoing stream to publish.
control_url:
type: string
title: Control Url
description: URL for subscribing via Trickle protocol for updates in the
live video-to-video generation params.
default: ''
model_id:
type: string
title: Model Id
Expand All @@ -966,10 +972,15 @@ components:
type: string
title: Publish Url
description: Destination URL of the outgoing stream to publish to
control_url:
type: string
title: Control Url
description: URL for updating the live video-to-video generation
type: object
required:
- subscribe_url
- publish_url
- control_url
title: LiveVideoToVideoResponse
description: Response model for live video-to-video generation.
MasksResponse:
Expand Down
Loading

0 comments on commit 04aaa59

Please sign in to comment.