diff options
-rw-r--r-- | commands/__init__.py | 3 | ||||
-rw-r--r-- | commands/handling.py | 35 | ||||
-rw-r--r-- | lib/__init__.py | 64 | ||||
-rw-r--r-- | lib/commands.py | 48 | ||||
-rw-r--r-- | lib/common.py | 16 | ||||
-rw-r--r-- | lib/match_script.py | 42 | ||||
-rw-r--r-- | main.py | 5 | ||||
-rw-r--r-- | modules/todo/todo.py | 52 | ||||
-rw-r--r-- | util.py | 11 |
9 files changed, 207 insertions, 69 deletions
diff --git a/commands/__init__.py b/commands/__init__.py index dc26d49..a7b0091 100644 --- a/commands/__init__.py +++ b/commands/__init__.py @@ -1,2 +1 @@ -from commands.handling import handle_commands, load_commands - +from commands.handling import handle_commands, load_modules, handle_match_scripts diff --git a/commands/handling.py b/commands/handling.py index 46cdb66..9c9e864 100644 --- a/commands/handling.py +++ b/commands/handling.py @@ -19,9 +19,11 @@ def load_module(module): for func in functions: if lib.is_command(func): lib.register_command(func) + if lib.is_match_script(func): + lib.register_match_script(func) -def load_commands(folder='modules'): +def load_modules(folder='modules'): for dirname, dirnames, filenames in os.walk(folder): for filename in filenames: filename: str @@ -32,10 +34,14 @@ def load_commands(folder='modules'): load_module(module) -def handle_commands(client: pyrogram.Client, update, users, chats): - if not (isinstance(update, tgtypes.UpdateNewMessage) +def is_message_update(update): + return (isinstance(update, tgtypes.UpdateNewMessage) or isinstance(update, tgtypes.UpdateNewChannelMessage) - or isinstance(update, tgtypes.UpdateNewEncryptedMessage)): + or isinstance(update, tgtypes.UpdateNewEncryptedMessage)) + + +def handle_commands(client: pyrogram.Client, update, users, chats): + if not is_message_update(update): return update: tgtypes.UpdateNewMessage message: tgtypes.Message = update.message @@ -51,7 +57,7 @@ def handle_commands(client: pyrogram.Client, update, users, chats): return command = parts[0][1:] args = parts[1:] - cmd_func = lib.commands[command.lower()] + cmd_func = lib.get_command_by_name(command.lower()) ctx = lib.CommandContext(client=client, channel=message.to_id, args=args, message=message) try: cmd_func(ctx) @@ -60,3 +66,22 @@ def handle_commands(client: pyrogram.Client, update, users, chats): except Exception as e: ctx.respond("unknown exception during execution. Error will be DM'd" + str(e)) print(traceback.format_exc(), file=sys.stderr) + + +def handle_match_scripts(client: pyrogram.Client, update, users, chats): + if not is_message_update(update): + return + update: tgtypes.UpdateNewMessage + message: tgtypes.Message = update.message + author_id = message.from_id + if author_id != client.user_id: + # do not react to other people + return + text: str = message.message + for regex, func in lib.get_match_scripts().items(): + match = regex.match(text) + if match is None: + continue + ctx = lib.MatchContext(client=client, channel=message.to_id, message=message, + match=match, groups=match.groups(), named_groups=match.groupdict()) + func(ctx) diff --git a/lib/__init__.py b/lib/__init__.py index edb1b42..0cec7cf 100644 --- a/lib/__init__.py +++ b/lib/__init__.py @@ -1,61 +1,5 @@ -from typing import List +from .commands import CommandContext, get_all_commands, get_command_name, get_command_description, description, name, \ + is_command, register_command, get_command_by_name -import pyrogram -from pyrogram.api import types as tgtypes - - -def property_decorator(key): - def decorator(value): - def wrapper(func): - setattr(func, key, value) - return func - - return wrapper - - return decorator - - -name = property_decorator('name') -description = property_decorator('description') - - -def is_command(func): - return hasattr(func, 'name') and func.name is not None - - -def get_command_name(func): - return func.name - - -def get_command_description(func): - return func.description - - -commands = {} - - -def get_all_commands(): - return commands.values() - - -class CommandContext(object): - def __init__(self, client: pyrogram.Client, channel, args: List[str], message: tgtypes.Message): - import re - self.args = args - self.client = client - self.channel = channel - self.message = message - self.rest_content = re.sub('^.*? ', '', message.message) - self.author = message.from_id - - def respond(self, text): - self.client.send_message(self.channel, text=text) - - def edit(self, text): - self.client.edit_message_text(chat_id=self.channel, message_id=self.message.id, text=text) - - -def register_command(func): - if not is_command(func): - return - commands[get_command_name(func)] = func +from .match_script import get_match_script_matcher, is_match_script, match_text, register_match_script, \ + get_match_scripts, MatchContext diff --git a/lib/commands.py b/lib/commands.py new file mode 100644 index 0000000..638ab9b --- /dev/null +++ b/lib/commands.py @@ -0,0 +1,48 @@ +import typing + +import pyrogram +from pyrogram.api import types as tgtypes + +from lib.common import CommonContext +from util import property_decorator + +commands = {} + + +def get_all_commands(): + return commands.values() + + +class CommandContext(CommonContext): + def __init__(self, client: pyrogram.Client, channel, args: typing.List[str], message: tgtypes.Message): + super().__init__(client, channel, message) + import re + self.args = args + self.rest_content = re.sub('^.*? ', '', message.message) + + + +def register_command(func): + if not is_command(func): + return + commands[get_command_name(func)] = func + + +def is_command(func): + return hasattr(func, 'name') and func.name is not None + + +def get_command_name(func): + return func.name + + +def get_command_description(func): + return func.description + + +def get_command_by_name(command_name: str): + return commands[command_name.lower()] + + +name = property_decorator('name') +description = property_decorator('description') diff --git a/lib/common.py b/lib/common.py new file mode 100644 index 0000000..ade2891 --- /dev/null +++ b/lib/common.py @@ -0,0 +1,16 @@ +import pyrogram +from pyrogram.api import types as tgtypes + + +class CommonContext: + def __init__(self, client: pyrogram.Client, channel, message: tgtypes.Message): + self.client = client + self.channel = channel + self.message = message + self.author = message.from_id + + def respond(self, text): + self.client.send_message(self.channel, text=text) + + def edit(self, text): + self.client.edit_message_text(chat_id=self.channel, message_id=self.message.id, text=text) diff --git a/lib/match_script.py b/lib/match_script.py new file mode 100644 index 0000000..124c23b --- /dev/null +++ b/lib/match_script.py @@ -0,0 +1,42 @@ +import pyrogram +from pyrogram.api import types as tgtypes + +from lib.common import CommonContext + + +def match_text(regex): + import re + regex = re.compile(regex) + + def wrapper(func): + func.regex = regex + return func + + return wrapper + + +match_scripts = {} + + +def is_match_script(func): + return hasattr(func, 'regex') and func.regex is not None + + +def get_match_script_matcher(func): + return func.regex + + +def register_match_script(func): + match_scripts[func.regex] = func + + +class MatchContext(CommonContext): + def __init__(self, client: pyrogram.Client, channel, message: tgtypes.Message, match, groups, named_groups): + super().__init__(client, channel, message) + self.match = match + self.groups = groups + self.named_groups = named_groups + + +def get_match_scripts(): + return match_scripts @@ -1,14 +1,15 @@ from pyrogram import Client -from commands import handle_commands, load_commands +from commands import handle_commands, load_modules, handle_match_scripts def update_handler(client, update, users, chats): handle_commands(client, update, users, chats) + handle_match_scripts(client, update, users, chats) def main(): - load_commands() + load_modules() client = Client(session_name="userbot") client.set_update_handler(update_handler) client.start() diff --git a/modules/todo/todo.py b/modules/todo/todo.py new file mode 100644 index 0000000..89a2b78 --- /dev/null +++ b/modules/todo/todo.py @@ -0,0 +1,52 @@ +import sqlite3 + +import lib + + +def init_todo_db(conn: sqlite3.Connection): + c = conn.cursor() + c.execute( + """ + CREATE TABLE IF NOT EXISTS todos ( + id INTEGER PRIMARY KEY, + text VARCHAR(1024) + ); + """) + conn.commit() + + +def add_todo(todo): + conn = sqlite3.connect('todos.db') + init_todo_db(conn) + c = conn.cursor() + c.execute( + """ + INSERT INTO todos (text) VALUES (?); + """, [todo]) + conn.commit() + + +def read_todos() -> sqlite3.Cursor: + conn = sqlite3.connect('todos.db') + init_todo_db(conn) + c = conn.cursor() + return c.execute( + """ + SELECT id, text FROM todos LIMIT 50; + """) + + +@lib.match_text('(?i)^TODO: (?P<todos>.*)$') +def handle_todo(ctx: lib.MatchContext): + for todo in ctx.named_groups['todos'].split(','): + todo = todo.strip() + add_todo(todo) + ctx.respond(todo + ' added') + + +@lib.name('todos') +def list_todos(ctx: lib.CommandContext): + mes = 'Todos: \n' + for row in read_todos(): + mes += '%d - %s\n' % row + ctx.respond(mes) @@ -0,0 +1,11 @@ + +def property_decorator(key): + def decorator(value): + def wrapper(func): + setattr(func, key, value) + return func + + return wrapper + + return decorator + |