Skip to content

Commit

Permalink
add better error handling for json decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Mar 7, 2024
1 parent b059ea1 commit 27d0a5a
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,13 @@ def generate_classifier(self, prompt: str) -> ImageTool:
{"role": "user", "content": prompt},
],
)
prompt = json.loads(response.choices[0].message.content)["prompt"]

try:
prompt = json.loads(cast(str, response.choices[0].message.content))["prompt"]
except json.JSONDecodeError:
_LOGGER.error(f"Failed to decode response: {response.choices[0].message.content}")
raise ValueError("Failed to decode response")

return CLIP(prompt)

def generate_detector(self, prompt: str) -> ImageTool:
Expand All @@ -123,7 +129,13 @@ def generate_detector(self, prompt: str) -> ImageTool:
{"role": "user", "content": prompt},
],
)
prompt = json.loads(response.choices[0].message.content)["prompt"]

try:
prompt = json.loads(cast(str, response.choices[0].message.content))["prompt"]
except json.JSONDecodeError:
_LOGGER.error(f"Failed to decode response: {response.choices[0].message.content}")
raise ValueError("Failed to decode response")

return GroundingDINO(prompt)

def generate_segmentor(self, prompt: str) -> ImageTool:
Expand All @@ -136,7 +148,13 @@ def generate_segmentor(self, prompt: str) -> ImageTool:
{"role": "user", "content": prompt},
],
)
prompt = json.loads(response.choices[0].message.content)["prompt"]

try:
prompt = json.loads(cast(str, response.choices[0].message.content))["prompt"]
except json.JSONDecodeError:
_LOGGER.error(f"Failed to decode response: {response.choices[0].message.content}")
raise ValueError("Failed to decode response")

return GroundingSAM(prompt)


Expand Down

0 comments on commit 27d0a5a

Please sign in to comment.