Skip to content

feat: pydantic #2675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ These changes are available on the `master` branch, but have not yet been releas
- Replaced audioop (deprecated module) implementation of `PCMVolumeTransformer.read`
method with a pure Python equivalent.
([#2176](https://github.com/Pycord-Development/pycord/pull/2176))
- Updated `Guild.filesize_limit` to 10 Mb instead of 25 Mb following Discord's API
changes. ([#2671](https://github.com/Pycord-Development/pycord/pull/2671))

### Deprecated

Expand Down
75 changes: 52 additions & 23 deletions discord/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from typing import TYPE_CHECKING, Any, Literal

import yarl
from typing_extensions import Final, override

from . import utils
from .errors import DiscordException, InvalidArgument
Expand All @@ -39,6 +40,7 @@
if TYPE_CHECKING:
ValidStaticFormatTypes = Literal["webp", "jpeg", "jpg", "png"]
ValidAssetFormatTypes = Literal["webp", "jpeg", "jpg", "png", "gif"]
from .state import ConnectionState

VALID_STATIC_FORMATS = frozenset({"jpeg", "jpg", "webp", "png"})
VALID_ASSET_FORMATS = VALID_STATIC_FORMATS | {"gif"}
Expand All @@ -49,7 +51,7 @@

class AssetMixin:
url: str
_state: Any | None
_state: ConnectionState | None

async def read(self) -> bytes:
"""|coro|
Expand Down Expand Up @@ -77,7 +79,9 @@ async def read(self) -> bytes:

async def save(
self,
fp: str | bytes | os.PathLike | io.BufferedIOBase,
fp: (
str | bytes | os.PathLike | io.BufferedIOBase
), # pyright: ignore [reportMissingTypeArgument]
*,
seek_begin: bool = True,
) -> int:
Expand Down Expand Up @@ -117,7 +121,7 @@ async def save(
fp.seek(0)
return written
else:
with open(fp, "wb") as f:
with open(fp, "wb") as f: # pyright: ignore [reportUnknownArgumentType]
return f.write(data)


Expand Down Expand Up @@ -154,16 +158,23 @@ class Asset(AssetMixin):
"_key",
)

BASE = "https://cdn.discordapp.com"
BASE: Final = "https://cdn.discordapp.com"

def __init__(self, state, *, url: str, key: str, animated: bool = False):
self._state = state
self._url = url
self._animated = animated
self._key = key
def __init__(
self,
state: ConnectionState | None,
*,
url: str,
key: str,
animated: bool = False,
):
self._state: ConnectionState | None = state
self._url: str = url
self._animated: bool = animated
self._key: str = key

@classmethod
def _from_default_avatar(cls, state, index: int) -> Asset:
def _from_default_avatar(cls, state: ConnectionState, index: int) -> Asset:
return cls(
state,
url=f"{cls.BASE}/embed/avatars/{index}.png",
Expand All @@ -172,7 +183,7 @@ def _from_default_avatar(cls, state, index: int) -> Asset:
)

@classmethod
def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset:
def _from_avatar(cls, state: ConnectionState, user_id: int, avatar: str) -> Asset:
animated = avatar.startswith("a_")
format = "gif" if animated else "png"
return cls(
Expand All @@ -184,7 +195,10 @@ def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset:

@classmethod
def _from_avatar_decoration(
cls, state, user_id: int, avatar_decoration: str
cls,
state: ConnectionState,
user_id: int,
avatar_decoration: str, # pyright: ignore [reportUnusedParameter]
) -> Asset:
animated = avatar_decoration.startswith("a_")
endpoint = (
Expand All @@ -201,7 +215,7 @@ def _from_avatar_decoration(

@classmethod
def _from_guild_avatar(
cls, state, guild_id: int, member_id: int, avatar: str
cls, state: ConnectionState, guild_id: int, member_id: int, avatar: str
) -> Asset:
animated = avatar.startswith("a_")
format = "gif" if animated else "png"
Expand All @@ -214,7 +228,7 @@ def _from_guild_avatar(

@classmethod
def _from_guild_banner(
cls, state, guild_id: int, member_id: int, banner: str
cls, state: ConnectionState, guild_id: int, member_id: int, banner: str
) -> Asset:
animated = banner.startswith("a_")
format = "gif" if animated else "png"
Expand All @@ -226,7 +240,9 @@ def _from_guild_banner(
)

@classmethod
def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset:
def _from_icon(
cls, state: ConnectionState, object_id: int, icon_hash: str, path: str
) -> Asset:
return cls(
state,
url=f"{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024",
Expand All @@ -235,7 +251,9 @@ def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset:
)

@classmethod
def _from_cover_image(cls, state, object_id: int, cover_image_hash: str) -> Asset:
def _from_cover_image(
cls, state: ConnectionState, object_id: int, cover_image_hash: str
) -> Asset:
return cls(
state,
url=f"{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024",
Expand All @@ -244,7 +262,9 @@ def _from_cover_image(cls, state, object_id: int, cover_image_hash: str) -> Asse
)

@classmethod
def _from_guild_image(cls, state, guild_id: int, image: str, path: str) -> Asset:
def _from_guild_image(
cls, state: ConnectionState, guild_id: int, image: str, path: str
) -> Asset:
animated = False
format = "png"
if path == "banners":
Expand All @@ -259,7 +279,9 @@ def _from_guild_image(cls, state, guild_id: int, image: str, path: str) -> Asset
)

@classmethod
def _from_guild_icon(cls, state, guild_id: int, icon_hash: str) -> Asset:
def _from_guild_icon(
cls, state: ConnectionState, guild_id: int, icon_hash: str
) -> Asset:
animated = icon_hash.startswith("a_")
format = "gif" if animated else "png"
return cls(
Expand All @@ -270,7 +292,7 @@ def _from_guild_icon(cls, state, guild_id: int, icon_hash: str) -> Asset:
)

@classmethod
def _from_sticker_banner(cls, state, banner: int) -> Asset:
def _from_sticker_banner(cls, state: ConnectionState, banner: int) -> Asset:
return cls(
state,
url=f"{cls.BASE}/app-assets/710982414301790216/store/{banner}.png",
Expand All @@ -279,7 +301,9 @@ def _from_sticker_banner(cls, state, banner: int) -> Asset:
)

@classmethod
def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset:
def _from_user_banner(
cls, state: ConnectionState, user_id: int, banner_hash: str
) -> Asset:
animated = banner_hash.startswith("a_")
format = "gif" if animated else "png"
return cls(
Expand All @@ -291,7 +315,7 @@ def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset:

@classmethod
def _from_scheduled_event_image(
cls, state, event_id: int, cover_hash: str
cls, state: ConnectionState, event_id: int, cover_hash: str
) -> Asset:
return cls(
state,
Expand All @@ -300,24 +324,29 @@ def _from_scheduled_event_image(
animated=False,
)

@override
def __str__(self) -> str:
return self._url

def __len__(self) -> int:
return len(self._url)

@override
def __repr__(self):
shorten = self._url.replace(self.BASE, "")
return f"<Asset url={shorten!r}>"

def __eq__(self, other):
@override
def __eq__(self, other: Any): # pyright: ignore [reportExplicitAny]
return isinstance(other, Asset) and self._url == other._url

@override
def __hash__(self):
return hash(self._url)

@property
def url(self) -> str:
@override
def url(self) -> str: # pyright: ignore [reportIncompatibleVariableOverride]
"""Returns the underlying URL of the asset."""
return self._url

Expand Down
4 changes: 2 additions & 2 deletions discord/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

import aiohttp

from . import utils
from . import models, utils
from .activity import ActivityTypes, BaseActivity, create_activity
from .appinfo import AppInfo, PartialAppInfo
from .application_role_connection import ApplicationRoleConnectionMetadata
Expand Down Expand Up @@ -1840,7 +1840,7 @@ async def fetch_user(self, user_id: int, /) -> User:
:exc:`HTTPException`
Fetching the user failed.
"""
data = await self.http.get_user(user_id)
data: models.User = await self.http.get_user(user_id)
return User(state=self._connection, data=data)

async def fetch_channel(
Expand Down
46 changes: 34 additions & 12 deletions discord/emoji.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, Iterator

from . import models
from .asset import Asset, AssetMixin
from .partial_emoji import PartialEmoji, _EmojiTag
from .user import User
Expand Down Expand Up @@ -61,19 +63,31 @@ class BaseEmoji(_EmojiTag, AssetMixin):
"available",
)

def __init__(self, *, state: ConnectionState, data: EmojiPayload):
def __init__(self, *, state: ConnectionState, data: models.Emoji):
if isinstance(data, dict):
data = models.Emoji(**data)
warnings.warn(
"Passing a dict to Emoji is deprecated and will be removed in a future version.",
DeprecationWarning,
stacklevel=2,
)
self._state: ConnectionState = state
self._from_data(data)

def _from_data(self, emoji: EmojiPayload):
self.require_colons: bool = emoji.get("require_colons", False)
self.managed: bool = emoji.get("managed", False)
self.id: int = int(emoji["id"]) # type: ignore
self.name: str = emoji["name"] # type: ignore
self.animated: bool = emoji.get("animated", False)
self.available: bool = emoji.get("available", True)
user = emoji.get("user")
self.user: User | None = User(state=self._state, data=user) if user else None
def _from_data(self, emoji: models.Emoji):
self.require_colons: bool = (
emoji.require_colons if emoji.require_colons is not MISSING else False
)
self.managed: bool = False if emoji.managed is MISSING else bool(emoji.managed)
self.id: models.types.EmojiID = emoji.id
self.name: str = emoji.name
self.animated: bool = emoji.animated if emoji.animated is not MISSING else False
self.available: bool = (
emoji.available if emoji.available is not MISSING else True
)
self.user: User | None = (
User(state=self._state, data=emoji.user) if emoji.user else None
)

def _to_partial(self) -> PartialEmoji:
return PartialEmoji(name=self.name, animated=self.animated, id=self.id)
Expand Down Expand Up @@ -166,9 +180,17 @@ class GuildEmoji(BaseEmoji):
"guild_id",
)

def __init__(self, *, guild: Guild, state: ConnectionState, data: EmojiPayload):
def __init__(self, *, guild: Guild, state: ConnectionState, data: models.Emoji):
if isinstance(data, dict):
data = models.Emoji(**data)
warnings.warn(
"Passing a dict to Emoji is deprecated and will be removed in a future version.",
DeprecationWarning,
stacklevel=2,
)
self.guild_id: int = guild.id
self._roles: SnowflakeList = SnowflakeList(map(int, data.get("roles", [])))
if data.roles is not models.MISSING:
self._roles: SnowflakeList = SnowflakeList(map(int, data.roles))
super().__init__(state=state, data=data)

def __repr__(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion discord/ext/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1695,7 +1695,7 @@ def decorator(
func: (
Callable[Concatenate[ContextT, P], Coro[Any]]
| Callable[Concatenate[CogT, ContextT, P], Coro[Any]]
)
),
) -> CommandT:
if isinstance(func, Command):
raise TypeError("Callback is already a command.")
Expand Down
17 changes: 15 additions & 2 deletions discord/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@
import traceback
import zlib
from collections import deque, namedtuple
from typing import Any

import aiohttp
from pydantic import BaseModel

from discord import models

from . import utils
from .activity import BaseActivity
Expand Down Expand Up @@ -548,11 +552,20 @@ async def received_message(self, msg, /):
)

try:
func = self._discord_parsers[event]
func: Any = self._discord_parsers[event]
except KeyError:
_log.debug("Unknown event %s.", event)
else:
func(data)
if hasattr(func, "_supports_model") and issubclass(
func._supports_model, models.gateway.GatewayEvent
):
func(
func._supports_model(
**msg
).d # pyright: ignore [reportUnknownMemberType, reportAttributeAccessIssue]
)
else:
func(data)

# remove the dispatched listeners
removed = []
Expand Down
Loading