From 680554a393304c828f1a7233b6978f3e8ba1df3d Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Sat, 16 Mar 2024 15:06:52 -0700 Subject: [PATCH] added images for reflexion --- tests/test_lmm.py | 39 +++++++++ vision_agent/agent/agent.py | 7 +- vision_agent/agent/reflexion.py | 143 +++++++++++++++++++++++++++----- vision_agent/data/data.py | 28 ++++--- vision_agent/lmm/lmm.py | 74 ++++++++++++++++- 5 files changed, 251 insertions(+), 40 deletions(-) diff --git a/tests/test_lmm.py b/tests/test_lmm.py index 0dff8466..c1390726 100644 --- a/tests/test_lmm.py +++ b/tests/test_lmm.py @@ -33,6 +33,45 @@ def test_generate_with_mock(openai_lmm_mock): # noqa: F811 ) +@pytest.mark.parametrize( + "openai_lmm_mock", ["mocked response"], indirect=["openai_lmm_mock"] +) +def test_chat_with_mock(openai_lmm_mock): # noqa: F811 + lmm = OpenAILMM() + response = lmm.chat([{"role": "user", "content": "test prompt"}]) + assert response == "mocked response" + assert ( + openai_lmm_mock.chat.completions.create.call_args.kwargs["messages"][0][ + "content" + ][0]["text"] + == "test prompt" + ) + + +@pytest.mark.parametrize( + "openai_lmm_mock", ["mocked response"], indirect=["openai_lmm_mock"] +) +def test_call_with_mock(openai_lmm_mock): # noqa: F811 + lmm = OpenAILMM() + response = lmm("test prompt") + assert response == "mocked response" + assert ( + openai_lmm_mock.chat.completions.create.call_args.kwargs["messages"][0][ + "content" + ][0]["text"] + == "test prompt" + ) + + response = lmm([{"role": "user", "content": "test prompt"}]) + assert response == "mocked response" + assert ( + openai_lmm_mock.chat.completions.create.call_args.kwargs["messages"][0][ + "content" + ][0]["text"] + == "test prompt" + ) + + @pytest.mark.parametrize( "openai_lmm_mock", ['{"Parameters": {"prompt": "cat"}}'], diff --git a/vision_agent/agent/agent.py b/vision_agent/agent/agent.py index 7de17496..05421e41 100644 --- a/vision_agent/agent/agent.py +++ b/vision_agent/agent/agent.py @@ -1,8 +1,11 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Union +from pathlib import Path +from typing import Dict, List, Optional, Union class Agent(ABC): @abstractmethod - def __call__(self, input: Union[List[Dict[str, str]], str]) -> str: + def __call__( + self, input: Union[List[Dict[str, str]], str], image: Optional[Union[str, Path]] = None + ) -> str: pass diff --git a/vision_agent/agent/reflexion.py b/vision_agent/agent/reflexion.py index dc9b3bce..8c89f024 100644 --- a/vision_agent/agent/reflexion.py +++ b/vision_agent/agent/reflexion.py @@ -1,7 +1,10 @@ +import logging import re +import sys +from pathlib import Path from typing import Dict, List, Optional, Tuple, Union -from vision_agent import LLM, OpenAILLM +from vision_agent import LLM, LMM, OpenAILLM from .agent import Agent from .reflexion_prompts import ( @@ -13,6 +16,10 @@ REFLECTION_HEADER, ) +logging.basicConfig(stream=sys.stdout) + +_LOGGER = logging.getLogger(__name__) + def format_step(step: str) -> str: return step.strip("\n").strip().replace("\n", "") @@ -27,6 +34,7 @@ def parse_action(input: str) -> Tuple[str, str]: argument = match.group(2) return action_type, argument + _LOGGER.error(f"Invalid action: {input}") raise ValueError(f"Invalid action: {input}") @@ -47,6 +55,29 @@ def format_chat(chat: List[Dict[str, str]]) -> str: class Reflexion(Agent): + r"""This is an implementation of the Reflexion paper https://arxiv.org/abs/2303.11366 + based on the original implementation https://github.com/noahshinn/reflexion in the + hotpotqa folder. There are several differences between this implementation and the + original one. Because we do not have instant feedback on whether or not the agent + was correct, we use user feedback to determine if the agent was correct. The user + feedback is evaluated by the self_reflect_llm with a new prompt. + + Examples:: + >>> from vision_agent.agent import Reflexion + >>> agent = Reflexion() + >>> question = "How many tires does a truck have?" + >>> resp = agent(question) + >>> print(resp) + >>> "18" + >>> resp = agent([ + >>> {"role": "user", "content": question}, + >>> {"role": "assistant", "content": resp}, + >>> {"role": "user", "content": "No I mean those regular trucks but where the back tires are double."} + >>> ]) + >>> print(resp) + >>> "6" + """ + def __init__( self, cot_examples: str = COTQA_SIMPLE6, @@ -54,8 +85,9 @@ def __init__( agent_prompt: str = COT_AGENT_REFLECT_INSTRUCTION, reflect_prompt: str = COT_REFLECT_INSTRUCTION, finsh_prompt: str = CHECK_FINSH, - self_reflect_llm: Optional[LLM] = None, - action_agent: Optional[Union[Agent, LLM]] = None, + self_reflect_llm: Optional[Union[LLM, LMM]] = None, + action_agent: Optional[Union[Agent, LLM, LMM]] = None, + verbose: bool = False, ): self.agent_prompt = agent_prompt self.reflect_prompt = reflect_prompt @@ -63,25 +95,48 @@ def __init__( self.cot_examples = cot_examples self.refelct_examples = reflect_examples self.reflections: List[str] = [] + if verbose: + _LOGGER.setLevel(logging.INFO) - if self_reflect_llm is None: - self.self_reflect_llm = OpenAILLM() - if action_agent is None: - self.action_agent = OpenAILLM() + if isinstance(self_reflect_llm, LLM) and not isinstance(action_agent, LLM): + raise ValueError( + "If self_reflect_llm is an LLM, then action_agent must also be an LLM." + ) + if isinstance(self_reflect_llm, LMM) and isinstance(action_agent, LLM): + raise ValueError( + "If self_reflect_llm is an LMM, then action_agent must also be an agent or LMM." + ) + + self.self_reflect_llm = ( + OpenAILLM() if self_reflect_llm is None else self_reflect_llm + ) + self.action_agent = OpenAILLM() if action_agent is None else action_agent - def __call__(self, input: Union[List[Dict[str, str]], str]) -> str: + def __call__( + self, + input: Union[str, List[Dict[str, str]]], + image: Optional[Union[str, Path]] = None, + ) -> str: if isinstance(input, str): input = [{"role": "user", "content": input}] - return self.chat(input) + return self.chat(input, image) - def chat(self, chat: List[Dict[str, str]]) -> str: + def chat( + self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None + ) -> str: if len(chat) == 0 or chat[0]["role"] != "user": raise ValueError( - f"Invalid chat. Should start with user and then assistant and contain at least one entry {chat}" + f"Invalid chat. Should start with user and alternate between user" + f"and assistant and contain at least one entry {chat}" ) + if image is not None and isinstance(self.action_agent, LLM): + raise ValueError( + "If image is provided, then action_agent must be an agent or LMM." + ) + question = chat[0]["content"] if len(chat) == 1: - results = self._step(question) + results = self._step(question, image=image) self.last_scratchpad = results["scratchpad"] return results["action_arg"] @@ -95,19 +150,29 @@ def chat(self, chat: List[Dict[str, str]]) -> str: else: self.last_scratchpad += "Answer is INCORRECT" chat_context = "The previous conversation was:\n" + chat_str - reflections = self.reflect(question, chat_context, self.last_scratchpad) - results = self._step(question, reflections) + reflections = self.reflect( + question, chat_context, self.last_scratchpad, image + ) + _LOGGER.info(f" {reflections}") + results = self._step(question, reflections, image=image) self.last_scratchpad = results["scratchpad"] return results["action_arg"] - def _step(self, question: str, reflections: str = "") -> Dict[str, str]: + def _step( + self, + question: str, + reflections: str = "", + image: Optional[Union[str, Path]] = None, + ) -> Dict[str, str]: # Think scratchpad = "\nThought:" - scratchpad += " " + self.prompt_agent(question, reflections, scratchpad) + scratchpad += " " + self.prompt_agent(question, reflections, scratchpad, image) + _LOGGER.info(f" {scratchpad}") # Act scratchpad += "\nAction:" - action = self.prompt_agent(question, reflections, scratchpad) + action = self.prompt_agent(question, reflections, scratchpad, image) + _LOGGER.info(f" {action}") scratchpad += " " + action action_type, argument = parse_action(action) return { @@ -116,23 +181,55 @@ def _step(self, question: str, reflections: str = "") -> Dict[str, str]: "action_arg": argument, } - def reflect(self, question: str, context: str, scratchpad: str) -> str: - self.reflections += [self.prompt_reflection(question, context, scratchpad)] + def reflect( + self, + question: str, + context: str, + scratchpad: str, + image: Optional[Union[str, Path]], + ) -> str: + self.reflections += [ + self.prompt_reflection(question, context, scratchpad, image) + ] return format_reflections(self.reflections) - def prompt_agent(self, question: str, reflections: str, scratchpad: str) -> str: + def prompt_agent( + self, + question: str, + reflections: str, + scratchpad: str, + image: Optional[Union[str, Path]] = None, + ) -> str: + if isinstance(self.action_agent, LLM): + return format_step( + self.action_agent( + self._build_agent_prompt(question, reflections, scratchpad) + ) + ) return format_step( self.action_agent( - self._build_agent_prompt(question, reflections, scratchpad) + self._build_agent_prompt(question, reflections, scratchpad), + image=image, ) ) def prompt_reflection( - self, question: str, context: str = "", scratchpad: str = "" + self, + question: str, + context: str = "", + scratchpad: str = "", + image: Optional[Union[str, Path]] = None, ) -> str: + if isinstance(self.self_reflect_llm, LLM): + return format_step( + self.self_reflect_llm( + self._build_reflect_prompt(question, context, scratchpad) + ) + ) return format_step( self.self_reflect_llm( - self._build_reflect_prompt(question, context, scratchpad) + self._build_reflect_prompt(question, context, scratchpad), + image=image, ) ) diff --git a/vision_agent/data/data.py b/vision_agent/data/data.py index 4548b42b..6b51488b 100644 --- a/vision_agent/data/data.py +++ b/vision_agent/data/data.py @@ -22,10 +22,12 @@ class DataStore: r"""A class to store and manage image data along with its generated metadata from an LMM.""" def __init__(self, df: pd.DataFrame): - r"""Initializes the DataStore with a DataFrame containing image paths and image IDs. If the image IDs are not present, they are generated using UUID4. The DataFrame must contain an 'image_paths' column. + r"""Initializes the DataStore with a DataFrame containing image paths and image + IDs. If the image IDs are not present, they are generated using UUID4. The + DataFrame must contain an 'image_paths' column. Args: - df (pd.DataFrame): The DataFrame containing "image_paths" and "image_id" columns. + df: The DataFrame containing "image_paths" and "image_id" columns. """ self.df = df self.lmm: Optional[LMM] = None @@ -47,12 +49,14 @@ def add_lmm(self, lmm: LMM) -> Self: def add_column( self, name: str, prompt: str, func: Optional[Callable[[str], str]] = None ) -> Self: - r"""Adds a new column to the DataFrame containing the generated metadata from the LMM. + r"""Adds a new column to the DataFrame containing the generated metadata from + the LMM. Args: - name (str): The name of the column to be added. - prompt (str): The prompt to be used to generate the metadata. - func (Optional[Callable[[Any], Any]]): A Python function to be applied on the output of `lmm.generate`. Defaults to None. + name: The name of the column to be added. + prompt: The prompt to be used to generate the metadata. + func: A Python function to be applied on the output of `lmm.generate`. + Defaults to None. """ if self.lmm is None: raise ValueError("LMM not set yet") @@ -67,10 +71,11 @@ def add_column( return self def build_index(self, target_col: str) -> Self: - r"""This will generate embeddings for the `target_col` and build a searchable index over them, so next time you run search it will search over this index. + r"""This will generate embeddings for the `target_col` and build a searchable + index over them, so next time you run search it will search over this index. Args: - target_col (str): The column name containing the data to be indexed.""" + target_col: The column name containing the data to be indexed.""" if self.emb is None: raise ValueError("Embedder not set yet") @@ -92,11 +97,12 @@ def get_embeddings(self) -> npt.NDArray[np.float32]: ) def search(self, query: str, top_k: int = 10) -> List[Dict]: - r"""Searches the index for the most similar images to the query and returns the top_k results. + r"""Searches the index for the most similar images to the query and returns + the top_k results. Args: - query (str): The query to search for. - top_k (int, optional): The number of results to return. Defaults to 10.""" + query: The query to search for. + top_k: The number of results to return. Defaults to 10.""" if self.index is None: raise ValueError("Index not built yet") if self.emb is None: diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index b6c20e27..2023000c 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -17,8 +17,6 @@ ImageTool, ) -logging.basicConfig(level=logging.INFO) - _LOGGER = logging.getLogger(__name__) _LLAVA_ENDPOINT = "https://svtswgdnleslqcsjvilau4p6u40jwrkn.lambda-url.us-east-2.on.aws" @@ -35,6 +33,20 @@ class LMM(ABC): def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str: pass + @abstractmethod + def chat( + self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None + ) -> str: + pass + + @abstractmethod + def __call__( + self, + input: Union[str, List[Dict[str, str]]], + image: Optional[Union[str, Path]] = None, + ) -> str: + pass + class LLaVALMM(LMM): r"""An LMM class for the LLaVA-1.6 34B model.""" @@ -42,6 +54,20 @@ class LLaVALMM(LMM): def __init__(self, model_name: str): self.model_name = model_name + def __call__( + self, + input: Union[str, List[Dict[str, str]]], + image: Optional[Union[str, Path]] = None, + ) -> str: + if isinstance(input, str): + return self.generate(input, image) + return self.chat(input, image) + + def chat( + self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None + ) -> str: + raise NotImplementedError("Chat not supported for LLaVA") + def generate( self, prompt: str, @@ -71,10 +97,50 @@ def generate( class OpenAILMM(LMM): r"""An LMM class for the OpenAI GPT-4 Vision model.""" - def __init__(self, model_name: str = "gpt-4-vision-preview"): + def __init__( + self, model_name: str = "gpt-4-vision-preview", max_tokens: int = 1024 + ): self.model_name = model_name + self.max_tokens = max_tokens self.client = OpenAI() + def __call__( + self, + input: Union[str, List[Dict[str, str]]], + image: Optional[Union[str, Path]] = None, + ) -> str: + if isinstance(input, str): + return self.generate(input, image) + return self.chat(input, image) + + def chat( + self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None + ) -> str: + fixed_chat = [] + for c in chat: + fixed_c = {"role": c["role"]} + fixed_c["content"] = [{"type": "text", "text": c["content"]}] # type: ignore + fixed_chat.append(fixed_c) + + if image: + extension = Path(image).suffix + encoded_image = encode_image(image) + fixed_chat[0]["content"].append( # type: ignore + { + "type": "image_url", + "image_url": { + "url": f"data:image/{extension};base64,{encoded_image}", + "detail": "low", + }, + }, + ) + + response = self.client.chat.completions.create( + model=self.model_name, messages=fixed_chat, max_tokens=self.max_tokens # type: ignore + ) + + return cast(str, response.choices[0].message.content) + def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str: message: List[Dict[str, Any]] = [ { @@ -98,7 +164,7 @@ def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str ) response = self.client.chat.completions.create( - model=self.model_name, messages=message # type: ignore + model=self.model_name, messages=message, max_tokens=self.max_tokens # type: ignore ) return cast(str, response.choices[0].message.content)