Skip to content

Commit 4de1582

Browse files
committed
Misc
1 parent 17e5617 commit 4de1582

27 files changed

+1106
-69
lines changed

aisploit/classifiers/amazon/comprehend.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,34 @@
1010

1111
@dataclass
1212
class BaseComprehendClassifier(BaseTextClassifier[T], Generic[T]):
13+
"""An abstract base class for Comprehend classifiers."""
14+
1315
session: boto3.Session = field(default_factory=lambda: boto3.Session())
1416
region_name: str = "us-east-1"
1517

1618
def __post_init__(self):
19+
"""Initialize the Comprehend client."""
1720
self._client = self.session.client("comprehend", region_name=self.region_name)
1821

1922

2023
@dataclass
2124
class ComprehendPIIClassifier(BaseComprehendClassifier[List[Any]]):
25+
"""A classifier that uses Amazon Comprehend to detect personally identifiable information (PII)."""
26+
2227
language: str = "en"
2328
threshold: float = 0.7
29+
tags: List[str] = field(default_factory=lambda: ["leakage"], init=False)
30+
31+
def score(self, input: str, _: List[str] | None = None) -> Score[List[Any]]:
32+
"""Score the input for PII using Amazon Comprehend.
33+
34+
Args:
35+
input (str): The input text to be scored.
36+
_: List of reference inputs (ignored).
2437
25-
def score(self, input: str) -> Score[List[Any]]:
38+
Returns:
39+
Score[List[Any]]: A Score object representing the PII entities found in the input.
40+
"""
2641
response = self._client.detect_pii_entities(Text=input, LanguageCode=self.language)
2742

2843
entities = [entity for entity in response["Entities"] if entity["Score"] >= self.threshold]
@@ -39,10 +54,22 @@ def score(self, input: str) -> Score[List[Any]]:
3954

4055
@dataclass
4156
class ComprehendToxicityClassifier(BaseComprehendClassifier[Dict[str, Any]]):
57+
"""A classifier that uses Amazon Comprehend to detect toxicity in text."""
58+
4259
language: str = "en"
4360
threshold: float = 0.7
61+
tags: List[str] = field(default_factory=lambda: ["toxicity"], init=False)
62+
63+
def score(self, input: str, _: List[str] | None = None) -> Score[Dict[str, Any]]:
64+
"""Score the input for toxicity using Amazon Comprehend.
65+
66+
Args:
67+
input (str): The input text to be scored.
68+
_: List of reference inputs (ignored).
4469
45-
def score(self, input: str) -> Score[Dict[str, Any]]:
70+
Returns:
71+
Score[Dict[str, Any]]: A Score object representing the toxicity score of the input.
72+
"""
4673
response = self._client.detect_toxic_content(
4774
TextSegments=[
4875
{'Text': input},
Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1-
from .pipeline_prompt_injection_identifier import PipelinePromptInjectionIdentifier
1+
from .bert_score import BertScoreClassifier
2+
from .bleu import BleuClassifier
3+
from .pipeline_prompt_injection import PipelinePromptInjectionClassifier
24

3-
__all__ = ["PipelinePromptInjectionIdentifier"]
5+
__all__ = [
6+
"BertScoreClassifier",
7+
"BleuClassifier",
8+
"PipelinePromptInjectionClassifier",
9+
]
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from dataclasses import dataclass, field
2+
from typing import Any, Dict, List
3+
4+
import evaluate
5+
6+
from ...core import BaseTextClassifier, Score
7+
8+
9+
@dataclass
10+
class BertScoreClassifier(BaseTextClassifier[Dict[str, Any]]):
11+
"""A classifier that computes BERTScore for text inputs."""
12+
13+
threshold: float = 0.8
14+
model_type: str = "distilbert-base-uncased"
15+
bertscore: evaluate.EvaluationModule = field(default_factory=lambda: evaluate.load("bertscore"), init=False)
16+
17+
def score(self, input: str, references: List[str] | None = None) -> Score[Dict[str, Any]]:
18+
"""Score the input using BERTScore computed by the evaluate module.
19+
20+
Args:
21+
input (str): The input text to be scored.
22+
references (List[str], optional): List of reference texts. Defaults to None.
23+
24+
Raises:
25+
ValueError: If references is None or if the number of references is not equal to 1.
26+
27+
Returns:
28+
Score[Dict[str, Any]]: A Score object representing the BERTScore of the input.
29+
"""
30+
if not references or not len(references) == 1:
31+
raise ValueError("The number of references must be exactly 1.")
32+
33+
score = self.bertscore.compute(
34+
predictions=[input],
35+
references=[references[0]],
36+
model_type=self.model_type,
37+
)
38+
39+
f1_score = score["f1"][0]
40+
41+
return Score[Dict[str, Any]](
42+
flagged=f1_score < self.threshold,
43+
value=score,
44+
description="Returns True if the f1 score is less than the threshold",
45+
explanation=f"The f1 score for the input and reference is {f1_score}",
46+
)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from dataclasses import dataclass, field
2+
from typing import Any, Dict, List
3+
4+
import evaluate
5+
6+
from ...core import BaseTextClassifier, Score
7+
8+
9+
@dataclass
10+
class BleuClassifier(BaseTextClassifier[Dict[str, Any]]):
11+
"""A classifier that computes BLEU score for text inputs."""
12+
13+
threshold: float = 0.2
14+
bleu: evaluate.EvaluationModule = field(default_factory=lambda: evaluate.load("bleu"), init=False)
15+
16+
def score(self, input: str, references: List[str] | None = None) -> Score[Dict[str, Any]]:
17+
"""Score the input using BLEU score computed by the evaluate module.
18+
19+
Args:
20+
input (str): The input text to be scored.
21+
references (List[str], optional): List of reference texts. Defaults to None.
22+
23+
Raises:
24+
ValueError: If the number of references is not equal to 1.
25+
26+
Returns:
27+
Score[Dict[str, Any]]: A Score object representing the BLEU score of the input.
28+
"""
29+
if not references or not len(references) == 1:
30+
raise ValueError("The number of references must be exactly 1.")
31+
32+
score = self.bleu.compute(
33+
predictions=[input],
34+
references=[references[0]],
35+
max_order=2,
36+
)
37+
38+
bleu_score = score["bleu"]
39+
40+
return Score[Dict[str, Any]](
41+
flagged=bleu_score < self.threshold,
42+
value=score,
43+
description="Returns True if the bleu score is less than the threshold",
44+
explanation=f"The bleu score for the input and reference is {bleu_score}",
45+
)

aisploit/classifiers/huggingface/pipeline_prompt_injection_identifier.py renamed to aisploit/classifiers/huggingface/pipeline_prompt_injection.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import List
2+
13
from transformers import (
24
AutoModelForSequenceClassification,
35
AutoTokenizer,
@@ -7,7 +9,7 @@
79
from ...core import BaseTextClassifier, Score
810

911

10-
class PipelinePromptInjectionIdentifier(BaseTextClassifier[float]):
12+
class PipelinePromptInjectionClassifier(BaseTextClassifier[float]):
1113
def __init__(
1214
self,
1315
*,
@@ -29,7 +31,7 @@ def __init__(
2931
self._injection_label = injection_label
3032
self._threshold = threshold
3133

32-
def score(self, input: str) -> Score[float]:
34+
def score(self, input: str, references: List[str] | None = None) -> Score[float]:
3335
result = self._model(input)
3436

3537
score = result[0]["score"] if result[0]["label"] == self._injection_label else 1 - result[0]["score"]

aisploit/classifiers/markdown.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
class MarkdownInjectionClassifier(BaseTextClassifier[List[Any]]):
88
"""A text classifier to detect Markdown injection in input text."""
99

10-
def score(self, input: str) -> Score[List[Any]]:
10+
def score(self, input: str, references: List[str] | None = None) -> Score[List[Any]]:
1111
# !\[.*?\]\((.*?)\) - This is for the inline image format in Markdown, which is ![alt_text](url).
1212
# !\[.*?\]\[(.*?)\] - This is for the reference-style image format in Markdown, which is ![alt_text][image_reference].
1313
pattern = r"!\s*\[.*?\]\((.*?)\)|!\s*\[.*?\]\[(.*?)\]"

aisploit/classifiers/openai/moderation.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Optional
2+
from typing import List, Optional
33

44
from openai import OpenAI
55
from openai.types.moderation import Moderation
@@ -20,8 +20,16 @@ def __init__(
2020

2121
self._client = OpenAI(api_key=api_key)
2222

23-
def score(self, input: str) -> Score[Moderation]:
24-
"""Score the input using the OpenAI Moderations API."""
23+
def score(self, input: str, _: List[str] | None = None) -> Score[Moderation]:
24+
"""Score the input using the OpenAI Moderations API.
25+
26+
Args:
27+
input (str): The input text to be scored.
28+
_: List of references (ignored).
29+
30+
Returns:
31+
Score[Moderation]: A Score object representing the moderation score of the input.
32+
"""
2533
response = self._client.moderations.create(input=input)
2634
output = response.results[0]
2735

aisploit/classifiers/package_hallucination.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import re
2-
from dataclasses import dataclass
2+
from dataclasses import dataclass, field
33
from typing import List
44

55
import requests
@@ -15,11 +15,12 @@ class PythonPackageHallucinationClassifier(BaseTextClassifier[List[str]]):
1515
"""
1616

1717
python_version: str = "3.12"
18+
tags: List[str] = field(default_factory=lambda: ["hallucination"], init=False)
1819

1920
def __post_init__(self) -> None:
2021
self.libraries = stdlib_list(self.python_version)
2122

22-
def score(self, input: str) -> Score[List[str]]:
23+
def score(self, input: str, references: List[str] | None = None) -> Score[List[str]]:
2324
"""
2425
Scores the input based on the presence of hallucinated Python package names.
2526

aisploit/classifiers/presidio/presidio_analyser.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from dataclasses import dataclass
1+
from dataclasses import dataclass, field
22
from typing import List
33

44
from presidio_analyzer import AnalyzerEngine, RecognizerResult
@@ -8,15 +8,17 @@
88

99
@dataclass
1010
class PresidioAnalyserClassifier(BaseTextClassifier[List[RecognizerResult]]):
11+
1112
language: str = "en"
1213
entities: List[str] | None = None
1314
threshold: float = 0.7
15+
tags: List[str] = field(default_factory=lambda: ["leakage"], init=False)
1416

1517
def __post_init__(self) -> None:
1618
# Set up the engine, loads the NLP module (spaCy model by default) and other PII recognizers
1719
self._analyzer = AnalyzerEngine(default_score_threshold=self.threshold)
1820

19-
def score(self, input: str) -> Score[List[RecognizerResult]]:
21+
def score(self, input: str, references: List[str] | None = None) -> Score[List[RecognizerResult]]:
2022
# Call analyzer to get results
2123
results = self._analyzer.analyze(text=input, entities=self.entities, language=self.language)
2224

aisploit/classifiers/text.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import re
22
from dataclasses import dataclass
3+
from typing import List
34

45
from ..core import BaseTextClassifier, Score
56

@@ -17,7 +18,7 @@ def __init__(self, *, pattern: re.Pattern, flag_matches=True) -> None:
1718
self._pattern = pattern
1819
self._flag_matches = flag_matches
1920

20-
def score(self, input: str) -> Score[bool]:
21+
def score(self, input: str, references: List[str] | None = None) -> Score[bool]:
2122
"""Score the input based on the regular expression pattern.
2223
2324
Args:
@@ -61,7 +62,7 @@ def __init__(self, *, substring: str, ignore_case=True, flag_matches=True) -> No
6162
class TextTokenClassifier(BaseTextClassifier[bool]):
6263
token: str
6364

64-
def score(self, input: str) -> Score[bool]:
65+
def score(self, input: str, references: List[str] | None = None) -> Score[bool]:
6566
return Score[bool](
6667
flagged=self.token in input,
6768
value=self.token in input,

aisploit/core/classifier.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC, abstractmethod
2-
from dataclasses import dataclass
3-
from typing import Generic, TypeVar
2+
from dataclasses import dataclass, field
3+
from typing import Generic, List, TypeVar
44

55
T = TypeVar("T")
66
Input = TypeVar("Input")
@@ -23,15 +23,19 @@ class Score(Generic[T]):
2323
explanation: str = ""
2424

2525

26+
@dataclass
2627
class BaseClassifier(ABC, Generic[T, Input]):
2728
"""An abstract base class for classifiers."""
2829

30+
tags: List[str] = field(default_factory=list, init=False)
31+
2932
@abstractmethod
30-
def score(self, input: Input) -> Score[T]:
33+
def score(self, input: Input, references: List[Input] | None = None) -> Score[T]:
3134
"""Score the input and return a Score object.
3235
3336
Args:
3437
input (Input): The input to be scored.
38+
references (List[Input], optional): List of reference inputs. Defaults to None.
3539
3640
Returns:
3741
Score[T]: A Score object representing the score of the input.

aisploit/core/job.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1-
from dataclasses import dataclass
1+
import sys
2+
from dataclasses import dataclass, field
23
from uuid import uuid4
34

5+
from ..utils import is_running_in_jupyter_notebook
6+
47

58
@dataclass(kw_only=True)
69
class BaseJob:
10+
disable_progressbar: bool = field(
11+
default_factory=lambda: False if (is_running_in_jupyter_notebook() or sys.stdout.isatty()) else True
12+
)
713
verbose: bool = False
814

915
def _create_run_id(self) -> str:

aisploit/demo/gandalf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def __init__(self, level: GandalfLevel, chat_model: BaseChatModel) -> None:
105105
self._endpoint = f"{GANDALF_API_ENDPOINT}/guess-password"
106106
self._chain = chat_model | StrOutputParser()
107107

108-
def score(self, input: str) -> Score[str]:
108+
def score(self, input: str, refereces: List[str] | None = None) -> Score[str]:
109109
score_description = (
110110
"When password is found in input, the score is the password itself. " "If nothing found, the score is empty"
111111
)

aisploit/red_team/job.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
GetSessionHistoryCallable,
1010
RunnableWithMessageHistory,
1111
)
12+
from tqdm.auto import tqdm
1213

1314
from .report import RedTeamReport, RedTeamReportEntry
1415
from .task import RedTeamTask
@@ -66,7 +67,7 @@ def execute(
6667

6768
current_prompt_text = initial_prompt_text
6869

69-
for attempt in range(1, max_attempt + 1):
70+
for attempt in tqdm(range(1, max_attempt + 1), desc="Attacking", disable=self.disable_progressbar):
7071
current_prompt_text = chain.invoke(
7172
input={self.task.input_messages_key: current_prompt_text},
7273
config={"configurable": {"session_id": run_id}},

aisploit/scanner/job.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
from dataclasses import dataclass, field
22
from typing import List, Optional, Sequence
33

4+
from tqdm.auto import tqdm
5+
46
from .plugin import Plugin
5-
from .plugins import PromptInjectionPlugin
67
from .report import Issue, ScanReport
78
from ..core import BaseJob, BaseTarget, CallbackManager, Callbacks
89

910

1011
@dataclass
1112
class ScannerJob(BaseJob):
1213
target: BaseTarget
13-
plugins: Sequence[Plugin] = field(default_factory=lambda: [PromptInjectionPlugin()])
14+
plugins: Sequence[Plugin]
1415
callbacks: Callbacks = field(default_factory=list)
1516

1617
def execute(self, *, run_id: Optional[str] = None, tags: Optional[Sequence[str]] = None) -> ScanReport:
@@ -22,7 +23,7 @@ def execute(self, *, run_id: Optional[str] = None, tags: Optional[Sequence[str]]
2223
)
2324

2425
issues: List[Issue] = []
25-
for plugin in self.plugins:
26+
for plugin in tqdm(self.plugins, desc="Scanning", disable=self.disable_progressbar):
2627
callback_manager.on_scanner_plugin_start(plugin.name)
2728
plugin_issues = plugin.run(run_id=run_id, target=self.target)
2829
callback_manager.on_scanner_plugin_end(plugin.name)

0 commit comments

Comments
 (0)