Skip to content

Commit 6dc9bb2

Browse files
committed
Update app.py: remove model.cpython-38.pyc binary diff, add generate\_response function from model.py, update model paths, and modify chat route to use generate\_response
* Remove binary diff of model
1 parent 0322308 commit 6dc9bb2

File tree

2 files changed

+26
-20
lines changed

2 files changed

+26
-20
lines changed

__pycache__/model.cpython-38.pyc

0 Bytes
Binary file not shown.

app.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,29 @@
22
from transformers import AutoModelForCausalLM, AutoTokenizer
33
from transformers import Conversation as Conversation
44
from flask_cors import CORS
5+
from model import generate_response
56

67
app = Flask(__name__)
78
CORS(app)
89

910
model_paths = {
10-
'meena': 'google/meena-chatbot',
11-
'dia_small': 'microsoft/DialoGPT-small',
1211
'dia_medium': 'microsoft/DialoGPT-medium',
13-
'dia_large': 'microsoft/DialoGPT-large',
14-
'blender_400m': 'facebook/blenderbot-400M',
15-
'blender_90m': 'facebook/blenderbot-90M',
12+
'dia_small': 'microsoft/DialoGPT-small',
13+
'blender_90m': 'facebook/blenderbot-90M'
14+
15+
# 'dia_large': 'microsoft/DialoGPT-large',
16+
# 'blender_400m': 'facebook/blenderbot-400M',
17+
# 'octo': 'NexaAIDev/Octopus-v2',
18+
# 'dia_base': 'microsoft/DialoGPT-base',
1619
}
1720

1821
chat_model = None
1922
chat_tokenizer = None
2023
conversations = []
2124
model_name = None
2225

26+
step = 0
27+
2328
@app.route('/model', methods=['GET', 'POST'])
2429
def post_model():
2530
global chat_model, chat_tokenizer, model_name
@@ -33,6 +38,7 @@ def post_model():
3338

3439
chat_model = AutoModelForCausalLM.from_pretrained(path)
3540
chat_tokenizer = AutoTokenizer.from_pretrained(path)
41+
chat_tokenizer.pad_token = chat_tokenizer.eos_token
3642

3743
# Add CORS headers to the response
3844
response_headers = {
@@ -45,31 +51,31 @@ def post_model():
4551

4652
@app.route('/chat', methods=['GET', 'POST'])
4753
def chat():
48-
global conversations, chat_model, chat_tokenizer
54+
global step, tokenizer, last_output, conversations
55+
4956
data = request.get_json()
5057
message = data['message']
51-
conversation_id = data['conversation_id']
52-
53-
if len(conversations) <= conversation_id:
54-
conversations.append(Conversation())
55-
56-
conversations[conversation_id].add_user_input({'role': 'user', 'content': message})
57-
inputs = chat_tokenizer(conversations[conversation_id].messages, return_tensors='pt')
58-
reply = chat_model.generate(inputs.input_ids, max_length=50)
59-
60-
conversations[conversation_id].add_system_message({'role': 'system', 'content': reply})
61-
58+
# conversation_id = data['conversation_id']
59+
60+
reply = generate_response(chat_model, chat_tokenizer, message, step, max_length=250)
61+
6262
# Add CORS headers to the response
6363
response_headers = {
6464
'Access-Control-Allow-Origin': '*', # Change the '*' to the appropriate origin if needed
6565
'Access-Control-Allow-Headers': 'Content-Type',
6666
'Access-Control-Allow-Methods': 'POST'
6767
}
6868

69+
step += 1
70+
6971
return jsonify({'reply': reply}), 200, response_headers
7072

7173
if __name__ == '__main__':
72-
test = Conversation()
74+
# Add CORS headers to the response
75+
response_headers = {
76+
'Access-Control-Allow-Origin': '*', # Change the '*' to the appropriate origin if needed
77+
'Access-Control-Allow-Headers': 'Content-Type',
78+
'Access-Control-Allow-Methods': 'POST'
79+
}
7380

74-
test.add_message({'':''})
75-
app.run(port=5000)
81+
app.run(port=8000)

0 commit comments

Comments
 (0)