Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP feat: migrate to openai>=1.0.0 #98

Draft
wants to merge 6 commits into
base: next
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
57 changes: 33 additions & 24 deletions BotHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
from JSONReaderWriter import load_json
from main import __version__

import base64

# User commands
BOT_COMMAND_START = "start"
BOT_COMMAND_HELP = "help"
Expand Down Expand Up @@ -303,36 +305,43 @@ def build_markup(
return InlineKeyboardMarkup(build_menu(buttons, n_cols=2))


async def parse_img(img_source: str):
async def parse_img(img_source: str | (str, str)):
"""
Test if an image source is valid
:param img_source:
:return:
"""
try:
res = requests.head(
img_source,
timeout=10,
headers={
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; WOW64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/91.4472.114 Safari/537.36"
},
allow_redirects=True,
)
content_type = res.headers.get("content-type")
if not content_type.startswith("image"):
raise Exception("Not Image")
if content_type == "image/svg+xml":
raise Exception("SVG Image")
except Exception as e:
logging.warning(
"Invalid image from {}: {}, You can ignore this message".format(
img_source, str(e)
if isinstance(img_source, str):
try:
res = requests.head(
img_source,
timeout=10,
headers={
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; WOW64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/91.4472.114 Safari/537.36"
},
allow_redirects=True,
)
)
return None
return img_source
content_type = res.headers.get("content-type")
if not content_type.startswith("image"):
raise Exception("Not Image")
if content_type == "image/svg+xml":
raise Exception("SVG Image")
except Exception as e:
logging.warning(
"Invalid image from {}: {}, You can ignore this message".format(
img_source, str(e)
)
)
return None
return img_source

img_type, img = img_source
if img_type == "base64":
return base64.b64decode(img)

raise Exception("Unknown image type {}".format(img_type))


async def _split_and_send_message_async(
Expand Down
41 changes: 33 additions & 8 deletions DALLEModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import logging
from typing import List, Dict

import openai
import httpx
from openai import OpenAI

import BotHandler
import UsersHandler
Expand All @@ -30,6 +31,7 @@ def __init__(self, config: dict, messages: List[Dict], users_handler: UsersHandl
self.config = config
self.messages = messages
self.users_handler = users_handler
self.client = None

def initialize(self, proxy=None) -> None:
"""
Expand All @@ -53,11 +55,14 @@ def initialize(self, proxy=None) -> None:
raise Exception("DALL-E module disabled in config file!")

# Set Key
openai.api_key = self.config["dalle"]["open_ai_api_key"]
api_key = self.config["dalle"]["open_ai_api_key"]

http_client = None
# Set proxy
if proxy:
openai.proxy = proxy
http_client = httpx.Client(proxies=proxy)

self.client = OpenAI(api_key=api_key, http_client=http_client)

# Done?
logging.info("DALL-E module initialized")
Expand Down Expand Up @@ -90,19 +95,39 @@ def process_request(self, request_response: RequestResponseContainer) -> None:

# Generate image
logging.info("Requesting image from DALL-E")
image_response = openai.Image.create(prompt=request_response.request,

request = request_response.request
params = {
"style": "vivid",
"quality": "standard",
"size": self.config["dalle"]["image_size"],
}
if request_response.request.startswith("?"):
space_index = request.index(" ")
param_str = request[1:space_index]
request = request[space_index + 1:]

for p in param_str.split(","):
[name, value] = p.split("=")
params[name] = value

image_response = self.client.images.generate(prompt=request,
n=1,
size=self.config["dalle"]["image_size"])
response_url = image_response["data"][0]["url"]
model="dall-e-3",
size=params["size"],
style=params["style"],
quality=params["quality"],
response_format="b64_json")
response_b64 = image_response.data[0].b64_json

# Check response
if not response_url or len(response_url) < 1:
if not response_b64 or len(response_b64) < 1:
raise Exception("Wrong DALL-E response!")

# OK?
logging.info("Response successfully processed for user {0} ({1})"
.format(request_response.user["user_name"], request_response.user["user_id"]))
request_response.response = response_url
request_response.response_images = [("base64", response_b64)]

# Exit requested
except KeyboardInterrupt:
Expand Down
10 changes: 9 additions & 1 deletion QueueHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,15 @@ def _collect_data(self, request_response: RequestResponseContainer, log_request=
# Images
for image_url in request_response.response_images:
try:
response = base64.b64encode(requests.get(image_url, timeout=120).content).decode("utf-8")
response = None
if isinstance(image_url, str):
response = base64.b64encode(requests.get(image_url, timeout=120).content).decode("utf-8")
else:
img_type, img = image_url
if img_type == "base64":
response = img
else:
raise Exception(f"Unknown image type: {img_type}")
log_file.write(response_str_to_format.format(request_response.response_timestamp,
request_response.id,
request_response.user["user_name"],
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ git+https://github.com/F33RNI/EdgeGPT@main#egg=EdgeGPT
git+https://github.com/dsdanielpark/Bard-API@main
git+https://github.com/handsome0hell/md2tgmd.git@main
python-telegram-bot==20.3
openai>=0.26.4
openai>=1.0.0
tiktoken>=0.2.0
OpenAIAuth>=0.3.2
asyncio>=3.4.3
Expand Down