diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index 1b5205f79d610..7ee9a1651400d 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -132,6 +132,28 @@ def test_chat_template(): assert res.body["__verbose"]["prompt"] == " <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" +@pytest.mark.parametrize("prefill,re_prefill", [ + ("Whill", "Whill"), + ([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Whill"), +]) +def test_chat_template_assistant_prefill(prefill, re_prefill): + global server + server.chat_template = "llama3" + server.debug = True # to get the "__verbose" object in the response + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": 8, + "messages": [ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + {"role": "assistant", "content": prefill}, + ] + }) + assert res.status_code == 200 + assert "__verbose" in res.body + assert res.body["__verbose"]["prompt"] == f" <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}" + + def test_apply_chat_template(): global server server.chat_template = "command-r" @@ -228,6 +250,7 @@ def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re [{"role": "system", "content": 123}], # [{"content": "hello"}], # TODO: should not be a valid case [{"role": "system", "content": "test"}, {}], + [{"role": "user", "content": "test"}, {"role": "assistant", "content": "test"}, {"role": "assistant", "content": "test"}], ]) def test_invalid_chat_completion_req(messages): global server diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index f8fab2c86664e..6add39830cba1 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -779,7 +779,13 @@ static json oaicompat_chat_params_parse( /* Append assistant prefilled message */ if (prefill_assistant_message) { - chat_params.prompt += last_message.content; + if (!last_message.content_parts.empty()) { + for (auto & p : last_message.content_parts) { + chat_params.prompt += p.text; + } + } else { + chat_params.prompt += last_message.content; + } } llama_params["chat_format"] = static_cast(chat_params.format);