From 3e1f25da36e873b2fbf7b5f359ee9151fcedde79 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Fri, 12 Apr 2024 14:35:45 -0700 Subject: [PATCH] formatting fix --- vision_agent/agent/vision_agent.py | 4 +++- vision_agent/image_utils.py | 7 +++++-- vision_agent/tools/tools.py | 26 +++++++++++++++----------- 3 files changed, 23 insertions(+), 14 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index fbbf1bc7..e6c34fdb 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -492,7 +492,9 @@ def chat( image: Optional[Union[str, Path]] = None, visualize_output: Optional[bool] = False, ) -> str: - answer, _ = self.chat_with_workflow(chat, image=image, visualize_output=visualize_output) + answer, _ = self.chat_with_workflow( + chat, image=image, visualize_output=visualize_output + ) return answer def retrieval( diff --git a/vision_agent/image_utils.py b/vision_agent/image_utils.py index 54fdfd0b..d0164e11 100644 --- a/vision_agent/image_utils.py +++ b/vision_agent/image_utils.py @@ -109,7 +109,8 @@ def overlay_bboxes( fontsize = max(12, int(min(width, height) / 40)) draw = ImageDraw.Draw(image) font = ImageFont.truetype( - str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")), fontsize + str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")), + fontsize, ) if "bboxes" not in bboxes: return image.convert("RGB") @@ -147,7 +148,9 @@ def overlay_masks( elif isinstance(image, np.ndarray): image = Image.fromarray(image) - color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(set(masks["labels"]))} + color = { + label: COLORS[i % len(COLORS)] for i, label in enumerate(set(masks["labels"])) + } if "masks" not in masks: return image.convert("RGB") diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 88f7536f..9b1a6740 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -53,9 +53,7 @@ class Tool(ABC): class NoOp(Tool): name = "noop_" - description = ( - "'noop_' is a no-op tool that does nothing if you do not want answer the question directly and not use a tool." - ) + description = "'noop_' is a no-op tool that does nothing if you do not want answer the question directly and not use a tool." usage = { "required_parameters": [], "examples": [ @@ -180,7 +178,10 @@ class GroundingDINO(Tool): }, { "scenario": "Can you detect the person on the left and right? Image name: person.jpg", - "parameters": {"prompt": "left person. right person", "image": "person.jpg"}, + "parameters": { + "prompt": "left person. right person", + "image": "person.jpg", + }, }, { "scenario": "Detect the red shirts and green shirst. Image name: shirts.jpg", @@ -286,7 +287,10 @@ class GroundingSAM(Tool): }, { "scenario": "Can you segment the person on the left and right? Image name: person.jpg", - "parameters": {"prompt": "left person. right person", "image": "person.jpg"}, + "parameters": { + "prompt": "left person. right person", + "image": "person.jpg", + }, }, { "scenario": "Can you build me a tool that segments red shirts and green shirts? Image name: shirts.jpg", @@ -496,9 +500,7 @@ def __call__(self, masks: Union[str, Path]) -> float: class BboxIoU(Tool): name = "bbox_iou_" - description = ( - "'bbox_iou_' returns the intersection over union of two bounding boxes. This is a good tool for determining if two objects are overlapping." - ) + description = "'bbox_iou_' returns the intersection over union of two bounding boxes. This is a good tool for determining if two objects are overlapping." usage = { "required_parameters": [ {"name": "bbox1", "type": "List[int]"}, @@ -602,7 +604,9 @@ class Calculator(Tool): r"""Calculator is a tool that can perform basic arithmetic operations.""" name = "calculator_" - description = "'calculator_' is a tool that can perform basic arithmetic operations." + description = ( + "'calculator_' is a tool that can perform basic arithmetic operations." + ) usage = { "required_parameters": [{"name": "equation", "type": "str"}], "examples": [ @@ -613,8 +617,8 @@ class Calculator(Tool): { "scenario": "If you want to calculate (4 + 2.5) / 2.1", "parameters": {"equation": "(4 + 2.5) / 2.1"}, - } - ] + }, + ], } def __call__(self, equation: str) -> float: