Skip to content

Commit

Permalink
🐛 version 0.3.0
Browse files Browse the repository at this point in the history
fix array response
  • Loading branch information
RF-Tar-Railt committed Oct 10, 2023
1 parent 9b8dde6 commit 01f4d09
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 38 deletions.
2 changes: 1 addition & 1 deletion nonebot/adapters/satori/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,5 +284,5 @@ async def _call_api(self, bot: Bot, api: str, **data: Any) -> Any:
log("DEBUG", f"Bot {bot.self_id} calling API <y>{api}</y>")
api_handler: Optional[API] = getattr(bot.__class__, api, None)
if api_handler is None:
raise ApiNotAvailable
raise ApiNotAvailable(api)
return await api_handler(bot, **data)
92 changes: 83 additions & 9 deletions nonebot/adapters/satori/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .event import Event, MessageEvent
from .message import Message, MessageSegment
from .models import InnerMessage as SatoriMessage
from .models import Role, User, Guild, Login, Channel, OuterMember
from .models import Role, User, Guild, Login, Channel, PageResult, OuterMember
from .exception import (
ActionFailed,
NetworkError,
Expand All @@ -22,6 +22,7 @@
BadRequestException,
UnauthorizedException,
MethodNotAllowedException,
ApiNotImplementedException,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -194,6 +195,8 @@ def _handle_response(self, response: Response) -> Any:
raise NotFoundException(response)
elif response.status_code == 405:
raise MethodNotAllowedException(response)
elif response.status_code == 500:
raise ApiNotImplementedException(response)
else:
raise ActionFailed(response)

Expand Down Expand Up @@ -308,8 +311,7 @@ async def message_list(self, *, channel_id: str, next_token: Optional[str] = Non
self.info.api_base / "message.list",
json={"channel_id": channel_id, "next": next_token},
)
res = await self._request(request)
return [SatoriMessage.parse_obj(i) for i in res]
return PageResult[SatoriMessage].parse_obj(await self._request(request))

@API
async def channel_get(self, *, channel_id: str):
Expand All @@ -328,8 +330,7 @@ async def channel_list(self, *, guild_id: str, next_token: Optional[str] = None)
self.info.api_base / "channel.list",
json={"guild_id": guild_id, "next": next_token},
)
res = await self._request(request)
return [Channel.parse_obj(i) for i in res]
return PageResult[Channel].parse_obj(await self._request(request))

@API
async def channel_create(self, *, guild_id: str, data: Channel):
Expand Down Expand Up @@ -388,7 +389,7 @@ async def guild_list(self, *, next_token: Optional[str] = None):
self.info.api_base / "guild.list",
json={"next": next_token},
)
return [Guild.parse_obj(i) for i in await self._request(request)]
return PageResult[Guild].parse_obj(await self._request(request))

@API
async def guild_approve(self, *, request_id: str, approve: bool, comment: str):
Expand All @@ -408,7 +409,7 @@ async def guild_member_list(
self.info.api_base / "guild.member.list",
json={"guild_id": guild_id, "next": next_token},
)
return [OuterMember.parse_obj(i) for i in await self._request(request)]
return PageResult[OuterMember].parse_obj(await self._request(request))

@API
async def guild_member_get(self, *, guild_id: str, user_id: str):
Expand Down Expand Up @@ -468,7 +469,7 @@ async def guild_role_list(self, guild_id: str, next_token: Optional[str] = None)
self.info.api_base / "guild.role.list",
json={"guild_id": guild_id, "next": next_token},
)
return [Role.parse_obj(i) for i in await self._request(request)]
return PageResult[Role].parse_obj(await self._request(request))

@API
async def guild_role_create(
Expand Down Expand Up @@ -508,6 +509,79 @@ async def guild_role_delete(self, *, guild_id: str, role_id: str):
)
await self._request(request)

@API
async def reaction_create(
self,
*,
channel_id: str,
message_id: str,
emoji: str,
):
request = Request(
"POST",
self.info.api_base / "reaction.create",
json={"channel_id": channel_id, "message_id": message_id, "emoji": emoji},
)
await self._request(request)

@API
async def reaction_delete(
self,
*,
channel_id: str,
message_id: str,
emoji: str,
user_id: Optional[str] = None,
):
data = {"channel_id": channel_id, "message_id": message_id, "emoji": emoji}
if user_id is not None:
data["user_id"] = user_id
request = Request(
"POST",
self.info.api_base / "reaction.delete",
json=data,
)
await self._request(request)

@API
async def reaction_clear(
self,
*,
channel_id: str,
message_id: str,
emoji: Optional[str] = None,
):
data = {"channel_id": channel_id, "message_id": message_id}
if emoji is not None:
data["emoji"] = emoji
request = Request(
"POST",
self.info.api_base / "reaction.clear",
json=data,
)
await self._request(request)

@API
async def reaction_list(
self,
*,
channel_id: str,
message_id: str,
emoji: str,
next_token: Optional[str] = None,
):
request = Request(
"POST",
self.info.api_base / "reaction.list",
json={
"channel_id": channel_id,
"message_id": message_id,
"emoji": emoji,
"next": next_token,
},
)
return PageResult[User].parse_obj(await self._request(request))

@API
async def login_get(self):
request = Request(
Expand All @@ -532,7 +606,7 @@ async def friend_list(self, *, next_token: Optional[str] = None):
self.info.api_base / "friend.list",
json={"next": next_token},
)
return [User.parse_obj(i) for i in await self._request(request)]
return PageResult[User].parse_obj(await self._request(request))

@API
async def friend_approve(self, *, request_id: str, approve: bool, comment: str):
Expand Down
34 changes: 12 additions & 22 deletions nonebot/adapters/satori/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def get_type(self) -> str:


class FriendEvent(NoticeEvent):
channel: Channel
user: User

@override
Expand All @@ -105,12 +104,11 @@ class FriendRequestEvent(FriendEvent):


class GuildEvent(NoticeEvent):
channel: Channel
guild: Guild

@override
def get_session_id(self) -> str:
return f"{self.guild.id}/{self.channel.id}"
return self.guild.id


@register_event_class
Expand All @@ -134,15 +132,16 @@ class GuildUpdatedEvent(GuildEvent):


class GuildInnerMemberEvent(GuildEvent):
member: InnerMember
user: User

@override
def get_user_id(self) -> str:
return self.member.user.id if self.member else self.user.id
return self.user.id

@override
def get_session_id(self) -> str:
return f"{self.guild.id}/{self.channel.id}/{self.get_user_id()}"
return f"{self.guild.id}/{self.get_user_id()}"


@register_event_class
Expand All @@ -152,7 +151,6 @@ class GuildInnerMemberAddedEvent(GuildInnerMemberEvent):

@register_event_class
class GuildInnerMemberRemovedEvent(GuildInnerMemberEvent):
member: InnerMember
__type__ = EventType.GUILD_MEMBER_REMOVED


Expand All @@ -163,7 +161,6 @@ class GuildInnerMemberRequestEvent(GuildInnerMemberEvent):

@register_event_class
class GuildInnerMemberUpdatedEvent(GuildInnerMemberEvent):
member: InnerMember
__type__ = EventType.GUILD_MEMBER_UPDATED


Expand All @@ -172,7 +169,7 @@ class GuildRoleEvent(GuildEvent):

@override
def get_session_id(self) -> str:
return f"{self.guild.id}/{self.channel.id}/{self.role.id}"
return f"{self.guild.id}/{self.role.id}"


@register_event_class
Expand All @@ -192,15 +189,6 @@ class GuildRoleUpdatedEvent(GuildRoleEvent):

class LoginEvent(NoticeEvent):
login: Login
user: User

@override
def get_user_id(self) -> str:
return self.user.id

@override
def get_session_id(self) -> str:
return self.user.id


@register_event_class
Expand Down Expand Up @@ -303,12 +291,14 @@ def get_user_id(self) -> str:


class PublicMessageEvent(MessageEvent):
guild: Guild
member: InnerMember

@override
def get_session_id(self) -> str:
return f"{self.guild.id}/{self.channel.id}/{self.user.id}"
s = f"{self.channel.id}/{self.user.id}"
if self.guild:
s = f"{self.guild.id}/{s}"
return s

@override
def get_user_id(self) -> str:
Expand All @@ -330,7 +320,7 @@ def get_event_description(self) -> str:
return escape_tag(
f"Message {self.msg_id} from "
f"{self.user.name or ''}({self.channel.id})"
f"@[{self.channel.name or ''}:{self.guild.id}/{self.channel.id}]"
f"@[{self.channel.name or ''}:{self.channel.id}]"
f": {self.get_message()!r}"
)

Expand All @@ -350,7 +340,7 @@ def get_event_description(self) -> str:
return escape_tag(
f"Message {self.msg_id} from "
f"{self.user.name or ''}({self.channel.id})"
f"@[{self.channel.name or ''}:{self.guild.id}/{self.channel.id}] deleted"
f"@[{self.channel.name or ''}:{self.channel.id}] deleted"
)


Expand All @@ -370,7 +360,7 @@ def get_event_description(self) -> str:
return escape_tag(
f"Message {self.msg_id} from "
f"{self.user.name or ''}({self.channel.id})"
f"@[{self.channel.name or ''}:{self.guild.id}/{self.channel.id}] updated"
f"@[{self.channel.name or ''}:{self.channel.id}] updated"
f": {self.get_message()!r}"
)

Expand Down
15 changes: 11 additions & 4 deletions nonebot/adapters/satori/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def __str__(self):
return self.__repr__()

def _prepare_body(self, body: dict):
self.code = body.get("code", None)
self.message = body.get("message", None)
self.data = body.get("data", None)
self.code = body.get("code")
self.message = body.get("message")
self.data = body.get("data")


class BadRequestException(ActionFailed):
Expand All @@ -58,6 +58,10 @@ class MethodNotAllowedException(ActionFailed):
pass


class ApiNotImplementedException(ActionFailed):
pass


class NetworkError(BaseNetworkError, SatoriAdapterException):
def __init__(self, msg: Optional[str] = None):
super().__init__()
Expand All @@ -72,4 +76,7 @@ def __str__(self):


class ApiNotAvailable(BaseApiNotAvailable, SatoriAdapterException):
pass
def __init__(self, msg: Optional[str] = None):
super().__init__()
self.msg: Optional[str] = msg
"""错误原因"""
11 changes: 10 additions & 1 deletion nonebot/adapters/satori/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from enum import IntEnum
from datetime import datetime
from typing import Any, Dict, List, Union, Literal, Optional
from typing import Any, Dict, List, Union, Generic, Literal, TypeVar, Optional

from pydantic.generics import GenericModel
from pydantic import Extra, Field, BaseModel, validator, root_validator

from .utils import Element, log, parse
Expand Down Expand Up @@ -220,3 +221,11 @@ class EventPayload(Payload):
Union[EventPayload, PingPayload, PongPayload, IdentifyPayload, ReadyPayload],
Payload,
]


T = TypeVar("T")


class PageResult(GenericModel, Generic[T], extra=Extra.allow):
data: List[T]
next: Optional[str] = None
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "nonebot-adapter-satori"
version = "0.2.0"
version = "0.3.0"
description = "Satori Protocol Adapter for Nonebot2"
authors = [
{name = "RF-Tar-Railt",email = "rf_tar_railt@qq.com"},
Expand Down

0 comments on commit 01f4d09

Please sign in to comment.