package moe.nea.ledger.database
import moe.nea.ledger.UUIDUtil
import java.sql.Connection
import java.sql.PreparedStatement
import java.sql.ResultSet
import java.time.Instant
import java.util.UUID
interface DBSchema {
val tables: List
}
interface DBType {
val dbType: String
fun get(result: ResultSet, index: Int): T
fun set(stmt: PreparedStatement, index: Int, value: T)
fun getName(): String = javaClass.simpleName
fun mapped(
from: (R) -> T,
to: (T) -> R,
): DBType {
return object : DBType {
override fun getName(): String {
return "Mapped(${this@DBType.getName()})"
}
override val dbType: String
get() = this@DBType.dbType
override fun get(result: ResultSet, index: Int): R {
return to(this@DBType.get(result, index))
}
override fun set(stmt: PreparedStatement, index: Int, value: R) {
this@DBType.set(stmt, index, from(value))
}
}
}
}
object DBUuid : DBType {
override val dbType: String
get() = "TEXT"
override fun get(result: ResultSet, index: Int): UUID {
return UUIDUtil.parseDashlessUuid(result.getString(index))
}
override fun set(stmt: PreparedStatement, index: Int, value: UUID) {
stmt.setString(index, value.toString())
}
}
object DBUlid : DBType {
override val dbType: String
get() = "TEXT"
override fun get(result: ResultSet, index: Int): UUIDUtil.ULIDWrapper {
val text = result.getString(index)
return UUIDUtil.ULIDWrapper(text)
}
override fun set(stmt: PreparedStatement, index: Int, value: UUIDUtil.ULIDWrapper) {
stmt.setString(index, value.wrapped)
}
}
object DBString : DBType {
override val dbType: String
get() = "TEXT"
override fun get(result: ResultSet, index: Int): String {
return result.getString(index)
}
override fun set(stmt: PreparedStatement, index: Int, value: String) {
stmt.setString(index, value)
}
}
class DBEnum>(
val type: Class,
) : DBType {
companion object {
inline operator fun > invoke(): DBEnum {
return DBEnum(T::class.java)
}
}
override val dbType: String
get() = "TEXT"
override fun getName(): String {
return "DBEnum(${type.simpleName})"
}
override fun set(stmt: PreparedStatement, index: Int, value: T) {
stmt.setString(index, value.name)
}
override fun get(result: ResultSet, index: Int): T {
val name = result.getString(index)
return java.lang.Enum.valueOf(type, name)
}
}
object DBDouble : DBType {
override val dbType: String
get() = "DOUBLE"
override fun get(result: ResultSet, index: Int): Double {
return result.getDouble(index)
}
override fun set(stmt: PreparedStatement, index: Int, value: Double) {
stmt.setDouble(index, value)
}
}
object DBInt : DBType {
override val dbType: String
get() = "INTEGER"
override fun get(result: ResultSet, index: Int): Long {
return result.getLong(index)
}
override fun set(stmt: PreparedStatement, index: Int, value: Long) {
stmt.setLong(index, value)
}
}
object DBInstant : DBType {
override val dbType: String
get() = "INTEGER"
override fun set(stmt: PreparedStatement, index: Int, value: Instant) {
stmt.setLong(index, value.toEpochMilli())
}
override fun get(result: ResultSet, index: Int): Instant {
return Instant.ofEpochMilli(result.getLong(index))
}
}
class Column @Deprecated("Use Table.column instead") constructor(
val table: Table,
val name: String,
val type: DBType
) {
val sqlName get() = "`$name`"
val qualifiedSqlName get() = table.sqlName + "." + sqlName
}
interface Constraint {
val affectedColumns: Collection>
fun asSQL(): String
}
class UniqueConstraint(val columns: List>) : Constraint {
init {
require(columns.isNotEmpty())
}
override val affectedColumns: Collection>
get() = columns
override fun asSQL(): String {
return "UNIQUE (${columns.joinToString() { it.sqlName }})"
}
}
abstract class Table(val name: String) {
val sqlName get() = "`$name`"
protected val _mutable_columns: MutableList> = mutableListOf()
protected val _mutable_constraints: MutableList = mutableListOf()
val columns: List> get() = _mutable_columns
val constraints get() = _mutable_constraints
protected fun unique(vararg columns: Column<*>) {
_mutable_constraints.add(UniqueConstraint(columns.toList()))
}
protected fun column(name: String, type: DBType): Column {
@Suppress("DEPRECATION") val column = Column(this, name, type)
_mutable_columns.add(column)
return column
}
fun debugSchema() {
val nameWidth = columns.maxOf { it.name.length }
val typeWidth = columns.maxOf { it.type.getName().length }
val totalWidth = maxOf(2 + nameWidth + 3 + typeWidth + 2, name.length + 4)
val adjustedTypeWidth = totalWidth - nameWidth - 2 - 3 - 2
var string = "\n"
string += ("+" + "-".repeat(totalWidth - 2) + "+\n")
string += ("| $name${" ".repeat(totalWidth - 4 - name.length)} |\n")
string += ("+" + "-".repeat(totalWidth - 2) + "+\n")
for (column in columns) {
string += ("| ${column.name}${" ".repeat(nameWidth - column.name.length)} |")
string += (" ${column.type.getName()}" +
"${" ".repeat(adjustedTypeWidth - column.type.getName().length)} |\n")
}
string += ("+" + "-".repeat(totalWidth - 2) + "+")
println(string)
}
fun createIfNotExists(
connection: Connection,
filteredColumns: List> = columns
) {
val properties = mutableListOf()
for (column in filteredColumns) {
properties.add("${column.sqlName} ${column.type.dbType}")
}
val columnSet = filteredColumns.toSet()
for (constraint in constraints) {
if (columnSet.containsAll(constraint.affectedColumns)) {
properties.add(constraint.asSQL())
}
}
connection.prepareAndLog("CREATE TABLE IF NOT EXISTS $sqlName (" + properties.joinToString() + ")")
.execute()
}
fun alterTableAddColumns(
connection: Connection,
newColumns: List>
) {
for (column in newColumns) {
connection.prepareAndLog("ALTER TABLE $sqlName ADD ${column.sqlName} ${column.type.dbType}")
.execute()
}
for (constraint in constraints) {
// TODO: automatically add constraints, maybe (or maybe move constraints into the upgrade schema)
}
}
enum class OnConflict {
FAIL,
IGNORE,
REPLACE,
;
fun asSql(): String {
return name
}
}
fun insert(connection: Connection, onConflict: OnConflict = OnConflict.FAIL, block: (InsertStatement) -> Unit) {
val insert = InsertStatement(HashMap())
block(insert)
require(insert.properties.keys == columns.toSet())
val columnNames = columns.joinToString { it.sqlName }
val valueNames = columns.joinToString { "?" }
val statement =
connection.prepareAndLog("INSERT OR ${onConflict.asSql()} INTO $sqlName ($columnNames) VALUES ($valueNames)")
for ((index, column) in columns.withIndex()) {
(column as Column).type.set(statement, index + 1, insert.properties[column]!!)
}
statement.execute()
}
fun from(connection: Connection): Query {
return Query(connection, mutableListOf(), this)
}
fun selectAll(connection: Connection): Query {
return Query(connection, columns.toMutableList(), this)
}
}
class InsertStatement(val properties: MutableMap, Any>) {
operator fun set(key: Column, value: T) {
properties[key] = value
}
}
fun Connection.prepareAndLog(statement: String): PreparedStatement {
println("Preparing to execute $statement")
return prepareStatement(statement)
}
interface SQLQueryComponent {
fun asSql(): String
/**
* @return the next writable index (should equal to the amount of `?` in [asSql] + [startIndex])
*/
fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int
companion object {
fun standalone(sql: String): SQLQueryComponent {
return object : SQLQueryComponent {
override fun asSql(): String {
return sql
}
override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int {
return startIndex
}
}
}
}
}
interface BooleanExpression : SQLQueryComponent
data class ORExpression(
val elements: List
) : BooleanExpression {
init {
require(elements.isNotEmpty())
}
override fun asSql(): String {
return (elements + SQLQueryComponent.standalone("FALSE")).joinToString(" OR ", "(", ")") { it.asSql() }
}
override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int {
var index = startIndex
for (element in elements) {
index = element.appendToStatement(stmt, index)
}
return index
}
}
data class ANDExpression(
val elements: List
) : BooleanExpression {
init {
require(elements.isNotEmpty())
}
override fun asSql(): String {
return (elements + SQLQueryComponent.standalone("TRUE")).joinToString(" AND ", "(", ")") { it.asSql() }
}
override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int {
var index = startIndex
for (element in elements) {
index = element.appendToStatement(stmt, index)
}
return index
}
}
class ClauseBuilder {
fun column(column: Column): Operand = Operand.ColumnOperand(column)
fun string(string: String): Operand.StringOperand = Operand.StringOperand(string)
infix fun Operand<*>.eq(operand: Operand<*>) = Clause.EqualsClause(this, operand)
infix fun Operand<*>.like(op: Operand.StringOperand) = Clause.LikeClause(this, op)
infix fun Operand<*>.like(op: String) = Clause.LikeClause(this, string(op))
}
interface Clause : BooleanExpression {
companion object {
operator fun invoke(builder: ClauseBuilder.() -> Clause): Clause {
return builder(ClauseBuilder())
}
}
data class EqualsClause(val left: Operand<*>, val right: Operand<*>) : Clause { // TODO: typecheck this somehow
override fun asSql(): String {
return left.asSql() + " = " + right.asSql()
}
override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int {
var index = startIndex
index = left.appendToStatement(stmt, index)
index = right.appendToStatement(stmt, index)
return index
}
}
data class LikeClause(val left: Operand, val right: Operand.StringOperand) : Clause {
//TODO: check type safety with this one
override fun asSql(): String {
return "(" + left.asSql() + " LIKE " + right.asSql() + ")"
}
override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int {
var index = startIndex
index = left.appendToStatement(stmt, index)
index = right.appendToStatement(stmt, index)
return index
}
}
}
interface Operand : SQLQueryComponent {
data class ColumnOperand(val column: Column) : Operand {
override fun asSql(): String {
return column.qualifiedSqlName
}
override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int {
return startIndex
}
}
data class StringOperand(val value: String) : Operand {
override fun asSql(): String {
return "?"
}
override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int {
stmt.setString(startIndex, value)
return 1 + startIndex
}
}
}
data class Join(
val table: Table,
//TODO: aliased columns val tableAlias: String,
val filter: Clause,
) : SQLQueryComponent {
// JOIN ItemEntry on LogEntry.transactionId = ItemEntry.transactionId
override fun asSql(): String {
return "JOIN ${table.sqlName} ON ${filter.asSql()}"
}
override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int {
return filter.appendToStatement(stmt, startIndex)
}
}
fun List.concatToFilledPreparedStatement(connection: Connection): PreparedStatement {
var query = ""
for (element in this) {
if (query.isNotEmpty()) {
query += " "
}
query += element.asSql()
}
val statement = connection.prepareAndLog(query)
var index = 1
for (element in this) {
val nextIndex = element.appendToStatement(statement, index)
if (nextIndex < index) error("$element went back in time")
index = nextIndex
}
return statement
}
class Query(
val connection: Connection,
val selectedColumns: MutableList>,
var table: Table,
var limit: UInt? = null,
var skip: UInt? = null,
val joins: MutableList = mutableListOf(),
val conditions: MutableList = mutableListOf(),
// var order: OrderClause?= null,
) : Iterable {
fun join(table: Table, on: Clause): Query {
joins.add(Join(table, on))
return this
}
fun where(binOp: BooleanExpression): Query {
conditions.add(binOp)
return this
}
fun select(vararg columns: Column<*>): Query {
selectedColumns.addAll(columns)
return this
}
fun skip(skip: UInt): Query {
require(limit != null)
this.skip = skip
return this
}
fun limit(limit: UInt): Query {
this.limit = limit
return this
}
override fun iterator(): Iterator {
val columnSelections = selectedColumns.joinToString { it.qualifiedSqlName }
val elements = mutableListOf(
SQLQueryComponent.standalone("SELECT $columnSelections FROM ${table.sqlName}"),
)
elements.addAll(joins)
if (conditions.any()) {
elements.add(SQLQueryComponent.standalone("WHERE"))
elements.add(ANDExpression(conditions))
}
if (limit != null) {
elements.add(SQLQueryComponent.standalone("LIMIT $limit"))
if (skip != null) {
elements.add(SQLQueryComponent.standalone("OFFSET $skip"))
}
}
val prepared = elements.concatToFilledPreparedStatement(connection)
val results = prepared.executeQuery()
return object : Iterator {
var hasAdvanced = false
var hasEnded = false
override fun hasNext(): Boolean {
if (hasEnded) return false
if (hasAdvanced) return true
if (results.next()) {
hasAdvanced = true
return true
} else {
results.close() // TODO: somehow enforce closing this
hasEnded = true
return false
}
}
override fun next(): ResultRow {
if (!hasNext()) {
throw NoSuchElementException()
}
hasAdvanced = false
return ResultRow(selectedColumns.withIndex().associate {
it.value to it.value.type.get(results, it.index + 1)
})
}
}
}
}
class ResultRow(val columnValues: Map, *>) {
operator fun get(column: Column): T {
val value = columnValues[column]
?: error("Invalid column ${column.name}. Only ${columnValues.keys.joinToString { it.name }} are available.")
return value as T
}
}