From 8ad60f6507e806bed7914b298310c6d10dd6eedb Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Sun, 17 Mar 2024 11:17:35 -0700 Subject: [PATCH] added more comments, fixed regex --- vision_agent/agent/reflexion.py | 49 +++++++++++++++++++++++---------- 1 file changed, 35 insertions(+), 14 deletions(-) diff --git a/vision_agent/agent/reflexion.py b/vision_agent/agent/reflexion.py index 8c89f024..4d0d05c3 100644 --- a/vision_agent/agent/reflexion.py +++ b/vision_agent/agent/reflexion.py @@ -26,8 +26,10 @@ def format_step(step: str) -> str: def parse_action(input: str) -> Tuple[str, str]: - pattern = r"^(\w+)\[(.+)\]$" - match = re.match(pattern, input) + # Make the pattern slightly less strict, the LMMs are not as good at following + # instructions so they often would fail on the original regex. + pattern = r"(\w+)\[(.+)\]" + match = re.search(pattern, input) if match: action_type = match.group(1) @@ -60,7 +62,10 @@ class Reflexion(Agent): hotpotqa folder. There are several differences between this implementation and the original one. Because we do not have instant feedback on whether or not the agent was correct, we use user feedback to determine if the agent was correct. The user - feedback is evaluated by the self_reflect_llm with a new prompt. + feedback is evaluated by the self_reflect_model with a new prompt. We also expand + Reflexion to include the ability to use an image as input to the action_agent and the + self_reflect_model. Using Reflexion with LMMs may not work well, if it gets it wrong + the first time, chances are it can't actually see the thing you want it to see. Examples:: >>> from vision_agent.agent import Reflexion @@ -76,6 +81,22 @@ class Reflexion(Agent): >>> ]) >>> print(resp) >>> "6" + >>> agent = Reflexion( + >>> self_reflect_model=va.lmm.OpenAILMM(), + >>> action_agent=va.lmm.OpenAILMM() + >>> ) + >>> quesiton = "How many hearts are in this image?" + >>> resp = agent(question, image="cards.png") + >>> print(resp) + >>> "6" + >>> resp = agent([ + >>> {"role": "user", "content": question}, + >>> {"role": "assistant", "content": resp}, + >>> {"role": "user", "content": "No, please count the hearts on the bottom card."} + >>> ], image="cards.png") + >>> print(resp) + >>> "4" + ) """ def __init__( @@ -85,7 +106,7 @@ def __init__( agent_prompt: str = COT_AGENT_REFLECT_INSTRUCTION, reflect_prompt: str = COT_REFLECT_INSTRUCTION, finsh_prompt: str = CHECK_FINSH, - self_reflect_llm: Optional[Union[LLM, LMM]] = None, + self_reflect_model: Optional[Union[LLM, LMM]] = None, action_agent: Optional[Union[Agent, LLM, LMM]] = None, verbose: bool = False, ): @@ -98,17 +119,17 @@ def __init__( if verbose: _LOGGER.setLevel(logging.INFO) - if isinstance(self_reflect_llm, LLM) and not isinstance(action_agent, LLM): + if isinstance(self_reflect_model, LLM) and not isinstance(action_agent, LLM): raise ValueError( - "If self_reflect_llm is an LLM, then action_agent must also be an LLM." + "If self_reflect_model is an LLM, then action_agent must also be an LLM." ) - if isinstance(self_reflect_llm, LMM) and isinstance(action_agent, LLM): + if isinstance(self_reflect_model, LMM) and isinstance(action_agent, LLM): raise ValueError( - "If self_reflect_llm is an LMM, then action_agent must also be an agent or LMM." + "If self_reflect_model is an LMM, then action_agent must also be an agent or LMM." ) - self.self_reflect_llm = ( - OpenAILLM() if self_reflect_llm is None else self_reflect_llm + self.self_reflect_model = ( + OpenAILLM() if self_reflect_model is None else self_reflect_model ) self.action_agent = OpenAILLM() if action_agent is None else action_agent @@ -146,7 +167,7 @@ def chat( self.last_scratchpad += "\nObservation: " if is_correct: self.last_scratchpad += "Answer is CORRECT" - return self.self_reflect_llm(chat) + return self.self_reflect_model(chat) else: self.last_scratchpad += "Answer is INCORRECT" chat_context = "The previous conversation was:\n" + chat_str @@ -220,14 +241,14 @@ def prompt_reflection( scratchpad: str = "", image: Optional[Union[str, Path]] = None, ) -> str: - if isinstance(self.self_reflect_llm, LLM): + if isinstance(self.self_reflect_model, LLM): return format_step( - self.self_reflect_llm( + self.self_reflect_model( self._build_reflect_prompt(question, context, scratchpad) ) ) return format_step( - self.self_reflect_llm( + self.self_reflect_model( self._build_reflect_prompt(question, context, scratchpad), image=image, )