aboutsummaryrefslogtreecommitdiff
path: root/src/main/kotlin/com/replaymod/gradle/remap/PsiPatterns.kt
blob: 14eee0773242ce3a89ecd7a2cd1ad038ece8c8cd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
package com.replaymod.gradle.remap

import com.replaymod.gradle.remap.PsiPattern.Matcher
import org.jetbrains.kotlin.backend.common.push
import org.jetbrains.kotlin.com.intellij.openapi.util.text.StringUtil.offsetToLineNumber
import org.jetbrains.kotlin.com.intellij.psi.*
import org.jetbrains.kotlin.psi.psiUtil.endOffset
import org.jetbrains.kotlin.psi.psiUtil.startOffset

internal class PsiPatterns(private val annotationFQN: String) {
    private val patterns = mutableListOf<PsiPattern>()

    fun read(file: PsiFile, replacementFile: String) {
        file.accept(object : JavaRecursiveElementVisitor() {
            override fun visitMethod(method: PsiMethod) {
                method.getAnnotation(annotationFQN) ?: return
                addPattern(file, method, replacementFile)
            }
        })
    }

    private fun addPattern(file: PsiFile, method: PsiMethod, replacementFile: String) {
        val body = method.body!!
        val methodLine = offsetToLineNumber(file.text, body.startOffset)

        val parameters = method.parameterList.parameters.map { it.name }
        val varArgs = method.parameterList.parameters.lastOrNull()?.isVarArgs ?: false

        val project = file.project
        val psiFileFactory = PsiFileFactory.getInstance(project)
        val replacementPsi = psiFileFactory.createFileFromText(file.language, replacementFile) as PsiJavaFile
        val replacementClass = replacementPsi.classes.first()
        val replacementMethod = replacementClass.findMethodsByName(method.name, false).let { candidates ->
            if (candidates.size > 1) {
                candidates.find { offsetToLineNumber(replacementFile, it.body!!.startOffset) == methodLine }
            } else {
                candidates.firstOrNull()
            } ?: throw RuntimeException("Failed to find updated method \"${method.name}\" (line ${methodLine + 1})")
        }
        val replacementBody = replacementMethod.body!!

        // If the body does not change, then there is no point in applying this pattern
        if (body.text == replacementBody.text) return

        // If either body is empty, then consider the pattern to be disabled
        if (body.statements.isEmpty()) return
        if (replacementBody.statements.isEmpty()) return

        val replacementExpression = when (val statement = replacementBody.statements.last()) {
            is PsiReturnStatement -> statement.returnValue!!
            else -> statement
        }

        val replacement = mutableListOf<String>().also { replacement ->
            val arguments = mutableListOf<PsiExpression>()
            replacementExpression.accept(object : JavaRecursiveElementVisitor() {
                override fun visitReferenceExpression(expr: PsiReferenceExpression) {
                    if (expr.firstChild is PsiReferenceParameterList && expr.referenceName in parameters) {
                        arguments.add(expr)
                    } else {
                        super.visitReferenceExpression(expr)
                    }
                }
            })
            val sortedArgs = arguments.toList().sortedBy { it.startOffset }
            var start = replacementExpression.startOffset
            for (argPsi in sortedArgs) {
                replacement.push(replacementFile.slice(start until argPsi.startOffset))
                start = argPsi.endOffset
            }
            replacement.push(replacementFile.slice(start until replacementExpression.endOffset))
        }

        patterns.add(PsiPattern(parameters, varArgs, body.statements.last(), replacement))
    }

    fun find(block: PsiCodeBlock): MutableList<Matcher> {
        val results = mutableListOf<Matcher>()
        for (pattern in patterns) {
            pattern.find(block.statements, results)
        }
        return results
    }

    fun find(expr: PsiExpression): MutableList<Matcher> {
        val results = mutableListOf<Matcher>()
        for (pattern in patterns) {
            pattern.find(expr, results)
        }
        return results
    }
}