-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvideo_aug_server.py
159 lines (127 loc) · 4.72 KB
/
video_aug_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import os
import io
import time
from PIL import Image
#Flask
from flask import Flask, request, jsonify
#NanoLLM
from nano_llm import NanoLLM, ChatHistory
from nano_llm.utils import ArgParser
model_name = 'liuhaotian/llava-v1.5-13b'
model = None
chat_history = None
args = None
isStreaming = False
system_prompt = """
You are LLaVA, a multimodal AI designed to analyze, describe, and interpret visual data.
When given an image, provide clear, descriptive, and insightful observations.
Combine your visual understanding with textual context to answer questions or provide
assistance.
Ensure responses are user-friendly, relevant, and concise.
Only respond in English!
"""
app = Flask(__name__)
def setup():
try:
global model, chat_history, model_name, args, isStreaming, system_prompt
#Setting Args
parser = ArgParser(extras=ArgParser.Defaults)
args = parser.parse_args()
args.vision_api = 'hf'
args.max_content_len = 100
args.system_prompt = system_prompt
args.min_new_tokens = 100
args.max_new_tokens = 300
print(args)
#Load model
print('Loading LLAVA...')
model = NanoLLM.from_pretrained(
model=model_name,
api='mlc',
quantization='q4f16_ft',
max_content_len = args.max_content_len,
vision_api = args.vision_api,
vision_model = args.vision_model,
vision_scaling = args.vision_scaling,
print_stats=False
)
print('LLAVA Loaded')
print(f"DOES THIS MODEL HAVE VISION: {model.has_vision}")
print(f"VISION API : {args.vision_api} >> Vision Model : {args.vision_model} >> {args.chat_template}")
#Create Chat History
print('Creating Chat History...')
chat_history = ChatHistory(model, args.chat_template, args.system_prompt)
chat_history.append(role="user", text="What is machine learning? -explain in english")
print('Chat History created')
#Embeddings and Warm up Reply
print("Making Embeddings and generating warm up reply...")
embedding, _ = chat_history.embed_chat()
reply = model.generate(
embedding,
kv_cache= chat_history.kv_cache,
max_new_tokens= args.max_new_tokens,
min_new_tokens= args.min_new_tokens,
do_sample= args.do_sample,
repetition_penalty= args.repetition_penalty,
temperature= args.temperature,
top_p= args.top_p,
streaming= isStreaming
)
print(f"Warm up reply: {reply}")
return 'Setup Completed!'
except Exception as error:
print(f'ERROR: {error}')
return 'Setup failed - See Error Above'
#Flask Route: /query POST
@app.route('/query', methods=['POST'])
def query():
global model, chat_history, args, isStreaming, system_prompt
#Check for user input
if 'image' not in request.files or 'text' not in request.form:
return jsonify({'error': 'Image and text are required.'}), 400
#Gets Image and Text data from request
image_file = request.files['image']
text_data = request.form['text']
#Get prompt
prompt = text_data.strip()
print(f'Prompt: {prompt}')
#Sanitize query input
if len(prompt) == 0:
return jsonify({'reply': 'Query cannot be empty!'}), 400
#See if image is in right format
image = None
try:
image = Image.open(image_file)
except Exception as e:
return jsonify({'error': 'Invalid image format.', 'details': str(e)}), 400
#Reset Chat History
chat_history.reset(system_prompt=system_prompt)
#Add image
chat_history.append('user', image)
#Add user prompt
chat_history.append('user', prompt)
#Generate chat embeddings
embedding, _= chat_history.embed_chat()
#Generate chat bot reply
try:
reply = model.generate(
embedding,
kv_cache=chat_history.kv_cache,
max_new_tokens=args.max_new_tokens,
min_new_tokens=args.min_new_tokens,
do_sample=args.do_sample,
repetition_penalty=args.repetition_penalty,
temperature=args.temperature,
top_p=args.top_p,
streaming=isStreaming
)
response = reply
response = response.replace("\n", "").replace("</s>", "").replace("<s>", "")
print(f'Response: {response}')
return jsonify({'reply': f'{response}'}), 200
except Exception as error:
print(f'ERROR: {error}')
#Runs setup and starts app
if __name__ == '__main__':
print(setup())
app.run(host='0.0.0.0', debug=True, port=5000)