aboutsummaryrefslogtreecommitdiff
path: root/src/main/kotlin/com/replaymod/gradle/remap/AutoImports.kt
blob: eb57486ed7b4d25e6533579d2b5f45fc1358a330 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
package com.replaymod.gradle.remap

import org.jetbrains.kotlin.cli.jvm.compiler.KotlinCoreEnvironment
import org.jetbrains.kotlin.com.intellij.psi.*
import org.jetbrains.kotlin.psi.psiUtil.endOffset
import org.jetbrains.kotlin.psi.psiUtil.startOffset

internal class AutoImports(private val environment: KotlinCoreEnvironment) {

    private val shortClassNames = ShortNameIndex(environment)

    fun apply(originalFile: PsiFile, mappedFile: String, processedFile: String): String =
        apply(originalFile, originalFile.text.lines(), mappedFile.lines(), processedFile.lines())

    private fun apply(
        originalFile: PsiFile,
        originalLines: List<String>,
        mappedLines: List<String>,
        processedLines: List<String>,
    ): String {
        if (originalLines.size != mappedLines.size || originalLines.size != processedLines.size) {
            return mappedLines.joinToString("\n")
        }

        val inputLines = processedLines.mapIndexed { index, processedLine ->
            if (originalLines[index] == processedLine) {
                mappedLines[index]
            } else {
                processedLine
            }
        }
        val inputText = inputLines.joinToString("\n")

        val psiFileFactory = PsiFileFactory.getInstance(environment.project)
        val psiFile =
            psiFileFactory.createFileFromText(originalFile.language, inputText) as? PsiJavaFile ?: return inputText
        val pkg = psiFile.packageStatement?.packageReference?.resolve() as? PsiPackage

        val references = findOutgoingReferences(psiFile)

        val imports = psiFile.importList?.importStatements ?: emptyArray()
        val onDemandImports = imports.filter { it.isOnDemand }.mapNotNull { it.qualifiedName }.map { "$it." }.toSet()
        val existingImports = imports.filter { !it.isOnDemand }.mapNotNull { it.qualifiedName }.toSet()
        val unusedImports = existingImports.filter { it.substringAfterLast(".") !in references }.toSet()

        val implicitReferenceSources = listOfNotNull(
            psiFile.classes.flatMap { it.allInnerClasses.asIterable() },
            pkg?.classes?.asIterable(),
        )
        val implicitReferences = implicitReferenceSources.flatten().mapNotNull { it.name }.toSet()
        val importedReferences = existingImports.map { it.substringAfterLast(".") }.toSet()
        val missingReferences = references.asSequence() - importedReferences - implicitReferences
        val newImports = missingReferences.mapNotNull { shortClassNames[it].singleOrNull()?.qualifiedName }
            .filter { ref -> onDemandImports.none { ref.startsWith(it) } }
            .filter { !it.startsWith("java.lang.") }

        val finalImports = existingImports.toSet() - unusedImports.toSet() + newImports + onDemandImports.map { "$it*" }

        val textBuilder = StringBuilder(inputText)

        imports.map { it.textRange }.sortedByDescending { it.startOffset }.forEach { importRange ->
            textBuilder.replace(importRange.startOffset, importRange.endOffset, "")

            val start = importRange.startOffset
            val whiteSpaceRange = start - 1..start
            if (whiteSpaceRange.first in textBuilder.indices && whiteSpaceRange.last in textBuilder.indices) {
                val whiteSpaceReplacement = when (textBuilder.substring(whiteSpaceRange)) {
                    "\n\n" -> "\n"
                    "\n " -> "\n"
                    " \n" -> "\n"
                    "  " -> " "
                    else -> null
                }
                if (whiteSpaceReplacement != null) {
                    textBuilder.replace(whiteSpaceRange.first, whiteSpaceRange.last + 1, whiteSpaceReplacement)
                }
            }
        }

        val startOfImports = psiFile.importList?.takeIf { it.textLength > 0 }?.startOffset
        val endOfPackage = psiFile.packageStatement?.endOffset ?: 0

        val removedLineCount = inputLines.size - textBuilder.lineSequence().count()
        textBuilder.insert(startOfImports ?: endOfPackage, "\n".repeat(removedLineCount))

        var index = startOfImports ?: endOfPackage

        if (startOfImports == null) {
            repeat(2) {
                if (textBuilder[index + 1] == '\n' && textBuilder[index + 2] == '\n') {
                    index++
                }
            }
        }

        val javaImports = finalImports.filter { it.startsWith("java.") || it.startsWith("javax.") }.toSet()
        val otherImports = finalImports - javaImports
        val importGroups = listOf(otherImports, javaImports).filter { it.isNotEmpty() }

        for ((importGroupIndex, importGroup) in importGroups.withIndex()) {
            val hasMoreGroups = importGroupIndex + 1 in importGroups.indices

            for (import in importGroup.sorted()) {
                val hasPrecedingStatement = index > 0 && textBuilder[index - 1] != '\n'
                val canAdvanceToNextLine = textBuilder[index + 1] == '\n' && textBuilder[index + 2] == '\n'

                val str = (if (hasPrecedingStatement) " " else "") + "import $import;"
                textBuilder.insert(index, str)
                index += str.length + if (canAdvanceToNextLine) 1 else 0
            }

            if (hasMoreGroups && textBuilder[index + 1] == '\n' && textBuilder[index + 2] == '\n') {
                index++
            }
        }

        return textBuilder.toString()
    }

    private fun findOutgoingReferences(file: PsiJavaFile): Set<String> {
        val references = mutableSetOf<String>()

        fun recordReference(reference: PsiJavaCodeReferenceElement) {
            if (reference.isQualified) return
            val name = reference.referenceName ?: return
            if (!name.first().isUpperCase()) return
            val resolved = reference.resolve()
            if (resolved is PsiTypeParameter) return
            if (resolved is PsiVariable) return
            references.add(name)
        }

        file.accept(object : JavaRecursiveElementVisitor() {
            override fun visitReferenceElement(reference: PsiJavaCodeReferenceElement) {
                recordReference(reference)
                super.visitReferenceElement(reference)
            }
        })
        return references
    }
}