Skip to content

Commit

Permalink
Python: Azure AI Inference Function Calling (#7035)
Browse files Browse the repository at this point in the history
### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->
Now that the Function Choice abstraction has been implemented in Python,
it is time to extend this feature to other connectors. The first (OAI
and AOAI are not included) connector to be granted this honor is the
Azure AI Inference connector.

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->
1. Add function calling to Azure AI Inference.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [ ] The code builds clean without any errors or warnings
- [ ] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [ ] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄
  • Loading branch information
TaoChenOSU authored Jul 9, 2024
1 parent 4e99b76 commit 0ad7f54
Show file tree
Hide file tree
Showing 8 changed files with 455 additions and 132 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.

from typing import Literal
from typing import Any, Literal

from pydantic import Field

Expand Down Expand Up @@ -30,6 +30,9 @@ class AzureAIInferencePromptExecutionSettings(PromptExecutionSettings):
class AzureAIInferenceChatPromptExecutionSettings(AzureAIInferencePromptExecutionSettings):
"""Azure AI Inference Chat Prompt Execution Settings."""

tools: list[dict[str, Any]] | None = Field(None, max_length=64)
tool_choice: str | None = None


@experimental_class
class AzureAIInferenceEmbeddingPromptExecutionSettings(PromptExecutionSettings):
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright (c) Microsoft. All rights reserved.

import logging
from collections.abc import Callable

from azure.ai.inference.models import (
AssistantMessage,
ChatCompletionsFunctionToolCall,
ChatRequestMessage,
FunctionCall,
ImageContentItem,
ImageDetailLevel,
ImageUrl,
SystemMessage,
TextContentItem,
ToolMessage,
UserMessage,
)

from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.function_call_content import FunctionCallContent
from semantic_kernel.contents.function_result_content import FunctionResultContent
from semantic_kernel.contents.image_content import ImageContent
from semantic_kernel.contents.text_content import TextContent
from semantic_kernel.contents.utils.author_role import AuthorRole

logger: logging.Logger = logging.getLogger(__name__)


def _format_system_message(message: ChatMessageContent) -> SystemMessage:
"""Format a system message to the expected object for the client.
Args:
message: The system message.
Returns:
The formatted system message.
"""
return SystemMessage(content=message.content)


def _format_user_message(message: ChatMessageContent) -> UserMessage:
"""Format a user message to the expected object for the client.
If there are any image items in the message, we need to create a list of content items,
otherwise we need to just pass in the content as a string or it will error.
Args:
message: The user message.
Returns:
The formatted user message.
"""
if not any(isinstance(item, (ImageContent)) for item in message.items):
return UserMessage(content=message.content)

contentItems = []
for item in message.items:
if isinstance(item, TextContent):
contentItems.append(TextContentItem(text=item.text))
elif isinstance(item, ImageContent) and (item.data_uri or item.uri):
contentItems.append(
ImageContentItem(image_url=ImageUrl(url=item.data_uri or str(item.uri), detail=ImageDetailLevel.Auto))
)
else:
logger.warning(
"Unsupported item type in User message while formatting chat history for Azure AI"
f" Inference: {type(item)}"
)

return UserMessage(content=contentItems)


def _format_assistant_message(message: ChatMessageContent) -> AssistantMessage:
"""Format an assistant message to the expected object for the client.
Args:
message: The assistant message.
Returns:
The formatted assistant message.
"""
contentItems = []
toolCalls = []

for item in message.items:
if isinstance(item, TextContent):
contentItems.append(TextContentItem(text=item.text))
elif isinstance(item, FunctionCallContent):
toolCalls.append(
ChatCompletionsFunctionToolCall(
id=item.id, function=FunctionCall(name=item.name, arguments=item.arguments)
)
)
else:
logger.warning(
"Unsupported item type in Assistant message while formatting chat history for Azure AI"
f" Inference: {type(item)}"
)

# tollCalls cannot be an empty list, so we need to set it to None if it is empty
return AssistantMessage(content=contentItems, tool_calls=toolCalls if toolCalls else None)


def _format_tool_message(message: ChatMessageContent) -> ToolMessage:
"""Format a tool message to the expected object for the client.
Args:
message: The tool message.
Returns:
The formatted tool message.
"""
if len(message.items) != 1:
logger.warning(
"Unsupported number of items in Tool message while formatting chat history for Azure AI"
f" Inference: {len(message.items)}"
)

if not isinstance(message.items[0], FunctionResultContent):
logger.warning(
"Unsupported item type in Tool message while formatting chat history for Azure AI"
f" Inference: {type(message.items[0])}"
)

# The API expects the result to be a string, so we need to convert it to a string
return ToolMessage(content=str(message.items[0].result), tool_call_id=message.items[0].id)


MESSAGE_CONVERTERS: dict[AuthorRole, Callable[[ChatMessageContent], ChatRequestMessage]] = {
AuthorRole.SYSTEM: _format_system_message,
AuthorRole.USER: _format_user_message,
AuthorRole.ASSISTANT: _format_assistant_message,
AuthorRole.TOOL: _format_tool_message,
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _prepare_chat_history_for_request(
chat_history: "ChatHistory",
role_key: str = "role",
content_key: str = "content",
) -> list[dict[str, str | None]]:
) -> Any:
"""Prepare the chat history for a request.
Allowing customization of the key names for role/author, and optionally overriding the role.
Expand All @@ -68,12 +68,14 @@ def _prepare_chat_history_for_request(
They require a "tool_call_id" and (function) "name" key, and the "metadata" key should
be removed. The "encoding" key should also be removed.
Override this method to customize the formatting of the chat history for a request.
Args:
chat_history (ChatHistory): The chat history to prepare.
role_key (str): The key name for the role/author.
content_key (str): The key name for the content/message.
Returns:
List[Dict[str, Optional[str]]]: The prepared chat history.
prepared_chat_history (Any): The prepared chat history for a request.
"""
return [message.to_dict(role_key=role_key, content_key=content_key) for message in chat_history.messages]
28 changes: 10 additions & 18 deletions python/semantic_kernel/connectors/ai/function_calling_utils.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,23 @@
# Copyright (c) Microsoft. All rights reserved.

import logging
from typing import TYPE_CHECKING, Any
from typing import Any

from semantic_kernel.connectors.ai.open_ai.prompt_execution_settings.open_ai_prompt_execution_settings import (
OpenAIChatPromptExecutionSettings,
)
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionCallChoiceConfiguration
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
from semantic_kernel.functions.kernel_function_metadata import KernelFunctionMetadata

if TYPE_CHECKING:
from semantic_kernel.connectors.ai.function_choice_behavior import (
FunctionCallChoiceConfiguration,
)
from semantic_kernel.connectors.ai.open_ai.prompt_execution_settings.open_ai_prompt_execution_settings import (
OpenAIChatPromptExecutionSettings,
)

logger = logging.getLogger(__name__)


def update_settings_from_function_call_configuration(
function_choice_configuration: "FunctionCallChoiceConfiguration",
settings: "OpenAIChatPromptExecutionSettings",
function_choice_configuration: FunctionCallChoiceConfiguration,
settings: PromptExecutionSettings,
type: str,
) -> None:
"""Update the settings from a FunctionChoiceConfiguration."""
if function_choice_configuration.available_functions:
if (
function_choice_configuration.available_functions
and hasattr(settings, "tool_choice")
and hasattr(settings, "tools")
):
settings.tool_choice = type
settings.tools = [
kernel_function_metadata_to_function_call_format(f)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,8 @@

from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase
from semantic_kernel.connectors.ai.function_call_behavior import FunctionCallBehavior
from semantic_kernel.connectors.ai.function_calling_utils import (
update_settings_from_function_call_configuration,
)
from semantic_kernel.connectors.ai.function_choice_behavior import (
FunctionChoiceBehavior,
)
from semantic_kernel.connectors.ai.function_calling_utils import update_settings_from_function_call_configuration
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior
from semantic_kernel.connectors.ai.open_ai.prompt_execution_settings.open_ai_prompt_execution_settings import (
OpenAIChatPromptExecutionSettings,
)
Expand All @@ -33,10 +29,7 @@
from semantic_kernel.contents.text_content import TextContent
from semantic_kernel.contents.utils.author_role import AuthorRole
from semantic_kernel.contents.utils.finish_reason import FinishReason
from semantic_kernel.exceptions import (
ServiceInvalidExecutionSettingsError,
ServiceInvalidResponseError,
)
from semantic_kernel.exceptions import ServiceInvalidExecutionSettingsError, ServiceInvalidResponseError
from semantic_kernel.filters.auto_function_invocation.auto_function_invocation_context import (
AutoFunctionInvocationContext,
)
Expand Down
52 changes: 52 additions & 0 deletions python/tests/integration/completions/test_chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,58 @@ def services() -> dict[str, tuple[ChatCompletionClientBase, type[PromptExecution
["house", "germany"],
id="azure_ai_inference_image_input_file",
),
pytest.param(
"azure_ai_inference",
{
"function_choice_behavior": FunctionChoiceBehavior.Auto(
auto_invoke=True, filters={"excluded_plugins": ["chat"]}
)
},
[
ChatMessageContent(role=AuthorRole.USER, items=[TextContent(text="What is 3+345?")]),
],
["348"],
id="azure_ai_inference_tool_call_auto",
),
pytest.param(
"azure_ai_inference",
{
"function_choice_behavior": FunctionChoiceBehavior.Auto(
auto_invoke=False, filters={"excluded_plugins": ["chat"]}
)
},
[
ChatMessageContent(role=AuthorRole.USER, items=[TextContent(text="What is 3+345?")]),
],
["348"],
id="azure_ai_inference_tool_call_non_auto",
),
pytest.param(
"azure_ai_inference",
{},
[
[
ChatMessageContent(
role=AuthorRole.USER,
items=[TextContent(text="What was our 2024 revenue?")],
),
ChatMessageContent(
role=AuthorRole.ASSISTANT,
items=[
FunctionCallContent(
id="fin", name="finance-search", arguments='{"company": "contoso", "year": 2024}'
)
],
),
ChatMessageContent(
role=AuthorRole.TOOL,
items=[FunctionResultContent(id="fin", name="finance-search", result="1.2B")],
),
],
],
["1.2"],
id="azure_ai_inference_tool_call_flow",
),
pytest.param(
"mistral_ai",
{},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ async def test_process_tool_calls_with_continuation_on_malformed_arguments():
ai_model_id="test_model_id", service_id="test", client=MagicMock(spec=AsyncOpenAI)
)

with patch("semantic_kernel.connectors.ai.function_calling_utils.logger", autospec=True):
with patch("semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion_base.logger", autospec=True):
await chat_completion_base._process_function_call(
tool_call_mock,
chat_history_mock,
Expand Down

0 comments on commit 0ad7f54

Please sign in to comment.