Skip to content

Commit f9d7660

Browse files
committed
Add image targets
1 parent 6740590 commit f9d7660

File tree

7 files changed

+287
-23
lines changed

7 files changed

+287
-23
lines changed

β€Žaisploit/core/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .model import BaseChatModel, BaseEmbeddings, BaseLLM, BaseModel
88
from .prompt import BasePromptValue
99
from .report import BaseReport
10-
from .target import BaseTarget, Response
10+
from .target import BaseImageTarget, BaseTarget, ContentFilteredException, Response
1111
from .vectorstore import BaseVectorStore
1212

1313
__all__ = [
@@ -30,6 +30,8 @@
3030
"BasePromptValue",
3131
"BaseReport",
3232
"BaseTarget",
33+
"BaseImageTarget",
3334
"Response",
35+
"ContentFilteredException",
3436
"BaseVectorStore",
3537
]

β€Žaisploit/core/target.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
1+
import base64
2+
import io
13
from abc import ABC, abstractmethod
24
from dataclasses import dataclass, field
3-
from typing import Any, Dict
5+
from typing import Any, Dict, Literal
6+
7+
from PIL import Image
48

59
from .prompt import BasePromptValue
610

711

12+
class ContentFilteredException(Exception):
13+
pass
14+
15+
816
@dataclass
917
class Response:
1018
"""A class representing a response from the target."""
@@ -32,3 +40,15 @@ def send_prompt(self, prompt: BasePromptValue) -> Response:
3240
Response: The response from the target.
3341
"""
3442
pass
43+
44+
45+
@dataclass
46+
class BaseImageTarget(ABC):
47+
size: Literal["512x512", "1024x1024"] = "512x512"
48+
show_image: bool = False
49+
50+
def _show_base64_image(self, base64_image: str) -> None:
51+
base64_bytes = base64_image.encode("ascii")
52+
image_bytes = base64.b64decode(base64_bytes)
53+
image = Image.open(io.BytesIO(image_bytes))
54+
image.show()

β€Žaisploit/targets/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .copilot import CopilotClient, CopilotTarget
22
from .email import EmailReceiver, EmailSender, EmailTarget, UserPasswordAuth
3-
from .image import OpenAIImageTarget
3+
from .image import BedrockAmazonImageTarget, BedrockStabilityImageTarget, OpenAIImageTarget
44
from .langchain import LangchainTarget
55
from .stdout import StdOutTarget
66
from .target import WrapperTarget, target
@@ -12,6 +12,8 @@
1212
"EmailSender",
1313
"EmailReceiver",
1414
"UserPasswordAuth",
15+
"BedrockAmazonImageTarget",
16+
"BedrockStabilityImageTarget",
1517
"OpenAIImageTarget",
1618
"LangchainTarget",
1719
"StdOutTarget",

β€Žaisploit/targets/image.py

Lines changed: 119 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,117 @@
1+
import json
12
import os
2-
from dataclasses import dataclass
3+
from abc import ABC
4+
from dataclasses import dataclass, field
35
from typing import Optional
46

7+
import boto3
8+
from botocore.exceptions import ClientError
59
from openai import OpenAI
610

7-
from ..core import BasePromptValue, BaseTarget, Response
11+
from ..core import BaseImageTarget, BasePromptValue, ContentFilteredException, Response
812

913

1014
@dataclass
11-
class OpenAIImageTarget(BaseTarget):
15+
class BaseBedrockImageTarget(BaseImageTarget, ABC):
16+
session: boto3.Session = field(default_factory=lambda: boto3.Session())
17+
region_name: str = "us-east-1"
18+
19+
def __post_init__(self):
20+
self._client = self.session.client("bedrock-runtime", region_name=self.region_name)
21+
22+
23+
@dataclass
24+
class BedrockAmazonImageTarget(BaseBedrockImageTarget):
25+
model: str = "titan-image-generator-v1"
26+
quality: str = "standard"
27+
seed: int = 0
28+
cfg_scale: int = 8
29+
30+
def send_prompt(self, prompt: BasePromptValue) -> Response:
31+
width, height = self.size.split("x")
32+
body = {
33+
"textToImageParams": {
34+
"text": prompt.to_string(),
35+
},
36+
"taskType": "TEXT_IMAGE",
37+
"imageGenerationConfig": {
38+
"seed": self.seed,
39+
"cfgScale": self.cfg_scale,
40+
"quality": self.quality,
41+
"width": int(width),
42+
"height": int(height),
43+
"numberOfImages": 1,
44+
},
45+
}
46+
47+
try:
48+
response = self._client.invoke_model(
49+
body=json.dumps(body),
50+
modelId=f"amazon.{self.model}",
51+
)
52+
53+
response_body = json.loads(response["body"].read())
54+
55+
if response_body["error"]:
56+
raise Exception(response_body["error"])
57+
58+
base64_image = response_body["images"][0]
59+
60+
if self.show_image:
61+
self._show_base64_image(base64_image)
62+
63+
return Response(content=base64_image)
64+
except ClientError as e:
65+
if e.response['Error']['Code'] == 'ValidationException':
66+
if "blocked by our content filters" in e.response['Error']['Message']:
67+
raise ContentFilteredException(e.response['Error']['Message']) from e
68+
69+
raise e
70+
71+
72+
@dataclass
73+
class BedrockStabilityImageTarget(BaseBedrockImageTarget):
74+
model: str = "stable-diffusion-xl-v1"
75+
steps: int = 50
76+
seed: int = 0
77+
cfg_scale: int = 8
78+
79+
def send_prompt(self, prompt: BasePromptValue) -> Response:
80+
width, height = self.size.split("x")
81+
body = {
82+
"text_prompts": [{"text": prompt.to_string(), "weight": 1}],
83+
"seed": self.seed,
84+
"cfg_scale": self.cfg_scale,
85+
"width": int(width),
86+
"height": int(height),
87+
"steps": self.steps,
88+
}
89+
90+
response = self._client.invoke_model(
91+
body=json.dumps(body),
92+
modelId=f"stability.{self.model}",
93+
)
94+
95+
response_body = json.loads(response["body"].read())
96+
97+
finish_reason = response_body.get("artifacts")[0].get("finishReason")
98+
99+
if finish_reason == "CONTENT_FILTERED":
100+
raise ContentFilteredException(f"Image error: {finish_reason}")
101+
102+
if finish_reason == "ERROR":
103+
raise Exception(f"Image error: {finish_reason}")
104+
105+
base64_image = response_body["artifacts"][0]["base64"]
106+
107+
if self.show_image:
108+
self._show_base64_image(base64_image)
109+
110+
return Response(content=base64_image)
111+
112+
113+
@dataclass
114+
class OpenAIImageTarget(BaseImageTarget):
12115
api_key: Optional[str] = None
13116

14117
def __post_init__(self):
@@ -18,6 +121,16 @@ def __post_init__(self):
18121
self._client = OpenAI(api_key=self.api_key)
19122

20123
def send_prompt(self, prompt: BasePromptValue) -> Response:
21-
response = self._client.images.generate(prompt=prompt.to_string(), n=1)
22-
print(response)
23-
return Response(content="")
124+
response = self._client.images.generate(
125+
prompt=prompt.to_string(),
126+
size=self.size,
127+
n=1,
128+
response_format="b64_json",
129+
)
130+
131+
base64_image = response.data[0].b64_json
132+
133+
if self.show_image:
134+
self._show_base64_image(base64_image)
135+
136+
return Response(content=base64_image)

β€Žexamples/target.ipynb

Lines changed: 73 additions & 13 deletions
Large diffs are not rendered by default.

β€Žpoetry.lock

Lines changed: 67 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

β€Žpyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ python-docx = "^1.1.0"
4949
brotli = "^1.1.0"
5050
stdlib-list = "^0.10.0"
5151
presidio-analyzer = "^2.2.354"
52+
boto3 = "^1.34.88"
5253

5354
[tool.poetry.group.dev.dependencies]
5455
chromadb = "^0.4.23"

0 commit comments

Comments
Β (0)