Skip to content

Commit

Permalink
changed the param name to self_reflect
Browse files Browse the repository at this point in the history
  • Loading branch information
shankar-vision-eng committed Apr 30, 2024
1 parent ef65ff2 commit 5bb6283
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def __call__(
image: Optional[Union[str, Path]] = None,
reference_data: Optional[Dict[str, str]] = None,
visualize_output: Optional[bool] = False,
reflect_output: Optional[bool] = True,
self_reflect: Optional[bool] = True,
) -> str:
"""Invoke the vision agent.
Expand All @@ -502,6 +502,7 @@ def __call__(
{"image": "image.jpg", "mask": "mask.jpg", "bbox": [0.1, 0.2, 0.1, 0.2]}
where the bounding box coordinates are normalized.
visualize_output: Whether to visualize the output.
self_reflect: boolean to enable and disable self reflection.
Returns:
The result of the vision agent in text.
Expand All @@ -513,6 +514,7 @@ def __call__(
image=image,
visualize_output=visualize_output,
reference_data=reference_data,
self_reflect=self_reflect,
)

def log_progress(self, description: str) -> None:
Expand All @@ -539,7 +541,7 @@ def chat_with_workflow(
image: Optional[Union[str, Path]] = None,
reference_data: Optional[Dict[str, str]] = None,
visualize_output: Optional[bool] = False,
reflect_output: Optional[bool] = True,
self_reflect: Optional[bool] = True,
) -> Tuple[str, List[Dict]]:
"""Chat with the vision agent and return the final answer and all tool results.
Expand All @@ -552,6 +554,7 @@ def chat_with_workflow(
{"image": "image.jpg", "mask": "mask.jpg", "bbox": [0.1, 0.2, 0.1, 0.2]}
where the bounding box coordinates are normalized.
visualize_output: Whether to visualize the output.
self_reflect: boolean to enable and disable self reflection.
Returns:
A tuple where the first item is the final answer and the second item is a
Expand Down Expand Up @@ -628,7 +631,7 @@ def chat_with_workflow(
else:
reflection_images = None

if reflect_output:
if self_reflect:
reflection = self_reflect(
self.reflect_model,
question,
Expand All @@ -644,7 +647,7 @@ def chat_with_workflow(
else:
reflections += "\n" + parsed_reflection["Reflection"]
else:
self.log_progress("Reflection skipped based on user request.")
self.log_progress("Self Reflection skipped based on user request.")
break
# '<ANSWER>' is a symbol to indicate the end of the chat, which is useful for streaming logs.
self.log_progress(
Expand All @@ -667,14 +670,14 @@ def chat(
image: Optional[Union[str, Path]] = None,
reference_data: Optional[Dict[str, str]] = None,
visualize_output: Optional[bool] = False,
reflect_output: Optional[bool] = True,
self_reflect: Optional[bool] = True,
) -> str:
answer, _ = self.chat_with_workflow(
chat,
image=image,
visualize_output=visualize_output,
reference_data=reference_data,
reflect_output=reflect_output,
self_reflect=self_reflect,
)
return answer

Expand Down

0 comments on commit 5bb6283

Please sign in to comment.