Skip to content

Commit

Permalink
Fix AzureVisionAgent (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
humpydonkey authored Jun 13, 2024
1 parent bdb2662 commit fe97056
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 4 deletions.
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,25 @@ ensure the documentation is in the same format above with description, `Paramete
`Returns:`, and `Example\n-------`. You can find an example use case [here](examples/custom_tools/).

### Azure Setup
If you want to use Azure OpenAI models, you can set the environment variable:
If you want to use Azure OpenAI models, you need to have two OpenAI model deployments:

1. OpenAI GPT-4o model
2. OpenAI text embedding model


Then you can set the following environment variables:

```bash
export AZURE_OPENAI_API_KEY="your-api-key"
export AZURE_OPENAI_ENDPOINT="your-endpoint"
# The deployment name of your OpenAI chat model
export AZURE_OPENAI_CHAT_MODEL_DEPLOYMENT_NAME="your_gpt4o_model_deployment_name"
# The deployment name of your OpenAI text embedding model
export AZURE_OPENAI_EMBEDDING_MODEL_DEPLOYMENT_NAME="your_embedding_model_deployment_name"
```

> NOTE: make sure your Azure model deployment have enough quota (token per minute) to support it.
You can then run Vision Agent using the Azure OpenAI models:

```python
Expand Down
10 changes: 8 additions & 2 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def generate_image_qa_tool(self, question: str) -> Callable:
class AzureOpenAILMM(OpenAILMM):
def __init__(
self,
model_name: str = "gpt-4o",
model_name: Optional[str] = None,
api_key: Optional[str] = None,
api_version: str = "2024-02-01",
azure_endpoint: Optional[str] = None,
Expand All @@ -245,14 +245,20 @@ def __init__(
api_key = os.getenv("AZURE_OPENAI_API_KEY")
if not azure_endpoint:
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
if not model_name:
model_name = os.getenv("AZURE_OPENAI_CHAT_MODEL_DEPLOYMENT_NAME")

if not api_key:
raise ValueError("OpenAI API key is required.")
if not azure_endpoint:
raise ValueError("Azure OpenAI endpoint is required.")
if not model_name:
raise ValueError("Azure OpenAI chat model deployment name is required.")

self.client = AzureOpenAI(
api_key=api_key, api_version=api_version, azure_endpoint=azure_endpoint
api_key=api_key,
api_version=api_version,
azure_endpoint=azure_endpoint,
)
self.model_name = model_name

Expand Down
8 changes: 7 additions & 1 deletion vision_agent/utils/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,23 @@ def __init__(
api_key: Optional[str] = None,
api_version: str = "2024-02-01",
azure_endpoint: Optional[str] = None,
model: str = "text-embedding-3-small",
model: Optional[str] = None,
) -> None:
if not api_key:
api_key = os.getenv("AZURE_OPENAI_API_KEY")
if not azure_endpoint:
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
if not model:
model = os.getenv("AZURE_OPENAI_EMBEDDING_MODEL_DEPLOYMENT_NAME")

if not api_key:
raise ValueError("Azure OpenAI API key is required.")
if not azure_endpoint:
raise ValueError("Azure OpenAI endpoint is required.")
if not model:
raise ValueError(
"Azure OpenAI embedding model deployment name is required."
)

self.df = df
self.client = AzureOpenAI(
Expand Down

0 comments on commit fe97056

Please sign in to comment.