package moe.nea.ledger.database import moe.nea.ledger.UUIDUtil import java.sql.Connection import java.sql.PreparedStatement import java.sql.ResultSet import java.time.Instant import java.util.UUID interface DBSchema { val tables: List<Table> } interface DBType<T> { val dbType: String fun get(result: ResultSet, index: Int): T fun set(stmt: PreparedStatement, index: Int, value: T) fun getName(): String = javaClass.simpleName fun <R> mapped( from: (R) -> T, to: (T) -> R, ): DBType<R> { return object : DBType<R> { override fun getName(): String { return "Mapped(${this@DBType.getName()})" } override val dbType: String get() = this@DBType.dbType override fun get(result: ResultSet, index: Int): R { return to(this@DBType.get(result, index)) } override fun set(stmt: PreparedStatement, index: Int, value: R) { this@DBType.set(stmt, index, from(value)) } } } } object DBUuid : DBType<UUID> { override val dbType: String get() = "TEXT" override fun get(result: ResultSet, index: Int): UUID { return UUIDUtil.parseDashlessUuid(result.getString(index)) } override fun set(stmt: PreparedStatement, index: Int, value: UUID) { stmt.setString(index, value.toString()) } } object DBUlid : DBType<UUIDUtil.ULIDWrapper> { override val dbType: String get() = "TEXT" override fun get(result: ResultSet, index: Int): UUIDUtil.ULIDWrapper { val text = result.getString(index) return UUIDUtil.ULIDWrapper(text) } override fun set(stmt: PreparedStatement, index: Int, value: UUIDUtil.ULIDWrapper) { stmt.setString(index, value.wrapped) } } object DBString : DBType<String> { override val dbType: String get() = "TEXT" override fun get(result: ResultSet, index: Int): String { return result.getString(index) } override fun set(stmt: PreparedStatement, index: Int, value: String) { stmt.setString(index, value) } } class DBEnum<T : Enum<T>>( val type: Class<T>, ) : DBType<T> { companion object { inline operator fun <reified T : Enum<T>> invoke(): DBEnum<T> { return DBEnum(T::class.java) } } override val dbType: String get() = "TEXT" override fun getName(): String { return "DBEnum(${type.simpleName})" } override fun set(stmt: PreparedStatement, index: Int, value: T) { stmt.setString(index, value.name) } override fun get(result: ResultSet, index: Int): T { val name = result.getString(index) return java.lang.Enum.valueOf(type, name) } } object DBDouble : DBType<Double> { override val dbType: String get() = "DOUBLE" override fun get(result: ResultSet, index: Int): Double { return result.getDouble(index) } override fun set(stmt: PreparedStatement, index: Int, value: Double) { stmt.setDouble(index, value) } } object DBInt : DBType<Long> { override val dbType: String get() = "INTEGER" override fun get(result: ResultSet, index: Int): Long { return result.getLong(index) } override fun set(stmt: PreparedStatement, index: Int, value: Long) { stmt.setLong(index, value) } } object DBInstant : DBType<Instant> { override val dbType: String get() = "INTEGER" override fun set(stmt: PreparedStatement, index: Int, value: Instant) { stmt.setLong(index, value.toEpochMilli()) } override fun get(result: ResultSet, index: Int): Instant { return Instant.ofEpochMilli(result.getLong(index)) } } class Column<T> @Deprecated("Use Table.column instead") constructor( val table: Table, val name: String, val type: DBType<T> ) { val sqlName get() = "`$name`" val qualifiedSqlName get() = table.sqlName + "." + sqlName } interface Constraint { val affectedColumns: Collection<Column<*>> fun asSQL(): String } class UniqueConstraint(val columns: List<Column<*>>) : Constraint { init { require(columns.isNotEmpty()) } override val affectedColumns: Collection<Column<*>> get() = columns override fun asSQL(): String { return "UNIQUE (${columns.joinToString() { it.sqlName }})" } } abstract class Table(val name: String) { val sqlName get() = "`$name`" protected val _mutable_columns: MutableList<Column<*>> = mutableListOf() protected val _mutable_constraints: MutableList<Constraint> = mutableListOf() val columns: List<Column<*>> get() = _mutable_columns val constraints get() = _mutable_constraints protected fun unique(vararg columns: Column<*>) { _mutable_constraints.add(UniqueConstraint(columns.toList())) } protected fun <T> column(name: String, type: DBType<T>): Column<T> { @Suppress("DEPRECATION") val column = Column(this, name, type) _mutable_columns.add(column) return column } fun debugSchema() { val nameWidth = columns.maxOf { it.name.length } val typeWidth = columns.maxOf { it.type.getName().length } val totalWidth = maxOf(2 + nameWidth + 3 + typeWidth + 2, name.length + 4) val adjustedTypeWidth = totalWidth - nameWidth - 2 - 3 - 2 var string = "\n" string += ("+" + "-".repeat(totalWidth - 2) + "+\n") string += ("| $name${" ".repeat(totalWidth - 4 - name.length)} |\n") string += ("+" + "-".repeat(totalWidth - 2) + "+\n") for (column in columns) { string += ("| ${column.name}${" ".repeat(nameWidth - column.name.length)} |") string += (" ${column.type.getName()}" + "${" ".repeat(adjustedTypeWidth - column.type.getName().length)} |\n") } string += ("+" + "-".repeat(totalWidth - 2) + "+") println(string) } fun createIfNotExists( connection: Connection, filteredColumns: List<Column<*>> = columns ) { val properties = mutableListOf<String>() for (column in filteredColumns) { properties.add("${column.sqlName} ${column.type.dbType}") } val columnSet = filteredColumns.toSet() for (constraint in constraints) { if (columnSet.containsAll(constraint.affectedColumns)) { properties.add(constraint.asSQL()) } } connection.prepareAndLog("CREATE TABLE IF NOT EXISTS $sqlName (" + properties.joinToString() + ")") .execute() } fun alterTableAddColumns( connection: Connection, newColumns: List<Column<*>> ) { for (column in newColumns) { connection.prepareAndLog("ALTER TABLE $sqlName ADD ${column.sqlName} ${column.type.dbType}") .execute() } for (constraint in constraints) { // TODO: automatically add constraints, maybe (or maybe move constraints into the upgrade schema) } } enum class OnConflict { FAIL, IGNORE, REPLACE, ; fun asSql(): String { return name } } fun insert(connection: Connection, onConflict: OnConflict = OnConflict.FAIL, block: (InsertStatement) -> Unit) { val insert = InsertStatement(HashMap()) block(insert) require(insert.properties.keys == columns.toSet()) val columnNames = columns.joinToString { it.sqlName } val valueNames = columns.joinToString { "?" } val statement = connection.prepareAndLog("INSERT OR ${onConflict.asSql()} INTO $sqlName ($columnNames) VALUES ($valueNames)") for ((index, column) in columns.withIndex()) { (column as Column<Any>).type.set(statement, index + 1, insert.properties[column]!!) } statement.execute() } fun from(connection: Connection): Query { return Query(connection, mutableListOf(), this) } fun selectAll(connection: Connection): Query { return Query(connection, columns.toMutableList(), this) } } class InsertStatement(val properties: MutableMap<Column<*>, Any>) { operator fun <T : Any> set(key: Column<T>, value: T) { properties[key] = value } } fun Connection.prepareAndLog(statement: String): PreparedStatement { println("Preparing to execute $statement") return prepareStatement(statement) } interface SQLQueryComponent { fun asSql(): String /** * @return the next writable index (should equal to the amount of `?` in [asSql] + [startIndex]) */ fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int companion object { fun standalone(sql: String): SQLQueryComponent { return object : SQLQueryComponent { override fun asSql(): String { return sql } override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int { return startIndex } } } } } interface BooleanExpression : SQLQueryComponent data class ORExpression( val elements: List<BooleanExpression> ) : BooleanExpression { init { require(elements.isNotEmpty()) } override fun asSql(): String { return (elements + SQLQueryComponent.standalone("FALSE")).joinToString(" OR ", "(", ")") { it.asSql() } } override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int { var index = startIndex for (element in elements) { index = element.appendToStatement(stmt, index) } return index } } data class ANDExpression( val elements: List<BooleanExpression> ) : BooleanExpression { init { require(elements.isNotEmpty()) } override fun asSql(): String { return (elements + SQLQueryComponent.standalone("TRUE")).joinToString(" AND ", "(", ")") { it.asSql() } } override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int { var index = startIndex for (element in elements) { index = element.appendToStatement(stmt, index) } return index } } class ClauseBuilder { fun <T> column(column: Column<T>): Operand<T> = Operand.ColumnOperand(column) fun string(string: String): Operand.StringOperand = Operand.StringOperand(string) infix fun Operand<*>.eq(operand: Operand<*>) = Clause.EqualsClause(this, operand) infix fun Operand<*>.like(op: Operand.StringOperand) = Clause.LikeClause(this, op) infix fun Operand<*>.like(op: String) = Clause.LikeClause(this, string(op)) } interface Clause : BooleanExpression { companion object { operator fun invoke(builder: ClauseBuilder.() -> Clause): Clause { return builder(ClauseBuilder()) } } data class EqualsClause(val left: Operand<*>, val right: Operand<*>) : Clause { // TODO: typecheck this somehow override fun asSql(): String { return left.asSql() + " = " + right.asSql() } override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int { var index = startIndex index = left.appendToStatement(stmt, index) index = right.appendToStatement(stmt, index) return index } } data class LikeClause<T>(val left: Operand<T>, val right: Operand.StringOperand) : Clause { //TODO: check type safety with this one override fun asSql(): String { return "(" + left.asSql() + " LIKE " + right.asSql() + ")" } override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int { var index = startIndex index = left.appendToStatement(stmt, index) index = right.appendToStatement(stmt, index) return index } } } interface Operand<T> : SQLQueryComponent { data class ColumnOperand<T>(val column: Column<T>) : Operand<T> { override fun asSql(): String { return column.qualifiedSqlName } override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int { return startIndex } } data class StringOperand(val value: String) : Operand<String> { override fun asSql(): String { return "?" } override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int { stmt.setString(startIndex, value) return 1 + startIndex } } } data class Join( val table: Table, //TODO: aliased columns val tableAlias: String, val filter: Clause, ) : SQLQueryComponent { // JOIN ItemEntry on LogEntry.transactionId = ItemEntry.transactionId override fun asSql(): String { return "JOIN ${table.sqlName} ON ${filter.asSql()}" } override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int { return filter.appendToStatement(stmt, startIndex) } } fun List<SQLQueryComponent>.concatToFilledPreparedStatement(connection: Connection): PreparedStatement { var query = "" for (element in this) { if (query.isNotEmpty()) { query += " " } query += element.asSql() } val statement = connection.prepareAndLog(query) var index = 1 for (element in this) { val nextIndex = element.appendToStatement(statement, index) if (nextIndex < index) error("$element went back in time") index = nextIndex } return statement } class Query( val connection: Connection, val selectedColumns: MutableList<Column<*>>, var table: Table, var limit: UInt? = null, var skip: UInt? = null, val joins: MutableList<Join> = mutableListOf(), val conditions: MutableList<BooleanExpression> = mutableListOf(), // var order: OrderClause?= null, ) : Iterable<ResultRow> { fun join(table: Table, on: Clause): Query { joins.add(Join(table, on)) return this } fun where(binOp: BooleanExpression): Query { conditions.add(binOp) return this } fun select(vararg columns: Column<*>): Query { selectedColumns.addAll(columns) return this } fun skip(skip: UInt): Query { require(limit != null) this.skip = skip return this } fun limit(limit: UInt): Query { this.limit = limit return this } override fun iterator(): Iterator<ResultRow> { val columnSelections = selectedColumns.joinToString { it.qualifiedSqlName } val elements = mutableListOf( SQLQueryComponent.standalone("SELECT $columnSelections FROM ${table.sqlName}"), ) elements.addAll(joins) if (conditions.any()) { elements.add(SQLQueryComponent.standalone("WHERE")) elements.add(ANDExpression(conditions)) } if (limit != null) { elements.add(SQLQueryComponent.standalone("LIMIT $limit")) if (skip != null) { elements.add(SQLQueryComponent.standalone("OFFSET $skip")) } } val prepared = elements.concatToFilledPreparedStatement(connection) val results = prepared.executeQuery() return object : Iterator<ResultRow> { var hasAdvanced = false var hasEnded = false override fun hasNext(): Boolean { if (hasEnded) return false if (hasAdvanced) return true if (results.next()) { hasAdvanced = true return true } else { results.close() // TODO: somehow enforce closing this hasEnded = true return false } } override fun next(): ResultRow { if (!hasNext()) { throw NoSuchElementException() } hasAdvanced = false return ResultRow(selectedColumns.withIndex().associate { it.value to it.value.type.get(results, it.index + 1) }) } } } } class ResultRow(val columnValues: Map<Column<*>, *>) { operator fun <T> get(column: Column<T>): T { val value = columnValues[column] ?: error("Invalid column ${column.name}. Only ${columnValues.keys.joinToString { it.name }} are available.") return value as T } }