Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for POLL_RESULT messages #9905

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions discord/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ class MessageType(Enum):
guild_incident_report_raid = 38
guild_incident_report_false_alarm = 39
purchase_notification = 44
poll_result = 46


class SpeakingState(Enum):
Expand Down
16 changes: 16 additions & 0 deletions discord/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@

self.original_response_message_id: Optional[int] = None
try:
self.original_response_message_id = int(data['original_response_message_id']) # type: ignore # EAFP

Check warning on line 868 in discord/message.py

View workflow job for this annotation

GitHub Actions / check 3.x

Unnecessary "# type: ignore" comment
except KeyError:
pass

Expand Down Expand Up @@ -2268,6 +2268,13 @@
# the channel will be the correct type here
ref.resolved = self.__class__(channel=chan, data=resolved, state=state) # type: ignore

if self.type is MessageType.poll_result:
if isinstance(self.reference.resolved, self.__class__):
self._state._update_poll_results(self, self.reference.resolved)
else:
if self.reference.message_id:
self._state._update_poll_results(self, self.reference.message_id)

self.application: Optional[MessageApplication] = None
try:
application = data['application']
Expand Down Expand Up @@ -2634,6 +2641,7 @@
MessageType.chat_input_command,
MessageType.context_menu_command,
MessageType.thread_starter_message,
MessageType.poll_result,
)

@utils.cached_slot_property('_cs_system_content')
Expand Down Expand Up @@ -2810,6 +2818,14 @@
if guild_product_purchase is not None:
return f'{self.author.name} has purchased {guild_product_purchase.product_name}!'

if self.type is MessageType.poll_result:
embed = self.embeds[0] # Will always have 1 embed
poll_title = utils.get(
embed.fields,
name='poll_question_text',
)
return f'{self.author.display_name}\'s poll {poll_title.value} has closed.' # type: ignore

# Fallback for unknown message types
return ''

Expand Down
96 changes: 92 additions & 4 deletions discord/poll.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

import datetime

from .enums import PollLayoutType, try_enum
from .enums import PollLayoutType, try_enum, MessageType
from . import utils
from .emoji import PartialEmoji, Emoji
from .user import User
Expand Down Expand Up @@ -125,7 +125,16 @@ class PollAnswer:
Whether the current user has voted to this answer or not.
"""

__slots__ = ('media', 'id', '_state', '_message', '_vote_count', 'self_voted', '_poll')
__slots__ = (
'media',
'id',
'_state',
'_message',
'_vote_count',
'self_voted',
'_poll',
'_victor',
)

def __init__(
self,
Expand All @@ -141,6 +150,7 @@ def __init__(
self._vote_count: int = 0
self.self_voted: bool = False
self._poll: Poll = poll
self._victor: bool = False

def _handle_vote_event(self, added: bool, self_voted: bool) -> None:
if added:
Expand Down Expand Up @@ -210,6 +220,19 @@ def _to_dict(self) -> PollAnswerPayload:
'poll_media': self.media.to_dict(),
}

@property
def victor(self) -> bool:
""":class:`bool`: Whether the answer is the one that had the most
votes when the poll ended.

.. versionadded:: 2.5

.. note::

If the poll has not ended, this will always return ``False``.
"""
return self._victor

async def voters(
self, *, limit: Optional[int] = None, after: Optional[Snowflake] = None
) -> AsyncIterator[Union[User, Member]]:
Expand Down Expand Up @@ -325,6 +348,8 @@ class Poll:
'_expiry',
'_finalized',
'_state',
'_total_votes',
'_victor_answer_id',
)

def __init__(
Expand All @@ -348,6 +373,8 @@ def __init__(
self._state: Optional[ConnectionState] = None
self._finalized: bool = False
self._expiry: Optional[datetime.datetime] = None
self._total_votes: Optional[int] = None
self._victor_answer_id: Optional[int] = None

def _update(self, message: Message) -> None:
self._state = message._state
Expand All @@ -359,6 +386,33 @@ def _update(self, message: Message) -> None:
# The message's poll contains the more up to date data.
self._expiry = message.poll.expires_at
self._finalized = message.poll._finalized
self._update_results_from_message(message)

def _update_results_from_message(self, message: Message) -> None:
if message.type != MessageType.poll_result or not message.embeds:
return

result_embed = message.embeds[0] # Will always have 1 embed
fields: Dict[str, str] = {field.name: field.value for field in result_embed.fields} # type: ignore

total_votes = fields.get('total_votes')

if total_votes is not None:
self._total_votes = int(total_votes)

victor_answer = fields.get('victor_answer_id')

if victor_answer is None:
return # Can't do anything else without the victor answer

self._victor_answer_id = int(victor_answer)

victor_answer_votes = fields['victor_answer_votes']

answer = self._answers[self._victor_answer_id]
answer._victor = True
answer._vote_count = int(victor_answer_votes)
self._answers[answer.id] = answer # Ensure update

def _update_results(self, data: PollResultPayload) -> None:
self._finalized = data['is_finalized']
Expand Down Expand Up @@ -431,6 +485,32 @@ def answers(self) -> List[PollAnswer]:
"""List[:class:`PollAnswer`]: Returns a read-only copy of the answers."""
return list(self._answers.values())

@property
def victor_answer_id(self) -> Optional[int]:
"""Optional[:class:`int`]: The victor answer ID.

.. versionadded:: 2.5

.. note::

This will **always** be ``None`` for polls that have not yet finished.
"""
return self._victor_answer_id

@property
def victor_answer(self) -> Optional[PollAnswer]:
"""Optional[:class:`PollAnswer`]: The victor answer.

.. versionadded:: 2.5

.. note::

This will **always** be ``None`` for polls that have not yet finished.
"""
if self.victor_answer_id is None:
return None
return self.get_answer(self.victor_answer_id)

@property
def expires_at(self) -> Optional[datetime.datetime]:
"""Optional[:class:`datetime.datetime`]: A datetime object representing the poll expiry.
Expand All @@ -456,12 +536,20 @@ def created_at(self) -> Optional[datetime.datetime]:

@property
def message(self) -> Optional[Message]:
""":class:`Message`: The message this poll is from."""
"""Optional[:class:`Message`]: The message this poll is from."""
return self._message

@property
def total_votes(self) -> int:
""":class:`int`: Returns the sum of all the answer votes."""
""":class:`int`: Returns the sum of all the answer votes.

If the poll has not yet finished, this is an approximate vote count.

.. versionchanged:: 2.5
This now returns an exact vote count when updated from its poll results message.
"""
if self._total_votes is not None:
return self._total_votes
return sum([answer.vote_count for answer in self.answers])

def is_finalised(self) -> bool:
Expand Down
21 changes: 21 additions & 0 deletions discord/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,27 @@ def _update_poll_counts(self, message: Message, answer_id: int, added: bool, sel
poll._handle_vote(answer_id, added, self_voted)
return poll

def _update_poll_results(self, from_: Message, to: Union[Message, int]) -> None:
if isinstance(to, Message):
cached = self._get_message(to.id)
elif isinstance(to, int):
cached = self._get_message(to)

if cached is None:
return

to = cached
else:
return

if to.poll is None:
return

to.poll._update_results_from_message(from_)

if cached is not None and cached.poll:
cached.poll._update_results_from_message(from_)

async def chunker(
self, guild_id: int, query: str = '', limit: int = 0, presences: bool = False, *, nonce: Optional[str] = None
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion discord/types/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class EmbedAuthor(TypedDict, total=False):
proxy_icon_url: str


EmbedType = Literal['rich', 'image', 'video', 'gifv', 'article', 'link']
EmbedType = Literal['rich', 'image', 'video', 'gifv', 'article', 'link', 'poll_result']


class Embed(TypedDict, total=False):
Expand Down
1 change: 1 addition & 0 deletions discord/types/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ class CallMessage(TypedDict):
38,
39,
44,
46,
]


Expand Down
Loading