Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Avoid thread starvation on many concurrent requests by making use of asyncio to lock llama_proxy context #1798

Merged
34 changes: 14 additions & 20 deletions llama_cpp/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import typing
import contextlib

from threading import Lock
from anyio import Lock
from functools import partial
from typing import Iterator, List, Optional, Union, Dict

Expand Down Expand Up @@ -70,14 +70,14 @@ def set_llama_proxy(model_settings: List[ModelSettings]):
_llama_proxy = LlamaProxy(models=model_settings)


def get_llama_proxy():
async def get_llama_proxy():
# NOTE: This double lock allows the currently streaming llama model to
# check if any other requests are pending in the same thread and cancel
# the stream if so.
llama_outer_lock.acquire()
await llama_outer_lock.acquire()
release_outer_lock = True
try:
llama_inner_lock.acquire()
await llama_inner_lock.acquire()
try:
llama_outer_lock.release()
release_outer_lock = False
Expand Down Expand Up @@ -159,7 +159,7 @@ async def get_event_publisher(
request: Request,
inner_send_chan: MemoryObjectSendStream[typing.Any],
iterator: Iterator[typing.Any],
on_complete: typing.Optional[typing.Callable[[], None]] = None,
on_complete: typing.Optional[typing.Callable[[], typing.Awaitable[None]]] = None,
):
server_settings = next(get_server_settings())
interrupt_requests = (
Expand All @@ -182,7 +182,7 @@ async def get_event_publisher(
raise e
finally:
if on_complete:
on_complete()
await on_complete()


def _logit_bias_tokens_to_input_ids(
Expand Down Expand Up @@ -267,10 +267,8 @@ async def create_completion(
request: Request,
body: CreateCompletionRequest,
) -> llama_cpp.Completion:
exit_stack = contextlib.ExitStack()
llama_proxy = await run_in_threadpool(
lambda: exit_stack.enter_context(contextlib.contextmanager(get_llama_proxy)())
)
exit_stack = contextlib.AsyncExitStack()
llama_proxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
if llama_proxy is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
Expand Down Expand Up @@ -332,7 +330,6 @@ async def create_completion(
def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]:
yield first_response
yield from iterator_or_completion
exit_stack.close()

send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse(
Expand All @@ -342,13 +339,13 @@ def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]:
request=request,
inner_send_chan=send_chan,
iterator=iterator(),
on_complete=exit_stack.close,
on_complete=exit_stack.aclose,
),
sep="\n",
ping_message_factory=_ping_message_factory,
)
else:
exit_stack.close()
await exit_stack.aclose()
return iterator_or_completion


Expand Down Expand Up @@ -477,10 +474,8 @@ async def create_chat_completion(
# where the dependency is cleaned up before a StreamingResponse
# is complete.
# https://github.com/tiangolo/fastapi/issues/11143
exit_stack = contextlib.ExitStack()
llama_proxy = await run_in_threadpool(
lambda: exit_stack.enter_context(contextlib.contextmanager(get_llama_proxy)())
)
exit_stack = contextlib.AsyncExitStack()
llama_proxy = exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
if llama_proxy is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
Expand Down Expand Up @@ -530,7 +525,6 @@ async def create_chat_completion(
def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
yield first_response
yield from iterator_or_completion
exit_stack.close()

send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse(
Expand All @@ -540,13 +534,13 @@ def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
request=request,
inner_send_chan=send_chan,
iterator=iterator(),
on_complete=exit_stack.close,
on_complete=exit_stack.aclose,
),
sep="\n",
ping_message_factory=_ping_message_factory,
)
else:
exit_stack.close()
await exit_stack.aclose()
return iterator_or_completion


Expand Down