diff --git a/backend/app/api/routes/teams.py b/backend/app/api/routes/teams.py index 332fa9a..d2d9e66 100644 --- a/backend/app/api/routes/teams.py +++ b/backend/app/api/routes/teams.py @@ -224,6 +224,7 @@ async def public_stream( team_chat: TeamChatPublic, thread_id: str, team: CurrentTeam, + streaming: bool = True, ) -> StreamingResponse: """ Stream a response from a team using a given message or an interrupt decision. Requires an API key for authentication. @@ -233,6 +234,7 @@ async def public_stream( Parameters: - `team_id` (int): The ID of the team to which the message is being sent. Must be a valid team ID. - `thread_id` (str): The ID of the thread where the message will be posted. If the thread ID does not exist, a new thread will be created. + - `streaming` (bool, optional): A flag to enable or disable streaming mode. If `True` (default), the messages will be streamed in chunks. Request Body (JSON): - The request body should be a JSON object containing either the `message` or `interrupt` field: @@ -277,6 +279,6 @@ async def public_stream( messages = [team_chat.message] if team_chat.message else [] return StreamingResponse( - generator(team, members, messages, thread_id, team_chat.interrupt), + generator(team, members, messages, thread_id, team_chat.interrupt, streaming), media_type="text/event-stream", ) diff --git a/backend/app/core/graph/build.py b/backend/app/core/graph/build.py index d50bf4a..808dd01 100644 --- a/backend/app/core/graph/build.py +++ b/backend/app/core/graph/build.py @@ -502,6 +502,7 @@ async def generator( messages: list[ChatMessage], thread_id: str, interrupt: Interrupt | None = None, + streaming: bool = True, ) -> AsyncGenerator[Any, Any]: """Create the graph and stream responses as JSON.""" formatted_messages = [ @@ -602,7 +603,7 @@ async def generator( ] } async for event in root.astream_events(state, version="v2", config=config): - response = event_to_response(event) + response = event_to_response(event, streaming) if response: formatted_output = f"data: {response.model_dump_json()}\n\n" yield formatted_output diff --git a/backend/app/core/graph/messages.py b/backend/app/core/graph/messages.py index 7ae3dd8..f933d45 100644 --- a/backend/app/core/graph/messages.py +++ b/backend/app/core/graph/messages.py @@ -38,24 +38,30 @@ def get_message_type(message: Any) -> str | None: return None -def event_to_response(event: StreamEvent) -> ChatResponse | None: +def event_to_response(event: StreamEvent, streaming: bool) -> ChatResponse | None: """Convert event to ChatResponse""" kind = event["event"] id = event["run_id"] - if kind == "on_chat_model_stream": + # Either listen to stream or end based on streaming arg + chat_model_event_kind = "on_chat_model_stream" if streaming else "on_chat_model_end" + if kind == chat_model_event_kind: name = event["metadata"]["langgraph_node"] - message_chunk: AIMessageChunk = event["data"]["chunk"] - type = get_message_type(message_chunk) + chat_message: AIMessage | AIMessageChunk = ( + event["data"]["chunk"] + if kind == "on_chat_model_stream" + else event["data"]["output"] + ) + type = get_message_type(chat_message) content: str = "" - if isinstance(message_chunk.content, list): - for c in message_chunk.content: + if isinstance(chat_message.content, list): + for c in chat_message.content: if isinstance(c, str): content += c elif isinstance(c, dict): content += c.get("text", "") else: - content = message_chunk.content - tool_calls = message_chunk.tool_calls + content = chat_message.content + tool_calls = chat_message.tool_calls if content and type: return ChatResponse( type=type, id=id, name=name, content=content, tool_calls=tool_calls