Skip to content

Commit

Permalink
Python: Refactor agent retrieve method to be classmethod. Update test…
Browse files Browse the repository at this point in the history
…s. (#7854)

### Motivation and Context

In working with the agent's retrieve method, it felt awkward to have to
create an agent object and then use that object to then retrieve the
OpenAI assistant that was previously defined.

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

### Description

Making the retrieve method a class method on the Agent class. Refactored
some of the settings creation and client creation methods for use
between multiple methods in the agent class. Updating unit tests.
- Updates openai package to 1.39.0

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [X] The code builds clean without any errors or warnings
- [X] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [X] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄
  • Loading branch information
moonbox3 authored Aug 5, 2024
1 parent 21f8e02 commit f3433b0
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 101 deletions.
94 changes: 74 additions & 20 deletions python/semantic_kernel/agents/open_ai/azure_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
service_id: str | None = None,
deployment_name: str | None = None,
api_key: str | None = None,
endpoint: str | None = None,
endpoint: HttpsUrl | None = None,
api_version: str | None = None,
ad_token: str | None = None,
ad_token_provider: Callable[[], str | Awaitable[str]] | None = None,
Expand Down Expand Up @@ -100,23 +100,21 @@ def __init__(
Raises:
AgentInitializationError: If the api_key is not provided in the configuration.
"""
try:
azure_openai_settings = AzureOpenAISettings.create(
api_key=api_key,
endpoint=endpoint,
chat_deployment_name=deployment_name,
api_version=api_version,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as ex:
raise AgentInitializationError("Failed to create Azure OpenAI settings.", ex) from ex
azure_openai_settings = AzureAssistantAgent._create_azure_openai_settings(
api_key=api_key,
endpoint=endpoint,
deployment_name=deployment_name,
api_version=api_version,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)

if not azure_openai_settings.chat_deployment_name:
raise AgentInitializationError("The Azure OpenAI chat_deployment_name is required.")

if not azure_openai_settings.api_key and not ad_token and not ad_token_provider:
raise AgentInitializationError("Please provide either api_key, ad_token or ad_token_provider.")

client = self._create_client(
api_key=azure_openai_settings.api_key.get_secret_value() if azure_openai_settings.api_key else None,
endpoint=azure_openai_settings.endpoint,
Expand Down Expand Up @@ -165,7 +163,7 @@ async def create(
service_id: str | None = None,
deployment_name: str | None = None,
api_key: str | None = None,
endpoint: str | None = None,
endpoint: HttpsUrl | None = None,
api_version: str | None = None,
ad_token: str | None = None,
ad_token_provider: Callable[[], str | Awaitable[str]] | None = None,
Expand Down Expand Up @@ -303,6 +301,42 @@ def _create_client(
default_headers=merged_headers,
)

@staticmethod
def _create_azure_openai_settings(
api_key: str | None = None,
endpoint: HttpsUrl | None = None,
deployment_name: str | None = None,
api_version: str | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> AzureOpenAISettings:
"""Create the Azure OpenAI settings.
Args:
api_key: The Azure OpenAI API key.
endpoint: The Azure OpenAI endpoint.
deployment_name: The Azure OpenAI chat deployment name.
api_version: The Azure OpenAI API version.
env_file_path: The environment file path.
env_file_encoding: The environment file encoding.
Returns:
An instance of the AzureOpenAISettings.
"""
try:
azure_openai_settings = AzureOpenAISettings.create(
api_key=api_key,
endpoint=endpoint,
chat_deployment_name=deployment_name,
api_version=api_version,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as ex:
raise AgentInitializationError("Failed to create Azure OpenAI settings.", ex) from ex

return azure_openai_settings

async def list_definitions(self) -> AsyncIterable[dict[str, Any]]:
"""List the assistant definitions.
Expand All @@ -313,8 +347,10 @@ async def list_definitions(self) -> AsyncIterable[dict[str, Any]]:
for assistant in assistants.data:
yield self._create_open_ai_assistant_definition(assistant)

@classmethod
async def retrieve(
self,
cls,
*,
id: str,
api_key: str | None = None,
endpoint: HttpsUrl | None = None,
Expand All @@ -324,6 +360,8 @@ async def retrieve(
client: AsyncAzureOpenAI | None = None,
kernel: "Kernel | None" = None,
default_headers: dict[str, str] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> "AzureAssistantAgent":
"""Retrieve an assistant by ID.
Expand All @@ -337,20 +375,36 @@ async def retrieve(
client: The Azure OpenAI client. (optional)
kernel: The Kernel instance. (optional)
default_headers: The default headers. (optional)
env_file_path: The environment file path. (optional)
env_file_encoding: The environment file encoding. (optional)
Returns:
An OpenAIAssistantAgent instance.
An AzureAssistantAgent instance.
"""
client = self._create_client(
azure_openai_settings = AzureAssistantAgent._create_azure_openai_settings(
api_key=api_key,
endpoint=endpoint,
api_version=api_version,
ad_token=ad_token,
ad_token_provider=ad_token_provider,
default_headers=default_headers,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)

if not azure_openai_settings.chat_deployment_name:
raise AgentInitializationError("The Azure OpenAI chat_deployment_name is required.")
if not azure_openai_settings.api_key and not ad_token and not ad_token_provider:
raise AgentInitializationError("Please provide either api_key, ad_token or ad_token_provider.")

if not client:
client = AzureAssistantAgent._create_client(
api_key=api_key,
endpoint=endpoint,
api_version=api_version,
ad_token=ad_token,
ad_token_provider=ad_token_provider,
default_headers=default_headers,
)
assistant = await client.beta.assistants.retrieve(id)
assistant_definition = self._create_open_ai_assistant_definition(assistant)
assistant_definition = OpenAIAssistantBase._create_open_ai_assistant_definition(assistant)
return AzureAssistantAgent(kernel=kernel, **assistant_definition)

# endregion
89 changes: 72 additions & 17 deletions python/semantic_kernel/agents/open_ai/open_ai_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,21 +94,18 @@ def __init__(
Raises:
AgentInitializationError: If the api_key is not provided in the configuration.
"""
try:
openai_settings = OpenAISettings.create(
api_key=api_key,
org_id=org_id,
chat_model_id=ai_model_id,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as ex:
raise AgentInitializationError("Failed to create OpenAI settings.", ex) from ex
openai_settings = OpenAIAssistantAgent._create_open_ai_settings(
api_key=api_key,
org_id=org_id,
ai_model_id=ai_model_id,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)

if not client and not openai_settings.api_key:
raise AgentInitializationError("The OpenAI API key is required, if a client is not provided.")
if not openai_settings.chat_model_id:
raise AgentInitializationError("The OpenAI model ID is required.")
raise AgentInitializationError("The OpenAI chat model ID is required.")

if not client:
client = self._create_client(
Expand Down Expand Up @@ -271,6 +268,39 @@ def _create_client(
default_headers=merged_headers,
)

@staticmethod
def _create_open_ai_settings(
api_key: str | None = None,
org_id: str | None = None,
ai_model_id: str | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> OpenAISettings:
"""An internal method to create the OpenAI settings from the provided arguments.
Args:
api_key: The OpenAI API key.
org_id: The OpenAI organization ID. (optional)
ai_model_id: The AI model ID. (optional)
env_file_path: The environment file path. (optional)
env_file_encoding: The environment file encoding. (optional)
Returns:
An OpenAI settings instance.
"""
try:
openai_settings = OpenAISettings.create(
api_key=api_key,
org_id=org_id,
chat_model_id=ai_model_id,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as ex:
raise AgentInitializationError("Failed to create OpenAI settings.", ex) from ex

return openai_settings

async def list_definitions(self) -> AsyncIterable[dict[str, Any]]:
"""List the assistant definitions.
Expand All @@ -281,30 +311,55 @@ async def list_definitions(self) -> AsyncIterable[dict[str, Any]]:
for assistant in assistants.data:
yield self._create_open_ai_assistant_definition(assistant)

@classmethod
async def retrieve(
self,
cls,
*,
id: str,
api_key: str,
kernel: "Kernel | None" = None,
api_key: str | None = None,
org_id: str | None = None,
ai_model_id: str | None = None,
client: AsyncOpenAI | None = None,
default_headers: dict[str, str] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> "OpenAIAssistantAgent":
"""Retrieve an assistant by ID.
Args:
id: The assistant ID.
api_key: The OpenAI API
kernel: The Kernel instance. (optional)
api_key: The OpenAI API key. (optional)
org_id: The OpenAI organization ID. (optional)
ai_model_id: The AI model ID. (optional)
client: The OpenAI client. (optional)
default_headers: The default headers. (optional)
env_file_path: The environment file path. (optional)
env_file_encoding: The environment file encoding. (optional
Returns:
An OpenAIAssistantAgent instance.
"""
client = self._create_client(api_key=api_key, org_id=org_id, default_headers=default_headers)
openai_settings = OpenAIAssistantAgent._create_open_ai_settings(
api_key=api_key,
org_id=org_id,
ai_model_id=ai_model_id,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
if not client and not openai_settings.api_key:
raise AgentInitializationError("The OpenAI API key is required, if a client is not provided.")
if not openai_settings.chat_model_id:
raise AgentInitializationError("The OpenAI chat model ID is required.")
if not client:
client = OpenAIAssistantAgent._create_client(
api_key=openai_settings.api_key.get_secret_value() if openai_settings.api_key else None,
org_id=openai_settings.org_id,
default_headers=default_headers,
)
assistant = await client.beta.assistants.retrieve(id)
assistant_definition = self._create_open_ai_assistant_definition(assistant)
assistant_definition = OpenAIAssistantBase._create_open_ai_assistant_definition(assistant)
return OpenAIAssistantAgent(kernel=kernel, **assistant_definition)

# endregion
12 changes: 7 additions & 5 deletions python/semantic_kernel/agents/open_ai/open_ai_assistant_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class OpenAIAssistantBase(Agent):
Manages the interaction with OpenAI Assistants.
"""

_options_metadata_key: str = "__run_options"
_options_metadata_key: ClassVar[str] = "__run_options"

ai_model_id: str
client: AsyncOpenAI
Expand Down Expand Up @@ -284,7 +284,8 @@ async def create_assistant(

return self.assistant

def _create_open_ai_assistant_definition(self, assistant: "Assistant") -> dict[str, Any]:
@classmethod
def _create_open_ai_assistant_definition(cls, assistant: "Assistant") -> dict[str, Any]:
"""Create an OpenAI Assistant Definition from the provided assistant dictionary.
Args:
Expand All @@ -294,11 +295,11 @@ def _create_open_ai_assistant_definition(self, assistant: "Assistant") -> dict[s
An OpenAI Assistant Definition.
"""
execution_settings = {}
if isinstance(assistant.metadata, dict) and self._options_metadata_key in assistant.metadata:
settings_data = assistant.metadata[self._options_metadata_key]
if isinstance(assistant.metadata, dict) and OpenAIAssistantBase._options_metadata_key in assistant.metadata:
settings_data = assistant.metadata[OpenAIAssistantBase._options_metadata_key]
if isinstance(settings_data, str):
settings_data = json.loads(settings_data)
assistant.metadata[self._options_metadata_key] = settings_data
assistant.metadata[OpenAIAssistantBase._options_metadata_key] = settings_data
execution_settings = {key: value for key, value in settings_data.items()}

file_ids: list[str] = []
Expand Down Expand Up @@ -737,6 +738,7 @@ def _generate_code_interpreter_content(self, agent_name: str, code: str) -> Chat
role=AuthorRole.ASSISTANT,
content=code,
name=agent_name,
metadata={"code": True},
)

def _generate_annotation_content(self, annotation: "Annotation") -> AnnotationContent:
Expand Down
Loading

0 comments on commit f3433b0

Please sign in to comment.