Skip to content

Commit

Permalink
bugfix: Improvements on GPT4V
Browse files Browse the repository at this point in the history
  • Loading branch information
GaiZhenbiao committed Mar 5, 2024
1 parent e2d069a commit 414815f
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 74 deletions.
101 changes: 28 additions & 73 deletions modules/models/OpenAIVision.py
Expand Up @@ -43,7 +43,6 @@ def __init__(
self.api_key = api_key
self.need_api_key = True
self.max_generation_token = 4096
self.images = []
self._refresh_header()

def get_answer_stream_iter(self):
Expand All @@ -64,68 +63,6 @@ def get_answer_at_once(self):
total_token_count = response["usage"]["total_tokens"]
return content, total_token_count

def try_read_image(self, filepath):
def is_image_file(filepath):
# 判断文件是否为图片
valid_image_extensions = [
".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"]
file_extension = os.path.splitext(filepath)[1].lower()
return file_extension in valid_image_extensions
def image_to_base64(image_path):
# 打开并加载图片
img = Image.open(image_path)

# 获取图片的宽度和高度
width, height = img.size

# 计算压缩比例,以确保最长边小于4096像素
max_dimension = 2048
scale_ratio = min(max_dimension / width, max_dimension / height)

if scale_ratio < 1:
# 按压缩比例调整图片大小
width = int(width * scale_ratio)
height = int(height * scale_ratio)
img = img.resize((width, height), Image.LANCZOS)
# 使用新的宽度和高度计算图片的token数量
self.image_token = self.count_image_tokens(width, height)

# 将图片转换为jpg格式的二进制数据
buffer = BytesIO()
if img.mode == "RGBA":
img = img.convert("RGB")
img.save(buffer, format='JPEG')
binary_image = buffer.getvalue()

# 对二进制数据进行Base64编码
base64_image = base64.b64encode(binary_image).decode('utf-8')

return base64_image

if is_image_file(filepath):
logging.info(f"读取图片文件: {filepath}")
base64_image = image_to_base64(filepath)
self.images.append({
"path": filepath,
"base64": base64_image,
})

def handle_file_upload(self, files, chatbot, language):
"""if the model accepts multi modal input, implement this function"""
if files:
for file in files:
if file.name:
self.try_read_image(file.name)
if self.images is not None:
chatbot = chatbot + [([image["path"] for image in self.images], None)]
return None, chatbot, None

def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot):
fake_inputs = real_inputs
display_append = ""
limited_context = False
return limited_context, fake_inputs, display_append, real_inputs, chatbot


def count_token(self, user_input):
input_token_count = count_token(construct_user(user_input))
Expand Down Expand Up @@ -185,20 +122,38 @@ def billing_info(self):
logging.error(i18n("获取API使用情况失败:") + str(e))
return STANDARD_ERROR_MSG + ERROR_RETRIEVE_MSG

def _get_gpt4v_style_history(self):
history = []
image_buffer = []
for message in self.history:
if message["role"] == "user":
content = []
if image_buffer:
for image in image_buffer:
content.append(
{
"type": "image_url",
"image_url": f"data:image/{self.get_image_type(image)};base64,{self.get_base64_image(image)}"
},
)
if content:
content.insert(0, {"type": "text", "text": message["content"]})
history.append(construct_user(content))
image_buffer = []
else:
history.append(message)
elif message["role"] == "assistant":
history.append(message)
elif message["role"] == "image":
image_buffer.append(message["content"])
return history


@shared.state.switching_api_key # 在不开启多账号模式的时候,这个装饰器不会起作用
def _get_response(self, stream=False):
openai_api_key = self.api_key
system_prompt = self.system_prompt
history = self.history
if self.images:
self.history[-1]["content"] = [
{"type": "text", "text": self.history[-1]["content"]},
*[{"type": "image_url", "image_url": "data:image/jpeg;base64,"+image["base64"]} for image in self.images]
]
self.images = []
# 添加图片token到总计数中
self.all_token_counts[-1] += self.image_token
self.image_token = 0
history = self._get_gpt4v_style_history()

logging.debug(colorama.Fore.YELLOW +
f"{history}" + colorama.Fore.RESET)
Expand Down
4 changes: 3 additions & 1 deletion modules/models/base_model.py
Expand Up @@ -423,7 +423,9 @@ def handle_file_upload(self, files, chatbot, language):
import traceback
traceback.print_exc()
status = i18n("索引构建失败!") + str(e)
if not other_files:
if other_files:
other_files = [f.name for f in other_files]
else:
other_files = None
return gr.File.update(value=other_files), chatbot, status

Expand Down
1 change: 1 addition & 0 deletions modules/presets.py
Expand Up @@ -160,6 +160,7 @@
"GPT4 Vision": {
"model_name": "gpt-4-vision-preview",
"token_limit": 128000,
"multimodal": True
},
"Claude": {
"model_name": "Claude",
Expand Down

0 comments on commit 414815f

Please sign in to comment.