diff --git a/vision_agent/agent/agent_utils.py b/vision_agent/agent/agent_utils.py index e4e678d7..a443c94d 100644 --- a/vision_agent/agent/agent_utils.py +++ b/vision_agent/agent/agent_utils.py @@ -1,13 +1,28 @@ +import re import json import logging import sys -from typing import Any, Dict +from typing import Any, Dict, Optional logging.basicConfig(stream=sys.stdout) _LOGGER = logging.getLogger(__name__) +def _extract_sub_json(json_str: str) -> Optional[Dict[str, Any]]: + json_pattern = r"\{.*\}" + match = re.search(json_pattern, json_str, re.DOTALL) + if match: + json_str = match.group() + try: + json_dict = json.loads(json_str) + return json_dict + except json.JSONDecodeError: + return None + return None + + def extract_json(json_str: str) -> Dict[str, Any]: + __import__("ipdb").set_trace() try: json_dict = json.loads(json_str) except json.JSONDecodeError: @@ -22,6 +37,9 @@ def extract_json(json_str: str) -> Dict[str, Any]: try: json_dict = json.loads(json_str) except json.JSONDecodeError as e: + json_dict = _extract_sub_json(json_str) + if json_dict is not None: + return json_dict error_msg = f"Could not extract JSON from the given str: {json_str}.\nFunction input:\n{input_json_str}" _LOGGER.exception(error_msg) raise ValueError(error_msg) from e