package moe.nea.ledger.database import moe.nea.ledger.UUIDUtil import java.sql.Connection import java.sql.PreparedStatement import java.sql.ResultSet import java.time.Instant import java.util.UUID interface DBSchema { val tables: List } interface DBType { val dbType: String fun get(result: ResultSet, index: Int): T fun set(stmt: PreparedStatement, index: Int, value: T) fun getName(): String = javaClass.simpleName fun mapped( from: (R) -> T, to: (T) -> R, ): DBType { return object : DBType { override fun getName(): String { return "Mapped(${this@DBType.getName()})" } override val dbType: String get() = this@DBType.dbType override fun get(result: ResultSet, index: Int): R { return to(this@DBType.get(result, index)) } override fun set(stmt: PreparedStatement, index: Int, value: R) { this@DBType.set(stmt, index, from(value)) } } } } object DBUuid : DBType { override val dbType: String get() = "TEXT" override fun get(result: ResultSet, index: Int): UUID { return UUIDUtil.parseDashlessUuid(result.getString(index)) } override fun set(stmt: PreparedStatement, index: Int, value: UUID) { stmt.setString(index, value.toString()) } } object DBUlid : DBType { override val dbType: String get() = "TEXT" override fun get(result: ResultSet, index: Int): UUIDUtil.ULIDWrapper { val text = result.getString(index) return UUIDUtil.ULIDWrapper(text) } override fun set(stmt: PreparedStatement, index: Int, value: UUIDUtil.ULIDWrapper) { stmt.setString(index, value.wrapped) } } object DBString : DBType { override val dbType: String get() = "TEXT" override fun get(result: ResultSet, index: Int): String { return result.getString(index) } override fun set(stmt: PreparedStatement, index: Int, value: String) { stmt.setString(index, value) } } class DBEnum>( val type: Class, ) : DBType { companion object { inline operator fun > invoke(): DBEnum { return DBEnum(T::class.java) } } override val dbType: String get() = "TEXT" override fun getName(): String { return "DBEnum(${type.simpleName})" } override fun set(stmt: PreparedStatement, index: Int, value: T) { stmt.setString(index, value.name) } override fun get(result: ResultSet, index: Int): T { val name = result.getString(index) return java.lang.Enum.valueOf(type, name) } } object DBDouble : DBType { override val dbType: String get() = "DOUBLE" override fun get(result: ResultSet, index: Int): Double { return result.getDouble(index) } override fun set(stmt: PreparedStatement, index: Int, value: Double) { stmt.setDouble(index, value) } } object DBInt : DBType { override val dbType: String get() = "INTEGER" override fun get(result: ResultSet, index: Int): Long { return result.getLong(index) } override fun set(stmt: PreparedStatement, index: Int, value: Long) { stmt.setLong(index, value) } } object DBInstant : DBType { override val dbType: String get() = "INTEGER" override fun set(stmt: PreparedStatement, index: Int, value: Instant) { stmt.setLong(index, value.toEpochMilli()) } override fun get(result: ResultSet, index: Int): Instant { return Instant.ofEpochMilli(result.getLong(index)) } } class Column @Deprecated("Use Table.column instead") constructor(val name: String, val type: DBType) { val sqlName get() = "`$name`" } interface Constraint { val affectedColumns: Collection> fun asSQL(): String } class UniqueConstraint(val columns: List>) : Constraint { init { require(columns.isNotEmpty()) } override val affectedColumns: Collection> get() = columns override fun asSQL(): String { return "UNIQUE (${columns.joinToString() { it.sqlName }})" } } abstract class Table(val name: String) { val sqlName get() = "`$name`" protected val _mutable_columns: MutableList> = mutableListOf() protected val _mutable_constraints: MutableList = mutableListOf() val columns: List> get() = _mutable_columns val constraints get() = _mutable_constraints protected fun unique(vararg columns: Column<*>) { _mutable_constraints.add(UniqueConstraint(columns.toList())) } protected fun column(name: String, type: DBType): Column { @Suppress("DEPRECATION") val column = Column(name, type) _mutable_columns.add(column) return column } fun debugSchema() { val nameWidth = columns.maxOf { it.name.length } val typeWidth = columns.maxOf { it.type.getName().length } val totalWidth = maxOf(2 + nameWidth + 3 + typeWidth + 2, name.length + 4) val adjustedTypeWidth = totalWidth - nameWidth - 2 - 3 - 2 var string = "\n" string += ("+" + "-".repeat(totalWidth - 2) + "+\n") string += ("| $name${" ".repeat(totalWidth - 4 - name.length)} |\n") string += ("+" + "-".repeat(totalWidth - 2) + "+\n") for (column in columns) { string += ("| ${column.name}${" ".repeat(nameWidth - column.name.length)} |") string += (" ${column.type.getName()}" + "${" ".repeat(adjustedTypeWidth - column.type.getName().length)} |\n") } string += ("+" + "-".repeat(totalWidth - 2) + "+") println(string) } fun createIfNotExists( connection: Connection, filteredColumns: List> = columns ) { val properties = mutableListOf() for (column in filteredColumns) { properties.add("${column.sqlName} ${column.type.dbType}") } val columnSet = filteredColumns.toSet() for (constraint in constraints) { if (columnSet.containsAll(constraint.affectedColumns)) { properties.add(constraint.asSQL()) } } connection.prepareAndLog("CREATE TABLE IF NOT EXISTS $sqlName (" + properties.joinToString() + ")") .execute() } fun alterTableAddColumns( connection: Connection, newColumns: List> ) { for (column in newColumns) { connection.prepareAndLog("ALTER TABLE $sqlName ADD ${column.sqlName} ${column.type.dbType}") .execute() } for (constraint in constraints) { // TODO: automatically add constraints, maybe (or maybe move constraints into the upgrade schema) } } enum class OnConflict { FAIL, IGNORE, REPLACE, ; fun asSql(): String { return name } } fun insert(connection: Connection, onConflict: OnConflict = OnConflict.FAIL, block: (InsertStatement) -> Unit) { val insert = InsertStatement(HashMap()) block(insert) require(insert.properties.keys == columns.toSet()) val columnNames = columns.joinToString { it.sqlName } val valueNames = columns.joinToString { "?" } val statement = connection.prepareAndLog("INSERT OR ${onConflict.asSql()} INTO $sqlName ($columnNames) VALUES ($valueNames)") for ((index, column) in columns.withIndex()) { (column as Column).type.set(statement, index + 1, insert.properties[column]!!) } statement.execute() } fun selectAll(connection: Connection): Query { return Query(connection, columns, this) } } class InsertStatement(val properties: MutableMap, Any>) { operator fun set(key: Column, value: T) { properties[key] = value } } fun Connection.prepareAndLog(statement: String): PreparedStatement { println("Preparing to execute $statement") return prepareStatement(statement) } class Query( val connection: Connection, val selectedColumns: List>, var table: Table, var limit: UInt? = null, var skip: UInt? = null, // var order: OrderClause?= null, // val condition: List ) : Iterable { fun skip(skip: UInt): Query { require(limit != null) this.skip = skip return this } fun limit(limit: UInt): Query { this.limit = limit return this } override fun iterator(): Iterator { val columnSelections = selectedColumns.joinToString { it.sqlName } var query = "SELECT $columnSelections FROM ${table.sqlName} " if (limit != null) { query += "LIMIT $limit " if (skip != null) { query += "OFFSET $skip " } } val prepared = connection.prepareAndLog(query.trim()) val results = prepared.executeQuery() return object : Iterator { var hasAdvanced = false var hasEnded = false override fun hasNext(): Boolean { if (hasEnded) return false if (hasAdvanced) return true if (results.next()) { hasAdvanced = true return true } else { hasEnded = true return false } } override fun next(): ResultRow { if (!hasNext()) { throw NoSuchElementException() } hasAdvanced = false return ResultRow(selectedColumns.withIndex().associate { it.value to it.value.type.get(results, it.index + 1) }) } } } } class ResultRow(val columnValues: Map, *>) { operator fun get(column: Column): T { val value = columnValues[column] ?: error("Invalid column ${column.name}. Only ${columnValues.keys.joinToString { it.name }} are available.") return value as T } }