aboutsummaryrefslogtreecommitdiff
path: root/drutils/converters.py
blob: 4354af187d1b05547506a4750acab3634e6d1832 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from typing import Callable, Union, Any, TypeVar, Type

from discord.ext.commands import Converter, UserInputError
from discord.utils import maybe_coroutine

C = TypeVar('C', bound=Converter)


class AugmentedConverter(Converter):
    """
    Augment an existing converter by mapping its result or filtering based on its output.

    Mapping is applied before filtering.

    """

    async def convert(self, ctx, argument):
        converter = self.original_converter
        if issubclass(converter, Converter):
            converter = converter()
        if isinstance(converter, Converter):
            converter = converter.convert
        result = await converter(ctx, argument)
        if await maybe_coroutine(self.filter_function, result):
            return result
        raise UserInputError()

    def __init__(self, original_converter: Union[C, Type[C], Callable], mapping_function: Callable[[Any], Any] = None,
                 filter_function: Callable[[Any], bool] = None, error_message: str = None):
        """

        :param original_converter: the converter augment. can be a function, converter or converter class. must be async
        :param mapping_function:   maps the result of the converter before any filtering takes place. optional
        :param filter_function:    filters the result. if this function returns a falsey value an `UserInputError` is
                                   risen and then caught by discord.ext.commands
        :param error_message:      the error message that is thrown. this can be accessed by a custom error handler
        """
        self.original_converter = original_converter
        self.mapping_function = mapping_function or (lambda x: x)
        self.filter_function = filter_function or (lambda _: True)
        self.error_message = error_message


def non_nullable(converter: Union[C, Type[C], Callable], error_message: str = None):
    """
    augments a converter to raise a UserInputError when None is returned. Useful for Greedy or Optional matchers.

    :param converter:     the converter to augment
    :param error_message: the error message to be associated with the UserInputError
    :return:              the augmented converter
    """
    return AugmentedConverter(converter, filter_function=lambda res: res is not None, error_message=error_message)