Skip to content

Commit

Permalink
Add first version of reddit embedder
Browse files Browse the repository at this point in the history
  • Loading branch information
joinemm committed Dec 13, 2023
1 parent 4dcac4a commit 55e4eec
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 14 deletions.
14 changes: 14 additions & 0 deletions cogs/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from modules.media_embedders import (
BaseEmbedder,
InstagramEmbedder,
RedditEmbedder,
TikTokEmbedder,
TwitterEmbedder,
)
Expand Down Expand Up @@ -217,6 +218,19 @@ async def twitter(self, ctx: commands.Context, *, links: str):
"""
await TwitterEmbedder(self.bot).process(ctx, links)

@commands.command(
usage="[OPTIONS] <links...>",
)
async def reddit(self, ctx: commands.Context, *, links: str):
"""Retrieve media from a reddit post
OPTIONS
`-c`, `--caption` : also include the caption/text of the media
`-s`, `--spoiler` : spoiler the uploaded images and text
`-d`, `--delete` : delete your message when the media is done embedding
"""
await RedditEmbedder(self.bot).process(ctx, links)

@commands.command(
aliases=["tik", "tok", "tt"],
usage="[OPTIONS] <links...>",
Expand Down
142 changes: 128 additions & 14 deletions modules/media_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import asyncio
import io
import subprocess
from typing import Any

import arrow
Expand Down Expand Up @@ -122,7 +123,10 @@ async def download_media(
# The url params are unescaped by aiohttp's built-in yarl
# This causes problems with the hash-based request signing that instagram uses
# Thankfully you can plug your own yarl.URL with encoded=True so it wont get encoded twice
async with self.bot.session.get(yarl.URL(media_url, encoded=True)) as response:
headers = {"User-Agent": util.random_user_agent()}
async with self.bot.session.get(
yarl.URL(media_url, encoded=True), headers=headers
) as response:
if not response.ok:
if response.headers.get("Content-Type") == "text/plain":
content = await response.text()
Expand Down Expand Up @@ -174,7 +178,7 @@ async def send(
message_contents = await self.create_message(
ctx.channel, media, options=options
)
msg = await ctx.send(**message_contents, suppress_embeds=True)
msg = await ctx.send(**message_contents)
message_contents["view"].message_ref = msg
message_contents["view"].approved_deletors.append(ctx.author)

Expand All @@ -187,7 +191,7 @@ async def send_contextless(
):
"""Send the media without relying on command context, for example in a message event"""
message_contents = await self.create_message(channel, media, options=options)
msg = await channel.send(**message_contents, suppress_embeds=True)
msg = await channel.send(**message_contents)
message_contents["view"].message_ref = msg
message_contents["view"].approved_deletors.append(author)

Expand All @@ -202,17 +206,128 @@ async def send_reply(
message.channel, media, options=options
)
try:
msg = await message.reply(
**message_contents, mention_author=False, suppress_embeds=True
)
msg = await message.reply(**message_contents, mention_author=False)
except discord.errors.HTTPException:
# the original message was deleted, so we can't reply
msg = await message.channel.send(**message_contents, suppress_embeds=True)
msg = await message.channel.send(**message_contents)

message_contents["view"].message_ref = msg
message_contents["view"].approved_deletors.append(message.author)


class RedditEmbedder(BaseEmbedder):
NAME = "reddit"
EMOJI = "<:reddit:1184484866520264724>"
NO_RESULTS_ERROR = "Found no Reddit links to embed!"

@staticmethod
def extract_links(text: str) -> list[str]:
text = "\n".join(text.split())
reddit_regex = r"(?:.+?)(?:reddit\.com/r)(?:/[\w\d]+){2}(?:/)([\w\d]*)"
gallery_regex = r"(?:.+?)reddit\.com/gallery/([\w\d]*)"
posts = regex.findall(reddit_regex, text)
galleries = regex.findall(gallery_regex, text)
return posts + galleries

async def create_message(
self,
channel: "discord.abc.MessageableChannel",
reddit_post_id: str,
options: Options | None = None,
):
api_url = f"https://api.reddit.com/api/info/?id=t3_{reddit_post_id}"
async with self.bot.session.get(api_url) as response:
data = await response.json()
post = data["data"]["children"][0]["data"]

timestamp = int(post["created"])
dateformat = arrow.get(timestamp).format("YYMMDD")

caption = f"{self.EMOJI} `{post['subreddit_name_prefixed']}` <t:{timestamp}:d>"
if options and options.captions:
caption += f"\n>>> {post['title']}"

media = []
files = []
if post.get("is_gallery"):
media = post["media_metadata"].values()
elif post["is_reddit_media_domain"]:
hint = post["post_hint"]
if hint == "image":
media = [
{"id": post["url_overridden_by_dest"].split("/")[-1].split(".")[0]}
]
elif hint == "hosted:video":
video_url = post["media"]["reddit_video"]["dash_url"]
video_path = f"downloads/{reddit_post_id}.mp4"
ffmpeg = subprocess.call(
[
"ffmpeg",
"-y",
"-hide_banner",
"-loglevel",
"error",
"-i",
video_url,
"-c",
"copy",
video_path,
]
)
if ffmpeg != 0:
raise exceptions.CommandError(
"There was an error encoding your video!"
)
files.append(
discord.File(
video_path,
spoiler=options.spoiler if options else False,
)
)

elif post["is_self"]:
raise exceptions.CommandWarning(
f"This is a text post! [`{reddit_post_id}`]"
)
elif post["post_hint"] == "link":
caption += "\n" + post["url"]
else:
raise exceptions.CommandWarning(
f"I don't know what to do with this post! [`{reddit_post_id}`]"
)

tasks = []
for n, media in enumerate(media, start=1):
filename = f"{dateformat}-{post['subreddit']}-{reddit_post_id}-{n}.jpg"
tasks.append(
self.download_media(
f"https://i.redd.it/{media['id']}.jpg",
filename,
filesize_limit(channel.guild),
url_tags=["reddit"],
spoiler=options.spoiler if options else False,
)
)

results = await asyncio.gather(*tasks)
for result in results:
if isinstance(result, discord.File):
files.append(result)
else:
caption += "\n" + result

return {
"content": caption,
"files": files,
"view": MediaUI(
"View on Reddit",
"https://reddit.com" + post["permalink"],
should_suppress=False,
),
"suppress_embeds": False,
}


class InstagramEmbedder(BaseEmbedder):
NAME = "instagram"
EMOJI = "<:ig:937425165162262528>"
Expand Down Expand Up @@ -299,6 +414,7 @@ async def create_message(
"content": caption,
"files": files,
"view": MediaUI("View on Instagram", post.url),
"suppress_embeds": True,
}


Expand Down Expand Up @@ -359,11 +475,7 @@ async def create_message(
"view": ui,
}

return {
"content": caption,
"file": file,
"view": ui,
}
return {"content": caption, "file": file, "view": ui, "suppress_embeds": True}


class TwitterEmbedder(BaseEmbedder):
Expand Down Expand Up @@ -474,16 +586,18 @@ async def create_message(
"view": MediaUI(
"View on X", f"https://twitter.com/{screen_name}/status/{tweet_id}"
),
"suppress_embeds": True,
}


class MediaUI(View):
def __init__(self, label: str, url: str):
def __init__(self, label: str, url: str, should_suppress: bool = True):
super().__init__(timeout=60)
linkbutton = discord.ui.Button(label=label, url=url)
self.add_item(linkbutton)
self.message_ref: discord.Message | None = None
self.approved_deletors = []
self.should_suppress = should_suppress
self._children.reverse()

@discord.ui.button(emoji=emojis.REMOVE, style=discord.ButtonStyle.danger)
Expand All @@ -499,6 +613,6 @@ async def on_timeout(self):
self.remove_item(self.delete_button)
if self.message_ref:
try:
await self.message_ref.edit(view=self, suppress=True)
await self.message_ref.edit(view=self, suppress=self.should_suppress)
except discord.NotFound:
pass

0 comments on commit 55e4eec

Please sign in to comment.