aboutsummaryrefslogtreecommitdiff
path: root/src/main/kotlin/moe/nea/ledger/database/DBSchema.kt
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/kotlin/moe/nea/ledger/database/DBSchema.kt')
-rw-r--r--src/main/kotlin/moe/nea/ledger/database/DBSchema.kt122
1 files changed, 116 insertions, 6 deletions
diff --git a/src/main/kotlin/moe/nea/ledger/database/DBSchema.kt b/src/main/kotlin/moe/nea/ledger/database/DBSchema.kt
index 5c9099c..dee99e4 100644
--- a/src/main/kotlin/moe/nea/ledger/database/DBSchema.kt
+++ b/src/main/kotlin/moe/nea/ledger/database/DBSchema.kt
@@ -1,10 +1,11 @@
package moe.nea.ledger.database
+import moe.nea.ledger.UUIDUtil
import java.sql.Connection
import java.sql.PreparedStatement
import java.sql.ResultSet
-import java.sql.Timestamp
import java.time.Instant
+import java.util.UUID
interface DBSchema {
val tables: List<Table>
@@ -16,6 +17,54 @@ interface DBType<T> {
fun get(result: ResultSet, index: Int): T
fun set(stmt: PreparedStatement, index: Int, value: T)
fun getName(): String = javaClass.simpleName
+ fun <R> mapped(
+ from: (R) -> T,
+ to: (T) -> R,
+ ): DBType<R> {
+ return object : DBType<R> {
+ override fun getName(): String {
+ return "Mapped(${this@DBType.getName()})"
+ }
+
+ override val dbType: String
+ get() = this@DBType.dbType
+
+ override fun get(result: ResultSet, index: Int): R {
+ return to(this@DBType.get(result, index))
+ }
+
+ override fun set(stmt: PreparedStatement, index: Int, value: R) {
+ this@DBType.set(stmt, index, from(value))
+ }
+ }
+ }
+}
+
+object DBUuid : DBType<UUID> {
+ override val dbType: String
+ get() = "TEXT"
+
+ override fun get(result: ResultSet, index: Int): UUID {
+ return UUIDUtil.parseDashlessUuid(result.getString(index))
+ }
+
+ override fun set(stmt: PreparedStatement, index: Int, value: UUID) {
+ stmt.setString(index, value.toString())
+ }
+}
+
+object DBUlid : DBType<UUIDUtil.ULIDWrapper> {
+ override val dbType: String
+ get() = "TEXT"
+
+ override fun get(result: ResultSet, index: Int): UUIDUtil.ULIDWrapper {
+ val text = result.getString(index)
+ return UUIDUtil.ULIDWrapper(text)
+ }
+
+ override fun set(stmt: PreparedStatement, index: Int, value: UUIDUtil.ULIDWrapper) {
+ stmt.setString(index, value.wrapped)
+ }
}
object DBString : DBType<String> {
@@ -31,6 +80,45 @@ object DBString : DBType<String> {
}
}
+class DBEnum<T : Enum<T>>(
+ val type: Class<T>,
+) : DBType<T> {
+ companion object {
+ inline operator fun <reified T : Enum<T>> invoke(): DBEnum<T> {
+ return DBEnum(T::class.java)
+ }
+ }
+
+ override val dbType: String
+ get() = "TEXT"
+
+ override fun getName(): String {
+ return "DBEnum(${type.simpleName})"
+ }
+
+ override fun set(stmt: PreparedStatement, index: Int, value: T) {
+ stmt.setString(index, value.name)
+ }
+
+ override fun get(result: ResultSet, index: Int): T {
+ val name = result.getString(index)
+ return java.lang.Enum.valueOf(type, name)
+ }
+}
+
+object DBDouble : DBType<Double> {
+ override val dbType: String
+ get() = "DOUBLE"
+
+ override fun get(result: ResultSet, index: Int): Double {
+ return result.getDouble(index)
+ }
+
+ override fun set(stmt: PreparedStatement, index: Int, value: Double) {
+ stmt.setDouble(index, value)
+ }
+}
+
object DBInt : DBType<Long> {
override val dbType: String
get() = "INTEGER"
@@ -57,12 +145,12 @@ object DBInstant : DBType<Instant> {
}
}
-// TODO: add table
class Column<T> @Deprecated("Use Table.column instead") constructor(val name: String, val type: DBType<T>) {
val sqlName get() = "`$name`"
}
interface Constraint {
+ val affectedColumns: Collection<Column<*>>
fun asSQL(): String
}
@@ -71,6 +159,9 @@ class UniqueConstraint(val columns: List<Column<*>>) : Constraint {
require(columns.isNotEmpty())
}
+ override val affectedColumns: Collection<Column<*>>
+ get() = columns
+
override fun asSQL(): String {
return "UNIQUE (${columns.joinToString() { it.sqlName }})"
}
@@ -111,18 +202,37 @@ abstract class Table(val name: String) {
println(string)
}
- fun createIfNotExists(connection: Connection) {
+ fun createIfNotExists(
+ connection: Connection,
+ filteredColumns: List<Column<*>> = columns
+ ) {
val properties = mutableListOf<String>()
- for (column in columns) {
+ for (column in filteredColumns) {
properties.add("${column.sqlName} ${column.type.dbType}")
}
+ val columnSet = filteredColumns.toSet()
for (constraint in constraints) {
- properties.add(constraint.asSQL())
+ if (columnSet.containsAll(constraint.affectedColumns)) {
+ properties.add(constraint.asSQL())
+ }
}
- connection.prepareAndLog("CREATE TABLE IF NOT EXISTS `$name` (" + properties.joinToString() + ")")
+ connection.prepareAndLog("CREATE TABLE IF NOT EXISTS $sqlName (" + properties.joinToString() + ")")
.execute()
}
+ fun alterTableAddColumns(
+ connection: Connection,
+ newColumns: List<Column<*>>
+ ) {
+ for (column in newColumns) {
+ connection.prepareAndLog("ALTER TABLE $sqlName ADD ${column.sqlName} ${column.type.dbType}")
+ .execute()
+ }
+ for (constraint in constraints) {
+ // TODO: automatically add constraints, maybe (or maybe move constraints into the upgrade schema)
+ }
+ }
+
enum class OnConflict {
FAIL,
IGNORE,