Skip to content

Commit 04aaa59

Browse files
authored
Add Control API (#279)
1 parent 27c969f commit 04aaa59

File tree

8 files changed

+137
-76
lines changed

8 files changed

+137
-76
lines changed

runner/app/live/infer.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
import os
88
import traceback
99
from typing import List
10+
import logging
11+
12+
from streamer import PipelineStreamer
13+
from trickle import TrickleSubscriber
1014

1115
# loads neighbouring modules with absolute paths
1216
infer_root = os.path.abspath(os.path.dirname(__file__))
@@ -17,7 +21,7 @@
1721
from streamer.zeromq import ZeroMQStreamer
1822

1923

20-
async def main(http_port: int, stream_protocol: str, subscribe_url: str, publish_url: str, pipeline: str, params: dict):
24+
async def main(http_port: int, stream_protocol: str, subscribe_url: str, publish_url: str, control_url: str, pipeline: str, params: dict):
2125
if stream_protocol == "trickle":
2226
handler = TrickleStreamer(subscribe_url, publish_url, pipeline, **(params or {}))
2327
elif stream_protocol == "zeromq":
@@ -29,6 +33,7 @@ async def main(http_port: int, stream_protocol: str, subscribe_url: str, publish
2933
try:
3034
handler.start()
3135
runner = await start_http_server(handler, http_port)
36+
asyncio.create_task(start_control_subscriber(handler, control_url))
3237
except Exception as e:
3338
logging.error(f"Error starting socket handler or HTTP server: {e}")
3439
logging.error(f"Stack trace:\n{traceback.format_exc()}")
@@ -56,6 +61,19 @@ def signal_handler(sig, _):
5661
signal.signal(sig, signal_handler)
5762
return await future
5863

64+
async def start_control_subscriber(handler: PipelineStreamer, control_url: str):
65+
if control_url is None or control_url.strip() == "":
66+
logging.warning("No control-url provided, inference won't get updates from the control trickle subscription")
67+
return
68+
logging.info("Starting Control subscriber at %s", control_url)
69+
subscriber = TrickleSubscriber(url=control_url)
70+
while True:
71+
segment = await subscriber.next()
72+
if segment.eos():
73+
return
74+
params = await segment.read()
75+
logging.info("Received control message, updating model with params: %s", params)
76+
handler.update_params(**json.loads(params))
5977

6078
if __name__ == "__main__":
6179
parser = argparse.ArgumentParser(description="Infer process to run the AI pipeline")
@@ -81,6 +99,9 @@ def signal_handler(sig, _):
8199
parser.add_argument(
82100
"--publish-url", type=str, required=True, help="URL to publish output frames (trickle). For zeromq this is the output socket address"
83101
)
102+
parser.add_argument(
103+
"--control-url", type=str, help="URL to subscribe for Control API JSON messages"
104+
)
84105
parser.add_argument(
85106
"-v", "--verbose",
86107
action="store_true",
@@ -103,7 +124,7 @@ def signal_handler(sig, _):
103124

104125
try:
105126
asyncio.run(
106-
main(args.http_port, args.stream_protocol, args.subscribe_url, args.publish_url, args.pipeline, params)
127+
main(args.http_port, args.stream_protocol, args.subscribe_url, args.publish_url, args.control_url, args.pipeline, params)
107128
)
108129
except Exception as e:
109130
logging.error(f"Fatal error in main: {e}")

runner/app/live/trickle/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .media import run_subscribe, run_publish
2+
from .trickle_subscriber import TrickleSubscriber
23

3-
__all__ = ["run_subscribe", "run_publish"]
4+
__all__ = ["run_subscribe", "run_publish", "TrickleSubscriber"]

runner/app/pipelines/live_video_to_video.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,17 @@ def __call__(
3131
):
3232
try:
3333
if not self.process:
34+
logger.info(f"Starting stream, subscribe={kwargs['subscribe_url']} publish={kwargs['publish_url']}, control={kwargs['control_url']}")
3435
self.start_process(
3536
pipeline=self.model_id, # we use the model_id as the pipeline name for now
3637
http_port=8888,
3738
subscribe_url=kwargs["subscribe_url"],
3839
publish_url=kwargs["publish_url"],
40+
control_url=kwargs["control_url"],
3941
initial_params=json.dumps(kwargs["params"]),
4042
# TODO: set torch device from self.torch_device
4143
)
42-
logger.info(f"Starting stream, subscribe={kwargs['subscribe_url']} publish={kwargs['publish_url']}")
44+
logger.info(f"Starting stream, subscribe={kwargs['subscribe_url']} publish={kwargs['publish_url']}, control={kwargs['control_url']}")
4345
return
4446
except Exception as e:
4547
raise InferenceError(original_exception=e)

runner/app/routes/live_video_to_video.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ class LiveVideoToVideoParams(BaseModel):
4545
description="Destination URL of the outgoing stream to publish.",
4646
),
4747
]
48+
control_url: Annotated[
49+
str,
50+
Field(
51+
default="",description="URL for subscribing via Trickle protocol for updates in the live video-to-video generation params.",
52+
),
53+
]
4854
model_id: Annotated[
4955
str,
5056
Field(
@@ -129,5 +135,5 @@ async def live_video_to_video(
129135
)
130136

131137
# outputs unused for now; the orchestrator is setting these
132-
return {'publish_url':"", 'subscribe_url': ""}
138+
return {'publish_url':"", 'subscribe_url': "", 'control_url': ""}
133139

runner/app/routes/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ class LiveVideoToVideoResponse(BaseModel):
9292
publish_url: str = Field(
9393
..., description="Destination URL of the outgoing stream to publish to"
9494
)
95-
95+
control_url: str = Field(
96+
..., description="URL for updating the live video-to-video generation"
97+
)
9698

9799
class APIError(BaseModel):
98100
"""API error response model."""

runner/gateway.openapi.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -924,6 +924,12 @@ components:
924924
type: string
925925
title: Publish Url
926926
description: Destination URL of the outgoing stream to publish.
927+
control_url:
928+
type: string
929+
title: Control Url
930+
description: URL for subscribing via Trickle protocol for updates in the
931+
live video-to-video generation params.
932+
default: ''
927933
model_id:
928934
type: string
929935
title: Model Id
@@ -950,10 +956,15 @@ components:
950956
type: string
951957
title: Publish Url
952958
description: Destination URL of the outgoing stream to publish to
959+
control_url:
960+
type: string
961+
title: Control Url
962+
description: URL for updating the live video-to-video generation
953963
type: object
954964
required:
955965
- subscribe_url
956966
- publish_url
967+
- control_url
957968
title: LiveVideoToVideoResponse
958969
description: Response model for live video-to-video generation.
959970
MasksResponse:

runner/openapi.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,12 @@ components:
941941
type: string
942942
title: Publish Url
943943
description: Destination URL of the outgoing stream to publish.
944+
control_url:
945+
type: string
946+
title: Control Url
947+
description: URL for subscribing via Trickle protocol for updates in the
948+
live video-to-video generation params.
949+
default: ''
944950
model_id:
945951
type: string
946952
title: Model Id
@@ -966,10 +972,15 @@ components:
966972
type: string
967973
title: Publish Url
968974
description: Destination URL of the outgoing stream to publish to
975+
control_url:
976+
type: string
977+
title: Control Url
978+
description: URL for updating the live video-to-video generation
969979
type: object
970980
required:
971981
- subscribe_url
972982
- publish_url
983+
- control_url
973984
title: LiveVideoToVideoResponse
974985
description: Response model for live video-to-video generation.
975986
MasksResponse:

0 commit comments

Comments
 (0)