Skip to content

Commit

Permalink
Python: add chat system message to complete settings (#2726)
Browse files Browse the repository at this point in the history
### Motivation and Context
This PR parallels work done in .NET for
#2671 and closes
#2722.

This PR allows users to set the chat system message using the semantic
function prompt template.

### Description

* added new `chat_system_message` parameter to
`complete_request_setttings` (defaults to None) and
`chat_request_settings `(defults to "Assistant is a large language
model.")
* added handling of this new parameter to the `chat_prompt_template`
* added corresponding tests
* added a small helper function for deserializing prompt_configs - token
selection biases need to be deserialized to ints, otherwise json loads
deserializes the keys in the set to be strings.

Unrelated 
* added missing __init__ file to the weaviate connector's directory

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [ ] The code builds clean without any errors or warnings
- [ ] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [ ] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄
  • Loading branch information
awharrison-28 authored Sep 6, 2023
1 parent a3ff069 commit 7313258
Show file tree
Hide file tree
Showing 8 changed files with 299 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class CompleteRequestSettings:
number_of_responses: int = 1
logprobs: int = 0
token_selection_biases: Dict[int, int] = field(default_factory=dict)
chat_system_prompt: str = "Assistant is a large language model."

def update_from_completion_config(
self, completion_config: "PromptTemplateConfig.CompletionConfig"
Expand All @@ -33,6 +34,9 @@ def update_from_completion_config(
self.number_of_responses = completion_config.number_of_responses
self.token_selection_biases = completion_config.token_selection_biases

if completion_config.chat_system_prompt:
self.chat_system_prompt = completion_config.chat_system_prompt

@staticmethod
def from_completion_config(
completion_config: "PromptTemplateConfig.CompletionConfig",
Expand Down
6 changes: 6 additions & 0 deletions python/semantic_kernel/connectors/memory/weaviate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Microsoft. All rights reserved
from semantic_kernel.connectors.memory.weaviate.weaviate_memory_store import (
WeaviateMemoryStore,
)

__all__ = ["WeaviateMemoryStore"]
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from typing import List, Optional, Tuple

import numpy as np
import weaviate
from weaviate.embedded import EmbeddedOptions

import weaviate
from semantic_kernel.memory.memory_record import MemoryRecord
from semantic_kernel.memory.memory_store_base import MemoryStoreBase
from semantic_kernel.utils.null_logger import NullLogger
from weaviate.embedded import EmbeddedOptions

SCHEMA = {
"class": "MemoryRecord",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def __init__(
) -> None:
super().__init__(template, template_engine, prompt_config, log)
self._messages = []
if self._prompt_config.completion.chat_system_prompt:
self.add_system_message(self._prompt_config.completion.chat_system_prompt)

async def render_async(self, context: "SKContext") -> str:
raise NotImplementedError(
Expand Down Expand Up @@ -81,6 +83,12 @@ def restore(
) -> "ChatPromptTemplate":
"""Restore a ChatPromptTemplate from a list of role and message pairs."""
chat_template = cls(template, template_engine, prompt_config, log)

if prompt_config.chat_system_prompt:
chat_template.add_system_message(
prompt_config.completion.chat_system_prompt
)

for message in messages:
chat_template.add_message(message["role"], message["message"])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class CompletionConfig:
number_of_responses: int = 1
stop_sequences: List[str] = field(default_factory=list)
token_selection_biases: Dict[int, int] = field(default_factory=dict)
chat_system_prompt: str = None

@dataclass
class InputParameter:
Expand Down Expand Up @@ -60,6 +61,7 @@ def from_dict(data: dict) -> "PromptTemplateConfig":
config.completion.token_selection_biases = completion_dict.get(
"token_selection_biases", {}
)
config.completion.chat_system_prompt = completion_dict.get("chat_system_prompt")

config.default_services = data.get("default_services", [])

Expand Down Expand Up @@ -102,7 +104,12 @@ def from_dict(data: dict) -> "PromptTemplateConfig":
def from_json(json_str: str) -> "PromptTemplateConfig":
import json

return PromptTemplateConfig.from_dict(json.loads(json_str))
def keystoint(d):
return {int(k) if k.isdigit() else k: v for k, v in d.items()}

return PromptTemplateConfig.from_dict(
json.loads(json_str, object_hook=keystoint)
)

@staticmethod
def from_completion_parameters(
Expand All @@ -114,6 +121,7 @@ def from_completion_parameters(
number_of_responses: int = 1,
stop_sequences: List[str] = [],
token_selection_biases: Dict[int, int] = {},
chat_system_prompt: str = None,
) -> "PromptTemplateConfig":
config = PromptTemplateConfig()
config.completion.temperature = temperature
Expand All @@ -124,4 +132,5 @@ def from_completion_parameters(
config.completion.number_of_responses = number_of_responses
config.completion.stop_sequences = stop_sequences
config.completion.token_selection_biases = token_selection_biases
config.completion.chat_system_prompt = chat_system_prompt
return config
2 changes: 1 addition & 1 deletion python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_aoai_config():
endpoint = os.environ["AzureOpenAI__Endpoint"]
else:
# Load credentials from .env file
deployment_name, api_key, endpoint, _ = sk.azure_openai_settings_from_dot_env()
deployment_name, api_key, endpoint = sk.azure_openai_settings_from_dot_env()
deployment_name = "text-embedding-ada-002"

return deployment_name, api_key, endpoint
Expand Down
95 changes: 95 additions & 0 deletions python/tests/unit/ai/test_request_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) Microsoft. All rights reserved.

from semantic_kernel.connectors.ai.chat_request_settings import ChatRequestSettings
from semantic_kernel.connectors.ai.complete_request_settings import (
CompleteRequestSettings,
)


def test_default_complete_request_settings():
settings = CompleteRequestSettings()
assert settings.temperature == 0.0
assert settings.top_p == 1.0
assert settings.presence_penalty == 0.0
assert settings.frequency_penalty == 0.0
assert settings.max_tokens == 256
assert settings.stop_sequences == []
assert settings.number_of_responses == 1
assert settings.logprobs == 0
assert settings.token_selection_biases == {}
assert settings.chat_system_prompt == "Assistant is a large language model."


def test_custom_complete_request_settings():
settings = CompleteRequestSettings(
temperature=0.5,
top_p=0.5,
presence_penalty=0.5,
frequency_penalty=0.5,
max_tokens=128,
stop_sequences=["\n"],
number_of_responses=2,
logprobs=1,
token_selection_biases={1: 1},
chat_system_prompt="Hello",
)
assert settings.temperature == 0.5
assert settings.top_p == 0.5
assert settings.presence_penalty == 0.5
assert settings.frequency_penalty == 0.5
assert settings.max_tokens == 128
assert settings.stop_sequences == ["\n"]
assert settings.number_of_responses == 2
assert settings.logprobs == 1
assert settings.token_selection_biases == {1: 1}
assert settings.chat_system_prompt == "Hello"


def test_default_chat_request_settings():
settings = ChatRequestSettings()
assert settings.temperature == 0.0
assert settings.top_p == 1.0
assert settings.presence_penalty == 0.0
assert settings.frequency_penalty == 0.0
assert settings.max_tokens == 256
assert settings.stop_sequences == []
assert settings.number_of_responses == 1
assert settings.token_selection_biases == {}


def test_complete_request_settings_from_default_completion_config():
settings = CompleteRequestSettings()
chat_settings = ChatRequestSettings.from_completion_config(settings)
chat_settings = ChatRequestSettings()
assert chat_settings.temperature == 0.0
assert chat_settings.top_p == 1.0
assert chat_settings.presence_penalty == 0.0
assert chat_settings.frequency_penalty == 0.0
assert chat_settings.max_tokens == 256
assert chat_settings.stop_sequences == []
assert chat_settings.number_of_responses == 1
assert chat_settings.token_selection_biases == {}


def test_chat_request_settings_from_custom_completion_config():
settings = CompleteRequestSettings(
temperature=0.5,
top_p=0.5,
presence_penalty=0.5,
frequency_penalty=0.5,
max_tokens=128,
stop_sequences=["\n"],
number_of_responses=2,
logprobs=1,
token_selection_biases={1: 1},
chat_system_prompt="Hello",
)
chat_settings = ChatRequestSettings.from_completion_config(settings)
assert chat_settings.temperature == 0.5
assert chat_settings.top_p == 0.5
assert chat_settings.presence_penalty == 0.5
assert chat_settings.frequency_penalty == 0.5
assert chat_settings.max_tokens == 128
assert chat_settings.stop_sequences == ["\n"]
assert chat_settings.number_of_responses == 2
assert chat_settings.token_selection_biases == {1: 1}
173 changes: 173 additions & 0 deletions python/tests/unit/skill_definition/test_prompt_templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright (c) Microsoft. All rights reserved.

import json

import pytest

from semantic_kernel.semantic_functions.chat_prompt_template import ChatPromptTemplate
from semantic_kernel.semantic_functions.prompt_template_config import (
PromptTemplateConfig,
)


def test_default_prompt_template_config():
prompt_template_config = PromptTemplateConfig()
assert prompt_template_config.schema == 1
assert prompt_template_config.type == "completion"
assert prompt_template_config.description == ""
assert prompt_template_config.completion.temperature == 0.0
assert prompt_template_config.completion.top_p == 1.0
assert prompt_template_config.completion.presence_penalty == 0.0
assert prompt_template_config.completion.frequency_penalty == 0.0
assert prompt_template_config.completion.max_tokens == 256
assert prompt_template_config.completion.number_of_responses == 1
assert prompt_template_config.completion.stop_sequences == []
assert prompt_template_config.completion.token_selection_biases == {}
assert prompt_template_config.completion.chat_system_prompt is None


def test_default_chat_prompt_template_from_empty_dict():
with pytest.raises(KeyError):
_ = PromptTemplateConfig().from_dict({})


def test_default_chat_prompt_template_from_empty_string():
with pytest.raises(json.decoder.JSONDecodeError):
_ = PromptTemplateConfig().from_json("")


def test_default_chat_prompt_template_from_empty_json():
with pytest.raises(KeyError):
_ = PromptTemplateConfig().from_json("{}")


def test_custom_prompt_template_config():
prompt_template_config = PromptTemplateConfig(
schema=2,
type="completion2",
description="Custom description.",
completion=PromptTemplateConfig.CompletionConfig(
temperature=0.5,
top_p=0.5,
presence_penalty=0.5,
frequency_penalty=0.5,
max_tokens=128,
number_of_responses=2,
stop_sequences=["\n"],
token_selection_biases={1: 1},
chat_system_prompt="Custom system prompt.",
),
)
assert prompt_template_config.schema == 2
assert prompt_template_config.type == "completion2"
assert prompt_template_config.description == "Custom description."
assert prompt_template_config.completion.temperature == 0.5
assert prompt_template_config.completion.top_p == 0.5
assert prompt_template_config.completion.presence_penalty == 0.5
assert prompt_template_config.completion.frequency_penalty == 0.5
assert prompt_template_config.completion.max_tokens == 128
assert prompt_template_config.completion.number_of_responses == 2
assert prompt_template_config.completion.stop_sequences == ["\n"]
assert prompt_template_config.completion.token_selection_biases == {1: 1}
assert (
prompt_template_config.completion.chat_system_prompt == "Custom system prompt."
)


def test_custom_prompt_template_config_from_dict():
prompt_template_dict = {
"schema": 2,
"type": "completion2",
"description": "Custom description.",
"completion": {
"temperature": 0.5,
"top_p": 0.5,
"presence_penalty": 0.5,
"frequency_penalty": 0.5,
"max_tokens": 128,
"number_of_responses": 2,
"stop_sequences": ["\n"],
"token_selection_biases": {1: 1},
"chat_system_prompt": "Custom system prompt.",
},
}
prompt_template_config = PromptTemplateConfig().from_dict(prompt_template_dict)
assert prompt_template_config.schema == 2
assert prompt_template_config.type == "completion2"
assert prompt_template_config.description == "Custom description."
assert prompt_template_config.completion.temperature == 0.5
assert prompt_template_config.completion.top_p == 0.5
assert prompt_template_config.completion.presence_penalty == 0.5
assert prompt_template_config.completion.frequency_penalty == 0.5
assert prompt_template_config.completion.max_tokens == 128
assert prompt_template_config.completion.number_of_responses == 2
assert prompt_template_config.completion.stop_sequences == ["\n"]
assert prompt_template_config.completion.token_selection_biases == {1: 1}
assert (
prompt_template_config.completion.chat_system_prompt == "Custom system prompt."
)


def test_custom_prompt_template_config_from_json():
prompt_template_json = """
{
"schema": 2,
"type": "completion2",
"description": "Custom description.",
"completion": {
"temperature": 0.5,
"top_p": 0.5,
"presence_penalty": 0.5,
"frequency_penalty": 0.5,
"max_tokens": 128,
"number_of_responses": 2,
"stop_sequences": ["s"],
"token_selection_biases": {"1": 1},
"chat_system_prompt": "Custom system prompt."
}
}
"""
prompt_template_config = PromptTemplateConfig().from_json(prompt_template_json)
assert prompt_template_config.schema == 2
assert prompt_template_config.type == "completion2"
assert prompt_template_config.description == "Custom description."
assert prompt_template_config.completion.temperature == 0.5
assert prompt_template_config.completion.top_p == 0.5
assert prompt_template_config.completion.presence_penalty == 0.5
assert prompt_template_config.completion.frequency_penalty == 0.5
assert prompt_template_config.completion.max_tokens == 128
assert prompt_template_config.completion.number_of_responses == 2
assert prompt_template_config.completion.stop_sequences == ["s"]
assert prompt_template_config.completion.token_selection_biases == {1: 1}
assert (
prompt_template_config.completion.chat_system_prompt == "Custom system prompt."
)


def test_chat_prompt_template():
chat_prompt_template = ChatPromptTemplate(
"{{$user_input}}",
None,
prompt_config=PromptTemplateConfig(),
)

assert chat_prompt_template._messages == []


def test_chat_prompt_template_with_system_prompt():
prompt_template_config = PromptTemplateConfig(
completion=PromptTemplateConfig.CompletionConfig(
chat_system_prompt="Custom system prompt.",
)
)

chat_prompt_template = ChatPromptTemplate(
"{{$user_input}}",
None,
prompt_config=prompt_template_config,
)

print(chat_prompt_template.messages)
assert len(chat_prompt_template.messages) == 1
assert chat_prompt_template.messages[0]["role"] == "system"
assert chat_prompt_template.messages[0]["message"] == "Custom system prompt."

0 comments on commit 7313258

Please sign in to comment.