From 27d0a5a75f8b57d2872e39b079cf5bbe20fef2fe Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Wed, 6 Mar 2024 16:39:54 -0800 Subject: [PATCH] add better error handling for json decoding --- vision_agent/lmm/lmm.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index 8992072d..c1e402f7 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -110,7 +110,13 @@ def generate_classifier(self, prompt: str) -> ImageTool: {"role": "user", "content": prompt}, ], ) - prompt = json.loads(response.choices[0].message.content)["prompt"] + + try: + prompt = json.loads(cast(str, response.choices[0].message.content))["prompt"] + except json.JSONDecodeError: + _LOGGER.error(f"Failed to decode response: {response.choices[0].message.content}") + raise ValueError("Failed to decode response") + return CLIP(prompt) def generate_detector(self, prompt: str) -> ImageTool: @@ -123,7 +129,13 @@ def generate_detector(self, prompt: str) -> ImageTool: {"role": "user", "content": prompt}, ], ) - prompt = json.loads(response.choices[0].message.content)["prompt"] + + try: + prompt = json.loads(cast(str, response.choices[0].message.content))["prompt"] + except json.JSONDecodeError: + _LOGGER.error(f"Failed to decode response: {response.choices[0].message.content}") + raise ValueError("Failed to decode response") + return GroundingDINO(prompt) def generate_segmentor(self, prompt: str) -> ImageTool: @@ -136,7 +148,13 @@ def generate_segmentor(self, prompt: str) -> ImageTool: {"role": "user", "content": prompt}, ], ) - prompt = json.loads(response.choices[0].message.content)["prompt"] + + try: + prompt = json.loads(cast(str, response.choices[0].message.content))["prompt"] + except json.JSONDecodeError: + _LOGGER.error(f"Failed to decode response: {response.choices[0].message.content}") + raise ValueError("Failed to decode response") + return GroundingSAM(prompt)