1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
|
package com.replaymod.gradle.remap
import org.jetbrains.kotlin.cli.jvm.compiler.KotlinCoreEnvironment
import org.jetbrains.kotlin.com.intellij.psi.*
import org.jetbrains.kotlin.psi.psiUtil.endOffset
import org.jetbrains.kotlin.psi.psiUtil.startOffset
internal class AutoImports(private val environment: KotlinCoreEnvironment) {
private val shortClassNames = ShortNameIndex(environment)
fun apply(originalFile: PsiFile, mappedFile: String, processedFile: String): String =
apply(originalFile, originalFile.text.lines(), mappedFile.lines(), processedFile.lines())
private fun apply(
originalFile: PsiFile,
originalLines: List<String>,
mappedLines: List<String>,
processedLines: List<String>,
): String {
if (originalLines.size != mappedLines.size || originalLines.size != processedLines.size) {
return mappedLines.joinToString("\n")
}
val inputLines = processedLines.mapIndexed { index, processedLine ->
if (originalLines[index] == processedLine) {
mappedLines[index]
} else {
processedLine
}
}
val inputText = inputLines.joinToString("\n")
val psiFileFactory = PsiFileFactory.getInstance(environment.project)
val psiFile =
psiFileFactory.createFileFromText(originalFile.language, inputText) as? PsiJavaFile ?: return inputText
val pkg = psiFile.packageStatement?.packageReference?.resolve() as? PsiPackage
val references = findOutgoingReferences(psiFile)
val imports = psiFile.importList?.importStatements ?: emptyArray()
val onDemandImports = imports.filter { it.isOnDemand }.mapNotNull { it.qualifiedName }.map { "$it." }.toSet()
val existingImports = imports.filter { !it.isOnDemand }.mapNotNull { it.qualifiedName }.toSet()
val unusedImports = existingImports.filter { it.substringAfterLast(".") !in references }.toSet()
val implicitReferenceSources = listOfNotNull(
psiFile.classes.flatMap { it.allInnerClasses.asIterable() },
pkg?.classes?.asIterable(),
)
val implicitReferences = implicitReferenceSources.flatten().mapNotNull { it.name }.toSet()
val importedReferences = existingImports.map { it.substringAfterLast(".") }.toSet()
val missingReferences = references.asSequence() - importedReferences - implicitReferences
val newImports = missingReferences.mapNotNull { shortClassNames[it].singleOrNull()?.qualifiedName }
.filter { ref -> onDemandImports.none { ref.startsWith(it) } }
.filter { !it.startsWith("java.lang.") }
val finalImports = existingImports.toSet() - unusedImports.toSet() + newImports + onDemandImports.map { "$it*" }
val textBuilder = StringBuilder(inputText)
imports.map { it.textRange }.sortedByDescending { it.startOffset }.forEach { importRange ->
textBuilder.replace(importRange.startOffset, importRange.endOffset, "")
val start = importRange.startOffset
val whiteSpaceRange = start - 1..start
if (whiteSpaceRange.first in textBuilder.indices && whiteSpaceRange.last in textBuilder.indices) {
val whiteSpaceReplacement = when (textBuilder.substring(whiteSpaceRange)) {
"\n\n" -> "\n"
"\n " -> "\n"
" \n" -> "\n"
" " -> " "
else -> null
}
if (whiteSpaceReplacement != null) {
textBuilder.replace(whiteSpaceRange.first, whiteSpaceRange.last + 1, whiteSpaceReplacement)
}
}
}
val startOfImports = psiFile.importList?.takeIf { it.textLength > 0 }?.startOffset
val endOfPackage = psiFile.packageStatement?.endOffset ?: 0
val removedLineCount = inputLines.size - textBuilder.lineSequence().count()
textBuilder.insert(startOfImports ?: endOfPackage, "\n".repeat(removedLineCount))
var index = startOfImports ?: endOfPackage
if (startOfImports == null) {
repeat(2) {
if (textBuilder[index + 1] == '\n' && textBuilder[index + 2] == '\n') {
index++
}
}
}
val javaImports = finalImports.filter { it.startsWith("java.") || it.startsWith("javax.") }.toSet()
val otherImports = finalImports - javaImports
val importGroups = listOf(otherImports, javaImports).filter { it.isNotEmpty() }
for ((importGroupIndex, importGroup) in importGroups.withIndex()) {
val hasMoreGroups = importGroupIndex + 1 in importGroups.indices
for (import in importGroup.sorted()) {
val hasPrecedingStatement = index > 0 && textBuilder[index - 1] != '\n'
val canAdvanceToNextLine = textBuilder[index + 1] == '\n' && textBuilder[index + 2] == '\n'
val str = (if (hasPrecedingStatement) " " else "") + "import $import;"
textBuilder.insert(index, str)
index += str.length + if (canAdvanceToNextLine) 1 else 0
}
if (hasMoreGroups && textBuilder[index + 1] == '\n' && textBuilder[index + 2] == '\n') {
index++
}
}
return textBuilder.toString()
}
private fun findOutgoingReferences(file: PsiJavaFile): Set<String> {
val references = mutableSetOf<String>()
fun recordReference(reference: PsiJavaCodeReferenceElement) {
if (reference.isQualified) return
val name = reference.referenceName ?: return
if (!name.first().isUpperCase()) return
val resolved = reference.resolve()
if (resolved is PsiTypeParameter) return
if (resolved is PsiVariable) return
references.add(name)
}
file.accept(object : JavaRecursiveElementVisitor() {
override fun visitReferenceElement(reference: PsiJavaCodeReferenceElement) {
recordReference(reference)
super.visitReferenceElement(reference)
}
})
return references
}
}
|