Skip to content

Commit

Permalink
Better Ollama support (#208)
Browse files Browse the repository at this point in the history
* fixed streaming for ollama

* removed old docs

* add ollama sim

* add json mode for ollama

* added options to ollama

* added ollama vision agent coder

* change context window to 128k

* better json parsing

* updated README

* fix type error

* fix type error

* use llama3.1 for ollamavisionagentcoder

* remove debug

* fixed merge issues

* added extra checks around picking plan

* added more docs on ollama

* more docs for ollama

* added tests for ollama

* remove debug
  • Loading branch information
dillonalaird authored Aug 27, 2024
1 parent da3eed1 commit e7c5615
Show file tree
Hide file tree
Showing 11 changed files with 345 additions and 93 deletions.
63 changes: 48 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,20 +168,18 @@ result = agent.chat_with_workflow(conv)

### Tools
There are a variety of tools for the model or the user to use. Some are executed locally
while others are hosted for you. You can also ask an LMM directly to build a tool for
you. For example:
while others are hosted for you. You can easily access them yourself, for example if
you want to run `owl_v2` and visualize the output you can run:

```python
>>> import vision_agent as va
>>> lmm = va.lmm.OpenAILMM()
>>> detector = lmm.generate_detector("Can you build a jar detector for me?")
>>> detector(va.tools.load_image("jar.jpg"))
[{"labels": ["jar",],
"scores": [0.99],
"bboxes": [
[0.58, 0.2, 0.72, 0.45],
]
}]
import vision_agent.tools as T
import matplotlib.pyplot as plt

image = T.load_image("dogs.jpg")
dets = T.owl_v2("dogs", image)
viz = T.overlay_bounding_boxes(image, dets)
plt.imshow(viz)
plt.show()
```

You can also add custom tools to the agent:
Expand Down Expand Up @@ -214,6 +212,41 @@ function. Make sure the documentation is in the same format above with descripti
`Parameters:`, `Returns:`, and `Example\n-------`. You can find an example use case
[here](examples/custom_tools/) as this is what the agent uses to pick and use the tool.

## Additional LLMs
### Ollama
We also provide a `VisionAgentCoder` that uses Ollama. To get started you must download
a few models:

```bash
ollama pull llama3.1
ollama pull mxbai-embed-large
```

`llama3.1` is used for the `OllamaLMM` for `OllamaVisionAgentCoder`. Normally we would
use an actual LMM such as `llava` but `llava` cannot handle the long context lengths
required by the agent. Since `llama3.1` cannot handle images you may see some
performance degredation. `mxbai-embed-large` is the embedding model used to look up
tools. You can use it just like you would use `VisionAgentCoder`:

```python
>>> import vision_agent as va
>>> agent = va.agent.OllamaVisionAgentCoder()
>>> agent("Count the apples in the image", media="apples.jpg")
```
> WARNING: VisionAgent doesn't work well unless the underlying LMM is sufficiently powerful. Do not expect good results or even working code with smaller models like Llama 3.1 8B.
### Azure OpenAI
We also provide a `AzureVisionAgentCoder` that uses Azure OpenAI models. To get started
follow the Azure Setup section below. You can use it just like you would use=
`VisionAgentCoder`:

```python
>>> import vision_agent as va
>>> agent = va.agent.AzureVisionAgentCoder()
>>> agent("Count the apples in the image", media="apples.jpg")
```


### Azure Setup
If you want to use Azure OpenAI models, you need to have two OpenAI model deployments:

Expand Down Expand Up @@ -252,6 +285,6 @@ agent = va.agent.AzureVisionAgentCoder()
2. Follow the instructions to purchase and manage your API credits.
3. Ensure your API key is correctly configured in your project settings.

Failure to have sufficient API credits may result in limited or no functionality for the features that rely on the OpenAI API.

For more details on managing your API usage and credits, please refer to the OpenAI API documentation.
Failure to have sufficient API credits may result in limited or no functionality for
the features that rely on the OpenAI API. For more details on managing your API usage
and credits, please refer to the OpenAI API documentation.
67 changes: 52 additions & 15 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
# 🔍🤖 Vision Agent
[![](https://dcbadge.vercel.app/api/server/wPdN8RCYew?compact=true&style=flat)](https://discord.gg/wPdN8RCYew)
![ci_status](https://github.com/landing-ai/vision-agent/actions/workflows/ci_cd.yml/badge.svg)
[![PyPI version](https://badge.fury.io/py/vision-agent.svg)](https://badge.fury.io/py/vision-agent)
![version](https://img.shields.io/pypi/pyversions/vision-agent)
</div>

Vision Agent is a library that helps you utilize agent frameworks to generate code to
solve your vision task. Many current vision problems can easily take hours or days to
Expand Down Expand Up @@ -160,20 +165,18 @@ result = agent.chat_with_workflow(conv)

### Tools
There are a variety of tools for the model or the user to use. Some are executed locally
while others are hosted for you. You can also ask an LMM directly to build a tool for
you. For example:
while others are hosted for you. You can easily access them yourself, for example if
you want to run `owl_v2` and visualize the output you can run:

```python
>>> import vision_agent as va
>>> lmm = va.lmm.OpenAILMM()
>>> detector = lmm.generate_detector("Can you build a jar detector for me?")
>>> detector(va.tools.load_image("jar.jpg"))
[{"labels": ["jar",],
"scores": [0.99],
"bboxes": [
[0.58, 0.2, 0.72, 0.45],
]
}]
import vision_agent.tools as T
import matplotlib.pyplot as plt

image = T.load_image("dogs.jpg")
dets = T.owl_v2("dogs", image)
viz = T.overlay_bounding_boxes(image, dets)
plt.imshow(viz)
plt.show()
```

You can also add custom tools to the agent:
Expand Down Expand Up @@ -206,6 +209,40 @@ function. Make sure the documentation is in the same format above with descripti
`Parameters:`, `Returns:`, and `Example\n-------`. You can find an example use case
[here](examples/custom_tools/) as this is what the agent uses to pick and use the tool.

## Additional LLMs
### Ollama
We also provide a `VisionAgentCoder` that uses Ollama. To get started you must download
a few models:

```bash
ollama pull llama3.1
ollama pull mxbai-embed-large
```

`llama3.1` is used for the `OllamaLMM` for `OllamaVisionAgentCoder`. Normally we would
use an actual LMM such as `llava` but `llava` cannot handle the long context lengths
required by the agent. Since `llama3.1` cannot handle images you may see some
performance degredation. `mxbai-embed-large` is the embedding model used to look up
tools. You can use it just like you would use `VisionAgentCoder`:

```python
>>> import vision_agent as va
>>> agent = va.agent.OllamaVisionAgentCoder()
>>> agent("Count the apples in the image", media="apples.jpg")
```

### Azure OpenAI
We also provide a `AzureVisionAgentCoder` that uses Azure OpenAI models. To get started
follow the Azure Setup section below. You can use it just like you would use=
`VisionAgentCoder`:

```python
>>> import vision_agent as va
>>> agent = va.agent.AzureVisionAgentCoder()
>>> agent("Count the apples in the image", media="apples.jpg")
```
> WARNING: VisionAgent doesn't work well unless the underlying LMM is sufficiently powerful. Do not expect good results or even working code with smaller models like Llama 3.1 8B.
### Azure Setup
If you want to use Azure OpenAI models, you need to have two OpenAI model deployments:

Expand Down Expand Up @@ -244,6 +281,6 @@ agent = va.agent.AzureVisionAgentCoder()
2. Follow the instructions to purchase and manage your API credits.
3. Ensure your API key is correctly configured in your project settings.

Failure to have sufficient API credits may result in limited or no functionality for the features that rely on the OpenAI API.

For more details on managing your API usage and credits, please refer to the OpenAI API documentation.
Failure to have sufficient API credits may result in limited or no functionality for
the features that rely on the OpenAI API. For more details on managing your API usage
and credits, please refer to the OpenAI API documentation.
20 changes: 0 additions & 20 deletions docs/lmms.md

This file was deleted.

24 changes: 24 additions & 0 deletions tests/unit/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,27 @@ def generator():
mock_instance = mock.return_value
mock_instance.chat.completions.create.return_value = mock_generate()
yield mock_instance


@pytest.fixture
def generate_ollama_lmm_mock(request):
content = request.param

mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {"response": content}
with patch("vision_agent.lmm.lmm.requests.post") as mock:
mock.return_value = mock_resp
yield mock


@pytest.fixture
def chat_ollama_lmm_mock(request):
content = request.param

mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {"message": {"content": content}}
with patch("vision_agent.lmm.lmm.requests.post") as mock:
mock.return_value = mock_resp
yield mock
34 changes: 32 additions & 2 deletions tests/unit/test_lmm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import json
import tempfile
from unittest.mock import patch

import numpy as np
import pytest
from PIL import Image

from vision_agent.lmm.lmm import OpenAILMM
from vision_agent.lmm.lmm import OllamaLMM, OpenAILMM

from .fixtures import openai_lmm_mock # noqa: F401
from .fixtures import ( # noqa: F401
chat_ollama_lmm_mock,
generate_ollama_lmm_mock,
openai_lmm_mock,
)


def create_temp_image(image_format="jpeg"):
Expand Down Expand Up @@ -135,6 +140,31 @@ def test_call_with_mock_stream(openai_lmm_mock): # noqa: F811
)


@pytest.mark.parametrize(
"generate_ollama_lmm_mock",
["mocked response"],
indirect=["generate_ollama_lmm_mock"],
)
def test_generate_ollama_mock(generate_ollama_lmm_mock): # noqa: F811
temp_image = create_temp_image()
lmm = OllamaLMM()
response = lmm.generate("test prompt", media=[temp_image])
assert response == "mocked response"
call_args = json.loads(generate_ollama_lmm_mock.call_args.kwargs["data"])
assert call_args["prompt"] == "test prompt"


@pytest.mark.parametrize(
"chat_ollama_lmm_mock", ["mocked response"], indirect=["chat_ollama_lmm_mock"]
)
def test_chat_ollama_mock(chat_ollama_lmm_mock): # noqa: F811
lmm = OllamaLMM()
response = lmm.chat([{"role": "user", "content": "test prompt"}])
assert response == "mocked response"
call_args = json.loads(chat_ollama_lmm_mock.call_args.kwargs["data"])
assert call_args["messages"][0]["content"] == "test prompt"


@pytest.mark.parametrize(
"openai_lmm_mock",
['{"Parameters": {"prompt": "cat"}}'],
Expand Down
6 changes: 5 additions & 1 deletion vision_agent/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from .agent import Agent
from .vision_agent import VisionAgent
from .vision_agent_coder import AzureVisionAgentCoder, VisionAgentCoder
from .vision_agent_coder import (
AzureVisionAgentCoder,
OllamaVisionAgentCoder,
VisionAgentCoder,
)
27 changes: 25 additions & 2 deletions vision_agent/agent/agent_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,24 @@
import json
import logging
import re
import sys
from typing import Any, Dict
from typing import Any, Dict, Optional

logging.basicConfig(stream=sys.stdout)
_LOGGER = logging.getLogger(__name__)


def _extract_sub_json(json_str: str) -> Optional[Dict[str, Any]]:
json_pattern = r"\{.*\}"
match = re.search(json_pattern, json_str, re.DOTALL)
if match:
json_str = match.group()
try:
json_dict = json.loads(json_str)
return json_dict # type: ignore
except json.JSONDecodeError:
return None
return None


def extract_json(json_str: str) -> Dict[str, Any]:
Expand All @@ -18,8 +33,16 @@ def extract_json(json_str: str) -> Dict[str, Any]:
json_str = json_str[json_str.find("```") + len("```") :]
# get the last ``` not one from an intermediate string
json_str = json_str[: json_str.find("}```")]
try:
json_dict = json.loads(json_str)
except json.JSONDecodeError as e:
json_dict = _extract_sub_json(json_str)
if json_dict is not None:
return json_dict # type: ignore
error_msg = f"Could not extract JSON from the given str: {json_str}"
_LOGGER.exception(error_msg)
raise ValueError(error_msg) from e

json_dict = json.loads(json_str)
return json_dict # type: ignore


Expand Down
Loading

0 comments on commit e7c5615

Please sign in to comment.