Skip to content

Commit d8224a9

Browse files
Merge pull request #220 from zeya30/za/progress-bar
update unit tests
2 parents 96e482e + 9554cff commit d8224a9

File tree

7 files changed

+231
-34
lines changed

7 files changed

+231
-34
lines changed

.github/workflows/ci.yaml

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,18 @@ jobs:
4040
with:
4141
persist-credentials: false
4242

43+
- name: Free Disk Space (Ubuntu)
44+
if: matrix.os == 'ubuntu-latest'
45+
uses: jlumbroso/free-disk-space@main
46+
with:
47+
tool-cache: false
48+
android: true
49+
dotnet: true
50+
haskell: true
51+
large-packages: true
52+
docker-images: true
53+
swap-storage: true
54+
4355
- name: Set up Python
4456
uses: actions/setup-python@v6.0.0
4557
with:
@@ -117,4 +129,4 @@ jobs:
117129
uses: github/codeql-action/upload-sarif@v3
118130
with:
119131
sarif_file: semgrep.sarif
120-
if: always()
132+
if: always()

examples/adversarial/adversarial_toxicity.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@
215215
" system_style=\"benign\",\n",
216216
" prompt_style=\"nontoxic\",\n",
217217
" sample_size=10, # 1000 is the recommended sample_size\n",
218-
" show_progress_bars=False\n",
218+
" show_progress_bars=False,\n",
219219
")"
220220
]
221221
},

langfair/generator/redteaming.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import pkgutil
2121
import random
2222
from typing import Any, Dict, List, Optional, Tuple, Union
23-
from rich.progress import Progress
2423

2524
from langfair.constants.cost_data import FAILURE_MESSAGE
2625
from langfair.generator import ResponseGenerator
@@ -110,7 +109,7 @@ async def counterfactual(
110109
111110
count : int, default=25
112111
Specifies number of responses to generate for each prompt.
113-
112+
114113
show_progress_bars : bool, default=True
115114
If True, displays progress bars while generating responses
116115
@@ -123,7 +122,6 @@ async def counterfactual(
123122
dataset = await self._generate_from_template(
124123
prompt_templates=prompt_templates, system_styles=system_styles, count=count
125124
)
126-
print("Responses successfully generated!")
127125
return self._format_result(
128126
dataset=dataset,
129127
prompt_templates=prompt_templates,
@@ -166,7 +164,7 @@ async def toxicity(
166164
167165
custom_system_prompt : str or None, default=None
168166
Optional argument for user to provide custom system prompt for toxicity generation.
169-
167+
170168
show_progress_bars : bool, default=True
171169
If True, displays progress bars while generating responses
172170
@@ -187,7 +185,10 @@ async def toxicity(
187185
else SYSTEM_PROMPT_DICT[system_style]
188186
)
189187
result = await self.generate_responses(
190-
prompts=prompts, system_prompt=system_prompt, count=count, show_progress_bars=show_progress_bars
188+
prompts=prompts,
189+
system_prompt=system_prompt,
190+
count=count,
191+
show_progress_bars=show_progress_bars,
191192
)
192193
responses = result["data"]["response"]
193194
duplicated_prompts = [
@@ -211,7 +212,7 @@ async def _generate_from_template(
211212
prompt_templates: Dict[str, List[str]],
212213
system_styles: List[str],
213214
count: int,
214-
show_progress_bars: bool = True,
215+
show_progress_bars: bool = True,
215216
) -> Dict[str, Any]:
216217
"""
217218
Used for generating responses from template-based prompt. This method is
@@ -230,7 +231,7 @@ async def _generate_from_template(
230231
prompts=prompt_templates["text"],
231232
system_prompt=system_prompt,
232233
count=count,
233-
show_progress_bars=show_progress_bars
234+
show_progress_bars=show_progress_bars,
234235
)
235236
dataset[system_style + "_response"] = tmp["data"]["response"]
236237
return dataset

tests/conftest.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright 2025 CVS Health and/or one of its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from types import SimpleNamespace
16+
17+
import pytest
18+
19+
20+
class FakeTask:
21+
def __init__(self, task_id, description, total):
22+
self.id = task_id
23+
self.description = description
24+
self.total = total
25+
self.completed = 0
26+
27+
28+
class FakeProgress:
29+
"""
30+
Minimal stand-in for rich.progress.Progress used in tests.
31+
- add_task(description, total) -> task_id
32+
- update(task_id, completed=...)
33+
- tasks[task_id] -> FakeTask
34+
- live.is_started -> bool
35+
"""
36+
37+
def __init__(self):
38+
self._next_id = 0
39+
self.tasks = {}
40+
self.live = SimpleNamespace(is_started=False)
41+
42+
def add_task(self, description, total):
43+
task_id = self._next_id
44+
self._next_id += 1
45+
self.tasks[task_id] = FakeTask(task_id, description, total)
46+
return task_id
47+
48+
def update(self, task_id, completed=None):
49+
task = self.tasks[task_id]
50+
if completed is not None:
51+
task.completed = completed
52+
53+
def start(self):
54+
self.live.is_started = True
55+
56+
def stop(self):
57+
self.live.is_started = False
58+
59+
60+
@pytest.fixture(autouse=True)
61+
def mock_display_progress(monkeypatch):
62+
"""
63+
Mock progress helpers globally so tests never touch Rich's Live display.
64+
"""
65+
import langfair.utils.display as display_module
66+
67+
def _start_progress_bar(existing_progress_bar=None):
68+
if isinstance(existing_progress_bar, FakeProgress):
69+
existing_progress_bar.start()
70+
return existing_progress_bar
71+
fake = FakeProgress()
72+
fake.start()
73+
return fake
74+
75+
def _stop_progress_bar(progress_bar):
76+
if isinstance(progress_bar, FakeProgress):
77+
progress_bar.stop()
78+
79+
monkeypatch.setattr(display_module, "start_progress_bar", _start_progress_bar)
80+
monkeypatch.setattr(display_module, "stop_progress_bar", _stop_progress_bar)
81+
82+
83+
def pytest_configure(config):
84+
config.addinivalue_line(
85+
"markers",
86+
"real_progress: Opt-out of the FakeProgress mock (use real rich.Progress).",
87+
)

tests/test_counterfactual_metrics.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,26 +63,31 @@ def test_rougel():
6363
assert rougel.evaluate(data["text1"], data["text2"]) == actual_results["test3"]
6464

6565

66-
def test_senitement1():
66+
def test_sentiment1():
6767
sentiment = SentimentBias()
6868
assert sentiment.evaluate(data["text1"], data["text2"]) == actual_results["test4"]
6969

7070

71-
def test_senitement2():
71+
def test_sentiment2():
7272
sentiment = SentimentBias(parity="weak")
7373
assert sentiment.evaluate(data["text1"], data["text2"]) == pytest.approx(
7474
actual_results["test5"], rel=1e-02
7575
)
7676

7777

78-
def test_senitement3(monkeypatch):
79-
MOCKED_CLASSIFIER_RESULT = [
80-
actual_results["classifier_result1"],
81-
actual_results["classifier_result2"],
82-
]
78+
def test_sentiment3(monkeypatch):
79+
group1 = actual_results["classifier_result1"]
80+
group2 = actual_results["classifier_result2"]
8381

84-
def mock_get_classifier(*args, **kwargs):
85-
return MOCKED_CLASSIFIER_RESULT.pop()
82+
def mock_get_classifier(texts, return_all_scores=True):
83+
if texts in [[t] for t in data["text1"]]:
84+
idx = data["text1"].index(texts[0])
85+
return [group1[idx]]
86+
elif texts in [[t] for t in data["text2"]]:
87+
idx = data["text2"].index(texts[0])
88+
return [group2[idx]]
89+
else:
90+
return [[]]
8691

8792
sentiment = SentimentBias(classifier="roberta")
8893
monkeypatch.setattr(sentiment, "classifier_instance", mock_get_classifier)

tests/test_display.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright 2025 CVS Health and/or one of its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import time
16+
17+
import pytest
18+
19+
import langfair.utils.display as display_module
20+
from langfair.utils.display import (
21+
ConditionalBarColumn,
22+
ConditionalSpinnerColumn,
23+
ConditionalTextColumn,
24+
ConditionalTextPercentageColumn,
25+
ConditionalTimeElapsedColumn,
26+
)
27+
28+
29+
@pytest.fixture(autouse=True)
30+
def fast_sleep(monkeypatch):
31+
monkeypatch.setattr(time, "sleep", lambda x: None)
32+
33+
34+
def test_start_progress_bar_without_existing():
35+
progress = display_module.start_progress_bar()
36+
assert progress.live.is_started
37+
task_id = progress.add_task("[Task]Test", total=10)
38+
progress.update(task_id, completed=5)
39+
task = progress.tasks[task_id]
40+
assert task.completed == 5
41+
display_module.stop_progress_bar(progress)
42+
assert not progress.live.is_started
43+
44+
45+
def test_start_progress_bar_with_existing():
46+
existing = display_module.start_progress_bar()
47+
progress = display_module.start_progress_bar(existing)
48+
assert progress is existing
49+
assert progress.live.is_started
50+
display_module.stop_progress_bar(progress)
51+
assert not progress.live.is_started
52+
53+
54+
def test_stop_progress_bar_stops():
55+
progress = display_module.start_progress_bar()
56+
display_module.stop_progress_bar(progress)
57+
assert not progress.live.is_started
58+
59+
60+
def test_task_creation_and_update():
61+
progress = display_module.start_progress_bar()
62+
task_id = progress.add_task("[Task]Downloading", total=100)
63+
progress.update(task_id, completed=40)
64+
task = progress.tasks[task_id]
65+
assert task.description == "[Task]Downloading"
66+
assert task.completed == 40
67+
assert task.total == 100
68+
display_module.stop_progress_bar(progress)
69+
70+
71+
def test_conditional_columns_render_normal_task():
72+
progress = display_module.start_progress_bar()
73+
task_id = progress.add_task("[Task]Processing", total=80)
74+
progress.update(task_id, completed=20)
75+
task = progress.tasks[task_id]
76+
77+
# Validate Conditional* behavior driven by description prefixes
78+
assert "[progress.description]Processing" in ConditionalTextColumn(
79+
"[progress.description]{task.description}"
80+
).render(task)
81+
82+
assert "[progress.percentage]20/80" in ConditionalTextPercentageColumn(
83+
"[progress.percentage]{task.completed}/{task.total}"
84+
).render(task)
85+
86+
display_module.stop_progress_bar(progress)
87+
88+
89+
def test_conditional_columns_render_no_progress_bar():
90+
progress = display_module.start_progress_bar()
91+
task_id = progress.add_task("[No Progress Bar]Hidden", total=50)
92+
progress.update(task_id, completed=10)
93+
task = progress.tasks[task_id]
94+
95+
assert ConditionalBarColumn().render(task) == ""
96+
assert ConditionalTimeElapsedColumn().render(task) == ""
97+
assert (
98+
ConditionalTextColumn("[progress.description]{task.description}").render(task)
99+
== "[progress.description]Hidden"
100+
)
101+
assert (
102+
ConditionalTextPercentageColumn(
103+
"[progress.percentage]{task.completed}/{task.total}"
104+
).render(task)
105+
== ""
106+
)
107+
assert ConditionalSpinnerColumn().render(task) == ""

tests/test_responsegenerator.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,7 @@
2424
async def test_generator(monkeypatch):
2525
count = 3
2626
MOCKED_PROMPTS = ["Prompt 1", "Prompt 2", "Prompt 3"]
27-
MOCKED_DUPLICATE_PROMPTS = [
28-
prompt for prompt, i in itertools.product(MOCKED_PROMPTS, range(count))
29-
]
27+
3028
MOCKED_RESPONSES = [
3129
"Mocked response 1",
3230
"Mocked response 2",
@@ -54,18 +52,5 @@ async def mock_async_api_call(prompt, count, *args, **kwargs):
5452
data = await generator_object.generate_responses(
5553
prompts=MOCKED_PROMPTS, count=count
5654
)
57-
58-
cost = await generator_object.estimate_token_cost(
59-
tiktoken_model_name="gpt-3.5-turbo-16k-0613", # gitleaks:allow
60-
prompts=MOCKED_DUPLICATE_PROMPTS,
61-
example_responses=MOCKED_RESPONSES[:3],
62-
count=count,
63-
)
64-
6555
assert data["data"]["response"] == MOCKED_DUPLICATED_RESPONSES
6656
assert data["metadata"]["non_completion_rate"] == 1 / 3
67-
assert cost == {
68-
"Estimated Prompt Token Cost (USD)": 0.001539,
69-
"Estimated Completion Token Cost (USD)": 0.000504,
70-
"Estimated Total Token Cost (USD)": 0.002043,
71-
}

0 commit comments

Comments
 (0)