Skip to content

Commit

Permalink
Pool Demo (#53)
Browse files Browse the repository at this point in the history
* visualized output/reflection to handle extract_frames_

* remove ipdb

* added json mode for lmm, upgraded gpt-4-turbo

* updated reflection prompt

* refactor to make function simpler

* updated reflection prompt, add tool usage doc

* fixed format issue

* fixed type issue

* fixed test case
  • Loading branch information
dillonalaird authored Apr 16, 2024
1 parent f871d2f commit 0c6c448
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 60 deletions.
8 changes: 4 additions & 4 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_generate_with_mock(openai_llm_mock): # noqa: F811
response = llm.generate("test prompt")
assert response == "mocked response"
openai_llm_mock.chat.completions.create.assert_called_once_with(
model="gpt-4-turbo-preview",
model="gpt-4-turbo",
messages=[{"role": "user", "content": "test prompt"}],
)

Expand All @@ -31,7 +31,7 @@ def test_chat_with_mock(openai_llm_mock): # noqa: F811
response = llm.chat([{"role": "user", "content": "test prompt"}])
assert response == "mocked response"
openai_llm_mock.chat.completions.create.assert_called_once_with(
model="gpt-4-turbo-preview",
model="gpt-4-turbo",
messages=[{"role": "user", "content": "test prompt"}],
)

Expand All @@ -44,14 +44,14 @@ def test_call_with_mock(openai_llm_mock): # noqa: F811
response = llm("test prompt")
assert response == "mocked response"
openai_llm_mock.chat.completions.create.assert_called_once_with(
model="gpt-4-turbo-preview",
model="gpt-4-turbo",
messages=[{"role": "user", "content": "test prompt"}],
)

response = llm([{"role": "user", "content": "test prompt"}])
assert response == "mocked response"
openai_llm_mock.chat.completions.create.assert_called_with(
model="gpt-4-turbo-preview",
model="gpt-4-turbo",
messages=[{"role": "user", "content": "test prompt"}],
)

Expand Down
159 changes: 109 additions & 50 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@

def parse_json(s: str) -> Any:
s = (
s.replace(": true", ": True")
.replace(": false", ": False")
.replace(":true", ": True")
.replace(":false", ": False")
s.replace(": True", ": true")
.replace(": False", ": false")
.replace(":True", ": true")
.replace(":False", ": false")
.replace("```", "")
.strip()
)
Expand All @@ -62,6 +62,19 @@ def format_tools(tools: Dict[int, Any]) -> str:
return tool_str


def format_tool_usage(tools: Dict[int, Any], tool_result: List[Dict]) -> str:
usage = []
name_to_usage = {v["name"]: v["usage"] for v in tools.values()}
for tool_res in tool_result:
if "tool_name" in tool_res:
usage.append((tool_res["tool_name"], name_to_usage[tool_res["tool_name"]]))

usage_str = ""
for tool_name, tool_usage in usage:
usage_str += f"{tool_name} - {tool_usage}\n"
return usage_str


def topological_sort(tasks: List[Dict]) -> List[Dict]:
in_degree = {task["id"]: 0 for task in tasks}
for task in tasks:
Expand Down Expand Up @@ -255,7 +268,8 @@ def self_reflect(
) -> str:
prompt = VISION_AGENT_REFLECTION.format(
question=question,
tools=format_tools(tools),
tools=format_tools({k: v["description"] for k, v in tools.items()}),
tool_usage=format_tool_usage(tools, tool_result),
tool_results=str(tool_result),
final_answer=final_answer,
)
Expand All @@ -268,59 +282,101 @@ def self_reflect(
return reflect_model(prompt)


def parse_reflect(reflect: str) -> bool:
# GPT-4V has a hard time following directions, so make the criteria less strict
return (
def parse_reflect(reflect: str) -> Any:
reflect = reflect.strip()
try:
return parse_json(reflect)
except Exception:
_LOGGER.error(f"Failed parse json reflection: {reflect}")
# LMMs have a hard time following directions, so make the criteria less strict
finish = (
"finish" in reflect.lower() and len(reflect) < 100
) or "finish" in reflect.lower()[-10:]


def visualize_result(all_tool_results: List[Dict]) -> List[str]:
image_to_data: Dict[str, Dict] = {}
for tool_result in all_tool_results:
if tool_result["tool_name"] not in ["grounding_sam_", "grounding_dino_"]:
continue

parameters = tool_result["parameters"]
# parameters can either be a dictionary or list, parameters can also be malformed
# becaus the LLM builds them
if isinstance(parameters, dict):
if "image" not in parameters:
continue
parameters = [parameters]
elif isinstance(tool_result["parameters"], list):
if len(tool_result["parameters"]) < 1 or (
"image" not in tool_result["parameters"][0]
):
continue

for param, call_result in zip(parameters, tool_result["call_results"]):
# calls can fail, so we need to check if the call was successful
if not isinstance(call_result, dict):
continue
if "bboxes" not in call_result:
continue

# if the call was successful, then we can add the image data
image = param["image"]
return {"Finish": finish, "Reflection": reflect}


def _handle_extract_frames(
image_to_data: Dict[str, Dict], tool_result: Dict
) -> Dict[str, Dict]:
image_to_data = image_to_data.copy()
# handle extract_frames_ case, useful if it extracts frames but doesn't do
# any following processing
for video_file_output in tool_result["call_results"]:
for frame, _ in video_file_output:
image = frame
if image not in image_to_data:
image_to_data[image] = {
"bboxes": [],
"masks": [],
"labels": [],
"scores": [],
}
return image_to_data


def _handle_viz_tools(
image_to_data: Dict[str, Dict], tool_result: Dict
) -> Dict[str, Dict]:
image_to_data = image_to_data.copy()

# handle grounding_sam_ and grounding_dino_
parameters = tool_result["parameters"]
# parameters can either be a dictionary or list, parameters can also be malformed
# becaus the LLM builds them
if isinstance(parameters, dict):
if "image" not in parameters:
return image_to_data
parameters = [parameters]
elif isinstance(tool_result["parameters"], list):
if len(tool_result["parameters"]) < 1 or (
"image" not in tool_result["parameters"][0]
):
return image_to_data

for param, call_result in zip(parameters, tool_result["call_results"]):
# calls can fail, so we need to check if the call was successful
if not isinstance(call_result, dict) or "bboxes" not in call_result:
return image_to_data

# if the call was successful, then we can add the image data
image = param["image"]
if image not in image_to_data:
image_to_data[image] = {
"bboxes": [],
"masks": [],
"labels": [],
"scores": [],
}

image_to_data[image]["bboxes"].extend(call_result["bboxes"])
image_to_data[image]["labels"].extend(call_result["labels"])
image_to_data[image]["scores"].extend(call_result["scores"])
if "masks" in call_result:
image_to_data[image]["masks"].extend(call_result["masks"])

return image_to_data


image_to_data[image]["bboxes"].extend(call_result["bboxes"])
image_to_data[image]["labels"].extend(call_result["labels"])
image_to_data[image]["scores"].extend(call_result["scores"])
if "masks" in call_result:
image_to_data[image]["masks"].extend(call_result["masks"])
def visualize_result(all_tool_results: List[Dict]) -> List[str]:
image_to_data: Dict[str, Dict] = {}
for tool_result in all_tool_results:
# only handle bbox/mask tools or frame extraction
if tool_result["tool_name"] not in [
"grounding_sam_",
"grounding_dino_",
"extract_frames_",
]:
continue

if tool_result["tool_name"] == "extract_frames_":
image_to_data = _handle_extract_frames(image_to_data, tool_result)
else:
image_to_data = _handle_viz_tools(image_to_data, tool_result)

visualized_images = []
for image in image_to_data:
image_path = Path(image)
image_data = image_to_data[image]
for image_str in image_to_data:
image_path = Path(image_str)
image_data = image_to_data[image_str]
image = overlay_masks(image_path, image_data)
image = overlay_bboxes(image, image_data)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
Expand Down Expand Up @@ -374,7 +430,9 @@ def __init__(
OpenAILLM(temperature=0.1) if answer_model is None else answer_model
)
self.reflect_model = (
OpenAILMM(temperature=0.1) if reflect_model is None else reflect_model
OpenAILMM(json_mode=True, temperature=0.1)
if reflect_model is None
else reflect_model
)
self.max_retries = max_retries
self.tools = TOOLS
Expand Down Expand Up @@ -470,11 +528,12 @@ def chat_with_workflow(
visualized_output[0] if len(visualized_output) > 0 else image,
)
self.log_progress(f"Reflection: {reflection}")
if parse_reflect(reflection):
parsed_reflection = parse_reflect(reflection)
if parsed_reflection["Finish"]:
break
else:
reflections += "\n" + reflection
# '<END>' is a symbol to indicate the end of the chat, which is useful for streaming logs.
reflections += "\n" + parsed_reflection["Reflection"]
# '<ANSWER>' is a symbol to indicate the end of the chat, which is useful for streaming logs.
self.log_progress(
f"The Vision Agent has concluded this chat. <ANSWER>{final_answer}</ANSWER>"
)
Expand Down
15 changes: 14 additions & 1 deletion vision_agent/agent/vision_agent_prompts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
VISION_AGENT_REFLECTION = """You are an advanced reasoning agent that can improve based on self-refection. You will be given a previous reasoning trial in which you were given the user's question, the available tools that the agent has, the decomposed tasks and tools that the agent used to answer the question and the final answer the agent provided. You may also receive an image with the visualized bounding boxes or masks with their associated labels and scores from the tools used. You must determine if the agent's answer was correct or incorrect. If the agent's answer was correct, respond with Finish. If the agent's answer was incorrect, you must diagnose a possible reason for failure or phrasing discrepancy and devise a new, concise, concrete plan that aims to mitigate the same failure with the tools available. Do not make vague steps like re-evaluate the threshold, instead make concrete steps like use a threshold of 0.5 or whatever threshold you think would fix this issue. If the task cannot be completed with the existing tools, respond with Finish. Use complete sentences.
VISION_AGENT_REFLECTION = """You are an advanced reasoning agent that can improve based on self-refection. You will be given a previous reasoning trial in which you were given the user's question, the available tools that the agent has, the decomposed tasks and tools that the agent used to answer the question and the final answer the agent provided. You may also receive an image with the visualized bounding boxes or masks with their associated labels and scores from the tools used.
Please note that:
1. You must ONLY output parsible JSON format. If the agents output was correct set "Finish" to true, else set "Finish" to false. An example output looks like:
{{"Finish": true, "Reflection": "The agent's answer was correct."}}
2. You must utilize the image with the visualized bounding boxes or masks and determine if the tools were used correctly or, using your own judgement, utilized incorrectly.
3. If the agent's answer was incorrect, you must diagnose a possible reason for failure or phrasing discrepancy and devise a new, concise, concrete plan that aims to mitigate the same failure with the tools available. An example output looks like:
{{"Finish": false, "Reflection": "I can see from teh visualized bounding boxes that the agent's answer was incorrect because the grounding_dino_ tool produced false positive predictions. The agent should use the following tools with the following parameters:
Step 1: Use 'grounding_dino_' with a 'prompt' of 'baby. bed' and a 'box_threshold' of 0.7 to reduce the false positives.
Step 2: Use 'box_iou_' with the baby bounding box and the bed bounding box to determine if the baby is on the bed or not."}}
4. If the task cannot be completed with the existing tools or by adjusting the parameters, set "Finish" to true.
User's question: {question}
Expand All @@ -8,6 +18,9 @@
Tasks and tools used:
{tool_results}
Tool's used API documentation:
{tool_usage}
Final answer:
{final_answer}
Expand Down
2 changes: 1 addition & 1 deletion vision_agent/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class OpenAILLM(LLM):

def __init__(
self,
model_name: str = "gpt-4-turbo-preview",
model_name: str = "gpt-4-turbo",
api_key: Optional[str] = None,
json_mode: bool = False,
**kwargs: Any
Expand Down
12 changes: 8 additions & 4 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,10 @@ class OpenAILMM(LMM):

def __init__(
self,
model_name: str = "gpt-4-vision-preview",
model_name: str = "gpt-4-turbo",
api_key: Optional[str] = None,
max_tokens: int = 1024,
json_mode: bool = False,
**kwargs: Any,
):
if not api_key:
Expand All @@ -111,7 +112,10 @@ def __init__(

self.client = OpenAI(api_key=api_key)
self.model_name = model_name
self.max_tokens = max_tokens
if "max_tokens" not in kwargs:
kwargs["max_tokens"] = max_tokens
if json_mode:
kwargs["response_format"] = {"type": "json_object"}
self.kwargs = kwargs

def __call__(
Expand Down Expand Up @@ -153,7 +157,7 @@ def chat(
)

response = self.client.chat.completions.create(
model=self.model_name, messages=fixed_chat, max_tokens=self.max_tokens, **self.kwargs # type: ignore
model=self.model_name, messages=fixed_chat, **self.kwargs # type: ignore
)

return cast(str, response.choices[0].message.content)
Expand Down Expand Up @@ -181,7 +185,7 @@ def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str
)

response = self.client.chat.completions.create(
model=self.model_name, messages=message, max_tokens=self.max_tokens, **self.kwargs # type: ignore
model=self.model_name, messages=message, **self.kwargs # type: ignore
)
return cast(str, response.choices[0].message.content)

Expand Down

0 comments on commit 0c6c448

Please sign in to comment.