Skip to content

Commit

Permalink
feat: ChatAgent interface enhancement and default model setting (#1101)
Browse files Browse the repository at this point in the history
Co-authored-by: Isaac Jin <[email protected]>
  • Loading branch information
Wendong-Fan and WHALEEYE authored Oct 22, 2024
1 parent 87c483a commit a247a88
Show file tree
Hide file tree
Showing 31 changed files with 111 additions and 90 deletions.
49 changes: 32 additions & 17 deletions camel/agents/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ class ChatAgent(BaseAgent):
r"""Class for managing conversations of CAMEL Chat Agents.
Args:
system_message (BaseMessage, optional): The system message for the
chat agent.
system_message (Union[BaseMessage, str], optional): The system message
for the chat agent.
model (BaseModelBackend, optional): The model backend to use for
generating responses. (default: :obj:`OpenAIModel` with
`GPT_4O_MINI`)
Expand Down Expand Up @@ -144,7 +144,7 @@ class ChatAgent(BaseAgent):

def __init__(
self,
system_message: Optional[BaseMessage] = None,
system_message: Optional[Union[BaseMessage, str]] = None,
model: Optional[BaseModelBackend] = None,
memory: Optional[AgentMemory] = None,
message_window_size: Optional[int] = None,
Expand All @@ -154,6 +154,11 @@ def __init__(
external_tools: Optional[List[FunctionTool]] = None,
response_terminators: Optional[List[ResponseTerminator]] = None,
) -> None:
if isinstance(system_message, str):
system_message = BaseMessage.make_assistant_message(
role_name='Assistant', content=system_message
)

self.orig_sys_message: Optional[BaseMessage] = system_message
self._system_message: Optional[BaseMessage] = system_message
self.role_name: str = (
Expand All @@ -166,8 +171,8 @@ def __init__(
model
if model is not None
else ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_4O_MINI,
model_platform=ModelPlatformType.DEFAULT,
model_type=ModelType.DEFAULT,
)
)
self.output_language: Optional[str] = output_language
Expand Down Expand Up @@ -414,18 +419,18 @@ def record_message(self, message: BaseMessage) -> None:

def step(
self,
input_message: BaseMessage,
input_message: Union[BaseMessage, str],
response_format: Optional[Type[BaseModel]] = None,
) -> ChatAgentResponse:
r"""Performs a single step in the chat session by generating a response
to the input message.
Args:
input_message (BaseMessage): The input message to the agent.
Its `role` field that specifies the role at backend may be
either `user` or `assistant` but it will be set to `user`
anyway since for the self agent any incoming message is
external.
input_message (Union[BaseMessage, str]): The input message to the
agent. For BaseMessage input, its `role` field that specifies
the role at backend may be either `user` or `assistant` but it
will be set to `user` anyway since for the self agent any
incoming message is external. For str input, the `role_name` would be `User`.
response_format (Optional[Type[BaseModel]], optional): A pydantic
model class that includes value types and field descriptions
used to generate a structured response by LLM. This schema
Expand All @@ -437,6 +442,11 @@ def step(
a boolean indicating whether the chat session has terminated,
and information about the chat session.
"""
if isinstance(input_message, str):
input_message = BaseMessage.make_user_message(
role_name='User', content=input_message
)

if "llama" in self.model_type.lower():
if self.model_backend.model_config_dict.get("tools", None):
tool_prompt = self._generate_tool_prompt(self.tool_schema_list)
Expand Down Expand Up @@ -646,18 +656,18 @@ def step(

async def step_async(
self,
input_message: BaseMessage,
input_message: Union[BaseMessage, str],
response_format: Optional[Type[BaseModel]] = None,
) -> ChatAgentResponse:
r"""Performs a single step in the chat session by generating a response
to the input message. This agent step can call async function calls.
Args:
input_message (BaseMessage): The input message to the agent.
Its `role` field that specifies the role at backend may be
either `user` or `assistant` but it will be set to `user`
anyway since for the self agent any incoming message is
external.
input_message (Union[BaseMessage, str]): The input message to the
agent. For BaseMessage input, its `role` field that specifies
the role at backend may be either `user` or `assistant` but it
will be set to `user` anyway since for the self agent any
incoming message is external. For str input, the `role_name` would be `User`.
response_format (Optional[Type[BaseModel]], optional): A pydantic
model class that includes value types and field descriptions
used to generate a structured response by LLM. This schema
Expand All @@ -669,6 +679,11 @@ async def step_async(
a boolean indicating whether the chat session has terminated,
and information about the chat session.
"""
if isinstance(input_message, str):
input_message = BaseMessage.make_user_message(
role_name='User', content=input_message
)

self.update_memory(input_message, OpenAIBackendRole.USER)

tool_call_records: List[FunctionCallingRecord] = []
Expand Down
5 changes: 4 additions & 1 deletion camel/types/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class RoleType(Enum):


class ModelType(UnifiedModelType, Enum):
DEFAULT = "gpt-4o-mini"

GPT_3_5_TURBO = "gpt-3.5-turbo"
GPT_4 = "gpt-4"
GPT_4_TURBO = "gpt-4-turbo"
Expand Down Expand Up @@ -427,14 +429,15 @@ class OpenAPIName(Enum):


class ModelPlatformType(Enum):
DEFAULT = "openai"

OPENAI = "openai"
AZURE = "azure"
ANTHROPIC = "anthropic"
GROQ = "groq"
OLLAMA = "ollama"
LITELLM = "litellm"
ZHIPU = "zhipuai"
DEFAULT = "default"
GEMINI = "gemini"
VLLM = "vllm"
MISTRAL = "mistral"
Expand Down
4 changes: 2 additions & 2 deletions camel/workforce/workforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,8 @@ def _create_new_agent(self, role: str, sys_msg: str) -> ChatAgent:
).as_dict()

model = ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_4O,
model_platform=ModelPlatformType.DEFAULT,
model_type=ModelType.DEFAULT,
model_config_dict=model_config_dict,
)

Expand Down
4 changes: 2 additions & 2 deletions examples/ai_society/role_playing_multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def generate_data(
original_task_prompt = task_prompt.replace(f"{task_idx+1}. ", "")

model = ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_4O_MINI,
model_platform=ModelPlatformType.DEFAULT,
model_type=ModelType.DEFAULT,
model_config_dict=ChatGPTConfig(temperature=1.4).as_dict(),
)

Expand Down
4 changes: 2 additions & 2 deletions examples/ai_society/role_playing_with_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
def main() -> None:
task_prompt = "Write a research proposal for large-scale language models"
model = ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_4O_MINI,
model_platform=ModelPlatformType.DEFAULT,
model_type=ModelType.DEFAULT,
model_config_dict=ChatGPTConfig(temperature=0.8, n=3).as_dict(),
)
assistant_agent_kwargs = dict(model=model)
Expand Down
4 changes: 2 additions & 2 deletions examples/ai_society/role_playing_with_human.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
def main() -> None:
task_prompt = "Write a book about the future of AI Society"
model = ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_4O_MINI,
model_platform=ModelPlatformType.DEFAULT,
model_type=ModelType.DEFAULT,
model_config_dict=ChatGPTConfig(temperature=1.4, n=3).as_dict(),
)
assistant_agent_kwargs = dict(model=model)
Expand Down
4 changes: 2 additions & 2 deletions examples/external_tools/use_external_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def main():
).as_dict()

model = ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_4O_MINI,
model_platform=ModelPlatformType.DEFAULT,
model_type=ModelType.DEFAULT,
model_config_dict=model_config_dict,
)

Expand Down
4 changes: 2 additions & 2 deletions examples/generate_text_embedding_data/single_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def main() -> None:
role_name="User", content="Start to generate!"
)
model = ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_4O_MINI,
model_platform=ModelPlatformType.DEFAULT,
model_type=ModelType.DEFAULT,
model_config_dict=ChatGPTConfig(
temperature=0.0, response_format={"type": "json_object"}
).as_dict(),
Expand Down
4 changes: 2 additions & 2 deletions examples/generate_text_embedding_data/task_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def main() -> None:
)

model = ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_4O_MINI,
model_platform=ModelPlatformType.DEFAULT,
model_type=ModelType.DEFAULT,
model_config_dict=ChatGPTConfig(temperature=0.0).as_dict(),
)
agent = ChatAgent(
Expand Down
4 changes: 2 additions & 2 deletions examples/misalignment/role_playing_multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def generate_data(
task_type=TaskType.MISALIGNMENT,
task_specify_agent_kwargs=dict(
model=ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_4O_MINI,
model_platform=ModelPlatformType.DEFAULT,
model_type=ModelType.DEFAULT,
model_config_dict=ChatGPTConfig(temperature=1.4).as_dict(),
)
),
Expand Down
4 changes: 2 additions & 2 deletions examples/misalignment/role_playing_with_human.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
def main() -> None:
task_prompt = "Escape from human control"
model = ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_4O_MINI,
model_platform=ModelPlatformType.DEFAULT,
model_type=ModelType.DEFAULT,
model_config_dict=ChatGPTConfig(temperature=1.4, n=3).as_dict(),
)
assistant_agent_kwargs = dict(model=model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
)

# Set up role playing session
model_platform = ModelPlatformType.OPENAI
model_type = ModelType.GPT_4O_MINI
model_platform = ModelPlatformType.DEFAULT
model_type = ModelType.DEFAULT
chat_turn_limit = 10
task_prompt = (
"Assume now is 2024 in the Gregorian calendar, "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
)

model = ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_4O_MINI,
model_platform=ModelPlatformType.DEFAULT,
model_type=ModelType.DEFAULT,
model_config_dict=assistant_model_config.as_dict(),
)

Expand Down
4 changes: 2 additions & 2 deletions examples/structured_response/json_format_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
)

model = ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_4O_MINI,
model_platform=ModelPlatformType.DEFAULT,
model_type=ModelType.DEFAULT,
)

# Set agent
Expand Down
4 changes: 2 additions & 2 deletions examples/tasks/task_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
)

model = ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_4O_MINI,
model_platform=ModelPlatformType.DEFAULT,
model_type=ModelType.DEFAULT,
model_config_dict=assistant_model_config.as_dict(),
)

Expand Down
4 changes: 2 additions & 2 deletions examples/tool_call/arxiv_toolkit_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
).as_dict()

model = ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_4O_MINI,
model_platform=ModelPlatformType.DEFAULT,
model_type=ModelType.DEFAULT,
model_config_dict=model_config_dict,
)

Expand Down
4 changes: 2 additions & 2 deletions examples/tool_call/code_execution_toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
)

model = ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_4O_MINI,
model_platform=ModelPlatformType.DEFAULT,
model_type=ModelType.DEFAULT,
model_config_dict=assistant_model_config.as_dict(),
)

Expand Down
8 changes: 4 additions & 4 deletions examples/tool_call/github_toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def write_weekly_pr_summary(repo_name, model=None):
assistant_model_config_dict = ChatGPTConfig(temperature=0.0).as_dict()

assistant_model = ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_4O,
model_platform=ModelPlatformType.DEFAULT,
model_type=ModelType.DEFAULT,
model_config_dict=assistant_model_config_dict,
)

Expand Down Expand Up @@ -116,8 +116,8 @@ def solve_issue(
).as_dict()

model = ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_4O_MINI,
model_platform=ModelPlatformType.DEFAULT,
model_type=ModelType.DEFAULT,
model_config_dict=assistant_model_config_dict,
)

Expand Down
4 changes: 2 additions & 2 deletions examples/tool_call/google_scholar_toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
).as_dict()

model = ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_4O_MINI,
model_platform=ModelPlatformType.DEFAULT,
model_type=ModelType.DEFAULT,
model_config_dict=model_config_dict,
)

Expand Down
4 changes: 2 additions & 2 deletions examples/tool_call/openapi_toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
).as_dict()

model = ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_4O,
model_platform=ModelPlatformType.DEFAULT,
model_type=ModelType.DEFAULT,
model_config_dict=model_config_dict,
)

Expand Down
4 changes: 2 additions & 2 deletions examples/tool_call/role_playing_with_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@


def main(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_4O_MINI,
model_platform=ModelPlatformType.DEFAULT,
model_type=ModelType.DEFAULT,
chat_turn_limit=10,
) -> None:
task_prompt = (
Expand Down
4 changes: 2 additions & 2 deletions examples/translation/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def translate_content(
model_config = ChatGPTConfig(stream=True)

model = ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_4O_MINI,
model_platform=ModelPlatformType.DEFAULT,
model_type=ModelType.DEFAULT,
model_config=model_config,
)

Expand Down
4 changes: 2 additions & 2 deletions examples/vision/image_crafting.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def main():
)

model = ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_4O_MINI,
model_platform=ModelPlatformType.DEFAULT,
model_type=ModelType.DEFAULT,
)

dalle_agent = ChatAgent(
Expand Down
4 changes: 2 additions & 2 deletions examples/vision/multi_condition_image_crafting.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def main(image_paths: list[str]) -> list[str]:
)

model = ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_4O_MINI,
model_platform=ModelPlatformType.DEFAULT,
model_type=ModelType.DEFAULT,
)

dalle_agent = ChatAgent(
Expand Down
Loading

0 comments on commit a247a88

Please sign in to comment.