aboutsummaryrefslogtreecommitdiff
path: root/src/main/kotlin/com/replaymod/gradle/remap/PsiPattern.kt
blob: 9bf1e901d9e7d65ca88249f4709f0a8978ef488f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
package com.replaymod.gradle.remap

import org.jetbrains.kotlin.backend.common.push
import org.jetbrains.kotlin.com.intellij.openapi.util.TextRange
import org.jetbrains.kotlin.com.intellij.psi.*
import org.jetbrains.kotlin.psi.psiUtil.endOffset
import org.jetbrains.kotlin.psi.psiUtil.startOffset

internal class PsiPattern(
        private val parameters: List<String>,
        private val pattern: PsiStatement,
        private val replacement: List<String>
) {
    private fun find(pattern: PsiElement, tree: PsiElement, result: MutableList<Matcher>) {
        tree.accept(object : JavaRecursiveElementVisitor() {
            override fun visitElement(element: PsiElement) {
                val matcher = Matcher(element)
                if (matcher.match(pattern)) {
                    result.add(matcher)
                } else {
                    super.visitElement(element)
                }
            }
        })
    }

    fun find(statements: Array<PsiStatement>, result: MutableList<Matcher>) {
        for (statement in statements) {
            when (pattern) {
                is PsiReturnStatement -> find(pattern.returnValue!!, statement, result)
                else -> find(pattern, statement, result)
            }
        }
    }

    fun find(expr: PsiExpression, result: MutableList<Matcher>) {
        when (pattern) {
            is PsiReturnStatement -> find(pattern.returnValue!!, expr, result)
            else -> find(pattern, expr, result)
        }
    }

    inner class Matcher(private val root: PsiElement, private val arguments: MutableList<PsiElement> = mutableListOf()) {

        fun toChanges(): List<Pair<TextRange, String>> {
            val sortedArgs = arguments.toList().sortedBy { it.startOffset }
            val changes = mutableListOf<Pair<TextRange, String>>()

            val replacementIter = replacement.iterator()
            var start = root.startOffset
            for (argPsi in sortedArgs) {
                changes.push(Pair(TextRange(start, argPsi.startOffset), replacementIter.next()))
                start = argPsi.endOffset
            }
            changes.push(Pair(TextRange(start, root.endOffset), replacementIter.next()))

            return changes.filterNot { it.first.isEmpty && it.second.isEmpty() }
        }

        fun match(pattern: PsiElement): Boolean = match(pattern, root)

        private fun match(pattern: PsiElement?, expr: PsiElement?): Boolean = when (pattern) {
            null -> expr == null
            is PsiAssignmentExpression -> expr is PsiAssignmentExpression
                    && match(pattern.lExpression, expr.lExpression)
                    && match(pattern.rExpression!!, expr.rExpression!!)
            is PsiBlockStatement -> expr is PsiBlockStatement
                    && pattern.codeBlock.statementCount == expr.codeBlock.statementCount
                    && pattern.codeBlock.statements.asSequence().zip(expr.codeBlock.statements.asSequence())
                    .all { (pattern, expr) -> match(pattern, expr) }
            is PsiReferenceExpression -> expr is PsiExpression
                    && match(pattern, expr)
            is PsiMethodCallExpression -> expr is PsiMethodCallExpression
                    && match(pattern.methodExpression, expr.methodExpression)
                    && match(pattern.argumentList, expr.argumentList)
            is PsiExpressionList -> expr is PsiExpressionList
                    && pattern.expressionCount == expr.expressionCount
                    && pattern.expressions.asSequence().zip(expr.expressions.asSequence())
                    .all { (pattern, expr) -> match(pattern, expr) }
            is PsiExpressionStatement -> expr is PsiExpressionStatement
                    && match(pattern.expression, expr.expression)
            is PsiTypeCastExpression -> expr is PsiTypeCastExpression
                    && match(pattern.operand, expr.operand)
            is PsiParenthesizedExpression -> expr is PsiParenthesizedExpression
                    && match(pattern.expression, expr.expression)
            is PsiNewExpression -> expr is PsiNewExpression
                    && pattern.classReference?.resolve() == expr.classReference?.resolve()
                    && match(pattern.qualifier, expr.qualifier)
                    && match(pattern.argumentList, expr.argumentList)
            is PsiLiteralExpression -> expr is PsiLiteralExpression
                    && pattern.text == expr.text
            else -> false
        }

        private fun match(pattern: PsiReferenceExpression, expr: PsiExpression): Boolean {
            return if (pattern.firstChild is PsiReferenceParameterList && pattern.referenceName in parameters) {
                val patternType = pattern.type ?: return false
                val exprType = expr.type ?: return false
                if (patternType.isAssignableFrom(exprType)) {
                    arguments.add(expr)
                    true
                } else {
                    false
                }
            }
            else expr is PsiReferenceExpression
                    && pattern.referenceName == expr.referenceName
                    && match(pattern.qualifierExpression, expr.qualifierExpression)
        }
    }
}