diff --git a/vision_agent/agent/reflexion.py b/vision_agent/agent/reflexion.py index 360846d2..c9cf680b 100644 --- a/vision_agent/agent/reflexion.py +++ b/vision_agent/agent/reflexion.py @@ -54,8 +54,8 @@ def __init__( agent_prompt: str = COT_AGENT_REFLECT_INSTRUCTION, reflect_prompt: str = COT_REFLECT_INSTRUCTION, finsh_prompt: str = CHECK_FINSH, - self_reflect_llm: LLM = OpenAILLM(), - action_agent: Union[Agent, LLM] = OpenAILLM(), + self_reflect_llm: Optional[LLM] = None, + action_agent: Optional[Union[Agent, LLM]] = None, ): self.agent_prompt = agent_prompt self.reflect_prompt = reflect_prompt @@ -66,6 +66,11 @@ def __init__( self.action_agent = action_agent self.reflections: List[str] = [] + if self_reflect_llm is None: + self.self_reflect_llm = OpenAILLM() + if action_agent is None: + self.action_agent = OpenAILLM() + def __call__(self, input: Union[List[Dict[str, str]], str]) -> str: if isinstance(input, str): input = [{"role": "user", "content": input}]