Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 52 additions & 21 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,25 +51,31 @@ class VideoStreamTrack(MediaStreamTrack):

kind = "video"

def __init__(self, track: MediaStreamTrack, pipeline: Pipeline):
def __init__(self, track: MediaStreamTrack | None, pipeline: Pipeline):
"""Initialize the VideoStreamTrack.

Args:
track: The underlying media stream track.
track: The underlying media stream track (None if generative).
pipeline: The processing pipeline to apply to each video frame.
"""
super().__init__()
self.track = track
self.pipeline = pipeline
self.fps_meter = FPSMeter(metrics_manager=app["metrics_manager"], track_id=track.id)

track_id = track.id if track else self.id
self.fps_meter = FPSMeter(metrics_manager=app["metrics_manager"], track_id=track_id)
self.running = True
self.collect_task = asyncio.create_task(self.collect_frames())

# Add cleanup when track ends
@track.on("ended")
async def on_ended():
logger.info("Source video track ended, stopping collection")
await cancel_collect_frames(self)

if track:
self.collect_task = asyncio.create_task(self.collect_frames())

# Add cleanup when track ends
@track.on("ended")
async def on_ended():
logger.info("Source video track ended, stopping collection")
await cancel_collect_frames(self)
else:
self.collect_task = None

async def collect_frames(self):
"""Collect video frames from the underlying track and pass them to
Expand Down Expand Up @@ -153,19 +159,25 @@ async def recv(self):
class AudioStreamTrack(MediaStreamTrack):
kind = "audio"

def __init__(self, track: MediaStreamTrack, pipeline):
def __init__(self, track: MediaStreamTrack | None, pipeline):
super().__init__()
self.track = track
self.pipeline = pipeline
self.running = True
logger.info(f"AudioStreamTrack created for track {track.id}")
self.collect_task = asyncio.create_task(self.collect_frames())

# Add cleanup when track ends
@track.on("ended")
async def on_ended():
logger.info("Source audio track ended, stopping collection")
await cancel_collect_frames(self)

track_id = track.id if track else self.id
logger.info(f"AudioStreamTrack created for track {track_id}")

if track:
self.collect_task = asyncio.create_task(self.collect_frames())

# Add cleanup when track ends
@track.on("ended")
async def on_ended():
logger.info("Source audio track ended, stopping collection")
await cancel_collect_frames(self)
else:
self.collect_task = None

async def collect_frames(self):
"""Collect audio frames from the underlying track and pass them to
Expand Down Expand Up @@ -285,7 +297,18 @@ async def offer(request):
# Add transceivers for both audio and video if present in the offer
if "m=video" in offer.sdp:
logger.debug("[Offer] Adding video transceiver")
video_transceiver = pc.addTransceiver("video", direction="sendrecv")

track_or_kind = "video"
if not is_noop_mode and not pipeline.accepts_video_input() and pipeline.produces_video_output():
logger.info("[Offer] Creating Generative Video Track")
gen_track = VideoStreamTrack(None, pipeline)
tracks["video"] = gen_track
track_or_kind = gen_track

# Store video track in app for stats
request.app["video_tracks"][gen_track.id] = gen_track

video_transceiver = pc.addTransceiver(track_or_kind, direction="sendrecv")
caps = RTCRtpSender.getCapabilities("video")
prefs = list(filter(lambda x: x.name == "H264", caps.codecs))
video_transceiver.setCodecPreferences(prefs)
Expand All @@ -296,7 +319,15 @@ async def offer(request):

if "m=audio" in offer.sdp:
logger.debug("[Offer] Adding audio transceiver")
audio_transceiver = pc.addTransceiver("audio", direction="sendrecv")

track_or_kind = "audio"
if not is_noop_mode and not pipeline.accepts_audio_input() and pipeline.produces_audio_output():
logger.info("[Offer] Creating Generative Audio Track")
gen_track = AudioStreamTrack(None, pipeline)
tracks["audio"] = gen_track
track_or_kind = gen_track

audio_transceiver = pc.addTransceiver(track_or_kind, direction="sendrecv")
audio_caps = RTCRtpSender.getCapabilities("audio")
# Prefer Opus for audio
audio_prefs = [codec for codec in audio_caps.codecs if codec.name == "opus"]
Expand Down
2 changes: 1 addition & 1 deletion src/comfystream/modalities.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class WorkflowModality(TypedDict):
"audio_input": {"LoadAudioTensor"},
"audio_output": {"SaveAudioTensor"},
# Text nodes
"text_input": set(), # No text input nodes currently
"text_input": {"PrimitiveString"}, # Basic text input node
"text_output": {"SaveTextTensor"},
}

Expand Down
14 changes: 14 additions & 0 deletions src/comfystream/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Dict, List, Optional, Set, Union

import av
from fractions import Fraction
import numpy as np
import torch

Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(
self._warmup_task: Optional[asyncio.Task] = None
self._warmup_completed = False
self._last_warmup_resolution: Optional[tuple[int, int]] = None
self._generated_pts = 0

@property
def state(self) -> PipelineState:
Expand Down Expand Up @@ -682,6 +684,18 @@ async def get_processed_video_frame(self) -> av.VideoFrame:
Returns:
The processed video frame, or original frame if no processing needed
"""
# Handle generative video case (no input, but produces output)
if not self.accepts_video_input() and self.produces_video_output():
async with temporary_log_level("comfy", self._comfyui_inference_log_level):
out_tensor = await self.client.get_video_output()

processed_frame = self.video_postprocess(out_tensor)
processed_frame.pts = self._generated_pts
processed_frame.time_base = Fraction(1, 30)
self._generated_pts += 1

return processed_frame

frame = await self.video_incoming_frames.get()

# Skip frames that were marked as skipped
Expand Down