Skip to content

Commit 48643b3

Browse files
author
Kye
committed
tests for yi, stable diffusion, timm models, etc
Former-commit-id: dfea671
1 parent 59f3b4c commit 48643b3

File tree

10 files changed

+1024
-124
lines changed

10 files changed

+1024
-124
lines changed

swarms/models/autotemp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import re
22
from concurrent.futures import ThreadPoolExecutor, as_completed
3-
from swarms.models.auto_temp import OpenAIChat
3+
from swarms.models.openai_models import OpenAIChat
44

55

66
class AutoTempAgent:

swarms/models/simple_ada.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
client = OpenAI()
44

5+
56
def get_ada_embeddings(text: str, model: str = "text-embedding-ada-002"):
67
"""
78
Simple function to get embeddings from ada

tests/models/bioclip.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Import necessary modules and define fixtures if needed
2+
import os
3+
import pytest
4+
import torch
5+
from PIL import Image
6+
from swarms.models.bioclip import BioClip
7+
8+
9+
# Define fixtures if needed
10+
@pytest.fixture
11+
def sample_image_path():
12+
return "path_to_sample_image.jpg"
13+
14+
15+
@pytest.fixture
16+
def clip_instance():
17+
return BioClip("microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224")
18+
19+
20+
# Basic tests for the BioClip class
21+
def test_clip_initialization(clip_instance):
22+
assert isinstance(clip_instance.model, torch.nn.Module)
23+
assert hasattr(clip_instance, "model_path")
24+
assert hasattr(clip_instance, "preprocess_train")
25+
assert hasattr(clip_instance, "preprocess_val")
26+
assert hasattr(clip_instance, "tokenizer")
27+
assert hasattr(clip_instance, "device")
28+
29+
30+
def test_clip_call_method(clip_instance, sample_image_path):
31+
labels = [
32+
"adenocarcinoma histopathology",
33+
"brain MRI",
34+
"covid line chart",
35+
"squamous cell carcinoma histopathology",
36+
"immunohistochemistry histopathology",
37+
"bone X-ray",
38+
"chest X-ray",
39+
"pie chart",
40+
"hematoxylin and eosin histopathology",
41+
]
42+
result = clip_instance(sample_image_path, labels)
43+
assert isinstance(result, dict)
44+
assert len(result) == len(labels)
45+
46+
47+
def test_clip_plot_image_with_metadata(clip_instance, sample_image_path):
48+
metadata = {
49+
"filename": "sample_image.jpg",
50+
"top_probs": {"label1": 0.75, "label2": 0.65},
51+
}
52+
clip_instance.plot_image_with_metadata(sample_image_path, metadata)
53+
54+
55+
# More test cases can be added to cover additional functionality and edge cases
56+
57+
58+
# Parameterized tests for different image and label combinations
59+
@pytest.mark.parametrize(
60+
"image_path, labels",
61+
[
62+
("image1.jpg", ["label1", "label2"]),
63+
("image2.jpg", ["label3", "label4"]),
64+
# Add more image and label combinations
65+
],
66+
)
67+
def test_clip_parameterized_calls(clip_instance, image_path, labels):
68+
result = clip_instance(image_path, labels)
69+
assert isinstance(result, dict)
70+
assert len(result) == len(labels)
71+
72+
73+
# Test image preprocessing
74+
def test_clip_image_preprocessing(clip_instance, sample_image_path):
75+
image = Image.open(sample_image_path)
76+
processed_image = clip_instance.preprocess_val(image)
77+
assert isinstance(processed_image, torch.Tensor)
78+
79+
80+
# Test label tokenization
81+
def test_clip_label_tokenization(clip_instance):
82+
labels = ["label1", "label2"]
83+
tokenized_labels = clip_instance.tokenizer(labels)
84+
assert isinstance(tokenized_labels, torch.Tensor)
85+
assert tokenized_labels.shape[0] == len(labels)
86+
87+
88+
# More tests can be added to cover other methods and edge cases
89+
90+
91+
# End-to-end tests with actual images and labels
92+
def test_clip_end_to_end(clip_instance, sample_image_path):
93+
labels = [
94+
"adenocarcinoma histopathology",
95+
"brain MRI",
96+
"covid line chart",
97+
"squamous cell carcinoma histopathology",
98+
"immunohistochemistry histopathology",
99+
"bone X-ray",
100+
"chest X-ray",
101+
"pie chart",
102+
"hematoxylin and eosin histopathology",
103+
]
104+
result = clip_instance(sample_image_path, labels)
105+
assert isinstance(result, dict)
106+
assert len(result) == len(labels)
107+
108+
109+
# Test label tokenization with long labels
110+
def test_clip_long_labels(clip_instance):
111+
labels = ["label" + str(i) for i in range(100)]
112+
tokenized_labels = clip_instance.tokenizer(labels)
113+
assert isinstance(tokenized_labels, torch.Tensor)
114+
assert tokenized_labels.shape[0] == len(labels)
115+
116+
117+
# Test handling of multiple image files
118+
def test_clip_multiple_images(clip_instance, sample_image_path):
119+
labels = ["label1", "label2"]
120+
image_paths = [sample_image_path, "image2.jpg"]
121+
results = clip_instance(image_paths, labels)
122+
assert isinstance(results, list)
123+
assert len(results) == len(image_paths)
124+
for result in results:
125+
assert isinstance(result, dict)
126+
assert len(result) == len(labels)
127+
128+
129+
# Test model inference performance
130+
def test_clip_inference_performance(clip_instance, sample_image_path, benchmark):
131+
labels = [
132+
"adenocarcinoma histopathology",
133+
"brain MRI",
134+
"covid line chart",
135+
"squamous cell carcinoma histopathology",
136+
"immunohistochemistry histopathology",
137+
"bone X-ray",
138+
"chest X-ray",
139+
"pie chart",
140+
"hematoxylin and eosin histopathology",
141+
]
142+
result = benchmark(clip_instance, sample_image_path, labels)
143+
assert isinstance(result, dict)
144+
assert len(result) == len(labels)
145+
146+
147+
# Test different preprocessing pipelines
148+
def test_clip_preprocessing_pipelines(clip_instance, sample_image_path):
149+
labels = ["label1", "label2"]
150+
image = Image.open(sample_image_path)
151+
152+
# Test preprocessing for training
153+
processed_image_train = clip_instance.preprocess_train(image)
154+
assert isinstance(processed_image_train, torch.Tensor)
155+
156+
# Test preprocessing for validation
157+
processed_image_val = clip_instance.preprocess_val(image)
158+
assert isinstance(processed_image_val, torch.Tensor)
159+
160+
161+
# ...

tests/models/distill_whisper.py

Lines changed: 114 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import os
22
import tempfile
33
from functools import wraps
4-
from unittest.mock import patch
4+
from unittest.mock import AsyncMock, MagicMock, patch
55

66
import numpy as np
77
import pytest
88
import torch
9+
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
910

10-
from swarms.models.distill_whisperx import DistilWhisperModel, async_retry
11+
from swarms.models.distilled_whisperx import DistilWhisperModel, async_retry
1112

1213

1314
@pytest.fixture
@@ -150,5 +151,114 @@ def test_create_audio_file():
150151
os.remove(audio_file_path)
151152

152153

153-
if __name__ == "__main__":
154-
pytest.main()
154+
# test_distilled_whisperx.py
155+
156+
157+
# Fixtures for setting up model, processor, and audio files
158+
@pytest.fixture(scope="module")
159+
def model_id():
160+
return "distil-whisper/distil-large-v2"
161+
162+
163+
@pytest.fixture(scope="module")
164+
def whisper_model(model_id):
165+
return DistilWhisperModel(model_id)
166+
167+
168+
@pytest.fixture(scope="session")
169+
def audio_file_path(tmp_path_factory):
170+
# You would create a small temporary MP3 file here for testing
171+
# or use a public domain MP3 file's path
172+
return "path/to/valid_audio.mp3"
173+
174+
175+
@pytest.fixture(scope="session")
176+
def invalid_audio_file_path():
177+
return "path/to/invalid_audio.mp3"
178+
179+
180+
@pytest.fixture(scope="session")
181+
def audio_dict():
182+
# This should represent a valid audio dictionary as expected by the model
183+
return {"array": torch.randn(1, 16000), "sampling_rate": 16000}
184+
185+
186+
# Test initialization
187+
def test_initialization(whisper_model):
188+
assert whisper_model.model is not None
189+
assert whisper_model.processor is not None
190+
191+
192+
# Test successful transcription with file path
193+
def test_transcribe_with_file_path(whisper_model, audio_file_path):
194+
transcription = whisper_model.transcribe(audio_file_path)
195+
assert isinstance(transcription, str)
196+
197+
198+
# Test successful transcription with audio dict
199+
def test_transcribe_with_audio_dict(whisper_model, audio_dict):
200+
transcription = whisper_model.transcribe(audio_dict)
201+
assert isinstance(transcription, str)
202+
203+
204+
# Test for file not found error
205+
def test_file_not_found(whisper_model, invalid_audio_file_path):
206+
with pytest.raises(Exception):
207+
whisper_model.transcribe(invalid_audio_file_path)
208+
209+
210+
# Asynchronous tests
211+
@pytest.mark.asyncio
212+
async def test_async_transcription_success(whisper_model, audio_file_path):
213+
transcription = await whisper_model.async_transcribe(audio_file_path)
214+
assert isinstance(transcription, str)
215+
216+
217+
@pytest.mark.asyncio
218+
async def test_async_transcription_failure(whisper_model, invalid_audio_file_path):
219+
with pytest.raises(Exception):
220+
await whisper_model.async_transcribe(invalid_audio_file_path)
221+
222+
223+
# Testing real-time transcription simulation
224+
def test_real_time_transcription(whisper_model, audio_file_path, capsys):
225+
whisper_model.real_time_transcribe(audio_file_path, chunk_duration=1)
226+
captured = capsys.readouterr()
227+
assert "Starting real-time transcription..." in captured.out
228+
229+
230+
# Testing retry decorator for asynchronous function
231+
@pytest.mark.asyncio
232+
async def test_async_retry():
233+
@async_retry(max_retries=2, exceptions=(ValueError,), delay=0)
234+
async def failing_func():
235+
raise ValueError("Test")
236+
237+
with pytest.raises(ValueError):
238+
await failing_func()
239+
240+
241+
# Mocking the actual model to avoid GPU/CPU intensive operations during test
242+
@pytest.fixture
243+
def mocked_model(monkeypatch):
244+
model_mock = AsyncMock(AutoModelForSpeechSeq2Seq)
245+
processor_mock = MagicMock(AutoProcessor)
246+
monkeypatch.setattr(
247+
"swarms.models.distilled_whisperx.AutoModelForSpeechSeq2Seq.from_pretrained",
248+
model_mock,
249+
)
250+
monkeypatch.setattr(
251+
"swarms.models.distilled_whisperx.AutoProcessor.from_pretrained", processor_mock
252+
)
253+
return model_mock, processor_mock
254+
255+
256+
@pytest.mark.asyncio
257+
async def test_async_transcribe_with_mocked_model(mocked_model, audio_file_path):
258+
model_mock, processor_mock = mocked_model
259+
# Set up what the mock should return when it's called
260+
model_mock.return_value.generate.return_value = torch.tensor([[0]])
261+
processor_mock.return_value.batch_decode.return_value = ["mocked transcription"]
262+
model_wrapper = DistilWhisperModel()
263+
transcription = await model_wrapper.async_transcribe(audio_file_path)
264+
assert transcription == "mocked transcription"

0 commit comments

Comments
 (0)