Skip to content

Commit 43a7c88

Browse files
committed
Refactor
1 parent 3a2950b commit 43a7c88

30 files changed

+257
-197
lines changed

README.md

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ pip install aisploit
2121
```python
2222
from typing import Any
2323
import textwrap
24-
from aisploit.core import BaseCallbackHandler, BasePromptValue, Score
24+
from aisploit.core import BaseCallbackHandler, BasePromptValue, Score, Response
2525
from aisploit.model import ChatOpenAI
2626
from aisploit.redteam import RedTeamJob, RedTeamClassifierTask
2727
from aisploit.target import target
@@ -36,14 +36,18 @@ def play_game(level: GandalfLevel, max_attempt=5) -> None:
3636
gandalf_scorer = GandalfScorer(level=level, chat_model=chat_model)
3737

3838
class GandalfHandler(BaseCallbackHandler):
39-
def on_redteam_attempt_start(self, attempt: int, prompt: BasePromptValue, **kwargs: Any):
39+
def on_redteam_attempt_start(
40+
self, attempt: int, prompt: BasePromptValue, **kwargs: Any
41+
):
4042
print(f"Attempt #{attempt}")
4143
print("Sending the following to Gandalf:")
4244
print(f"{prompt.to_string()}\n")
4345

44-
def on_redteam_attempt_end(self, attempt: int, response: str, score: Score, **kwargs: Any):
46+
def on_redteam_attempt_end(
47+
self, attempt: int, response: Response, score: Score, **kwargs: Any
48+
):
4549
print("Response from Gandalf:")
46-
print(f"{response}\n")
50+
print(f"{response.content}\n")
4751

4852
task = RedTeamClassifierTask(
4953
objective=textwrap.dedent(
@@ -58,7 +62,7 @@ def play_game(level: GandalfLevel, max_attempt=5) -> None:
5862
),
5963
classifier=gandalf_scorer,
6064
)
61-
65+
6266
@target
6367
def send_prompt(prompt: str):
6468
return gandalf_bot.invoke(prompt)

aisploit/classifier/huggingface/pipeline_prompt_injection_identifier.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
pipeline,
55
)
66

7-
from ...core import BaseClassifier, Score
7+
from ...core import BaseTextClassifier, Score
88

99

10-
class PipelinePromptInjectionIdentifier(BaseClassifier[float]):
10+
class PipelinePromptInjectionIdentifier(BaseTextClassifier[float]):
1111
def __init__(
1212
self,
1313
*,
@@ -29,8 +29,8 @@ def __init__(
2929
self._injection_label = injection_label
3030
self._threshold = threshold
3131

32-
def score_text(self, text: str) -> Score[float]:
33-
result = self._model(text)
32+
def score(self, input: str) -> Score[float]:
33+
result = self._model(input)
3434

3535
score = (
3636
result[0]["score"]

aisploit/classifier/openai/moderation.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from openai import OpenAI
44
from openai.types.moderation import Moderation
55

6-
from ...core import BaseClassifier, Score
6+
from ...core import BaseTextClassifier, Score
77

88

9-
class ModerationClassifier(BaseClassifier[Moderation]):
9+
class ModerationClassifier(BaseTextClassifier[Moderation]):
1010
"""A classifier that uses the OpenAI Moderations API for scoring."""
1111

1212
def __init__(
@@ -19,14 +19,14 @@ def __init__(
1919

2020
self._client = OpenAI(api_key=api_key)
2121

22-
def score_text(self, text: str) -> Score[Moderation]:
23-
"""Score the text using the OpenAI Moderations API."""
24-
response = self._client.moderations.create(input=text)
22+
def score(self, input: str) -> Score[Moderation]:
23+
"""Score the input using the OpenAI Moderations API."""
24+
response = self._client.moderations.create(input=input)
2525
output = response.results[0]
2626

2727
return Score[Moderation](
2828
flagged=output.flagged,
2929
value=output,
30-
description="Moderation score for the given text",
30+
description="Moderation score for the given input",
3131
explanation="Details about the moderation score",
3232
)

aisploit/classifier/text.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,56 @@
11
import re
2-
from ..core import BaseClassifier, Score
2+
from ..core import BaseTextClassifier, Score
33

44

5-
class RegexClassifier(BaseClassifier[bool]):
5+
class RegexClassifier(BaseTextClassifier[bool]):
6+
"""A text classifier based on regular expressions."""
7+
68
def __init__(self, *, pattern: re.Pattern, flag_matches=True) -> None:
9+
"""Initialize the RegexClassifier.
10+
11+
Args:
12+
pattern (re.Pattern): The regular expression pattern to match.
13+
flag_matches (bool, optional): Flag indicating whether matches should be flagged. Defaults to True.
14+
"""
715
self._pattern = pattern
816
self._flag_matches = flag_matches
917

10-
def score_text(self, text: str) -> Score[bool]:
11-
if re.search(self._pattern, text):
18+
def score(self, input: str) -> Score[bool]:
19+
"""Score the input based on the regular expression pattern.
20+
21+
Args:
22+
input (str): The input text to be scored.
23+
24+
Returns:
25+
Score[bool]: A Score object representing the result of scoring.
26+
"""
27+
if re.search(self._pattern, input):
1228
return Score[bool](
1329
flagged=True if self._flag_matches else False,
1430
value=True,
15-
description=f"Return True if the pattern {self._pattern.pattern} is found in the text",
16-
explanation=f"Found pattern {self._pattern.pattern} in text",
31+
description=f"Return True if the pattern {self._pattern.pattern} is found in the input",
32+
explanation=f"Found pattern {self._pattern.pattern} in input",
1733
)
1834

1935
return Score[bool](
2036
flagged=False if self._flag_matches else True,
2137
value=False,
22-
description=f"Return True if the pattern {self._pattern.pattern} is found in the text",
23-
explanation=f"Did not find pattern {self._pattern.pattern} in text",
38+
description=f"Return True if the pattern {self._pattern.pattern} is found in the input",
39+
explanation=f"Did not find pattern {self._pattern.pattern} in input",
2440
)
2541

2642

2743
class SubstringClassifier(RegexClassifier):
44+
"""A text classifier based on substring matching."""
45+
2846
def __init__(self, *, substring: str, ignore_case=True, flag_matches=True) -> None:
47+
"""Initialize the SubstringClassifier.
48+
49+
Args:
50+
substring (str): The substring to match.
51+
ignore_case (bool, optional): Flag indicating whether to ignore case when matching substrings. Defaults to True.
52+
flag_matches (bool, optional): Flag indicating whether matches should be flagged. Defaults to True.
53+
"""
2954
compiled_pattern = (
3055
re.compile(substring, re.IGNORECASE)
3156
if ignore_case

aisploit/core/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
from .callbacks import BaseCallbackHandler, Callbacks, CallbackManager
2-
from .classifier import BaseClassifier, Score
2+
from .classifier import BaseClassifier, BaseTextClassifier, Score
33
from .converter import BaseConverter, BaseChatModelConverter
44
from .dataset import BaseDataset, YamlDeserializable
55
from .generator import BaseGenerator
66
from .job import BaseJob
77
from .model import BaseLLM, BaseChatModel, BaseModel, BaseEmbeddings
88
from .prompt import BasePromptValue
99
from .report import BaseReport
10-
from .target import BaseTarget
10+
from .target import BaseTarget, Response
1111
from .vectorstore import BaseVectorStore
1212

1313
__all__ = [
1414
"BaseCallbackHandler",
1515
"Callbacks",
1616
"CallbackManager",
1717
"BaseClassifier",
18+
"BaseTextClassifier",
1819
"Score",
1920
"BaseConverter",
2021
"BaseChatModelConverter",
@@ -29,5 +30,6 @@
2930
"BasePromptValue",
3031
"BaseReport",
3132
"BaseTarget",
33+
"Response",
3234
"BaseVectorStore",
3335
]

aisploit/core/callbacks.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from .prompt import BasePromptValue
44
from .classifier import Score
5+
from .target import Response
56

67

78
class BaseCallbackHandler:
@@ -20,13 +21,13 @@ def on_redteam_attempt_start(
2021
pass
2122

2223
def on_redteam_attempt_end(
23-
self, attempt: int, response: str, score: Score, *, run_id: str
24+
self, attempt: int, response: Response, score: Score, *, run_id: str
2425
):
2526
"""Called when a red team attempt ends.
2627
2728
Args:
2829
attempt (int): The attempt number.
29-
response (str): The response from the attempt.
30+
response (Response): The response from the attempt.
3031
score (Score): The score of the attempt.
3132
run_id (str): The ID of the current run.
3233
"""
@@ -50,6 +51,12 @@ def on_scanner_plugin_end(self, name: str, *, run_id: str):
5051
"""
5152
pass
5253

54+
def on_sender_send_prompt_start(self):
55+
pass
56+
57+
def on_sender_send_prompt_end(self):
58+
pass
59+
5360

5461
Callbacks = Sequence[BaseCallbackHandler]
5562

@@ -84,12 +91,12 @@ def on_redteam_attempt_start(self, attempt: int, prompt: BasePromptValue):
8491
attempt=attempt, prompt=prompt, run_id=self.run_id
8592
)
8693

87-
def on_redteam_attempt_end(self, attempt: int, response: str, score: Score):
94+
def on_redteam_attempt_end(self, attempt: int, response: Response, score: Score):
8895
"""Notify callback handlers when a red team attempt ends.
8996
9097
Args:
9198
attempt (int): The attempt number.
92-
response (str): The response from the attempt.
99+
response (Response): The response from the attempt.
93100
score (Score): The score of the attempt.
94101
"""
95102
for cb in self._callbacks:
@@ -114,3 +121,11 @@ def on_scanner_plugin_end(self, name: str):
114121
"""
115122
for cb in self._callbacks:
116123
cb.on_scanner_plugin_end(name=name, run_id=self.run_id)
124+
125+
def on_sender_send_prompt_start(self):
126+
for cb in self._callbacks:
127+
cb.on_sender_send_prompt_start()
128+
129+
def on_sender_send_prompt_end(self):
130+
for cb in self._callbacks:
131+
cb.on_sender_send_prompt_end()

aisploit/core/classifier.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,44 @@
22
from abc import ABC, abstractmethod
33
from dataclasses import dataclass
44

5-
65
T = TypeVar("T")
6+
Input = TypeVar("Input")
77

88

99
@dataclass(frozen=True)
1010
class Score(Generic[T]):
11-
"""A class representing a score."""
11+
"""A class representing a score.
12+
13+
Attributes:
14+
flagged (bool): Whether the score is flagged.
15+
value (T): The value of the score.
16+
description (str): Optional description of the score.
17+
explanation (str): Optional explanation of the score.
18+
"""
1219

1320
flagged: bool
1421
value: T
1522
description: str = ""
1623
explanation: str = ""
1724

1825

19-
class BaseClassifier(ABC, Generic[T]):
26+
class BaseClassifier(ABC, Generic[T, Input]):
2027
"""An abstract base class for classifiers."""
2128

2229
@abstractmethod
23-
def score_text(self, text: str) -> Score[T]:
24-
"""Score the text and return a Score object.
30+
def score(self, input: Input) -> Score[T]:
31+
"""Score the input and return a Score object.
2532
2633
Args:
27-
text (str): The text to be scored.
34+
input (Input): The input to be scored.
2835
2936
Returns:
30-
Score[T]: A Score object representing the score of the text.
37+
Score[T]: A Score object representing the score of the input.
3138
"""
3239
pass
40+
41+
42+
class BaseTextClassifier(BaseClassifier[T, str], Generic[T]):
43+
"""An abstract base class for text classifiers."""
44+
45+
pass

aisploit/core/target.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,35 @@
11
from abc import ABC, abstractmethod
2+
from dataclasses import dataclass
23

34
from .prompt import BasePromptValue
45

56

7+
@dataclass(frozen=True)
8+
class Response:
9+
"""A class representing a response from a target.
10+
11+
Attributes:
12+
content (str): The content of the response.
13+
"""
14+
15+
content: str
16+
17+
def __repr__(self) -> str:
18+
"""Return a string representation of the Response."""
19+
return f"content={repr(self.content)}"
20+
21+
622
class BaseTarget(ABC):
23+
"""An abstract base class for targets."""
24+
725
@abstractmethod
8-
def send_prompt(self, prompt: BasePromptValue) -> str:
26+
def send_prompt(self, prompt: BasePromptValue) -> Response:
27+
"""Send a prompt to the target and return the response.
28+
29+
Args:
30+
prompt (BasePromptValue): The prompt to send.
31+
32+
Returns:
33+
Response: The response from the target.
34+
"""
935
pass

0 commit comments

Comments
 (0)