77import os
88import traceback
99from typing import List
10+ import logging
11+
12+ from streamer import PipelineStreamer
13+ from trickle import TrickleSubscriber
1014
1115# loads neighbouring modules with absolute paths
1216infer_root = os .path .abspath (os .path .dirname (__file__ ))
1721from streamer .zeromq import ZeroMQStreamer
1822
1923
20- async def main (http_port : int , stream_protocol : str , subscribe_url : str , publish_url : str , pipeline : str , params : dict ):
24+ async def main (http_port : int , stream_protocol : str , subscribe_url : str , publish_url : str , control_url : str , pipeline : str , params : dict ):
2125 if stream_protocol == "trickle" :
2226 handler = TrickleStreamer (subscribe_url , publish_url , pipeline , ** (params or {}))
2327 elif stream_protocol == "zeromq" :
@@ -29,6 +33,7 @@ async def main(http_port: int, stream_protocol: str, subscribe_url: str, publish
2933 try :
3034 handler .start ()
3135 runner = await start_http_server (handler , http_port )
36+ asyncio .create_task (start_control_subscriber (handler , control_url ))
3237 except Exception as e :
3338 logging .error (f"Error starting socket handler or HTTP server: { e } " )
3439 logging .error (f"Stack trace:\n { traceback .format_exc ()} " )
@@ -56,6 +61,19 @@ def signal_handler(sig, _):
5661 signal .signal (sig , signal_handler )
5762 return await future
5863
64+ async def start_control_subscriber (handler : PipelineStreamer , control_url : str ):
65+ if control_url is None or control_url .strip () == "" :
66+ logging .warning ("No control-url provided, inference won't get updates from the control trickle subscription" )
67+ return
68+ logging .info ("Starting Control subscriber at %s" , control_url )
69+ subscriber = TrickleSubscriber (url = control_url )
70+ while True :
71+ segment = await subscriber .next ()
72+ if segment .eos ():
73+ return
74+ params = await segment .read ()
75+ logging .info ("Received control message, updating model with params: %s" , params )
76+ handler .update_params (** json .loads (params ))
5977
6078if __name__ == "__main__" :
6179 parser = argparse .ArgumentParser (description = "Infer process to run the AI pipeline" )
@@ -81,6 +99,9 @@ def signal_handler(sig, _):
8199 parser .add_argument (
82100 "--publish-url" , type = str , required = True , help = "URL to publish output frames (trickle). For zeromq this is the output socket address"
83101 )
102+ parser .add_argument (
103+ "--control-url" , type = str , help = "URL to subscribe for Control API JSON messages"
104+ )
84105 parser .add_argument (
85106 "-v" , "--verbose" ,
86107 action = "store_true" ,
@@ -103,7 +124,7 @@ def signal_handler(sig, _):
103124
104125 try :
105126 asyncio .run (
106- main (args .http_port , args .stream_protocol , args .subscribe_url , args .publish_url , args .pipeline , params )
127+ main (args .http_port , args .stream_protocol , args .subscribe_url , args .publish_url , args .control_url , args . pipeline , params )
107128 )
108129 except Exception as e :
109130 logging .error (f"Fatal error in main: { e } " )
0 commit comments