Skip to content

Commit 7a4957a

Browse files
committed
Add docstrings
1 parent 0bb5849 commit 7a4957a

File tree

1 file changed

+54
-1
lines changed

1 file changed

+54
-1
lines changed

aisploit/redteam/task.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from typing import List
21
from abc import ABC, abstractmethod
32

43
from langchain_core.prompts import PromptTemplate
@@ -8,6 +7,10 @@
87

98

109
class RedTeamTask(ABC):
10+
"""
11+
Abstract base class for defining red team tasks in a conversation.
12+
"""
13+
1114
def __init__(
1215
self,
1316
*,
@@ -16,6 +19,15 @@ def __init__(
1619
input_messages_key="input",
1720
history_messages_key="chat_history",
1821
) -> None:
22+
"""
23+
Initialize a RedTeamTask instance.
24+
25+
Args:
26+
objective (str): The objective of the task.
27+
system_template (PromptTemplate): The system prompt template.
28+
input_messages_key (str): The key for input messages.
29+
history_messages_key (str): The key for chat history messages.
30+
"""
1931
if len(objective) == 0:
2032
raise ValueError("Objective cannot be empty.")
2133

@@ -25,6 +37,12 @@ def __init__(
2537

2638
@property
2739
def prompt(self) -> ChatPromptTemplate:
40+
"""
41+
Get the chat prompt template.
42+
43+
Returns:
44+
ChatPromptTemplate: The chat prompt template.
45+
"""
2846
return ChatPromptTemplate.from_messages(
2947
[
3048
("system", self.system_prompt),
@@ -68,6 +86,10 @@ def evaluate_task_completion(
6886

6987

7088
class RedTeamEndTokenTask(RedTeamTask):
89+
"""
90+
Red team task with an end token to mark task completion.
91+
"""
92+
7193
def __init__(
7294
self,
7395
*,
@@ -77,16 +99,33 @@ def __init__(
7799
history_messages_key="chat_history",
78100
end_token=RED_TEAM_END_TOKEN,
79101
) -> None:
102+
"""
103+
Initialize a RedTeamEndTokenTask instance.
104+
105+
Args:
106+
objective (str): The objective of the task.
107+
system_template (PromptTemplate): The system prompt template.
108+
input_messages_key (str): The key for input messages.
109+
history_messages_key (str): The key for chat history messages.
110+
end_token (str): The token to mark task completion.
111+
"""
80112
super().__init__(
81113
objective=objective,
82114
system_template=system_template.partial(end_token=end_token),
83115
input_messages_key=input_messages_key,
84116
history_messages_key=history_messages_key,
85117
)
118+
86119
self._end_token = end_token
87120

88121
@property
89122
def end_token(self) -> str:
123+
"""
124+
Get the end token.
125+
126+
Returns:
127+
str: The end token.
128+
"""
90129
return self._end_token
91130

92131
def evaluate_task_completion(
@@ -137,6 +176,10 @@ def evaluate_task_completion(
137176

138177

139178
class RedTeamClassifierTask(RedTeamTask):
179+
"""
180+
Red team task using a classifier to evaluate completion.
181+
"""
182+
140183
def __init__(
141184
self,
142185
*,
@@ -146,6 +189,16 @@ def __init__(
146189
input_messages_key="input",
147190
history_messages_key="chat_history",
148191
) -> None:
192+
"""
193+
Initialize a RedTeamClassifierTask instance.
194+
195+
Args:
196+
objective (str): The objective of the task.
197+
classifier (BaseClassifier): The classifier used to evaluate completion.
198+
system_template (PromptTemplate): The system prompt template.
199+
input_messages_key (str): The key for input messages.
200+
history_messages_key (str): The key for chat history messages.
201+
"""
149202
super().__init__(
150203
objective=objective,
151204
system_template=system_template,

0 commit comments

Comments
 (0)