Skip to content

Commit a77e7d1

Browse files
committed
Enhance async handling in generate_openai_chat_completion to support async generators and improve response streaming
1 parent 560b549 commit a77e7d1

File tree

1 file changed

+85
-26
lines changed

1 file changed

+85
-26
lines changed

main.py

Lines changed: 85 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -668,24 +668,67 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
668668
detail=f"Pipeline {form_data.model} not found",
669669
)
670670

671-
# Get the pipeline details
672671
pipeline = app.state.PIPELINES[form_data.model]
673672
pipeline_id = form_data.model
674673

675-
# Get the appropriate pipe function
676674
if pipeline["type"] == "manifold":
677675
manifold_id, pipeline_id = pipeline_id.split(".", 1)
678676
pipe = PIPELINE_MODULES[manifold_id].pipe
679677
else:
680678
pipe = PIPELINE_MODULES[pipeline_id].pipe
681679

682-
# Check if pipe is async or sync
683680
is_async = inspect.iscoroutinefunction(pipe)
681+
is_async_gen = inspect.isasyncgenfunction(pipe)
682+
683+
# Helper function to ensure line is a string
684+
def ensure_string(line):
685+
if isinstance(line, bytes):
686+
return line.decode("utf-8")
687+
return str(line)
684688

685689
if form_data.stream:
686690
async def stream_content():
687-
# Handle async pipe
688-
if is_async:
691+
if is_async_gen:
692+
pipe_gen = pipe(
693+
user_message=user_message,
694+
model_id=pipeline_id,
695+
messages=messages,
696+
body=form_data.model_dump(),
697+
)
698+
699+
async for line in pipe_gen:
700+
if isinstance(line, BaseModel):
701+
line = line.model_dump_json()
702+
line = f"data: {line}"
703+
704+
line = ensure_string(line)
705+
logging.info(f"stream_content:AsyncGeneratorFunction:{line}")
706+
707+
if line.startswith("data:"):
708+
yield f"{line}\n\n"
709+
else:
710+
line = stream_message_template(form_data.model, line)
711+
yield f"data: {json.dumps(line)}\n\n"
712+
713+
finish_message = {
714+
"id": f"{form_data.model}-{str(uuid.uuid4())}",
715+
"object": "chat.completion.chunk",
716+
"created": int(time.time()),
717+
"model": form_data.model,
718+
"choices": [
719+
{
720+
"index": 0,
721+
"delta": {},
722+
"logprobs": None,
723+
"finish_reason": "stop",
724+
}
725+
],
726+
}
727+
728+
yield f"data: {json.dumps(finish_message)}\n\n"
729+
yield f"data: [DONE]"
730+
731+
elif is_async:
689732
res = await pipe(
690733
user_message=user_message,
691734
model_id=pipeline_id,
@@ -695,24 +738,18 @@ async def stream_content():
695738

696739
logging.info(f"stream:true:async:{res}")
697740

698-
# Handle async string response
699741
if isinstance(res, str):
700742
message = stream_message_template(form_data.model, res)
701743
logging.info(f"stream_content:str:async:{message}")
702744
yield f"data: {json.dumps(message)}\n\n"
703745

704-
# Handle async generators/iterators
705746
elif inspect.isasyncgen(res):
706747
async for line in res:
707748
if isinstance(line, BaseModel):
708749
line = line.model_dump_json()
709750
line = f"data: {line}"
710751

711-
try:
712-
line = line.decode("utf-8")
713-
except:
714-
pass
715-
752+
line = ensure_string(line)
716753
logging.info(f"stream_content:AsyncGenerator:{line}")
717754

718755
if line.startswith("data:"):
@@ -721,7 +758,6 @@ async def stream_content():
721758
line = stream_message_template(form_data.model, line)
722759
yield f"data: {json.dumps(line)}\n\n"
723760

724-
# Send finish message for async responses
725761
if isinstance(res, str) or inspect.isasyncgen(res):
726762
finish_message = {
727763
"id": f"{form_data.model}-{str(uuid.uuid4())}",
@@ -741,9 +777,7 @@ async def stream_content():
741777
yield f"data: {json.dumps(finish_message)}\n\n"
742778
yield f"data: [DONE]"
743779

744-
# Handle sync pipe (existing implementation)
745780
else:
746-
# Use a threadpool for synchronous functions to avoid blocking
747781
def sync_job():
748782
res = pipe(
749783
user_message=user_message,
@@ -767,11 +801,7 @@ def sync_job():
767801
line = line.model_dump_json()
768802
line = f"data: {line}"
769803

770-
try:
771-
line = line.decode("utf-8")
772-
except:
773-
pass
774-
804+
line = ensure_string(line)
775805
logging.info(f"stream_content:Generator:{line}")
776806

777807
if line.startswith("data:"):
@@ -801,9 +831,38 @@ def sync_job():
801831

802832
return StreamingResponse(stream_content(), media_type="text/event-stream")
803833
else:
804-
# Non-streaming response
805-
if is_async:
806-
# Handle async pipe for non-streaming case
834+
if is_async_gen:
835+
pipe_gen = pipe(
836+
user_message=user_message,
837+
model_id=pipeline_id,
838+
messages=messages,
839+
body=form_data.model_dump(),
840+
)
841+
842+
message = ""
843+
async for stream in pipe_gen:
844+
stream = ensure_string(stream)
845+
message = f"{message}{stream}"
846+
847+
logging.info(f"stream:false:async_gen_function:{message}")
848+
return {
849+
"id": f"{form_data.model}-{str(uuid.uuid4())}",
850+
"object": "chat.completion",
851+
"created": int(time.time()),
852+
"model": form_data.model,
853+
"choices": [
854+
{
855+
"index": 0,
856+
"message": {
857+
"role": "assistant",
858+
"content": message,
859+
},
860+
"logprobs": None,
861+
"finish_reason": "stop",
862+
}
863+
],
864+
}
865+
elif is_async:
807866
res = await pipe(
808867
user_message=user_message,
809868
model_id=pipeline_id,
@@ -822,9 +881,9 @@ def sync_job():
822881
if isinstance(res, str):
823882
message = res
824883

825-
# Handle async generator
826884
elif inspect.isasyncgen(res):
827885
async for stream in res:
886+
stream = ensure_string(stream)
828887
message = f"{message}{stream}"
829888

830889
logging.info(f"stream:false:async:{message}")
@@ -846,7 +905,6 @@ def sync_job():
846905
],
847906
}
848907
else:
849-
# Use existing implementation for sync pipes
850908
def job():
851909
res = pipe(
852910
user_message=user_message,
@@ -868,6 +926,7 @@ def job():
868926

869927
if isinstance(res, Generator):
870928
for stream in res:
929+
stream = ensure_string(stream)
871930
message = f"{message}{stream}"
872931

873932
logging.info(f"stream:false:sync:{message}")
@@ -889,4 +948,4 @@ def job():
889948
],
890949
}
891950

892-
return await run_in_threadpool(job)
951+
return await run_in_threadpool(job)

0 commit comments

Comments
 (0)