aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorromangraef <romangraef@gmail.com>2019-09-08 19:35:35 +0200
committerromangraef <romangraef@gmail.com>2019-09-08 19:35:35 +0200
commit15a7505738abcaa0d3d82b5cc21f59b653fca3e0 (patch)
tree30c11352fc7e6d9ff5231f9378d70e9cbca627ee
parentf222ac827089aa72c82a3d49dd0df57d190262a3 (diff)
downloaddrutils-15a7505738abcaa0d3d82b5cc21f59b653fca3e0.tar.gz
drutils-15a7505738abcaa0d3d82b5cc21f59b653fca3e0.tar.bz2
drutils-15a7505738abcaa0d3d82b5cc21f59b653fca3e0.zip
augmented converters
-rw-r--r--drutils/__init__.py2
-rw-r--r--drutils/converters.py52
-rw-r--r--drutils/eval.py9
3 files changed, 62 insertions, 1 deletions
diff --git a/drutils/__init__.py b/drutils/__init__.py
index 2aaa5c6..a67a7dd 100644
--- a/drutils/__init__.py
+++ b/drutils/__init__.py
@@ -1,9 +1,11 @@
from .awaiter import AdvancedAwaiter, AwaitException, AwaitCanceled, AwaitTimedOut
+from .converters import FilteredConverter, non_nullable
from .eval import handle_eval
from .version import VERSION, VersionInfo
__all__ = (
'AdvancedAwaiter', 'AwaitException', 'AwaitCanceled', 'AwaitTimedOut',
+ 'FilteredConverter', 'non_nullable',
'handle_eval',
'VERSION', 'VersionInfo',
)
diff --git a/drutils/converters.py b/drutils/converters.py
new file mode 100644
index 0000000..640a7a6
--- /dev/null
+++ b/drutils/converters.py
@@ -0,0 +1,52 @@
+from typing import Callable, Union, Any, TypeVar, Type
+
+from discord.ext.commands import Converter, UserInputError
+from discord.utils import maybe_coroutine
+
+C = TypeVar('C', bound=Converter)
+
+
+class AugmentedConverter(Converter):
+ """
+ Augment an existing converter by mapping its result or filtering based on its output.
+
+ Mapping is applied before filtering.
+
+ """
+
+ async def convert(self, ctx, argument):
+ converter = self.original_converter
+ if issubclass(converter, Converter):
+ converter = converter()
+ if isinstance(converter, Converter):
+ converter = converter.convert
+ result = await converter(ctx, argument)
+ if await maybe_coroutine(self.filter_function(result)):
+ return result
+ raise UserInputError()
+
+ def __init__(self, original_converter: Union[C, Type[C], Callable], mapping_function: Callable[[Any], Any] = None,
+ filter_function: Callable[[Any], bool] = None, error_message: str = None):
+ """
+
+ :param original_converter: the converter augment. can be a function, converter or converter class. must be async
+ :param mapping_function: maps the result of the converter before any filtering takes place. optional
+ :param filter_function: filters the result. if this function returns a falsey value an `UserInputError` is
+ risen and then caught by discord.ext.commands
+ :param error_message: the error message that is thrown. this can be accessed by a custom error handler
+ """
+ self.original_converter = original_converter
+ self.mapping_function = mapping_function or (lambda x: x)
+ self.filter_function = filter_function or (lambda _: True)
+ self.error_message = error_message
+
+
+def non_nullable(converter: Union[C, Type[C], Callable], error_message: str = None):
+ """
+ augments a converter to raise a UserInputError when None is returned. Useful for Greedy or Optional matchers.
+
+ :param converter: the converter to augment
+ :param error_message: the error message to be associated with the UserInputError
+ :return: the augmented converter
+ """
+ return AugmentedConverter(converter, filter_function=lambda res: res is not None, error_message=error_message)
diff --git a/drutils/eval.py b/drutils/eval.py
index fd105b5..645d069 100644
--- a/drutils/eval.py
+++ b/drutils/eval.py
@@ -13,7 +13,8 @@ REPLACEMENTS = {
}
-async def handle_eval(message: discord.Message, client: discord.Client, to_eval: str, **kwargs):
+async def handle_eval(message: discord.Message, client: discord.Client, to_eval: str,
+ strip_codeblock: bool = False, **kwargs):
channel = message.channel
author = message.author
@@ -36,7 +37,13 @@ async def handle_eval(message: discord.Message, client: discord.Client, to_eval:
'guild': channel.guild if hasattr(channel, 'guild') else None,
}
variables.update(kwargs)
+
lines = to_eval.strip().split('\n')
+ if strip_codeblock:
+ if lines[0].startswith("```"):
+ lines = lines[1:]
+ lines[-1] = ''.join(lines[-1].rsplit('```', 1))
+
block = '\n'.join(' ' + line for line in lines)
code = ("async def code({variables}):\n"
"{block}").format(variables=', '.join(variables.keys()), block=block)