From c8808ec947590b0ab2cce69484e71aa90fecabf0 Mon Sep 17 00:00:00 2001 From: romangraef Date: Mon, 26 Nov 2018 21:05:44 +0100 Subject: Initial commit --- .gitignore | 136 +++++++++++++++++++++++++++++++ README.md | 3 + drutils/__init__.py | 7 ++ drutils/awaiter.py | 217 ++++++++++++++++++++++++++++++++++++++++++++++++++ drutils/version.py | 20 +++++ pylintrc.cfg | 3 + requirements.txt | 1 + setup.cfg | 2 + setup.py | 34 ++++++++ tests/__init__.py | 0 tests/test_awaiter.py | 40 ++++++++++ tests/utils.py | 8 ++ tox.ini | 16 ++++ 13 files changed, 487 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 drutils/__init__.py create mode 100644 drutils/awaiter.py create mode 100644 drutils/version.py create mode 100644 pylintrc.cfg create mode 100644 requirements.txt create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 tests/__init__.py create mode 100644 tests/test_awaiter.py create mode 100644 tests/utils.py create mode 100644 tox.ini diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3291c57 --- /dev/null +++ b/.gitignore @@ -0,0 +1,136 @@ + +# Created by https://www.gitignore.io/api/python +# Edit at https://www.gitignore.io/?templates=python + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +### Python Patch ### +.venv/ + +### Python.VirtualEnv Stack ### +# Virtualenv +# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ +[Bb]in +[Ii]nclude +[Ll]ib +[Ll]ib64 +[Ll]ocal +[Ss]cripts +pyvenv.cfg +pip-selfcheck.json + +# End of https://www.gitignore.io/api/python diff --git a/README.md b/README.md new file mode 100644 index 0000000..8f0b56d --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +## Romans discord.py utilities + +Yes thats it. diff --git a/drutils/__init__.py b/drutils/__init__.py new file mode 100644 index 0000000..7f362ee --- /dev/null +++ b/drutils/__init__.py @@ -0,0 +1,7 @@ +from .awaiter import AdvancedAwaiter, AwaitException, AwaitCanceled, AwaitTimedOut +from .version import VERSION, VersionInfo + +__all__ = ( + 'AdvancedAwaiter', 'AwaitException', 'AwaitCanceled', 'AwaitTimedOut', + 'VERSION', 'VersionInfo' +) diff --git a/drutils/awaiter.py b/drutils/awaiter.py new file mode 100644 index 0000000..bad7c79 --- /dev/null +++ b/drutils/awaiter.py @@ -0,0 +1,217 @@ +import asyncio +import inspect +import string +from collections import defaultdict +from typing import Any, Optional, List + +import discord +from discord import Message, Guild, PartialEmoji, Embed, Color, Reaction, User, Role, TextChannel +from discord.abc import Messageable +from discord.ext.commands import Bot, Paginator, Context + +_NoneType = type(None) + + +async def await_if(func, *args, **kwargs): + if inspect.iscoroutinefunction(func): + return await func(*args, **kwargs) + else: + return func(*args, **kwargs) + + +def keeper(keep): + table = defaultdict(_NoneType) + table.update({ord(c): c for c in keep}) + return table + + +NUMBER_EMOJIS = [str(num) + '\U000020e3' for num in range(1, 10)] + ['\U0001f51f'] +digit_keeper = keeper(string.digits) +YES_ANSWERS = ['yes', 'y'] +NO_ANSWERS = ['n', 'no'] +YES_REACTION = '\N{WHITE HEAVY CHECK MARK}' +NO_REACTION = '\N{CROSS MARK}' + + +class AwaitException(BaseException): + pass + + +class AwaitCanceled(AwaitException): + pass + + +class AwaitTimedOut(AwaitException): + pass + + +class AdvancedAwaiter: + def __init__(self, bot: Bot, channel: Messageable, guild: Optional[Guild], user: User, timeout: Optional[int]): + self.bot = bot + self.channel = channel + self.guild = guild + self.author = user + self.timeout = timeout + + @classmethod + def from_context(cls, ctx: Context, timeout: Optional[int] = None): + return cls(ctx.bot, ctx.channel, ctx.guild, ctx.author, timeout) + + @classmethod + def from_message(cls, bot: Bot, mes: Message, timeout: Optional[int] = None): + return cls(bot, mes.channel, mes.guild, mes.author, timeout) + + @classmethod + def in_direct_message(cls, bot: Bot, user: User, timeout: Optional[int] = None): + return cls(bot, user, None, user, timeout) + + def check_author(self, mes: Message): + if mes.author.id != self.author.id: + return False + return mes.channel == self.channel or (self.guild is None and mes.guild is None) + + async def by_converter(self, text, check, converter) -> Any: + obj = None + while obj is None or not await await_if(check, obj): + try: + res = await self(text) + obj = await await_if(converter, res) + except AwaitException: + raise + except BaseException as e: + print(e) + return obj + + async def __call__(self, text, check=lambda mes: True) -> Message: + await self.channel.send( + embed=Embed( + color=Color.blurple(), + description=text)) + + try: + mes = await self.bot.wait_for('message', check=lambda mes: self.check_author(mes) and check(mes), + timeout=self.timeout) + if mes.content.lower() == "@cancel@": + raise AwaitCanceled + return mes + except asyncio.TimeoutError: + raise AwaitTimedOut + + async def converted_emoji(self, text: str, converter=lambda x: x, check=lambda x: True): + thing = None + while thing is None or not check(thing): + try: + ret = await self.emoji_reaction(text) + thing = await await_if(converter, ret) + except AwaitException: + raise + except BaseException as e: + print(e) + return thing + + async def emoji_reaction(self, text: str) -> PartialEmoji: + mes = await self.channel.send( + embed=Embed( + color=Color.blurple(), + description=text)) + + def check(reaction: Reaction, user: User): + message = reaction.message + if self.author.id != user.id: + return False + return mes.id == message.id + + try: + reaction, user = await self.bot.wait_for('reaction_add', check=check, timeout=self.timeout) + return reaction + except asyncio.TimeoutError: + raise AwaitTimedOut + + async def guild_role(self, text: str, check=lambda role: True, list_ids=False) -> Role: + async def converter(mes: Message): + return discord.utils.get(self.guild.roles, + id=int(mes.content.translate(digit_keeper))) + + if list_ids: + guild = self.guild + paginator = Paginator() + for role in guild.roles: + paginator.add_line(role.name + ' ' + str(role.id)) + + for page in paginator.pages: + await self.channel.send( + embed=Embed( + color=Color.blurple(), + description=page)) + return await self.by_converter( + text, + check=check, + converter=converter) + + async def emoji_choice(self, text: str, choices: List[str]): + emoji = '' + while emoji not in choices: + mes = await self.channel.send( + embed=Embed( + color=Color.blurple(), + description=text)) + for choice in choices: + await mes.add_reaction(choice) + + def check(reaction: Reaction, user: User): + message = reaction.message + if user.id != self.author.id: + return False + return message.id == mes.id + + try: + reaction, user = await self.bot.wait_for('reaction_add', check=check, timeout=self.timeout) + emoji = str(reaction.emoji) + except asyncio.TimeoutError: + raise AwaitTimedOut + return emoji + + async def choice(self, text: str, choices: List[Any]): + emojis = NUMBER_EMOJIS[:len(choices)] + emoji = await self.emoji_choice(text + '\n' + '\n'.join( + map(' '.join, zip(choices, emojis)) + ), emojis) + return choices[emojis.index(emoji)] + + async def emoji_yes_no(self, text: str) -> bool: + emoji = await self.emoji_choice(text, [YES_REACTION, NO_REACTION]) + return emoji == YES_REACTION + + async def text(self, text: str): + return (await self(text)).content + + async def guild_channel(self, text: str, check=lambda channel: True, writable=False) -> object: + async def converter(mes: Message): + return discord.utils.get(self.guild.channels, + id=int(mes.content.translate(digit_keeper))) + + async def all_checks(channel: TextChannel): + if writable and not channel.permissions_for(self.bot.user).send_messages: + return False + return await await_if(check, channel) + + return await self.by_converter( + text, + check=all_checks, + converter=converter) + + async def as_message(self, text: str, check=lambda mes: True, in_channel: TextChannel = None) -> Message: + if in_channel is None: + in_channel = self.channel + + async def converter(mes: Message): + return await in_channel.get_message(mes.content) + + return await self.by_converter(text, check=check, converter=converter) + + async def yes_no_question(self, text: str) -> bool: + response = '' + while response not in (YES_ANSWERS + NO_ANSWERS): + response = (await self.text(text)).lower() + pass + return response in YES_ANSWERS diff --git a/drutils/version.py b/drutils/version.py new file mode 100644 index 0000000..5e43541 --- /dev/null +++ b/drutils/version.py @@ -0,0 +1,20 @@ +"""versioninfo""" + + +# pylint: disable=too-few-public-methods +class VersionInfo: + """Version info dataclass""" + + # pylint: disable=too-many-arguments + def __init__(self, major: int, minor: int, build: int, level: str, serial: int): + self.major = major + self.minor = minor + self.build = build + self.level = level + self.serial = serial + + def __str__(self): + return '{major}.{minor}.{build}{level}{serial}'.format(**self.__dict__) + + +VERSION = VersionInfo(1, 0, 0, 'a', 0) diff --git a/pylintrc.cfg b/pylintrc.cfg new file mode 100644 index 0000000..d8301e6 --- /dev/null +++ b/pylintrc.cfg @@ -0,0 +1,3 @@ +[MASTER] +ignore=tests, setup.py + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..43fb479 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +http://github.com/Rapptz/discord.py/tarball/rewrite diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..31ad82b --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[aliases] +test = pytest diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..5065fcb --- /dev/null +++ b/setup.py @@ -0,0 +1,34 @@ +from setuptools import setup + +with open('drutils/version.py') as f: + _loc = {} + exec(f.read(), _loc, _loc) + version = _loc['VERSION'] + +dev_requirements = [ + "pylint", "aiounittest", "tox", "pytest" +] + +with open('README.md') as f: + readme = f.read() + +if not version: + raise RuntimeError('version is not set in drutils/version.py') + +setup( + name="drutils", + author="romangraef", + url="https://github.com/romangraef/drutils", + version=str(version), + install_requires=[], + long_description=readme, + setup_requires=[], + tests_require=dev_requirements, + dependency_links=[''], + license="MIT", + packages=['drutils'], + description="discord.py utils i found myself using often", + classifiers=[ + 'Topic :: Utilities', + ] +) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_awaiter.py b/tests/test_awaiter.py new file mode 100644 index 0000000..78d25a8 --- /dev/null +++ b/tests/test_awaiter.py @@ -0,0 +1,40 @@ +from unittest.mock import ANY + +from aiounittest import AsyncTestCase +from discord import Object + +from drutils.awaiter import AdvancedAwaiter +from .utils import get_mock_coro + +TEST_TEXT = "ABCDEF" +TEST_RESPONSE = "DEFABC" +TEST_MESSAGE_ID = 3 +TEST_USER_ID = 2 +TEST_CHANNEL_ID = 1 +TEST_TIMEOUT = 100 + + +class AwaiterTest(AsyncTestCase): + + def setUp(self): + super().setUp() + self.ctx = Object(id=0) + self.bot = Object(id=0) + self.message = Object(id=TEST_MESSAGE_ID) + self.message.content = TEST_RESPONSE + self.user = Object(id=TEST_USER_ID) + self.message.author = self.user + self.bot.wait_for = get_mock_coro(self.message) + self.channel = Object(id=TEST_CHANNEL_ID) + self.channel.send = get_mock_coro(None) + self.ctx.bot = self.bot + self.ctx.channel = self.channel + self.ctx.guild = None + self.ctx.author = self.user + + async def test_text(self): + awaiter = AdvancedAwaiter.from_context(self.ctx, timeout=TEST_TIMEOUT) + text = await awaiter.text(TEST_TEXT) + self.assertEqual(text, TEST_RESPONSE) + self.bot.wait_for.assert_called_once_with('message', check=ANY, timeout=TEST_TIMEOUT) + self.assertEqual(self.channel.send.call_args[1]['embed'].description, TEST_TEXT) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..20b76c8 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,8 @@ +from unittest.mock import Mock + + +def get_mock_coro(return_value): + async def mock_coro(*args, **kwargs): + return return_value + + return Mock(wraps=mock_coro) diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..e6299aa --- /dev/null +++ b/tox.ini @@ -0,0 +1,16 @@ +# tox (https://tox.readthedocs.io/) is a tool for running tests +# in multiple virtualenvs. This configuration file will run the +# test suite on all supported python versions. To use it, "pip install tox" +# and then run "tox" from this directory. + +[tox] +envlist = py35, py36 + +[testenv] +deps = aiounittest + pytest + -rrequirements.txt + + +commands = + py.test -- cgit