From 680554a393304c828f1a7233b6978f3e8ba1df3d Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Sat, 16 Mar 2024 15:06:52 -0700 Subject: [PATCH 1/3] added images for reflexion --- tests/test_lmm.py | 39 +++++++++ vision_agent/agent/agent.py | 7 +- vision_agent/agent/reflexion.py | 143 +++++++++++++++++++++++++++----- vision_agent/data/data.py | 28 ++++--- vision_agent/lmm/lmm.py | 74 ++++++++++++++++- 5 files changed, 251 insertions(+), 40 deletions(-) diff --git a/tests/test_lmm.py b/tests/test_lmm.py index 0dff8466..c1390726 100644 --- a/tests/test_lmm.py +++ b/tests/test_lmm.py @@ -33,6 +33,45 @@ def test_generate_with_mock(openai_lmm_mock): # noqa: F811 ) +@pytest.mark.parametrize( + "openai_lmm_mock", ["mocked response"], indirect=["openai_lmm_mock"] +) +def test_chat_with_mock(openai_lmm_mock): # noqa: F811 + lmm = OpenAILMM() + response = lmm.chat([{"role": "user", "content": "test prompt"}]) + assert response == "mocked response" + assert ( + openai_lmm_mock.chat.completions.create.call_args.kwargs["messages"][0][ + "content" + ][0]["text"] + == "test prompt" + ) + + +@pytest.mark.parametrize( + "openai_lmm_mock", ["mocked response"], indirect=["openai_lmm_mock"] +) +def test_call_with_mock(openai_lmm_mock): # noqa: F811 + lmm = OpenAILMM() + response = lmm("test prompt") + assert response == "mocked response" + assert ( + openai_lmm_mock.chat.completions.create.call_args.kwargs["messages"][0][ + "content" + ][0]["text"] + == "test prompt" + ) + + response = lmm([{"role": "user", "content": "test prompt"}]) + assert response == "mocked response" + assert ( + openai_lmm_mock.chat.completions.create.call_args.kwargs["messages"][0][ + "content" + ][0]["text"] + == "test prompt" + ) + + @pytest.mark.parametrize( "openai_lmm_mock", ['{"Parameters": {"prompt": "cat"}}'], diff --git a/vision_agent/agent/agent.py b/vision_agent/agent/agent.py index 7de17496..05421e41 100644 --- a/vision_agent/agent/agent.py +++ b/vision_agent/agent/agent.py @@ -1,8 +1,11 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Union +from pathlib import Path +from typing import Dict, List, Optional, Union class Agent(ABC): @abstractmethod - def __call__(self, input: Union[List[Dict[str, str]], str]) -> str: + def __call__( + self, input: Union[List[Dict[str, str]], str], image: Optional[Union[str, Path]] = None + ) -> str: pass diff --git a/vision_agent/agent/reflexion.py b/vision_agent/agent/reflexion.py index dc9b3bce..8c89f024 100644 --- a/vision_agent/agent/reflexion.py +++ b/vision_agent/agent/reflexion.py @@ -1,7 +1,10 @@ +import logging import re +import sys +from pathlib import Path from typing import Dict, List, Optional, Tuple, Union -from vision_agent import LLM, OpenAILLM +from vision_agent import LLM, LMM, OpenAILLM from .agent import Agent from .reflexion_prompts import ( @@ -13,6 +16,10 @@ REFLECTION_HEADER, ) +logging.basicConfig(stream=sys.stdout) + +_LOGGER = logging.getLogger(__name__) + def format_step(step: str) -> str: return step.strip("\n").strip().replace("\n", "") @@ -27,6 +34,7 @@ def parse_action(input: str) -> Tuple[str, str]: argument = match.group(2) return action_type, argument + _LOGGER.error(f"Invalid action: {input}") raise ValueError(f"Invalid action: {input}") @@ -47,6 +55,29 @@ def format_chat(chat: List[Dict[str, str]]) -> str: class Reflexion(Agent): + r"""This is an implementation of the Reflexion paper https://arxiv.org/abs/2303.11366 + based on the original implementation https://github.com/noahshinn/reflexion in the + hotpotqa folder. There are several differences between this implementation and the + original one. Because we do not have instant feedback on whether or not the agent + was correct, we use user feedback to determine if the agent was correct. The user + feedback is evaluated by the self_reflect_llm with a new prompt. + + Examples:: + >>> from vision_agent.agent import Reflexion + >>> agent = Reflexion() + >>> question = "How many tires does a truck have?" + >>> resp = agent(question) + >>> print(resp) + >>> "18" + >>> resp = agent([ + >>> {"role": "user", "content": question}, + >>> {"role": "assistant", "content": resp}, + >>> {"role": "user", "content": "No I mean those regular trucks but where the back tires are double."} + >>> ]) + >>> print(resp) + >>> "6" + """ + def __init__( self, cot_examples: str = COTQA_SIMPLE6, @@ -54,8 +85,9 @@ def __init__( agent_prompt: str = COT_AGENT_REFLECT_INSTRUCTION, reflect_prompt: str = COT_REFLECT_INSTRUCTION, finsh_prompt: str = CHECK_FINSH, - self_reflect_llm: Optional[LLM] = None, - action_agent: Optional[Union[Agent, LLM]] = None, + self_reflect_llm: Optional[Union[LLM, LMM]] = None, + action_agent: Optional[Union[Agent, LLM, LMM]] = None, + verbose: bool = False, ): self.agent_prompt = agent_prompt self.reflect_prompt = reflect_prompt @@ -63,25 +95,48 @@ def __init__( self.cot_examples = cot_examples self.refelct_examples = reflect_examples self.reflections: List[str] = [] + if verbose: + _LOGGER.setLevel(logging.INFO) - if self_reflect_llm is None: - self.self_reflect_llm = OpenAILLM() - if action_agent is None: - self.action_agent = OpenAILLM() + if isinstance(self_reflect_llm, LLM) and not isinstance(action_agent, LLM): + raise ValueError( + "If self_reflect_llm is an LLM, then action_agent must also be an LLM." + ) + if isinstance(self_reflect_llm, LMM) and isinstance(action_agent, LLM): + raise ValueError( + "If self_reflect_llm is an LMM, then action_agent must also be an agent or LMM." + ) + + self.self_reflect_llm = ( + OpenAILLM() if self_reflect_llm is None else self_reflect_llm + ) + self.action_agent = OpenAILLM() if action_agent is None else action_agent - def __call__(self, input: Union[List[Dict[str, str]], str]) -> str: + def __call__( + self, + input: Union[str, List[Dict[str, str]]], + image: Optional[Union[str, Path]] = None, + ) -> str: if isinstance(input, str): input = [{"role": "user", "content": input}] - return self.chat(input) + return self.chat(input, image) - def chat(self, chat: List[Dict[str, str]]) -> str: + def chat( + self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None + ) -> str: if len(chat) == 0 or chat[0]["role"] != "user": raise ValueError( - f"Invalid chat. Should start with user and then assistant and contain at least one entry {chat}" + f"Invalid chat. Should start with user and alternate between user" + f"and assistant and contain at least one entry {chat}" ) + if image is not None and isinstance(self.action_agent, LLM): + raise ValueError( + "If image is provided, then action_agent must be an agent or LMM." + ) + question = chat[0]["content"] if len(chat) == 1: - results = self._step(question) + results = self._step(question, image=image) self.last_scratchpad = results["scratchpad"] return results["action_arg"] @@ -95,19 +150,29 @@ def chat(self, chat: List[Dict[str, str]]) -> str: else: self.last_scratchpad += "Answer is INCORRECT" chat_context = "The previous conversation was:\n" + chat_str - reflections = self.reflect(question, chat_context, self.last_scratchpad) - results = self._step(question, reflections) + reflections = self.reflect( + question, chat_context, self.last_scratchpad, image + ) + _LOGGER.info(f" {reflections}") + results = self._step(question, reflections, image=image) self.last_scratchpad = results["scratchpad"] return results["action_arg"] - def _step(self, question: str, reflections: str = "") -> Dict[str, str]: + def _step( + self, + question: str, + reflections: str = "", + image: Optional[Union[str, Path]] = None, + ) -> Dict[str, str]: # Think scratchpad = "\nThought:" - scratchpad += " " + self.prompt_agent(question, reflections, scratchpad) + scratchpad += " " + self.prompt_agent(question, reflections, scratchpad, image) + _LOGGER.info(f" {scratchpad}") # Act scratchpad += "\nAction:" - action = self.prompt_agent(question, reflections, scratchpad) + action = self.prompt_agent(question, reflections, scratchpad, image) + _LOGGER.info(f" {action}") scratchpad += " " + action action_type, argument = parse_action(action) return { @@ -116,23 +181,55 @@ def _step(self, question: str, reflections: str = "") -> Dict[str, str]: "action_arg": argument, } - def reflect(self, question: str, context: str, scratchpad: str) -> str: - self.reflections += [self.prompt_reflection(question, context, scratchpad)] + def reflect( + self, + question: str, + context: str, + scratchpad: str, + image: Optional[Union[str, Path]], + ) -> str: + self.reflections += [ + self.prompt_reflection(question, context, scratchpad, image) + ] return format_reflections(self.reflections) - def prompt_agent(self, question: str, reflections: str, scratchpad: str) -> str: + def prompt_agent( + self, + question: str, + reflections: str, + scratchpad: str, + image: Optional[Union[str, Path]] = None, + ) -> str: + if isinstance(self.action_agent, LLM): + return format_step( + self.action_agent( + self._build_agent_prompt(question, reflections, scratchpad) + ) + ) return format_step( self.action_agent( - self._build_agent_prompt(question, reflections, scratchpad) + self._build_agent_prompt(question, reflections, scratchpad), + image=image, ) ) def prompt_reflection( - self, question: str, context: str = "", scratchpad: str = "" + self, + question: str, + context: str = "", + scratchpad: str = "", + image: Optional[Union[str, Path]] = None, ) -> str: + if isinstance(self.self_reflect_llm, LLM): + return format_step( + self.self_reflect_llm( + self._build_reflect_prompt(question, context, scratchpad) + ) + ) return format_step( self.self_reflect_llm( - self._build_reflect_prompt(question, context, scratchpad) + self._build_reflect_prompt(question, context, scratchpad), + image=image, ) ) diff --git a/vision_agent/data/data.py b/vision_agent/data/data.py index 4548b42b..6b51488b 100644 --- a/vision_agent/data/data.py +++ b/vision_agent/data/data.py @@ -22,10 +22,12 @@ class DataStore: r"""A class to store and manage image data along with its generated metadata from an LMM.""" def __init__(self, df: pd.DataFrame): - r"""Initializes the DataStore with a DataFrame containing image paths and image IDs. If the image IDs are not present, they are generated using UUID4. The DataFrame must contain an 'image_paths' column. + r"""Initializes the DataStore with a DataFrame containing image paths and image + IDs. If the image IDs are not present, they are generated using UUID4. The + DataFrame must contain an 'image_paths' column. Args: - df (pd.DataFrame): The DataFrame containing "image_paths" and "image_id" columns. + df: The DataFrame containing "image_paths" and "image_id" columns. """ self.df = df self.lmm: Optional[LMM] = None @@ -47,12 +49,14 @@ def add_lmm(self, lmm: LMM) -> Self: def add_column( self, name: str, prompt: str, func: Optional[Callable[[str], str]] = None ) -> Self: - r"""Adds a new column to the DataFrame containing the generated metadata from the LMM. + r"""Adds a new column to the DataFrame containing the generated metadata from + the LMM. Args: - name (str): The name of the column to be added. - prompt (str): The prompt to be used to generate the metadata. - func (Optional[Callable[[Any], Any]]): A Python function to be applied on the output of `lmm.generate`. Defaults to None. + name: The name of the column to be added. + prompt: The prompt to be used to generate the metadata. + func: A Python function to be applied on the output of `lmm.generate`. + Defaults to None. """ if self.lmm is None: raise ValueError("LMM not set yet") @@ -67,10 +71,11 @@ def add_column( return self def build_index(self, target_col: str) -> Self: - r"""This will generate embeddings for the `target_col` and build a searchable index over them, so next time you run search it will search over this index. + r"""This will generate embeddings for the `target_col` and build a searchable + index over them, so next time you run search it will search over this index. Args: - target_col (str): The column name containing the data to be indexed.""" + target_col: The column name containing the data to be indexed.""" if self.emb is None: raise ValueError("Embedder not set yet") @@ -92,11 +97,12 @@ def get_embeddings(self) -> npt.NDArray[np.float32]: ) def search(self, query: str, top_k: int = 10) -> List[Dict]: - r"""Searches the index for the most similar images to the query and returns the top_k results. + r"""Searches the index for the most similar images to the query and returns + the top_k results. Args: - query (str): The query to search for. - top_k (int, optional): The number of results to return. Defaults to 10.""" + query: The query to search for. + top_k: The number of results to return. Defaults to 10.""" if self.index is None: raise ValueError("Index not built yet") if self.emb is None: diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index b6c20e27..2023000c 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -17,8 +17,6 @@ ImageTool, ) -logging.basicConfig(level=logging.INFO) - _LOGGER = logging.getLogger(__name__) _LLAVA_ENDPOINT = "https://svtswgdnleslqcsjvilau4p6u40jwrkn.lambda-url.us-east-2.on.aws" @@ -35,6 +33,20 @@ class LMM(ABC): def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str: pass + @abstractmethod + def chat( + self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None + ) -> str: + pass + + @abstractmethod + def __call__( + self, + input: Union[str, List[Dict[str, str]]], + image: Optional[Union[str, Path]] = None, + ) -> str: + pass + class LLaVALMM(LMM): r"""An LMM class for the LLaVA-1.6 34B model.""" @@ -42,6 +54,20 @@ class LLaVALMM(LMM): def __init__(self, model_name: str): self.model_name = model_name + def __call__( + self, + input: Union[str, List[Dict[str, str]]], + image: Optional[Union[str, Path]] = None, + ) -> str: + if isinstance(input, str): + return self.generate(input, image) + return self.chat(input, image) + + def chat( + self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None + ) -> str: + raise NotImplementedError("Chat not supported for LLaVA") + def generate( self, prompt: str, @@ -71,10 +97,50 @@ def generate( class OpenAILMM(LMM): r"""An LMM class for the OpenAI GPT-4 Vision model.""" - def __init__(self, model_name: str = "gpt-4-vision-preview"): + def __init__( + self, model_name: str = "gpt-4-vision-preview", max_tokens: int = 1024 + ): self.model_name = model_name + self.max_tokens = max_tokens self.client = OpenAI() + def __call__( + self, + input: Union[str, List[Dict[str, str]]], + image: Optional[Union[str, Path]] = None, + ) -> str: + if isinstance(input, str): + return self.generate(input, image) + return self.chat(input, image) + + def chat( + self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None + ) -> str: + fixed_chat = [] + for c in chat: + fixed_c = {"role": c["role"]} + fixed_c["content"] = [{"type": "text", "text": c["content"]}] # type: ignore + fixed_chat.append(fixed_c) + + if image: + extension = Path(image).suffix + encoded_image = encode_image(image) + fixed_chat[0]["content"].append( # type: ignore + { + "type": "image_url", + "image_url": { + "url": f"data:image/{extension};base64,{encoded_image}", + "detail": "low", + }, + }, + ) + + response = self.client.chat.completions.create( + model=self.model_name, messages=fixed_chat, max_tokens=self.max_tokens # type: ignore + ) + + return cast(str, response.choices[0].message.content) + def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str: message: List[Dict[str, Any]] = [ { @@ -98,7 +164,7 @@ def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str ) response = self.client.chat.completions.create( - model=self.model_name, messages=message # type: ignore + model=self.model_name, messages=message, max_tokens=self.max_tokens # type: ignore ) return cast(str, response.choices[0].message.content) From a85235470474becfa0069c8caba2cc68ea54bb07 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Sat, 16 Mar 2024 19:56:20 -0700 Subject: [PATCH 2/3] formatting fix --- vision_agent/agent/agent.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vision_agent/agent/agent.py b/vision_agent/agent/agent.py index 05421e41..5054f170 100644 --- a/vision_agent/agent/agent.py +++ b/vision_agent/agent/agent.py @@ -6,6 +6,8 @@ class Agent(ABC): @abstractmethod def __call__( - self, input: Union[List[Dict[str, str]], str], image: Optional[Union[str, Path]] = None + self, + input: Union[List[Dict[str, str]], str], + image: Optional[Union[str, Path]] = None, ) -> str: pass From 8ad60f6507e806bed7914b298310c6d10dd6eedb Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Sun, 17 Mar 2024 11:17:35 -0700 Subject: [PATCH 3/3] added more comments, fixed regex --- vision_agent/agent/reflexion.py | 49 +++++++++++++++++++++++---------- 1 file changed, 35 insertions(+), 14 deletions(-) diff --git a/vision_agent/agent/reflexion.py b/vision_agent/agent/reflexion.py index 8c89f024..4d0d05c3 100644 --- a/vision_agent/agent/reflexion.py +++ b/vision_agent/agent/reflexion.py @@ -26,8 +26,10 @@ def format_step(step: str) -> str: def parse_action(input: str) -> Tuple[str, str]: - pattern = r"^(\w+)\[(.+)\]$" - match = re.match(pattern, input) + # Make the pattern slightly less strict, the LMMs are not as good at following + # instructions so they often would fail on the original regex. + pattern = r"(\w+)\[(.+)\]" + match = re.search(pattern, input) if match: action_type = match.group(1) @@ -60,7 +62,10 @@ class Reflexion(Agent): hotpotqa folder. There are several differences between this implementation and the original one. Because we do not have instant feedback on whether or not the agent was correct, we use user feedback to determine if the agent was correct. The user - feedback is evaluated by the self_reflect_llm with a new prompt. + feedback is evaluated by the self_reflect_model with a new prompt. We also expand + Reflexion to include the ability to use an image as input to the action_agent and the + self_reflect_model. Using Reflexion with LMMs may not work well, if it gets it wrong + the first time, chances are it can't actually see the thing you want it to see. Examples:: >>> from vision_agent.agent import Reflexion @@ -76,6 +81,22 @@ class Reflexion(Agent): >>> ]) >>> print(resp) >>> "6" + >>> agent = Reflexion( + >>> self_reflect_model=va.lmm.OpenAILMM(), + >>> action_agent=va.lmm.OpenAILMM() + >>> ) + >>> quesiton = "How many hearts are in this image?" + >>> resp = agent(question, image="cards.png") + >>> print(resp) + >>> "6" + >>> resp = agent([ + >>> {"role": "user", "content": question}, + >>> {"role": "assistant", "content": resp}, + >>> {"role": "user", "content": "No, please count the hearts on the bottom card."} + >>> ], image="cards.png") + >>> print(resp) + >>> "4" + ) """ def __init__( @@ -85,7 +106,7 @@ def __init__( agent_prompt: str = COT_AGENT_REFLECT_INSTRUCTION, reflect_prompt: str = COT_REFLECT_INSTRUCTION, finsh_prompt: str = CHECK_FINSH, - self_reflect_llm: Optional[Union[LLM, LMM]] = None, + self_reflect_model: Optional[Union[LLM, LMM]] = None, action_agent: Optional[Union[Agent, LLM, LMM]] = None, verbose: bool = False, ): @@ -98,17 +119,17 @@ def __init__( if verbose: _LOGGER.setLevel(logging.INFO) - if isinstance(self_reflect_llm, LLM) and not isinstance(action_agent, LLM): + if isinstance(self_reflect_model, LLM) and not isinstance(action_agent, LLM): raise ValueError( - "If self_reflect_llm is an LLM, then action_agent must also be an LLM." + "If self_reflect_model is an LLM, then action_agent must also be an LLM." ) - if isinstance(self_reflect_llm, LMM) and isinstance(action_agent, LLM): + if isinstance(self_reflect_model, LMM) and isinstance(action_agent, LLM): raise ValueError( - "If self_reflect_llm is an LMM, then action_agent must also be an agent or LMM." + "If self_reflect_model is an LMM, then action_agent must also be an agent or LMM." ) - self.self_reflect_llm = ( - OpenAILLM() if self_reflect_llm is None else self_reflect_llm + self.self_reflect_model = ( + OpenAILLM() if self_reflect_model is None else self_reflect_model ) self.action_agent = OpenAILLM() if action_agent is None else action_agent @@ -146,7 +167,7 @@ def chat( self.last_scratchpad += "\nObservation: " if is_correct: self.last_scratchpad += "Answer is CORRECT" - return self.self_reflect_llm(chat) + return self.self_reflect_model(chat) else: self.last_scratchpad += "Answer is INCORRECT" chat_context = "The previous conversation was:\n" + chat_str @@ -220,14 +241,14 @@ def prompt_reflection( scratchpad: str = "", image: Optional[Union[str, Path]] = None, ) -> str: - if isinstance(self.self_reflect_llm, LLM): + if isinstance(self.self_reflect_model, LLM): return format_step( - self.self_reflect_llm( + self.self_reflect_model( self._build_reflect_prompt(question, context, scratchpad) ) ) return format_step( - self.self_reflect_llm( + self.self_reflect_model( self._build_reflect_prompt(question, context, scratchpad), image=image, )