From 4f01bae893b96699edb55f2d6c21f4874cce32a6 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Tue, 19 Mar 2024 16:16:48 -0700 Subject: [PATCH] fix doc --- vision_agent/llm/llm.py | 6 ++++-- vision_agent/lmm/lmm.py | 9 ++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/vision_agent/llm/llm.py b/vision_agent/llm/llm.py index 412369ca..91e3ebe8 100644 --- a/vision_agent/llm/llm.py +++ b/vision_agent/llm/llm.py @@ -64,7 +64,8 @@ def __call__(self, input: Union[str, List[Dict[str, str]]]) -> str: return self.chat(input) def generate_classifier(self, prompt: str) -> ImageTool: - prompt = CHOOSE_PARAMS.format(api_doc=CLIP.description, question=prompt) + api_doc = CLIP.description + "\n" + str(CLIP.usage) + prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=prompt) response = self.client.chat.completions.create( model=self.model_name, response_format={"type": "json_object"}, @@ -97,7 +98,8 @@ def generate_detector(self, params: str) -> ImageTool: return GroundingDINO(**cast(Mapping, params)) def generate_segmentor(self, params: str) -> ImageTool: - params = CHOOSE_PARAMS.format(api_doc=GroundingSAM.description, question=params) + api_doc = GroundingSAM.description + "\n" + str(GroundingSAM.usage) + params = CHOOSE_PARAMS.format(api_doc=api_doc, question=params) response = self.client.chat.completions.create( model=self.model_name, response_format={"type": "json_object"}, diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index 2023000c..58ee1d65 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -169,7 +169,8 @@ def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str return cast(str, response.choices[0].message.content) def generate_classifier(self, prompt: str) -> ImageTool: - prompt = CHOOSE_PARAMS.format(api_doc=CLIP.doc, question=prompt) + api_doc = CLIP.description + "\n" + str(CLIP.usage) + prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=prompt) response = self.client.chat.completions.create( model=self.model_name, messages=[ @@ -191,7 +192,8 @@ def generate_classifier(self, prompt: str) -> ImageTool: return CLIP(**cast(Mapping, prompt)) def generate_detector(self, params: str) -> ImageTool: - params = CHOOSE_PARAMS.format(api_doc=GroundingDINO.doc, question=params) + api_doc = GroundingDINO.description + "\n" + str(GroundingDINO.usage) + params = CHOOSE_PARAMS.format(api_doc=api_doc, question=params) response = self.client.chat.completions.create( model=self.model_name, messages=[ @@ -213,7 +215,8 @@ def generate_detector(self, params: str) -> ImageTool: return GroundingDINO(**cast(Mapping, params)) def generate_segmentor(self, prompt: str) -> ImageTool: - prompt = CHOOSE_PARAMS.format(api_doc=GroundingSAM.doc, question=prompt) + api_doc = GroundingSAM.description + "\n" + str(GroundingSAM.usage) + prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=prompt) response = self.client.chat.completions.create( model=self.model_name, messages=[