Skip to content

Commit

Permalink
Improvements to reddit embedder
Browse files Browse the repository at this point in the history
  • Loading branch information
joinemm committed Mar 11, 2024
1 parent 883f139 commit 10d5a91
Show file tree
Hide file tree
Showing 8 changed files with 289 additions and 133 deletions.
1 change: 0 additions & 1 deletion .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ docker-compose.yml
flake.*

docs
downloads
infra
venv
templates
Expand Down
8 changes: 4 additions & 4 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# https://git.joinemm.dev/miso-bot
{
inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-23.05";
nixpkgs.url = "github:NixOS/nixpkgs/nixos-23.11";
devenv.url = "github:cachix/devenv";
};

Expand Down
190 changes: 67 additions & 123 deletions modules/media_embedders.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
# SPDX-FileCopyrightText: 2023 Joonas Rautiola <[email protected]>
# SPDX-FileCopyrightText: 2024 Joonas Rautiola <[email protected]>
# SPDX-License-Identifier: MPL-2.0
# https://git.joinemm.dev/miso-bot

import asyncio
import io
import subprocess
from typing import Any
from typing import TYPE_CHECKING, Any

import arrow
import discord
import regex
import yarl
from aiohttp import BasicAuth, ClientConnectorError
from aiohttp import ClientConnectorError
from attr import dataclass
from discord.ext import commands
from discord.ui import View
from loguru import logger

from modules import emojis, exceptions, instagram, util
from modules.misobot import MisoBot
from modules.tiktok import TikTok

if TYPE_CHECKING:
from modules.misobot import MisoBot


@dataclass
class InstagramPost:
Expand Down Expand Up @@ -57,7 +58,7 @@ class BaseEmbedder:
EMOJI = "..."

def __init__(self, bot) -> None:
self.bot: MisoBot = bot
self.bot: "MisoBot" = bot

@staticmethod
def get_options(text: str) -> Options:
Expand Down Expand Up @@ -123,9 +124,11 @@ async def download_media(
# The url params are unescaped by aiohttp's built-in yarl
# This causes problems with the hash-based request signing that instagram uses
# Thankfully you can plug your own yarl.URL with encoded=True so it wont get encoded twice
headers = {"User-Agent": util.random_user_agent()}
async with self.bot.session.get(
yarl.URL(media_url, encoded=True), headers=headers
yarl.URL(media_url, encoded=True),
headers={
"User-Agent": util.random_user_agent(),
},
) as response:
if not response.ok:
if response.headers.get("Content-Type") == "text/plain":
Expand All @@ -140,22 +143,28 @@ async def download_media(
content_length = response.headers.get(
"Content-Length"
) or response.headers.get("x-full-image-content-length")
if content_length and int(content_length) < max_filesize:
try:
buffer = io.BytesIO(await response.read())
return discord.File(fp=buffer, filename=filename, spoiler=spoiler)
except asyncio.TimeoutError:
pass
try:
# try to stream until we hit our limit
buffer = b""
async for chunk in response.content.iter_chunked(1024):
buffer += chunk
if len(buffer) > max_filesize:
if content_length:
if int(content_length) < max_filesize:
try:
file = io.BytesIO(await response.read())
return discord.File(
fp=file, filename=filename, spoiler=spoiler
)
except asyncio.TimeoutError:
pass
else:
raise ValueError
return discord.File(
fp=io.BytesIO(buffer), filename=filename, spoiler=spoiler
)
else:
# try to stream until we hit our limit
buffer: bytes = b""
async for chunk in response.content.iter_chunked(1024):
buffer += chunk
if len(buffer) > max_filesize:
raise ValueError
return discord.File(
fp=io.BytesIO(buffer), filename=filename, spoiler=spoiler
)
except ValueError:
pass

Expand All @@ -179,8 +188,7 @@ async def send(
ctx.channel, media, options=options
)
msg = await ctx.send(**message_contents)
message_contents["view"].message_ref = msg
message_contents["view"].approved_deletors.append(ctx.author)
await self.msg_post_process(msg, message_contents["view"], ctx.author)

async def send_contextless(
self,
Expand All @@ -192,8 +200,7 @@ async def send_contextless(
"""Send the media without relying on command context, for example in a message event"""
message_contents = await self.create_message(channel, media, options=options)
msg = await channel.send(**message_contents)
message_contents["view"].message_ref = msg
message_contents["view"].approved_deletors.append(author)
await self.msg_post_process(msg, message_contents["view"], author)

async def send_reply(
self,
Expand All @@ -211,8 +218,16 @@ async def send_reply(
# the original message was deleted, so we can't reply
msg = await message.channel.send(**message_contents)

message_contents["view"].message_ref = msg
message_contents["view"].approved_deletors.append(message.author)
await self.msg_post_process(msg, message_contents["view"], message.author)

async def msg_post_process(
self,
msg: discord.Message,
view: discord.ui,
author: discord.User,
):
view.message_ref = msg
view.approved_deletors.append(author)


class RedditEmbedder(BaseEmbedder):
Expand All @@ -235,118 +250,47 @@ async def create_message(
reddit_post_id: str,
options: Options | None = None,
):
user_agent = "Miso Bot (by Joinemm)"
token = self.bot.reddit_access_token
now = arrow.utcnow().timestamp()
if token["expiry"] < now:
async with self.bot.session.post(
"https://www.reddit.com/api/v1/access_token",
headers={"User-Agent": user_agent},
data={"grant_type": "client_credentials"},
auth=BasicAuth(
self.bot.keychain.REDDIT_CLIENT_ID,
self.bot.keychain.REDDIT_CLIENT_SECRET,
),
) as response:
data = await response.json()
self.bot.reddit_access_token = {
"expiry": now + data["expires_in"],
"token": data["access_token"],
}
api_url = f"https://oauth.reddit.com/api/info/?id=t3_{reddit_post_id}"
headers = {
"User-Agent": user_agent,
"Authorization": f"Bearer {self.bot.reddit_access_token['token']}",
}
async with self.bot.session.get(api_url, headers=headers) as response:
data = await response.json()
post = data["data"]["children"][0]["data"]

timestamp = int(post["created"])
dateformat = arrow.get(timestamp).format("YYMMDD")
post = await self.bot.reddit_client.get_post(reddit_post_id)
videos = post.videos
caption = post.caption

caption = f"{self.EMOJI} `{post['subreddit_name_prefixed']}` <t:{timestamp}:d>"
if options and options.captions:
caption += f"\n>>> {post['title']}"

media = []
files = []
if post.get("is_gallery"):
media = [
{
"url": f"https://i.redd.it/{m['id']}.{m['m'].split('/')[-1]}",
}
for m in post["media_metadata"].values()
]
elif post["is_reddit_media_domain"]:
hint = post["post_hint"]
if hint == "image":
media = [{"url": post["url_overridden_by_dest"]}]
elif hint == "hosted:video":
video_url = post["media"]["reddit_video"]["dash_url"]
video_path = f"downloads/{reddit_post_id}.mp4"
ffmpeg = subprocess.call(
[
"/usr/bin/ffmpeg",
"-y",
"-hide_banner",
"-loglevel",
"error",
"-i",
video_url,
"-c",
"copy",
video_path,
]
)
if ffmpeg != 0:
raise exceptions.CommandError(
"There was an error encoding your video!"
)
files.append(
discord.File(
video_path,
spoiler=options.spoiler if options else False,
)
)

elif post["is_self"]:
raise exceptions.CommandWarning(
f"This is a text post! [`{reddit_post_id}`]"
)
elif post["post_hint"] == "link":
caption += "\n" + post["url"]
else:
raise exceptions.CommandWarning(
f"I don't know what to do with this post! [`{reddit_post_id}`]"
)
dateformat = arrow.get(post.timestamp).format("YYMMDD")

tasks = []
for n, media in enumerate(media, start=1):
filename = f"{dateformat}-{post['subreddit']}-{reddit_post_id}-{n}.jpg"
for n, media in enumerate(post.media, start=1):
filename = f"{dateformat}-{post.subreddit}-{reddit_post_id}-{n}.jpg"
tasks.append(
self.download_media(
media["url"],
media,
filename,
filesize_limit(channel.guild),
url_tags=["reddit"],
spoiler=options.spoiler if options else False,
)
)

files = []
results = await asyncio.gather(*tasks)
for result in results:
if isinstance(result, discord.File):
files.append(result)
else:
caption += "\n" + result

for video in videos:
files.append(
discord.File(
video,
spoiler=options.spoiler if options else False,
)
)

return {
"content": caption,
"files": files,
"view": MediaUI(
"View on Reddit",
"https://reddit.com" + post["permalink"],
"https://reddit.com" + post.url,
should_suppress=False,
),
"suppress_embeds": False,
Expand All @@ -368,7 +312,7 @@ def extract_links(
r"?([a-zA-Z0-9\.\_\-]+)?\/([p]+)?([reel]+)?([tv]+)?([stories]+)?\/"
r"([a-zA-Z0-9\-\_\.]+)\/?([0-9]+)?"
)
results = []
results: list[InstagramPost | InstagramStory] = []
for match in regex.finditer(instagram_regex, text):
# group 1 for username
# group 2 for p
Expand All @@ -378,18 +322,18 @@ def extract_links(
# group 6 for shortcode and username stories
# group 7 for stories pk
if match.group(5) == "stories":
username = match.group(6)
story_id = match.group(7)
username: str = match.group(6)
story_id: str = match.group(7)
if username and story_id:
results.append(InstagramStory(username, story_id))
results.append(InstagramStory(username=username, story_pk=story_id))

elif match.group(6):
results.append(InstagramPost(match.group(6)))
results.append(InstagramPost(shortcode=match.group(6)))

if include_shortcodes:
shortcode_regex = r"(?:\s|^)([^-][a-zA-Z0-9\-\_\.]{9,})(?=\s|$)"
for match in regex.finditer(shortcode_regex, text):
results.append(InstagramPost(match.group(1)))
results.append(InstagramPost(shortcode=match.group(1)))

return results

Expand Down Expand Up @@ -629,7 +573,7 @@ def __init__(self, label: str, url: str, should_suppress: bool = True):
linkbutton = discord.ui.Button(label=label, url=url)
self.add_item(linkbutton)
self.message_ref: discord.Message | None = None
self.approved_deletors = []
self.approved_deletors: list[discord.User] = []
self.should_suppress = should_suppress
self._children.reverse()

Expand Down
7 changes: 4 additions & 3 deletions modules/misobot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
from discord.errors import Forbidden
from discord.ext import commands
from loguru import logger

from modules import cache, maria, util
from modules.help import EmbedHelpCommand
from modules.instagram import Datalama
from modules.keychain import Keychain
from modules.reddit import Reddit
from modules.redis import Redis

from modules import cache, maria, util


@dataclass
class LastFmContext:
Expand Down Expand Up @@ -103,7 +104,7 @@ def __init__(
self.datalama = Datalama(self)
self.boot_up_time: float | None = None
self.session: aiohttp.ClientSession
self.reddit_access_token = {"expiry": 0, "token": None}
self.reddit_client = Reddit(self)
self.register_hooks()

async def get_context(self, message: discord.Message):
Expand Down
Loading

0 comments on commit 10d5a91

Please sign in to comment.