Skip to content

Commit

Permalink
Add Reflexion agent (#16)
Browse files Browse the repository at this point in the history
* updated llm interface

* added reflexion agent framework

* remove object creation in args

* fixed typign issue

* fixed overlapping test cases

* added chat and call tests
  • Loading branch information
dillonalaird authored Mar 16, 2024
1 parent 50348dc commit f883600
Show file tree
Hide file tree
Showing 8 changed files with 331 additions and 3 deletions.
33 changes: 33 additions & 0 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,39 @@ def test_generate_with_mock(openai_llm_mock): # noqa: F811
)


@pytest.mark.parametrize(
"openai_llm_mock", ["mocked response"], indirect=["openai_llm_mock"]
)
def test_chat_with_mock(openai_llm_mock): # noqa: F811
llm = OpenAILLM()
response = llm.chat([{"role": "user", "content": "test prompt"}])
assert response == "mocked response"
openai_llm_mock.chat.completions.create.assert_called_once_with(
model="gpt-4-turbo-preview",
messages=[{"role": "user", "content": "test prompt"}],
)


@pytest.mark.parametrize(
"openai_llm_mock", ["mocked response"], indirect=["openai_llm_mock"]
)
def test_call_with_mock(openai_llm_mock): # noqa: F811
llm = OpenAILLM()
response = llm("test prompt")
assert response == "mocked response"
openai_llm_mock.chat.completions.create.assert_called_once_with(
model="gpt-4-turbo-preview",
messages=[{"role": "user", "content": "test prompt"}],
)

response = llm([{"role": "user", "content": "test prompt"}])
assert response == "mocked response"
openai_llm_mock.chat.completions.create.assert_called_with(
model="gpt-4-turbo-preview",
messages=[{"role": "user", "content": "test prompt"}],
)


@pytest.mark.parametrize(
"openai_llm_mock",
['{"Parameters": {"prompt": "cat"}}'],
Expand Down
4 changes: 2 additions & 2 deletions tests/test_lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_generate_classifier(openai_lmm_mock): # noqa: F811
['{"Parameters": {"prompt": "cat"}}'],
indirect=["openai_lmm_mock"],
)
def test_generate_classifier(openai_lmm_mock): # noqa: F811
def test_generate_detector(openai_lmm_mock): # noqa: F811
lmm = OpenAILMM()
prompt = "Can you generate a cat classifier?"
detector = lmm.generate_detector(prompt)
Expand All @@ -64,7 +64,7 @@ def test_generate_classifier(openai_lmm_mock): # noqa: F811
['{"Parameters": {"prompt": "cat"}}'],
indirect=["openai_lmm_mock"],
)
def test_generate_classifier(openai_lmm_mock): # noqa: F811
def test_generate_segmentor(openai_lmm_mock): # noqa: F811
lmm = OpenAILMM()
prompt = "Can you generate a cat classifier?"
segmentor = lmm.generate_segmentor(prompt)
Expand Down
1 change: 1 addition & 0 deletions vision_agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .emb import Embedder, OpenAIEmb, SentenceTransformerEmb, get_embedder
from .llm import LLM, OpenAILLM
from .lmm import LMM, LLaVALMM, OpenAILMM, get_lmm
from .agent import Agent
2 changes: 2 additions & 0 deletions vision_agent/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .agent import Agent
from .reflexion import Reflexion
8 changes: 8 additions & 0 deletions vision_agent/agent/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Union


class Agent(ABC):
@abstractmethod
def __call__(self, input: Union[List[Dict[str, str]], str]) -> str:
pass
162 changes: 162 additions & 0 deletions vision_agent/agent/reflexion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import re
from typing import Dict, List, Optional, Tuple, Union

from vision_agent import LLM, OpenAILLM

from .agent import Agent
from .reflexion_prompts import (
CHECK_FINSH,
COT_AGENT_REFLECT_INSTRUCTION,
COT_REFLECT_INSTRUCTION,
COT_SIMPLE_REFLECTION,
COTQA_SIMPLE6,
REFLECTION_HEADER,
)


def format_step(step: str) -> str:
return step.strip("\n").strip().replace("\n", "")


def parse_action(input: str) -> Tuple[str, str]:
pattern = r"^(\w+)\[(.+)\]$"
match = re.match(pattern, input)

if match:
action_type = match.group(1)
argument = match.group(2)
return action_type, argument

raise ValueError(f"Invalid action: {input}")


def format_reflections(reflections: List[str], header: str = REFLECTION_HEADER) -> str:
if reflections == []:
return ""
else:
return (
header + "Reflections:\n- " + "\n- ".join([r.strip() for r in reflections])
)


def format_chat(chat: List[Dict[str, str]]) -> str:
chat_str = ""
for c in chat:
chat_str += c["role"] + ": " + c["content"] + "\n"
return chat_str.strip()


class Reflexion(Agent):
def __init__(
self,
cot_examples: str = COTQA_SIMPLE6,
reflect_examples: str = COT_SIMPLE_REFLECTION,
agent_prompt: str = COT_AGENT_REFLECT_INSTRUCTION,
reflect_prompt: str = COT_REFLECT_INSTRUCTION,
finsh_prompt: str = CHECK_FINSH,
self_reflect_llm: Optional[LLM] = None,
action_agent: Optional[Union[Agent, LLM]] = None,
):
self.agent_prompt = agent_prompt
self.reflect_prompt = reflect_prompt
self.finsh_prompt = finsh_prompt
self.cot_examples = cot_examples
self.refelct_examples = reflect_examples
self.reflections: List[str] = []

if self_reflect_llm is None:
self.self_reflect_llm = OpenAILLM()
if action_agent is None:
self.action_agent = OpenAILLM()

def __call__(self, input: Union[List[Dict[str, str]], str]) -> str:
if isinstance(input, str):
input = [{"role": "user", "content": input}]
return self.chat(input)

def chat(self, chat: List[Dict[str, str]]) -> str:
if len(chat) == 0 or chat[0]["role"] != "user":
raise ValueError(
f"Invalid chat. Should start with user and then assistant and contain at least one entry {chat}"
)
question = chat[0]["content"]
if len(chat) == 1:
results = self._step(question)
self.last_scratchpad = results["scratchpad"]
return results["action_arg"]

# Observe
chat_str = format_chat(chat)
is_correct = self.prompt_finish(chat_str)
self.last_scratchpad += "\nObservation: "
if is_correct:
self.last_scratchpad += "Answer is CORRECT"
return self.self_reflect_llm(chat)
else:
self.last_scratchpad += "Answer is INCORRECT"
chat_context = "The previous conversation was:\n" + chat_str
reflections = self.reflect(question, chat_context, self.last_scratchpad)
results = self._step(question, reflections)
self.last_scratchpad = results["scratchpad"]
return results["action_arg"]

def _step(self, question: str, reflections: str = "") -> Dict[str, str]:
# Think
scratchpad = "\nThought:"
scratchpad += " " + self.prompt_agent(question, reflections, scratchpad)

# Act
scratchpad += "\nAction:"
action = self.prompt_agent(question, reflections, scratchpad)
scratchpad += " " + action
action_type, argument = parse_action(action)
return {
"scratchpad": scratchpad,
"action_type": action_type,
"action_arg": argument,
}

def reflect(self, question: str, context: str, scratchpad: str) -> str:
self.reflections += [self.prompt_reflection(question, context, scratchpad)]
return format_reflections(self.reflections)

def prompt_agent(self, question: str, reflections: str, scratchpad: str) -> str:
return format_step(
self.action_agent(
self._build_agent_prompt(question, reflections, scratchpad)
)
)

def prompt_reflection(
self, question: str, context: str = "", scratchpad: str = ""
) -> str:
return format_step(
self.self_reflect_llm(
self._build_reflect_prompt(question, context, scratchpad)
)
)

def prompt_finish(self, chat: str) -> bool:
answer = self.action_agent(self.finsh_prompt.format(chat=chat))
return "true" in answer.lower()

def _build_agent_prompt(
self, question: str, reflections: str, scratchpad: str
) -> str:
return self.agent_prompt.format(
examples=self.cot_examples,
reflections=reflections,
context="",
question=question,
scratchpad=scratchpad,
)

def _build_reflect_prompt(
self, question: str, context: str = "", scratchpad: str = ""
) -> str:
return self.reflect_prompt.format(
examples=self.refelct_examples,
context=context,
question=question,
scratchpad=scratchpad,
)
100 changes: 100 additions & 0 deletions vision_agent/agent/reflexion_prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
COT_AGENT_REFLECT_INSTRUCTION = """Solve a question answering task by having a Thought, then Finish with your answer. Thought can reason about the current situation. Finish[answer] returns the answer and finishes the task. You will be given context that you should use to help you answer the question.
Here are some examples:
{examples}
(END OF EXAMPLES)
{reflections}
Relevant Context: {context}
Question: {question}{scratchpad}"""


COT_REFLECT_INSTRUCTION = """You are an advanced reasoning agent that can improve based on self refection. You will be given a previous reasoning trial in which you were given access to relevant context and a question to answer. You were unsuccessful in answering the question either because you guessed the wrong answer with Finish[<answer>] or there is a phrasing discrepancy with your provided answer and the answer key. In a few sentences, Diagnose a possible reason for failure or phrasing discrepancy and devise a new, concise, high level plan that aims to mitigate the same failure. Use complete sentences.
Here are some examples:
{examples}
(END OF EXAMPLES)
Previous trial:
Relevant Context: {context}
Question: {question}{scratchpad}
Reflection:"""


REFLECTION_HEADER = "You have attempted to answer following question before and failed. The following reflection(s) give a plan to avoid failing to answer the question in the same way you did previously. Use them to improve your strategy of correctly answering the given question.\n"


COT = """Relevant Context: The Nile River is the longest river in the world, spanning approximately 6,650 kilometers (4,132 miles) in length. It flows through eleven countries in northeastern Africa, including Egypt, Sudan, and Uganda.
Question: What is the longest river in the world?
Thought: The question asks for the longest river in the world, which I know is the Nile River based on the context provided.
Action: Finish[Nile River]
Relevant Context: Ludwig van Beethoven was a German composer and pianist who was a crucial figure in the transition between the Classical and Romantic eras in Western classical music. One of his most famous compositions is the Symphony No. 9, also known as the "Choral" symphony.
Question: Which composer created Symphony No. 9?
Thought: The question is asking for the composer of Symphony No. 9. Based on the context, I know that Ludwig van Beethoven composed this symphony.
Action: Finish[Ludwig van Beethoven]
Relevant Context: Photosynthesis is the process by which green plants and some other organisms convert light energy into chemical energy. During this process, plants absorb sunlight, carbon dioxide, and water to produce glucose and oxygen.
Question: What do plants produce during photosynthesis?
Thought: The question is asking about the products of photosynthesis. From the context, I know that plants produce glucose and oxygen during this process.
Action: Finish[Glucose and Oxygen]
"""

COT_REFLECT = """
Relevant Context: Ernest Hemingway's novel "The Old Man and the Sea" tells the story of Santiago, an aging Cuban fisherman, who struggles to catch a giant marlin in the Gulf Stream. The book won the Pulitzer Prize for Fiction in 1953 and contributed to Hemingway's Nobel Prize for Literature in 1954.
Question: Which literary award did "The Old Man and the Sea" contribute to Hemingway winning?
Thought: The question is asking which award "The Old Man and the Sea" contributed to Hemingway winning. Based on the context, I know the novel won the Pulitzer Prize for Fiction and contributed to his Nobel Prize for Literature.
Action: Finish[Pulitzer Prize for Fiction]
Reflection: My answer was correct based on the context, but may not be the exact answer stored by the grading environment. Next time, I should try to provide a less verbose answer like "Pulitzer Prize" or "Nobel Prize."
Context: On 14 October 1947, Chuck Yeager, a United States Air Force test pilot, became the first person to break the sound barrier by flying the Bell X-1 experimental aircraft at an altitude of 45,000 feet.
Charles Elwood "Chuck" Yeager (13 February 1923 - 7 December 2020) was a United States Air Force officer, flying ace, and test pilot. He is best known for becoming the first person to break the sound barrier, which he achieved in the Bell X-1 aircraft named Glamorous Glennis. Yeager was also a distinguished fighter pilot during World War II and was credited with shooting down at least 12 enemy aircraft. In 1973, he was inducted into the National Aviation Hall of Fame for his significant contributions to aviation.
Question: Who is the first person to break the sound barrier?
Thought: The question is asking for the first person to break the sound barrier. From the context, I know that Chuck Yeager, a United States Air Force test pilot, was the first person to break the sound barrier.
Action: Finish[Chuck Yeager]
Reflection: Upon reflecting on the incorrect answer I provided, I realize that I may not have given the full name of the individual in question. In the context, both the given name and the nickname were mentioned, and I only used the nickname in my response. This could have been the reason my answer was deemed incorrect. Moving forward, when attempting this question again or similar questions, I will make sure to include the complete name of the person, which consists of their given name, any middle names, and their nickname (if applicable). This will help ensure that my answer is more accurate and comprehensive."""

COTQA_SIMPLE6 = """
Question: What is the elevation range for the area that the eastern sector of the Colorado orogeny extends into?
Thought: Let's think step by step. The eastern sector of Colorado orogeny extends into the High Plains. High Plains rise in elevation from around 1,800 to 7,000 ft, so the answer is 1,800 to 7,000 ft.
Action: Finish[1,800 to 7,000 ft]
Question: Musician and satirist Allie Goertz wrote a song about the "The Simpsons" character Milhouse, who Matt Groening named after who?
Thought: Let's think step by step. Milhouse was named after U.S. president Richard Nixon, so the answer is Richard Nixon.
Action: Finish[Richard Nixon]
Question: Which documentary is about Finnish rock groups, Adam Clayton Powell or The Saimaa Gesture?
Thought: Let's think step by step. Adam Clayton Powell (film) is a documentary about an African-American politician, not Finnish rock groups. So the documentary about Finnish rock groups must instead be The Saimaa Gesture.
Action: Finish[The Saimaa Gesture]
Question: What profession does Nicholas Ray and Elia Kazan have in common?
Thought: Let's think step by step. Professions of Nicholas Ray are director, screenwriter, and actor. Professions of Elia Kazan are director, producer, screenwriter, and actor. So profession Nicholas Ray and Elia Kazan have in common is director, screenwriter, and actor.
Action: Finish[director, screenwriter, actor]
Question: Which magazine was started first Arthur's Magazine or First for Women?
Thought: Let's think step by step. Arthur's Magazine was started in 1844. First for Women was started in 1989. 1844 (Arthur's Magazine) < 1989 (First for Women), so Arthur's Magazine was started first.
Action: Finish[Arthur's Magazine]
Question: Were Pavel Urysohn and Leonid Levin known for the same type of work?
Thought: Let's think step by step. Pavel Urysohn is a mathematician. Leonid Levin is a mathematician and computer scientist. So Pavel Urysohn and Leonid Levin have the same type of work.
Action: Finish[Yes]"""

COT_SIMPLE_REFLECTION = """
Question: What is the elevation range for the area that the eastern sector of the Colorado orogeny extends into?
Thought: Let's think step by step. The eastern sector of Colorado orogeny extends into the Rocky Mountains. The Rocky Mountains rise in elevation from around 1,800 to 14,000 ft, so the answer is 1,800 to 14,000 ft.
Action: Finish[1,800 to 14,000 ft]
Reflection: My reasoning for the elevation range of the eastern sector of the Colorado orogeny failed because I incorrectly associated it with the Rocky Mountains instead of the High Plains. In the future, when attempting this question, I should focus on better understanding the geography and geological formations involved in the Colorado orogeny to avoid confusion.
Question: Musician and satirist Allie Goertz wrote a song about the "The Simpsons" character Milhouse, who Matt Groening named after who?
Thought: Let's think step by step. Milhouse was named after a prominent cartoon character, Mickey Mouse, so the answer is Mickey Mouse.
Action: Finish[Mickey Mouse]
Reflection: My reasoning for the naming of the character Milhouse in "The Simpsons" failed because I incorrectly assumed that the character was named after a prominent cartoon character, Mickey Mouse. In the future, when attempting this question, I should focus on better researching the background of "The Simpsons" and Matt Groening's influences in naming characters to avoid making incorrect assumptions.
"""

CHECK_FINSH = """
You are an advanced reasoning agent, your job is to look at a conversation between a human and another agent and determine if the agent correctly answered the human's quesiton. If the agent correctly answered the question, return True. If the agent did not correctly answer the question or you are unsure, return False.
{chat}"""
24 changes: 23 additions & 1 deletion vision_agent/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from abc import ABC, abstractmethod
from typing import Mapping, cast
from typing import Dict, List, Mapping, Union, cast

from openai import OpenAI

from vision_agent.tools import (
Expand All @@ -18,6 +19,14 @@ class LLM(ABC):
def generate(self, prompt: str) -> str:
pass

@abstractmethod
def chat(self, chat: List[Dict[str, str]]) -> str:
pass

@abstractmethod
def __call__(self, input: Union[str, List[Dict[str, str]]]) -> str:
pass


class OpenAILLM(LLM):
r"""An LLM class for any OpenAI LLM model."""
Expand All @@ -36,6 +45,19 @@ def generate(self, prompt: str) -> str:

return cast(str, response.choices[0].message.content)

def chat(self, chat: List[Dict[str, str]]) -> str:
response = self.client.chat.completions.create(
model=self.model_name,
messages=chat, # type: ignore
)

return cast(str, response.choices[0].message.content)

def __call__(self, input: Union[str, List[Dict[str, str]]]) -> str:
if isinstance(input, str):
return self.generate(input)
return self.chat(input)

def generate_classifier(self, prompt: str) -> ImageTool:
prompt = CHOOSE_PARAMS.format(api_doc=CLIP.doc, question=prompt)
response = self.client.chat.completions.create(
Expand Down

0 comments on commit f883600

Please sign in to comment.