Skip to content

Store app command ids on CommandTree #9924

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions discord/app_commands/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,10 @@ async def delete(self) -> None:
self.id,
)

tree = self._state._command_tree
if tree:
tree._command_ids.get(self.guild_id, {}).pop(self.name, None)

async def edit(
self,
*,
Expand Down Expand Up @@ -392,6 +396,11 @@ async def edit(
self.id,
payload,
)

tree = self._state._command_tree
if tree:
tree._update_command_ids(data)

return AppCommand(data=data, state=state)

async def fetch_permissions(self, guild: Snowflake) -> GuildAppCommandPermissions:
Expand Down
114 changes: 113 additions & 1 deletion discord/app_commands/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@

if TYPE_CHECKING:
from ..types.interactions import ApplicationCommandInteractionData, ApplicationCommandInteractionDataOption
from ..types.command import ApplicationCommand
from ..interactions import Interaction
from ..abc import Snowflake
from .commands import ContextMenuCallback, CommandCallback, P, T
Expand Down Expand Up @@ -132,6 +133,13 @@ class CommandTree(Generic[ClientT]):
Note that you can override this on a per command basis.

.. versionadded:: 2.4
store_app_command_ids: :class:`bool`
Whether to store the application command IDs on the tree. These can be used to mention a command.
Defaults to ``False``.

This must be enabled if you want to use :meth:`get_command_mention` or :meth:`get_command_id`.

.. versionadded:: 2.5
"""

def __init__(
Expand All @@ -141,6 +149,7 @@ def __init__(
fallback_to_global: bool = True,
allowed_contexts: AppCommandContext = MISSING,
allowed_installs: AppInstallationType = MISSING,
store_app_command_ids: bool = False,
):
self.client: ClientT = client
self._http = client.http
Expand All @@ -161,6 +170,9 @@ def __init__(
# it's uncommon and N=5 anyway.
self._context_menus: Dict[Tuple[str, Optional[int], int], ContextMenu] = {}

self.store_app_command_ids: bool = store_app_command_ids
self._command_ids: Dict[Optional[int], Dict[str, int]] = {}

async def fetch_command(self, command_id: int, /, *, guild: Optional[Snowflake] = None) -> AppCommand:
"""|coro|

Expand Down Expand Up @@ -198,6 +210,7 @@ async def fetch_command(self, command_id: int, /, *, guild: Optional[Snowflake]
else:
command = await self._http.get_guild_command(self.client.application_id, guild.id, command_id)

self._update_command_ids(command)
return AppCommand(data=command, state=self._state)

async def fetch_commands(self, *, guild: Optional[Snowflake] = None) -> List[AppCommand]:
Expand Down Expand Up @@ -238,7 +251,91 @@ async def fetch_commands(self, *, guild: Optional[Snowflake] = None) -> List[App
else:
commands = await self._http.get_guild_commands(self.client.application_id, guild.id)

return [AppCommand(data=data, state=self._state) for data in commands]
self._update_command_ids(*commands)
return [AppCommand(data=command, state=self._state) for command in commands]

def get_command_id(
self, command: Union[AppCommand, Command, ContextMenu, Group, str], /, *, guild: Optional[Snowflake] = None
) -> Optional[int]:
"""Gets the command ID for a command.

Parameters
-----------
name: Union[:class:`~discord.app_commands.Command`, :class:`~discord.app_commands.ContextMenu`, :class:`~discord.app_commands.Group`, :class:`str`]
The name of the command to get the ID for.
guild: Optional[:class:`~discord.abc.Snowflake`]
The guild to get the command ID for. If not passed then the global command
ID is fetched instead.

Returns
--------
Optional[:class:`int`]
The command ID if found, otherwise ``None``.

.. note::

Group commands will return the ID of the root command. Subcommands do not have their own IDs.
"""
name: Optional[str] = None

if isinstance(command, AppCommand):
return command.id

if isinstance(command, (Command, Group, ContextMenu)):
name = (command.root_parent or command).name if not isinstance(command, ContextMenu) else command.name
elif isinstance(command, str):
name = command.split()[0]

return self._command_ids.get(guild.id if guild else None, {}).get(name)

def get_command_mention(
self, command: Union[AppCommand, Command, ContextMenu, Group, str], /, *, guild: Optional[Snowflake] = None
) -> Optional[str]:
"""Gets the mention string for a command.

Parameters
-----------
command: Union[:class:`~discord.app_commands.Command`, :class:`~discord.app_commands.ContextMenu`, :class:`~discord.app_commands.Group`, :class:`str`]
The command to get the mention string for.

Returns
--------
Optional[:class:`str`]
The mention string for the command if found, otherwise ``None``.

.. note::

Remember that groups cannot be mentioned, only with a subcommand.
"""
if isinstance(command, AppCommand):
return command.mention

command_id = self.get_command_id(command, guild=guild)
if command_id is None:
return None

if isinstance(command, (Command, Group, ContextMenu)):
full_name = command.qualified_name
elif isinstance(command, str):
full_name = command

return f'</{full_name}:{command_id}>'

def get_command_ids(self, guild: Optional[Snowflake] = None) -> Dict[str, int]:
"""Gets all command IDs for the given guild.

Parameters
-----------
guild: Optional[:class:`~discord.abc.Snowflake`]
The guild to get the command IDs for. If not passed then the global command
IDs are returned instead.

Returns
--------
Dict[:class:`str`, :class:`int`]
A dictionary of command names and their IDs.
"""
return self._command_ids.get(guild.id if guild else None, {})

def copy_global_to(self, *, guild: Snowflake) -> None:
"""Copies all global commands to the specified guild.
Expand Down Expand Up @@ -1134,8 +1231,22 @@ async def sync(self, *, guild: Optional[Snowflake] = None) -> List[AppCommand]:
raise CommandSyncFailure(e, commands) from None
raise

self._update_command_ids(*data)
return [AppCommand(data=d, state=self._state) for d in data]

def _update_command_ids(self, *data: Union[ApplicationCommandInteractionData, ApplicationCommand]) -> None:
if not self.store_app_command_ids:
return

for d in data:
command_id: int = int(d['id'])
name: str = d['name']
guild_id: Optional[int] = _get_as_snowflake(d, 'guild_id')
try:
self._command_ids[guild_id][name] = command_id
except KeyError:
self._command_ids[guild_id] = {name: command_id}

async def _dispatch_error(self, interaction: Interaction[ClientT], error: AppCommandError, /) -> None:
command = interaction.command
interaction.command_failed = True
Expand Down Expand Up @@ -1274,6 +1385,7 @@ async def _call(self, interaction: Interaction[ClientT]) -> None:
return

data: ApplicationCommandInteractionData = interaction.data # type: ignore
self._update_command_ids(data)
type = data.get('type', 1)
if type != 1:
# Context menu command...
Expand Down
Loading