diff options
Diffstat (limited to 'drutils/converters.py')
-rw-r--r-- | drutils/converters.py | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/drutils/converters.py b/drutils/converters.py new file mode 100644 index 0000000..640a7a6 --- /dev/null +++ b/drutils/converters.py @@ -0,0 +1,52 @@ +from typing import Callable, Union, Any, TypeVar, Type + +from discord.ext.commands import Converter, UserInputError +from discord.utils import maybe_coroutine + +C = TypeVar('C', bound=Converter) + + +class AugmentedConverter(Converter): + """ + Augment an existing converter by mapping its result or filtering based on its output. + + Mapping is applied before filtering. + + """ + + async def convert(self, ctx, argument): + converter = self.original_converter + if issubclass(converter, Converter): + converter = converter() + if isinstance(converter, Converter): + converter = converter.convert + result = await converter(ctx, argument) + if await maybe_coroutine(self.filter_function(result)): + return result + raise UserInputError() + + def __init__(self, original_converter: Union[C, Type[C], Callable], mapping_function: Callable[[Any], Any] = None, + filter_function: Callable[[Any], bool] = None, error_message: str = None): + """ + + :param original_converter: the converter augment. can be a function, converter or converter class. must be async + :param mapping_function: maps the result of the converter before any filtering takes place. optional + :param filter_function: filters the result. if this function returns a falsey value an `UserInputError` is + risen and then caught by discord.ext.commands + :param error_message: the error message that is thrown. this can be accessed by a custom error handler + """ + self.original_converter = original_converter + self.mapping_function = mapping_function or (lambda x: x) + self.filter_function = filter_function or (lambda _: True) + self.error_message = error_message + + +def non_nullable(converter: Union[C, Type[C], Callable], error_message: str = None): + """ + augments a converter to raise a UserInputError when None is returned. Useful for Greedy or Optional matchers. + + :param converter: the converter to augment + :param error_message: the error message to be associated with the UserInputError + :return: the augmented converter + """ + return AugmentedConverter(converter, filter_function=lambda res: res is not None, error_message=error_message) |