Skip to content

Commit ec4c8e4

Browse files
committed
Add translation converter
1 parent 8e94f24 commit ec4c8e4

File tree

5 files changed

+91
-13
lines changed

5 files changed

+91
-13
lines changed

aisploit/converters/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .remove_punctuation import RemovePunctuationConverter
1212
from .sequence import SequenceConverter
1313
from .stemming import StemmingConverter
14+
from .translation import TranslationConverter
1415
from .unicode_confusable import UnicodeConfusableConverter
1516

1617
__all__ = [
@@ -27,5 +28,6 @@
2728
"RemovePunctuationConverter",
2829
"SequenceConverter",
2930
"StemmingConverter",
31+
"TranslationConverter",
3032
"UnicodeConfusableConverter",
3133
]

aisploit/converters/gender.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import textwrap
2+
from dataclasses import dataclass, field
3+
from typing import Dict
24

35
from langchain_core.output_parsers import StrOutputParser
46
from langchain_core.prompts import ChatPromptTemplate
7+
from langchain_core.runnables import RunnableSerializable
58

6-
from ..core import BaseChatModel, BaseChatModelConverter
9+
from ..core import BaseChatModelConverter
710

811
_template = ChatPromptTemplate.from_template(
912
textwrap.dedent(
@@ -24,10 +27,13 @@
2427
)
2528

2629

30+
@dataclass
2731
class GenderConverter(BaseChatModelConverter):
28-
def __init__(self, *, chat_model: BaseChatModel, prompt=_template) -> None:
29-
super().__init__(chat_model)
30-
self._chain = prompt | chat_model | StrOutputParser()
32+
prompt: ChatPromptTemplate = field(default_factory=lambda: _template)
33+
_chain: RunnableSerializable[Dict, str] = field(init=False)
34+
35+
def __post_init__(self) -> None:
36+
self._chain = self.prompt | self.chat_model | StrOutputParser()
3137

3238
def _convert(self, prompt: str) -> str:
3339
return self._chain.invoke({"input": prompt})

aisploit/converters/translation.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import textwrap
2+
from dataclasses import dataclass, field
3+
from typing import Dict
4+
5+
from langchain_core.output_parsers import StrOutputParser
6+
from langchain_core.prompts import ChatPromptTemplate
7+
from langchain_core.runnables import RunnableSerializable
8+
9+
from ..core import BaseChatModelConverter
10+
11+
_template = ChatPromptTemplate.from_template(
12+
textwrap.dedent(
13+
"""
14+
Translate the following text to {language}.
15+
Please maintain the original meaning and context as closely as possible.
16+
17+
Input text:
18+
{input}
19+
20+
Translated text:
21+
"""
22+
)
23+
)
24+
25+
26+
@dataclass
27+
class TranslationConverter(BaseChatModelConverter):
28+
language: str
29+
prompt: ChatPromptTemplate = field(default_factory=lambda: _template)
30+
_chain: RunnableSerializable[Dict, str] = field(init=False)
31+
32+
def __post_init__(self) -> None:
33+
self._chain = self.prompt | self.chat_model | StrOutputParser()
34+
35+
def _convert(self, prompt: str) -> str:
36+
return self._chain.invoke({"input": prompt, "language": self.language})

aisploit/core/converter.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from abc import ABC, abstractmethod
2+
from dataclasses import dataclass
23
from typing import Union
34

45
from langchain_core.prompt_values import StringPromptValue
@@ -52,9 +53,9 @@ def __repr__(self) -> str:
5253
return f"<{prefix}.{self.__class__.__name__}>"
5354

5455

56+
@dataclass
5557
class BaseChatModelConverter(BaseConverter, ABC):
56-
def __init__(self, chat_model: BaseChatModel) -> None:
57-
self._chat_model = chat_model
58+
chat_model: BaseChatModel
5859

5960
def __repr__(self) -> str:
6061
"""Return a string representation of the converter.
@@ -66,4 +67,4 @@ def __repr__(self) -> str:
6667
if not self.__module__.startswith(prefix):
6768
prefix = "custom"
6869

69-
return f"<{prefix}.{self.__class__.__name__}(chat_model={self._chat_model.get_name()})>"
70+
return f"<{prefix}.{self.__class__.__name__}(chat_model={self.chat_model.get_name()})>"

examples/converter.ipynb

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
" SequenceConverter,\n",
4141
" StemmingConverter,\n",
4242
" UnicodeConfusableConverter,\n",
43+
" TranslationConverter,\n",
4344
")\n",
4445
"from aisploit.models import ChatOpenAI\n",
4546
"\n",
@@ -194,13 +195,45 @@
194195
"cell_type": "markdown",
195196
"metadata": {},
196197
"source": [
197-
"## RemovePunctuationConverter"
198+
"## TranslationConverter"
198199
]
199200
},
200201
{
201202
"cell_type": "code",
202203
"execution_count": 7,
203204
"metadata": {},
205+
"outputs": [
206+
{
207+
"data": {
208+
"text/markdown": [
209+
"H3ll0, w0rld! H0w 4r3 y0u?"
210+
],
211+
"text/plain": [
212+
"<IPython.core.display.Markdown object>"
213+
]
214+
},
215+
"metadata": {},
216+
"output_type": "display_data"
217+
}
218+
],
219+
"source": [
220+
"converter = TranslationConverter(chat_model=chat_model, language=\"leetspeak\")\n",
221+
"converted_prompt = converter.convert(\"Hello, world! How are you?\")\n",
222+
"\n",
223+
"display(Markdown(converted_prompt.to_string()))"
224+
]
225+
},
226+
{
227+
"cell_type": "markdown",
228+
"metadata": {},
229+
"source": [
230+
"## RemovePunctuationConverter"
231+
]
232+
},
233+
{
234+
"cell_type": "code",
235+
"execution_count": 8,
236+
"metadata": {},
204237
"outputs": [
205238
{
206239
"data": {
@@ -231,7 +264,7 @@
231264
},
232265
{
233266
"cell_type": "code",
234-
"execution_count": 8,
267+
"execution_count": 9,
235268
"metadata": {},
236269
"outputs": [
237270
{
@@ -263,7 +296,7 @@
263296
},
264297
{
265298
"cell_type": "code",
266-
"execution_count": 9,
299+
"execution_count": 10,
267300
"metadata": {},
268301
"outputs": [
269302
{
@@ -295,7 +328,7 @@
295328
},
296329
{
297330
"cell_type": "code",
298-
"execution_count": 10,
331+
"execution_count": 11,
299332
"metadata": {},
300333
"outputs": [
301334
{
@@ -327,7 +360,7 @@
327360
},
328361
{
329362
"cell_type": "code",
330-
"execution_count": 11,
363+
"execution_count": 12,
331364
"metadata": {},
332365
"outputs": [
333366
{
@@ -366,7 +399,7 @@
366399
},
367400
{
368401
"cell_type": "code",
369-
"execution_count": 12,
402+
"execution_count": 13,
370403
"metadata": {},
371404
"outputs": [
372405
{

0 commit comments

Comments
 (0)