Skip to content

Commit

Permalink
Add ability to use image for Reflexion (#17)
Browse files Browse the repository at this point in the history
* added images for reflexion

* formatting fix

* added more comments, fixed regex
  • Loading branch information
dillonalaird authored Mar 17, 2024
1 parent 04b6411 commit a6cec71
Show file tree
Hide file tree
Showing 5 changed files with 278 additions and 44 deletions.
39 changes: 39 additions & 0 deletions tests/test_lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,45 @@ def test_generate_with_mock(openai_lmm_mock): # noqa: F811
)


@pytest.mark.parametrize(
"openai_lmm_mock", ["mocked response"], indirect=["openai_lmm_mock"]
)
def test_chat_with_mock(openai_lmm_mock): # noqa: F811
lmm = OpenAILMM()
response = lmm.chat([{"role": "user", "content": "test prompt"}])
assert response == "mocked response"
assert (
openai_lmm_mock.chat.completions.create.call_args.kwargs["messages"][0][
"content"
][0]["text"]
== "test prompt"
)


@pytest.mark.parametrize(
"openai_lmm_mock", ["mocked response"], indirect=["openai_lmm_mock"]
)
def test_call_with_mock(openai_lmm_mock): # noqa: F811
lmm = OpenAILMM()
response = lmm("test prompt")
assert response == "mocked response"
assert (
openai_lmm_mock.chat.completions.create.call_args.kwargs["messages"][0][
"content"
][0]["text"]
== "test prompt"
)

response = lmm([{"role": "user", "content": "test prompt"}])
assert response == "mocked response"
assert (
openai_lmm_mock.chat.completions.create.call_args.kwargs["messages"][0][
"content"
][0]["text"]
== "test prompt"
)


@pytest.mark.parametrize(
"openai_lmm_mock",
['{"Parameters": {"prompt": "cat"}}'],
Expand Down
9 changes: 7 additions & 2 deletions vision_agent/agent/agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Union
from pathlib import Path
from typing import Dict, List, Optional, Union


class Agent(ABC):
@abstractmethod
def __call__(self, input: Union[List[Dict[str, str]], str]) -> str:
def __call__(
self,
input: Union[List[Dict[str, str]], str],
image: Optional[Union[str, Path]] = None,
) -> str:
pass
172 changes: 145 additions & 27 deletions vision_agent/agent/reflexion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import logging
import re
import sys
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

from vision_agent import LLM, OpenAILLM
from vision_agent import LLM, LMM, OpenAILLM

from .agent import Agent
from .reflexion_prompts import (
Expand All @@ -13,20 +16,27 @@
REFLECTION_HEADER,
)

logging.basicConfig(stream=sys.stdout)

_LOGGER = logging.getLogger(__name__)


def format_step(step: str) -> str:
return step.strip("\n").strip().replace("\n", "")


def parse_action(input: str) -> Tuple[str, str]:
pattern = r"^(\w+)\[(.+)\]$"
match = re.match(pattern, input)
# Make the pattern slightly less strict, the LMMs are not as good at following
# instructions so they often would fail on the original regex.
pattern = r"(\w+)\[(.+)\]"
match = re.search(pattern, input)

if match:
action_type = match.group(1)
argument = match.group(2)
return action_type, argument

_LOGGER.error(f"Invalid action: {input}")
raise ValueError(f"Invalid action: {input}")


Expand All @@ -47,41 +57,107 @@ def format_chat(chat: List[Dict[str, str]]) -> str:


class Reflexion(Agent):
r"""This is an implementation of the Reflexion paper https://arxiv.org/abs/2303.11366
based on the original implementation https://github.com/noahshinn/reflexion in the
hotpotqa folder. There are several differences between this implementation and the
original one. Because we do not have instant feedback on whether or not the agent
was correct, we use user feedback to determine if the agent was correct. The user
feedback is evaluated by the self_reflect_model with a new prompt. We also expand
Reflexion to include the ability to use an image as input to the action_agent and the
self_reflect_model. Using Reflexion with LMMs may not work well, if it gets it wrong
the first time, chances are it can't actually see the thing you want it to see.
Examples::
>>> from vision_agent.agent import Reflexion
>>> agent = Reflexion()
>>> question = "How many tires does a truck have?"
>>> resp = agent(question)
>>> print(resp)
>>> "18"
>>> resp = agent([
>>> {"role": "user", "content": question},
>>> {"role": "assistant", "content": resp},
>>> {"role": "user", "content": "No I mean those regular trucks but where the back tires are double."}
>>> ])
>>> print(resp)
>>> "6"
>>> agent = Reflexion(
>>> self_reflect_model=va.lmm.OpenAILMM(),
>>> action_agent=va.lmm.OpenAILMM()
>>> )
>>> quesiton = "How many hearts are in this image?"
>>> resp = agent(question, image="cards.png")
>>> print(resp)
>>> "6"
>>> resp = agent([
>>> {"role": "user", "content": question},
>>> {"role": "assistant", "content": resp},
>>> {"role": "user", "content": "No, please count the hearts on the bottom card."}
>>> ], image="cards.png")
>>> print(resp)
>>> "4"
)
"""

def __init__(
self,
cot_examples: str = COTQA_SIMPLE6,
reflect_examples: str = COT_SIMPLE_REFLECTION,
agent_prompt: str = COT_AGENT_REFLECT_INSTRUCTION,
reflect_prompt: str = COT_REFLECT_INSTRUCTION,
finsh_prompt: str = CHECK_FINSH,
self_reflect_llm: Optional[LLM] = None,
action_agent: Optional[Union[Agent, LLM]] = None,
self_reflect_model: Optional[Union[LLM, LMM]] = None,
action_agent: Optional[Union[Agent, LLM, LMM]] = None,
verbose: bool = False,
):
self.agent_prompt = agent_prompt
self.reflect_prompt = reflect_prompt
self.finsh_prompt = finsh_prompt
self.cot_examples = cot_examples
self.refelct_examples = reflect_examples
self.reflections: List[str] = []
if verbose:
_LOGGER.setLevel(logging.INFO)

if isinstance(self_reflect_model, LLM) and not isinstance(action_agent, LLM):
raise ValueError(
"If self_reflect_model is an LLM, then action_agent must also be an LLM."
)
if isinstance(self_reflect_model, LMM) and isinstance(action_agent, LLM):
raise ValueError(
"If self_reflect_model is an LMM, then action_agent must also be an agent or LMM."
)

if self_reflect_llm is None:
self.self_reflect_llm = OpenAILLM()
if action_agent is None:
self.action_agent = OpenAILLM()
self.self_reflect_model = (
OpenAILLM() if self_reflect_model is None else self_reflect_model
)
self.action_agent = OpenAILLM() if action_agent is None else action_agent

def __call__(self, input: Union[List[Dict[str, str]], str]) -> str:
def __call__(
self,
input: Union[str, List[Dict[str, str]]],
image: Optional[Union[str, Path]] = None,
) -> str:
if isinstance(input, str):
input = [{"role": "user", "content": input}]
return self.chat(input)
return self.chat(input, image)

def chat(self, chat: List[Dict[str, str]]) -> str:
def chat(
self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None
) -> str:
if len(chat) == 0 or chat[0]["role"] != "user":
raise ValueError(
f"Invalid chat. Should start with user and then assistant and contain at least one entry {chat}"
f"Invalid chat. Should start with user and alternate between user"
f"and assistant and contain at least one entry {chat}"
)
if image is not None and isinstance(self.action_agent, LLM):
raise ValueError(
"If image is provided, then action_agent must be an agent or LMM."
)

question = chat[0]["content"]
if len(chat) == 1:
results = self._step(question)
results = self._step(question, image=image)
self.last_scratchpad = results["scratchpad"]
return results["action_arg"]

Expand All @@ -91,23 +167,33 @@ def chat(self, chat: List[Dict[str, str]]) -> str:
self.last_scratchpad += "\nObservation: "
if is_correct:
self.last_scratchpad += "Answer is CORRECT"
return self.self_reflect_llm(chat)
return self.self_reflect_model(chat)
else:
self.last_scratchpad += "Answer is INCORRECT"
chat_context = "The previous conversation was:\n" + chat_str
reflections = self.reflect(question, chat_context, self.last_scratchpad)
results = self._step(question, reflections)
reflections = self.reflect(
question, chat_context, self.last_scratchpad, image
)
_LOGGER.info(f" {reflections}")
results = self._step(question, reflections, image=image)
self.last_scratchpad = results["scratchpad"]
return results["action_arg"]

def _step(self, question: str, reflections: str = "") -> Dict[str, str]:
def _step(
self,
question: str,
reflections: str = "",
image: Optional[Union[str, Path]] = None,
) -> Dict[str, str]:
# Think
scratchpad = "\nThought:"
scratchpad += " " + self.prompt_agent(question, reflections, scratchpad)
scratchpad += " " + self.prompt_agent(question, reflections, scratchpad, image)
_LOGGER.info(f" {scratchpad}")

# Act
scratchpad += "\nAction:"
action = self.prompt_agent(question, reflections, scratchpad)
action = self.prompt_agent(question, reflections, scratchpad, image)
_LOGGER.info(f" {action}")
scratchpad += " " + action
action_type, argument = parse_action(action)
return {
Expand All @@ -116,23 +202,55 @@ def _step(self, question: str, reflections: str = "") -> Dict[str, str]:
"action_arg": argument,
}

def reflect(self, question: str, context: str, scratchpad: str) -> str:
self.reflections += [self.prompt_reflection(question, context, scratchpad)]
def reflect(
self,
question: str,
context: str,
scratchpad: str,
image: Optional[Union[str, Path]],
) -> str:
self.reflections += [
self.prompt_reflection(question, context, scratchpad, image)
]
return format_reflections(self.reflections)

def prompt_agent(self, question: str, reflections: str, scratchpad: str) -> str:
def prompt_agent(
self,
question: str,
reflections: str,
scratchpad: str,
image: Optional[Union[str, Path]] = None,
) -> str:
if isinstance(self.action_agent, LLM):
return format_step(
self.action_agent(
self._build_agent_prompt(question, reflections, scratchpad)
)
)
return format_step(
self.action_agent(
self._build_agent_prompt(question, reflections, scratchpad)
self._build_agent_prompt(question, reflections, scratchpad),
image=image,
)
)

def prompt_reflection(
self, question: str, context: str = "", scratchpad: str = ""
self,
question: str,
context: str = "",
scratchpad: str = "",
image: Optional[Union[str, Path]] = None,
) -> str:
if isinstance(self.self_reflect_model, LLM):
return format_step(
self.self_reflect_model(
self._build_reflect_prompt(question, context, scratchpad)
)
)
return format_step(
self.self_reflect_llm(
self._build_reflect_prompt(question, context, scratchpad)
self.self_reflect_model(
self._build_reflect_prompt(question, context, scratchpad),
image=image,
)
)

Expand Down
28 changes: 17 additions & 11 deletions vision_agent/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ class DataStore:
r"""A class to store and manage image data along with its generated metadata from an LMM."""

def __init__(self, df: pd.DataFrame):
r"""Initializes the DataStore with a DataFrame containing image paths and image IDs. If the image IDs are not present, they are generated using UUID4. The DataFrame must contain an 'image_paths' column.
r"""Initializes the DataStore with a DataFrame containing image paths and image
IDs. If the image IDs are not present, they are generated using UUID4. The
DataFrame must contain an 'image_paths' column.
Args:
df (pd.DataFrame): The DataFrame containing "image_paths" and "image_id" columns.
df: The DataFrame containing "image_paths" and "image_id" columns.
"""
self.df = df
self.lmm: Optional[LMM] = None
Expand All @@ -47,12 +49,14 @@ def add_lmm(self, lmm: LMM) -> Self:
def add_column(
self, name: str, prompt: str, func: Optional[Callable[[str], str]] = None
) -> Self:
r"""Adds a new column to the DataFrame containing the generated metadata from the LMM.
r"""Adds a new column to the DataFrame containing the generated metadata from
the LMM.
Args:
name (str): The name of the column to be added.
prompt (str): The prompt to be used to generate the metadata.
func (Optional[Callable[[Any], Any]]): A Python function to be applied on the output of `lmm.generate`. Defaults to None.
name: The name of the column to be added.
prompt: The prompt to be used to generate the metadata.
func: A Python function to be applied on the output of `lmm.generate`.
Defaults to None.
"""
if self.lmm is None:
raise ValueError("LMM not set yet")
Expand All @@ -67,10 +71,11 @@ def add_column(
return self

def build_index(self, target_col: str) -> Self:
r"""This will generate embeddings for the `target_col` and build a searchable index over them, so next time you run search it will search over this index.
r"""This will generate embeddings for the `target_col` and build a searchable
index over them, so next time you run search it will search over this index.
Args:
target_col (str): The column name containing the data to be indexed."""
target_col: The column name containing the data to be indexed."""
if self.emb is None:
raise ValueError("Embedder not set yet")

Expand All @@ -92,11 +97,12 @@ def get_embeddings(self) -> npt.NDArray[np.float32]:
)

def search(self, query: str, top_k: int = 10) -> List[Dict]:
r"""Searches the index for the most similar images to the query and returns the top_k results.
r"""Searches the index for the most similar images to the query and returns
the top_k results.
Args:
query (str): The query to search for.
top_k (int, optional): The number of results to return. Defaults to 10."""
query: The query to search for.
top_k: The number of results to return. Defaults to 10."""
if self.index is None:
raise ValueError("Index not built yet")
if self.emb is None:
Expand Down
Loading

0 comments on commit a6cec71

Please sign in to comment.