2
2
from transformers import AutoModelForCausalLM , AutoTokenizer
3
3
from transformers import Conversation as Conversation
4
4
from flask_cors import CORS
5
+ from model import generate_response
5
6
6
7
app = Flask (__name__ )
7
8
CORS (app )
8
9
9
10
model_paths = {
10
- 'meena' : 'google/meena-chatbot' ,
11
- 'dia_small' : 'microsoft/DialoGPT-small' ,
12
11
'dia_medium' : 'microsoft/DialoGPT-medium' ,
13
- 'dia_large' : 'microsoft/DialoGPT-large' ,
14
- 'blender_400m' : 'facebook/blenderbot-400M' ,
15
- 'blender_90m' : 'facebook/blenderbot-90M' ,
12
+ 'dia_small' : 'microsoft/DialoGPT-small' ,
13
+ 'blender_90m' : 'facebook/blenderbot-90M'
14
+
15
+ # 'dia_large': 'microsoft/DialoGPT-large',
16
+ # 'blender_400m': 'facebook/blenderbot-400M',
17
+ # 'octo': 'NexaAIDev/Octopus-v2',
18
+ # 'dia_base': 'microsoft/DialoGPT-base',
16
19
}
17
20
18
21
chat_model = None
19
22
chat_tokenizer = None
20
23
conversations = []
21
24
model_name = None
22
25
26
+ step = 0
27
+
23
28
@app .route ('/model' , methods = ['GET' , 'POST' ])
24
29
def post_model ():
25
30
global chat_model , chat_tokenizer , model_name
@@ -33,6 +38,7 @@ def post_model():
33
38
34
39
chat_model = AutoModelForCausalLM .from_pretrained (path )
35
40
chat_tokenizer = AutoTokenizer .from_pretrained (path )
41
+ chat_tokenizer .pad_token = chat_tokenizer .eos_token
36
42
37
43
# Add CORS headers to the response
38
44
response_headers = {
@@ -45,31 +51,31 @@ def post_model():
45
51
46
52
@app .route ('/chat' , methods = ['GET' , 'POST' ])
47
53
def chat ():
48
- global conversations , chat_model , chat_tokenizer
54
+ global step , tokenizer , last_output , conversations
55
+
49
56
data = request .get_json ()
50
57
message = data ['message' ]
51
- conversation_id = data ['conversation_id' ]
52
-
53
- if len (conversations ) <= conversation_id :
54
- conversations .append (Conversation ())
55
-
56
- conversations [conversation_id ].add_user_input ({'role' : 'user' , 'content' : message })
57
- inputs = chat_tokenizer (conversations [conversation_id ].messages , return_tensors = 'pt' )
58
- reply = chat_model .generate (inputs .input_ids , max_length = 50 )
59
-
60
- conversations [conversation_id ].add_system_message ({'role' : 'system' , 'content' : reply })
61
-
58
+ # conversation_id = data['conversation_id']
59
+
60
+ reply = generate_response (chat_model , chat_tokenizer , message , step , max_length = 250 )
61
+
62
62
# Add CORS headers to the response
63
63
response_headers = {
64
64
'Access-Control-Allow-Origin' : '*' , # Change the '*' to the appropriate origin if needed
65
65
'Access-Control-Allow-Headers' : 'Content-Type' ,
66
66
'Access-Control-Allow-Methods' : 'POST'
67
67
}
68
68
69
+ step += 1
70
+
69
71
return jsonify ({'reply' : reply }), 200 , response_headers
70
72
71
73
if __name__ == '__main__' :
72
- test = Conversation ()
74
+ # Add CORS headers to the response
75
+ response_headers = {
76
+ 'Access-Control-Allow-Origin' : '*' , # Change the '*' to the appropriate origin if needed
77
+ 'Access-Control-Allow-Headers' : 'Content-Type' ,
78
+ 'Access-Control-Allow-Methods' : 'POST'
79
+ }
73
80
74
- test .add_message ({'' :'' })
75
- app .run (port = 5000 )
81
+ app .run (port = 8000 )
0 commit comments