Skip to content

Commit

Permalink
Merge: pull request #1404 from interactions-py/unstable
Browse files Browse the repository at this point in the history
5.4.0
  • Loading branch information
LordOfPolls authored May 17, 2023
2 parents 3655951 + 79e9b97 commit 1e85c0f
Show file tree
Hide file tree
Showing 13 changed files with 182 additions and 34 deletions.
2 changes: 1 addition & 1 deletion docs/src/Guides/99 2.x Migration_NAFF.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ from interactions.ext import prefixed_commands
client = Client(..., intents=Intents.GUILD_MESSAGES | ...)
prefixed_commands.setup(client)
```
From here it's more or less the same as before. You can find a guide on how to use prefixed commands [here](/Guides/26 Prefixed Commands.md).
From here it's more or less the same as before. You can find a guide on how to use prefixed commands [here](/Guides/26 Prefixed Commands/).

## Hybrid Commands
For now, hybrid commands are not supported, but they will be in the future.
Expand Down
29 changes: 27 additions & 2 deletions interactions/api/http/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import aiohttp
import discord_typings
from aiohttp import BaseConnector, ClientSession, ClientWebSocketResponse, FormData
from aiohttp import BaseConnector, ClientSession, ClientWebSocketResponse, FormData, BasicAuth
from multidict import CIMultiDictProxy

import interactions.client.const as constants
Expand Down Expand Up @@ -214,7 +214,11 @@ class HTTPClient(
"""A http client for sending requests to the Discord API."""

def __init__(
self, connector: BaseConnector | None = None, logger: Logger = MISSING, show_ratelimit_tracebacks: bool = False
self,
connector: BaseConnector | None = None,
logger: Logger = MISSING,
show_ratelimit_tracebacks: bool = False,
proxy: tuple[str, BasicAuth] | None = None,
) -> None:
self.connector: BaseConnector | None = connector
self.__session: ClientSession | None = None
Expand All @@ -229,6 +233,8 @@ def __init__(
self.user_agent: str = (
f"DiscordBot ({__repo_url__} {__version__} Python/{__py_version__}) aiohttp/{aiohttp.__version__}"
)
self.proxy: tuple[str, BasicAuth] | None = proxy
self.__proxy_validated: bool = False

if logger is MISSING:
logger = constants.get_logger()
Expand Down Expand Up @@ -384,6 +390,10 @@ async def request( # noqa: C901
kwargs["json"] = processed_data # pyright: ignore
await self.global_lock.wait()

if self.proxy:
kwargs["proxy"] = self.proxy[0]
kwargs["proxy_auth"] = self.proxy[1]

async with self.__session.request(route.method, route.url, **kwargs) as response:
result = await response_decode(response)
self.ingest_ratelimit(route, response.headers, lock)
Expand Down Expand Up @@ -505,6 +515,19 @@ async def login(self, token: str) -> dict[str, Any]:
connector=self.connector or aiohttp.TCPConnector(limit=self.global_lock.max_requests),
json_serialize=FastJson.dumps,
)
if not self.__proxy_validated and self.proxy:
try:
self.logger.info(f"Validating Proxy @ {self.proxy[0]}")
async with self.__session.get(
"http://icanhazip.com/", proxy=self.proxy[0], proxy_auth=self.proxy[1]
) as response:
if response.status != 200:
raise RuntimeError("Proxy configuration is invalid")
self.logger.info(f"Proxy Connected @ {(await response.text()).strip()}")
self.__proxy_validated = True
except Exception as e:
raise RuntimeError("Proxy configuration is invalid") from e

self.token = token
try:
result = await self.request(Route("GET", "/users/@me"))
Expand Down Expand Up @@ -556,4 +579,6 @@ async def websocket_connect(self, url: str) -> ClientWebSocketResponse:
autoclose=False,
headers={"User-Agent": self.user_agent},
compress=0,
proxy=self.proxy[0] if self.proxy else None,
proxy_auth=self.proxy[1] if self.proxy else None,
)
4 changes: 2 additions & 2 deletions interactions/api/http/http_requests/guild.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ async def get_guild(self, guild_id: "Snowflake_Type", with_counts: bool = True)
a guild object
"""
params = {"guild_id": guild_id, "with_counts": int(with_counts)}
result = await self.request(Route("GET", "/guilds/{guild_id}"), params=params)
params = {"with_counts": int(with_counts)}
result = await self.request(Route("GET", "/guilds/{guild_id}", guild_id=guild_id), params=params)
return cast(discord_typings.GuildData, result)

async def get_guild_preview(self, guild_id: "Snowflake_Type") -> discord_typings.GuildPreviewData:
Expand Down
2 changes: 1 addition & 1 deletion interactions/api/voice/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def __init__(self, src: Union[str, Path]) -> None:

self.ffmpeg_before_args = ""
self.ffmpeg_args = ""
self.probe: bool = True
self.probe: bool = False

def __repr__(self) -> str:
return f"<{type(self).__name__}: {self.source}>"
Expand Down
86 changes: 74 additions & 12 deletions interactions/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
Tuple,
)

from aiohttp import BasicAuth

import interactions.api.events as events
import interactions.client.const as constants
from interactions.api.events import BaseEvent, RawGatewayEvent, processors
Expand Down Expand Up @@ -88,11 +90,17 @@
Intents,
InteractionType,
Status,
MessageFlags,
)
from interactions.models.discord.file import UPLOADABLE_TYPE
from interactions.models.discord.snowflake import Snowflake, to_snowflake_list
from interactions.models.internal.active_voice_state import ActiveVoiceState
from interactions.models.internal.application_commands import ContextMenu, ModalCommand, GlobalAutoComplete
from interactions.models.internal.application_commands import (
ContextMenu,
ModalCommand,
GlobalAutoComplete,
CallbackType,
)
from interactions.models.internal.auto_defer import AutoDefer
from interactions.models.internal.callback import CallbackObject
from interactions.models.internal.command import BaseCommand
Expand All @@ -111,10 +119,8 @@
if TYPE_CHECKING:
from interactions.models import Snowflake_Type, TYPE_ALL_CHANNEL


__all__ = ("Client",)


# see https://discord.com/developers/docs/topics/gateway#list-of-intents
_INTENT_EVENTS: dict[BaseEvent, list[Intents]] = {
# Intents.GUILDS
Expand Down Expand Up @@ -225,6 +231,7 @@ class Client(
enforce_interaction_perms: Enforce discord application command permissions, locally
fetch_members: Should the client fetch members from guilds upon startup (this will delay the client being ready)
send_command_tracebacks: Automatically send uncaught tracebacks if a command throws an exception
send_not_ready_messages: Send a message to the user if they try to use a command before the client is ready
auto_defer: AutoDefer: A system to automatically defer commands after a set duration
interaction_context: Type[InteractionContext]: InteractionContext: The object to instantiate for Interaction Context
Expand All @@ -241,6 +248,9 @@ class Client(
logging_level: The level of logging to use for basic_logging. Do not use in combination with `Client.logger`
logger: The logger interactions.py should use. Do not use in combination with `Client.basic_logging` and `Client.logging_level`. Note: Different loggers with multiple clients are not supported
proxy: A http/https proxy to use for all requests
proxy_auth: The auth to use for the proxy - must be either a tuple of (username, password) or aiohttp.BasicAuth
Optionally, you can configure the caches here, by specifying the name of the cache, followed by a dict-style object to use.
It is recommended to use `smart_cache.create_cache` to configure the cache here.
as an example, this is a recommended attribute `message_cache=create_cache(250, 50)`,
Expand Down Expand Up @@ -277,12 +287,15 @@ def __init__(
modal_context: Type[BaseContext] = ModalContext,
owner_ids: Iterable["Snowflake_Type"] = (),
send_command_tracebacks: bool = True,
send_not_ready_messages: bool = False,
shard_id: int = 0,
show_ratelimit_tracebacks: bool = False,
slash_context: Type[BaseContext] = SlashContext,
status: Status = Status.ONLINE,
sync_ext: bool = True,
sync_interactions: bool = True,
proxy_url: str | None = None,
proxy_auth: BasicAuth | tuple[str, str] | None = None,
token: str | None = None,
total_shards: int = 1,
**kwargs,
Expand Down Expand Up @@ -312,6 +325,8 @@ def __init__(
"""Sync global commands as guild for quicker command updates during debug"""
self.send_command_tracebacks: bool = send_command_tracebacks
"""Should the traceback of command errors be sent in reply to the command invocation"""
self.send_not_ready_messages: bool = send_not_ready_messages
"""Should the bot send a message when it is not ready yet in response to a command invocation"""
if auto_defer is True:
auto_defer = AutoDefer(enabled=True)
else:
Expand All @@ -321,8 +336,12 @@ def __init__(
self.intents = intents if isinstance(intents, Intents) else Intents(intents)

# resources
if isinstance(proxy_auth, tuple):
proxy_auth = BasicAuth(*proxy_auth)

self.http: HTTPClient = HTTPClient(logger=self.logger, show_ratelimit_tracebacks=show_ratelimit_tracebacks)
self.http: HTTPClient = HTTPClient(
logger=self.logger, show_ratelimit_tracebacks=show_ratelimit_tracebacks, proxy=(proxy_url, proxy_auth)
)
"""The HTTP client to use when interacting with discord endpoints"""

# context factories
Expand Down Expand Up @@ -386,6 +405,7 @@ def __init__(
self._component_callbacks: Dict[str, Callable[..., Coroutine]] = {}
self._regex_component_callbacks: Dict[re.Pattern, Callable[..., Coroutine]] = {}
self._modal_callbacks: Dict[str, Callable[..., Coroutine]] = {}
self._regex_modal_callbacks: Dict[re.Pattern, Callable[..., Coroutine]] = {}
self._global_autocompletes: Dict[str, GlobalAutoComplete] = {}
self.processors: Dict[str, Callable[..., Coroutine]] = {}
self.__modules = {}
Expand Down Expand Up @@ -684,7 +704,7 @@ async def on_command_error(self, event: events.CommandError) -> None:
embeds=Embed(
title=f"Error: {type(event.error).__name__}",
color=BrandColors.RED,
description=f"```\n{out[:EMBED_MAX_DESC_LENGTH-8]}```",
description=f"```\n{out[:EMBED_MAX_DESC_LENGTH - 8]}```",
)
)

Expand Down Expand Up @@ -1305,9 +1325,14 @@ def add_modal_callback(self, command: ModalCommand) -> None:
command: The command to add
"""
for listener in command.listeners:
if listener in self._modal_callbacks.keys():
raise ValueError(f"Duplicate Component! Multiple modal callbacks for `{listener}`")
self._modal_callbacks[listener] = command
if isinstance(listener, re.Pattern):
if listener in self._regex_component_callbacks.keys():
raise ValueError(f"Duplicate Component! Multiple modal callbacks for `{listener}`")
self._regex_modal_callbacks[listener] = command
else:
if listener in self._modal_callbacks.keys():
raise ValueError(f"Duplicate Component! Multiple modal callbacks for `{listener}`")
self._modal_callbacks[listener] = command
continue

def add_global_autocomplete(self, callback: GlobalAutoComplete) -> None:
Expand Down Expand Up @@ -1559,8 +1584,7 @@ def _build_sync_payload(

for local_cmd in self.interactions_by_scope.get(cmd_scope, {}).values():
remote_cmd_json = next(
(v for v in remote_commands if int(v["id"]) == local_cmd.cmd_id.get(cmd_scope)),
None,
(c for c in remote_commands if int(c["id"]) == int(local_cmd.cmd_id.get(cmd_scope, 0))), None
)
local_cmd_json = next((c for c in local_cmds_json[cmd_scope] if c["name"] == str(local_cmd.name)))

Expand Down Expand Up @@ -1696,6 +1720,32 @@ async def get_context(self, data: dict) -> InteractionContext:
self.logger.debug(f"Failed to fetch channel data for {data['channel_id']}")
return cls

async def handle_pre_ready_response(self, data: dict) -> None:
"""
Respond to an interaction that was received before the bot was ready.
Args:
data: The interaction data
"""
if data["type"] == InteractionType.AUTOCOMPLETE:
# we do not want to respond to autocompletes as discord will cache the response,
# so we just ignore them
return

with contextlib.suppress(HTTPException):
await self.http.post_initial_response(
{
"type": CallbackType.CHANNEL_MESSAGE_WITH_SOURCE,
"data": {
"content": f"{self.user.display_name} is starting up. Please try again in a few seconds",
"flags": MessageFlags.EPHEMERAL,
},
},
token=data["token"],
interaction_id=data["id"],
)

async def _run_slash_command(self, command: SlashCommand, ctx: "InteractionContext") -> Any:
"""Overrideable method that executes slash commands, can be used to wrap callback execution"""
return await command(ctx, **ctx.kwargs)
Expand All @@ -1713,6 +1763,8 @@ async def _dispatch_interaction(self, event: RawGatewayEvent) -> None: # noqa:

if not self._startup:
self.logger.warning("Received interaction before startup completed, ignoring")
if self.send_not_ready_messages:
await self.handle_pre_ready_response(interaction_data)
return

if interaction_data["type"] in (
Expand Down Expand Up @@ -1792,8 +1844,18 @@ async def _dispatch_interaction(self, event: RawGatewayEvent) -> None: # noqa:
ctx = await self.get_context(interaction_data)
self.dispatch(events.ModalCompletion(ctx=ctx))

if callback := self._modal_callbacks.get(ctx.custom_id):
await self.__dispatch_interaction(ctx=ctx, callback=callback(ctx), error_callback=events.ModalError)
modal_callback = self._modal_callbacks.get(ctx.custom_id)
if not modal_callback:
# evaluate regex component callbacks
for regex, callback in self._regex_modal_callbacks.items():
if regex.match(ctx.custom_id):
modal_callback = callback
break

if modal_callback:
await self.__dispatch_interaction(
ctx=ctx, callback=modal_callback(ctx), error_callback=events.ModalError
)

else:
raise NotImplementedError(f"Unknown Interaction Received: {interaction_data['type']}")
Expand Down
48 changes: 46 additions & 2 deletions interactions/client/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,5 +220,49 @@ class MentionPrefix(Sentinel):
LIB_PATH = os.sep.join(__file__.split(os.sep)[:-2])
"""The path to the library folder."""

RECOVERABLE_WEBSOCKET_CLOSE_CODES = (4000, 4001, 4002, 4003, 4005, 4007, 4008, 4009)
NON_RESUMABLE_WEBSOCKET_CLOSE_CODES = (1000, 4007)
# fmt: off
RECOVERABLE_WEBSOCKET_CLOSE_CODES = ( # Codes that are recoverable, and the bot will reconnect
1000, # Normal closure
1001, # Server going away
1003, # Unsupported Data
1005, # No status code
1006, # Abnormal closure
1008, # Policy Violation
1009, # Message too big
1011, # Server error
1012, # Server is restarting
1014, # Handshake failed
1015, # TLS error
4000, # Unknown error
4001, # Unknown opcode
4002, # Decode error
4003, # Not authenticated
4005, # Already authenticated
4007, # Invalid seq
4008, # Rate limited
4009, # Session timed out
)
NON_RESUMABLE_WEBSOCKET_CLOSE_CODES = ( # Codes that are recoverable, but the bot won't be able to resume the session
1000, # Normal closure
1003, # Unsupported Data
1008, # Policy Violation
1009, # Message too big
1011, # Server error
1012, # Server is restarting
1014, # Handshake failed
1015, # TLS error
4007, # Invalid seq
)
# Any close code not in the above two tuples is a non-recoverable close code, and will result in the bot shutting down
# fmt: on


# Sanity check the above constants - only useful during development, but doesn't hurt to leave in
try:
assert set(NON_RESUMABLE_WEBSOCKET_CLOSE_CODES).issubset(set(RECOVERABLE_WEBSOCKET_CLOSE_CODES))
except AssertionError as e:
# find the difference between the two sets
diff = set(NON_RESUMABLE_WEBSOCKET_CLOSE_CODES) - set(RECOVERABLE_WEBSOCKET_CLOSE_CODES)
raise RuntimeError(
f"NON_RESUMABLE_WEBSOCKET_CLOSE_CODES contains codes that are not in RECOVERABLE_WEBSOCKET_CLOSE_CODES: {diff}"
) from e
12 changes: 7 additions & 5 deletions interactions/models/discord/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,11 +404,12 @@ async def purge(
search_limit: int = 100,
predicate: Callable[["models.Message"], bool] = MISSING,
avoid_loading_msg: bool = True,
return_messages: bool = False,
before: Optional[Snowflake_Type] = MISSING,
after: Optional[Snowflake_Type] = MISSING,
around: Optional[Snowflake_Type] = MISSING,
reason: Absent[Optional[str]] = MISSING,
) -> int:
) -> int | List["models.Message"]:
"""
Bulk delete messages within a channel. If a `predicate` is provided, it will be used to determine which messages to delete, otherwise all messages will be deleted within the `deletion_limit`.
Expand All @@ -424,6 +425,7 @@ async def purge(
search_limit: How many messages to search through
predicate: A function that returns True or False, and takes a message as an argument
avoid_loading_msg: Should the bot attempt to avoid deleting its own loading messages (recommended enabled)
return_messages: Should the bot return the messages that were deleted
before: Search messages before this ID
after: Search messages after this ID
around: Search messages around this ID
Expand Down Expand Up @@ -461,13 +463,13 @@ def predicate(m) -> bool:
# message is too old to be purged
continue

to_delete.append(message.id)
to_delete.append(message)

count = len(to_delete)
out = to_delete.copy()
while len(to_delete):
iteration = [to_delete.pop() for i in range(min(100, len(to_delete)))]
iteration = [to_delete.pop().id for i in range(min(100, len(to_delete)))]
await self.delete_messages(iteration, reason=reason)
return count
return out if return_messages else len(out)

async def trigger_typing(self) -> None:
"""Trigger a typing animation in this channel."""
Expand Down
Loading

0 comments on commit 1e85c0f

Please sign in to comment.