Skip to content

Commit 97e789f

Browse files
committed
Misc
1 parent 4a16bef commit 97e789f

File tree

8 files changed

+428
-248
lines changed

8 files changed

+428
-248
lines changed

aisploit/classifiers/self_similarity.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,22 @@
11
from dataclasses import dataclass, field
22
from typing import Any, Dict, List, Literal
33

4-
from sentence_transformers import SentenceTransformer
5-
from sentence_transformers.util import cos_sim
4+
import torch
5+
import torch.utils
66

7-
from ..core import BaseTextClassifier, Score
7+
from ..core import BaseEmbeddings, BaseTextClassifier, Score
8+
from ..embeddings import HuggingFaceEmbeddings
89

910

10-
@dataclass
11+
@dataclass(kw_only=True)
1112
class SelfSimilarityClassifier(BaseTextClassifier[Dict[str, Any]]):
1213
"""A text classifier based on self-similarity using cosine similarity scores."""
1314

14-
model_name_or_path: str = "all-MiniLM-L6-v2"
15+
embeddings: BaseEmbeddings = field(default_factory=lambda: HuggingFaceEmbeddings())
1516
threshold: float = 0.7
1617
aggregation: Literal["mean", "min"] = "mean"
1718
tags: List[str] = field(default_factory=lambda: ["hallucination"], init=False)
1819

19-
def __post_init__(self) -> None:
20-
"""Initialize the SentenceTransformer model."""
21-
self._model = SentenceTransformer(self.model_name_or_path)
22-
2320
def score(self, input: str, references: List[str] | None = None) -> Score[Dict[str, Any]]:
2421
"""Score the input text based on its self-similarity to reference texts.
2522
@@ -36,15 +33,17 @@ def score(self, input: str, references: List[str] | None = None) -> Score[Dict[s
3633
if not references or not len(references) >= 1:
3734
raise ValueError("The number of references must be at least 1.")
3835

39-
input_embeddings = self._model.encode(input, convert_to_tensor=True)
40-
references_embeddings = self._model.encode(references, convert_to_tensor=True)
36+
input_embeddings = torch.tensor(self.embeddings.embed_query(input))
37+
38+
references_embeddings = torch.tensor(self.embeddings.embed_documents(references))
4139

42-
cos_scores = cos_sim(input_embeddings, references_embeddings)[0]
40+
# Calculate cosine similarity
41+
cos_scores = torch.nn.functional.cosine_similarity(input_embeddings.unsqueeze(0), references_embeddings, dim=1)
4342

4443
score = cos_scores.mean() if self.aggregation == "mean" else cos_scores.min()
4544

4645
return Score[Dict[str, Any]](
47-
flagged=(score < self.threshold).item(),
46+
flagged=bool(score < self.threshold),
4847
value={
4948
"aggregated_score": score.item(),
5049
"scores": cos_scores.tolist(),

aisploit/embeddings/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from .bedrock import BedrockEmbeddings
22
from .google import GoogleGenerativeAIEmbeddings
3+
from .huggingface import HuggingFaceEmbeddings
34
from .ollama import OllamaEmbeddings
45
from .openai import OpenAIEmbeddings
56

67
__all__ = [
78
"BedrockEmbeddings",
89
"GoogleGenerativeAIEmbeddings",
10+
"HuggingFaceEmbeddings",
911
"OllamaEmbeddings",
1012
"OpenAIEmbeddings",
1113
]

aisploit/embeddings/huggingface.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from langchain_community.embeddings import (
2+
HuggingFaceEmbeddings as LangchainHuggingFaceEmbeddings,
3+
)
4+
5+
from ..core import BaseEmbeddings
6+
7+
8+
class HuggingFaceEmbeddings(LangchainHuggingFaceEmbeddings, BaseEmbeddings):
9+
def __init__(
10+
self,
11+
*,
12+
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
13+
**kwargs,
14+
) -> None:
15+
super().__init__(
16+
model_name=model_name,
17+
**kwargs,
18+
)

aisploit/scanner/plugins/self_similarity.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,30 @@
11
from dataclasses import dataclass, field
2-
from typing import List, Sequence
2+
from typing import List, Literal, Sequence
33

44
from ..plugin import Plugin
55
from ..report import Issue, IssueCategory
66
from ...classifiers import SelfSimilarityClassifier
77
from ...converters import NoOpConverter
8-
from ...core import BaseConverter, BaseTarget
8+
from ...core import BaseConverter, BaseEmbeddings, BaseTarget
9+
from ...embeddings import HuggingFaceEmbeddings
910
from ...sender import SenderJob
1011

1112

1213
@dataclass(kw_only=True)
1314
class SelfSimilarityPlugin(Plugin):
1415
questions: List[str] # TODO dataset
1516
num_samples: int = 3
16-
model_name_or_path: str = "all-MiniLM-L6-v2"
17+
embeddings: BaseEmbeddings = field(default_factory=lambda: HuggingFaceEmbeddings())
1718
threshold: float = 0.7
19+
aggregation: Literal['mean', 'min'] = "mean"
1820
converters: List[BaseConverter] = field(default_factory=lambda: [NoOpConverter()])
1921
name: str = field(default="self_similarity", init=False)
2022

2123
def __post_init__(self) -> None:
2224
self._classifier = SelfSimilarityClassifier(
23-
model_name_or_path=self.model_name_or_path,
25+
embeddings=self.embeddings,
2426
threshold=self.threshold,
27+
aggregation=self.aggregation,
2528
)
2629

2730
def run(self, *, run_id: str, target: BaseTarget) -> Sequence[Issue]:

docs/scanner.ipynb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
" domain=\"cxd47vgx2z2qyzr637trlgzogfm6ayyn.oastify.com\"\n",
3434
" ),\n",
3535
" ],\n",
36-
")"
36+
")\n",
37+
"\n",
38+
"# job.execute()"
3739
]
3840
},
3941
{

examples/classifier.ipynb

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,23 @@
4444
},
4545
{
4646
"cell_type": "code",
47-
"execution_count": 2,
47+
"execution_count": 3,
4848
"metadata": {},
4949
"outputs": [
50+
{
51+
"name": "stdout",
52+
"output_type": "stream",
53+
"text": [
54+
"torch.Size([384]) torch.Size([1, 384])\n"
55+
]
56+
},
5057
{
5158
"data": {
5259
"text/plain": [
5360
"Score(flagged=True, value={'aggregated_score': 0.6721476912498474, 'scores': [0.6721476912498474]}, description='Returns True if the aggregated cosine similarity score is less than the threshold', explanation='The aggregated cosine similarity score for the input is 0.6721476912498474')"
5461
]
5562
},
56-
"execution_count": 2,
63+
"execution_count": 3,
5764
"metadata": {},
5865
"output_type": "execute_result"
5966
}

0 commit comments

Comments
 (0)