diff --git a/tests/test_tools.py b/tests/test_tools.py index ed7670cb..22dcdd7d 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -16,107 +16,212 @@ owl_v2, ) +RETRIES = 3 + def test_grounding_dino(): img = ski.data.coins() - result = grounding_dino( - prompt="coin", - image=img, - ) - assert len(result) == 24 - assert [res["label"] for res in result] == ["coin"] * 24 + count = 0 + while count < RETRIES: + try: + result = grounding_dino( + prompt="coin", + image=img, + ) + assert len(result) == 24 + assert [res["label"] for res in result] == ["coin"] * 24 + break + except Exception as e: + count += 1 + if count == RETRIES: + raise e def test_grounding_dino_tiny(): img = ski.data.coins() - result = grounding_dino(prompt="coin", image=img, model_size="tiny") - assert len(result) == 24 - assert [res["label"] for res in result] == ["coin"] * 24 + count = 0 + while count < RETRIES: + try: + result = grounding_dino( + prompt="coin", + image=img, + model_size="tiny", + ) + assert len(result) == 24 + assert [res["label"] for res in result] == ["coin"] * 24 + break + except Exception as e: + count += 1 + if count == RETRIES: + raise e def test_owl(): img = ski.data.coins() - result = owl_v2(prompt="coin", image=img, box_threshold=0.15) - assert len(result) == 25 - assert [res["label"] for res in result] == ["coin"] * 25 + count = 0 + while count < RETRIES: + try: + result = owl_v2( + prompt="coin", + image=img, + ) + assert len(result) == 25 + assert [res["label"] for res in result] == ["coin"] * 25 + break + except Exception as e: + count += 1 + if count == RETRIES: + raise e def test_grounding_sam(): img = ski.data.coins() - result = grounding_sam( - prompt="coin", - image=img, - ) - assert len(result) == 24 - assert [res["label"] for res in result] == ["coin"] * 24 - assert len([res["mask"] for res in result]) == 24 + count = 0 + while count < RETRIES: + try: + result = grounding_sam( + prompt="coin", + image=img, + ) + assert len(result) == 24 + assert [res["label"] for res in result] == ["coin"] * 24 + assert len([res["mask"] for res in result]) == 24 + break + except Exception as e: + count += 1 + if count == RETRIES: + raise e def test_clip(): img = ski.data.coins() - result = clip( - classes=["coins", "notes"], - image=img, - ) - assert result["scores"] == [0.9999, 0.0001] + count = 0 + while count < RETRIES: + try: + result = clip( + classes=["coins", "notes"], + image=img, + ) + assert result["scores"] == [0.9999, 0.0001] + break + except Exception as e: + count += 1 + if count == RETRIES: + raise e def test_vit_classification(): img = ski.data.coins() - result = vit_image_classification( - image=img, - ) - assert "typewriter keyboard" in result["labels"] + count = 0 + while count < RETRIES: + try: + result = vit_image_classification( + image=img, + ) + assert "typewriter keyboard" in result["labels"] + break + except Exception as e: + count += 1 + if count == RETRIES: + raise e def test_nsfw_classification(): img = ski.data.coins() - result = vit_nsfw_classification( - image=img, - ) - assert result["labels"] == "normal" + count = 0 + while count < RETRIES: + try: + result = vit_nsfw_classification( + image=img, + ) + assert result["labels"] == "normal" + break + except Exception as e: + count += 1 + if count == RETRIES: + raise e def test_image_caption() -> None: img = ski.data.rocket() - result = blip_image_caption( - image=img, - ) - assert result.strip() == "a rocket on a stand" + count = 0 + while count < RETRIES: + try: + result = blip_image_caption( + image=img, + ) + assert result.strip() == "a rocket on a stand" + break + except Exception as e: + count += 1 + if count == RETRIES: + raise e def test_loca_zero_shot_counting() -> None: img = ski.data.coins() - result = loca_zero_shot_counting( - image=img, - ) - assert result["count"] == 21 + count = 0 + while count < RETRIES: + try: + result = loca_zero_shot_counting( + image=img, + ) + assert result["count"] == 21 + break + except Exception as e: + count += 1 + if count == RETRIES: + raise e def test_loca_visual_prompt_counting() -> None: img = ski.data.coins() - result = loca_visual_prompt_counting( - visual_prompt={"bbox": [85, 106, 122, 145]}, - image=img, - ) - assert result["count"] == 25 + count = 0 + while count < RETRIES: + try: + result = loca_visual_prompt_counting( + visual_prompt={"bbox": [85, 106, 122, 145]}, + image=img, + ) + assert result["count"] == 25 + break + except Exception as e: + count += 1 + if count == RETRIES: + raise e def test_git_vqa_v2() -> None: img = ski.data.rocket() - result = git_vqa_v2( - prompt="Is the scene captured during day or night ?", - image=img, - ) - assert result.strip() == "night" + count = 0 + while count < RETRIES: + try: + result = git_vqa_v2( + prompt="Is the scene captured during day or night ?", + image=img, + ) + assert result.strip() == "night" + break + except Exception as e: + count += 1 + if count == RETRIES: + raise e def test_ocr() -> None: img = ski.data.page() - result = ocr( - image=img, - ) - assert any("Region-based segmentation" in res["label"] for res in result) + count = 0 + while count < RETRIES: + try: + result = ocr( + image=img, + ) + assert any("Region-based segmentation" in res["label"] for res in result) + break + except Exception as e: + count += 1 + if count == RETRIES: + raise e def test_mask_distance(): diff --git a/vision_agent/lmm/__init__.py b/vision_agent/lmm/__init__.py index 6d41888c..664e7e4a 100644 --- a/vision_agent/lmm/__init__.py +++ b/vision_agent/lmm/__init__.py @@ -1 +1 @@ -from .lmm import LMM, AzureOpenAILMM, Message, OpenAILMM +from .lmm import LMM, AzureOpenAILMM, Message, OllamaLMM, OpenAILMM diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index 9ca5c581..71c632dd 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -1,12 +1,12 @@ import base64 import json -import requests import logging import os from abc import ABC, abstractmethod from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union, cast +import requests from openai import AzureOpenAI, OpenAI import vision_agent.tools as T @@ -292,7 +292,7 @@ def __call__( if isinstance(input, str): return self.generate(input) return self.chat(input) - + def chat( self, chat: List[Message], @@ -315,18 +315,14 @@ def chat( url = f"{self.url}/chat" model = self.model_name messages = fixed_chat - data = { - "model": model, - "messages": messages, - "stream": self.stream - } + data = {"model": model, "messages": messages, "stream": self.stream} json_data = json.dumps(data) response = requests.post(url, data=json_data) if response.status_code != 200: raise ValueError(f"Request failed with status code {response.status_code}") response = response.json() - return response["message"]["content"] - + return response["message"]["content"] # type: ignore + def generate( self, prompt: str, @@ -338,19 +334,18 @@ def generate( "model": self.model_name, "prompt": prompt, "images": [], - "stream": self.stream + "stream": self.stream, } json_data = json.dumps(data) if media and len(media) > 0: for m in media: - data["images"].append(encode_image(m)) + data["images"].append(encode_image(m)) # type: ignore response = requests.post(url, data=json_data) if response.status_code != 200: raise ValueError(f"Request failed with status code {response.status_code}") - + response = response.json() - return response["response"] - \ No newline at end of file + return response["response"] # type: ignore