1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
|
package com.replaymod.gradle.remap
import org.jetbrains.kotlin.backend.common.push
import org.jetbrains.kotlin.com.intellij.openapi.util.TextRange
import org.jetbrains.kotlin.com.intellij.psi.*
import org.jetbrains.kotlin.psi.psiUtil.endOffset
import org.jetbrains.kotlin.psi.psiUtil.startOffset
internal class PsiPattern(
private val parameters: Set<PsiParameter>,
private val varArgs: Boolean,
private val pattern: PsiStatement,
private val replacement: List<String>,
private val replacementCanBeAssigned: Boolean,
) {
private fun find(pattern: PsiElement, tree: PsiElement, result: MutableList<Matcher>) {
tree.accept(object : JavaRecursiveElementVisitor() {
override fun visitElement(element: PsiElement) {
val matcher = Matcher(element)
if (matcher.match(pattern)) {
result.add(matcher)
}
super.visitElement(element)
}
})
}
fun find(statements: Array<PsiStatement>, result: MutableList<Matcher>) {
for (statement in statements) {
when (pattern) {
is PsiReturnStatement -> find(pattern.returnValue!!, statement, result)
else -> find(pattern, statement, result)
}
}
}
fun find(expr: PsiExpression, result: MutableList<Matcher>) {
when (pattern) {
is PsiReturnStatement -> find(pattern.returnValue!!, expr, result)
else -> find(pattern, expr, result)
}
}
inner class Matcher(private val root: PsiElement, private val arguments: MutableList<TextRange> = mutableListOf()) {
private val localVariables = mutableMapOf<PsiElement, PsiElement>()
fun toChanges(): List<Pair<TextRange, String>> {
val sortedArgs = arguments.toList().sortedBy { it.startOffset }
val changes = mutableListOf<Pair<TextRange, String>>()
val replacementIter = replacement.iterator()
var start = root.startOffset
for (argPsi in sortedArgs) {
var replacement = replacementIter.next()
if (argPsi.isEmpty) {
// An argument range which is empty should only happen when we're matching a varargs method but the
// match has zero varargs. This is a hack (cause it depends on exact spacing) to avoid the trailing
// comma we'd otherwise have if there are other leading arguments in the call.
replacement = replacement.removeSuffix(", ")
}
changes.push(Pair(TextRange(start, argPsi.startOffset), replacement))
start = argPsi.endOffset
}
changes.push(Pair(TextRange(start, root.endOffset), replacementIter.next()))
return changes.filterNot { it.first.isEmpty && it.second.isEmpty() }
}
fun match(pattern: PsiElement): Boolean {
val parent = root.parent
if (parent is PsiAssignmentExpression && parent.lExpression == root && !replacementCanBeAssigned) {
return false
}
return match(pattern, root)
}
private fun match(pattern: PsiElement?, expr: PsiElement?): Boolean = when (pattern) {
null -> expr == null
is PsiAssignmentExpression -> expr is PsiAssignmentExpression
&& match(pattern.lExpression, expr.lExpression)
&& match(pattern.rExpression!!, expr.rExpression!!)
is PsiBlockStatement -> expr is PsiBlockStatement
&& pattern.codeBlock.statementCount == expr.codeBlock.statementCount
&& pattern.codeBlock.statements.asSequence().zip(expr.codeBlock.statements.asSequence())
.all { (pattern, expr) -> match(pattern, expr) }
is PsiReferenceExpression -> expr is PsiExpression
&& match(pattern, expr)
is PsiMethodCallExpression -> expr is PsiMethodCallExpression
&& match(pattern.methodExpression, expr.methodExpression)
&& match(pattern.argumentList, expr.argumentList)
is PsiExpressionList -> expr is PsiExpressionList
&& match(pattern, expr)
is PsiExpressionStatement -> expr is PsiExpressionStatement
&& match(pattern.expression, expr.expression)
is PsiTypeCastExpression -> expr is PsiTypeCastExpression
&& match(pattern.operand, expr.operand)
is PsiParenthesizedExpression -> expr is PsiParenthesizedExpression
&& match(pattern.expression, expr.expression)
is PsiBinaryExpression -> expr is PsiBinaryExpression
&& pattern.operationTokenType == expr.operationTokenType
&& match(pattern.lOperand, expr.lOperand)
&& match(pattern.rOperand, expr.rOperand)
is PsiNewExpression -> expr is PsiNewExpression
&& pattern.classReference?.resolve() == expr.classReference?.resolve()
&& match(pattern.qualifier, expr.qualifier)
&& match(pattern.argumentList, expr.argumentList)
is PsiLambdaExpression -> expr is PsiLambdaExpression
&& match(pattern.parameterList, expr.parameterList)
&& match(pattern.body, expr.body)
is PsiParameterList -> expr is PsiParameterList
&& pattern.parametersCount == expr.parametersCount
&& pattern.parameters.zip(expr.parameters).all { (p, e) -> match(p, e) }
is PsiLiteralExpression -> expr is PsiLiteralExpression
&& pattern.text == expr.text
else -> false
}
private fun match(pattern: PsiParameter, expr: PsiParameter): Boolean {
if (pattern.isVarArgs != expr.isVarArgs) {
return false
}
localVariables[pattern] = expr
return true
}
private fun match(pattern: PsiReferenceExpression, expr: PsiExpression): Boolean {
val resolvedPattern = pattern.resolve()
if (resolvedPattern in parameters) {
val patternType = pattern.type ?: return false
val exprType = expr.type ?: return false
return if (patternType.isAssignableFrom(exprType)) {
arguments.add(expr.textRange)
true
} else {
false
}
}
// If the pattern is not a free variable, the expression must match it structurally
if (expr !is PsiReferenceExpression) {
return false
}
// If the pattern refers to a specific local variable, so must the expression
val localVariable = localVariables[resolvedPattern]
if (localVariable != null) {
return expr.resolve() == localVariable
}
return pattern.referenceName == expr.referenceName && match(pattern.qualifierExpression, expr.qualifierExpression)
}
private fun match(pattern: PsiExpressionList, expr: PsiExpressionList): Boolean {
val argsPattern = pattern.expressions
val argsExpr = expr.expressions
val varArgPattern = (argsPattern.lastOrNull() as? PsiReferenceExpression?)
if (varArgPattern == null || !isVarArgsParameter(varArgPattern)) {
if (argsPattern.size != argsExpr.size) {
return false
}
return argsPattern.asSequence().zip(argsExpr.asSequence())
.all { (pattern, expr) -> match(pattern, expr) }
}
if (argsPattern.size - 1 > argsExpr.size) {
return false
}
val regularArgsMatch = argsPattern.dropLast(1).asSequence().zip(argsExpr.asSequence())
.all { (pattern, expr) -> match(pattern, expr) }
if (!regularArgsMatch) {
return false
}
val regularArgsExpr = argsExpr.take(argsPattern.size - 1)
val varArgsExpr = argsExpr.drop(regularArgsExpr.size)
val argArrayTypePattern = varArgPattern.type as PsiArrayType
if (varArgsExpr.size == 1 && match(varArgPattern, varArgsExpr.single())) {
return true
}
val argTypePattern = argArrayTypePattern.componentType
for (argExpr in varArgsExpr) {
val argTypeExpr = argExpr.type ?: return false
if (!argTypePattern.isAssignableFrom(argTypeExpr)) {
return false
}
}
arguments.add(if (varArgsExpr.isEmpty()) {
TextRange.from(expr.lastChild.startOffset, 0)
} else {
TextRange(varArgsExpr.first().startOffset, varArgsExpr.last().endOffset)
})
return true
}
private fun isVarArgsParameter(expr: PsiReferenceExpression): Boolean =
varArgs && expr.firstChild is PsiReferenceParameterList && expr.resolve() == parameters.last()
}
}
|