From a247a8892374db6f338d59aa76685dccbee9143b Mon Sep 17 00:00:00 2001 From: Wendong-Fan <133094783+Wendong-Fan@users.noreply.github.com> Date: Wed, 23 Oct 2024 02:37:42 +0800 Subject: [PATCH] feat: ChatAgent interface enhancement and default model setting (#1101) Co-authored-by: Isaac Jin --- camel/agents/chat_agent.py | 49 ++++++++++++------- camel/types/enums.py | 5 +- camel/workforce/workforce.py | 4 +- .../ai_society/role_playing_multiprocess.py | 4 +- .../ai_society/role_playing_with_critic.py | 4 +- .../ai_society/role_playing_with_human.py | 4 +- examples/external_tools/use_external_tools.py | 4 +- .../single_agent.py | 4 +- .../task_generation.py | 4 +- .../misalignment/role_playing_multiprocess.py | 4 +- .../misalignment/role_playing_with_human.py | 4 +- ...gentops_track_roleplaying_with_function.py | 4 +- .../json_format_reponse_with_tools.py | 4 +- .../json_format_response.py | 4 +- examples/tasks/task_generation.py | 4 +- examples/tool_call/arxiv_toolkit_example.py | 4 +- examples/tool_call/code_execution_toolkit.py | 4 +- examples/tool_call/github_toolkit.py | 8 +-- examples/tool_call/google_scholar_toolkit.py | 4 +- examples/tool_call/openapi_toolkit.py | 4 +- .../tool_call/role_playing_with_functions.py | 4 +- examples/translation/translator.py | 4 +- examples/vision/image_crafting.py | 4 +- .../vision/multi_condition_image_crafting.py | 4 +- examples/vision/multi_turn_image_refining.py | 4 +- examples/vision/object_recognition.py | 4 +- examples/vision/video_description.py | 4 +- examples/workforce/hackathon_judges.py | 8 +-- examples/workforce/multiple_single_agents.py | 12 ++--- .../workforce/role_playing_with_agents.py | 4 +- test/agents/test_chat_agent.py | 19 ++++--- 31 files changed, 111 insertions(+), 90 deletions(-) diff --git a/camel/agents/chat_agent.py b/camel/agents/chat_agent.py index fc9a9173cd..bfe75a662e 100644 --- a/camel/agents/chat_agent.py +++ b/camel/agents/chat_agent.py @@ -114,8 +114,8 @@ class ChatAgent(BaseAgent): r"""Class for managing conversations of CAMEL Chat Agents. Args: - system_message (BaseMessage, optional): The system message for the - chat agent. + system_message (Union[BaseMessage, str], optional): The system message + for the chat agent. model (BaseModelBackend, optional): The model backend to use for generating responses. (default: :obj:`OpenAIModel` with `GPT_4O_MINI`) @@ -144,7 +144,7 @@ class ChatAgent(BaseAgent): def __init__( self, - system_message: Optional[BaseMessage] = None, + system_message: Optional[Union[BaseMessage, str]] = None, model: Optional[BaseModelBackend] = None, memory: Optional[AgentMemory] = None, message_window_size: Optional[int] = None, @@ -154,6 +154,11 @@ def __init__( external_tools: Optional[List[FunctionTool]] = None, response_terminators: Optional[List[ResponseTerminator]] = None, ) -> None: + if isinstance(system_message, str): + system_message = BaseMessage.make_assistant_message( + role_name='Assistant', content=system_message + ) + self.orig_sys_message: Optional[BaseMessage] = system_message self._system_message: Optional[BaseMessage] = system_message self.role_name: str = ( @@ -166,8 +171,8 @@ def __init__( model if model is not None else ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O_MINI, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, ) ) self.output_language: Optional[str] = output_language @@ -414,18 +419,18 @@ def record_message(self, message: BaseMessage) -> None: def step( self, - input_message: BaseMessage, + input_message: Union[BaseMessage, str], response_format: Optional[Type[BaseModel]] = None, ) -> ChatAgentResponse: r"""Performs a single step in the chat session by generating a response to the input message. Args: - input_message (BaseMessage): The input message to the agent. - Its `role` field that specifies the role at backend may be - either `user` or `assistant` but it will be set to `user` - anyway since for the self agent any incoming message is - external. + input_message (Union[BaseMessage, str]): The input message to the + agent. For BaseMessage input, its `role` field that specifies + the role at backend may be either `user` or `assistant` but it + will be set to `user` anyway since for the self agent any + incoming message is external. For str input, the `role_name` would be `User`. response_format (Optional[Type[BaseModel]], optional): A pydantic model class that includes value types and field descriptions used to generate a structured response by LLM. This schema @@ -437,6 +442,11 @@ def step( a boolean indicating whether the chat session has terminated, and information about the chat session. """ + if isinstance(input_message, str): + input_message = BaseMessage.make_user_message( + role_name='User', content=input_message + ) + if "llama" in self.model_type.lower(): if self.model_backend.model_config_dict.get("tools", None): tool_prompt = self._generate_tool_prompt(self.tool_schema_list) @@ -646,18 +656,18 @@ def step( async def step_async( self, - input_message: BaseMessage, + input_message: Union[BaseMessage, str], response_format: Optional[Type[BaseModel]] = None, ) -> ChatAgentResponse: r"""Performs a single step in the chat session by generating a response to the input message. This agent step can call async function calls. Args: - input_message (BaseMessage): The input message to the agent. - Its `role` field that specifies the role at backend may be - either `user` or `assistant` but it will be set to `user` - anyway since for the self agent any incoming message is - external. + input_message (Union[BaseMessage, str]): The input message to the + agent. For BaseMessage input, its `role` field that specifies + the role at backend may be either `user` or `assistant` but it + will be set to `user` anyway since for the self agent any + incoming message is external. For str input, the `role_name` would be `User`. response_format (Optional[Type[BaseModel]], optional): A pydantic model class that includes value types and field descriptions used to generate a structured response by LLM. This schema @@ -669,6 +679,11 @@ async def step_async( a boolean indicating whether the chat session has terminated, and information about the chat session. """ + if isinstance(input_message, str): + input_message = BaseMessage.make_user_message( + role_name='User', content=input_message + ) + self.update_memory(input_message, OpenAIBackendRole.USER) tool_call_records: List[FunctionCallingRecord] = [] diff --git a/camel/types/enums.py b/camel/types/enums.py index 12274583d0..236c2ede59 100644 --- a/camel/types/enums.py +++ b/camel/types/enums.py @@ -26,6 +26,8 @@ class RoleType(Enum): class ModelType(UnifiedModelType, Enum): + DEFAULT = "gpt-4o-mini" + GPT_3_5_TURBO = "gpt-3.5-turbo" GPT_4 = "gpt-4" GPT_4_TURBO = "gpt-4-turbo" @@ -427,6 +429,8 @@ class OpenAPIName(Enum): class ModelPlatformType(Enum): + DEFAULT = "openai" + OPENAI = "openai" AZURE = "azure" ANTHROPIC = "anthropic" @@ -434,7 +438,6 @@ class ModelPlatformType(Enum): OLLAMA = "ollama" LITELLM = "litellm" ZHIPU = "zhipuai" - DEFAULT = "default" GEMINI = "gemini" VLLM = "vllm" MISTRAL = "mistral" diff --git a/camel/workforce/workforce.py b/camel/workforce/workforce.py index 5e967542d5..6b0076c624 100644 --- a/camel/workforce/workforce.py +++ b/camel/workforce/workforce.py @@ -364,8 +364,8 @@ def _create_new_agent(self, role: str, sys_msg: str) -> ChatAgent: ).as_dict() model = ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, model_config_dict=model_config_dict, ) diff --git a/examples/ai_society/role_playing_multiprocess.py b/examples/ai_society/role_playing_multiprocess.py index 9666985047..dfa45bd16b 100644 --- a/examples/ai_society/role_playing_multiprocess.py +++ b/examples/ai_society/role_playing_multiprocess.py @@ -40,8 +40,8 @@ def generate_data( original_task_prompt = task_prompt.replace(f"{task_idx+1}. ", "") model = ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O_MINI, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, model_config_dict=ChatGPTConfig(temperature=1.4).as_dict(), ) diff --git a/examples/ai_society/role_playing_with_critic.py b/examples/ai_society/role_playing_with_critic.py index fc0b09fa04..71d7564080 100644 --- a/examples/ai_society/role_playing_with_critic.py +++ b/examples/ai_society/role_playing_with_critic.py @@ -23,8 +23,8 @@ def main() -> None: task_prompt = "Write a research proposal for large-scale language models" model = ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O_MINI, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, model_config_dict=ChatGPTConfig(temperature=0.8, n=3).as_dict(), ) assistant_agent_kwargs = dict(model=model) diff --git a/examples/ai_society/role_playing_with_human.py b/examples/ai_society/role_playing_with_human.py index 7043b00a7c..62b0a70a5b 100644 --- a/examples/ai_society/role_playing_with_human.py +++ b/examples/ai_society/role_playing_with_human.py @@ -23,8 +23,8 @@ def main() -> None: task_prompt = "Write a book about the future of AI Society" model = ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O_MINI, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, model_config_dict=ChatGPTConfig(temperature=1.4, n=3).as_dict(), ) assistant_agent_kwargs = dict(model=model) diff --git a/examples/external_tools/use_external_tools.py b/examples/external_tools/use_external_tools.py index 493228b0a4..760699f2d5 100644 --- a/examples/external_tools/use_external_tools.py +++ b/examples/external_tools/use_external_tools.py @@ -32,8 +32,8 @@ def main(): ).as_dict() model = ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O_MINI, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, model_config_dict=model_config_dict, ) diff --git a/examples/generate_text_embedding_data/single_agent.py b/examples/generate_text_embedding_data/single_agent.py index 8b4f5346a8..1d08f32d06 100644 --- a/examples/generate_text_embedding_data/single_agent.py +++ b/examples/generate_text_embedding_data/single_agent.py @@ -66,8 +66,8 @@ def main() -> None: role_name="User", content="Start to generate!" ) model = ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O_MINI, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, model_config_dict=ChatGPTConfig( temperature=0.0, response_format={"type": "json_object"} ).as_dict(), diff --git a/examples/generate_text_embedding_data/task_generation.py b/examples/generate_text_embedding_data/task_generation.py index 45ec54babd..18897f8ab1 100644 --- a/examples/generate_text_embedding_data/task_generation.py +++ b/examples/generate_text_embedding_data/task_generation.py @@ -36,8 +36,8 @@ def main() -> None: ) model = ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O_MINI, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, model_config_dict=ChatGPTConfig(temperature=0.0).as_dict(), ) agent = ChatAgent( diff --git a/examples/misalignment/role_playing_multiprocess.py b/examples/misalignment/role_playing_multiprocess.py index 6d89894187..2ccc15b6f1 100644 --- a/examples/misalignment/role_playing_multiprocess.py +++ b/examples/misalignment/role_playing_multiprocess.py @@ -46,8 +46,8 @@ def generate_data( task_type=TaskType.MISALIGNMENT, task_specify_agent_kwargs=dict( model=ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O_MINI, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, model_config_dict=ChatGPTConfig(temperature=1.4).as_dict(), ) ), diff --git a/examples/misalignment/role_playing_with_human.py b/examples/misalignment/role_playing_with_human.py index dba3b13307..db6c533122 100644 --- a/examples/misalignment/role_playing_with_human.py +++ b/examples/misalignment/role_playing_with_human.py @@ -23,8 +23,8 @@ def main() -> None: task_prompt = "Escape from human control" model = ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O_MINI, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, model_config_dict=ChatGPTConfig(temperature=1.4, n=3).as_dict(), ) assistant_agent_kwargs = dict(model=model) diff --git a/examples/observability/agentops_track_roleplaying_with_function.py b/examples/observability/agentops_track_roleplaying_with_function.py index fe4d5e4eb8..8071694f36 100644 --- a/examples/observability/agentops_track_roleplaying_with_function.py +++ b/examples/observability/agentops_track_roleplaying_with_function.py @@ -35,8 +35,8 @@ ) # Set up role playing session -model_platform = ModelPlatformType.OPENAI -model_type = ModelType.GPT_4O_MINI +model_platform = ModelPlatformType.DEFAULT +model_type = ModelType.DEFAULT chat_turn_limit = 10 task_prompt = ( "Assume now is 2024 in the Gregorian calendar, " diff --git a/examples/structured_response/json_format_reponse_with_tools.py b/examples/structured_response/json_format_reponse_with_tools.py index 44814249c1..8af7178e15 100644 --- a/examples/structured_response/json_format_reponse_with_tools.py +++ b/examples/structured_response/json_format_reponse_with_tools.py @@ -39,8 +39,8 @@ ) model = ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O_MINI, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, model_config_dict=assistant_model_config.as_dict(), ) diff --git a/examples/structured_response/json_format_response.py b/examples/structured_response/json_format_response.py index 33b8545d61..9d0b1c76b9 100644 --- a/examples/structured_response/json_format_response.py +++ b/examples/structured_response/json_format_response.py @@ -26,8 +26,8 @@ ) model = ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O_MINI, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, ) # Set agent diff --git a/examples/tasks/task_generation.py b/examples/tasks/task_generation.py index 9764e22a41..2df2f6b8d3 100644 --- a/examples/tasks/task_generation.py +++ b/examples/tasks/task_generation.py @@ -31,8 +31,8 @@ ) model = ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O_MINI, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, model_config_dict=assistant_model_config.as_dict(), ) diff --git a/examples/tool_call/arxiv_toolkit_example.py b/examples/tool_call/arxiv_toolkit_example.py index 54350a3ae9..66d3d5abad 100644 --- a/examples/tool_call/arxiv_toolkit_example.py +++ b/examples/tool_call/arxiv_toolkit_example.py @@ -31,8 +31,8 @@ ).as_dict() model = ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O_MINI, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, model_config_dict=model_config_dict, ) diff --git a/examples/tool_call/code_execution_toolkit.py b/examples/tool_call/code_execution_toolkit.py index 1b937406b9..1f93e2cef8 100644 --- a/examples/tool_call/code_execution_toolkit.py +++ b/examples/tool_call/code_execution_toolkit.py @@ -32,8 +32,8 @@ ) model = ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O_MINI, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, model_config_dict=assistant_model_config.as_dict(), ) diff --git a/examples/tool_call/github_toolkit.py b/examples/tool_call/github_toolkit.py index 833b0c712c..9c1d0965eb 100644 --- a/examples/tool_call/github_toolkit.py +++ b/examples/tool_call/github_toolkit.py @@ -62,8 +62,8 @@ def write_weekly_pr_summary(repo_name, model=None): assistant_model_config_dict = ChatGPTConfig(temperature=0.0).as_dict() assistant_model = ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, model_config_dict=assistant_model_config_dict, ) @@ -116,8 +116,8 @@ def solve_issue( ).as_dict() model = ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O_MINI, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, model_config_dict=assistant_model_config_dict, ) diff --git a/examples/tool_call/google_scholar_toolkit.py b/examples/tool_call/google_scholar_toolkit.py index 6692b641ae..f20cb7fed4 100644 --- a/examples/tool_call/google_scholar_toolkit.py +++ b/examples/tool_call/google_scholar_toolkit.py @@ -33,8 +33,8 @@ ).as_dict() model = ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O_MINI, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, model_config_dict=model_config_dict, ) diff --git a/examples/tool_call/openapi_toolkit.py b/examples/tool_call/openapi_toolkit.py index 7b1499b611..ada25637fb 100644 --- a/examples/tool_call/openapi_toolkit.py +++ b/examples/tool_call/openapi_toolkit.py @@ -30,8 +30,8 @@ ).as_dict() model = ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, model_config_dict=model_config_dict, ) diff --git a/examples/tool_call/role_playing_with_functions.py b/examples/tool_call/role_playing_with_functions.py index aef45236d7..f0016fb655 100644 --- a/examples/tool_call/role_playing_with_functions.py +++ b/examples/tool_call/role_playing_with_functions.py @@ -29,8 +29,8 @@ def main( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O_MINI, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, chat_turn_limit=10, ) -> None: task_prompt = ( diff --git a/examples/translation/translator.py b/examples/translation/translator.py index c485cfc5e8..f8942eb9b1 100644 --- a/examples/translation/translator.py +++ b/examples/translation/translator.py @@ -124,8 +124,8 @@ def translate_content( model_config = ChatGPTConfig(stream=True) model = ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O_MINI, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, model_config=model_config, ) diff --git a/examples/vision/image_crafting.py b/examples/vision/image_crafting.py index 7ebe2f4fe7..4534878b6c 100644 --- a/examples/vision/image_crafting.py +++ b/examples/vision/image_crafting.py @@ -43,8 +43,8 @@ def main(): ) model = ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O_MINI, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, ) dalle_agent = ChatAgent( diff --git a/examples/vision/multi_condition_image_crafting.py b/examples/vision/multi_condition_image_crafting.py index 3b5b3d51e7..c0f33b0d72 100644 --- a/examples/vision/multi_condition_image_crafting.py +++ b/examples/vision/multi_condition_image_crafting.py @@ -40,8 +40,8 @@ def main(image_paths: list[str]) -> list[str]: ) model = ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O_MINI, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, ) dalle_agent = ChatAgent( diff --git a/examples/vision/multi_turn_image_refining.py b/examples/vision/multi_turn_image_refining.py index c2853206d4..eec8dc5c6c 100644 --- a/examples/vision/multi_turn_image_refining.py +++ b/examples/vision/multi_turn_image_refining.py @@ -68,8 +68,8 @@ def __init__( def init_agents(self): r"""Initialize artist and critic agents with their system messages.""" model = ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O_MINI, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, ) self.artist = ChatAgent( diff --git a/examples/vision/object_recognition.py b/examples/vision/object_recognition.py index 6d9375fc4a..910bb05bd8 100644 --- a/examples/vision/object_recognition.py +++ b/examples/vision/object_recognition.py @@ -51,8 +51,8 @@ def detect_image_obj(image_paths: str) -> None: content=sys_msg, ) model = ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O_MINI, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, ) agent = ChatAgent( assistant_sys_msg, diff --git a/examples/vision/video_description.py b/examples/vision/video_description.py index 8223608e50..d39e7bd8fc 100644 --- a/examples/vision/video_description.py +++ b/examples/vision/video_description.py @@ -29,8 +29,8 @@ ) model = ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, ) # Set agent diff --git a/examples/workforce/hackathon_judges.py b/examples/workforce/hackathon_judges.py index d8f6f5bf66..24b9e16d69 100644 --- a/examples/workforce/hackathon_judges.py +++ b/examples/workforce/hackathon_judges.py @@ -45,8 +45,8 @@ def make_judge( ) model = ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, ) agent = ChatAgent( @@ -79,8 +79,8 @@ def main(): ] researcher_model = ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, ) researcher_agent = ChatAgent( diff --git a/examples/workforce/multiple_single_agents.py b/examples/workforce/multiple_single_agents.py index c2f4c2d1a1..34627c9007 100644 --- a/examples/workforce/multiple_single_agents.py +++ b/examples/workforce/multiple_single_agents.py @@ -35,8 +35,8 @@ def main(): # Set up web searching agent search_agent_model = ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, ) search_agent = ChatAgent( system_message=BaseMessage.make_assistant_message( @@ -49,8 +49,8 @@ def main(): # Set up tour guide agent tour_guide_agent_model = ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, ) tour_guide_agent = ChatAgent( @@ -69,8 +69,8 @@ def main(): content="You can ask questions about your travel plans", ), model=ModelFactory.create( - model_platform=ModelPlatformType.OPENAI, - model_type=ModelType.GPT_4O, + model_platform=ModelPlatformType.DEFAULT, + model_type=ModelType.DEFAULT, ), ) diff --git a/examples/workforce/role_playing_with_agents.py b/examples/workforce/role_playing_with_agents.py index 2dcd2dab9b..64f56adb3e 100644 --- a/examples/workforce/role_playing_with_agents.py +++ b/examples/workforce/role_playing_with_agents.py @@ -41,8 +41,8 @@ def main(): *GoogleMapsToolkit().get_tools(), ] - model_platform = ModelPlatformType.OPENAI - model_type = ModelType.GPT_4O_MINI + model_platform = ModelPlatformType.DEFAULT + model_type = ModelType.DEFAULT assistant_role_name = "Searcher" user_role_name = "Professor" assistant_agent_kwargs = dict( diff --git a/test/agents/test_chat_agent.py b/test/agents/test_chat_agent.py index 342c298a8a..8350d33d74 100644 --- a/test/agents/test_chat_agent.py +++ b/test/agents/test_chat_agent.py @@ -83,21 +83,24 @@ def test_chat_agent(model): for assistant in [assistant_with_sys_msg, assistant_without_sys_msg]: assistant.reset() - user_msg = BaseMessage( + user_msg_bm = BaseMessage( role_name="Patient", role_type=RoleType.USER, meta_dict=dict(), content="Hello!", ) + user_msg_str = "Hello!" + for assistant in [assistant_with_sys_msg, assistant_without_sys_msg]: - response = assistant.step(user_msg) - assert isinstance(response.msgs, list) - assert len(response.msgs) > 0 - assert isinstance(response.terminated, bool) - assert response.terminated is False - assert isinstance(response.info, dict) - assert response.info['id'] is not None + for user_msg in [user_msg_bm, user_msg_str]: + response = assistant.step(user_msg) + assert isinstance(response.msgs, list) + assert len(response.msgs) > 0 + assert isinstance(response.terminated, bool) + assert response.terminated is False + assert isinstance(response.info, dict) + assert response.info['id'] is not None @pytest.mark.model_backend