From 40ce4f856718d6b0c7fa15cd40a1f6db2f774e6f Mon Sep 17 00:00:00 2001
From: Linnea Gräf <nea@nea.moe>
Date: Tue, 18 Mar 2025 19:28:15 +0100
Subject: feat: Allow checking out repo PRs

---
 src/main/kotlin/commands/rome.kt            |  10 ++
 src/main/kotlin/repo/RepoDownloadManager.kt | 193 ++++++++++++++--------------
 src/main/kotlin/repo/RepoManager.kt         |   7 +
 3 files changed, 113 insertions(+), 97 deletions(-)

diff --git a/src/main/kotlin/commands/rome.kt b/src/main/kotlin/commands/rome.kt
index 8ae34f6..c3eb03d 100644
--- a/src/main/kotlin/commands/rome.kt
+++ b/src/main/kotlin/commands/rome.kt
@@ -1,6 +1,7 @@
 package moe.nea.firmament.commands
 
 import com.mojang.brigadier.CommandDispatcher
+import com.mojang.brigadier.arguments.IntegerArgumentType
 import com.mojang.brigadier.arguments.StringArgumentType.string
 import io.ktor.client.statement.bodyAsText
 import net.fabricmc.fabric.api.client.command.v2.FabricClientCommandSource
@@ -130,6 +131,15 @@ fun firmamentCommand() = literal("firmament") {
 		}
 	}
 	thenLiteral("repo") {
+		thenLiteral("checkpr") {
+			thenArgument("prnum", IntegerArgumentType.integer(1)) { prnum ->
+				thenExecute {
+					val prnum = this[prnum]
+					source.sendFeedback(tr("firmament.repo.reload.pr", "Temporarily reloading repo from PR #${prnum}."))
+					RepoManager.downloadOverridenBranch("refs/pull/$prnum/head")
+				}
+			}
+		}
 		thenLiteral("reload") {
 			thenLiteral("fetch") {
 				thenExecute {
diff --git a/src/main/kotlin/repo/RepoDownloadManager.kt b/src/main/kotlin/repo/RepoDownloadManager.kt
index 3efd83b..888248d 100644
--- a/src/main/kotlin/repo/RepoDownloadManager.kt
+++ b/src/main/kotlin/repo/RepoDownloadManager.kt
@@ -1,5 +1,3 @@
-
-
 package moe.nea.firmament.repo
 
 import io.ktor.client.call.body
@@ -28,101 +26,102 @@ import moe.nea.firmament.util.iterate
 
 object RepoDownloadManager {
 
-    val repoSavedLocation = Firmament.DATA_DIR.resolve("repo-extracted")
-    val repoMetadataLocation = Firmament.DATA_DIR.resolve("loaded-repo-sha.txt")
-
-    private fun loadSavedVersionHash(): String? =
-        if (repoSavedLocation.exists()) {
-            if (repoMetadataLocation.exists()) {
-                try {
-                    repoMetadataLocation.readText().trim()
-                } catch (e: IOException) {
-                    null
-                }
-            } else {
-                null
-            }
-        } else null
-
-    private fun saveVersionHash(versionHash: String) {
-        latestSavedVersionHash = versionHash
-        repoMetadataLocation.writeText(versionHash)
-    }
-
-    var latestSavedVersionHash: String? = loadSavedVersionHash()
-        private set
-
-    @Serializable
-    private class GithubCommitsResponse(val sha: String)
-
-    private suspend fun requestLatestGithubSha(): String? {
-        if (RepoManager.Config.branch == "prerelease") {
-            RepoManager.Config.branch = "master"
-        }
-        val response =
-            Firmament.httpClient.get("https://api.github.com/repos/${RepoManager.Config.username}/${RepoManager.Config.reponame}/commits/${RepoManager.Config.branch}")
-        if (response.status.value != 200) {
-            return null
-        }
-        return response.body<GithubCommitsResponse>().sha
-    }
-
-    private suspend fun downloadGithubArchive(url: String): Path = withContext(IO) {
-        val response = Firmament.httpClient.get(url)
-        val targetFile = Files.createTempFile("firmament-repo", ".zip")
-        val outputChannel = Files.newByteChannel(targetFile, StandardOpenOption.CREATE, StandardOpenOption.WRITE)
-        response.bodyAsChannel().copyTo(outputChannel)
-        targetFile
-    }
-
-    /**
-     * Downloads the latest repository from github, setting [latestSavedVersionHash].
-     * @return true, if an update was performed, false, otherwise (no update needed, or wasn't able to complete update)
-     */
-    suspend fun downloadUpdate(force: Boolean): Boolean = withContext(CoroutineName("Repo Update Check")) {
-        val latestSha = requestLatestGithubSha()
-        if (latestSha == null) {
-            logger.warn("Could not request github API to retrieve latest REPO sha.")
-            return@withContext false
-        }
-        val currentSha = loadSavedVersionHash()
-        if (latestSha != currentSha || force) {
-            val requestUrl =
-                "https://github.com/${RepoManager.Config.username}/${RepoManager.Config.reponame}/archive/$latestSha.zip"
-            logger.info("Planning to upgrade repository from $currentSha to $latestSha from $requestUrl")
-            val zipFile = downloadGithubArchive(requestUrl)
-            logger.info("Download repository zip file to $zipFile. Deleting old repository")
-            withContext(IO) { repoSavedLocation.toFile().deleteRecursively() }
-            logger.info("Extracting new repository")
-            withContext(IO) { extractNewRepository(zipFile) }
-            logger.info("Repository loaded on disk.")
-            saveVersionHash(latestSha)
-            return@withContext true
-        } else {
-            logger.debug("Repository on latest sha $currentSha. Not performing update")
-            return@withContext false
-        }
-    }
-
-    private fun extractNewRepository(zipFile: Path) {
-        repoSavedLocation.createDirectories()
-        ZipInputStream(zipFile.inputStream()).use { cis ->
-            while (true) {
-                val entry = cis.nextEntry ?: break
-                if (entry.isDirectory) continue
-                val extractedLocation =
-                    repoSavedLocation.resolve(
-                        entry.name.substringAfter('/', missingDelimiterValue = "")
-                    )
-                if (repoSavedLocation !in extractedLocation.iterate { it.parent }) {
-                    logger.error("Firmament detected an invalid zip file. This is a potential security risk, please report this in the Firmament discord.")
-                    throw RuntimeException("Firmament detected an invalid zip file. This is a potential security risk, please report this in the Firmament discord.")
-                }
-                extractedLocation.parent.createDirectories()
-                extractedLocation.outputStream().use { cis.copyTo(it) }
-            }
-        }
-    }
+	val repoSavedLocation = Firmament.DATA_DIR.resolve("repo-extracted")
+	val repoMetadataLocation = Firmament.DATA_DIR.resolve("loaded-repo-sha.txt")
+
+	private fun loadSavedVersionHash(): String? =
+		if (repoSavedLocation.exists()) {
+			if (repoMetadataLocation.exists()) {
+				try {
+					repoMetadataLocation.readText().trim()
+				} catch (e: IOException) {
+					null
+				}
+			} else {
+				null
+			}
+		} else null
+
+	private fun saveVersionHash(versionHash: String) {
+		latestSavedVersionHash = versionHash
+		repoMetadataLocation.writeText(versionHash)
+	}
+
+	var latestSavedVersionHash: String? = loadSavedVersionHash()
+		private set
+
+	@Serializable
+	private class GithubCommitsResponse(val sha: String)
+
+	private suspend fun requestLatestGithubSha(branchOverride: String?): String? {
+		if (RepoManager.Config.branch == "prerelease") {
+			RepoManager.Config.branch = "master"
+		}
+		val response =
+			Firmament.httpClient.get("https://api.github.com/repos/${RepoManager.Config.username}/${RepoManager.Config.reponame}/commits/${branchOverride ?: RepoManager.Config.branch}")
+		if (response.status.value != 200) {
+			return null
+		}
+		return response.body<GithubCommitsResponse>().sha
+	}
+
+	private suspend fun downloadGithubArchive(url: String): Path = withContext(IO) {
+		val response = Firmament.httpClient.get(url)
+		val targetFile = Files.createTempFile("firmament-repo", ".zip")
+		val outputChannel = Files.newByteChannel(targetFile, StandardOpenOption.CREATE, StandardOpenOption.WRITE)
+		response.bodyAsChannel().copyTo(outputChannel)
+		targetFile
+	}
+
+	/**
+	 * Downloads the latest repository from github, setting [latestSavedVersionHash].
+	 * @return true, if an update was performed, false, otherwise (no update needed, or wasn't able to complete update)
+	 */
+	suspend fun downloadUpdate(force: Boolean, branch: String? = null): Boolean =
+		withContext(CoroutineName("Repo Update Check")) {
+			val latestSha = requestLatestGithubSha(branch)
+			if (latestSha == null) {
+				logger.warn("Could not request github API to retrieve latest REPO sha.")
+				return@withContext false
+			}
+			val currentSha = loadSavedVersionHash()
+			if (latestSha != currentSha || force) {
+				val requestUrl =
+					"https://github.com/${RepoManager.Config.username}/${RepoManager.Config.reponame}/archive/$latestSha.zip"
+				logger.info("Planning to upgrade repository from $currentSha to $latestSha from $requestUrl")
+				val zipFile = downloadGithubArchive(requestUrl)
+				logger.info("Download repository zip file to $zipFile. Deleting old repository")
+				withContext(IO) { repoSavedLocation.toFile().deleteRecursively() }
+				logger.info("Extracting new repository")
+				withContext(IO) { extractNewRepository(zipFile) }
+				logger.info("Repository loaded on disk.")
+				saveVersionHash(latestSha)
+				return@withContext true
+			} else {
+				logger.debug("Repository on latest sha $currentSha. Not performing update")
+				return@withContext false
+			}
+		}
+
+	private fun extractNewRepository(zipFile: Path) {
+		repoSavedLocation.createDirectories()
+		ZipInputStream(zipFile.inputStream()).use { cis ->
+			while (true) {
+				val entry = cis.nextEntry ?: break
+				if (entry.isDirectory) continue
+				val extractedLocation =
+					repoSavedLocation.resolve(
+						entry.name.substringAfter('/', missingDelimiterValue = "")
+					)
+				if (repoSavedLocation !in extractedLocation.iterate { it.parent }) {
+					logger.error("Firmament detected an invalid zip file. This is a potential security risk, please report this in the Firmament discord.")
+					throw RuntimeException("Firmament detected an invalid zip file. This is a potential security risk, please report this in the Firmament discord.")
+				}
+				extractedLocation.parent.createDirectories()
+				extractedLocation.outputStream().use { cis.copyTo(it) }
+			}
+		}
+	}
 
 
 }
diff --git a/src/main/kotlin/repo/RepoManager.kt b/src/main/kotlin/repo/RepoManager.kt
index e50a131..cc36fba 100644
--- a/src/main/kotlin/repo/RepoManager.kt
+++ b/src/main/kotlin/repo/RepoManager.kt
@@ -102,6 +102,13 @@ object RepoManager {
 
 	fun getNEUItem(skyblockId: SkyblockId): NEUItem? = neuRepo.items.getItemBySkyblockId(skyblockId.neuItem)
 
+	fun downloadOverridenBranch(branch: String) {
+		Firmament.coroutineScope.launch {
+			RepoDownloadManager.downloadUpdate(true, branch)
+			reload()
+		}
+	}
+
 	fun launchAsyncUpdate(force: Boolean = false) {
 		Firmament.coroutineScope.launch {
 			RepoDownloadManager.downloadUpdate(force)
-- 
cgit