Skip to content

Commit

Permalink
fix llm for tools
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Mar 19, 2024
1 parent 82d192a commit 22d5b74
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions vision_agent/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,29 @@ def __call__(self, input: Union[str, List[Dict[str, str]]]) -> str:
class OpenAILLM(LLM):
r"""An LLM class for any OpenAI LLM model."""

def __init__(self, model_name: str = "gpt-4-turbo-preview"):
def __init__(self, model_name: str = "gpt-4-turbo-preview", json_mode: bool = False):
self.model_name = model_name
self.client = OpenAI()
self.json_mode = json_mode

def generate(self, prompt: str) -> str:
kwargs = {"response_format": {"type": "json_object"}} if self.json_mode else {}
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "user", "content": prompt},
],
**kwargs, # type: ignore
)

return cast(str, response.choices[0].message.content)

def chat(self, chat: List[Dict[str, str]]) -> str:
kwargs = {"response_format": {"type": "json_object"}} if self.json_mode else {}
response = self.client.chat.completions.create(
model=self.model_name,
messages=chat, # type: ignore
**kwargs, # type: ignore
)

return cast(str, response.choices[0].message.content)
Expand All @@ -59,7 +64,7 @@ def __call__(self, input: Union[str, List[Dict[str, str]]]) -> str:
return self.chat(input)

def generate_classifier(self, prompt: str) -> ImageTool:
prompt = CHOOSE_PARAMS.format(api_doc=CLIP.doc, question=prompt)
prompt = CHOOSE_PARAMS.format(api_doc=CLIP.description, question=prompt)
response = self.client.chat.completions.create(
model=self.model_name,
response_format={"type": "json_object"},
Expand All @@ -75,7 +80,8 @@ def generate_classifier(self, prompt: str) -> ImageTool:
return CLIP(**cast(Mapping, params))

def generate_detector(self, params: str) -> ImageTool:
params = CHOOSE_PARAMS.format(api_doc=GroundingDINO.doc, question=params)
api_doc = GroundingDINO.description + "\n" + str(GroundingDINO.usage)
params = CHOOSE_PARAMS.format(api_doc=api_doc, question=params)
response = self.client.chat.completions.create(
model=self.model_name,
response_format={"type": "json_object"},
Expand All @@ -91,7 +97,7 @@ def generate_detector(self, params: str) -> ImageTool:
return GroundingDINO(**cast(Mapping, params))

def generate_segmentor(self, params: str) -> ImageTool:
params = CHOOSE_PARAMS.format(api_doc=GroundingSAM.doc, question=params)
params = CHOOSE_PARAMS.format(api_doc=GroundingSAM.description, question=params)
response = self.client.chat.completions.create(
model=self.model_name,
response_format={"type": "json_object"},
Expand Down

0 comments on commit 22d5b74

Please sign in to comment.