55import typing
66import contextlib
77
8- from threading import Lock
8+ from anyio import Lock
99from functools import partial
1010from typing import Iterator , List , Optional , Union , Dict
1111
@@ -70,14 +70,14 @@ def set_llama_proxy(model_settings: List[ModelSettings]):
7070 _llama_proxy = LlamaProxy (models = model_settings )
7171
7272
73- def get_llama_proxy ():
73+ async def get_llama_proxy ():
7474 # NOTE: This double lock allows the currently streaming llama model to
7575 # check if any other requests are pending in the same thread and cancel
7676 # the stream if so.
77- llama_outer_lock .acquire ()
77+ await llama_outer_lock .acquire ()
7878 release_outer_lock = True
7979 try :
80- llama_inner_lock .acquire ()
80+ await llama_inner_lock .acquire ()
8181 try :
8282 llama_outer_lock .release ()
8383 release_outer_lock = False
@@ -159,7 +159,7 @@ async def get_event_publisher(
159159 request : Request ,
160160 inner_send_chan : MemoryObjectSendStream [typing .Any ],
161161 iterator : Iterator [typing .Any ],
162- on_complete : typing .Optional [typing .Callable [[], None ]] = None ,
162+ on_complete : typing .Optional [typing .Callable [[], typing . Awaitable [ None ] ]] = None ,
163163):
164164 server_settings = next (get_server_settings ())
165165 interrupt_requests = (
@@ -182,7 +182,7 @@ async def get_event_publisher(
182182 raise e
183183 finally :
184184 if on_complete :
185- on_complete ()
185+ await on_complete ()
186186
187187
188188def _logit_bias_tokens_to_input_ids (
@@ -267,10 +267,8 @@ async def create_completion(
267267 request : Request ,
268268 body : CreateCompletionRequest ,
269269) -> llama_cpp .Completion :
270- exit_stack = contextlib .ExitStack ()
271- llama_proxy = await run_in_threadpool (
272- lambda : exit_stack .enter_context (contextlib .contextmanager (get_llama_proxy )())
273- )
270+ exit_stack = contextlib .AsyncExitStack ()
271+ llama_proxy = await exit_stack .enter_async_context (contextlib .asynccontextmanager (get_llama_proxy )())
274272 if llama_proxy is None :
275273 raise HTTPException (
276274 status_code = status .HTTP_503_SERVICE_UNAVAILABLE ,
@@ -332,7 +330,6 @@ async def create_completion(
332330 def iterator () -> Iterator [llama_cpp .CreateCompletionStreamResponse ]:
333331 yield first_response
334332 yield from iterator_or_completion
335- exit_stack .close ()
336333
337334 send_chan , recv_chan = anyio .create_memory_object_stream (10 )
338335 return EventSourceResponse (
@@ -342,13 +339,13 @@ def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]:
342339 request = request ,
343340 inner_send_chan = send_chan ,
344341 iterator = iterator (),
345- on_complete = exit_stack .close ,
342+ on_complete = exit_stack .aclose ,
346343 ),
347344 sep = "\n " ,
348345 ping_message_factory = _ping_message_factory ,
349346 )
350347 else :
351- exit_stack .close ()
348+ await exit_stack .aclose ()
352349 return iterator_or_completion
353350
354351
@@ -477,10 +474,8 @@ async def create_chat_completion(
477474 # where the dependency is cleaned up before a StreamingResponse
478475 # is complete.
479476 # https://github.com/tiangolo/fastapi/issues/11143
480- exit_stack = contextlib .ExitStack ()
481- llama_proxy = await run_in_threadpool (
482- lambda : exit_stack .enter_context (contextlib .contextmanager (get_llama_proxy )())
483- )
477+ exit_stack = contextlib .AsyncExitStack ()
478+ llama_proxy = exit_stack .enter_async_context (contextlib .asynccontextmanager (get_llama_proxy )())
484479 if llama_proxy is None :
485480 raise HTTPException (
486481 status_code = status .HTTP_503_SERVICE_UNAVAILABLE ,
@@ -530,7 +525,6 @@ async def create_chat_completion(
530525 def iterator () -> Iterator [llama_cpp .ChatCompletionChunk ]:
531526 yield first_response
532527 yield from iterator_or_completion
533- exit_stack .close ()
534528
535529 send_chan , recv_chan = anyio .create_memory_object_stream (10 )
536530 return EventSourceResponse (
@@ -540,13 +534,13 @@ def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
540534 request = request ,
541535 inner_send_chan = send_chan ,
542536 iterator = iterator (),
543- on_complete = exit_stack .close ,
537+ on_complete = exit_stack .aclose ,
544538 ),
545539 sep = "\n " ,
546540 ping_message_factory = _ping_message_factory ,
547541 )
548542 else :
549- exit_stack .close ()
543+ await exit_stack .aclose ()
550544 return iterator_or_completion
551545
552546
0 commit comments