aboutsummaryrefslogtreecommitdiff
path: root/src/main/kotlin/moe/nea/ledger/database
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/kotlin/moe/nea/ledger/database')
-rw-r--r--src/main/kotlin/moe/nea/ledger/database/DBSchema.kt217
1 files changed, 207 insertions, 10 deletions
diff --git a/src/main/kotlin/moe/nea/ledger/database/DBSchema.kt b/src/main/kotlin/moe/nea/ledger/database/DBSchema.kt
index dee99e4..492f261 100644
--- a/src/main/kotlin/moe/nea/ledger/database/DBSchema.kt
+++ b/src/main/kotlin/moe/nea/ledger/database/DBSchema.kt
@@ -145,8 +145,13 @@ object DBInstant : DBType<Instant> {
}
}
-class Column<T> @Deprecated("Use Table.column instead") constructor(val name: String, val type: DBType<T>) {
+class Column<T> @Deprecated("Use Table.column instead") constructor(
+ val table: Table,
+ val name: String,
+ val type: DBType<T>
+) {
val sqlName get() = "`$name`"
+ val qualifiedSqlName get() = table.sqlName + "." + sqlName
}
interface Constraint {
@@ -178,7 +183,7 @@ abstract class Table(val name: String) {
}
protected fun <T> column(name: String, type: DBType<T>): Column<T> {
- @Suppress("DEPRECATION") val column = Column(name, type)
+ @Suppress("DEPRECATION") val column = Column(this, name, type)
_mutable_columns.add(column)
return column
}
@@ -258,9 +263,12 @@ abstract class Table(val name: String) {
statement.execute()
}
+ fun from(connection: Connection): Query {
+ return Query(connection, mutableListOf(), this)
+ }
fun selectAll(connection: Connection): Query {
- return Query(connection, columns, this)
+ return Query(connection, columns.toMutableList(), this)
}
}
@@ -275,15 +283,196 @@ fun Connection.prepareAndLog(statement: String): PreparedStatement {
return prepareStatement(statement)
}
+interface SQLQueryComponent {
+ fun asSql(): String
+
+ /**
+ * @return the next writable index (should equal to the amount of `?` in [asSql] + [startIndex])
+ */
+ fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int
+
+ companion object {
+ fun standalone(sql: String): SQLQueryComponent {
+ return object : SQLQueryComponent {
+ override fun asSql(): String {
+ return sql
+ }
+
+ override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int {
+ return startIndex
+ }
+ }
+ }
+ }
+}
+
+interface BooleanExpression : SQLQueryComponent
+
+data class ORExpression(
+ val elements: List<BooleanExpression>
+) : BooleanExpression {
+ init {
+ require(elements.isNotEmpty())
+ }
+
+ override fun asSql(): String {
+ return (elements + SQLQueryComponent.standalone("FALSE")).joinToString(" OR ", "(", ")") { it.asSql() }
+ }
+
+ override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int {
+ var index = startIndex
+ for (element in elements) {
+ index = element.appendToStatement(stmt, index)
+ }
+ return index
+ }
+}
+
+
+data class ANDExpression(
+ val elements: List<BooleanExpression>
+) : BooleanExpression {
+ init {
+ require(elements.isNotEmpty())
+ }
+
+ override fun asSql(): String {
+ return (elements + SQLQueryComponent.standalone("TRUE")).joinToString(" AND ", "(", ")") { it.asSql() }
+ }
+
+ override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int {
+ var index = startIndex
+ for (element in elements) {
+ index = element.appendToStatement(stmt, index)
+ }
+ return index
+ }
+}
+
+class ClauseBuilder {
+ fun <T> column(column: Column<T>): Operand<T> = Operand.ColumnOperand(column)
+ fun string(string: String): Operand.StringOperand = Operand.StringOperand(string)
+ infix fun Operand<*>.eq(operand: Operand<*>) = Clause.EqualsClause(this, operand)
+ infix fun Operand<*>.like(op: Operand.StringOperand) = Clause.LikeClause(this, op)
+ infix fun Operand<*>.like(op: String) = Clause.LikeClause(this, string(op))
+}
+
+interface Clause : BooleanExpression {
+ companion object {
+ operator fun invoke(builder: ClauseBuilder.() -> Clause): Clause {
+ return builder(ClauseBuilder())
+ }
+ }
+
+ data class EqualsClause(val left: Operand<*>, val right: Operand<*>) : Clause { // TODO: typecheck this somehow
+ override fun asSql(): String {
+ return left.asSql() + " = " + right.asSql()
+ }
+
+ override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int {
+ var index = startIndex
+ index = left.appendToStatement(stmt, index)
+ index = right.appendToStatement(stmt, index)
+ return index
+ }
+ }
+
+ data class LikeClause<T>(val left: Operand<T>, val right: Operand.StringOperand) : Clause {
+ //TODO: check type safety with this one
+ override fun asSql(): String {
+ return "(" + left.asSql() + " LIKE " + right.asSql() + ")"
+ }
+
+ override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int {
+ var index = startIndex
+ index = left.appendToStatement(stmt, index)
+ index = right.appendToStatement(stmt, index)
+ return index
+ }
+ }
+}
+
+interface Operand<T> : SQLQueryComponent {
+ data class ColumnOperand<T>(val column: Column<T>) : Operand<T> {
+ override fun asSql(): String {
+ return column.qualifiedSqlName
+ }
+
+ override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int {
+ return startIndex
+ }
+ }
+
+ data class StringOperand(val value: String) : Operand<String> {
+ override fun asSql(): String {
+ return "?"
+ }
+
+ override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int {
+ stmt.setString(startIndex, value)
+ return 1 + startIndex
+ }
+ }
+}
+
+data class Join(
+ val table: Table,
+//TODO: aliased columns val tableAlias: String,
+ val filter: Clause,
+) : SQLQueryComponent {
+ // JOIN ItemEntry on LogEntry.transactionId = ItemEntry.transactionId
+ override fun asSql(): String {
+ return "JOIN ${table.sqlName} ON ${filter.asSql()}"
+ }
+
+ override fun appendToStatement(stmt: PreparedStatement, startIndex: Int): Int {
+ return filter.appendToStatement(stmt, startIndex)
+ }
+}
+
+fun List<SQLQueryComponent>.concatToFilledPreparedStatement(connection: Connection): PreparedStatement {
+ var query = ""
+ for (element in this) {
+ if (query.isNotEmpty()) {
+ query += " "
+ }
+ query += element.asSql()
+ }
+ val statement = connection.prepareAndLog(query)
+ var index = 1
+ for (element in this) {
+ val nextIndex = element.appendToStatement(statement, index)
+ if (nextIndex < index) error("$element went back in time")
+ index = nextIndex
+ }
+ return statement
+}
+
class Query(
val connection: Connection,
- val selectedColumns: List<Column<*>>,
+ val selectedColumns: MutableList<Column<*>>,
var table: Table,
var limit: UInt? = null,
var skip: UInt? = null,
+ val joins: MutableList<Join> = mutableListOf(),
+ val conditions: MutableList<BooleanExpression> = mutableListOf(),
// var order: OrderClause?= null,
-// val condition: List<SqlCondition>
) : Iterable<ResultRow> {
+ fun join(table: Table, on: Clause): Query {
+ joins.add(Join(table, on))
+ return this
+ }
+
+ fun where(binOp: BooleanExpression): Query {
+ conditions.add(binOp)
+ return this
+ }
+
+ fun select(vararg columns: Column<*>): Query {
+ selectedColumns.addAll(columns)
+ return this
+ }
+
fun skip(skip: UInt): Query {
require(limit != null)
this.skip = skip
@@ -296,15 +485,22 @@ class Query(
}
override fun iterator(): Iterator<ResultRow> {
- val columnSelections = selectedColumns.joinToString { it.sqlName }
- var query = "SELECT $columnSelections FROM ${table.sqlName} "
+ val columnSelections = selectedColumns.joinToString { it.qualifiedSqlName }
+ val elements = mutableListOf(
+ SQLQueryComponent.standalone("SELECT $columnSelections FROM ${table.sqlName}"),
+ )
+ elements.addAll(joins)
+ if (conditions.any()) {
+ elements.add(SQLQueryComponent.standalone("WHERE"))
+ elements.add(ANDExpression(conditions))
+ }
if (limit != null) {
- query += "LIMIT $limit "
+ elements.add(SQLQueryComponent.standalone("LIMIT $limit"))
if (skip != null) {
- query += "OFFSET $skip "
+ elements.add(SQLQueryComponent.standalone("OFFSET $skip"))
}
}
- val prepared = connection.prepareAndLog(query.trim())
+ val prepared = elements.concatToFilledPreparedStatement(connection)
val results = prepared.executeQuery()
return object : Iterator<ResultRow> {
var hasAdvanced = false
@@ -316,6 +512,7 @@ class Query(
hasAdvanced = true
return true
} else {
+ results.close() // TODO: somehow enforce closing this
hasEnded = true
return false
}