package com.replaymod.gradle.remap import com.replaymod.gradle.remap.PsiPattern.Matcher import org.jetbrains.kotlin.backend.common.push import org.jetbrains.kotlin.com.intellij.openapi.util.text.StringUtil.offsetToLineNumber import org.jetbrains.kotlin.com.intellij.psi.* import org.jetbrains.kotlin.psi.psiUtil.endOffset import org.jetbrains.kotlin.psi.psiUtil.startOffset internal class PsiPatterns(private val annotationFQN: String) { private val patterns = mutableListOf() fun read(file: PsiFile, replacementFile: String) { file.accept(object : JavaRecursiveElementVisitor() { override fun visitMethod(method: PsiMethod) { method.getAnnotation(annotationFQN) ?: return addPattern(file, method, replacementFile) } }) } private fun addPattern(file: PsiFile, method: PsiMethod, replacementFile: String) { val body = method.body!! val methodLine = offsetToLineNumber(file.text, body.startOffset) val parameters = method.parameterList.parameters.map { it.name } val project = file.project val psiFileFactory = PsiFileFactory.getInstance(project) val replacementPsi = psiFileFactory.createFileFromText(file.language, replacementFile) as PsiJavaFile val replacementClass = replacementPsi.classes.first() val replacementMethod = replacementClass.findMethodsByName(method.name, false).let { candidates -> if (candidates.size > 1) { candidates.find { offsetToLineNumber(replacementFile, it.body!!.startOffset) == methodLine } } else { candidates.firstOrNull() } ?: throw RuntimeException("Failed to find updated method \"${method.name}\" (line ${methodLine + 1})") } if (method.text == replacementMethod.text) return val replacementExpression = when (val statement = replacementMethod.body!!.statements.last()) { is PsiReturnStatement -> statement.returnValue!! else -> statement } val replacement = mutableListOf().also { replacement -> val arguments = mutableListOf() replacementExpression.accept(object : JavaRecursiveElementVisitor() { override fun visitReferenceExpression(expr: PsiReferenceExpression) { if (expr.firstChild is PsiReferenceParameterList && expr.referenceName in parameters) { arguments.add(expr) } else { super.visitReferenceExpression(expr) } } }) val sortedArgs = arguments.toList().sortedBy { it.startOffset } var start = replacementExpression.startOffset for (argPsi in sortedArgs) { replacement.push(replacementFile.slice(start until argPsi.startOffset)) start = argPsi.endOffset } replacement.push(replacementFile.slice(start until replacementExpression.endOffset)) } patterns.add(PsiPattern(parameters, body.statements.last(), replacement)) } fun find(block: PsiCodeBlock): MutableList { val results = mutableListOf() for (pattern in patterns) { pattern.find(block.statements, results) } return results } fun find(expr: PsiExpression): MutableList { val results = mutableListOf() for (pattern in patterns) { pattern.find(expr, results) } return results } }