Skip to content

Commit 0ff4354

Browse files
committed
🚧 WIP
1 parent f4170eb commit 0ff4354

File tree

9 files changed

+86
-42
lines changed

9 files changed

+86
-42
lines changed

discord/http.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
from urllib.parse import quote as _uriquote
3434

3535
import aiohttp
36-
from typing_extensions import overload
36+
from pydantic import BaseModel, TypeAdapter
37+
from typing_extensions import overload, reveal_type
3738

3839
from . import __version__, models, utils
3940
from .errors import (
@@ -53,8 +54,6 @@
5354
if TYPE_CHECKING:
5455
from types import TracebackType
5556

56-
from pydantic import BaseModel
57-
5857
from .enums import AuditLogAction, InteractionResponseType
5958
from .file import File
6059
from .types import (
@@ -94,7 +93,8 @@
9493
Response = Coroutine[Any, Any, T] # pyright: ignore [reportExplicitAny]
9594

9695
API_VERSION: int = 10
97-
BM = TypeVar("BM", bound="BaseModel")
96+
TP = TypeVar("TP")
97+
BM = TypeVar("BM", bound=BaseModel)
9898

9999

100100
async def json_or_text(response: aiohttp.ClientResponse) -> dict[str, Any] | str:
@@ -242,15 +242,26 @@ async def request(
242242
**kwargs: Any,
243243
) -> BM: ...
244244

245+
@overload
246+
async def request(
247+
self,
248+
route: Route,
249+
*,
250+
files: None = ...,
251+
form: None = ...,
252+
model: TypeAdapter[TP],
253+
**kwargs: Any,
254+
) -> TP: ...
255+
245256
async def request(
246257
self,
247258
route: Route,
248259
*,
249260
files: Sequence[File] | None = None,
250261
form: Iterable[dict[str, Any]] | None = None,
251-
model: type[BM] | None = None,
262+
model: type[BM] | TypeAdapter[TP] | None = None,
252263
**kwargs: Any,
253-
) -> Any | BM:
264+
) -> Any | BM | TP:
254265
bucket = route.bucket
255266
method = route.method
256267
url = route.url
@@ -347,9 +358,12 @@ async def request(
347358
if 300 > response.status >= 200:
348359
_log.debug("%s %s has received %s", method, url, data)
349360
if model:
350-
return model(
351-
**data # pyright: ignore [reportCallIssue]
352-
)
361+
if isinstance(model, TypeAdapter):
362+
return model.validate_python(
363+
data
364+
) # pyright: ignore [reportUnknownVariableType]
365+
return model.model_validate(data)
366+
353367
return data
354368

355369
# we are being rate limited
@@ -1630,11 +1644,11 @@ def create_from_template(
16301644

16311645
def get_bans(
16321646
self,
1633-
guild_id: Snowflake,
1647+
guild_id: models.Snowflake,
16341648
limit: int | None = None,
1635-
before: Snowflake | None = None,
1636-
after: Snowflake | None = None,
1637-
) -> Response[list[guild.Ban]]:
1649+
before: models.Snowflake | None = None,
1650+
after: models.Snowflake | None = None,
1651+
) -> Response[list[models.Ban]]:
16381652
params: dict[str, int | Snowflake] = {}
16391653

16401654
if limit is not None:
@@ -1645,10 +1659,14 @@ def get_bans(
16451659
params["after"] = after
16461660

16471661
return self.request(
1648-
Route("GET", "/guilds/{guild_id}/bans", guild_id=guild_id), params=params
1662+
Route("GET", "/guilds/{guild_id}/bans", guild_id=guild_id),
1663+
params=params,
1664+
model=TypeAdapter(list[models.Ban]),
16491665
)
16501666

1651-
def get_ban(self, user_id: Snowflake, guild_id: Snowflake) -> Response[models.Ban]:
1667+
def get_ban(
1668+
self, user_id: models.Snowflake, guild_id: models.Snowflake
1669+
) -> Response[models.Ban]:
16521670
return self.request(
16531671
Route(
16541672
"GET",

discord/iterators.py

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838
Union,
3939
)
4040

41+
from typing_extensions import Final, override, reveal_type
42+
43+
from . import models
4144
from .audit_logs import AuditLogEntry
4245
from .errors import NoMoreItems
4346
from .monetization import Entitlement
@@ -57,9 +60,11 @@
5760
if TYPE_CHECKING:
5861
from .abc import Snowflake
5962
from .guild import BanEntry, Guild
63+
from .http import HTTPClient
6064
from .member import Member
6165
from .message import Message
6266
from .scheduled_events import ScheduledEvent
67+
from .state import ConnectionState
6368
from .threads import Thread
6469
from .types.audit_log import AuditLog as AuditLogPayload
6570
from .types.guild import Guild as GuildPayload
@@ -737,16 +742,26 @@ def create_member(self, data):
737742

738743

739744
class BanIterator(_AsyncIterator["BanEntry"]):
740-
def __init__(self, guild, limit=None, before=None, after=None):
741-
self.guild = guild
742-
self.limit = limit
743-
self.after = after
744-
self.before = before
745-
746-
self.state = self.guild._state
747-
self.get_bans = self.state.http.get_bans
748-
self.bans = asyncio.Queue()
749-
745+
def __init__(
746+
self,
747+
guild: Guild,
748+
limit: int | None = None,
749+
before: models.Snowflake | None = None,
750+
after: models.Snowflake | None = None,
751+
):
752+
self.guild: Guild = guild
753+
self.limit: int | None = limit
754+
self.after: models.Snowflake | None = after
755+
self.before: models.Snowflake | None = before
756+
self.retrieve: int = 0
757+
758+
self.state: ConnectionState = (
759+
self.guild._state
760+
) # pyright: ignore [reportPrivateUsage]
761+
self.get_bans: Final = self.state.http.get_bans
762+
self.bans: asyncio.Queue[BanEntry] = asyncio.Queue()
763+
764+
@override
750765
async def next(self) -> BanEntry:
751766
if self.bans.empty():
752767
await self.fill_bans()
@@ -757,20 +772,20 @@ async def next(self) -> BanEntry:
757772
raise NoMoreItems()
758773

759774
def _get_retrieve(self):
760-
l = self.limit
761-
if l is None or l > 1000:
762-
r = 1000
775+
if self.limit is None or self.limit > 1000:
776+
self.retrieve = 1000
763777
else:
764-
r = l
765-
self.retrieve = r
766-
return r > 0
778+
self.retrieve = self.limit
779+
return self.retrieve > 0
767780

768781
async def fill_bans(self):
769782
if not self._get_retrieve():
770783
return
771-
before = self.before.id if self.before else None
772-
after = self.after.id if self.after else None
773-
data = await self.get_bans(self.guild.id, self.retrieve, before, after)
784+
before: models.Snowflake | None = self.before if self.before else None
785+
after: models.Snowflake | None = self.after if self.after else None
786+
data = await self.get_bans(
787+
models.Snowflake(self.guild.id), self.retrieve, before, after
788+
)
774789
if not data:
775790
# no data, terminate
776791
return
@@ -780,18 +795,16 @@ async def fill_bans(self):
780795
if len(data) < 1000:
781796
self.limit = 0 # terminate loop
782797

783-
self.after = Object(id=int(data[-1]["user"]["id"]))
798+
self.after = data[-1].user.id
784799

785800
for element in reversed(data):
786801
await self.bans.put(self.create_ban(element))
787802

788-
def create_ban(self, data):
803+
def create_ban(self, data: models.Ban) -> BanEntry:
789804
from .guild import BanEntry
790805
from .user import User
791806

792-
return BanEntry(
793-
reason=data["reason"], user=User(state=self.state, data=data["user"])
794-
)
807+
return BanEntry(reason=data.reason, user=User(state=self.state, data=data.user))
795808

796809

797810
class ArchivedThreadIterator(_AsyncIterator["Thread"]):

discord/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
UnavailableGuild,
1111
User,
1212
)
13+
from .types import Snowflake
1314
from .types.utils import MISSING
1415

1516
__all__ = (
@@ -24,4 +25,5 @@
2425
"AvatarDecorationData",
2526
"gateway",
2627
"Ban",
28+
"Snowflake",
2729
)

discord/models/api/__init__.py

Whitespace-only changes.

discord/models/api/ban.py

Whitespace-only changes.

discord/models/types/flags/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313

1414
class flag_value:
1515
def __init__(
16-
self, func: Callable[[Any], int]
17-
): # pyright: ignore [reportExplicitAny]
16+
self, func: Callable[[Any], int] # pyright: ignore [reportExplicitAny]
17+
):
1818
self.flag: int = func(None)
1919
self.__doc__ = func.__doc__
2020

discord/models/types/snowflake.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ def increment(self) -> int:
4343
"""Returns the increment count."""
4444
return self & 0xFFF
4545

46+
@property
47+
def id(self) -> int:
48+
return int(self)
49+
4650
@classmethod
4751
def __get_pydantic_core_schema__(
4852
cls,

discord/models/types/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing_extensions import TypeAlias, final, override
1+
from typing_extensions import TypeAlias, final
22

33

44
@final

discord/user.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from __future__ import annotations
2727

28+
import warnings
2829
from typing import TYPE_CHECKING, Any, TypeVar
2930

3031
from typing_extensions import override
@@ -574,6 +575,12 @@ class User(BaseUser, discord.abc.Messageable):
574575
def __init__(self, *, state: ConnectionState, data: models.User) -> None:
575576
if isinstance(data, dict):
576577
data = models.User(**data)
578+
warnings.warn(
579+
"Passing a dict to User is deprecated and will be removed in a future version.",
580+
DeprecationWarning,
581+
stacklevel=2,
582+
)
583+
577584
super().__init__(state=state, data=data)
578585
self._stored: bool = False
579586

0 commit comments

Comments
 (0)