diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 1d2ceca..e1c5488 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -6,21 +6,22 @@ jobs: lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: psf/black@stable test: runs-on: ubuntu-latest needs: lint steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: - python-version: '3.11' + python-version: '3.10' cache: 'pip' - name: Install pip dependencies run: | - pip install -r requirements.txt + pip install --upgrade pip setuptools + pip install -r requirements_test.txt pip install pytest - name: Download spaCy model run: python -m spacy download en_core_web_sm diff --git a/requirements.txt b/requirements.txt index 982f6b0..31360ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,8 +3,9 @@ argparse func_timeout mistralai mysql-connector-python +numpy==2.1.2 openai>=1.1.0 -pandas +pandas==2.2.3 pandas-gbq peft psycopg2-binary @@ -15,11 +16,11 @@ sentence-transformers snowflake-connector-python spacy sqlalchemy -tiktoken==0.7.0 +tiktoken together -torch +torch==2.4.0 tqdm transformers sqlparse sqlglot -vllm; sys_platform != 'darwin' +vllm==0.6.3.post1; sys_platform != 'darwin' diff --git a/requirements_test.txt b/requirements_test.txt new file mode 100644 index 0000000..8aeada3 --- /dev/null +++ b/requirements_test.txt @@ -0,0 +1,13 @@ +func_timeout +numpy +openai +pandas +psycopg2-binary +pysqlite3 +sentence_transformers +snowflake-connector-python +spacy==3.7.2 +sqlalchemy +sqlglot +torch +tqdm \ No newline at end of file diff --git a/run_model_cot.sh b/run_model_cot.sh index 35d62b5..3c09d67 100755 --- a/run_model_cot.sh +++ b/run_model_cot.sh @@ -49,7 +49,6 @@ for model_name in "${model_names[@]}"; do --api_url "http://localhost:${PORT}/generate" \ --api_type "vllm" \ -p 10 \ - --cot_table_alias "prealias" \ --logprobs # finally, kill the api server pkill -9 -f "python3 utils/api_server.py.*--port ${PORT}" diff --git a/tests/test_utils_pruning.py b/tests/test_utils_pruning.py index 63cb564..468548d 100644 --- a/tests/test_utils_pruning.py +++ b/tests/test_utils_pruning.py @@ -33,6 +33,11 @@ def test_metadata(): "airport.airport_name,text,name of airport", "flight.airport_name,text,name of the airport", ], + "FAC": [ + "country.name,text,country name", + "airport.airport_name,text,name of airport", + "flight.airport_name,text,name of the airport", + ], "PERSON": ["flight.pilot_name,text,name of the pilot"], } column_join = {("airport", "country"): [("airport.country_id", "country.id")]} diff --git a/utils/api_server.py b/utils/api_server.py index ea1d009..a2d21a5 100644 --- a/utils/api_server.py +++ b/utils/api_server.py @@ -55,17 +55,32 @@ async def generate(request: Request) -> Response: sql_lora_path = request_dict.pop("sql_lora_path", None) request_dict.pop("sql_lora_name", None) lora_request = ( - LoRARequest("sql_adapter", 1, sql_lora_path) if sql_lora_path else None + LoRARequest(lora_name="sql_adapter", lora_int_id=1, lora_path=sql_lora_path) + if sql_lora_path + else None ) + if vllm_version >= "0.6.2": + # remove use_beam_search if present as it's no longer supported + # see https://github.com/vllm-project/vllm/releases/tag/v0.6.2 + if "use_beam_search" in request_dict: + request_dict.pop("use_beam_search") sampling_params = SamplingParams(**request_dict) request_id = random_uuid() tokenizer = await engine.get_tokenizer() prompt_token_ids = tokenizer.encode(text=prompt, add_special_tokens=False) - # print(f"prompt_token_ids: {prompt_token_ids}") if prompt_token_ids[0] != tokenizer.bos_token_id: prompt_token_ids = [tokenizer.bos_token_id] + prompt_token_ids - if vllm_version >= "0.4.2": + if vllm_version >= "0.6.3": + from vllm import TokensPrompt + + results_generator = engine.generate( + prompt=TokensPrompt(prompt_token_ids=prompt_token_ids), + sampling_params=sampling_params, + request_id=request_id, + lora_request=lora_request, + ) + elif vllm_version >= "0.4.2": results_generator = engine.generate( inputs={"prompt_token_ids": prompt_token_ids}, sampling_params=sampling_params,