diff --git a/vision_agent/llm/llm.py b/vision_agent/llm/llm.py index d2eef9c6..1ccc5809 100644 --- a/vision_agent/llm/llm.py +++ b/vision_agent/llm/llm.py @@ -1,8 +1,9 @@ import json +import os from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Mapping, Union, cast +from typing import Any, Callable, Dict, List, Mapping, Optional, Union, cast -from openai import OpenAI +from openai import AzureOpenAI, OpenAI from vision_agent.tools import ( CHOOSE_PARAMS, @@ -33,15 +34,18 @@ class OpenAILLM(LLM): def __init__( self, model_name: str = "gpt-4-turbo-preview", - api_key: str = "", + api_key: Optional[str] = None, json_mode: bool = False, **kwargs: Any ): + if not api_key: + api_key = os.getenv("OPENAI_API_KEY") + + if not api_key: + raise ValueError("OpenAI API key is required.") + + self.client = OpenAI(api_key=api_key) self.model_name = model_name - if api_key: - self.client = OpenAI(api_key=api_key) - else: - self.client = OpenAI() self.kwargs = kwargs if json_mode: self.kwargs["response_format"] = {"type": "json_object"} @@ -124,3 +128,32 @@ def generate_segmentor(self, question: str) -> Callable: ] return lambda x: GroundingSAM()(**{"prompt": params["prompt"], "image": x}) + + +class AzureOpenAILLM(OpenAILLM): + def __init__( + self, + model_name: str = "gpt-4-turbo-preview", + api_key: Optional[str] = None, + api_version: str = "2024-02-01", + azure_endpoint: Optional[str] = None, + json_mode: bool = False, + **kwargs: Any + ): + if not api_key: + api_key = os.getenv("AZURE_OPENAI_API_KEY") + if not azure_endpoint: + azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") + + if not api_key: + raise ValueError("Azure OpenAI API key is required.") + if not azure_endpoint: + raise ValueError("Azure OpenAI endpoint is required.") + + self.client = AzureOpenAI( + api_key=api_key, api_version=api_version, azure_endpoint=azure_endpoint + ) + self.model_name = model_name + self.kwargs = kwargs + if json_mode: + self.kwargs["response_format"] = {"type": "json_object"} diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index 0bff8e85..3eee8766 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -1,12 +1,13 @@ import base64 import json import logging +import os from abc import ABC, abstractmethod from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union, cast import requests -from openai import OpenAI +from openai import AzureOpenAI, OpenAI from vision_agent.tools import ( CHOOSE_PARAMS, @@ -99,16 +100,19 @@ class OpenAILMM(LMM): def __init__( self, model_name: str = "gpt-4-vision-preview", - api_key: str = "", + api_key: Optional[str] = None, max_tokens: int = 1024, **kwargs: Any, ): + if not api_key: + api_key = os.getenv("OPENAI_API_KEY") + + if not api_key: + raise ValueError("OpenAI API key is required.") + + self.client = OpenAI(api_key=api_key) self.model_name = model_name self.max_tokens = max_tokens - if api_key: - self.client = OpenAI(api_key=api_key) - else: - self.client = OpenAI() self.kwargs = kwargs def __call__( @@ -252,6 +256,34 @@ def generate_segmentor(self, question: str) -> Callable: return lambda x: GroundingSAM()(**{"prompt": params["prompt"], "image": x}) +class AzureOpenAILMM(OpenAILMM): + def __init__( + self, + model_name: str = "gpt-4-vision-preview", + api_key: Optional[str] = None, + api_version: str = "2021-02-01", + azure_endpoint: Optional[str] = None, + max_tokens: int = 1024, + **kwargs: Any, + ): + if not api_key: + api_key = os.getenv("OPENAI_API_KEY") + if not azure_endpoint: + azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") + + if not api_key: + raise ValueError("OpenAI API key is required.") + if not azure_endpoint: + raise ValueError("Azure OpenAI endpoint is required.") + + self.client = AzureOpenAI( + api_key=api_key, api_version=api_version, azure_endpoint=azure_endpoint + ) + self.model_name = model_name + self.max_tokens = max_tokens + self.kwargs = kwargs + + def get_lmm(name: str) -> LMM: if name == "openai": return OpenAILMM(name)