Skip to content

Commit

Permalink
🐛 修复无法正确发送 heartbeat 的问题 #18
Browse files Browse the repository at this point in the history
  • Loading branch information
CMHopeSunshine committed Mar 24, 2024
1 parent 387f583 commit 890b262
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 77 deletions.
19 changes: 11 additions & 8 deletions nonebot/adapters/discord/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing_extensions import override

from nonebot.adapters import Adapter as BaseAdapter
from nonebot.compat import model_dump, type_validate_python
from nonebot.compat import type_validate_json, type_validate_python
from nonebot.drivers import URL, Driver, ForwardDriver, Request, WebSocket
from nonebot.exception import WebSocketClosed
from nonebot.plugin import get_plugin_config
Expand All @@ -31,7 +31,7 @@
Reconnect,
Resume,
)
from .utils import decompress_data, log
from .utils import decompress_data, log, model_dump

RECONNECT_INTERVAL = 3.0

Expand Down Expand Up @@ -131,7 +131,7 @@ async def _get_gateway_bot(self, bot_info: BotInfo) -> GatewayBot:
resp = await self.request(request)
if not resp.content:
raise ValueError("Failed to get gateway info")
return type_validate_python(GatewayBot, json.loads(resp.content))
return type_validate_json(GatewayBot, resp.content)

async def _get_bot_user(self, bot_info: BotInfo) -> User:
headers = {"Authorization": self.get_authorization(bot_info)}
Expand All @@ -145,7 +145,7 @@ async def _get_bot_user(self, bot_info: BotInfo) -> User:
resp = await self.request(request)
if not resp.content:
raise ValueError("Failed to get bot user info")
return type_validate_python(User, json.loads(resp.content))
return type_validate_json(User, resp.content)

async def _forward_ws(
self,
Expand Down Expand Up @@ -229,7 +229,8 @@ async def _forward_ws(
if heartbeat_task:
heartbeat_task.cancel()
heartbeat_task = None
self.bot_disconnect(bot)
if bot.self_id in self.bots:
self.bot_disconnect(bot)

except Exception as e:
log(
Expand Down Expand Up @@ -279,7 +280,7 @@ async def _heartbeat(ws: WebSocket, bot: Bot):
{"data": bot.sequence if bot.has_sequence else None},
)
with contextlib.suppress(Exception):
await ws.send(json.dumps(model_dump(payload)))
await ws.send(json.dumps(model_dump(payload, by_alias=True)))

async def _heartbeat_task(self, ws: WebSocket, bot: Bot, heartbeat_interval: int):
"""心跳任务"""
Expand Down Expand Up @@ -319,7 +320,9 @@ async def _authenticate(self, bot: Bot, ws: WebSocket, shard: Tuple[int, int]):
)

try:
await ws.send(json.dumps(model_dump(payload, exclude_none=True)))
await ws.send(
json.dumps(model_dump(payload, by_alias=True, exclude_none=True))
)
except Exception as e:
log(
"ERROR",
Expand Down Expand Up @@ -429,7 +432,7 @@ def get_authorization(bot_info: BotInfo) -> str:
async def receive_payload(self, ws: WebSocket) -> Payload:
data = await ws.receive()
data = decompress_data(data, self.discord_config.discord_compress)
return type_validate_python(PayloadType, json.loads(data)) # type: ignore
return type_validate_json(PayloadType, data) # type: ignore

@classmethod
def payload_to_event(cls, payload: Dispatch) -> Event:
Expand Down
7 changes: 4 additions & 3 deletions nonebot/adapters/discord/api/handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
)
from urllib.parse import quote

from nonebot.compat import model_dump, type_validate_python
from nonebot.compat import type_validate_python
from nonebot.drivers import Request

from .model import *
from .request import _request
from .utils import parse_data, parse_forum_thread_message, parse_interaction_response
from ..utils import model_dump

if TYPE_CHECKING:
from ..adapter import Adapter
Expand Down Expand Up @@ -2508,8 +2509,8 @@ async def _get_stage_instance(
url=adapter.base_url / f"stage-instances/{channel_id}",
)
return type_validate_python(
Optional[StageInstance],
await _request(adapter, bot, request), # type: ignore
Optional[StageInstance], # type: ignore
await _request(adapter, bot, request),
)


Expand Down
75 changes: 13 additions & 62 deletions nonebot/adapters/discord/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,75 +18,23 @@
from nonebot.compat import PYDANTIC_V2

from pydantic import (
BaseModel as PydanticBaseModel,
BaseModel,
Field,
)
from pydantic.generics import GenericModel

if PYDANTIC_V2:
GenericModel = BaseModel
from pydantic_core import CoreSchema, core_schema

from .types import *

if TYPE_CHECKING:
if PYDANTIC_V2:
from pydantic.main import IncEx
else:
if TYPE_CHECKING:
GenericModel = BaseModel
else:
from pydantic.typing import AbstractSetIntStr, DictStrAny, MappingIntStrAny

T = TypeVar("T", str, int, float)
from pydantic.generics import GenericModel


class BaseModel(PydanticBaseModel):
if PYDANTIC_V2:

def model_dump(
self,
*,
include: "IncEx" = None,
exclude: "IncEx" = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> Dict[str, Any]:
data = super().model_dump(
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
# exclude UNSET
if exclude_unset or exclude_none:
data = {key: value for key, value in data.items() if value is not UNSET}
return data

else:
from .types import *

def dict(
self,
*,
include: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None,
exclude: Optional[Union["AbstractSetIntStr", "MappingIntStrAny"]] = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> "DictStrAny":
data = super().dict(
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
# exclude UNSET
if exclude_unset or exclude_none:
data = {key: value for key, value in data.items() if value is not UNSET}
return data
T = TypeVar("T", str, int, float)


@final
Expand Down Expand Up @@ -3380,8 +3328,11 @@ class AuthorizationResponse(BaseModel):


for name, obj in inspect.getmembers(sys.modules[__name__]):
if inspect.isclass(obj) and issubclass(obj, BaseModel):
obj.update_forward_refs()
if inspect.isclass(obj) and issubclass(obj, BaseModel) and obj is not BaseModel:
if PYDANTIC_V2:
obj.model_rebuild()
else:
obj.update_forward_refs()

__all__ = [
"BaseModel",
Expand Down
3 changes: 2 additions & 1 deletion nonebot/adapters/discord/api/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import json
from typing import Any, Dict, List, Literal, Type, Union

from nonebot.compat import model_dump, type_validate_python
from nonebot.compat import type_validate_python

from .model import (
ExecuteWebhookParams,
InteractionCallbackMessage,
InteractionResponse,
MessageSend,
)
from ..utils import model_dump


def parse_data(
Expand Down
3 changes: 1 addition & 2 deletions nonebot/adapters/discord/commands/storage.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, List, Literal

from nonebot.compat import model_dump

from ..api import ApplicationCommandCreate, Snowflake
from ..bot import Bot
from ..utils import model_dump

if TYPE_CHECKING:
from .matcher import ApplicationCommandConfig
Expand Down
43 changes: 42 additions & 1 deletion nonebot/adapters/discord/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,52 @@
from typing import Union
from typing import Any, Dict, Optional, Set, Union
import zlib

from nonebot.compat import model_dump as model_dump_
from nonebot.utils import logger_wrapper

from pydantic import BaseModel

from .api.types import UNSET

log = logger_wrapper("Discord")


def exclude_unset_data(data: Any) -> Any:
if isinstance(data, dict):
return data.__class__(
(k, exclude_unset_data(v)) for k, v in data.items() if v is not UNSET
)
elif isinstance(data, list):
return data.__class__(exclude_unset_data(i) for i in data)
elif data is UNSET:
return None
return data


def model_dump(
model: BaseModel,
include: Optional[Set[str]] = None,
exclude: Optional[Set[str]] = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> Dict[str, Any]:
data = model_dump_(
model,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
if exclude_none or exclude_unset:
return exclude_unset_data(data)
else:
return data


def escape(s: str) -> str:
return s.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")

Expand Down

0 comments on commit 890b262

Please sign in to comment.