diff --git a/src/main/scala/com/supercoder/Main.scala b/src/main/scala/com/supercoder/Main.scala index 46924ea..36eedff 100644 --- a/src/main/scala/com/supercoder/Main.scala +++ b/src/main/scala/com/supercoder/Main.scala @@ -2,15 +2,19 @@ package com.supercoder import com.supercoder.ui.TerminalChat import com.supercoder.agents.CoderAgent -import com.supercoder.config.ArgsParser +import com.supercoder.config.{ArgsParser, Config} import com.supercoder.lib.CursorRulesLoader object Main { + var AppConfig: Config = Config() + def main(args: Array[String]): Unit = { ArgsParser.parse(args) match { case Some(config) => - val additionalPrompt = if config.useCursorRules then CursorRulesLoader.loadRules() else "" - val agent = new CoderAgent(additionalPrompt) + AppConfig = config + val additionalPrompt = if AppConfig.useCursorRules then CursorRulesLoader.loadRules() else "" + val modelName = AppConfig.model + val agent = new CoderAgent(additionalPrompt, modelName) TerminalChat.run(agent) case None => // invalid options, usage error message is already printed by scopt diff --git a/src/main/scala/com/supercoder/agents/CoderAgent.scala b/src/main/scala/com/supercoder/agents/CoderAgent.scala index a11a333..a65d8ba 100644 --- a/src/main/scala/com/supercoder/agents/CoderAgent.scala +++ b/src/main/scala/com/supercoder/agents/CoderAgent.scala @@ -21,8 +21,8 @@ The discussion is about the code of the current project/folder. Always use the r project if you are unsure before giving answer. """ -class CoderAgent(additionalPrompt: String = "") - extends BaseChatAgent(coderAgentPrompt + additionalPrompt) { +class CoderAgent(additionalPrompt: String = "", model: String = "") + extends BaseChatAgent(coderAgentPrompt + additionalPrompt, model) { final val availableTools = List( CodeSearchTool, diff --git a/src/main/scala/com/supercoder/base/Agent.scala b/src/main/scala/com/supercoder/base/Agent.scala index 9e99d0f..b27c38b 100644 --- a/src/main/scala/com/supercoder/base/Agent.scala +++ b/src/main/scala/com/supercoder/base/Agent.scala @@ -3,27 +3,35 @@ package com.supercoder.base import com.openai.client.okhttp.OpenAIOkHttpClient import com.openai.core.http.Headers import com.openai.models.* - -import java.util -import java.util.Optional -import scala.collection.mutable.ListBuffer +import com.supercoder.Main +import com.supercoder.Main.AppConfig import com.supercoder.lib.Console.{blue, red} import io.circe.* import io.circe.generic.auto.* import io.circe.parser.* +import java.util +import java.util.Optional +import scala.collection.mutable.ListBuffer + val BasePrompt = s""" # Tool calling For each function call, return a json object with function name and arguments within <@TOOL> XML tags: -<@TOOL>{"name": , "arguments": ""} +<@TOOL> +{"name": , "arguments": ""} + The arguments value is ALWAYS a JSON-encoded string, when there is no arguments, use empty string "". For example: -<@TOOL>{"name": "file-read", "arguments": "{\"fileName\": \"example.txt\"}"} +<@TOOL> +{"name": "file-read", "arguments": "{\"fileName\": \"example.txt\"}"} + -<@TOOL>{"name": "project-structure", "arguments": ""} +<@TOOL> +{"name": "project-structure", "arguments": ""} + The client will response with <@TOOL-RESULT>[content] XML tags to provide the result of the function call. Use it to continue the conversation with the user. @@ -66,7 +74,7 @@ case class ToolCallDescription( } -abstract class BaseChatAgent(prompt: String) { +abstract class BaseChatAgent(prompt: String, model: String = AgentConfig.OpenAIModel) { private val client = OpenAIOkHttpClient.builder() .baseUrl(AgentConfig.OpenAIAPIBaseURL) .apiKey(AgentConfig.OpenAIAPIKey) @@ -79,6 +87,8 @@ abstract class BaseChatAgent(prompt: String) { private var chatHistory: ListBuffer[ChatCompletionMessageParam] = ListBuffer.empty + def selectedModel: String = if (model.nonEmpty) model else AgentConfig.OpenAIModel + def toolExecution(toolCall: ToolCallDescription): String def toolDefinitionList: List[FunctionDefinition] @@ -106,7 +116,7 @@ abstract class BaseChatAgent(prompt: String) { val params = ChatCompletionCreateParams .builder() .addSystemMessage(BasePrompt + prompt) - .model(AgentConfig.OpenAIModel) + .model(selectedModel) // Add all messages from chat history chatHistory.foreach(params.addMessage) @@ -144,48 +154,109 @@ abstract class BaseChatAgent(prompt: String) { val it = streamResponse.stream().iterator() streamingStarted = true val wordBuffer = new StringBuilder() - var isHiddenTokens = false + var isInToolTag = false while(it.hasNext && !cancelStreaming) { val chunk = it.next() val delta = chunk.choices.getFirst.delta if (delta.content().isPresent) { - val chunkContent = delta.content().get() - currentMessageBuilder.append(chunkContent) - wordBuffer.append(chunkContent) - val bufferContent = wordBuffer.toString() - if (bufferContent.contains(" ")) { - val words = bufferContent.split(" ") - val endsWithSpace = bufferContent.last.isWhitespace - val completeWords = if (endsWithSpace) words else words.dropRight(1) - for (word <- completeWords) { - if (word.contains("<@TOOL>")) { - isHiddenTokens = true + val content = delta.content().get() + wordBuffer.append(content) + currentMessageBuilder.append(content) + + val toolStart = "<@TOOL>" + val toolEnd = "" + val toolResultStart = "<@TOOL-RESULT>" + val toolResultEnd = "" + + var currentToolTagEndMarker: Option[String] = None + + var processedSomething = true + while (processedSomething && wordBuffer.nonEmpty) { + processedSomething = false + + if (isInToolTag) { + val endMarker = currentToolTagEndMarker.getOrElse(toolEnd) + val endTagIndex = wordBuffer.indexOf(endMarker) + if (endTagIndex != -1) { + val contentToConsume = wordBuffer.substring(0, endTagIndex + endMarker.length) + if (AppConfig.isDebugMode) { + print(red(contentToConsume)) + } + wordBuffer.delete(0, contentToConsume.length) + isInToolTag = false + currentToolTagEndMarker = None + processedSomething = true + } else { + if (AppConfig.isDebugMode) { + print(red(wordBuffer.toString())) + } + wordBuffer.clear() } - if (word.contains("")) { - isHiddenTokens = false + } else { + val toolStartIndex = wordBuffer.indexOf(toolStart) + val toolResultStartIndex = wordBuffer.indexOf(toolResultStart) + + var startTagIndex = -1 + var startMarker = "" + var expectedEndMarker = "" + + if (toolStartIndex != -1 && (toolResultStartIndex == -1 || toolStartIndex < toolResultStartIndex)) { + startTagIndex = toolStartIndex + startMarker = toolStart + expectedEndMarker = toolEnd + } else if (toolResultStartIndex != -1) { + startTagIndex = toolResultStartIndex + startMarker = toolResultStart + expectedEndMarker = toolResultEnd } - if (!isHiddenTokens) { - print(blue(word + " ")) + + if (startTagIndex != -1) { + val beforeTag = wordBuffer.substring(0, startTagIndex) + if (beforeTag.nonEmpty) { + val (words, remaining) = processWords(beforeTag) + if (words.nonEmpty) { + words.foreach { case (word, ws) => print(blue(word)); print(ws) } + wordBuffer.delete(0, beforeTag.length - remaining.length) + processedSomething = true + } + } + + if (wordBuffer.indexOf(startMarker) == 0) { + if (AppConfig.isDebugMode) { + print(red(startMarker)) + } + wordBuffer.delete(0, startMarker.length) + isInToolTag = true + currentToolTagEndMarker = Some(expectedEndMarker) + processedSomething = true + } + + } else { + val (words, remaining) = processWords(wordBuffer.toString()) + if (words.nonEmpty) { // Only process if complete words were found + words.foreach { case (word, ws) => print(blue(word)); print(ws) } + val processedLength = wordBuffer.length() - remaining.length() + wordBuffer.delete(0, processedLength) + processedSomething = true // Buffer content changed + } } } - wordBuffer.clear() - if (!endsWithSpace && words.nonEmpty) { - wordBuffer.append(words.last) - } } } } + // Print out the rest of the word buffer if it has any content if (wordBuffer.nonEmpty) { - val remainingContent = wordBuffer.toString() - if (remainingContent.nonEmpty) { - if (!isHiddenTokens) { - println(blue(remainingContent)) + if (isInToolTag) { + if (AppConfig.isDebugMode) { + print(red(wordBuffer.toString())) } - currentMessageBuilder.append(remainingContent) + } else { + print(blue(wordBuffer.toString())) } + wordBuffer.clear() } if (cancelStreaming) { @@ -227,6 +298,33 @@ abstract class BaseChatAgent(prompt: String) { } } + // Helper function to process words and whitespace + private def processWords(text: String): (ListBuffer[(String, String)], String) = { + val words = ListBuffer[(String, String)]() + var remainingText = text + var continueProcessing = true + + while (continueProcessing) { + val whitespaceIndex = remainingText.indexWhere(_.isWhitespace) + if (whitespaceIndex != -1) { + val word = remainingText.substring(0, whitespaceIndex) + val whitespace = remainingText.substring(whitespaceIndex).takeWhile(_.isWhitespace) + if (word.nonEmpty) { + words += ((word, whitespace)) + } else { + // Handle leading whitespace? For now, just consume it with the next word or as trailing. + // If printing just whitespace: print(whitespace) + } + remainingText = remainingText.substring(whitespaceIndex + whitespace.length) + if (remainingText.isEmpty) continueProcessing = false + } else { + // No more whitespace, the rest is a partial word or empty + continueProcessing = false + } + } + (words, remainingText) // Return processed words and any remaining partial word + } + private def handleToolCall(toolCall: ToolCallDescription): Unit = { val toolResult = toolExecution(toolCall) diff --git a/src/main/scala/com/supercoder/config/ArgsParser.scala b/src/main/scala/com/supercoder/config/ArgsParser.scala index 6b12568..d22377e 100644 --- a/src/main/scala/com/supercoder/config/ArgsParser.scala +++ b/src/main/scala/com/supercoder/config/ArgsParser.scala @@ -2,7 +2,7 @@ package com.supercoder.config import scopt.OParser -case class Config(useCursorRules: Boolean = false) +case class Config(useCursorRules: Boolean = false, model: String = "", isDebugMode: Boolean = false) object ArgsParser { def parse(args: Array[String]): Option[Config] = { @@ -14,6 +14,12 @@ object ArgsParser { opt[String]('c', "use-cursor-rules") .action((x, c) => c.copy(useCursorRules = (x == "true"))) .text("use Cursor rules for the agent"), + opt[String]('m', "model") + .action((x, c) => c.copy(model = x)) + .text("model to use for the agent"), + opt[String]('d', "debug") + .action((x, c) => c.copy(isDebugMode = (x == "true"))) + .text("enable debug mode"), help("help").text("prints this usage text") ) } diff --git a/src/main/scala/com/supercoder/ui/TerminalChat.scala b/src/main/scala/com/supercoder/ui/TerminalChat.scala index 1bf2023..df79bf3 100644 --- a/src/main/scala/com/supercoder/ui/TerminalChat.scala +++ b/src/main/scala/com/supercoder/ui/TerminalChat.scala @@ -14,12 +14,13 @@ object TerminalChat { print("\u001b[H") } - def printHeader(): Unit = { + def printHeader(agent: BaseChatAgent): Unit = { clearScreen() println(blue("█▀ █░█ █▀█ █▀▀ █▀█ █▀▀ █▀█ █▀▄ █▀▀ █▀█")) println(blue("▄█ █▄█ █▀▀ ██▄ █▀▄ █▄▄ █▄█ █▄▀ ██▄ █▀▄")) println(blue(s"v${BuildInfo.version}")) println() + println(blue(s"Model: ${agent.selectedModel}")) println(blue("Type '/help' for available commands.\n")) } @@ -34,7 +35,7 @@ object TerminalChat { } def run(agent: BaseChatAgent): Unit = { - printHeader() + printHeader(agent) val terminal: Terminal = TerminalBuilder.builder().system(true).build() val reader: LineReader = LineReaderBuilder.builder().terminal(terminal).build() @@ -69,7 +70,7 @@ object TerminalChat { case "/help" => showHelp() case "/clear" => clearScreen() - printHeader() + printHeader(agent) case "exit" | "bye" => println(blue("\nChat session terminated. Goodbye!")) keepRunning = false