aboutsummaryrefslogtreecommitdiff
path: root/symbols/src/main/kotlin/process/SubscribeAnnotationProcessor.kt
blob: 6d88b69875dd48a86cd8d3e182b06337d8dccb4f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
/*
 * SPDX-FileCopyrightText: 2024 Linnea Gräf <nea@nea.moe>
 *
 * SPDX-License-Identifier: GPL-3.0-or-later
 */

package moe.nea.firmament.annotations.process

import com.google.auto.service.AutoService
import com.google.devtools.ksp.processing.CodeGenerator
import com.google.devtools.ksp.processing.Dependencies
import com.google.devtools.ksp.processing.KSPLogger
import com.google.devtools.ksp.processing.Resolver
import com.google.devtools.ksp.processing.SymbolProcessor
import com.google.devtools.ksp.processing.SymbolProcessorEnvironment
import com.google.devtools.ksp.processing.SymbolProcessorProvider
import com.google.devtools.ksp.symbol.ClassKind
import com.google.devtools.ksp.symbol.KSAnnotated
import com.google.devtools.ksp.symbol.KSClassDeclaration
import com.google.devtools.ksp.symbol.KSFunctionDeclaration
import com.google.devtools.ksp.symbol.KSType
import com.google.devtools.ksp.symbol.Nullability
import com.google.devtools.ksp.validate
import java.text.SimpleDateFormat
import java.util.Date
import moe.nea.firmament.annotations.Subscribe

class SubscribeAnnotationProcessor(
    val logger: KSPLogger,
    val codeGenerator: CodeGenerator,
) : SymbolProcessor {
    override fun finish() {
        subscriptions.sort()
        val subscriptionSet = subscriptions.mapTo(mutableSetOf()) { it.parent.containingFile!! }
        val dependencies = Dependencies(
            aggregating = true,
            *subscriptionSet.toTypedArray())
        val subscriptionsFile =
            codeGenerator
                .createNewFile(dependencies, "moe.nea.firmament.annotations.generated", "AllSubscriptions")
                .bufferedWriter()
        subscriptionsFile.apply {
            appendLine("// This file is @generated by SubscribeAnnotationProcessor")
            appendLine("// Do not edit")
            for (file in subscriptionSet) {
                appendLine("// Dependency: ${file.filePath}")
            }
            appendLine("package moe.nea.firmament.annotations.generated")
            appendLine()
            appendLine("import moe.nea.firmament.events.subscription.*")
            appendLine()
            appendLine("object AllSubscriptions {")
            appendLine("  fun provideSubscriptions(addSubscription: (Subscription<*>) -> Unit) {")
            for (subscription in subscriptions) {
                val owner = subscription.parent.qualifiedName!!.asString()
                val method = subscription.child.simpleName.asString()
                val type = subscription.type.declaration.qualifiedName!!.asString()
                appendLine("    addSubscription(Subscription<$type>(")
                appendLine("        ${owner},")
                appendLine("        ${owner}::${method},")
                appendLine("        ${type}))")
            }
            appendLine("  }")
            appendLine("}")
        }
        subscriptionsFile.close()
    }

    data class Subscription(
        val parent: KSClassDeclaration,
        val child: KSFunctionDeclaration,
        val type: KSType,
    ) : Comparable<Subscription> {
        override fun compareTo(other: Subscription): Int {
            var compare = parent.qualifiedName!!.asString().compareTo(other.parent.qualifiedName!!.asString())
            if (compare != 0) return compare
            compare = other.child.simpleName.asString().compareTo(child.simpleName.asString())
            if (compare != 0) return compare
            compare = other.type.declaration.qualifiedName!!.asString()
                .compareTo(type.declaration.qualifiedName!!.asString())
            if (compare != 0) return compare
            return 0
        }
    }

    val subscriptions = mutableListOf<Subscription>()

    fun processCandidates(list: List<KSAnnotated>) {
        for (element in list) {
            if (element !is KSFunctionDeclaration) {
                logger.error("@Subscribe annotation on a not-function", element)
                continue
            }
            if (element.isAbstract) {
                logger.error("@Subscribe annotation on an abstract function", element)
                continue
            }
            val parent = element.parentDeclaration
            if (parent !is KSClassDeclaration || parent.classKind != ClassKind.OBJECT) {
                logger.error("@Subscribe on a non-object", element)
                continue
            }
            val param = element.parameters.singleOrNull()
            if (param == null) {
                logger.error("@Subscribe annotated functions need to take exactly one parameter", element)
                continue
            }
            val type = param.type.resolve()
            if (type.nullability != Nullability.NOT_NULL) {
                logger.error("@Subscribe annotated functions cannot take a nullable event", element)
                continue
            }
            subscriptions.add(Subscription(parent, element, type))
        }
    }

    override fun process(resolver: Resolver): List<KSAnnotated> {
        val candidates = resolver.getSymbolsWithAnnotation(Subscribe::class.qualifiedName!!).toList()
        val valid = candidates.filter { it.validate() }
        val invalid = candidates.filter { !it.validate() }
        processCandidates(valid)
        return invalid
    }
}

@AutoService(SymbolProcessorProvider::class)
class SubscribeAnnotationProcessorProvider : SymbolProcessorProvider {
    override fun create(environment: SymbolProcessorEnvironment): SymbolProcessor {
        return SubscribeAnnotationProcessor(environment.logger, environment.codeGenerator)
    }
}