Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

runner/live: Improve pipeline status schema #362

Merged
merged 13 commits into from
Dec 13, 2024
7 changes: 4 additions & 3 deletions runner/app/live/streamer/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import queue
import sys
import time
from typing import Any

from PIL import Image

Expand Down Expand Up @@ -52,10 +53,10 @@ def stop(self):
logging.error("Failed to terminate process, killing")
self.process.kill()

for q in [self.input_queue, self.output_queue, self.param_update_queue, self.error_queue, self.log_queue]:
for q in [self.input_queue, self.output_queue, self.param_update_queue,
self.error_queue, self.log_queue]:
q.cancel_join_thread()
q.close()
self.done = None

def is_done(self):
return self.done.is_set()
Expand Down Expand Up @@ -158,7 +159,7 @@ def _setup_logging(self):
sys.stdout = QueueTeeStream(sys.stdout, self)
sys.stderr = QueueTeeStream(sys.stderr, self)

def _queue_put_fifo(self, _queue: mp.Queue, item: any):
def _queue_put_fifo(self, _queue: mp.Queue, item: Any):
"""Helper to put an item on a queue, dropping oldest items if needed"""
while not self.is_done():
try:
Expand Down
159 changes: 110 additions & 49 deletions runner/app/live/streamer/streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import logging
import os
import time
import traceback
import numpy as np
from multiprocessing.synchronize import Event
from typing import AsyncGenerator
from asyncio import Lock
import hashlib
import json

import cv2
from PIL import Image
Expand All @@ -18,39 +19,68 @@
fps_log_interval = 10
status_report_interval = 10

class PipelineStatus(BaseModel):
"""Holds metrics for the pipeline streamer"""
type: str = "status"
pipeline: str
start_time: float
class InputStatus(BaseModel):
"""Holds metrics for the input stream"""
last_input_time: float | None = None
fps: float = 0.0

def model_dump(self, **kwargs):
return _convert_timestamps(super().model_dump(**kwargs))

class InferenceStatus(BaseModel):
"""Holds metrics for the inference process"""
last_output_time: float | None = None
fps: float = 0.0

last_params_update_time: float | None = None
last_params: dict | None = None
last_params_hash: str | None = None

input_fps: float = 0.0
output_fps: float = 0.0
last_input_time: float | None = None
last_output_time: float | None = None
last_error_time: float | None = None
last_error: str | None = None

restart_count: int = 0
last_restart_time: float | None = None
last_restart_logs: list[str] | None = None # Will contain last N lines before restart
last_error: str | None = None
last_error_time: float | None = None
last_restart_logs: list[str] | None = None
restart_count: int = 0

def model_dump(self, **kwargs):
return _convert_timestamps(super().model_dump(**kwargs))

# Use a class instead of an enum since Pydantic can't handle serializing enums
class PipelineState:
OFFLINE = "OFFLINE"
ONLINE = "ONLINE"
DEGRADED_INPUT = "DEGRADED_INPUT"
DEGRADED_INFERENCE = "DEGRADED_INFERENCE"

def update_params(self, params: dict):
self.last_params = params
self.last_params_hash = str(hash(str(sorted(params.items()))))
class PipelineStatus(BaseModel):
"""Holds metrics for the pipeline streamer"""
type: str = "status"
pipeline: str
start_time: float
state: str = PipelineState.OFFLINE
last_state_update_time: float | None = None

input_status: InputStatus = InputStatus()
inference_status: InferenceStatus = InferenceStatus()

def update_params(self, params: dict, do_update_time=True):
self.inference_status.last_params = params
self.inference_status.last_params_hash = hashlib.md5(json.dumps(params, sort_keys=True).encode()).hexdigest()
if do_update_time:
self.inference_status.last_params_update_time = time.time()
return self

def model_dump(self, **kwargs):
data = super().model_dump(**kwargs)
# Convert all fields ending with _time to milliseconds
for field, value in data.items():
if field.endswith('_time'):
data[field] = _timestamp_to_ms(value)
return data
return _convert_timestamps(super().model_dump(**kwargs))


def _convert_timestamps(data: dict) -> dict:
"""Convert timestamp fields ending with _time to milliseconds"""
for field, value in data.items():
if field.endswith('_time'):
data[field] = _timestamp_to_ms(value)
return data

def _timestamp_to_ms(v: float | None) -> int | None:
return int(v * 1000) if v is not None else None
Expand All @@ -64,7 +94,7 @@ def __init__(self, protocol: StreamProtocol, pipeline: str, input_timeout: int,
self.process = None
self.input_timeout = input_timeout # 0 means disabled
self.done_future = None
self.status = PipelineStatus(pipeline=pipeline, start_time=time.time()).update_params(params)
self.status = PipelineStatus(pipeline=pipeline, start_time=time.time()).update_params(params, False)
self.control_task = None
self.report_status_task = None
self.report_status_lock = Lock()
Expand Down Expand Up @@ -132,23 +162,25 @@ async def _restart(self):
# don't call the full start/stop methods since we don't want to restart the protocol
await self._stop_process()
self._start_process()
self.status.restart_count += 1
self.status.last_restart_time = time.time()
self.status.last_restart_logs = restart_logs
self.status.inference_status.restart_count += 1
self.status.inference_status.last_restart_time = time.time()
self.status.inference_status.last_restart_logs = restart_logs
if last_error:
self.status.last_error = last_error
error_msg, error_time = last_error
self.status.inference_status.last_error = error_msg
self.status.inference_status.last_error_time = error_time

await self._emit_monitoring_event({
"type": "restart",
"pipeline": self.pipeline,
"restart_count": self.status.restart_count,
"restart_time": self.status.last_restart_time,
"restart_count": self.status.inference_status.restart_count,
"restart_time": self.status.inference_status.last_restart_time,
"restart_logs": restart_logs,
"last_error": last_error
})

logging.info(
f"PipelineProcess restarted. Restart count: {self.status.restart_count}"
f"PipelineProcess restarted. Restart count: {self.status.inference_status.restart_count}"
)
except Exception:
logging.error(f"Error restarting pipeline process", exc_info=True)
Expand All @@ -158,15 +190,14 @@ async def update_params(self, params: dict):
self.params = params
if self.process:
self.process.update_params(params)
self.status.last_params_update_time = time.time()
self.status.update_params(params)

await self._emit_monitoring_event({
"type": "params_update",
"pipeline": self.pipeline,
"params": params,
"params_hash": self.status.last_params_hash,
"update_time": self.status.last_params_update_time
"params_hash": self.status.inference_status.last_params_hash,
"update_time": self.status.inference_status.last_params_update_time
})

async def report_status_loop(self):
Expand All @@ -180,15 +211,45 @@ async def report_status_loop(self):
await asyncio.sleep(next_report - current_time)
next_report += status_report_interval

new_state = self._current_state()
if new_state != self.status.state:
self.status.state = new_state
self.status.last_state_update_time = current_time
logging.info(f"Pipeline state changed to {new_state}")

event = self.status.model_dump()
# Clear the large transient fields after reporting them once
self.status.last_params = None
self.status.last_restart_logs = None
self.status.inference_status.last_params = None
self.status.inference_status.last_restart_logs = None
await self._emit_monitoring_event(event)

def _current_state(self) -> str:
current_time = time.time()
input = self.status.input_status
if not input.last_input_time or current_time - input.last_input_time > 60:
return PipelineState.OFFLINE
elif current_time - input.last_input_time > 2 or input.fps < 15:
return PipelineState.DEGRADED_INPUT

inference = self.status.inference_status
pipeline_load_time = max(self.status.start_time, inference.last_params_update_time or 0)
if not inference.last_output_time and current_time - pipeline_load_time < 30:
# 30s grace period for the pipeline to start
return PipelineState.ONLINE

delayed_frames = not inference.last_output_time or current_time - inference.last_output_time > 5
low_fps = inference.fps < min(10, 0.8 * input.fps)
recent_restart = inference.last_restart_time and current_time - inference.last_restart_time < 60
recent_error = inference.last_error_time and current_time - inference.last_error_time < 15
if delayed_frames or low_fps or recent_restart or recent_error:
return PipelineState.DEGRADED_INFERENCE

return PipelineState.ONLINE

async def _emit_monitoring_event(self, event: dict):
"""Protected method to emit monitoring event with lock"""
event["timestamp"] = _timestamp_to_ms(time.time())
logging.info(f"Emitting monitoring event: {event}")
async with self.report_status_lock:
try:
await self.protocol.emit_monitoring_event(event)
Expand All @@ -202,11 +263,11 @@ async def monitor_loop(self, done: Event):
if not self.process:
return

error_info = self.process.get_last_error()
if error_info:
error_msg, error_time = error_info
self.status.last_error = error_msg
self.status.last_error_time = error_time
last_error = self.process.get_last_error()
if last_error:
error_msg, error_time = last_error
self.status.inference_status.last_error = error_msg
self.status.inference_status.last_error_time = error_time
await self._emit_monitoring_event({
"type": "error",
"pipeline": self.pipeline,
Expand All @@ -215,8 +276,8 @@ async def monitor_loop(self, done: Event):
})

current_time = time.time()
last_input_time = self.status.last_input_time or start_time
last_output_time = self.status.last_output_time or start_time
last_input_time = self.status.input_status.last_input_time or start_time
last_output_time = self.status.inference_status.last_output_time or start_time
last_params_update_time = self.status.last_params_update_time or start_time

time_since_last_input = current_time - last_input_time
Expand Down Expand Up @@ -280,14 +341,14 @@ async def run_ingress_loop(self, done: Event):

logging.debug(f"Sending input frame. Scaled from {width}x{height} to {frame.size[0]}x{frame.size[1]}")
self.process.send_input(frame)
self.status.last_input_time = time.time() # Track time after send completes
self.status.input_status.last_input_time = time.time() # Track time after send completes

# Increment frame count and measure FPS
frame_count += 1
elapsed_time = time.time() - start_time
if elapsed_time >= fps_log_interval:
self.status.input_fps = frame_count / elapsed_time
logging.info(f"Input FPS: {self.status.input_fps:.2f}")
self.status.input_status.fps = frame_count / elapsed_time
logging.info(f"Input FPS: {self.status.input_status.fps:.2f}")
frame_count = 0
start_time = time.time()
# automatically stop the streamer when the ingress ends cleanly
Expand All @@ -308,7 +369,7 @@ async def gen_output_frames() -> AsyncGenerator[Image.Image, None]:
if not output_image:
break

self.status.last_output_time = time.time() # Track time after receive completes
self.status.inference_status.last_output_time = time.time() # Track time after receive completes
logging.debug(
f"Output image received out_width: {output_image.width}, out_height: {output_image.height}"
)
Expand All @@ -319,8 +380,8 @@ async def gen_output_frames() -> AsyncGenerator[Image.Image, None]:
frame_count += 1
elapsed_time = time.time() - start_time
if elapsed_time >= fps_log_interval:
self.status.output_fps = frame_count / elapsed_time
logging.info(f"Output FPS: {self.status.output_fps:.2f}")
self.status.inference_status.fps = frame_count / elapsed_time
logging.info(f"Output FPS: {self.status.inference_status.fps:.2f}")
frame_count = 0
start_time = time.time()

Expand Down
Loading