From 78971c4037e8cb9aa07b56168392b7a93c3c2913 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Fri, 5 Apr 2024 16:29:36 -0700 Subject: [PATCH] added api_key in init arg --- vision_agent/llm/llm.py | 6 +++++- vision_agent/lmm/lmm.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/vision_agent/llm/llm.py b/vision_agent/llm/llm.py index e97bcdeb..d2eef9c6 100644 --- a/vision_agent/llm/llm.py +++ b/vision_agent/llm/llm.py @@ -33,11 +33,15 @@ class OpenAILLM(LLM): def __init__( self, model_name: str = "gpt-4-turbo-preview", + api_key: str = "", json_mode: bool = False, **kwargs: Any ): self.model_name = model_name - self.client = OpenAI() + if api_key: + self.client = OpenAI(api_key=api_key) + else: + self.client = OpenAI() self.kwargs = kwargs if json_mode: self.kwargs["response_format"] = {"type": "json_object"} diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index 7ae65eb2..0bff8e85 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -99,12 +99,16 @@ class OpenAILMM(LMM): def __init__( self, model_name: str = "gpt-4-vision-preview", + api_key: str = "", max_tokens: int = 1024, **kwargs: Any, ): self.model_name = model_name self.max_tokens = max_tokens - self.client = OpenAI() + if api_key: + self.client = OpenAI(api_key=api_key) + else: + self.client = OpenAI() self.kwargs = kwargs def __call__(