diff --git a/.dockerignore b/.dockerignore index 8cda483..e7d2fb2 100644 --- a/.dockerignore +++ b/.dockerignore @@ -14,7 +14,6 @@ docker-compose.yml flake.* docs -downloads infra venv templates diff --git a/flake.lock b/flake.lock index 862af81..58631bb 100644 --- a/flake.lock +++ b/flake.lock @@ -167,16 +167,16 @@ }, "nixpkgs_2": { "locked": { - "lastModified": 1704290814, - "narHash": "sha256-LWvKHp7kGxk/GEtlrGYV68qIvPHkU9iToomNFGagixU=", + "lastModified": 1705916986, + "narHash": "sha256-iBpfltu6QvN4xMpen6jGGEb6jOqmmVQKUrXdOJ32u8w=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "70bdadeb94ffc8806c0570eb5c2695ad29f0e421", + "rev": "d7f206b723e42edb09d9d753020a84b3061a79d8", "type": "github" }, "original": { "owner": "NixOS", - "ref": "nixos-23.05", + "ref": "nixos-23.11", "repo": "nixpkgs", "type": "github" } diff --git a/flake.nix b/flake.nix index afe09c7..669136c 100644 --- a/flake.nix +++ b/flake.nix @@ -3,7 +3,7 @@ # https://git.joinemm.dev/miso-bot { inputs = { - nixpkgs.url = "github:NixOS/nixpkgs/nixos-23.05"; + nixpkgs.url = "github:NixOS/nixpkgs/nixos-23.11"; devenv.url = "github:cachix/devenv"; }; diff --git a/modules/media_embedders.py b/modules/media_embedders.py index 712da95..c2a467f 100644 --- a/modules/media_embedders.py +++ b/modules/media_embedders.py @@ -1,26 +1,27 @@ -# SPDX-FileCopyrightText: 2023 Joonas Rautiola +# SPDX-FileCopyrightText: 2024 Joonas Rautiola # SPDX-License-Identifier: MPL-2.0 # https://git.joinemm.dev/miso-bot import asyncio import io -import subprocess -from typing import Any +from typing import TYPE_CHECKING, Any import arrow import discord import regex import yarl -from aiohttp import BasicAuth, ClientConnectorError +from aiohttp import ClientConnectorError from attr import dataclass from discord.ext import commands from discord.ui import View from loguru import logger from modules import emojis, exceptions, instagram, util -from modules.misobot import MisoBot from modules.tiktok import TikTok +if TYPE_CHECKING: + from modules.misobot import MisoBot + @dataclass class InstagramPost: @@ -57,7 +58,7 @@ class BaseEmbedder: EMOJI = "..." def __init__(self, bot) -> None: - self.bot: MisoBot = bot + self.bot: "MisoBot" = bot @staticmethod def get_options(text: str) -> Options: @@ -123,9 +124,11 @@ async def download_media( # The url params are unescaped by aiohttp's built-in yarl # This causes problems with the hash-based request signing that instagram uses # Thankfully you can plug your own yarl.URL with encoded=True so it wont get encoded twice - headers = {"User-Agent": util.random_user_agent()} async with self.bot.session.get( - yarl.URL(media_url, encoded=True), headers=headers + yarl.URL(media_url, encoded=True), + headers={ + "User-Agent": util.random_user_agent(), + }, ) as response: if not response.ok: if response.headers.get("Content-Type") == "text/plain": @@ -140,22 +143,28 @@ async def download_media( content_length = response.headers.get( "Content-Length" ) or response.headers.get("x-full-image-content-length") - if content_length and int(content_length) < max_filesize: - try: - buffer = io.BytesIO(await response.read()) - return discord.File(fp=buffer, filename=filename, spoiler=spoiler) - except asyncio.TimeoutError: - pass try: - # try to stream until we hit our limit - buffer = b"" - async for chunk in response.content.iter_chunked(1024): - buffer += chunk - if len(buffer) > max_filesize: + if content_length: + if int(content_length) < max_filesize: + try: + file = io.BytesIO(await response.read()) + return discord.File( + fp=file, filename=filename, spoiler=spoiler + ) + except asyncio.TimeoutError: + pass + else: raise ValueError - return discord.File( - fp=io.BytesIO(buffer), filename=filename, spoiler=spoiler - ) + else: + # try to stream until we hit our limit + buffer: bytes = b"" + async for chunk in response.content.iter_chunked(1024): + buffer += chunk + if len(buffer) > max_filesize: + raise ValueError + return discord.File( + fp=io.BytesIO(buffer), filename=filename, spoiler=spoiler + ) except ValueError: pass @@ -179,8 +188,7 @@ async def send( ctx.channel, media, options=options ) msg = await ctx.send(**message_contents) - message_contents["view"].message_ref = msg - message_contents["view"].approved_deletors.append(ctx.author) + await self.msg_post_process(msg, message_contents["view"], ctx.author) async def send_contextless( self, @@ -192,8 +200,7 @@ async def send_contextless( """Send the media without relying on command context, for example in a message event""" message_contents = await self.create_message(channel, media, options=options) msg = await channel.send(**message_contents) - message_contents["view"].message_ref = msg - message_contents["view"].approved_deletors.append(author) + await self.msg_post_process(msg, message_contents["view"], author) async def send_reply( self, @@ -211,8 +218,16 @@ async def send_reply( # the original message was deleted, so we can't reply msg = await message.channel.send(**message_contents) - message_contents["view"].message_ref = msg - message_contents["view"].approved_deletors.append(message.author) + await self.msg_post_process(msg, message_contents["view"], message.author) + + async def msg_post_process( + self, + msg: discord.Message, + view: discord.ui, + author: discord.User, + ): + view.message_ref = msg + view.approved_deletors.append(author) class RedditEmbedder(BaseEmbedder): @@ -235,98 +250,18 @@ async def create_message( reddit_post_id: str, options: Options | None = None, ): - user_agent = "Miso Bot (by Joinemm)" - token = self.bot.reddit_access_token - now = arrow.utcnow().timestamp() - if token["expiry"] < now: - async with self.bot.session.post( - "https://www.reddit.com/api/v1/access_token", - headers={"User-Agent": user_agent}, - data={"grant_type": "client_credentials"}, - auth=BasicAuth( - self.bot.keychain.REDDIT_CLIENT_ID, - self.bot.keychain.REDDIT_CLIENT_SECRET, - ), - ) as response: - data = await response.json() - self.bot.reddit_access_token = { - "expiry": now + data["expires_in"], - "token": data["access_token"], - } - api_url = f"https://oauth.reddit.com/api/info/?id=t3_{reddit_post_id}" - headers = { - "User-Agent": user_agent, - "Authorization": f"Bearer {self.bot.reddit_access_token['token']}", - } - async with self.bot.session.get(api_url, headers=headers) as response: - data = await response.json() - post = data["data"]["children"][0]["data"] - - timestamp = int(post["created"]) - dateformat = arrow.get(timestamp).format("YYMMDD") + post = await self.bot.reddit_client.get_post(reddit_post_id) + videos = post.videos + caption = post.caption - caption = f"{self.EMOJI} `{post['subreddit_name_prefixed']}` " - if options and options.captions: - caption += f"\n>>> {post['title']}" - - media = [] - files = [] - if post.get("is_gallery"): - media = [ - { - "url": f"https://i.redd.it/{m['id']}.{m['m'].split('/')[-1]}", - } - for m in post["media_metadata"].values() - ] - elif post["is_reddit_media_domain"]: - hint = post["post_hint"] - if hint == "image": - media = [{"url": post["url_overridden_by_dest"]}] - elif hint == "hosted:video": - video_url = post["media"]["reddit_video"]["dash_url"] - video_path = f"downloads/{reddit_post_id}.mp4" - ffmpeg = subprocess.call( - [ - "/usr/bin/ffmpeg", - "-y", - "-hide_banner", - "-loglevel", - "error", - "-i", - video_url, - "-c", - "copy", - video_path, - ] - ) - if ffmpeg != 0: - raise exceptions.CommandError( - "There was an error encoding your video!" - ) - files.append( - discord.File( - video_path, - spoiler=options.spoiler if options else False, - ) - ) - - elif post["is_self"]: - raise exceptions.CommandWarning( - f"This is a text post! [`{reddit_post_id}`]" - ) - elif post["post_hint"] == "link": - caption += "\n" + post["url"] - else: - raise exceptions.CommandWarning( - f"I don't know what to do with this post! [`{reddit_post_id}`]" - ) + dateformat = arrow.get(post.timestamp).format("YYMMDD") tasks = [] - for n, media in enumerate(media, start=1): - filename = f"{dateformat}-{post['subreddit']}-{reddit_post_id}-{n}.jpg" + for n, media in enumerate(post.media, start=1): + filename = f"{dateformat}-{post.subreddit}-{reddit_post_id}-{n}.jpg" tasks.append( self.download_media( - media["url"], + media, filename, filesize_limit(channel.guild), url_tags=["reddit"], @@ -334,6 +269,7 @@ async def create_message( ) ) + files = [] results = await asyncio.gather(*tasks) for result in results: if isinstance(result, discord.File): @@ -341,12 +277,20 @@ async def create_message( else: caption += "\n" + result + for video in videos: + files.append( + discord.File( + video, + spoiler=options.spoiler if options else False, + ) + ) + return { "content": caption, "files": files, "view": MediaUI( "View on Reddit", - "https://reddit.com" + post["permalink"], + "https://reddit.com" + post.url, should_suppress=False, ), "suppress_embeds": False, @@ -368,7 +312,7 @@ def extract_links( r"?([a-zA-Z0-9\.\_\-]+)?\/([p]+)?([reel]+)?([tv]+)?([stories]+)?\/" r"([a-zA-Z0-9\-\_\.]+)\/?([0-9]+)?" ) - results = [] + results: list[InstagramPost | InstagramStory] = [] for match in regex.finditer(instagram_regex, text): # group 1 for username # group 2 for p @@ -378,18 +322,18 @@ def extract_links( # group 6 for shortcode and username stories # group 7 for stories pk if match.group(5) == "stories": - username = match.group(6) - story_id = match.group(7) + username: str = match.group(6) + story_id: str = match.group(7) if username and story_id: - results.append(InstagramStory(username, story_id)) + results.append(InstagramStory(username=username, story_pk=story_id)) elif match.group(6): - results.append(InstagramPost(match.group(6))) + results.append(InstagramPost(shortcode=match.group(6))) if include_shortcodes: shortcode_regex = r"(?:\s|^)([^-][a-zA-Z0-9\-\_\.]{9,})(?=\s|$)" for match in regex.finditer(shortcode_regex, text): - results.append(InstagramPost(match.group(1))) + results.append(InstagramPost(shortcode=match.group(1))) return results @@ -629,7 +573,7 @@ def __init__(self, label: str, url: str, should_suppress: bool = True): linkbutton = discord.ui.Button(label=label, url=url) self.add_item(linkbutton) self.message_ref: discord.Message | None = None - self.approved_deletors = [] + self.approved_deletors: list[discord.User] = [] self.should_suppress = should_suppress self._children.reverse() diff --git a/modules/misobot.py b/modules/misobot.py index 1eacdb8..643fdd4 100644 --- a/modules/misobot.py +++ b/modules/misobot.py @@ -15,13 +15,14 @@ from discord.errors import Forbidden from discord.ext import commands from loguru import logger + +from modules import cache, maria, util from modules.help import EmbedHelpCommand from modules.instagram import Datalama from modules.keychain import Keychain +from modules.reddit import Reddit from modules.redis import Redis -from modules import cache, maria, util - @dataclass class LastFmContext: @@ -103,7 +104,7 @@ def __init__( self.datalama = Datalama(self) self.boot_up_time: float | None = None self.session: aiohttp.ClientSession - self.reddit_access_token = {"expiry": 0, "token": None} + self.reddit_client = Reddit(self) self.register_hooks() async def get_context(self, message: discord.Message): diff --git a/modules/reddit.py b/modules/reddit.py new file mode 100644 index 0000000..eaced7f --- /dev/null +++ b/modules/reddit.py @@ -0,0 +1,151 @@ +# SPDX-FileCopyrightText: 2024 Joonas Rautiola +# SPDX-License-Identifier: MPL-2.0 +# https://git.joinemm.dev/miso-bot + +import subprocess +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import arrow +from aiohttp import BasicAuth + +from modules import exceptions +from modules.media_embedders import Options + +if TYPE_CHECKING: + from modules.misobot import MisoBot + + +class RedditError(Exception): + def __init__(self, message): + self.message = message + super().__init__(message) + + +@dataclass +class RedditPost: + videos: list[str] + caption: str + media: list[str] + timestamp: int + url: str + subreddit: str + + +class Reddit: + API_V1_URL: str = "https://www.reddit.com/api/v1" + API_OAUTH_URL: str = "https://oauth.reddit.com/api" + EMOJI = "<:reddit:1184484866520264724>" + USER_AGENT = "Miso Bot (by Joinemm)" + + def __init__(self, bot: "MisoBot"): + self.bot = bot + self.access_token = { + "expiry": 0, + "token": None, + } + + async def autheticate(self): + now = arrow.utcnow().timestamp() + async with self.bot.session.post( + self.API_V1_URL + "/access_token", + headers={ + "User-Agent": self.USER_AGENT, + }, + data={ + "grant_type": "client_credentials", + }, + auth=BasicAuth( + self.bot.keychain.REDDIT_CLIENT_ID, + self.bot.keychain.REDDIT_CLIENT_SECRET, + ), + ) as response: + data = await response.json() + self.access_token = { + "expiry": now + data["expires_in"], + "token": data["access_token"], + } + + async def api_request(self, path: str): + if self.access_token["expiry"] < arrow.utcnow().timestamp(): + await self.autheticate() + + async with self.bot.session.get( + self.API_OAUTH_URL + path, + headers={ + "User-Agent": self.USER_AGENT, + "Authorization": f"Bearer {self.access_token['token']}", + }, + ) as response: + data = await response.json() + post = data["data"]["children"][0]["data"] + + return post + + async def get_post( + self, + reddit_post_id: str, + options: Options | None = None, + ): + post = await self.api_request(f"/info/?id=t3_{reddit_post_id}") + + timestamp = int(post["created"]) + caption = f"{self.EMOJI} `{post['subreddit_name_prefixed']}` " + if options and options.captions: + caption += f"\n>>> {post['title']}" + + pictures = [] + videos = [] + if post.get("is_gallery"): + for item in post["gallery_data"]["items"]: + meta = post["media_metadata"][item["media_id"]] + pictures.append( + f"https://i.redd.it/{meta['id']}.{meta['m'].split('/')[-1]}" + ) + elif post["is_reddit_media_domain"]: + hint = post["post_hint"] + if hint == "image": + pictures = [post["url_overridden_by_dest"]] + + elif hint == "hosted:video": + video_url = post["media"]["reddit_video"]["dash_url"] + video_path = f"downloads/{reddit_post_id}.mp4" + ffmpeg = subprocess.call( + [ + "/usr/bin/ffmpeg", + "-y", + "-hide_banner", + "-loglevel", + "error", + "-i", + video_url, + "-c", + "copy", + video_path, + ] + ) + if ffmpeg != 0: + raise exceptions.CommandError( + "There was an error encoding your video!" + ) + videos.append(video_path) + + elif post["is_self"]: + raise exceptions.CommandWarning( + f"This is a text post! [`{reddit_post_id}`]" + ) + elif post["post_hint"] == "link": + caption += "\n" + post["url"] + else: + raise exceptions.CommandWarning( + f"I don't know what to do with this post! [`{reddit_post_id}`]" + ) + + return RedditPost( + media=pictures, + videos=videos, + caption=caption, + timestamp=timestamp, + url=post["permalink"], + subreddit=post["subreddit"], + ) diff --git a/poetry.lock b/poetry.lock index 5a1b41b..649ccae 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1369,6 +1369,66 @@ files = [ {file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"}, ] +[[package]] +name = "mypy" +version = "1.8.0" +description = "Optional static typing for Python" +category = "dev" +optional = false +python-versions = ">=3.8" +files = [ + {file = "mypy-1.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:485a8942f671120f76afffff70f259e1cd0f0cfe08f81c05d8816d958d4577d3"}, + {file = "mypy-1.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:df9824ac11deaf007443e7ed2a4a26bebff98d2bc43c6da21b2b64185da011c4"}, + {file = "mypy-1.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2afecd6354bbfb6e0160f4e4ad9ba6e4e003b767dd80d85516e71f2e955ab50d"}, + {file = "mypy-1.8.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8963b83d53ee733a6e4196954502b33567ad07dfd74851f32be18eb932fb1cb9"}, + {file = "mypy-1.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:e46f44b54ebddbeedbd3d5b289a893219065ef805d95094d16a0af6630f5d410"}, + {file = "mypy-1.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:855fe27b80375e5c5878492f0729540db47b186509c98dae341254c8f45f42ae"}, + {file = "mypy-1.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4c886c6cce2d070bd7df4ec4a05a13ee20c0aa60cb587e8d1265b6c03cf91da3"}, + {file = "mypy-1.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d19c413b3c07cbecf1f991e2221746b0d2a9410b59cb3f4fb9557f0365a1a817"}, + {file = "mypy-1.8.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9261ed810972061388918c83c3f5cd46079d875026ba97380f3e3978a72f503d"}, + {file = "mypy-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:51720c776d148bad2372ca21ca29256ed483aa9a4cdefefcef49006dff2a6835"}, + {file = "mypy-1.8.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:52825b01f5c4c1c4eb0db253ec09c7aa17e1a7304d247c48b6f3599ef40db8bd"}, + {file = "mypy-1.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f5ac9a4eeb1ec0f1ccdc6f326bcdb464de5f80eb07fb38b5ddd7b0de6bc61e55"}, + {file = "mypy-1.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afe3fe972c645b4632c563d3f3eff1cdca2fa058f730df2b93a35e3b0c538218"}, + {file = "mypy-1.8.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:42c6680d256ab35637ef88891c6bd02514ccb7e1122133ac96055ff458f93fc3"}, + {file = "mypy-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:720a5ca70e136b675af3af63db533c1c8c9181314d207568bbe79051f122669e"}, + {file = "mypy-1.8.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:028cf9f2cae89e202d7b6593cd98db6759379f17a319b5faf4f9978d7084cdc6"}, + {file = "mypy-1.8.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4e6d97288757e1ddba10dd9549ac27982e3e74a49d8d0179fc14d4365c7add66"}, + {file = "mypy-1.8.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f1478736fcebb90f97e40aff11a5f253af890c845ee0c850fe80aa060a267c6"}, + {file = "mypy-1.8.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:42419861b43e6962a649068a61f4a4839205a3ef525b858377a960b9e2de6e0d"}, + {file = "mypy-1.8.0-cp38-cp38-win_amd64.whl", hash = "sha256:2b5b6c721bd4aabaadead3a5e6fa85c11c6c795e0c81a7215776ef8afc66de02"}, + {file = "mypy-1.8.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5c1538c38584029352878a0466f03a8ee7547d7bd9f641f57a0f3017a7c905b8"}, + {file = "mypy-1.8.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4ef4be7baf08a203170f29e89d79064463b7fc7a0908b9d0d5114e8009c3a259"}, + {file = "mypy-1.8.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7178def594014aa6c35a8ff411cf37d682f428b3b5617ca79029d8ae72f5402b"}, + {file = "mypy-1.8.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ab3c84fa13c04aeeeabb2a7f67a25ef5d77ac9d6486ff33ded762ef353aa5592"}, + {file = "mypy-1.8.0-cp39-cp39-win_amd64.whl", hash = "sha256:99b00bc72855812a60d253420d8a2eae839b0afa4938f09f4d2aa9bb4654263a"}, + {file = "mypy-1.8.0-py3-none-any.whl", hash = "sha256:538fd81bb5e430cc1381a443971c0475582ff9f434c16cd46d2c66763ce85d9d"}, + {file = "mypy-1.8.0.tar.gz", hash = "sha256:6ff8b244d7085a0b425b56d327b480c3b29cafbd2eff27316a004f9a7391ae07"}, +] + +[package.dependencies] +mypy-extensions = ">=1.0.0" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = ">=4.1.0" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +install-types = ["pip"] +mypyc = ["setuptools (>=50)"] +reports = ["lxml"] + +[[package]] +name = "mypy-extensions" +version = "1.0.0" +description = "Type system extensions for programs checked with the mypy type checker." +category = "dev" +optional = false +python-versions = ">=3.5" +files = [ + {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, + {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, +] + [[package]] name = "nodeenv" version = "1.8.0" @@ -2598,4 +2658,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "5a87bbb2efce16f7194f684b95d60f841df669a37fadee520c5f14286ac92589" +content-hash = "d340e27aa840f9d3289f1066b29b5e5d0e22d160d71423f46e993d38a3f2dd10" diff --git a/pyproject.toml b/pyproject.toml index 203991e..ca80a10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ async-cse = "^0.3.0" [tool.poetry.group.dev.dependencies] pre-commit = "^3.6.0" +mypy = "^1.8.0" [build-system] requires = ["poetry-core"]