Skip to content

Commit

Permalink
Fix Ollama minor issues, make tests more stable (#145)
Browse files Browse the repository at this point in the history
* fixed a few minor issues

* make tests more robust

* fix mypy
  • Loading branch information
dillonalaird authored Jun 19, 2024
1 parent 4890c10 commit ea81dab
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 69 deletions.
213 changes: 159 additions & 54 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,107 +16,212 @@
owl_v2,
)

RETRIES = 3


def test_grounding_dino():
img = ski.data.coins()
result = grounding_dino(
prompt="coin",
image=img,
)
assert len(result) == 24
assert [res["label"] for res in result] == ["coin"] * 24
count = 0
while count < RETRIES:
try:
result = grounding_dino(
prompt="coin",
image=img,
)
assert len(result) == 24
assert [res["label"] for res in result] == ["coin"] * 24
break
except Exception as e:
count += 1
if count == RETRIES:
raise e


def test_grounding_dino_tiny():
img = ski.data.coins()
result = grounding_dino(prompt="coin", image=img, model_size="tiny")
assert len(result) == 24
assert [res["label"] for res in result] == ["coin"] * 24
count = 0
while count < RETRIES:
try:
result = grounding_dino(
prompt="coin",
image=img,
model_size="tiny",
)
assert len(result) == 24
assert [res["label"] for res in result] == ["coin"] * 24
break
except Exception as e:
count += 1
if count == RETRIES:
raise e


def test_owl():
img = ski.data.coins()
result = owl_v2(prompt="coin", image=img, box_threshold=0.15)
assert len(result) == 25
assert [res["label"] for res in result] == ["coin"] * 25
count = 0
while count < RETRIES:
try:
result = owl_v2(
prompt="coin",
image=img,
)
assert len(result) == 25
assert [res["label"] for res in result] == ["coin"] * 25
break
except Exception as e:
count += 1
if count == RETRIES:
raise e


def test_grounding_sam():
img = ski.data.coins()
result = grounding_sam(
prompt="coin",
image=img,
)
assert len(result) == 24
assert [res["label"] for res in result] == ["coin"] * 24
assert len([res["mask"] for res in result]) == 24
count = 0
while count < RETRIES:
try:
result = grounding_sam(
prompt="coin",
image=img,
)
assert len(result) == 24
assert [res["label"] for res in result] == ["coin"] * 24
assert len([res["mask"] for res in result]) == 24
break
except Exception as e:
count += 1
if count == RETRIES:
raise e


def test_clip():
img = ski.data.coins()
result = clip(
classes=["coins", "notes"],
image=img,
)
assert result["scores"] == [0.9999, 0.0001]
count = 0
while count < RETRIES:
try:
result = clip(
classes=["coins", "notes"],
image=img,
)
assert result["scores"] == [0.9999, 0.0001]
break
except Exception as e:
count += 1
if count == RETRIES:
raise e


def test_vit_classification():
img = ski.data.coins()
result = vit_image_classification(
image=img,
)
assert "typewriter keyboard" in result["labels"]
count = 0
while count < RETRIES:
try:
result = vit_image_classification(
image=img,
)
assert "typewriter keyboard" in result["labels"]
break
except Exception as e:
count += 1
if count == RETRIES:
raise e


def test_nsfw_classification():
img = ski.data.coins()
result = vit_nsfw_classification(
image=img,
)
assert result["labels"] == "normal"
count = 0
while count < RETRIES:
try:
result = vit_nsfw_classification(
image=img,
)
assert result["labels"] == "normal"
break
except Exception as e:
count += 1
if count == RETRIES:
raise e


def test_image_caption() -> None:
img = ski.data.rocket()
result = blip_image_caption(
image=img,
)
assert result.strip() == "a rocket on a stand"
count = 0
while count < RETRIES:
try:
result = blip_image_caption(
image=img,
)
assert result.strip() == "a rocket on a stand"
break
except Exception as e:
count += 1
if count == RETRIES:
raise e


def test_loca_zero_shot_counting() -> None:
img = ski.data.coins()
result = loca_zero_shot_counting(
image=img,
)
assert result["count"] == 21
count = 0
while count < RETRIES:
try:
result = loca_zero_shot_counting(
image=img,
)
assert result["count"] == 21
break
except Exception as e:
count += 1
if count == RETRIES:
raise e


def test_loca_visual_prompt_counting() -> None:
img = ski.data.coins()
result = loca_visual_prompt_counting(
visual_prompt={"bbox": [85, 106, 122, 145]},
image=img,
)
assert result["count"] == 25
count = 0
while count < RETRIES:
try:
result = loca_visual_prompt_counting(
visual_prompt={"bbox": [85, 106, 122, 145]},
image=img,
)
assert result["count"] == 25
break
except Exception as e:
count += 1
if count == RETRIES:
raise e


def test_git_vqa_v2() -> None:
img = ski.data.rocket()
result = git_vqa_v2(
prompt="Is the scene captured during day or night ?",
image=img,
)
assert result.strip() == "night"
count = 0
while count < RETRIES:
try:
result = git_vqa_v2(
prompt="Is the scene captured during day or night ?",
image=img,
)
assert result.strip() == "night"
break
except Exception as e:
count += 1
if count == RETRIES:
raise e


def test_ocr() -> None:
img = ski.data.page()
result = ocr(
image=img,
)
assert any("Region-based segmentation" in res["label"] for res in result)
count = 0
while count < RETRIES:
try:
result = ocr(
image=img,
)
assert any("Region-based segmentation" in res["label"] for res in result)
break
except Exception as e:
count += 1
if count == RETRIES:
raise e


def test_mask_distance():
Expand Down
2 changes: 1 addition & 1 deletion vision_agent/lmm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .lmm import LMM, AzureOpenAILMM, Message, OpenAILMM
from .lmm import LMM, AzureOpenAILMM, Message, OllamaLMM, OpenAILMM
23 changes: 9 additions & 14 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import base64
import json
import requests
import logging
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union, cast

import requests
from openai import AzureOpenAI, OpenAI

import vision_agent.tools as T
Expand Down Expand Up @@ -292,7 +292,7 @@ def __call__(
if isinstance(input, str):
return self.generate(input)
return self.chat(input)

def chat(
self,
chat: List[Message],
Expand All @@ -315,18 +315,14 @@ def chat(
url = f"{self.url}/chat"
model = self.model_name
messages = fixed_chat
data = {
"model": model,
"messages": messages,
"stream": self.stream
}
data = {"model": model, "messages": messages, "stream": self.stream}
json_data = json.dumps(data)
response = requests.post(url, data=json_data)
if response.status_code != 200:
raise ValueError(f"Request failed with status code {response.status_code}")
response = response.json()
return response["message"]["content"]
return response["message"]["content"] # type: ignore

def generate(
self,
prompt: str,
Expand All @@ -338,19 +334,18 @@ def generate(
"model": self.model_name,
"prompt": prompt,
"images": [],
"stream": self.stream
"stream": self.stream,
}

json_data = json.dumps(data)
if media and len(media) > 0:
for m in media:
data["images"].append(encode_image(m))
data["images"].append(encode_image(m)) # type: ignore

response = requests.post(url, data=json_data)

if response.status_code != 200:
raise ValueError(f"Request failed with status code {response.status_code}")

response = response.json()
return response["response"]

return response["response"] # type: ignore

0 comments on commit ea81dab

Please sign in to comment.