-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Python: add chat system message to complete settings (#2726)
### Motivation and Context This PR parallels work done in .NET for #2671 and closes #2722. This PR allows users to set the chat system message using the semantic function prompt template. ### Description * added new `chat_system_message` parameter to `complete_request_setttings` (defaults to None) and `chat_request_settings `(defults to "Assistant is a large language model.") * added handling of this new parameter to the `chat_prompt_template` * added corresponding tests * added a small helper function for deserializing prompt_configs - token selection biases need to be deserialized to ints, otherwise json loads deserializes the keys in the set to be strings. Unrelated * added missing __init__ file to the weaviate connector's directory ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [ ] The code builds clean without any errors or warnings - [ ] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [ ] All unit tests pass, and I have added new tests where possible - [ ] I didn't break anyone 😄
- Loading branch information
1 parent
a3ff069
commit 7313258
Showing
8 changed files
with
299 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
6 changes: 6 additions & 0 deletions
6
python/semantic_kernel/connectors/memory/weaviate/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Copyright (c) Microsoft. All rights reserved | ||
from semantic_kernel.connectors.memory.weaviate.weaviate_memory_store import ( | ||
WeaviateMemoryStore, | ||
) | ||
|
||
__all__ = ["WeaviateMemoryStore"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# Copyright (c) Microsoft. All rights reserved. | ||
|
||
from semantic_kernel.connectors.ai.chat_request_settings import ChatRequestSettings | ||
from semantic_kernel.connectors.ai.complete_request_settings import ( | ||
CompleteRequestSettings, | ||
) | ||
|
||
|
||
def test_default_complete_request_settings(): | ||
settings = CompleteRequestSettings() | ||
assert settings.temperature == 0.0 | ||
assert settings.top_p == 1.0 | ||
assert settings.presence_penalty == 0.0 | ||
assert settings.frequency_penalty == 0.0 | ||
assert settings.max_tokens == 256 | ||
assert settings.stop_sequences == [] | ||
assert settings.number_of_responses == 1 | ||
assert settings.logprobs == 0 | ||
assert settings.token_selection_biases == {} | ||
assert settings.chat_system_prompt == "Assistant is a large language model." | ||
|
||
|
||
def test_custom_complete_request_settings(): | ||
settings = CompleteRequestSettings( | ||
temperature=0.5, | ||
top_p=0.5, | ||
presence_penalty=0.5, | ||
frequency_penalty=0.5, | ||
max_tokens=128, | ||
stop_sequences=["\n"], | ||
number_of_responses=2, | ||
logprobs=1, | ||
token_selection_biases={1: 1}, | ||
chat_system_prompt="Hello", | ||
) | ||
assert settings.temperature == 0.5 | ||
assert settings.top_p == 0.5 | ||
assert settings.presence_penalty == 0.5 | ||
assert settings.frequency_penalty == 0.5 | ||
assert settings.max_tokens == 128 | ||
assert settings.stop_sequences == ["\n"] | ||
assert settings.number_of_responses == 2 | ||
assert settings.logprobs == 1 | ||
assert settings.token_selection_biases == {1: 1} | ||
assert settings.chat_system_prompt == "Hello" | ||
|
||
|
||
def test_default_chat_request_settings(): | ||
settings = ChatRequestSettings() | ||
assert settings.temperature == 0.0 | ||
assert settings.top_p == 1.0 | ||
assert settings.presence_penalty == 0.0 | ||
assert settings.frequency_penalty == 0.0 | ||
assert settings.max_tokens == 256 | ||
assert settings.stop_sequences == [] | ||
assert settings.number_of_responses == 1 | ||
assert settings.token_selection_biases == {} | ||
|
||
|
||
def test_complete_request_settings_from_default_completion_config(): | ||
settings = CompleteRequestSettings() | ||
chat_settings = ChatRequestSettings.from_completion_config(settings) | ||
chat_settings = ChatRequestSettings() | ||
assert chat_settings.temperature == 0.0 | ||
assert chat_settings.top_p == 1.0 | ||
assert chat_settings.presence_penalty == 0.0 | ||
assert chat_settings.frequency_penalty == 0.0 | ||
assert chat_settings.max_tokens == 256 | ||
assert chat_settings.stop_sequences == [] | ||
assert chat_settings.number_of_responses == 1 | ||
assert chat_settings.token_selection_biases == {} | ||
|
||
|
||
def test_chat_request_settings_from_custom_completion_config(): | ||
settings = CompleteRequestSettings( | ||
temperature=0.5, | ||
top_p=0.5, | ||
presence_penalty=0.5, | ||
frequency_penalty=0.5, | ||
max_tokens=128, | ||
stop_sequences=["\n"], | ||
number_of_responses=2, | ||
logprobs=1, | ||
token_selection_biases={1: 1}, | ||
chat_system_prompt="Hello", | ||
) | ||
chat_settings = ChatRequestSettings.from_completion_config(settings) | ||
assert chat_settings.temperature == 0.5 | ||
assert chat_settings.top_p == 0.5 | ||
assert chat_settings.presence_penalty == 0.5 | ||
assert chat_settings.frequency_penalty == 0.5 | ||
assert chat_settings.max_tokens == 128 | ||
assert chat_settings.stop_sequences == ["\n"] | ||
assert chat_settings.number_of_responses == 2 | ||
assert chat_settings.token_selection_biases == {1: 1} |
173 changes: 173 additions & 0 deletions
173
python/tests/unit/skill_definition/test_prompt_templates.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
# Copyright (c) Microsoft. All rights reserved. | ||
|
||
import json | ||
|
||
import pytest | ||
|
||
from semantic_kernel.semantic_functions.chat_prompt_template import ChatPromptTemplate | ||
from semantic_kernel.semantic_functions.prompt_template_config import ( | ||
PromptTemplateConfig, | ||
) | ||
|
||
|
||
def test_default_prompt_template_config(): | ||
prompt_template_config = PromptTemplateConfig() | ||
assert prompt_template_config.schema == 1 | ||
assert prompt_template_config.type == "completion" | ||
assert prompt_template_config.description == "" | ||
assert prompt_template_config.completion.temperature == 0.0 | ||
assert prompt_template_config.completion.top_p == 1.0 | ||
assert prompt_template_config.completion.presence_penalty == 0.0 | ||
assert prompt_template_config.completion.frequency_penalty == 0.0 | ||
assert prompt_template_config.completion.max_tokens == 256 | ||
assert prompt_template_config.completion.number_of_responses == 1 | ||
assert prompt_template_config.completion.stop_sequences == [] | ||
assert prompt_template_config.completion.token_selection_biases == {} | ||
assert prompt_template_config.completion.chat_system_prompt is None | ||
|
||
|
||
def test_default_chat_prompt_template_from_empty_dict(): | ||
with pytest.raises(KeyError): | ||
_ = PromptTemplateConfig().from_dict({}) | ||
|
||
|
||
def test_default_chat_prompt_template_from_empty_string(): | ||
with pytest.raises(json.decoder.JSONDecodeError): | ||
_ = PromptTemplateConfig().from_json("") | ||
|
||
|
||
def test_default_chat_prompt_template_from_empty_json(): | ||
with pytest.raises(KeyError): | ||
_ = PromptTemplateConfig().from_json("{}") | ||
|
||
|
||
def test_custom_prompt_template_config(): | ||
prompt_template_config = PromptTemplateConfig( | ||
schema=2, | ||
type="completion2", | ||
description="Custom description.", | ||
completion=PromptTemplateConfig.CompletionConfig( | ||
temperature=0.5, | ||
top_p=0.5, | ||
presence_penalty=0.5, | ||
frequency_penalty=0.5, | ||
max_tokens=128, | ||
number_of_responses=2, | ||
stop_sequences=["\n"], | ||
token_selection_biases={1: 1}, | ||
chat_system_prompt="Custom system prompt.", | ||
), | ||
) | ||
assert prompt_template_config.schema == 2 | ||
assert prompt_template_config.type == "completion2" | ||
assert prompt_template_config.description == "Custom description." | ||
assert prompt_template_config.completion.temperature == 0.5 | ||
assert prompt_template_config.completion.top_p == 0.5 | ||
assert prompt_template_config.completion.presence_penalty == 0.5 | ||
assert prompt_template_config.completion.frequency_penalty == 0.5 | ||
assert prompt_template_config.completion.max_tokens == 128 | ||
assert prompt_template_config.completion.number_of_responses == 2 | ||
assert prompt_template_config.completion.stop_sequences == ["\n"] | ||
assert prompt_template_config.completion.token_selection_biases == {1: 1} | ||
assert ( | ||
prompt_template_config.completion.chat_system_prompt == "Custom system prompt." | ||
) | ||
|
||
|
||
def test_custom_prompt_template_config_from_dict(): | ||
prompt_template_dict = { | ||
"schema": 2, | ||
"type": "completion2", | ||
"description": "Custom description.", | ||
"completion": { | ||
"temperature": 0.5, | ||
"top_p": 0.5, | ||
"presence_penalty": 0.5, | ||
"frequency_penalty": 0.5, | ||
"max_tokens": 128, | ||
"number_of_responses": 2, | ||
"stop_sequences": ["\n"], | ||
"token_selection_biases": {1: 1}, | ||
"chat_system_prompt": "Custom system prompt.", | ||
}, | ||
} | ||
prompt_template_config = PromptTemplateConfig().from_dict(prompt_template_dict) | ||
assert prompt_template_config.schema == 2 | ||
assert prompt_template_config.type == "completion2" | ||
assert prompt_template_config.description == "Custom description." | ||
assert prompt_template_config.completion.temperature == 0.5 | ||
assert prompt_template_config.completion.top_p == 0.5 | ||
assert prompt_template_config.completion.presence_penalty == 0.5 | ||
assert prompt_template_config.completion.frequency_penalty == 0.5 | ||
assert prompt_template_config.completion.max_tokens == 128 | ||
assert prompt_template_config.completion.number_of_responses == 2 | ||
assert prompt_template_config.completion.stop_sequences == ["\n"] | ||
assert prompt_template_config.completion.token_selection_biases == {1: 1} | ||
assert ( | ||
prompt_template_config.completion.chat_system_prompt == "Custom system prompt." | ||
) | ||
|
||
|
||
def test_custom_prompt_template_config_from_json(): | ||
prompt_template_json = """ | ||
{ | ||
"schema": 2, | ||
"type": "completion2", | ||
"description": "Custom description.", | ||
"completion": { | ||
"temperature": 0.5, | ||
"top_p": 0.5, | ||
"presence_penalty": 0.5, | ||
"frequency_penalty": 0.5, | ||
"max_tokens": 128, | ||
"number_of_responses": 2, | ||
"stop_sequences": ["s"], | ||
"token_selection_biases": {"1": 1}, | ||
"chat_system_prompt": "Custom system prompt." | ||
} | ||
} | ||
""" | ||
prompt_template_config = PromptTemplateConfig().from_json(prompt_template_json) | ||
assert prompt_template_config.schema == 2 | ||
assert prompt_template_config.type == "completion2" | ||
assert prompt_template_config.description == "Custom description." | ||
assert prompt_template_config.completion.temperature == 0.5 | ||
assert prompt_template_config.completion.top_p == 0.5 | ||
assert prompt_template_config.completion.presence_penalty == 0.5 | ||
assert prompt_template_config.completion.frequency_penalty == 0.5 | ||
assert prompt_template_config.completion.max_tokens == 128 | ||
assert prompt_template_config.completion.number_of_responses == 2 | ||
assert prompt_template_config.completion.stop_sequences == ["s"] | ||
assert prompt_template_config.completion.token_selection_biases == {1: 1} | ||
assert ( | ||
prompt_template_config.completion.chat_system_prompt == "Custom system prompt." | ||
) | ||
|
||
|
||
def test_chat_prompt_template(): | ||
chat_prompt_template = ChatPromptTemplate( | ||
"{{$user_input}}", | ||
None, | ||
prompt_config=PromptTemplateConfig(), | ||
) | ||
|
||
assert chat_prompt_template._messages == [] | ||
|
||
|
||
def test_chat_prompt_template_with_system_prompt(): | ||
prompt_template_config = PromptTemplateConfig( | ||
completion=PromptTemplateConfig.CompletionConfig( | ||
chat_system_prompt="Custom system prompt.", | ||
) | ||
) | ||
|
||
chat_prompt_template = ChatPromptTemplate( | ||
"{{$user_input}}", | ||
None, | ||
prompt_config=prompt_template_config, | ||
) | ||
|
||
print(chat_prompt_template.messages) | ||
assert len(chat_prompt_template.messages) == 1 | ||
assert chat_prompt_template.messages[0]["role"] == "system" | ||
assert chat_prompt_template.messages[0]["message"] == "Custom system prompt." |