-
-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
289 additions
and
133 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,6 @@ docker-compose.yml | |
flake.* | ||
|
||
docs | ||
downloads | ||
infra | ||
venv | ||
templates | ||
|
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,27 @@ | ||
# SPDX-FileCopyrightText: 2023 Joonas Rautiola <[email protected]> | ||
# SPDX-FileCopyrightText: 2024 Joonas Rautiola <[email protected]> | ||
# SPDX-License-Identifier: MPL-2.0 | ||
# https://git.joinemm.dev/miso-bot | ||
|
||
import asyncio | ||
import io | ||
import subprocess | ||
from typing import Any | ||
from typing import TYPE_CHECKING, Any | ||
|
||
import arrow | ||
import discord | ||
import regex | ||
import yarl | ||
from aiohttp import BasicAuth, ClientConnectorError | ||
from aiohttp import ClientConnectorError | ||
from attr import dataclass | ||
from discord.ext import commands | ||
from discord.ui import View | ||
from loguru import logger | ||
|
||
from modules import emojis, exceptions, instagram, util | ||
from modules.misobot import MisoBot | ||
from modules.tiktok import TikTok | ||
|
||
if TYPE_CHECKING: | ||
from modules.misobot import MisoBot | ||
|
||
|
||
@dataclass | ||
class InstagramPost: | ||
|
@@ -57,7 +58,7 @@ class BaseEmbedder: | |
EMOJI = "..." | ||
|
||
def __init__(self, bot) -> None: | ||
self.bot: MisoBot = bot | ||
self.bot: "MisoBot" = bot | ||
|
||
@staticmethod | ||
def get_options(text: str) -> Options: | ||
|
@@ -123,9 +124,11 @@ async def download_media( | |
# The url params are unescaped by aiohttp's built-in yarl | ||
# This causes problems with the hash-based request signing that instagram uses | ||
# Thankfully you can plug your own yarl.URL with encoded=True so it wont get encoded twice | ||
headers = {"User-Agent": util.random_user_agent()} | ||
async with self.bot.session.get( | ||
yarl.URL(media_url, encoded=True), headers=headers | ||
yarl.URL(media_url, encoded=True), | ||
headers={ | ||
"User-Agent": util.random_user_agent(), | ||
}, | ||
) as response: | ||
if not response.ok: | ||
if response.headers.get("Content-Type") == "text/plain": | ||
|
@@ -140,22 +143,28 @@ async def download_media( | |
content_length = response.headers.get( | ||
"Content-Length" | ||
) or response.headers.get("x-full-image-content-length") | ||
if content_length and int(content_length) < max_filesize: | ||
try: | ||
buffer = io.BytesIO(await response.read()) | ||
return discord.File(fp=buffer, filename=filename, spoiler=spoiler) | ||
except asyncio.TimeoutError: | ||
pass | ||
try: | ||
# try to stream until we hit our limit | ||
buffer = b"" | ||
async for chunk in response.content.iter_chunked(1024): | ||
buffer += chunk | ||
if len(buffer) > max_filesize: | ||
if content_length: | ||
if int(content_length) < max_filesize: | ||
try: | ||
file = io.BytesIO(await response.read()) | ||
return discord.File( | ||
fp=file, filename=filename, spoiler=spoiler | ||
) | ||
except asyncio.TimeoutError: | ||
pass | ||
else: | ||
raise ValueError | ||
return discord.File( | ||
fp=io.BytesIO(buffer), filename=filename, spoiler=spoiler | ||
) | ||
else: | ||
# try to stream until we hit our limit | ||
buffer: bytes = b"" | ||
async for chunk in response.content.iter_chunked(1024): | ||
buffer += chunk | ||
if len(buffer) > max_filesize: | ||
raise ValueError | ||
return discord.File( | ||
fp=io.BytesIO(buffer), filename=filename, spoiler=spoiler | ||
) | ||
except ValueError: | ||
pass | ||
|
||
|
@@ -179,8 +188,7 @@ async def send( | |
ctx.channel, media, options=options | ||
) | ||
msg = await ctx.send(**message_contents) | ||
message_contents["view"].message_ref = msg | ||
message_contents["view"].approved_deletors.append(ctx.author) | ||
await self.msg_post_process(msg, message_contents["view"], ctx.author) | ||
|
||
async def send_contextless( | ||
self, | ||
|
@@ -192,8 +200,7 @@ async def send_contextless( | |
"""Send the media without relying on command context, for example in a message event""" | ||
message_contents = await self.create_message(channel, media, options=options) | ||
msg = await channel.send(**message_contents) | ||
message_contents["view"].message_ref = msg | ||
message_contents["view"].approved_deletors.append(author) | ||
await self.msg_post_process(msg, message_contents["view"], author) | ||
|
||
async def send_reply( | ||
self, | ||
|
@@ -211,8 +218,16 @@ async def send_reply( | |
# the original message was deleted, so we can't reply | ||
msg = await message.channel.send(**message_contents) | ||
|
||
message_contents["view"].message_ref = msg | ||
message_contents["view"].approved_deletors.append(message.author) | ||
await self.msg_post_process(msg, message_contents["view"], message.author) | ||
|
||
async def msg_post_process( | ||
self, | ||
msg: discord.Message, | ||
view: discord.ui, | ||
author: discord.User, | ||
): | ||
view.message_ref = msg | ||
view.approved_deletors.append(author) | ||
|
||
|
||
class RedditEmbedder(BaseEmbedder): | ||
|
@@ -235,118 +250,47 @@ async def create_message( | |
reddit_post_id: str, | ||
options: Options | None = None, | ||
): | ||
user_agent = "Miso Bot (by Joinemm)" | ||
token = self.bot.reddit_access_token | ||
now = arrow.utcnow().timestamp() | ||
if token["expiry"] < now: | ||
async with self.bot.session.post( | ||
"https://www.reddit.com/api/v1/access_token", | ||
headers={"User-Agent": user_agent}, | ||
data={"grant_type": "client_credentials"}, | ||
auth=BasicAuth( | ||
self.bot.keychain.REDDIT_CLIENT_ID, | ||
self.bot.keychain.REDDIT_CLIENT_SECRET, | ||
), | ||
) as response: | ||
data = await response.json() | ||
self.bot.reddit_access_token = { | ||
"expiry": now + data["expires_in"], | ||
"token": data["access_token"], | ||
} | ||
api_url = f"https://oauth.reddit.com/api/info/?id=t3_{reddit_post_id}" | ||
headers = { | ||
"User-Agent": user_agent, | ||
"Authorization": f"Bearer {self.bot.reddit_access_token['token']}", | ||
} | ||
async with self.bot.session.get(api_url, headers=headers) as response: | ||
data = await response.json() | ||
post = data["data"]["children"][0]["data"] | ||
|
||
timestamp = int(post["created"]) | ||
dateformat = arrow.get(timestamp).format("YYMMDD") | ||
post = await self.bot.reddit_client.get_post(reddit_post_id) | ||
videos = post.videos | ||
caption = post.caption | ||
|
||
caption = f"{self.EMOJI} `{post['subreddit_name_prefixed']}` <t:{timestamp}:d>" | ||
if options and options.captions: | ||
caption += f"\n>>> {post['title']}" | ||
|
||
media = [] | ||
files = [] | ||
if post.get("is_gallery"): | ||
media = [ | ||
{ | ||
"url": f"https://i.redd.it/{m['id']}.{m['m'].split('/')[-1]}", | ||
} | ||
for m in post["media_metadata"].values() | ||
] | ||
elif post["is_reddit_media_domain"]: | ||
hint = post["post_hint"] | ||
if hint == "image": | ||
media = [{"url": post["url_overridden_by_dest"]}] | ||
elif hint == "hosted:video": | ||
video_url = post["media"]["reddit_video"]["dash_url"] | ||
video_path = f"downloads/{reddit_post_id}.mp4" | ||
ffmpeg = subprocess.call( | ||
[ | ||
"/usr/bin/ffmpeg", | ||
"-y", | ||
"-hide_banner", | ||
"-loglevel", | ||
"error", | ||
"-i", | ||
video_url, | ||
"-c", | ||
"copy", | ||
video_path, | ||
] | ||
) | ||
if ffmpeg != 0: | ||
raise exceptions.CommandError( | ||
"There was an error encoding your video!" | ||
) | ||
files.append( | ||
discord.File( | ||
video_path, | ||
spoiler=options.spoiler if options else False, | ||
) | ||
) | ||
|
||
elif post["is_self"]: | ||
raise exceptions.CommandWarning( | ||
f"This is a text post! [`{reddit_post_id}`]" | ||
) | ||
elif post["post_hint"] == "link": | ||
caption += "\n" + post["url"] | ||
else: | ||
raise exceptions.CommandWarning( | ||
f"I don't know what to do with this post! [`{reddit_post_id}`]" | ||
) | ||
dateformat = arrow.get(post.timestamp).format("YYMMDD") | ||
|
||
tasks = [] | ||
for n, media in enumerate(media, start=1): | ||
filename = f"{dateformat}-{post['subreddit']}-{reddit_post_id}-{n}.jpg" | ||
for n, media in enumerate(post.media, start=1): | ||
filename = f"{dateformat}-{post.subreddit}-{reddit_post_id}-{n}.jpg" | ||
tasks.append( | ||
self.download_media( | ||
media["url"], | ||
media, | ||
filename, | ||
filesize_limit(channel.guild), | ||
url_tags=["reddit"], | ||
spoiler=options.spoiler if options else False, | ||
) | ||
) | ||
|
||
files = [] | ||
results = await asyncio.gather(*tasks) | ||
for result in results: | ||
if isinstance(result, discord.File): | ||
files.append(result) | ||
else: | ||
caption += "\n" + result | ||
|
||
for video in videos: | ||
files.append( | ||
discord.File( | ||
video, | ||
spoiler=options.spoiler if options else False, | ||
) | ||
) | ||
|
||
return { | ||
"content": caption, | ||
"files": files, | ||
"view": MediaUI( | ||
"View on Reddit", | ||
"https://reddit.com" + post["permalink"], | ||
"https://reddit.com" + post.url, | ||
should_suppress=False, | ||
), | ||
"suppress_embeds": False, | ||
|
@@ -368,7 +312,7 @@ def extract_links( | |
r"?([a-zA-Z0-9\.\_\-]+)?\/([p]+)?([reel]+)?([tv]+)?([stories]+)?\/" | ||
r"([a-zA-Z0-9\-\_\.]+)\/?([0-9]+)?" | ||
) | ||
results = [] | ||
results: list[InstagramPost | InstagramStory] = [] | ||
for match in regex.finditer(instagram_regex, text): | ||
# group 1 for username | ||
# group 2 for p | ||
|
@@ -378,18 +322,18 @@ def extract_links( | |
# group 6 for shortcode and username stories | ||
# group 7 for stories pk | ||
if match.group(5) == "stories": | ||
username = match.group(6) | ||
story_id = match.group(7) | ||
username: str = match.group(6) | ||
story_id: str = match.group(7) | ||
if username and story_id: | ||
results.append(InstagramStory(username, story_id)) | ||
results.append(InstagramStory(username=username, story_pk=story_id)) | ||
|
||
elif match.group(6): | ||
results.append(InstagramPost(match.group(6))) | ||
results.append(InstagramPost(shortcode=match.group(6))) | ||
|
||
if include_shortcodes: | ||
shortcode_regex = r"(?:\s|^)([^-][a-zA-Z0-9\-\_\.]{9,})(?=\s|$)" | ||
for match in regex.finditer(shortcode_regex, text): | ||
results.append(InstagramPost(match.group(1))) | ||
results.append(InstagramPost(shortcode=match.group(1))) | ||
|
||
return results | ||
|
||
|
@@ -629,7 +573,7 @@ def __init__(self, label: str, url: str, should_suppress: bool = True): | |
linkbutton = discord.ui.Button(label=label, url=url) | ||
self.add_item(linkbutton) | ||
self.message_ref: discord.Message | None = None | ||
self.approved_deletors = [] | ||
self.approved_deletors: list[discord.User] = [] | ||
self.should_suppress = should_suppress | ||
self._children.reverse() | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.