diff options
author | Jonas Herzig <me@johni0702.de> | 2022-05-25 07:35:00 +0200 |
---|---|---|
committer | Jonas Herzig <me@johni0702.de> | 2022-05-27 08:46:01 +0200 |
commit | 4019ebe20786059fdce5b25c7cf6d746a083eef3 (patch) | |
tree | e2aee9d92f7287d735a82f21debb91ad8d8b2fe4 /src | |
parent | 3104e9fdb3c2df7528813e03e4a3e08a3e1a8c2a (diff) | |
download | Remap-4019ebe20786059fdce5b25c7cf6d746a083eef3.tar.gz Remap-4019ebe20786059fdce5b25c7cf6d746a083eef3.tar.bz2 Remap-4019ebe20786059fdce5b25c7cf6d746a083eef3.zip |
Support matching lambda expressions with @Pattern
Diffstat (limited to 'src')
4 files changed, 137 insertions, 9 deletions
diff --git a/src/main/kotlin/com/replaymod/gradle/remap/PsiPattern.kt b/src/main/kotlin/com/replaymod/gradle/remap/PsiPattern.kt index fcba558..c1cdf0f 100644 --- a/src/main/kotlin/com/replaymod/gradle/remap/PsiPattern.kt +++ b/src/main/kotlin/com/replaymod/gradle/remap/PsiPattern.kt @@ -7,7 +7,7 @@ import org.jetbrains.kotlin.psi.psiUtil.endOffset import org.jetbrains.kotlin.psi.psiUtil.startOffset internal class PsiPattern( - private val parameters: List<String>, + private val parameters: Set<PsiParameter>, private val varArgs: Boolean, private val pattern: PsiStatement, private val replacement: List<String> @@ -43,6 +43,8 @@ internal class PsiPattern( inner class Matcher(private val root: PsiElement, private val arguments: MutableList<TextRange> = mutableListOf()) { + private val localVariables = mutableMapOf<PsiElement, PsiElement>() + fun toChanges(): List<Pair<TextRange, String>> { val sortedArgs = arguments.toList().sortedBy { it.startOffset } val changes = mutableListOf<Pair<TextRange, String>>() @@ -97,25 +99,51 @@ internal class PsiPattern( && pattern.classReference?.resolve() == expr.classReference?.resolve() && match(pattern.qualifier, expr.qualifier) && match(pattern.argumentList, expr.argumentList) + is PsiLambdaExpression -> expr is PsiLambdaExpression + && match(pattern.parameterList, expr.parameterList) + && match(pattern.body, expr.body) + is PsiParameterList -> expr is PsiParameterList + && pattern.parametersCount == expr.parametersCount + && pattern.parameters.zip(expr.parameters).all { (p, e) -> match(p, e) } is PsiLiteralExpression -> expr is PsiLiteralExpression && pattern.text == expr.text else -> false } + private fun match(pattern: PsiParameter, expr: PsiParameter): Boolean { + if (pattern.isVarArgs != expr.isVarArgs) { + return false + } + + localVariables[pattern] = expr + return true + } + private fun match(pattern: PsiReferenceExpression, expr: PsiExpression): Boolean { - return if (pattern.firstChild is PsiReferenceParameterList && pattern.referenceName in parameters) { + val resolvedPattern = pattern.resolve() + if (resolvedPattern in parameters) { val patternType = pattern.type ?: return false val exprType = expr.type ?: return false - if (patternType.isAssignableFrom(exprType)) { + return if (patternType.isAssignableFrom(exprType)) { arguments.add(expr.textRange) true } else { false } } - else expr is PsiReferenceExpression - && pattern.referenceName == expr.referenceName - && match(pattern.qualifierExpression, expr.qualifierExpression) + + // If the pattern is not a free variable, the expression must match it structurally + if (expr !is PsiReferenceExpression) { + return false + } + + // If the pattern refers to a specific local variable, so must the expression + val localVariable = localVariables[resolvedPattern] + if (localVariable != null) { + return expr.resolve() == localVariable + } + + return pattern.referenceName == expr.referenceName && match(pattern.qualifierExpression, expr.qualifierExpression) } private fun match(pattern: PsiExpressionList, expr: PsiExpressionList): Boolean { @@ -167,6 +195,6 @@ internal class PsiPattern( } private fun isVarArgsParameter(expr: PsiReferenceExpression): Boolean = - varArgs && expr.firstChild is PsiReferenceParameterList && expr.referenceName == parameters.last() + varArgs && expr.firstChild is PsiReferenceParameterList && expr.resolve() == parameters.last() } }
\ No newline at end of file diff --git a/src/main/kotlin/com/replaymod/gradle/remap/PsiPatterns.kt b/src/main/kotlin/com/replaymod/gradle/remap/PsiPatterns.kt index 14eee07..fe6040c 100644 --- a/src/main/kotlin/com/replaymod/gradle/remap/PsiPatterns.kt +++ b/src/main/kotlin/com/replaymod/gradle/remap/PsiPatterns.kt @@ -23,7 +23,8 @@ internal class PsiPatterns(private val annotationFQN: String) { val body = method.body!! val methodLine = offsetToLineNumber(file.text, body.startOffset) - val parameters = method.parameterList.parameters.map { it.name } + val parameters = method.parameterList.parameters.toSet() + val parameterNames = parameters.map { it.name }.toSet() val varArgs = method.parameterList.parameters.lastOrNull()?.isVarArgs ?: false val project = file.project @@ -55,7 +56,7 @@ internal class PsiPatterns(private val annotationFQN: String) { val arguments = mutableListOf<PsiExpression>() replacementExpression.accept(object : JavaRecursiveElementVisitor() { override fun visitReferenceExpression(expr: PsiReferenceExpression) { - if (expr.firstChild is PsiReferenceParameterList && expr.referenceName in parameters) { + if (expr.firstChild is PsiReferenceParameterList && expr.referenceName in parameterNames) { arguments.add(expr) } else { super.visitReferenceExpression(expr) diff --git a/src/test/kotlin/com/replaymod/gradle/remap/pattern/TestLambdaExpression.kt b/src/test/kotlin/com/replaymod/gradle/remap/pattern/TestLambdaExpression.kt new file mode 100644 index 0000000..dc45492 --- /dev/null +++ b/src/test/kotlin/com/replaymod/gradle/remap/pattern/TestLambdaExpression.kt @@ -0,0 +1,90 @@ +package com.replaymod.gradle.remap.pattern + +import com.replaymod.gradle.remap.util.TestData +import io.kotest.matchers.shouldBe +import org.junit.jupiter.api.Disabled +import org.junit.jupiter.api.Test + +class TestLambdaExpression { + @Test + fun `should find simply lambda expression`() { + TestData.remap(""" + class Test { + private void test() { + a.pkg.A.supplier(() -> "test"); + } + } + """.trimIndent(), """ + @remap.Pattern + private void pattern(String str) { + return a.pkg.A.supplier(() -> str); + } + """.trimIndent(), """ + @remap.Pattern + private void pattern(String str) { + return matched(str); + } + """.trimIndent()) shouldBe """ + class Test { + private void test() { + matched("test"); + } + } + """.trimIndent() + } + + @Test + fun `should find lambda expression with bound arguments`() { + TestData.remap(""" + class Test { + private void test() { + a.pkg.A.function(str -> str + "test"); + } + } + """.trimIndent(), """ + @remap.Pattern + private void pattern(String str) { + return a.pkg.A.function(s -> s + str); + } + """.trimIndent(), """ + @remap.Pattern + private void pattern(String str) { + return matched(str); + } + """.trimIndent()) shouldBe """ + class Test { + private void test() { + matched("test"); + } + } + """.trimIndent() + } + + @Test + @Disabled("Not yet implemented. Requires more complex replacement scheme.") + fun `should preserve bound lambda argument names`() { + TestData.remap(""" + class Test { + private void test() { + a.pkg.A.function(str -> str + "test"); + } + } + """.trimIndent(), """ + @remap.Pattern + private void pattern(String str) { + return a.pkg.A.function(s -> s + str); + } + """.trimIndent(), """ + @remap.Pattern + private void pattern(String str) { + return matched(s -> s + str); + } + """.trimIndent()) shouldBe """ + class Test { + private void test() { + matched(str -> str + "test"); + } + } + """.trimIndent() + } +}
\ No newline at end of file diff --git a/src/testA/java/a/pkg/A.java b/src/testA/java/a/pkg/A.java index 1a0dcac..be9bc1b 100644 --- a/src/testA/java/a/pkg/A.java +++ b/src/testA/java/a/pkg/A.java @@ -1,5 +1,8 @@ package a.pkg; +import java.util.function.Function; +import java.util.function.Supplier; + public class A extends AParent implements AInterface { private A a; private int aField; @@ -98,6 +101,12 @@ public class A extends AParent implements AInterface { new A() {}; } + public static void supplier(Supplier<String> supplier) { + } + + public static void function(Function<String, String> func) { + } + public class Inner { private int aField; } |