diff --git a/flaml/autogen/agent/human_proxy_agent.py b/flaml/autogen/agent/human_proxy_agent.py index ebb360c68c..237330a442 100644 --- a/flaml/autogen/agent/human_proxy_agent.py +++ b/flaml/autogen/agent/human_proxy_agent.py @@ -1,5 +1,6 @@ from .agent import Agent from flaml.autogen.code_utils import extract_code, execute_code +from collections import defaultdict class HumanProxyAgent(Agent): @@ -7,7 +8,7 @@ class HumanProxyAgent(Agent): DEFAULT_SYSTEM_MESSAGE = """You are human agent. You can execute_code or give feedback to the sender. """ - MAX_TURN_NUM = 100 # maximum number of turns in one conversation session (subject to future change) + MAX_CONSECUTIVE_AUTO_REPLY = 100 # maximum number of consecutive auto replies (subject to future change) def __init__( self, @@ -15,7 +16,7 @@ def __init__( system_message="", work_dir=None, human_input_mode="ALWAYS", - max_turn_num=None, + max_consecutive_auto_reply=None, is_termination_msg=None, **config, ): @@ -29,15 +30,15 @@ def __init__( When "ALWAYS", the agent will ask for human input every time a message is received. When "TERMINATE", the agent will ask for human input only when a termination message is received. When "NEVER", the agent will never ask for human input. - max_turn_num (int): the maximum number of turns in one conversation session. - default: None (no limit provided, class attribute MAX_TURN_NUM will be used as the limit). + max_consecutive_auto_reply (int): the maximum number of consecutive auto replies. + default: None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case). The limit only plays a role when human_input_mode is not "ALWAYS". is_termination_msg (function): a function that takes a message and returns a boolean value. This function is used to determine if a received message is a termination message. config (dict): other configurations. - The conversation stops when a termination message is received or the number of turns larger than - the provided max_turn_num or the human input is "exit". + The conversation stops when the human input is "exit", or no human input is provided and a termination message is received, + or the number of consecutive auto reply is larger than the provided max_consecutive_auto_reply (when human_input_mode is not "ALWAYS"). """ super().__init__(name, system_message) self._work_dir = work_dir @@ -46,14 +47,17 @@ def __init__( is_termination_msg if is_termination_msg is not None else (lambda x: x == "TERMINATE") ) self._config = config - self._max_turn_num = max_turn_num if max_turn_num is not None else self.MAX_TURN_NUM - self._conversation_turn_counter = {} + self._max_consecutive_auto_reply = ( + max_consecutive_auto_reply if max_consecutive_auto_reply is not None else self.MAX_CONSECUTIVE_AUTO_REPLY + ) + self._consecutive_auto_reply_counter = defaultdict(int) def receive(self, message, sender): """Receive a message from the sender agent. Every time a message is received, the human agent will give feedback. - The conversation stops when a termination message is received or the number of turns larger than - the provided max_turn_num or the human input is "exit". + + The conversation stops when the human input is "exit", or no human input is provided and a termination message is received, + or the number of consecutive auto reply is larger than the provided max_consecutive_auto_reply (when human_input_mode is not "ALWAYS"). """ super().receive(message, sender) # to determine if the message is a termination message using a function @@ -63,6 +67,9 @@ def receive(self, message, sender): if self._human_input_mode == "ALWAYS" or terminate and self._human_input_mode == "TERMINATE" else "" ) + # reset the consecutive_auto_reply_counter + if self._human_input_mode != "ALWAYS" and feedback: + self._consecutive_auto_reply_counter[sender.name] = 0 if feedback and feedback != "exit": self._send(feedback, sender) elif ( @@ -70,11 +77,11 @@ def receive(self, message, sender): or feedback == "exit" or ( self._human_input_mode != "ALWAYS" - and (len(self._conversations[sender.name]) + 1) / 2 >= self._max_turn_num + and self._consecutive_auto_reply_counter[sender.name] >= self._max_consecutive_auto_reply ) ): - # note that len(self._conversations[sender.name])+1)/2 is the number of turns in the conversation return + self._consecutive_auto_reply_counter[sender.name] += 1 # try to execute the code code, lang = extract_code(message) if lang == "unknown": @@ -93,7 +100,7 @@ def receive(self, message, sender): exitcode, logs = execute_code(code, work_dir=self._work_dir, filename=filename) else: # TODO: could this happen? - exitcode = 1 + exitcode, logs = 1, "unknown language" raise NotImplementedError exitcode2str = "execution succeeded" if exitcode == 0 else "execution failed" self._send(f"exitcode: {exitcode} ({exitcode2str})\nCode output: {logs.decode('utf-8')}", sender) diff --git a/test/autogen/test_agent.py b/test/autogen/test_agent.py index 81e70805fb..9b3bec5b54 100644 --- a/test/autogen/test_agent.py +++ b/test/autogen/test_agent.py @@ -6,7 +6,7 @@ def test_extract_code(): print(extract_code("```bash\npython temp.py\n```")) -def test_coding_agent(human_input_mode="NEVER", max_turn_num=10): +def test_coding_agent(human_input_mode="NEVER", max_consecutive_auto_reply=10): try: import openai except ImportError: @@ -20,7 +20,7 @@ def test_coding_agent(human_input_mode="NEVER", max_turn_num=10): user = HumanProxyAgent( "user", human_input_mode=human_input_mode, - max_turn_num=max_turn_num, + max_consecutive_auto_reply=max_consecutive_auto_reply, is_termination_msg=lambda x: x.rstrip().endswith("TERMINATE"), ) # agent.receive("""Find $a+b+c$, given that $x+y\\neq -1$ and \\begin{align*} @@ -52,7 +52,7 @@ def test_coding_agent(human_input_mode="NEVER", max_turn_num=10): oai.ChatCompletion.stop_logging() -def test_tsp(human_input_mode="NEVER", max_turn_num=10): +def test_tsp(human_input_mode="NEVER", max_consecutive_auto_reply=10): try: import openai except ImportError: @@ -69,7 +69,10 @@ def test_tsp(human_input_mode="NEVER", max_turn_num=10): oai.ChatCompletion.start_logging() agent = PythonAgent("coding_agent", temperature=0) user = HumanProxyAgent( - "user", work_dir="test/autogen", human_input_mode=human_input_mode, max_turn_num=max_turn_num + "user", + work_dir="test/autogen", + human_input_mode=human_input_mode, + max_consecutive_auto_reply=max_consecutive_auto_reply, ) with open("test/autogen/tsp_prompt.txt", "r") as f: prompt = f.read() @@ -91,4 +94,4 @@ def test_tsp(human_input_mode="NEVER", max_turn_num=10): # openai.api_key = "" # test_extract_code() test_coding_agent(human_input_mode="TERMINATE") - test_tsp(human_input_mode="NEVER", max_turn_num=2) + test_tsp(human_input_mode="NEVER", max_consecutive_auto_reply=2) diff --git a/test/autogen/test_human_proxy_agent.py b/test/autogen/test_human_proxy_agent.py index 507311bd83..335552f8e6 100644 --- a/test/autogen/test_human_proxy_agent.py +++ b/test/autogen/test_human_proxy_agent.py @@ -12,7 +12,7 @@ def test_human_agent(): conversations = {} oai.ChatCompletion.start_logging(conversations) agent = ChatAgent("chat_agent") - user = HumanProxyAgent("human_user", human_input_mode="NEVER", max_turn_num=2) + user = HumanProxyAgent("human_user", human_input_mode="NEVER", max_consecutive_auto_reply=2) agent.receive( """Write python code to solve the equation x^3=125. You must write code in the following format. You must always print the result. Wait for me to return the result.