diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 44c3aa08..bbd2c1a5 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -445,6 +445,7 @@ def __call__( self, input: Union[List[Dict[str, str]], str], image: Optional[Union[str, Path]] = None, + reference_data: Optional[Dict[str, str]] = None, visualize_output: Optional[bool] = False, ) -> str: """Invoke the vision agent. @@ -459,7 +460,12 @@ def __call__( """ if isinstance(input, str): input = [{"role": "user", "content": input}] - return self.chat(input, image=image, visualize_output=visualize_output) + return self.chat( + input, + image=image, + visualize_output=visualize_output, + reference_data=reference_data, + ) def log_progress(self, description: str) -> None: _LOGGER.info(description) @@ -563,10 +569,14 @@ def chat( self, chat: List[Dict[str, str]], image: Optional[Union[str, Path]] = None, + reference_data: Optional[Dict[str, str]] = None, visualize_output: Optional[bool] = False, ) -> str: answer, _ = self.chat_with_workflow( - chat, image=image, visualize_output=visualize_output + chat, + image=image, + visualize_output=visualize_output, + reference_data=reference_data, ) return answer