diff --git a/README.md b/README.md index 066f8ac..80d0c1d 100644 --- a/README.md +++ b/README.md @@ -193,6 +193,28 @@ python llm_correctness.py \ ``` + +### AzureAI Compatible APIs +```bash +export AZUREAI_API_KEY=secret_abcdefg +export AZUREAI_API_BASE="https://api.endpoints.ai.azure.com/v1" + +python token_benchmark_ray.py \ +--model "Llama-2-70b-chat" \ +--mean-input-tokens 550 \ +--stddev-input-tokens 150 \ +--mean-output-tokens 150 \ +--stddev-output-tokens 10 \ +--max-num-completed-requests 2 \ +--timeout 600 \ +--num-concurrent-requests 1 \ +--results-dir "result_outputs" \ +--llm-api azureai \ +--additional-sampling-params '{}' + +``` + + see `python token_benchmark_ray.py --help` for more details on the arguments. ## Correctness Test @@ -338,6 +360,7 @@ python llm_correctness.py \ ``` + ## Saving Results The results of the load test and correctness test are saved in the results directory specified by the `--results-dir` argument. The results are saved in 2 files, one with the summary metrics of the test, and one with metrics from each individual request that is returned. diff --git a/src/llmperf/common.py b/src/llmperf/common.py index 3efefa1..621c843 100644 --- a/src/llmperf/common.py +++ b/src/llmperf/common.py @@ -5,6 +5,7 @@ ) from llmperf.ray_clients.sagemaker_client import SageMakerClient from llmperf.ray_clients.vertexai_client import VertexAIClient +from llmperf.ray_clients.azureai_chat_completion import AzureAIChatCompletionsClient from llmperf.ray_llm_client import LLMClient @@ -28,6 +29,8 @@ def construct_clients(llm_api: str, num_clients: int) -> List[LLMClient]: clients = [SageMakerClient.remote() for _ in range(num_clients)] elif llm_api == "vertexai": clients = [VertexAIClient.remote() for _ in range(num_clients)] + elif llm_api == "azureai": + clients = [AzureAIChatCompletionsClient.remote() for _ in range(num_clients)] elif llm_api in SUPPORTED_APIS: clients = [LiteLLMClient.remote() for _ in range(num_clients)] else: diff --git a/src/llmperf/ray_clients/azureai_chat_completion.py b/src/llmperf/ray_clients/azureai_chat_completion.py new file mode 100644 index 0000000..04e5d2e --- /dev/null +++ b/src/llmperf/ray_clients/azureai_chat_completion.py @@ -0,0 +1,119 @@ +import json +import os +import time +from typing import Any, Dict + +import ray +import requests + +from llmperf.ray_llm_client import LLMClient +from llmperf.models import RequestConfig +from llmperf import common_metrics + +@ray.remote +class AzureAIChatCompletionsClient(LLMClient): + """Client for AzureAI Chat Completions API.""" + + def llm_request(self, request_config: RequestConfig) -> Dict[str, Any]: + prompt = request_config.prompt + prompt, prompt_len = prompt + + message = [ + {"role": "system", "content": ""}, + {"role": "user", "content": prompt}, + ] + model = request_config.model + body = { + "model": model, + "messages": message, + "stream": True, + } + sampling_params = request_config.sampling_params + body.update(sampling_params or {}) + time_to_next_token = [] + tokens_received = 0 + ttft = 0 + error_response_code = -1 + generated_text = "" + error_msg = "" + output_throughput = 0 + total_request_time = 0 + + metrics = {} + + metrics[common_metrics.ERROR_CODE] = None + metrics[common_metrics.ERROR_MSG] = "" + + start_time = time.monotonic() + most_recent_received_token_time = time.monotonic() + address = os.environ.get("AZUREAI_API_BASE") + if not address: + raise ValueError("the environment variable AZUREAI_API_BASE must be set.") + key = os.environ.get("AZUREAI_API_KEY") + if not key: + raise ValueError("the environment variable AZUREAI_API_KEY must be set.") + headers = {"Authorization": f"Bearer {key}"} + if not address: + raise ValueError("No host provided.") + if not address.endswith("/"): + address = address + "/" + address += "chat/completions" + try: + with requests.post( + address, + json=body, + stream=True, + timeout=180, + headers=headers, + ) as response: + if response.status_code != 200: + error_msg = response.text + error_response_code = response.status_code + response.raise_for_status() + for chunk in response.iter_lines(chunk_size=None): + chunk = chunk.strip() + + if not chunk: + continue + stem = "data: " + chunk = chunk[len(stem) :] + if chunk == b"[DONE]": + continue + tokens_received += 1 + data = json.loads(chunk) + + if "error" in data: + error_msg = data["error"]["message"] + error_response_code = data["error"]["code"] + raise RuntimeError(data["error"]["message"]) + + delta = data["choices"][0]["delta"] + if delta.get("content", None): + if not ttft: + ttft = time.monotonic() - start_time + time_to_next_token.append(ttft) + else: + time_to_next_token.append( + time.monotonic() - most_recent_received_token_time + ) + most_recent_received_token_time = time.monotonic() + generated_text += delta["content"] + + total_request_time = time.monotonic() - start_time + output_throughput = tokens_received / total_request_time + + except Exception as e: + metrics[common_metrics.ERROR_MSG] = error_msg + metrics[common_metrics.ERROR_CODE] = error_response_code + print(f"Warning Or Error: {e}") + print(error_response_code) + + metrics[common_metrics.INTER_TOKEN_LAT] = sum(time_to_next_token) #This should be same as metrics[common_metrics.E2E_LAT]. Leave it here for now + metrics[common_metrics.TTFT] = ttft + metrics[common_metrics.E2E_LAT] = total_request_time + metrics[common_metrics.REQ_OUTPUT_THROUGHPUT] = output_throughput + metrics[common_metrics.NUM_TOTAL_TOKENS] = tokens_received + prompt_len + metrics[common_metrics.NUM_OUTPUT_TOKENS] = tokens_received + metrics[common_metrics.NUM_INPUT_TOKENS] = prompt_len + + return metrics, generated_text, request_config