diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 1d2ceca..654a88e 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -6,17 +6,17 @@ jobs: lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: psf/black@stable test: runs-on: ubuntu-latest needs: lint steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: - python-version: '3.11' + python-version: '3.10' cache: 'pip' - name: Install pip dependencies run: | diff --git a/requirements.txt b/requirements.txt index 982f6b0..cd04fb3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,7 +15,7 @@ sentence-transformers snowflake-connector-python spacy sqlalchemy -tiktoken==0.7.0 +tiktoken together torch tqdm diff --git a/run_model_cot.sh b/run_model_cot.sh index 35d62b5..3c09d67 100755 --- a/run_model_cot.sh +++ b/run_model_cot.sh @@ -49,7 +49,6 @@ for model_name in "${model_names[@]}"; do --api_url "http://localhost:${PORT}/generate" \ --api_type "vllm" \ -p 10 \ - --cot_table_alias "prealias" \ --logprobs # finally, kill the api server pkill -9 -f "python3 utils/api_server.py.*--port ${PORT}" diff --git a/utils/api_server.py b/utils/api_server.py index ea1d009..a2d21a5 100644 --- a/utils/api_server.py +++ b/utils/api_server.py @@ -55,17 +55,32 @@ async def generate(request: Request) -> Response: sql_lora_path = request_dict.pop("sql_lora_path", None) request_dict.pop("sql_lora_name", None) lora_request = ( - LoRARequest("sql_adapter", 1, sql_lora_path) if sql_lora_path else None + LoRARequest(lora_name="sql_adapter", lora_int_id=1, lora_path=sql_lora_path) + if sql_lora_path + else None ) + if vllm_version >= "0.6.2": + # remove use_beam_search if present as it's no longer supported + # see https://github.com/vllm-project/vllm/releases/tag/v0.6.2 + if "use_beam_search" in request_dict: + request_dict.pop("use_beam_search") sampling_params = SamplingParams(**request_dict) request_id = random_uuid() tokenizer = await engine.get_tokenizer() prompt_token_ids = tokenizer.encode(text=prompt, add_special_tokens=False) - # print(f"prompt_token_ids: {prompt_token_ids}") if prompt_token_ids[0] != tokenizer.bos_token_id: prompt_token_ids = [tokenizer.bos_token_id] + prompt_token_ids - if vllm_version >= "0.4.2": + if vllm_version >= "0.6.3": + from vllm import TokensPrompt + + results_generator = engine.generate( + prompt=TokensPrompt(prompt_token_ids=prompt_token_ids), + sampling_params=sampling_params, + request_id=request_id, + lora_request=lora_request, + ) + elif vllm_version >= "0.4.2": results_generator = engine.generate( inputs={"prompt_token_ids": prompt_token_ids}, sampling_params=sampling_params,