@@ -668,24 +668,67 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
668
668
detail = f"Pipeline { form_data .model } not found" ,
669
669
)
670
670
671
- # Get the pipeline details
672
671
pipeline = app .state .PIPELINES [form_data .model ]
673
672
pipeline_id = form_data .model
674
673
675
- # Get the appropriate pipe function
676
674
if pipeline ["type" ] == "manifold" :
677
675
manifold_id , pipeline_id = pipeline_id .split ("." , 1 )
678
676
pipe = PIPELINE_MODULES [manifold_id ].pipe
679
677
else :
680
678
pipe = PIPELINE_MODULES [pipeline_id ].pipe
681
679
682
- # Check if pipe is async or sync
683
680
is_async = inspect .iscoroutinefunction (pipe )
681
+ is_async_gen = inspect .isasyncgenfunction (pipe )
682
+
683
+ # Helper function to ensure line is a string
684
+ def ensure_string (line ):
685
+ if isinstance (line , bytes ):
686
+ return line .decode ("utf-8" )
687
+ return str (line )
684
688
685
689
if form_data .stream :
686
690
async def stream_content ():
687
- # Handle async pipe
688
- if is_async :
691
+ if is_async_gen :
692
+ pipe_gen = pipe (
693
+ user_message = user_message ,
694
+ model_id = pipeline_id ,
695
+ messages = messages ,
696
+ body = form_data .model_dump (),
697
+ )
698
+
699
+ async for line in pipe_gen :
700
+ if isinstance (line , BaseModel ):
701
+ line = line .model_dump_json ()
702
+ line = f"data: { line } "
703
+
704
+ line = ensure_string (line )
705
+ logging .info (f"stream_content:AsyncGeneratorFunction:{ line } " )
706
+
707
+ if line .startswith ("data:" ):
708
+ yield f"{ line } \n \n "
709
+ else :
710
+ line = stream_message_template (form_data .model , line )
711
+ yield f"data: { json .dumps (line )} \n \n "
712
+
713
+ finish_message = {
714
+ "id" : f"{ form_data .model } -{ str (uuid .uuid4 ())} " ,
715
+ "object" : "chat.completion.chunk" ,
716
+ "created" : int (time .time ()),
717
+ "model" : form_data .model ,
718
+ "choices" : [
719
+ {
720
+ "index" : 0 ,
721
+ "delta" : {},
722
+ "logprobs" : None ,
723
+ "finish_reason" : "stop" ,
724
+ }
725
+ ],
726
+ }
727
+
728
+ yield f"data: { json .dumps (finish_message )} \n \n "
729
+ yield f"data: [DONE]"
730
+
731
+ elif is_async :
689
732
res = await pipe (
690
733
user_message = user_message ,
691
734
model_id = pipeline_id ,
@@ -695,24 +738,18 @@ async def stream_content():
695
738
696
739
logging .info (f"stream:true:async:{ res } " )
697
740
698
- # Handle async string response
699
741
if isinstance (res , str ):
700
742
message = stream_message_template (form_data .model , res )
701
743
logging .info (f"stream_content:str:async:{ message } " )
702
744
yield f"data: { json .dumps (message )} \n \n "
703
745
704
- # Handle async generators/iterators
705
746
elif inspect .isasyncgen (res ):
706
747
async for line in res :
707
748
if isinstance (line , BaseModel ):
708
749
line = line .model_dump_json ()
709
750
line = f"data: { line } "
710
751
711
- try :
712
- line = line .decode ("utf-8" )
713
- except :
714
- pass
715
-
752
+ line = ensure_string (line )
716
753
logging .info (f"stream_content:AsyncGenerator:{ line } " )
717
754
718
755
if line .startswith ("data:" ):
@@ -721,7 +758,6 @@ async def stream_content():
721
758
line = stream_message_template (form_data .model , line )
722
759
yield f"data: { json .dumps (line )} \n \n "
723
760
724
- # Send finish message for async responses
725
761
if isinstance (res , str ) or inspect .isasyncgen (res ):
726
762
finish_message = {
727
763
"id" : f"{ form_data .model } -{ str (uuid .uuid4 ())} " ,
@@ -741,9 +777,7 @@ async def stream_content():
741
777
yield f"data: { json .dumps (finish_message )} \n \n "
742
778
yield f"data: [DONE]"
743
779
744
- # Handle sync pipe (existing implementation)
745
780
else :
746
- # Use a threadpool for synchronous functions to avoid blocking
747
781
def sync_job ():
748
782
res = pipe (
749
783
user_message = user_message ,
@@ -767,11 +801,7 @@ def sync_job():
767
801
line = line .model_dump_json ()
768
802
line = f"data: { line } "
769
803
770
- try :
771
- line = line .decode ("utf-8" )
772
- except :
773
- pass
774
-
804
+ line = ensure_string (line )
775
805
logging .info (f"stream_content:Generator:{ line } " )
776
806
777
807
if line .startswith ("data:" ):
@@ -801,9 +831,38 @@ def sync_job():
801
831
802
832
return StreamingResponse (stream_content (), media_type = "text/event-stream" )
803
833
else :
804
- # Non-streaming response
805
- if is_async :
806
- # Handle async pipe for non-streaming case
834
+ if is_async_gen :
835
+ pipe_gen = pipe (
836
+ user_message = user_message ,
837
+ model_id = pipeline_id ,
838
+ messages = messages ,
839
+ body = form_data .model_dump (),
840
+ )
841
+
842
+ message = ""
843
+ async for stream in pipe_gen :
844
+ stream = ensure_string (stream )
845
+ message = f"{ message } { stream } "
846
+
847
+ logging .info (f"stream:false:async_gen_function:{ message } " )
848
+ return {
849
+ "id" : f"{ form_data .model } -{ str (uuid .uuid4 ())} " ,
850
+ "object" : "chat.completion" ,
851
+ "created" : int (time .time ()),
852
+ "model" : form_data .model ,
853
+ "choices" : [
854
+ {
855
+ "index" : 0 ,
856
+ "message" : {
857
+ "role" : "assistant" ,
858
+ "content" : message ,
859
+ },
860
+ "logprobs" : None ,
861
+ "finish_reason" : "stop" ,
862
+ }
863
+ ],
864
+ }
865
+ elif is_async :
807
866
res = await pipe (
808
867
user_message = user_message ,
809
868
model_id = pipeline_id ,
@@ -822,9 +881,9 @@ def sync_job():
822
881
if isinstance (res , str ):
823
882
message = res
824
883
825
- # Handle async generator
826
884
elif inspect .isasyncgen (res ):
827
885
async for stream in res :
886
+ stream = ensure_string (stream )
828
887
message = f"{ message } { stream } "
829
888
830
889
logging .info (f"stream:false:async:{ message } " )
@@ -846,7 +905,6 @@ def sync_job():
846
905
],
847
906
}
848
907
else :
849
- # Use existing implementation for sync pipes
850
908
def job ():
851
909
res = pipe (
852
910
user_message = user_message ,
@@ -868,6 +926,7 @@ def job():
868
926
869
927
if isinstance (res , Generator ):
870
928
for stream in res :
929
+ stream = ensure_string (stream )
871
930
message = f"{ message } { stream } "
872
931
873
932
logging .info (f"stream:false:sync:{ message } " )
@@ -889,4 +948,4 @@ def job():
889
948
],
890
949
}
891
950
892
- return await run_in_threadpool (job )
951
+ return await run_in_threadpool (job )
0 commit comments