aboutsummaryrefslogtreecommitdiff
path: root/drutils
diff options
context:
space:
mode:
Diffstat (limited to 'drutils')
-rw-r--r--drutils/__init__.py7
-rw-r--r--drutils/awaiter.py217
-rw-r--r--drutils/version.py20
3 files changed, 244 insertions, 0 deletions
diff --git a/drutils/__init__.py b/drutils/__init__.py
new file mode 100644
index 0000000..7f362ee
--- /dev/null
+++ b/drutils/__init__.py
@@ -0,0 +1,7 @@
+from .awaiter import AdvancedAwaiter, AwaitException, AwaitCanceled, AwaitTimedOut
+from .version import VERSION, VersionInfo
+
+__all__ = (
+ 'AdvancedAwaiter', 'AwaitException', 'AwaitCanceled', 'AwaitTimedOut',
+ 'VERSION', 'VersionInfo'
+)
diff --git a/drutils/awaiter.py b/drutils/awaiter.py
new file mode 100644
index 0000000..bad7c79
--- /dev/null
+++ b/drutils/awaiter.py
@@ -0,0 +1,217 @@
+import asyncio
+import inspect
+import string
+from collections import defaultdict
+from typing import Any, Optional, List
+
+import discord
+from discord import Message, Guild, PartialEmoji, Embed, Color, Reaction, User, Role, TextChannel
+from discord.abc import Messageable
+from discord.ext.commands import Bot, Paginator, Context
+
+_NoneType = type(None)
+
+
+async def await_if(func, *args, **kwargs):
+ if inspect.iscoroutinefunction(func):
+ return await func(*args, **kwargs)
+ else:
+ return func(*args, **kwargs)
+
+
+def keeper(keep):
+ table = defaultdict(_NoneType)
+ table.update({ord(c): c for c in keep})
+ return table
+
+
+NUMBER_EMOJIS = [str(num) + '\U000020e3' for num in range(1, 10)] + ['\U0001f51f']
+digit_keeper = keeper(string.digits)
+YES_ANSWERS = ['yes', 'y']
+NO_ANSWERS = ['n', 'no']
+YES_REACTION = '\N{WHITE HEAVY CHECK MARK}'
+NO_REACTION = '\N{CROSS MARK}'
+
+
+class AwaitException(BaseException):
+ pass
+
+
+class AwaitCanceled(AwaitException):
+ pass
+
+
+class AwaitTimedOut(AwaitException):
+ pass
+
+
+class AdvancedAwaiter:
+ def __init__(self, bot: Bot, channel: Messageable, guild: Optional[Guild], user: User, timeout: Optional[int]):
+ self.bot = bot
+ self.channel = channel
+ self.guild = guild
+ self.author = user
+ self.timeout = timeout
+
+ @classmethod
+ def from_context(cls, ctx: Context, timeout: Optional[int] = None):
+ return cls(ctx.bot, ctx.channel, ctx.guild, ctx.author, timeout)
+
+ @classmethod
+ def from_message(cls, bot: Bot, mes: Message, timeout: Optional[int] = None):
+ return cls(bot, mes.channel, mes.guild, mes.author, timeout)
+
+ @classmethod
+ def in_direct_message(cls, bot: Bot, user: User, timeout: Optional[int] = None):
+ return cls(bot, user, None, user, timeout)
+
+ def check_author(self, mes: Message):
+ if mes.author.id != self.author.id:
+ return False
+ return mes.channel == self.channel or (self.guild is None and mes.guild is None)
+
+ async def by_converter(self, text, check, converter) -> Any:
+ obj = None
+ while obj is None or not await await_if(check, obj):
+ try:
+ res = await self(text)
+ obj = await await_if(converter, res)
+ except AwaitException:
+ raise
+ except BaseException as e:
+ print(e)
+ return obj
+
+ async def __call__(self, text, check=lambda mes: True) -> Message:
+ await self.channel.send(
+ embed=Embed(
+ color=Color.blurple(),
+ description=text))
+
+ try:
+ mes = await self.bot.wait_for('message', check=lambda mes: self.check_author(mes) and check(mes),
+ timeout=self.timeout)
+ if mes.content.lower() == "@cancel@":
+ raise AwaitCanceled
+ return mes
+ except asyncio.TimeoutError:
+ raise AwaitTimedOut
+
+ async def converted_emoji(self, text: str, converter=lambda x: x, check=lambda x: True):
+ thing = None
+ while thing is None or not check(thing):
+ try:
+ ret = await self.emoji_reaction(text)
+ thing = await await_if(converter, ret)
+ except AwaitException:
+ raise
+ except BaseException as e:
+ print(e)
+ return thing
+
+ async def emoji_reaction(self, text: str) -> PartialEmoji:
+ mes = await self.channel.send(
+ embed=Embed(
+ color=Color.blurple(),
+ description=text))
+
+ def check(reaction: Reaction, user: User):
+ message = reaction.message
+ if self.author.id != user.id:
+ return False
+ return mes.id == message.id
+
+ try:
+ reaction, user = await self.bot.wait_for('reaction_add', check=check, timeout=self.timeout)
+ return reaction
+ except asyncio.TimeoutError:
+ raise AwaitTimedOut
+
+ async def guild_role(self, text: str, check=lambda role: True, list_ids=False) -> Role:
+ async def converter(mes: Message):
+ return discord.utils.get(self.guild.roles,
+ id=int(mes.content.translate(digit_keeper)))
+
+ if list_ids:
+ guild = self.guild
+ paginator = Paginator()
+ for role in guild.roles:
+ paginator.add_line(role.name + ' ' + str(role.id))
+
+ for page in paginator.pages:
+ await self.channel.send(
+ embed=Embed(
+ color=Color.blurple(),
+ description=page))
+ return await self.by_converter(
+ text,
+ check=check,
+ converter=converter)
+
+ async def emoji_choice(self, text: str, choices: List[str]):
+ emoji = ''
+ while emoji not in choices:
+ mes = await self.channel.send(
+ embed=Embed(
+ color=Color.blurple(),
+ description=text))
+ for choice in choices:
+ await mes.add_reaction(choice)
+
+ def check(reaction: Reaction, user: User):
+ message = reaction.message
+ if user.id != self.author.id:
+ return False
+ return message.id == mes.id
+
+ try:
+ reaction, user = await self.bot.wait_for('reaction_add', check=check, timeout=self.timeout)
+ emoji = str(reaction.emoji)
+ except asyncio.TimeoutError:
+ raise AwaitTimedOut
+ return emoji
+
+ async def choice(self, text: str, choices: List[Any]):
+ emojis = NUMBER_EMOJIS[:len(choices)]
+ emoji = await self.emoji_choice(text + '\n' + '\n'.join(
+ map(' '.join, zip(choices, emojis))
+ ), emojis)
+ return choices[emojis.index(emoji)]
+
+ async def emoji_yes_no(self, text: str) -> bool:
+ emoji = await self.emoji_choice(text, [YES_REACTION, NO_REACTION])
+ return emoji == YES_REACTION
+
+ async def text(self, text: str):
+ return (await self(text)).content
+
+ async def guild_channel(self, text: str, check=lambda channel: True, writable=False) -> object:
+ async def converter(mes: Message):
+ return discord.utils.get(self.guild.channels,
+ id=int(mes.content.translate(digit_keeper)))
+
+ async def all_checks(channel: TextChannel):
+ if writable and not channel.permissions_for(self.bot.user).send_messages:
+ return False
+ return await await_if(check, channel)
+
+ return await self.by_converter(
+ text,
+ check=all_checks,
+ converter=converter)
+
+ async def as_message(self, text: str, check=lambda mes: True, in_channel: TextChannel = None) -> Message:
+ if in_channel is None:
+ in_channel = self.channel
+
+ async def converter(mes: Message):
+ return await in_channel.get_message(mes.content)
+
+ return await self.by_converter(text, check=check, converter=converter)
+
+ async def yes_no_question(self, text: str) -> bool:
+ response = ''
+ while response not in (YES_ANSWERS + NO_ANSWERS):
+ response = (await self.text(text)).lower()
+ pass
+ return response in YES_ANSWERS
diff --git a/drutils/version.py b/drutils/version.py
new file mode 100644
index 0000000..5e43541
--- /dev/null
+++ b/drutils/version.py
@@ -0,0 +1,20 @@
+"""versioninfo"""
+
+
+# pylint: disable=too-few-public-methods
+class VersionInfo:
+ """Version info dataclass"""
+
+ # pylint: disable=too-many-arguments
+ def __init__(self, major: int, minor: int, build: int, level: str, serial: int):
+ self.major = major
+ self.minor = minor
+ self.build = build
+ self.level = level
+ self.serial = serial
+
+ def __str__(self):
+ return '{major}.{minor}.{build}{level}{serial}'.format(**self.__dict__)
+
+
+VERSION = VersionInfo(1, 0, 0, 'a', 0)