From 10732af5deacbaa34613c8629e14c61de2c119a2 Mon Sep 17 00:00:00 2001 From: Sergei Zinchenko Date: Sun, 23 Jun 2024 05:41:04 +0200 Subject: [PATCH 1/2] [model_lock_per_request] model lock rewrote to be async lock over request --- llama_cpp/server/__main__.py | 20 +- llama_cpp/server/app.py | 507 +++++++++++++++++------------------ llama_cpp/server/errors.py | 3 +- llama_cpp/server/settings.py | 33 ++- llama_cpp/server/types.py | 49 +++- 5 files changed, 308 insertions(+), 304 deletions(-) diff --git a/llama_cpp/server/__main__.py b/llama_cpp/server/__main__.py index a6f1f4e9c..5a7860113 100644 --- a/llama_cpp/server/__main__.py +++ b/llama_cpp/server/__main__.py @@ -34,7 +34,7 @@ Settings, ServerSettings, ModelSettings, - ConfigFileSettings, + read_config, ) from llama_cpp.server.cli import add_args_from_model, parse_model_from_args @@ -49,28 +49,12 @@ def main(): type=str, help="Path to a config file to load.", ) - server_settings: ServerSettings | None = None - model_settings: list[ModelSettings] = [] args = parser.parse_args() try: # Load server settings from config_file if provided config_file = os.environ.get("CONFIG_FILE", args.config_file) if config_file: - if not os.path.exists(config_file): - raise ValueError(f"Config file {config_file} not found!") - with open(config_file, "rb") as f: - # Check if yaml file - if config_file.endswith(".yaml") or config_file.endswith(".yml"): - import yaml - import json - - config_file_settings = ConfigFileSettings.model_validate_json( - json.dumps(yaml.safe_load(f)) - ) - else: - config_file_settings = ConfigFileSettings.model_validate_json(f.read()) - server_settings = ServerSettings.model_validate(config_file_settings) - model_settings = config_file_settings.models + server_settings, model_settings = read_config(config_file) else: server_settings = parse_model_from_args(ServerSettings, args) model_settings = [parse_model_from_args(ModelSettings, args)] diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 4cda4af7a..4454c29f4 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -1,30 +1,30 @@ from __future__ import annotations +import collections import os import json - -from threading import Lock +from typing import Annotated, Callable +import asyncio from functools import partial from typing import Iterator, List, Optional, Union, Dict import llama_cpp - import anyio from anyio.streams.memory import MemoryObjectSendStream from starlette.concurrency import run_in_threadpool, iterate_in_threadpool from fastapi import Depends, FastAPI, APIRouter, Request, HTTPException, status, Body from fastapi.middleware import Middleware from fastapi.middleware.cors import CORSMiddleware -from fastapi.security import HTTPBearer +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from sse_starlette.sse import EventSourceResponse from starlette_context.plugins import RequestIdPlugin # type: ignore from starlette_context.middleware import RawContextMiddleware +from llama_cpp.server.settings import read_config from llama_cpp.server.model import ( LlamaProxy, ) from llama_cpp.server.settings import ( - ConfigFileSettings, Settings, ModelSettings, ServerSettings, @@ -42,7 +42,6 @@ ) from llama_cpp.server.errors import RouteErrorHandler - router = APIRouter(route_class=RouteErrorHandler) _server_settings: Optional[ServerSettings] = None @@ -57,64 +56,49 @@ def get_server_settings(): yield _server_settings -_llama_proxy: Optional[LlamaProxy] = None +_llama_proxy_context_manager: Optional[LlamaProxyContextManager] = None -llama_outer_lock = Lock() -llama_inner_lock = Lock() +def set_llama_proxy_context_manage(model_settings: List[ModelSettings]): + global _llama_proxy_context_manager + _llama_proxy_context_manager = LlamaProxyContextManager(model_settings) -def set_llama_proxy(model_settings: List[ModelSettings]): - global _llama_proxy - _llama_proxy = LlamaProxy(models=model_settings) +def get_llama_proxy_context_manager(): + return _llama_proxy_context_manager -def get_llama_proxy(): - # NOTE: This double lock allows the currently streaming llama model to - # check if any other requests are pending in the same thread and cancel - # the stream if so. - llama_outer_lock.acquire() - release_outer_lock = True - try: - llama_inner_lock.acquire() - try: - llama_outer_lock.release() - release_outer_lock = False - yield _llama_proxy - finally: - llama_inner_lock.release() - finally: - if release_outer_lock: - llama_outer_lock.release() + +class LlamaProxyContextManager: + _llama_proxy: LlamaProxy + _lock = asyncio.Lock() + + def __init__(self, model_settings: List[ModelSettings]): + self._llama_proxy = LlamaProxy(models=model_settings) + + async def __aenter__(self) -> LlamaProxy: + await self._lock.acquire() + return self._llama_proxy + + async def __aexit__(self, exc_type, exc, tb): + self._lock.release() _ping_message_factory = None + def set_ping_message_factory(factory): - global _ping_message_factory - _ping_message_factory = factory + global _ping_message_factory + _ping_message_factory = factory def create_app( - settings: Settings | None = None, - server_settings: ServerSettings | None = None, - model_settings: List[ModelSettings] | None = None, + settings: Settings | None = None, + server_settings: ServerSettings | None = None, + model_settings: List[ModelSettings] | None = None, ): config_file = os.environ.get("CONFIG_FILE", None) if config_file is not None: - if not os.path.exists(config_file): - raise ValueError(f"Config file {config_file} not found!") - with open(config_file, "rb") as f: - # Check if yaml file - if config_file.endswith(".yaml") or config_file.endswith(".yml"): - import yaml - - config_file_settings = ConfigFileSettings.model_validate_json( - json.dumps(yaml.safe_load(f)) - ) - else: - config_file_settings = ConfigFileSettings.model_validate_json(f.read()) - server_settings = ServerSettings.model_validate(config_file_settings) - model_settings = config_file_settings.models + server_settings, model_settings = read_config(config_file) if server_settings is None and model_settings is None: if settings is None: @@ -123,7 +107,7 @@ def create_app( model_settings = [ModelSettings.model_validate(settings)] assert ( - server_settings is not None and model_settings is not None + server_settings is not None and model_settings is not None ), "server_settings and model_settings must be provided together" set_server_settings(server_settings) @@ -144,7 +128,7 @@ def create_app( app.include_router(router) assert model_settings is not None - set_llama_proxy(model_settings=model_settings) + set_llama_proxy_context_manage(model_settings=model_settings) if server_settings.disable_ping_events: set_ping_message_factory(lambda: bytes()) @@ -153,22 +137,16 @@ def create_app( async def get_event_publisher( - request: Request, - inner_send_chan: MemoryObjectSendStream, - iterator: Iterator, + request: Request, + inner_send_chan: MemoryObjectSendStream, + iterator: collections.AsyncIterable, ): async with inner_send_chan: try: - async for chunk in iterate_in_threadpool(iterator): + async for chunk in iterator: await inner_send_chan.send(dict(data=json.dumps(chunk))) if await request.is_disconnected(): raise anyio.get_cancelled_exc_class()() - if ( - next(get_server_settings()).interrupt_requests - and llama_outer_lock.locked() - ): - await inner_send_chan.send(dict(data="[DONE]")) - raise anyio.get_cancelled_exc_class()() await inner_send_chan.send(dict(data="[DONE]")) except anyio.get_cancelled_exc_class() as e: print("disconnected") @@ -178,8 +156,8 @@ async def get_event_publisher( def _logit_bias_tokens_to_input_ids( - llama: llama_cpp.Llama, - logit_bias: Dict[str, float], + llama: llama_cpp.Llama, + logit_bias: Dict[str, float], ) -> Dict[str, float]: to_bias: Dict[str, float] = {} for token, score in logit_bias.items(): @@ -194,8 +172,8 @@ def _logit_bias_tokens_to_input_ids( async def authenticate( - settings: Settings = Depends(get_server_settings), - authorization: Optional[str] = Depends(bearer_scheme), + settings: Settings = Depends(get_server_settings), + authorization: Annotated[Optional[HTTPAuthorizationCredentials], Depends(bearer_scheme)] = None ): # Skip API key check if it's not set in settings if settings.api_key is None: @@ -216,6 +194,86 @@ async def authenticate( openai_v1_tag = "OpenAI V1" +async def completion_async_generator(req: Request, request_body: CreateCompletionRequest, + llama_proxy_context_manager: LlamaProxyContextManager, model_name: str, + exclude: set[str], method: Callable) -> collections.AsyncIterable: + kwargs = request_body.model_dump(exclude=exclude) + if await req.is_disconnected(): + raise anyio.get_cancelled_exc_class()() + async with llama_proxy_context_manager as llama_proxy: + llama = llama_proxy(model_name) + + if request_body.logit_bias is not None: + kwargs["logit_bias"] = ( + _logit_bias_tokens_to_input_ids(llama, request_body.logit_bias) + if request_body.logit_bias_type == "tokens" + else request_body.logit_bias + ) + + if request_body.grammar is not None: + kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(request_body.grammar) + + if request_body.min_tokens > 0: + _min_tokens_logits_processor = llama_cpp.LogitsProcessorList( + [llama_cpp.MinTokensLogitsProcessor(request_body.min_tokens, llama.token_eos())] + ) + if "logits_processor" not in kwargs: + kwargs["logits_processor"] = _min_tokens_logits_processor + else: + kwargs["logits_processor"].extend(_min_tokens_logits_processor) + if await req.is_disconnected(): + raise anyio.get_cancelled_exc_class()() + iterator_or_completion: Union[ + llama_cpp.CreateCompletionResponse | + Iterator[llama_cpp.CreateCompletionStreamResponse], + ] = await run_in_threadpool(method, llama, **kwargs) + if await req.is_disconnected(): + raise anyio.get_cancelled_exc_class()() + if isinstance(iterator_or_completion, Iterator): + async for chunk in iterate_in_threadpool(iterator_or_completion): + if await req.is_disconnected(): + raise anyio.get_cancelled_exc_class()() + yield False, chunk + else: + yield True, iterator_or_completion + + +async def handle_completion_request(request: Request, request_body: CreateCompletionRequest, + llama_proxy_context_manager: LlamaProxyContextManager, model_name: str, + exclude: set[str], method: Callable): + completion_iter = await run_in_threadpool(completion_async_generator, request, request_body, + llama_proxy_context_manager, model_name, + exclude, method) + + first_response = None + complete_response = False + async for response in completion_iter: + complete_response, first_response = response + break + + if complete_response: + return first_response + + async def response_async_generator(): + yield first_response + async for cr, item in completion_iter: + yield item + + send_chan, recv_chan = anyio.create_memory_object_stream(10) + + return EventSourceResponse( + recv_chan, + data_sender_callable=partial( # type: ignore + get_event_publisher, + request=request, + inner_send_chan=send_chan, + iterator=response_async_generator(), + ), + sep="\n", + ping_message_factory=_ping_message_factory, + ) + + @router.post( "/v1/completions", summary="Completion", @@ -240,8 +298,11 @@ async def authenticate( "schema": { "type": "string", "title": "Server Side Streaming response, when stream=True. " - + "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501 - "example": """data: {... see CreateCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""", + + "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server" + "-sent_events/Using_server-sent_events#Event_stream_format", + # noqa: E501 + "example": """data: {... see CreateCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: + [DONE]""", } }, }, @@ -256,20 +317,14 @@ async def authenticate( tags=[openai_v1_tag], ) async def create_completion( - request: Request, - body: CreateCompletionRequest, - llama_proxy: LlamaProxy = Depends(get_llama_proxy), + request: Request, + body: CreateCompletionRequest, + llama_proxy_context_manager: LlamaProxyContextManager = Depends(get_llama_proxy_context_manager), ) -> llama_cpp.Completion: if isinstance(body.prompt, list): assert len(body.prompt) <= 1 body.prompt = body.prompt[0] if len(body.prompt) > 0 else "" - llama = llama_proxy( - body.model - if request.url.path != "/v1/engines/copilot-codex/completions" - else "copilot-codex" - ) - exclude = { "n", "best_of", @@ -277,56 +332,14 @@ async def create_completion( "user", "min_tokens", } - kwargs = body.model_dump(exclude=exclude) - - if body.logit_bias is not None: - kwargs["logit_bias"] = ( - _logit_bias_tokens_to_input_ids(llama, body.logit_bias) - if body.logit_bias_type == "tokens" - else body.logit_bias - ) - if body.grammar is not None: - kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar) + model_name = body.model if request.url.path != "/v1/engines/copilot-codex/completions" else "copilot-codex" - if body.min_tokens > 0: - _min_tokens_logits_processor = llama_cpp.LogitsProcessorList( - [llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())] - ) - if "logits_processor" not in kwargs: - kwargs["logits_processor"] = _min_tokens_logits_processor - else: - kwargs["logits_processor"].extend(_min_tokens_logits_processor) + method = llama_cpp.Llama.create_completion - iterator_or_completion: Union[ - llama_cpp.CreateCompletionResponse, - Iterator[llama_cpp.CreateCompletionStreamResponse], - ] = await run_in_threadpool(llama, **kwargs) - - if isinstance(iterator_or_completion, Iterator): - # EAFP: It's easier to ask for forgiveness than permission - first_response = await run_in_threadpool(next, iterator_or_completion) - - # If no exception was raised from first_response, we can assume that - # the iterator is valid and we can use it to stream the response. - def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]: - yield first_response - yield from iterator_or_completion - - send_chan, recv_chan = anyio.create_memory_object_stream(10) - return EventSourceResponse( - recv_chan, - data_sender_callable=partial( # type: ignore - get_event_publisher, - request=request, - inner_send_chan=send_chan, - iterator=iterator(), - ), - sep="\n", - ping_message_factory=_ping_message_factory, - ) - else: - return iterator_or_completion + return await handle_completion_request(request, body, + llama_proxy_context_manager, model_name, + exclude, method) @router.post( @@ -336,13 +349,14 @@ def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]: tags=[openai_v1_tag], ) async def create_embedding( - request: CreateEmbeddingRequest, - llama_proxy: LlamaProxy = Depends(get_llama_proxy), + request: CreateEmbeddingRequest, + llama_proxy_context: LlamaProxyContextManager = Depends(get_llama_proxy_context_manager), ): - return await run_in_threadpool( - llama_proxy(request.model).create_embedding, - **request.model_dump(exclude={"user"}), - ) + async with llama_proxy_context as llama_proxy: + return await run_in_threadpool( + llama_proxy(request.model).create_embedding, + **request.model_dump(exclude={"user"}), + ) @router.post( @@ -368,8 +382,11 @@ async def create_embedding( "schema": { "type": "string", "title": "Server Side Streaming response, when stream=True" - + "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format", # noqa: E501 - "example": """data: {... see CreateChatCompletionResponse ...} \\n\\n data: ... \\n\\n ... data: [DONE]""", + + "See SSE format: https://developer.mozilla.org/en-US/docs/Web/API/Server" + "-sent_events/Using_server-sent_events#Event_stream_format", + # noqa: E501 + "example": """data: {... see CreateChatCompletionResponse ...} \\n\\n data: ... \\n\\n ... + data: [DONE]""", } }, }, @@ -378,78 +395,78 @@ async def create_embedding( tags=[openai_v1_tag], ) async def create_chat_completion( - request: Request, - body: CreateChatCompletionRequest = Body( - openapi_examples={ - "normal": { - "summary": "Chat Completion", - "value": { - "model": "gpt-3.5-turbo", - "messages": [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is the capital of France?"}, - ], + request: Request, + body: CreateChatCompletionRequest = Body( + openapi_examples={ + "normal": { + "summary": "Chat Completion", + "value": { + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + ], + }, }, - }, - "json_mode": { - "summary": "JSON Mode", - "value": { - "model": "gpt-3.5-turbo", - "messages": [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Who won the world series in 2020"}, - ], - "response_format": { "type": "json_object" } + "json_mode": { + "summary": "JSON Mode", + "value": { + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Who won the world series in 2020"}, + ], + "response_format": {"type": "json_object"} + }, }, - }, - "tool_calling": { - "summary": "Tool Calling", - "value": { - "model": "gpt-3.5-turbo", - "messages": [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Extract Jason is 30 years old."}, - ], - "tools": [ - { + "tool_calling": { + "summary": "Tool Calling", + "value": { + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Extract Jason is 30 years old."}, + ], + "tools": [ + { + "type": "function", + "function": { + "name": "User", + "description": "User record", + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "number"}, + }, + "required": ["name", "age"], + }, + } + } + ], + "tool_choice": { "type": "function", "function": { "name": "User", - "description": "User record", - "parameters": { - "type": "object", - "properties": { - "name": {"type": "string"}, - "age": {"type": "number"}, - }, - "required": ["name", "age"], - }, } } - ], - "tool_choice": { - "type": "function", - "function": { - "name": "User", - } - } + }, }, - }, - "logprobs": { - "summary": "Logprobs", - "value": { - "model": "gpt-3.5-turbo", - "messages": [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is the capital of France?"}, - ], - "logprobs": True, - "top_logprobs": 10 + "logprobs": { + "summary": "Logprobs", + "value": { + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + ], + "logprobs": True, + "top_logprobs": 10 + }, }, - }, - } - ), - llama_proxy: LlamaProxy = Depends(get_llama_proxy), + } + ), + llama_proxy_context_manager: LlamaProxyContextManager = Depends(get_llama_proxy_context_manager), ) -> llama_cpp.ChatCompletion: exclude = { "n", @@ -457,55 +474,14 @@ async def create_chat_completion( "user", "min_tokens", } - kwargs = body.model_dump(exclude=exclude) - llama = llama_proxy(body.model) - if body.logit_bias is not None: - kwargs["logit_bias"] = ( - _logit_bias_tokens_to_input_ids(llama, body.logit_bias) - if body.logit_bias_type == "tokens" - else body.logit_bias - ) - if body.grammar is not None: - kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar) + model_name = body.model - if body.min_tokens > 0: - _min_tokens_logits_processor = llama_cpp.LogitsProcessorList( - [llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())] - ) - if "logits_processor" not in kwargs: - kwargs["logits_processor"] = _min_tokens_logits_processor - else: - kwargs["logits_processor"].extend(_min_tokens_logits_processor) - - iterator_or_completion: Union[ - llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk] - ] = await run_in_threadpool(llama.create_chat_completion, **kwargs) - - if isinstance(iterator_or_completion, Iterator): - # EAFP: It's easier to ask for forgiveness than permission - first_response = await run_in_threadpool(next, iterator_or_completion) - - # If no exception was raised from first_response, we can assume that - # the iterator is valid and we can use it to stream the response. - def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]: - yield first_response - yield from iterator_or_completion - - send_chan, recv_chan = anyio.create_memory_object_stream(10) - return EventSourceResponse( - recv_chan, - data_sender_callable=partial( # type: ignore - get_event_publisher, - request=request, - inner_send_chan=send_chan, - iterator=iterator(), - ), - sep="\n", - ping_message_factory=_ping_message_factory, - ) - else: - return iterator_or_completion + method = llama_cpp.Llama.create_chat_completion + + return await handle_completion_request(request, body, + llama_proxy_context_manager, model_name, + exclude, method) @router.get( @@ -515,20 +491,21 @@ def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]: tags=[openai_v1_tag], ) async def get_models( - llama_proxy: LlamaProxy = Depends(get_llama_proxy), + llama_proxy_context_manager: LlamaProxyContextManager = Depends(get_llama_proxy_context_manager), ) -> ModelList: - return { - "object": "list", - "data": [ - { - "id": model_alias, - "object": "model", - "owned_by": "me", - "permissions": [], - } - for model_alias in llama_proxy - ], - } + async with llama_proxy_context_manager as llama_proxy: + return { + "object": "list", + "data": [ + { + "id": model_alias, + "object": "model", + "owned_by": "me", + "permissions": [], + } + for model_alias in llama_proxy + ], + } extras_tag = "Extras" @@ -541,11 +518,11 @@ async def get_models( tags=[extras_tag], ) async def tokenize( - body: TokenizeInputRequest, - llama_proxy: LlamaProxy = Depends(get_llama_proxy), + body: TokenizeInputRequest, + llama_proxy_context_manager: LlamaProxyContextManager = Depends(get_llama_proxy_context_manager), ) -> TokenizeInputResponse: - tokens = llama_proxy(body.model).tokenize(body.input.encode("utf-8"), special=True) - + async with llama_proxy_context_manager as llama_proxy: + tokens = llama_proxy(body.model).tokenize(body.input.encode("utf-8"), special=True) return TokenizeInputResponse(tokens=tokens) @@ -556,11 +533,11 @@ async def tokenize( tags=[extras_tag], ) async def count_query_tokens( - body: TokenizeInputRequest, - llama_proxy: LlamaProxy = Depends(get_llama_proxy), + body: TokenizeInputRequest, + llama_proxy_context_manager: LlamaProxyContextManager = Depends(get_llama_proxy_context_manager), ) -> TokenizeInputCountResponse: - tokens = llama_proxy(body.model).tokenize(body.input.encode("utf-8"), special=True) - + async with llama_proxy_context_manager as llama_proxy: + tokens = llama_proxy(body.model).tokenize(body.input.encode("utf-8"), special=True) return TokenizeInputCountResponse(count=len(tokens)) @@ -571,9 +548,9 @@ async def count_query_tokens( tags=[extras_tag], ) async def detokenize( - body: DetokenizeInputRequest, - llama_proxy: LlamaProxy = Depends(get_llama_proxy), + body: DetokenizeInputRequest, + llama_proxy_context_manager: LlamaProxyContextManager = Depends(get_llama_proxy_context_manager), ) -> DetokenizeInputResponse: - text = llama_proxy(body.model).detokenize(body.tokens).decode("utf-8") - + async with llama_proxy_context_manager as llama_proxy: + text = llama_proxy(body.model).detokenize(body.tokens).decode("utf-8") return DetokenizeInputResponse(text=text) diff --git a/llama_cpp/server/errors.py b/llama_cpp/server/errors.py index fbf9fd80d..a39b9dfaf 100644 --- a/llama_cpp/server/errors.py +++ b/llama_cpp/server/errors.py @@ -84,8 +84,7 @@ def context_length_exceeded( @staticmethod def model_not_found( - request: Union["CreateCompletionRequest", "CreateChatCompletionRequest"], - match, # type: Match[str] # type: ignore + match, # type: Match[str] # type: ignore ) -> Tuple[int, ErrorResponse]: """Formatter for model_not_found error""" diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py index 4d924f337..0806be3a4 100644 --- a/llama_cpp/server/settings.py +++ b/llama_cpp/server/settings.py @@ -1,6 +1,7 @@ from __future__ import annotations import multiprocessing +import os from typing import Optional, List, Literal, Union, Dict, cast from typing_extensions import Self @@ -56,7 +57,8 @@ class ModelSettings(BaseSettings): ) kv_overrides: Optional[List[str]] = Field( default=None, - description="List of model kv overrides in the format key=type:value where type is one of (bool, int, float). Valid true values are (true, TRUE, 1), otherwise false.", + description="List of model kv overrides in the format key=type:value where type is one of (bool, int, " + "float). Valid true values are (true, TRUE, 1), otherwise false.", ) rpc_servers: Optional[str] = Field( default=None, @@ -112,7 +114,8 @@ class ModelSettings(BaseSettings): # LoRA Params lora_base: Optional[str] = Field( default=None, - description="Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.", + description="Optional path to base model, useful if using a quantized base model and you want to apply LoRA " + "to an f16 model.", ) lora_path: Optional[str] = Field( default=None, @@ -152,7 +155,8 @@ class ModelSettings(BaseSettings): ) hf_pretrained_model_name_or_path: Optional[str] = Field( default=None, - description="The model name or path to a pretrained HuggingFace tokenizer model. Same as you would pass to AutoTokenizer.from_pretrained().", + description="The model name or path to a pretrained HuggingFace tokenizer model. Same as you would pass to " + "AutoTokenizer.from_pretrained().", ) # Loading from HuggingFace Model Hub hf_model_repo_id: Optional[str] = Field( @@ -211,10 +215,6 @@ class ServerSettings(BaseSettings): default=None, description="API key for authentication. If set all requests need to be authenticated.", ) - interrupt_requests: bool = Field( - default=True, - description="Whether to interrupt requests when a new request is received.", - ) disable_ping_events: bool = Field( default=False, description="Disable EventSource pings (may be needed for some clients).", @@ -233,3 +233,22 @@ class ConfigFileSettings(ServerSettings): """Configuration file format settings.""" models: List[ModelSettings] = Field(default=[], description="Model configs") + + +def read_config(config_file: str) -> tuple[ServerSettings, list[ModelSettings]]: + if not os.path.exists(config_file): + raise ValueError(f"Config file {config_file} not found!") + with open(config_file, "rb") as f: + # Check if yaml file + if config_file.endswith(".yaml") or config_file.endswith(".yml"): + import yaml + import json + + config_file_settings = ConfigFileSettings.model_validate_json( + json.dumps(yaml.safe_load(f)) + ) + else: + config_file_settings = ConfigFileSettings.model_validate_json(f.read()) + server_settings = ServerSettings.model_validate(config_file_settings) + model_settings = config_file_settings.models + return server_settings, model_settings diff --git a/llama_cpp/server/types.py b/llama_cpp/server/types.py index a75f9e55b..60d4d8f3b 100644 --- a/llama_cpp/server/types.py +++ b/llama_cpp/server/types.py @@ -19,21 +19,32 @@ min_tokens_field = Field( default=0, ge=0, - description="The minimum number of tokens to generate. It may return fewer tokens if another condition is met (e.g. max_tokens, stop).", + description="The minimum number of tokens to generate. It may return fewer tokens if another condition is met (" + "e.g. max_tokens, stop).", ) temperature_field = Field( default=0.8, description="Adjust the randomness of the generated text.\n\n" - + "Temperature is a hyperparameter that controls the randomness of the generated text. It affects the probability distribution of the model's output tokens. A higher temperature (e.g., 1.5) makes the output more random and creative, while a lower temperature (e.g., 0.5) makes the output more focused, deterministic, and conservative. The default value is 0.8, which provides a balance between randomness and determinism. At the extreme, a temperature of 0 will always pick the most likely next token, leading to identical outputs in each run.", + + "Temperature is a hyperparameter that controls the randomness of the generated text. It affects the probability " + "distribution of the model's output tokens. A higher temperature (e.g., 1.5) makes the output more random and " + "creative, while a lower temperature (e.g., 0.5) makes the output more focused, deterministic, " + "and conservative. The default value is 0.8, which provides a balance between randomness and determinism. At " + "the extreme, a temperature of 0 will always pick the most likely next token, leading to identical outputs in " + "each run.", ) top_p_field = Field( default=0.95, ge=0.0, le=1.0, - description="Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P.\n\n" - + "Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token from a subset of tokens that together have a cumulative probability of at least p. This method provides a balance between diversity and quality by considering both the probabilities of tokens and the number of tokens to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text.", + description="Limit the next token selection to a subset of tokens with a cumulative probability above a threshold " + "P.\n\n" + + "Top-p sampling, also known as nucleus sampling, is another text generation method that selects the next token " + "from a subset of tokens that together have a cumulative probability of at least p. This method provides a " + "balance between diversity and quality by considering both the probabilities of tokens and the number of tokens " + "to sample from. A higher value for top_p (e.g., 0.95) will lead to more diverse text, while a lower value (" + "e.g., 0.5) will generate more focused and conservative text.", ) min_p_field = Field( @@ -41,7 +52,10 @@ ge=0.0, le=1.0, description="Sets a minimum base probability threshold for token selection.\n\n" - + "The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and variety. The parameter min_p represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with min_p=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out.", + + "The Min-P sampling method was designed as an alternative to Top-P, and aims to ensure a balance of quality and " + "variety. The parameter min_p represents the minimum probability for a token to be considered, relative to the " + "probability of the most likely token. For example, with min_p=0.05 and the most likely token having a " + "probability of 0.9, logits with a value less than 0.045 are filtered out.", ) stop_field = Field( @@ -58,28 +72,37 @@ default=40, ge=0, description="Limit the next token selection to the K most probable tokens.\n\n" - + "Top-k sampling is a text generation method that selects the next token only from the top k most likely tokens predicted by the model. It helps reduce the risk of generating low-probability or nonsensical tokens, but it may also limit the diversity of the output. A higher value for top_k (e.g., 100) will consider more tokens and lead to more diverse text, while a lower value (e.g., 10) will focus on the most probable tokens and generate more conservative text.", + + "Top-k sampling is a text generation method that selects the next token only from the top k most likely tokens " + "predicted by the model. It helps reduce the risk of generating low-probability or nonsensical tokens, " + "but it may also limit the diversity of the output. A higher value for top_k (e.g., 100) will consider more " + "tokens and lead to more diverse text, while a lower value (e.g., 10) will focus on the most probable tokens " + "and generate more conservative text.", ) repeat_penalty_field = Field( default=1.1, ge=0.0, - description="A penalty applied to each token that is already generated. This helps prevent the model from repeating itself.\n\n" - + "Repeat penalty is a hyperparameter used to penalize the repetition of token sequences during text generation. It helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient.", + description="A penalty applied to each token that is already generated. This helps prevent the model from " + "repeating itself.\n\n" + + "Repeat penalty is a hyperparameter used to penalize the repetition of token sequences during text generation. " + "It helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., " + "1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient.", ) presence_penalty_field = Field( default=0.0, ge=-2.0, le=2.0, - description="Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.", + description="Positive values penalize new tokens based on whether they appear in the text so far, increasing the " + "model's likelihood to talk about new topics.", ) frequency_penalty_field = Field( default=0.0, ge=-2.0, le=2.0, - description="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.", + description="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing " + "the model's likelihood to repeat the same line verbatim.", ) mirostat_mode_field = Field( @@ -93,7 +116,8 @@ default=5.0, ge=0.0, le=10.0, - description="Mirostat target entropy, i.e. the target perplexity - lower values produce focused and coherent text, larger values produce more diverse and less coherent text", + description="Mirostat target entropy, i.e. the target perplexity - lower values produce focused and coherent " + "text, larger values produce more diverse and less coherent text", ) mirostat_eta_field = Field( @@ -221,7 +245,8 @@ class CreateChatCompletionRequest(BaseModel): top_logprobs: Optional[int] = Field( default=None, ge=0, - description="The number of logprobs to generate. If None, no logprobs are generated. logprobs need to set to True.", + description="The number of logprobs to generate. If None, no logprobs are generated. logprobs need to set to " + "True.", ) temperature: float = temperature_field top_p: float = top_p_field From 71e28b74f3ca06509bc9c7311d63761c2778b744 Mon Sep 17 00:00:00 2001 From: Sergei Zinchenko Date: Sun, 23 Jun 2024 06:19:46 +0200 Subject: [PATCH 2/2] [model_lock_per_request] added limit_concurrency for uvicorn --- llama_cpp/server/__main__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llama_cpp/server/__main__.py b/llama_cpp/server/__main__.py index 5a7860113..6d1101f91 100644 --- a/llama_cpp/server/__main__.py +++ b/llama_cpp/server/__main__.py @@ -74,6 +74,7 @@ def main(): port=int(os.getenv("PORT", server_settings.port)), ssl_keyfile=server_settings.ssl_keyfile, ssl_certfile=server_settings.ssl_certfile, + limit_concurrency=10 )