aboutsummaryrefslogtreecommitdiff
path: root/drutils/converters.py
diff options
context:
space:
mode:
Diffstat (limited to 'drutils/converters.py')
-rw-r--r--drutils/converters.py52
1 files changed, 52 insertions, 0 deletions
diff --git a/drutils/converters.py b/drutils/converters.py
new file mode 100644
index 0000000..640a7a6
--- /dev/null
+++ b/drutils/converters.py
@@ -0,0 +1,52 @@
+from typing import Callable, Union, Any, TypeVar, Type
+
+from discord.ext.commands import Converter, UserInputError
+from discord.utils import maybe_coroutine
+
+C = TypeVar('C', bound=Converter)
+
+
+class AugmentedConverter(Converter):
+ """
+ Augment an existing converter by mapping its result or filtering based on its output.
+
+ Mapping is applied before filtering.
+
+ """
+
+ async def convert(self, ctx, argument):
+ converter = self.original_converter
+ if issubclass(converter, Converter):
+ converter = converter()
+ if isinstance(converter, Converter):
+ converter = converter.convert
+ result = await converter(ctx, argument)
+ if await maybe_coroutine(self.filter_function(result)):
+ return result
+ raise UserInputError()
+
+ def __init__(self, original_converter: Union[C, Type[C], Callable], mapping_function: Callable[[Any], Any] = None,
+ filter_function: Callable[[Any], bool] = None, error_message: str = None):
+ """
+
+ :param original_converter: the converter augment. can be a function, converter or converter class. must be async
+ :param mapping_function: maps the result of the converter before any filtering takes place. optional
+ :param filter_function: filters the result. if this function returns a falsey value an `UserInputError` is
+ risen and then caught by discord.ext.commands
+ :param error_message: the error message that is thrown. this can be accessed by a custom error handler
+ """
+ self.original_converter = original_converter
+ self.mapping_function = mapping_function or (lambda x: x)
+ self.filter_function = filter_function or (lambda _: True)
+ self.error_message = error_message
+
+
+def non_nullable(converter: Union[C, Type[C], Callable], error_message: str = None):
+ """
+ augments a converter to raise a UserInputError when None is returned. Useful for Greedy or Optional matchers.
+
+ :param converter: the converter to augment
+ :param error_message: the error message to be associated with the UserInputError
+ :return: the augmented converter
+ """
+ return AugmentedConverter(converter, filter_function=lambda res: res is not None, error_message=error_message)