aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main/kotlin/com/replaymod/gradle/remap/PsiPattern.kt69
-rw-r--r--src/main/kotlin/com/replaymod/gradle/remap/PsiPatterns.kt3
-rw-r--r--src/test/kotlin/com/replaymod/gradle/remap/pattern/TestVarArgs.kt192
3 files changed, 257 insertions, 7 deletions
diff --git a/src/main/kotlin/com/replaymod/gradle/remap/PsiPattern.kt b/src/main/kotlin/com/replaymod/gradle/remap/PsiPattern.kt
index 9bf1e90..c1809f7 100644
--- a/src/main/kotlin/com/replaymod/gradle/remap/PsiPattern.kt
+++ b/src/main/kotlin/com/replaymod/gradle/remap/PsiPattern.kt
@@ -8,6 +8,7 @@ import org.jetbrains.kotlin.psi.psiUtil.startOffset
internal class PsiPattern(
private val parameters: List<String>,
+ private val varArgs: Boolean,
private val pattern: PsiStatement,
private val replacement: List<String>
) {
@@ -40,7 +41,7 @@ internal class PsiPattern(
}
}
- inner class Matcher(private val root: PsiElement, private val arguments: MutableList<PsiElement> = mutableListOf()) {
+ inner class Matcher(private val root: PsiElement, private val arguments: MutableList<TextRange> = mutableListOf()) {
fun toChanges(): List<Pair<TextRange, String>> {
val sortedArgs = arguments.toList().sortedBy { it.startOffset }
@@ -49,7 +50,14 @@ internal class PsiPattern(
val replacementIter = replacement.iterator()
var start = root.startOffset
for (argPsi in sortedArgs) {
- changes.push(Pair(TextRange(start, argPsi.startOffset), replacementIter.next()))
+ var replacement = replacementIter.next()
+ if (argPsi.isEmpty) {
+ // An argument range which is empty should only happen when we're matching a varargs method but the
+ // match has zero varargs. This is a hack (cause it depends on exact spacing) to avoid the trailing
+ // comma we'd otherwise have if there are other leading arguments in the call.
+ replacement = replacement.removeSuffix(", ")
+ }
+ changes.push(Pair(TextRange(start, argPsi.startOffset), replacement))
start = argPsi.endOffset
}
changes.push(Pair(TextRange(start, root.endOffset), replacementIter.next()))
@@ -74,9 +82,7 @@ internal class PsiPattern(
&& match(pattern.methodExpression, expr.methodExpression)
&& match(pattern.argumentList, expr.argumentList)
is PsiExpressionList -> expr is PsiExpressionList
- && pattern.expressionCount == expr.expressionCount
- && pattern.expressions.asSequence().zip(expr.expressions.asSequence())
- .all { (pattern, expr) -> match(pattern, expr) }
+ && match(pattern, expr)
is PsiExpressionStatement -> expr is PsiExpressionStatement
&& match(pattern.expression, expr.expression)
is PsiTypeCastExpression -> expr is PsiTypeCastExpression
@@ -97,7 +103,7 @@ internal class PsiPattern(
val patternType = pattern.type ?: return false
val exprType = expr.type ?: return false
if (patternType.isAssignableFrom(exprType)) {
- arguments.add(expr)
+ arguments.add(expr.textRange)
true
} else {
false
@@ -107,5 +113,56 @@ internal class PsiPattern(
&& pattern.referenceName == expr.referenceName
&& match(pattern.qualifierExpression, expr.qualifierExpression)
}
+
+ private fun match(pattern: PsiExpressionList, expr: PsiExpressionList): Boolean {
+ val argsPattern = pattern.expressions
+ val argsExpr = expr.expressions
+
+ val varArgPattern = (argsPattern.lastOrNull() as? PsiReferenceExpression?)
+ if (varArgPattern == null || !isVarArgsParameter(varArgPattern)) {
+ if (argsPattern.size != argsExpr.size) {
+ return false
+ }
+ return argsPattern.asSequence().zip(argsExpr.asSequence())
+ .all { (pattern, expr) -> match(pattern, expr) }
+ }
+
+ if (argsPattern.size - 1 > argsExpr.size) {
+ return false
+ }
+
+ val regularArgsMatch = argsPattern.dropLast(1).asSequence().zip(argsExpr.asSequence())
+ .all { (pattern, expr) -> match(pattern, expr) }
+ if (!regularArgsMatch) {
+ return false
+ }
+
+ val regularArgsExpr = argsExpr.take(argsPattern.size - 1)
+ val varArgsExpr = argsExpr.drop(regularArgsExpr.size)
+
+ val argArrayTypePattern = varArgPattern.type as PsiArrayType
+ if (varArgsExpr.size == 1 && match(varArgPattern, varArgsExpr.single())) {
+ return true
+ }
+
+ val argTypePattern = argArrayTypePattern.componentType
+ for (argExpr in varArgsExpr) {
+ val argTypeExpr = argExpr.type ?: return false
+ if (!argTypePattern.isAssignableFrom(argTypeExpr)) {
+ return false
+ }
+ }
+
+ arguments.add(if (varArgsExpr.isEmpty()) {
+ TextRange.from(expr.lastChild.startOffset, 0)
+ } else {
+ TextRange(varArgsExpr.first().startOffset, varArgsExpr.last().endOffset)
+ })
+
+ return true
+ }
+
+ private fun isVarArgsParameter(expr: PsiReferenceExpression): Boolean =
+ varArgs && expr.firstChild is PsiReferenceParameterList && expr.referenceName == parameters.last()
}
} \ No newline at end of file
diff --git a/src/main/kotlin/com/replaymod/gradle/remap/PsiPatterns.kt b/src/main/kotlin/com/replaymod/gradle/remap/PsiPatterns.kt
index a1f6e8b..14eee07 100644
--- a/src/main/kotlin/com/replaymod/gradle/remap/PsiPatterns.kt
+++ b/src/main/kotlin/com/replaymod/gradle/remap/PsiPatterns.kt
@@ -24,6 +24,7 @@ internal class PsiPatterns(private val annotationFQN: String) {
val methodLine = offsetToLineNumber(file.text, body.startOffset)
val parameters = method.parameterList.parameters.map { it.name }
+ val varArgs = method.parameterList.parameters.lastOrNull()?.isVarArgs ?: false
val project = file.project
val psiFileFactory = PsiFileFactory.getInstance(project)
@@ -70,7 +71,7 @@ internal class PsiPatterns(private val annotationFQN: String) {
replacement.push(replacementFile.slice(start until replacementExpression.endOffset))
}
- patterns.add(PsiPattern(parameters, body.statements.last(), replacement))
+ patterns.add(PsiPattern(parameters, varArgs, body.statements.last(), replacement))
}
fun find(block: PsiCodeBlock): MutableList<Matcher> {
diff --git a/src/test/kotlin/com/replaymod/gradle/remap/pattern/TestVarArgs.kt b/src/test/kotlin/com/replaymod/gradle/remap/pattern/TestVarArgs.kt
new file mode 100644
index 0000000..a9caa15
--- /dev/null
+++ b/src/test/kotlin/com/replaymod/gradle/remap/pattern/TestVarArgs.kt
@@ -0,0 +1,192 @@
+package com.replaymod.gradle.remap.pattern
+
+import com.replaymod.gradle.remap.util.TestData
+import io.kotest.matchers.shouldBe
+import org.junit.jupiter.api.Test
+
+class TestVarArgs {
+ @Test
+ fun `should find varargs method`() {
+ TestData.remap("""
+ class Test {
+ private void test() {
+ method();
+ method("1");
+ method("1", "2");
+ method("1", "2", null);
+ method(new String[0]);
+ method("1", "2", 3);
+ }
+ }
+ """.trimIndent(), """
+ @remap.Pattern
+ private String pattern(String...args) {
+ return method(args);
+ }
+ """.trimIndent(), """
+ @remap.Pattern
+ private String pattern(String...args) {
+ return matched(args);
+ }
+ """.trimIndent()) shouldBe """
+ class Test {
+ private void test() {
+ matched();
+ matched("1");
+ matched("1", "2");
+ matched("1", "2", null);
+ matched(new String[0]);
+ method("1", "2", 3);
+ }
+ }
+ """.trimIndent()
+ }
+
+ @Test
+ fun `should find varargs method with fixed leading argument`() {
+ TestData.remap("""
+ class Test {
+ private void test() {
+ method(42);
+ method(42, "1");
+ method(42, "1", "2");
+ method(42, "1", "2", null);
+ method(42, new String[0]);
+ method(42, "1", "2", 3);
+ }
+ }
+ """.trimIndent(), """
+ @remap.Pattern
+ private String pattern(String...args) {
+ return method(42, args);
+ }
+ """.trimIndent(), """
+ @remap.Pattern
+ private String pattern(String...args) {
+ return matched(42, args);
+ }
+ """.trimIndent()) shouldBe """
+ class Test {
+ private void test() {
+ matched(42);
+ matched(42, "1");
+ matched(42, "1", "2");
+ matched(42, "1", "2", null);
+ matched(42, new String[0]);
+ method(42, "1", "2", 3);
+ }
+ }
+ """.trimIndent()
+ }
+
+ @Test
+ fun `should find varargs method with variable leading argument`() {
+ TestData.remap("""
+ class Test {
+ private void test() {
+ method(42);
+ method(43, "1");
+ method(44, "1", "2");
+ method(45, "1", "2", null);
+ method(46, new String[0]);
+ method(47, "1", "2", 3);
+ }
+ }
+ """.trimIndent(), """
+ @remap.Pattern
+ private String pattern(int i, String...args) {
+ return method(i, args);
+ }
+ """.trimIndent(), """
+ @remap.Pattern
+ private String pattern(int i, String...args) {
+ return matched(i, args);
+ }
+ """.trimIndent()) shouldBe """
+ class Test {
+ private void test() {
+ matched(42);
+ matched(43, "1");
+ matched(44, "1", "2");
+ matched(45, "1", "2", null);
+ matched(46, new String[0]);
+ method(47, "1", "2", 3);
+ }
+ }
+ """.trimIndent()
+ }
+
+ @Test
+ fun `should allow leading argument to be removed`() {
+ TestData.remap("""
+ class Test {
+ private void test() {
+ method(42);
+ method(42, "1");
+ method(42, "1", "2");
+ method(42, "1", "2", null);
+ method(42, new String[0]);
+ method(42, "1", "2", 3);
+ }
+ }
+ """.trimIndent(), """
+ @remap.Pattern
+ private String pattern(String...args) {
+ return method(42, args);
+ }
+ """.trimIndent(), """
+ @remap.Pattern
+ private String pattern(String...args) {
+ return matched(args);
+ }
+ """.trimIndent()) shouldBe """
+ class Test {
+ private void test() {
+ matched();
+ matched("1");
+ matched("1", "2");
+ matched("1", "2", null);
+ matched(new String[0]);
+ method(42, "1", "2", 3);
+ }
+ }
+ """.trimIndent()
+ }
+
+ @Test
+ fun `should allow leading argument to be added`() {
+ TestData.remap("""
+ class Test {
+ private void test() {
+ method();
+ method("1");
+ method("1", "2");
+ method("1", "2", null);
+ method(new String[0]);
+ method("1", "2", 3);
+ }
+ }
+ """.trimIndent(), """
+ @remap.Pattern
+ private String pattern(String...args) {
+ return method(args);
+ }
+ """.trimIndent(), """
+ @remap.Pattern
+ private String pattern(String...args) {
+ return matched(42, args);
+ }
+ """.trimIndent()) shouldBe """
+ class Test {
+ private void test() {
+ matched(42);
+ matched(42, "1");
+ matched(42, "1", "2");
+ matched(42, "1", "2", null);
+ matched(42, new String[0]);
+ method("1", "2", 3);
+ }
+ }
+ """.trimIndent()
+ }
+} \ No newline at end of file