summaryrefslogtreecommitdiff
path: root/query/parse.py
blob: b509044c338a6ec02e48da982be40bd4514d107a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from typing import List

from antlr4 import *

from query.QueryLexer import QueryLexer
from query.QueryParser import QueryParser
from query.QueryVisitor import QueryVisitor


def make_function(outer_scope, argument_names: List[str], parser: QueryParser, prog: QueryParser.ProgContext):
    def function(*args):
        visitor = MyQueryVisitor(parser)
        visitor.memory = outer_scope.copy()
        for i in range(len(args)):
            visitor.memory[argument_names[i]] = args[i]
        return visitor.for_result(prog)

    return function


class Return(Exception):
    def __init__(self, thing):
        self.ret = thing


class MyQueryVisitor(QueryVisitor):
    def __init__(self, parser: QueryParser):
        self.memory = {
        }
        self.parser: QueryParser = parser
        self.last_expr = None

    def for_result(self, ctx):
        try:
            self.visit(ctx)
        except Return as ret:
            return ret.ret
        return self.last_expr

    def visitProg(self, ctx: QueryParser.ProgContext):
        for stat in ctx.getChildren(lambda child: isinstance(child, QueryParser.StatContext)):
            self.visit(stat)

    def visitFunc(self, ctx: QueryParser.FuncContext):
        args = ctx.arguments().getText()[:-2].split(',')
        return make_function(self.memory, args, self.parser, ctx.prog())

    def visitAssign(self, ctx):
        name = ctx.ID().getText()
        value = self.visit(ctx.expr())
        self.memory[name] = value
        return value

    def visitRawExpr(self, ctx: QueryParser.RawExprContext):
        value = self.visit(ctx.expr())
        self.last_expr = value
        return value

    def visitInt(self, ctx):
        return int(ctx.INT().getText())

    def visitId(self, ctx):
        name = ctx.ID().getText()
        if name in self.memory:
            return self.memory[name]
        return 0

    def visitAccess(self, ctx: QueryParser.AccessContext):
        thing = self.visit(ctx.expr())
        return getattr(thing, ctx.ID(), 0)

    def visitMulDiv(self, ctx):
        left = (self.visit(ctx.expr(0)))
        right = (self.visit(ctx.expr(1)))
        if ctx.op.type == QueryParser.MUL:
            return left * right
        return left / right

    def visitString(self, ctx: QueryParser.StringContext):
        return eval(ctx.getText())

    def visitAddSub(self, ctx):
        left = int(self.visit(ctx.expr(0)))
        right = int(self.visit(ctx.expr(1)))
        if ctx.op.type == QueryParser.ADD:
            return left + right
        return left - right

    def visitParens(self, ctx):
        return self.visit(ctx.expr())

    def visitCall(self, ctx: QueryParser.CallContext):
        to_call = self.visit(ctx.expr(0))
        args = [self.visit(arg) for arg in ctx.getChildren(lambda x: isinstance(x, QueryParser.ExprContext))][1:]
        return to_call(*args)

    def visit(self, tree):
        return super(MyQueryVisitor, self).visit(tree)


def parse(text, **kwargs):
    parser = QueryParser(CommonTokenStream(QueryLexer(InputStream(text))))
    tree = parser.prog()
    visitor = MyQueryVisitor(parser)
    for key, value in kwargs.items():
        visitor.memory[key] = value
    return visitor.for_result(tree)


if __name__ == '__main__':
    with open('debug.txt') as handle:
        content = handle.read()
    print(parse(content))