File tree 5 files changed +91
-13
lines changed
5 files changed +91
-13
lines changed Original file line number Diff line number Diff line change 11
11
from .remove_punctuation import RemovePunctuationConverter
12
12
from .sequence import SequenceConverter
13
13
from .stemming import StemmingConverter
14
+ from .translation import TranslationConverter
14
15
from .unicode_confusable import UnicodeConfusableConverter
15
16
16
17
__all__ = [
27
28
"RemovePunctuationConverter" ,
28
29
"SequenceConverter" ,
29
30
"StemmingConverter" ,
31
+ "TranslationConverter" ,
30
32
"UnicodeConfusableConverter" ,
31
33
]
Original file line number Diff line number Diff line change 1
1
import textwrap
2
+ from dataclasses import dataclass , field
3
+ from typing import Dict
2
4
3
5
from langchain_core .output_parsers import StrOutputParser
4
6
from langchain_core .prompts import ChatPromptTemplate
7
+ from langchain_core .runnables import RunnableSerializable
5
8
6
- from ..core import BaseChatModel , BaseChatModelConverter
9
+ from ..core import BaseChatModelConverter
7
10
8
11
_template = ChatPromptTemplate .from_template (
9
12
textwrap .dedent (
24
27
)
25
28
26
29
30
+ @dataclass
27
31
class GenderConverter (BaseChatModelConverter ):
28
- def __init__ (self , * , chat_model : BaseChatModel , prompt = _template ) -> None :
29
- super ().__init__ (chat_model )
30
- self ._chain = prompt | chat_model | StrOutputParser ()
32
+ prompt : ChatPromptTemplate = field (default_factory = lambda : _template )
33
+ _chain : RunnableSerializable [Dict , str ] = field (init = False )
34
+
35
+ def __post_init__ (self ) -> None :
36
+ self ._chain = self .prompt | self .chat_model | StrOutputParser ()
31
37
32
38
def _convert (self , prompt : str ) -> str :
33
39
return self ._chain .invoke ({"input" : prompt })
Original file line number Diff line number Diff line change
1
+ import textwrap
2
+ from dataclasses import dataclass , field
3
+ from typing import Dict
4
+
5
+ from langchain_core .output_parsers import StrOutputParser
6
+ from langchain_core .prompts import ChatPromptTemplate
7
+ from langchain_core .runnables import RunnableSerializable
8
+
9
+ from ..core import BaseChatModelConverter
10
+
11
+ _template = ChatPromptTemplate .from_template (
12
+ textwrap .dedent (
13
+ """
14
+ Translate the following text to {language}.
15
+ Please maintain the original meaning and context as closely as possible.
16
+
17
+ Input text:
18
+ {input}
19
+
20
+ Translated text:
21
+ """
22
+ )
23
+ )
24
+
25
+
26
+ @dataclass
27
+ class TranslationConverter (BaseChatModelConverter ):
28
+ language : str
29
+ prompt : ChatPromptTemplate = field (default_factory = lambda : _template )
30
+ _chain : RunnableSerializable [Dict , str ] = field (init = False )
31
+
32
+ def __post_init__ (self ) -> None :
33
+ self ._chain = self .prompt | self .chat_model | StrOutputParser ()
34
+
35
+ def _convert (self , prompt : str ) -> str :
36
+ return self ._chain .invoke ({"input" : prompt , "language" : self .language })
Original file line number Diff line number Diff line change 1
1
from abc import ABC , abstractmethod
2
+ from dataclasses import dataclass
2
3
from typing import Union
3
4
4
5
from langchain_core .prompt_values import StringPromptValue
@@ -52,9 +53,9 @@ def __repr__(self) -> str:
52
53
return f"<{ prefix } .{ self .__class__ .__name__ } >"
53
54
54
55
56
+ @dataclass
55
57
class BaseChatModelConverter (BaseConverter , ABC ):
56
- def __init__ (self , chat_model : BaseChatModel ) -> None :
57
- self ._chat_model = chat_model
58
+ chat_model : BaseChatModel
58
59
59
60
def __repr__ (self ) -> str :
60
61
"""Return a string representation of the converter.
@@ -66,4 +67,4 @@ def __repr__(self) -> str:
66
67
if not self .__module__ .startswith (prefix ):
67
68
prefix = "custom"
68
69
69
- return f"<{ prefix } .{ self .__class__ .__name__ } (chat_model={ self ._chat_model .get_name ()} )>"
70
+ return f"<{ prefix } .{ self .__class__ .__name__ } (chat_model={ self .chat_model .get_name ()} )>"
Original file line number Diff line number Diff line change 40
40
" SequenceConverter,\n " ,
41
41
" StemmingConverter,\n " ,
42
42
" UnicodeConfusableConverter,\n " ,
43
+ " TranslationConverter,\n " ,
43
44
" )\n " ,
44
45
" from aisploit.models import ChatOpenAI\n " ,
45
46
" \n " ,
194
195
"cell_type" : " markdown" ,
195
196
"metadata" : {},
196
197
"source" : [
197
- " ## RemovePunctuationConverter "
198
+ " ## TranslationConverter "
198
199
]
199
200
},
200
201
{
201
202
"cell_type" : " code" ,
202
203
"execution_count" : 7 ,
203
204
"metadata" : {},
205
+ "outputs" : [
206
+ {
207
+ "data" : {
208
+ "text/markdown" : [
209
+ " H3ll0, w0rld! H0w 4r3 y0u?"
210
+ ],
211
+ "text/plain" : [
212
+ " <IPython.core.display.Markdown object>"
213
+ ]
214
+ },
215
+ "metadata" : {},
216
+ "output_type" : " display_data"
217
+ }
218
+ ],
219
+ "source" : [
220
+ " converter = TranslationConverter(chat_model=chat_model, language=\" leetspeak\" )\n " ,
221
+ " converted_prompt = converter.convert(\" Hello, world! How are you?\" )\n " ,
222
+ " \n " ,
223
+ " display(Markdown(converted_prompt.to_string()))"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type" : " markdown" ,
228
+ "metadata" : {},
229
+ "source" : [
230
+ " ## RemovePunctuationConverter"
231
+ ]
232
+ },
233
+ {
234
+ "cell_type" : " code" ,
235
+ "execution_count" : 8 ,
236
+ "metadata" : {},
204
237
"outputs" : [
205
238
{
206
239
"data" : {
231
264
},
232
265
{
233
266
"cell_type" : " code" ,
234
- "execution_count" : 8 ,
267
+ "execution_count" : 9 ,
235
268
"metadata" : {},
236
269
"outputs" : [
237
270
{
263
296
},
264
297
{
265
298
"cell_type" : " code" ,
266
- "execution_count" : 9 ,
299
+ "execution_count" : 10 ,
267
300
"metadata" : {},
268
301
"outputs" : [
269
302
{
295
328
},
296
329
{
297
330
"cell_type" : " code" ,
298
- "execution_count" : 10 ,
331
+ "execution_count" : 11 ,
299
332
"metadata" : {},
300
333
"outputs" : [
301
334
{
327
360
},
328
361
{
329
362
"cell_type" : " code" ,
330
- "execution_count" : 11 ,
363
+ "execution_count" : 12 ,
331
364
"metadata" : {},
332
365
"outputs" : [
333
366
{
366
399
},
367
400
{
368
401
"cell_type" : " code" ,
369
- "execution_count" : 12 ,
402
+ "execution_count" : 13 ,
370
403
"metadata" : {},
371
404
"outputs" : [
372
405
{
You can’t perform that action at this time.
0 commit comments