diff --git a/vision_agent/llm/llm.py b/vision_agent/llm/llm.py index e97bcdeb..d2eef9c6 100644 --- a/vision_agent/llm/llm.py +++ b/vision_agent/llm/llm.py @@ -33,11 +33,15 @@ class OpenAILLM(LLM): def __init__( self, model_name: str = "gpt-4-turbo-preview", + api_key: str = "", json_mode: bool = False, **kwargs: Any ): self.model_name = model_name - self.client = OpenAI() + if api_key: + self.client = OpenAI(api_key=api_key) + else: + self.client = OpenAI() self.kwargs = kwargs if json_mode: self.kwargs["response_format"] = {"type": "json_object"} diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index 7ae65eb2..0bff8e85 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -99,12 +99,16 @@ class OpenAILMM(LMM): def __init__( self, model_name: str = "gpt-4-vision-preview", + api_key: str = "", max_tokens: int = 1024, **kwargs: Any, ): self.model_name = model_name self.max_tokens = max_tokens - self.client = OpenAI() + if api_key: + self.client = OpenAI(api_key=api_key) + else: + self.client = OpenAI() self.kwargs = kwargs def __call__(