diff options
Diffstat (limited to 'src/main/kotlin/moe/nea/ledger/database/DBSchema.kt')
-rw-r--r-- | src/main/kotlin/moe/nea/ledger/database/DBSchema.kt | 217 |
1 files changed, 207 insertions, 10 deletions
diff --git a/src/main/kotlin/moe/nea/ledger/database/DBSchema.kt b/src/main/kotlin/moe/nea/ledger/database/DBSchema.kt index dee99e4..492f261 100644 --- a/src/main/kotlin/moe/nea/ledger/database/DBSchema.kt +++ b/src/main/kotlin/moe/nea/ledger/database/DBSchema.kt @@ -145,8 +145,13 @@ object DBInstant : DBType<Instant> { } } -class Column<T> @Deprecated("Use Table.column instead") constructor(val name: String, val type: DBType<T>) { +class Column<T> @Deprecated("Use Table.column instead") constructor( + val table: Table, + val name: String, + val type: DBType<T> +) { val sqlName get() = "`$name`" + val qualifiedSqlName get() = table.sqlName + "." + sqlName } interface Constraint { @@ -178,7 +183,7 @@ abstract class Table(val name: String) { } protected fun <T> column(name: String, type: DBType<T>): Column<T> { - @Suppress("DEPRECATION") val column = Column(name, type) + @Suppress("DEPRECATION") val column = Column(this, name, type) _mutable_columns.add(column) return column } @@ -258,9 +263,12 @@ abstract class Table(val name: String) { statement.execute() } + fun from(connection: Connection): Query { + return Query(connection, mutableListOf(), this) + } fun selectAll(connection: Connection): Query { - return Query(connection, columns, this) + return Query(connection, columns.toMutableList(), this) } } @@ -275,15 +283,196 @@ fun Connection.prepareAndLog(statement: String): PreparedStatement { return prepareStatement(statement) } +interface SQLQueryComponent { + fun asSql(): String + + /** + * @return the next writable index (should equal to the amount of `?` in [asSql] + [startIndex]) + */ + fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int + + companion object { + fun standalone(sql: String): SQLQueryComponent { + return object : SQLQueryComponent { + override fun asSql(): String { + return sql + } + + override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int { + return startIndex + } + } + } + } +} + +interface BooleanExpression : SQLQueryComponent + +data class ORExpression( + val elements: List<BooleanExpression> +) : BooleanExpression { + init { + require(elements.isNotEmpty()) + } + + override fun asSql(): String { + return (elements + SQLQueryComponent.standalone("FALSE")).joinToString(" OR ", "(", ")") { it.asSql() } + } + + override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int { + var index = startIndex + for (element in elements) { + index = element.appendToStatement(stmt, index) + } + return index + } +} + + +data class ANDExpression( + val elements: List<BooleanExpression> +) : BooleanExpression { + init { + require(elements.isNotEmpty()) + } + + override fun asSql(): String { + return (elements + SQLQueryComponent.standalone("TRUE")).joinToString(" AND ", "(", ")") { it.asSql() } + } + + override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int { + var index = startIndex + for (element in elements) { + index = element.appendToStatement(stmt, index) + } + return index + } +} + +class ClauseBuilder { + fun <T> column(column: Column<T>): Operand<T> = Operand.ColumnOperand(column) + fun string(string: String): Operand.StringOperand = Operand.StringOperand(string) + infix fun Operand<*>.eq(operand: Operand<*>) = Clause.EqualsClause(this, operand) + infix fun Operand<*>.like(op: Operand.StringOperand) = Clause.LikeClause(this, op) + infix fun Operand<*>.like(op: String) = Clause.LikeClause(this, string(op)) +} + +interface Clause : BooleanExpression { + companion object { + operator fun invoke(builder: ClauseBuilder.() -> Clause): Clause { + return builder(ClauseBuilder()) + } + } + + data class EqualsClause(val left: Operand<*>, val right: Operand<*>) : Clause { // TODO: typecheck this somehow + override fun asSql(): String { + return left.asSql() + " = " + right.asSql() + } + + override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int { + var index = startIndex + index = left.appendToStatement(stmt, index) + index = right.appendToStatement(stmt, index) + return index + } + } + + data class LikeClause<T>(val left: Operand<T>, val right: Operand.StringOperand) : Clause { + //TODO: check type safety with this one + override fun asSql(): String { + return "(" + left.asSql() + " LIKE " + right.asSql() + ")" + } + + override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int { + var index = startIndex + index = left.appendToStatement(stmt, index) + index = right.appendToStatement(stmt, index) + return index + } + } +} + +interface Operand<T> : SQLQueryComponent { + data class ColumnOperand<T>(val column: Column<T>) : Operand<T> { + override fun asSql(): String { + return column.qualifiedSqlName + } + + override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int { + return startIndex + } + } + + data class StringOperand(val value: String) : Operand<String> { + override fun asSql(): String { + return "?" + } + + override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int { + stmt.setString(startIndex, value) + return 1 + startIndex + } + } +} + +data class Join( + val table: Table, +//TODO: aliased columns val tableAlias: String, + val filter: Clause, +) : SQLQueryComponent { + // JOIN ItemEntry on LogEntry.transactionId = ItemEntry.transactionId + override fun asSql(): String { + return "JOIN ${table.sqlName} ON ${filter.asSql()}" + } + + override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int { + return filter.appendToStatement(stmt, startIndex) + } +} + +fun List<SQLQueryComponent>.concatToFilledPreparedStatement(connection: Connection): PreparedStatement { + var query = "" + for (element in this) { + if (query.isNotEmpty()) { + query += " " + } + query += element.asSql() + } + val statement = connection.prepareAndLog(query) + var index = 1 + for (element in this) { + val nextIndex = element.appendToStatement(statement, index) + if (nextIndex < index) error("$element went back in time") + index = nextIndex + } + return statement +} + class Query( val connection: Connection, - val selectedColumns: List<Column<*>>, + val selectedColumns: MutableList<Column<*>>, var table: Table, var limit: UInt? = null, var skip: UInt? = null, + val joins: MutableList<Join> = mutableListOf(), + val conditions: MutableList<BooleanExpression> = mutableListOf(), // var order: OrderClause?= null, -// val condition: List<SqlCondition> ) : Iterable<ResultRow> { + fun join(table: Table, on: Clause): Query { + joins.add(Join(table, on)) + return this + } + + fun where(binOp: BooleanExpression): Query { + conditions.add(binOp) + return this + } + + fun select(vararg columns: Column<*>): Query { + selectedColumns.addAll(columns) + return this + } + fun skip(skip: UInt): Query { require(limit != null) this.skip = skip @@ -296,15 +485,22 @@ class Query( } override fun iterator(): Iterator<ResultRow> { - val columnSelections = selectedColumns.joinToString { it.sqlName } - var query = "SELECT $columnSelections FROM ${table.sqlName} " + val columnSelections = selectedColumns.joinToString { it.qualifiedSqlName } + val elements = mutableListOf( + SQLQueryComponent.standalone("SELECT $columnSelections FROM ${table.sqlName}"), + ) + elements.addAll(joins) + if (conditions.any()) { + elements.add(SQLQueryComponent.standalone("WHERE")) + elements.add(ANDExpression(conditions)) + } if (limit != null) { - query += "LIMIT $limit " + elements.add(SQLQueryComponent.standalone("LIMIT $limit")) if (skip != null) { - query += "OFFSET $skip " + elements.add(SQLQueryComponent.standalone("OFFSET $skip")) } } - val prepared = connection.prepareAndLog(query.trim()) + val prepared = elements.concatToFilledPreparedStatement(connection) val results = prepared.executeQuery() return object : Iterator<ResultRow> { var hasAdvanced = false @@ -316,6 +512,7 @@ class Query( hasAdvanced = true return true } else { + results.close() // TODO: somehow enforce closing this hasEnded = true return false } |