From 49af5028918523b6ce21ea8ca79a8ad2f5d5c177 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Tue, 9 Apr 2024 09:39:02 -0700 Subject: [PATCH] Add Azure OpenAI (#42) * added api_key in init arg * added azure openai * added azure openai to modules * fix for passing tests * added azure openai to readme * fixed typo --- README.md | 23 +++++++++++++++++++- vision_agent/llm/__init__.py | 2 +- vision_agent/llm/llm.py | 41 +++++++++++++++++++++++++++++++++--- vision_agent/lmm/__init__.py | 2 +- vision_agent/lmm/lmm.py | 39 ++++++++++++++++++++++++++++++++-- 5 files changed, 99 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 181be5ed..bff0a12a 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,8 @@ To get started, you can install the library using pip: pip install vision-agent ``` -Ensure you have an OpenAI API key and set it as an environment variable: +Ensure you have an OpenAI API key and set it as an environment variable (if you are +using Azure OpenAI please see the additional setup section): ```bash export OPENAI_API_KEY="your-api-key" @@ -109,3 +110,23 @@ you. For example: It also has a basic set of calculate tools such as add, subtract, multiply and divide. + +### Additional Setup +If you want to use Azure OpenAI models, you can set the environment variable: + +```bash +export AZURE_OPENAI_API_KEY="your-api-key" +export AZURE_OPENAI_ENDPOINT="your-endpoint" +``` + +You can then run Vision Agent using the Azure OpenAI models: + +```python +>>> import vision_agent as va +>>> agent = va.agent.VisionAgent( +>>> task_model=va.llm.AzureOpenAILLM(), +>>> answer_model=va.lmm.AzureOpenAILMM(), +>>> reflection_model=va.lmm.AzureOpenAILMM(), +>>> ) +``` + diff --git a/vision_agent/llm/__init__.py b/vision_agent/llm/__init__.py index dd5f5c54..e482f69d 100644 --- a/vision_agent/llm/__init__.py +++ b/vision_agent/llm/__init__.py @@ -1 +1 @@ -from .llm import LLM, OpenAILLM +from .llm import LLM, AzureOpenAILLM, OpenAILLM diff --git a/vision_agent/llm/llm.py b/vision_agent/llm/llm.py index e97bcdeb..9022ef73 100644 --- a/vision_agent/llm/llm.py +++ b/vision_agent/llm/llm.py @@ -1,8 +1,9 @@ import json +import os from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Mapping, Union, cast +from typing import Any, Callable, Dict, List, Mapping, Optional, Union, cast -from openai import OpenAI +from openai import AzureOpenAI, OpenAI from vision_agent.tools import ( CHOOSE_PARAMS, @@ -33,11 +34,16 @@ class OpenAILLM(LLM): def __init__( self, model_name: str = "gpt-4-turbo-preview", + api_key: Optional[str] = None, json_mode: bool = False, **kwargs: Any ): + if not api_key: + self.client = OpenAI() + else: + self.client = OpenAI(api_key=api_key) + self.model_name = model_name - self.client = OpenAI() self.kwargs = kwargs if json_mode: self.kwargs["response_format"] = {"type": "json_object"} @@ -120,3 +126,32 @@ def generate_segmentor(self, question: str) -> Callable: ] return lambda x: GroundingSAM()(**{"prompt": params["prompt"], "image": x}) + + +class AzureOpenAILLM(OpenAILLM): + def __init__( + self, + model_name: str = "gpt-4-turbo-preview", + api_key: Optional[str] = None, + api_version: str = "2024-02-01", + azure_endpoint: Optional[str] = None, + json_mode: bool = False, + **kwargs: Any + ): + if not api_key: + api_key = os.getenv("AZURE_OPENAI_API_KEY") + if not azure_endpoint: + azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") + + if not api_key: + raise ValueError("Azure OpenAI API key is required.") + if not azure_endpoint: + raise ValueError("Azure OpenAI endpoint is required.") + + self.client = AzureOpenAI( + api_key=api_key, api_version=api_version, azure_endpoint=azure_endpoint + ) + self.model_name = model_name + self.kwargs = kwargs + if json_mode: + self.kwargs["response_format"] = {"type": "json_object"} diff --git a/vision_agent/lmm/__init__.py b/vision_agent/lmm/__init__.py index 26bc23c1..9c7ace7a 100644 --- a/vision_agent/lmm/__init__.py +++ b/vision_agent/lmm/__init__.py @@ -1 +1 @@ -from .lmm import LMM, LLaVALMM, OpenAILMM, get_lmm +from .lmm import LMM, AzureOpenAILMM, LLaVALMM, OpenAILMM, get_lmm diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index 7ae65eb2..0d63b158 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -1,12 +1,13 @@ import base64 import json import logging +import os from abc import ABC, abstractmethod from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union, cast import requests -from openai import OpenAI +from openai import AzureOpenAI, OpenAI from vision_agent.tools import ( CHOOSE_PARAMS, @@ -99,12 +100,18 @@ class OpenAILMM(LMM): def __init__( self, model_name: str = "gpt-4-vision-preview", + api_key: Optional[str] = None, max_tokens: int = 1024, **kwargs: Any, ): + if not api_key: + self.client = OpenAI() + else: + self.client = OpenAI(api_key=api_key) + + self.client = OpenAI(api_key=api_key) self.model_name = model_name self.max_tokens = max_tokens - self.client = OpenAI() self.kwargs = kwargs def __call__( @@ -248,6 +255,34 @@ def generate_segmentor(self, question: str) -> Callable: return lambda x: GroundingSAM()(**{"prompt": params["prompt"], "image": x}) +class AzureOpenAILMM(OpenAILMM): + def __init__( + self, + model_name: str = "gpt-4-vision-preview", + api_key: Optional[str] = None, + api_version: str = "2024-02-01", + azure_endpoint: Optional[str] = None, + max_tokens: int = 1024, + **kwargs: Any, + ): + if not api_key: + api_key = os.getenv("AZURE_OPENAI_API_KEY") + if not azure_endpoint: + azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") + + if not api_key: + raise ValueError("OpenAI API key is required.") + if not azure_endpoint: + raise ValueError("Azure OpenAI endpoint is required.") + + self.client = AzureOpenAI( + api_key=api_key, api_version=api_version, azure_endpoint=azure_endpoint + ) + self.model_name = model_name + self.max_tokens = max_tokens + self.kwargs = kwargs + + def get_lmm(name: str) -> LMM: if name == "openai": return OpenAILMM(name)