aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJonas Herzig <me@johni0702.de>2022-05-25 07:35:00 +0200
committerJonas Herzig <me@johni0702.de>2022-05-27 08:46:01 +0200
commit4019ebe20786059fdce5b25c7cf6d746a083eef3 (patch)
treee2aee9d92f7287d735a82f21debb91ad8d8b2fe4
parent3104e9fdb3c2df7528813e03e4a3e08a3e1a8c2a (diff)
downloadRemap-4019ebe20786059fdce5b25c7cf6d746a083eef3.tar.gz
Remap-4019ebe20786059fdce5b25c7cf6d746a083eef3.tar.bz2
Remap-4019ebe20786059fdce5b25c7cf6d746a083eef3.zip
Support matching lambda expressions with @Pattern
-rw-r--r--src/main/kotlin/com/replaymod/gradle/remap/PsiPattern.kt42
-rw-r--r--src/main/kotlin/com/replaymod/gradle/remap/PsiPatterns.kt5
-rw-r--r--src/test/kotlin/com/replaymod/gradle/remap/pattern/TestLambdaExpression.kt90
-rw-r--r--src/testA/java/a/pkg/A.java9
4 files changed, 137 insertions, 9 deletions
diff --git a/src/main/kotlin/com/replaymod/gradle/remap/PsiPattern.kt b/src/main/kotlin/com/replaymod/gradle/remap/PsiPattern.kt
index fcba558..c1cdf0f 100644
--- a/src/main/kotlin/com/replaymod/gradle/remap/PsiPattern.kt
+++ b/src/main/kotlin/com/replaymod/gradle/remap/PsiPattern.kt
@@ -7,7 +7,7 @@ import org.jetbrains.kotlin.psi.psiUtil.endOffset
import org.jetbrains.kotlin.psi.psiUtil.startOffset
internal class PsiPattern(
- private val parameters: List<String>,
+ private val parameters: Set<PsiParameter>,
private val varArgs: Boolean,
private val pattern: PsiStatement,
private val replacement: List<String>
@@ -43,6 +43,8 @@ internal class PsiPattern(
inner class Matcher(private val root: PsiElement, private val arguments: MutableList<TextRange> = mutableListOf()) {
+ private val localVariables = mutableMapOf<PsiElement, PsiElement>()
+
fun toChanges(): List<Pair<TextRange, String>> {
val sortedArgs = arguments.toList().sortedBy { it.startOffset }
val changes = mutableListOf<Pair<TextRange, String>>()
@@ -97,25 +99,51 @@ internal class PsiPattern(
&& pattern.classReference?.resolve() == expr.classReference?.resolve()
&& match(pattern.qualifier, expr.qualifier)
&& match(pattern.argumentList, expr.argumentList)
+ is PsiLambdaExpression -> expr is PsiLambdaExpression
+ && match(pattern.parameterList, expr.parameterList)
+ && match(pattern.body, expr.body)
+ is PsiParameterList -> expr is PsiParameterList
+ && pattern.parametersCount == expr.parametersCount
+ && pattern.parameters.zip(expr.parameters).all { (p, e) -> match(p, e) }
is PsiLiteralExpression -> expr is PsiLiteralExpression
&& pattern.text == expr.text
else -> false
}
+ private fun match(pattern: PsiParameter, expr: PsiParameter): Boolean {
+ if (pattern.isVarArgs != expr.isVarArgs) {
+ return false
+ }
+
+ localVariables[pattern] = expr
+ return true
+ }
+
private fun match(pattern: PsiReferenceExpression, expr: PsiExpression): Boolean {
- return if (pattern.firstChild is PsiReferenceParameterList && pattern.referenceName in parameters) {
+ val resolvedPattern = pattern.resolve()
+ if (resolvedPattern in parameters) {
val patternType = pattern.type ?: return false
val exprType = expr.type ?: return false
- if (patternType.isAssignableFrom(exprType)) {
+ return if (patternType.isAssignableFrom(exprType)) {
arguments.add(expr.textRange)
true
} else {
false
}
}
- else expr is PsiReferenceExpression
- && pattern.referenceName == expr.referenceName
- && match(pattern.qualifierExpression, expr.qualifierExpression)
+
+ // If the pattern is not a free variable, the expression must match it structurally
+ if (expr !is PsiReferenceExpression) {
+ return false
+ }
+
+ // If the pattern refers to a specific local variable, so must the expression
+ val localVariable = localVariables[resolvedPattern]
+ if (localVariable != null) {
+ return expr.resolve() == localVariable
+ }
+
+ return pattern.referenceName == expr.referenceName && match(pattern.qualifierExpression, expr.qualifierExpression)
}
private fun match(pattern: PsiExpressionList, expr: PsiExpressionList): Boolean {
@@ -167,6 +195,6 @@ internal class PsiPattern(
}
private fun isVarArgsParameter(expr: PsiReferenceExpression): Boolean =
- varArgs && expr.firstChild is PsiReferenceParameterList && expr.referenceName == parameters.last()
+ varArgs && expr.firstChild is PsiReferenceParameterList && expr.resolve() == parameters.last()
}
} \ No newline at end of file
diff --git a/src/main/kotlin/com/replaymod/gradle/remap/PsiPatterns.kt b/src/main/kotlin/com/replaymod/gradle/remap/PsiPatterns.kt
index 14eee07..fe6040c 100644
--- a/src/main/kotlin/com/replaymod/gradle/remap/PsiPatterns.kt
+++ b/src/main/kotlin/com/replaymod/gradle/remap/PsiPatterns.kt
@@ -23,7 +23,8 @@ internal class PsiPatterns(private val annotationFQN: String) {
val body = method.body!!
val methodLine = offsetToLineNumber(file.text, body.startOffset)
- val parameters = method.parameterList.parameters.map { it.name }
+ val parameters = method.parameterList.parameters.toSet()
+ val parameterNames = parameters.map { it.name }.toSet()
val varArgs = method.parameterList.parameters.lastOrNull()?.isVarArgs ?: false
val project = file.project
@@ -55,7 +56,7 @@ internal class PsiPatterns(private val annotationFQN: String) {
val arguments = mutableListOf<PsiExpression>()
replacementExpression.accept(object : JavaRecursiveElementVisitor() {
override fun visitReferenceExpression(expr: PsiReferenceExpression) {
- if (expr.firstChild is PsiReferenceParameterList && expr.referenceName in parameters) {
+ if (expr.firstChild is PsiReferenceParameterList && expr.referenceName in parameterNames) {
arguments.add(expr)
} else {
super.visitReferenceExpression(expr)
diff --git a/src/test/kotlin/com/replaymod/gradle/remap/pattern/TestLambdaExpression.kt b/src/test/kotlin/com/replaymod/gradle/remap/pattern/TestLambdaExpression.kt
new file mode 100644
index 0000000..dc45492
--- /dev/null
+++ b/src/test/kotlin/com/replaymod/gradle/remap/pattern/TestLambdaExpression.kt
@@ -0,0 +1,90 @@
+package com.replaymod.gradle.remap.pattern
+
+import com.replaymod.gradle.remap.util.TestData
+import io.kotest.matchers.shouldBe
+import org.junit.jupiter.api.Disabled
+import org.junit.jupiter.api.Test
+
+class TestLambdaExpression {
+ @Test
+ fun `should find simply lambda expression`() {
+ TestData.remap("""
+ class Test {
+ private void test() {
+ a.pkg.A.supplier(() -> "test");
+ }
+ }
+ """.trimIndent(), """
+ @remap.Pattern
+ private void pattern(String str) {
+ return a.pkg.A.supplier(() -> str);
+ }
+ """.trimIndent(), """
+ @remap.Pattern
+ private void pattern(String str) {
+ return matched(str);
+ }
+ """.trimIndent()) shouldBe """
+ class Test {
+ private void test() {
+ matched("test");
+ }
+ }
+ """.trimIndent()
+ }
+
+ @Test
+ fun `should find lambda expression with bound arguments`() {
+ TestData.remap("""
+ class Test {
+ private void test() {
+ a.pkg.A.function(str -> str + "test");
+ }
+ }
+ """.trimIndent(), """
+ @remap.Pattern
+ private void pattern(String str) {
+ return a.pkg.A.function(s -> s + str);
+ }
+ """.trimIndent(), """
+ @remap.Pattern
+ private void pattern(String str) {
+ return matched(str);
+ }
+ """.trimIndent()) shouldBe """
+ class Test {
+ private void test() {
+ matched("test");
+ }
+ }
+ """.trimIndent()
+ }
+
+ @Test
+ @Disabled("Not yet implemented. Requires more complex replacement scheme.")
+ fun `should preserve bound lambda argument names`() {
+ TestData.remap("""
+ class Test {
+ private void test() {
+ a.pkg.A.function(str -> str + "test");
+ }
+ }
+ """.trimIndent(), """
+ @remap.Pattern
+ private void pattern(String str) {
+ return a.pkg.A.function(s -> s + str);
+ }
+ """.trimIndent(), """
+ @remap.Pattern
+ private void pattern(String str) {
+ return matched(s -> s + str);
+ }
+ """.trimIndent()) shouldBe """
+ class Test {
+ private void test() {
+ matched(str -> str + "test");
+ }
+ }
+ """.trimIndent()
+ }
+} \ No newline at end of file
diff --git a/src/testA/java/a/pkg/A.java b/src/testA/java/a/pkg/A.java
index 1a0dcac..be9bc1b 100644
--- a/src/testA/java/a/pkg/A.java
+++ b/src/testA/java/a/pkg/A.java
@@ -1,5 +1,8 @@
package a.pkg;
+import java.util.function.Function;
+import java.util.function.Supplier;
+
public class A extends AParent implements AInterface {
private A a;
private int aField;
@@ -98,6 +101,12 @@ public class A extends AParent implements AInterface {
new A() {};
}
+ public static void supplier(Supplier<String> supplier) {
+ }
+
+ public static void function(Function<String, String> func) {
+ }
+
public class Inner {
private int aField;
}