Skip to content

Commit

Permalink
fix doc
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Mar 19, 2024
1 parent 22d5b74 commit 4f01bae
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
6 changes: 4 additions & 2 deletions vision_agent/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def __call__(self, input: Union[str, List[Dict[str, str]]]) -> str:
return self.chat(input)

def generate_classifier(self, prompt: str) -> ImageTool:
prompt = CHOOSE_PARAMS.format(api_doc=CLIP.description, question=prompt)
api_doc = CLIP.description + "\n" + str(CLIP.usage)
prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=prompt)
response = self.client.chat.completions.create(
model=self.model_name,
response_format={"type": "json_object"},
Expand Down Expand Up @@ -97,7 +98,8 @@ def generate_detector(self, params: str) -> ImageTool:
return GroundingDINO(**cast(Mapping, params))

def generate_segmentor(self, params: str) -> ImageTool:
params = CHOOSE_PARAMS.format(api_doc=GroundingSAM.description, question=params)
api_doc = GroundingSAM.description + "\n" + str(GroundingSAM.usage)
params = CHOOSE_PARAMS.format(api_doc=api_doc, question=params)
response = self.client.chat.completions.create(
model=self.model_name,
response_format={"type": "json_object"},
Expand Down
9 changes: 6 additions & 3 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str
return cast(str, response.choices[0].message.content)

def generate_classifier(self, prompt: str) -> ImageTool:
prompt = CHOOSE_PARAMS.format(api_doc=CLIP.doc, question=prompt)
api_doc = CLIP.description + "\n" + str(CLIP.usage)
prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=prompt)
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
Expand All @@ -191,7 +192,8 @@ def generate_classifier(self, prompt: str) -> ImageTool:
return CLIP(**cast(Mapping, prompt))

def generate_detector(self, params: str) -> ImageTool:
params = CHOOSE_PARAMS.format(api_doc=GroundingDINO.doc, question=params)
api_doc = GroundingDINO.description + "\n" + str(GroundingDINO.usage)
params = CHOOSE_PARAMS.format(api_doc=api_doc, question=params)
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
Expand All @@ -213,7 +215,8 @@ def generate_detector(self, params: str) -> ImageTool:
return GroundingDINO(**cast(Mapping, params))

def generate_segmentor(self, prompt: str) -> ImageTool:
prompt = CHOOSE_PARAMS.format(api_doc=GroundingSAM.doc, question=prompt)
api_doc = GroundingSAM.description + "\n" + str(GroundingSAM.usage)
prompt = CHOOSE_PARAMS.format(api_doc=api_doc, question=prompt)
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
Expand Down

0 comments on commit 4f01bae

Please sign in to comment.