aboutsummaryrefslogtreecommitdiff
path: root/sbdata/task.py
diff options
context:
space:
mode:
authornea <romangraef@gmail.com>2022-03-12 01:57:57 +0100
committernea <romangraef@gmail.com>2022-03-12 01:57:57 +0100
commite7caa7a9ba0202c44ad02ea9fd37c27bd4336c26 (patch)
tree03bce4042c47e475a99dfe90bc1e7a31afe25358 /sbdata/task.py
downloadsbdata-e7caa7a9ba0202c44ad02ea9fd37c27bd4336c26.tar.gz
sbdata-e7caa7a9ba0202c44ad02ea9fd37c27bd4336c26.tar.bz2
sbdata-e7caa7a9ba0202c44ad02ea9fd37c27bd4336c26.zip
Initial commit
Diffstat (limited to 'sbdata/task.py')
-rw-r--r--sbdata/task.py64
1 files changed, 64 insertions, 0 deletions
diff --git a/sbdata/task.py b/sbdata/task.py
new file mode 100644
index 0000000..7011ca2
--- /dev/null
+++ b/sbdata/task.py
@@ -0,0 +1,64 @@
+import dataclasses
+import os
+import sys
+import typing
+
+import questionary
+
+_T = typing.TypeVar('_T')
+
+
+class Arguments:
+
+ def __init__(self, args: list[str]):
+ self.prog = args[0]
+ self.args: typing.Dict[str, str] = {}
+ self.flags: list[str] = []
+ self.no_prompt = os.environ.get('PROMPT') == 'NO_PROMPT'
+ self.task: typing.Optional[str]
+ self.task = None
+ last_arg = None
+ for arg in args:
+ if last_arg is None:
+ if arg.startswith('--'):
+ last_arg = arg[2:]
+ elif arg.startswith('-'):
+ self.flags.append(arg[1:])
+ elif arg.startswith(':'):
+ self.task = arg[1:]
+ else:
+ print("Unknown arg: " + arg)
+ else:
+ self.args[last_arg] = arg
+ last_arg = None
+
+ def get_value(self, label: str, value: _T, question: questionary.Question) -> _T:
+ if value is None:
+ if self.no_prompt:
+ print('No argument present for ' + label)
+ sys.exit(1)
+ return question.ask()
+ return value
+
+ def has_flag(self, param: str) -> bool:
+ return param in self.flags
+
+
+@dataclasses.dataclass
+class Task:
+ label: str
+ name: str
+ run: typing.Callable
+
+
+tasks = {}
+
+TASK_TYPE = typing.Callable[[Arguments], None]
+
+
+def register_task(label: str) -> typing.Callable[[TASK_TYPE], TASK_TYPE]:
+ def d(func: TASK_TYPE) -> TASK_TYPE:
+ tasks[func.__name__] = Task(label, func.__name__, func)
+ return func
+
+ return d