Skip to content

Commit

Permalink
added more comments, fixed regex
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Mar 17, 2024
1 parent a852354 commit 8ad60f6
Showing 1 changed file with 35 additions and 14 deletions.
49 changes: 35 additions & 14 deletions vision_agent/agent/reflexion.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ def format_step(step: str) -> str:


def parse_action(input: str) -> Tuple[str, str]:
pattern = r"^(\w+)\[(.+)\]$"
match = re.match(pattern, input)
# Make the pattern slightly less strict, the LMMs are not as good at following
# instructions so they often would fail on the original regex.
pattern = r"(\w+)\[(.+)\]"
match = re.search(pattern, input)

if match:
action_type = match.group(1)
Expand Down Expand Up @@ -60,7 +62,10 @@ class Reflexion(Agent):
hotpotqa folder. There are several differences between this implementation and the
original one. Because we do not have instant feedback on whether or not the agent
was correct, we use user feedback to determine if the agent was correct. The user
feedback is evaluated by the self_reflect_llm with a new prompt.
feedback is evaluated by the self_reflect_model with a new prompt. We also expand
Reflexion to include the ability to use an image as input to the action_agent and the
self_reflect_model. Using Reflexion with LMMs may not work well, if it gets it wrong
the first time, chances are it can't actually see the thing you want it to see.
Examples::
>>> from vision_agent.agent import Reflexion
Expand All @@ -76,6 +81,22 @@ class Reflexion(Agent):
>>> ])
>>> print(resp)
>>> "6"
>>> agent = Reflexion(
>>> self_reflect_model=va.lmm.OpenAILMM(),
>>> action_agent=va.lmm.OpenAILMM()
>>> )
>>> quesiton = "How many hearts are in this image?"
>>> resp = agent(question, image="cards.png")
>>> print(resp)
>>> "6"
>>> resp = agent([
>>> {"role": "user", "content": question},
>>> {"role": "assistant", "content": resp},
>>> {"role": "user", "content": "No, please count the hearts on the bottom card."}
>>> ], image="cards.png")
>>> print(resp)
>>> "4"
)
"""

def __init__(
Expand All @@ -85,7 +106,7 @@ def __init__(
agent_prompt: str = COT_AGENT_REFLECT_INSTRUCTION,
reflect_prompt: str = COT_REFLECT_INSTRUCTION,
finsh_prompt: str = CHECK_FINSH,
self_reflect_llm: Optional[Union[LLM, LMM]] = None,
self_reflect_model: Optional[Union[LLM, LMM]] = None,
action_agent: Optional[Union[Agent, LLM, LMM]] = None,
verbose: bool = False,
):
Expand All @@ -98,17 +119,17 @@ def __init__(
if verbose:
_LOGGER.setLevel(logging.INFO)

if isinstance(self_reflect_llm, LLM) and not isinstance(action_agent, LLM):
if isinstance(self_reflect_model, LLM) and not isinstance(action_agent, LLM):
raise ValueError(
"If self_reflect_llm is an LLM, then action_agent must also be an LLM."
"If self_reflect_model is an LLM, then action_agent must also be an LLM."
)
if isinstance(self_reflect_llm, LMM) and isinstance(action_agent, LLM):
if isinstance(self_reflect_model, LMM) and isinstance(action_agent, LLM):
raise ValueError(
"If self_reflect_llm is an LMM, then action_agent must also be an agent or LMM."
"If self_reflect_model is an LMM, then action_agent must also be an agent or LMM."
)

self.self_reflect_llm = (
OpenAILLM() if self_reflect_llm is None else self_reflect_llm
self.self_reflect_model = (
OpenAILLM() if self_reflect_model is None else self_reflect_model
)
self.action_agent = OpenAILLM() if action_agent is None else action_agent

Expand Down Expand Up @@ -146,7 +167,7 @@ def chat(
self.last_scratchpad += "\nObservation: "
if is_correct:
self.last_scratchpad += "Answer is CORRECT"
return self.self_reflect_llm(chat)
return self.self_reflect_model(chat)
else:
self.last_scratchpad += "Answer is INCORRECT"
chat_context = "The previous conversation was:\n" + chat_str
Expand Down Expand Up @@ -220,14 +241,14 @@ def prompt_reflection(
scratchpad: str = "",
image: Optional[Union[str, Path]] = None,
) -> str:
if isinstance(self.self_reflect_llm, LLM):
if isinstance(self.self_reflect_model, LLM):
return format_step(
self.self_reflect_llm(
self.self_reflect_model(
self._build_reflect_prompt(question, context, scratchpad)
)
)
return format_step(
self.self_reflect_llm(
self.self_reflect_model(
self._build_reflect_prompt(question, context, scratchpad),
image=image,
)
Expand Down

0 comments on commit 8ad60f6

Please sign in to comment.