Skip to content

Commit

Permalink
added azure openai
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Apr 6, 2024
1 parent 78971c4 commit fe1ff8d
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 13 deletions.
47 changes: 40 additions & 7 deletions vision_agent/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
import os
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Mapping, Union, cast
from typing import Any, Callable, Dict, List, Mapping, Optional, Union, cast

from openai import OpenAI
from openai import AzureOpenAI, OpenAI

from vision_agent.tools import (
CHOOSE_PARAMS,
Expand Down Expand Up @@ -33,15 +34,18 @@ class OpenAILLM(LLM):
def __init__(
self,
model_name: str = "gpt-4-turbo-preview",
api_key: str = "",
api_key: Optional[str] = None,
json_mode: bool = False,
**kwargs: Any
):
if not api_key:
api_key = os.getenv("OPENAI_API_KEY")

if not api_key:
raise ValueError("OpenAI API key is required.")

self.client = OpenAI(api_key=api_key)
self.model_name = model_name
if api_key:
self.client = OpenAI(api_key=api_key)
else:
self.client = OpenAI()
self.kwargs = kwargs
if json_mode:
self.kwargs["response_format"] = {"type": "json_object"}
Expand Down Expand Up @@ -124,3 +128,32 @@ def generate_segmentor(self, question: str) -> Callable:
]

return lambda x: GroundingSAM()(**{"prompt": params["prompt"], "image": x})


class AzureOpenAILLM(OpenAILLM):
def __init__(
self,
model_name: str = "gpt-4-turbo-preview",
api_key: Optional[str] = None,
api_version: str = "2024-02-01",
azure_endpoint: Optional[str] = None,
json_mode: bool = False,
**kwargs: Any
):
if not api_key:
api_key = os.getenv("AZURE_OPENAI_API_KEY")
if not azure_endpoint:
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")

if not api_key:
raise ValueError("Azure OpenAI API key is required.")
if not azure_endpoint:
raise ValueError("Azure OpenAI endpoint is required.")

self.client = AzureOpenAI(
api_key=api_key, api_version=api_version, azure_endpoint=azure_endpoint
)
self.model_name = model_name
self.kwargs = kwargs
if json_mode:
self.kwargs["response_format"] = {"type": "json_object"}
44 changes: 38 additions & 6 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import base64
import json
import logging
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union, cast

import requests
from openai import OpenAI
from openai import AzureOpenAI, OpenAI

from vision_agent.tools import (
CHOOSE_PARAMS,
Expand Down Expand Up @@ -99,16 +100,19 @@ class OpenAILMM(LMM):
def __init__(
self,
model_name: str = "gpt-4-vision-preview",
api_key: str = "",
api_key: Optional[str] = None,
max_tokens: int = 1024,
**kwargs: Any,
):
if not api_key:
api_key = os.getenv("OPENAI_API_KEY")

if not api_key:
raise ValueError("OpenAI API key is required.")

self.client = OpenAI(api_key=api_key)
self.model_name = model_name
self.max_tokens = max_tokens
if api_key:
self.client = OpenAI(api_key=api_key)
else:
self.client = OpenAI()
self.kwargs = kwargs

def __call__(
Expand Down Expand Up @@ -252,6 +256,34 @@ def generate_segmentor(self, question: str) -> Callable:
return lambda x: GroundingSAM()(**{"prompt": params["prompt"], "image": x})


class AzureOpenAILMM(OpenAILMM):
def __init__(
self,
model_name: str = "gpt-4-vision-preview",
api_key: Optional[str] = None,
api_version: str = "2021-02-01",
azure_endpoint: Optional[str] = None,
max_tokens: int = 1024,
**kwargs: Any,
):
if not api_key:
api_key = os.getenv("OPENAI_API_KEY")
if not azure_endpoint:
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")

if not api_key:
raise ValueError("OpenAI API key is required.")
if not azure_endpoint:
raise ValueError("Azure OpenAI endpoint is required.")

self.client = AzureOpenAI(
api_key=api_key, api_version=api_version, azure_endpoint=azure_endpoint
)
self.model_name = model_name
self.max_tokens = max_tokens
self.kwargs = kwargs


def get_lmm(name: str) -> LMM:
if name == "openai":
return OpenAILMM(name)
Expand Down

0 comments on commit fe1ff8d

Please sign in to comment.