diff options
author | romangraef <romangraef@gmail.com> | 2019-09-08 19:35:35 +0200 |
---|---|---|
committer | romangraef <romangraef@gmail.com> | 2019-09-08 19:35:35 +0200 |
commit | 15a7505738abcaa0d3d82b5cc21f59b653fca3e0 (patch) | |
tree | 30c11352fc7e6d9ff5231f9378d70e9cbca627ee | |
parent | f222ac827089aa72c82a3d49dd0df57d190262a3 (diff) | |
download | drutils-15a7505738abcaa0d3d82b5cc21f59b653fca3e0.tar.gz drutils-15a7505738abcaa0d3d82b5cc21f59b653fca3e0.tar.bz2 drutils-15a7505738abcaa0d3d82b5cc21f59b653fca3e0.zip |
augmented converters
-rw-r--r-- | drutils/__init__.py | 2 | ||||
-rw-r--r-- | drutils/converters.py | 52 | ||||
-rw-r--r-- | drutils/eval.py | 9 |
3 files changed, 62 insertions, 1 deletions
diff --git a/drutils/__init__.py b/drutils/__init__.py index 2aaa5c6..a67a7dd 100644 --- a/drutils/__init__.py +++ b/drutils/__init__.py @@ -1,9 +1,11 @@ from .awaiter import AdvancedAwaiter, AwaitException, AwaitCanceled, AwaitTimedOut +from .converters import FilteredConverter, non_nullable from .eval import handle_eval from .version import VERSION, VersionInfo __all__ = ( 'AdvancedAwaiter', 'AwaitException', 'AwaitCanceled', 'AwaitTimedOut', + 'FilteredConverter', 'non_nullable', 'handle_eval', 'VERSION', 'VersionInfo', ) diff --git a/drutils/converters.py b/drutils/converters.py new file mode 100644 index 0000000..640a7a6 --- /dev/null +++ b/drutils/converters.py @@ -0,0 +1,52 @@ +from typing import Callable, Union, Any, TypeVar, Type + +from discord.ext.commands import Converter, UserInputError +from discord.utils import maybe_coroutine + +C = TypeVar('C', bound=Converter) + + +class AugmentedConverter(Converter): + """ + Augment an existing converter by mapping its result or filtering based on its output. + + Mapping is applied before filtering. + + """ + + async def convert(self, ctx, argument): + converter = self.original_converter + if issubclass(converter, Converter): + converter = converter() + if isinstance(converter, Converter): + converter = converter.convert + result = await converter(ctx, argument) + if await maybe_coroutine(self.filter_function(result)): + return result + raise UserInputError() + + def __init__(self, original_converter: Union[C, Type[C], Callable], mapping_function: Callable[[Any], Any] = None, + filter_function: Callable[[Any], bool] = None, error_message: str = None): + """ + + :param original_converter: the converter augment. can be a function, converter or converter class. must be async + :param mapping_function: maps the result of the converter before any filtering takes place. optional + :param filter_function: filters the result. if this function returns a falsey value an `UserInputError` is + risen and then caught by discord.ext.commands + :param error_message: the error message that is thrown. this can be accessed by a custom error handler + """ + self.original_converter = original_converter + self.mapping_function = mapping_function or (lambda x: x) + self.filter_function = filter_function or (lambda _: True) + self.error_message = error_message + + +def non_nullable(converter: Union[C, Type[C], Callable], error_message: str = None): + """ + augments a converter to raise a UserInputError when None is returned. Useful for Greedy or Optional matchers. + + :param converter: the converter to augment + :param error_message: the error message to be associated with the UserInputError + :return: the augmented converter + """ + return AugmentedConverter(converter, filter_function=lambda res: res is not None, error_message=error_message) diff --git a/drutils/eval.py b/drutils/eval.py index fd105b5..645d069 100644 --- a/drutils/eval.py +++ b/drutils/eval.py @@ -13,7 +13,8 @@ REPLACEMENTS = { } -async def handle_eval(message: discord.Message, client: discord.Client, to_eval: str, **kwargs): +async def handle_eval(message: discord.Message, client: discord.Client, to_eval: str, + strip_codeblock: bool = False, **kwargs): channel = message.channel author = message.author @@ -36,7 +37,13 @@ async def handle_eval(message: discord.Message, client: discord.Client, to_eval: 'guild': channel.guild if hasattr(channel, 'guild') else None, } variables.update(kwargs) + lines = to_eval.strip().split('\n') + if strip_codeblock: + if lines[0].startswith("```"): + lines = lines[1:] + lines[-1] = ''.join(lines[-1].rsplit('```', 1)) + block = '\n'.join(' ' + line for line in lines) code = ("async def code({variables}):\n" "{block}").format(variables=', '.join(variables.keys()), block=block) |