Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/main/scala/com/supercoder/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@ package com.supercoder

import com.supercoder.ui.TerminalChat
import com.supercoder.agents.CoderAgent
import com.supercoder.config.ArgsParser
import com.supercoder.config.{ArgsParser, Config}
import com.supercoder.lib.CursorRulesLoader

object Main {
var AppConfig: Config = Config()

def main(args: Array[String]): Unit = {
ArgsParser.parse(args) match {
case Some(config) =>
val additionalPrompt = if config.useCursorRules then CursorRulesLoader.loadRules() else ""
val agent = new CoderAgent(additionalPrompt)
AppConfig = config
val additionalPrompt = if AppConfig.useCursorRules then CursorRulesLoader.loadRules() else ""
val modelName = AppConfig.model
val agent = new CoderAgent(additionalPrompt, modelName)
TerminalChat.run(agent)
case None =>
// invalid options, usage error message is already printed by scopt
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/com/supercoder/agents/CoderAgent.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ The discussion is about the code of the current project/folder. Always use the r
project if you are unsure before giving answer.
"""

class CoderAgent(additionalPrompt: String = "")
extends BaseChatAgent(coderAgentPrompt + additionalPrompt) {
class CoderAgent(additionalPrompt: String = "", model: String = "")
extends BaseChatAgent(coderAgentPrompt + additionalPrompt, model) {

final val availableTools = List(
CodeSearchTool,
Expand Down
166 changes: 132 additions & 34 deletions src/main/scala/com/supercoder/base/Agent.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,35 @@ package com.supercoder.base
import com.openai.client.okhttp.OpenAIOkHttpClient
import com.openai.core.http.Headers
import com.openai.models.*

import java.util
import java.util.Optional
import scala.collection.mutable.ListBuffer
import com.supercoder.Main
import com.supercoder.Main.AppConfig
import com.supercoder.lib.Console.{blue, red}
import io.circe.*
import io.circe.generic.auto.*
import io.circe.parser.*

import java.util
import java.util.Optional
import scala.collection.mutable.ListBuffer

val BasePrompt = s"""
# Tool calling
For each function call, return a json object with function name and arguments within <@TOOL></@TOOL> XML tags:

<@TOOL>{"name": <function-name>, "arguments": "<json-encoded-string-of-the-arguments>"}</@TOOL>
<@TOOL>
{"name": <function-name>, "arguments": "<json-encoded-string-of-the-arguments>"}
</@TOOL>

The arguments value is ALWAYS a JSON-encoded string, when there is no arguments, use empty string "".

For example:
<@TOOL>{"name": "file-read", "arguments": "{\"fileName\": \"example.txt\"}"}</@TOOL>
<@TOOL>
{"name": "file-read", "arguments": "{\"fileName\": \"example.txt\"}"}
</@TOOL>

<@TOOL>{"name": "project-structure", "arguments": ""}</@TOOL>
<@TOOL>
{"name": "project-structure", "arguments": ""}
</@TOOL>

The client will response with <@TOOL-RESULT>[content]</@TOOL-RESULT> XML tags to provide the result of the function call.
Use it to continue the conversation with the user.
Expand Down Expand Up @@ -66,7 +74,7 @@ case class ToolCallDescription(

}

abstract class BaseChatAgent(prompt: String) {
abstract class BaseChatAgent(prompt: String, model: String = AgentConfig.OpenAIModel) {
private val client = OpenAIOkHttpClient.builder()
.baseUrl(AgentConfig.OpenAIAPIBaseURL)
.apiKey(AgentConfig.OpenAIAPIKey)
Expand All @@ -79,6 +87,8 @@ abstract class BaseChatAgent(prompt: String) {
private var chatHistory: ListBuffer[ChatCompletionMessageParam] =
ListBuffer.empty

def selectedModel: String = if (model.nonEmpty) model else AgentConfig.OpenAIModel

def toolExecution(toolCall: ToolCallDescription): String
def toolDefinitionList: List[FunctionDefinition]

Expand Down Expand Up @@ -106,7 +116,7 @@ abstract class BaseChatAgent(prompt: String) {
val params = ChatCompletionCreateParams
.builder()
.addSystemMessage(BasePrompt + prompt)
.model(AgentConfig.OpenAIModel)
.model(selectedModel)

// Add all messages from chat history
chatHistory.foreach(params.addMessage)
Expand Down Expand Up @@ -144,48 +154,109 @@ abstract class BaseChatAgent(prompt: String) {
val it = streamResponse.stream().iterator()
streamingStarted = true
val wordBuffer = new StringBuilder()
var isHiddenTokens = false
var isInToolTag = false

while(it.hasNext && !cancelStreaming) {
val chunk = it.next()
val delta = chunk.choices.getFirst.delta

if (delta.content().isPresent) {
val chunkContent = delta.content().get()
currentMessageBuilder.append(chunkContent)
wordBuffer.append(chunkContent)
val bufferContent = wordBuffer.toString()
if (bufferContent.contains(" ")) {
val words = bufferContent.split(" ")
val endsWithSpace = bufferContent.last.isWhitespace
val completeWords = if (endsWithSpace) words else words.dropRight(1)
for (word <- completeWords) {
if (word.contains("<@TOOL>")) {
isHiddenTokens = true
val content = delta.content().get()
wordBuffer.append(content)
currentMessageBuilder.append(content)

val toolStart = "<@TOOL>"
val toolEnd = "</@TOOL>"
val toolResultStart = "<@TOOL-RESULT>"
val toolResultEnd = "</@TOOL-RESULT>"

var currentToolTagEndMarker: Option[String] = None

var processedSomething = true
while (processedSomething && wordBuffer.nonEmpty) {
processedSomething = false

if (isInToolTag) {
val endMarker = currentToolTagEndMarker.getOrElse(toolEnd)
val endTagIndex = wordBuffer.indexOf(endMarker)
if (endTagIndex != -1) {
val contentToConsume = wordBuffer.substring(0, endTagIndex + endMarker.length)
if (AppConfig.isDebugMode) {
print(red(contentToConsume))
}
wordBuffer.delete(0, contentToConsume.length)
isInToolTag = false
currentToolTagEndMarker = None
processedSomething = true
} else {
if (AppConfig.isDebugMode) {
print(red(wordBuffer.toString()))
}
wordBuffer.clear()
}
if (word.contains("</@TOOL>")) {
isHiddenTokens = false
} else {
val toolStartIndex = wordBuffer.indexOf(toolStart)
val toolResultStartIndex = wordBuffer.indexOf(toolResultStart)

var startTagIndex = -1
var startMarker = ""
var expectedEndMarker = ""

if (toolStartIndex != -1 && (toolResultStartIndex == -1 || toolStartIndex < toolResultStartIndex)) {
startTagIndex = toolStartIndex
startMarker = toolStart
expectedEndMarker = toolEnd
} else if (toolResultStartIndex != -1) {
startTagIndex = toolResultStartIndex
startMarker = toolResultStart
expectedEndMarker = toolResultEnd
}
if (!isHiddenTokens) {
print(blue(word + " "))

if (startTagIndex != -1) {
val beforeTag = wordBuffer.substring(0, startTagIndex)
if (beforeTag.nonEmpty) {
val (words, remaining) = processWords(beforeTag)
if (words.nonEmpty) {
words.foreach { case (word, ws) => print(blue(word)); print(ws) }
wordBuffer.delete(0, beforeTag.length - remaining.length)
processedSomething = true
}
}

if (wordBuffer.indexOf(startMarker) == 0) {
if (AppConfig.isDebugMode) {
print(red(startMarker))
}
wordBuffer.delete(0, startMarker.length)
isInToolTag = true
currentToolTagEndMarker = Some(expectedEndMarker)
processedSomething = true
}

} else {
val (words, remaining) = processWords(wordBuffer.toString())
if (words.nonEmpty) { // Only process if complete words were found
words.foreach { case (word, ws) => print(blue(word)); print(ws) }
val processedLength = wordBuffer.length() - remaining.length()
wordBuffer.delete(0, processedLength)
processedSomething = true // Buffer content changed
}
}
}
wordBuffer.clear()
if (!endsWithSpace && words.nonEmpty) {
wordBuffer.append(words.last)
}
}
}
}

// Print out the rest of the word buffer if it has any content
if (wordBuffer.nonEmpty) {
val remainingContent = wordBuffer.toString()
if (remainingContent.nonEmpty) {
if (!isHiddenTokens) {
println(blue(remainingContent))
if (isInToolTag) {
if (AppConfig.isDebugMode) {
print(red(wordBuffer.toString()))
}
currentMessageBuilder.append(remainingContent)
} else {
print(blue(wordBuffer.toString()))
}
wordBuffer.clear()
}

if (cancelStreaming) {
Expand Down Expand Up @@ -227,6 +298,33 @@ abstract class BaseChatAgent(prompt: String) {
}
}

// Helper function to process words and whitespace
private def processWords(text: String): (ListBuffer[(String, String)], String) = {
val words = ListBuffer[(String, String)]()
var remainingText = text
var continueProcessing = true

while (continueProcessing) {
val whitespaceIndex = remainingText.indexWhere(_.isWhitespace)
if (whitespaceIndex != -1) {
val word = remainingText.substring(0, whitespaceIndex)
val whitespace = remainingText.substring(whitespaceIndex).takeWhile(_.isWhitespace)
if (word.nonEmpty) {
words += ((word, whitespace))
} else {
// Handle leading whitespace? For now, just consume it with the next word or as trailing.
// If printing just whitespace: print(whitespace)
}
remainingText = remainingText.substring(whitespaceIndex + whitespace.length)
if (remainingText.isEmpty) continueProcessing = false
} else {
// No more whitespace, the rest is a partial word or empty
continueProcessing = false
}
}
(words, remainingText) // Return processed words and any remaining partial word
}

private def handleToolCall(toolCall: ToolCallDescription): Unit = {
val toolResult = toolExecution(toolCall)

Expand Down
8 changes: 7 additions & 1 deletion src/main/scala/com/supercoder/config/ArgsParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package com.supercoder.config

import scopt.OParser

case class Config(useCursorRules: Boolean = false)
case class Config(useCursorRules: Boolean = false, model: String = "", isDebugMode: Boolean = false)

object ArgsParser {
def parse(args: Array[String]): Option[Config] = {
Expand All @@ -14,6 +14,12 @@ object ArgsParser {
opt[String]('c', "use-cursor-rules")
.action((x, c) => c.copy(useCursorRules = (x == "true")))
.text("use Cursor rules for the agent"),
opt[String]('m', "model")
.action((x, c) => c.copy(model = x))
.text("model to use for the agent"),
opt[String]('d', "debug")
.action((x, c) => c.copy(isDebugMode = (x == "true")))
.text("enable debug mode"),
help("help").text("prints this usage text")
)
}
Expand Down
7 changes: 4 additions & 3 deletions src/main/scala/com/supercoder/ui/TerminalChat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ object TerminalChat {
print("\u001b[H")
}

def printHeader(): Unit = {
def printHeader(agent: BaseChatAgent): Unit = {
clearScreen()
println(blue("█▀ █░█ █▀█ █▀▀ █▀█ █▀▀ █▀█ █▀▄ █▀▀ █▀█"))
println(blue("▄█ █▄█ █▀▀ ██▄ █▀▄ █▄▄ █▄█ █▄▀ ██▄ █▀▄"))
println(blue(s"v${BuildInfo.version}"))
println()
println(blue(s"Model: ${agent.selectedModel}"))
println(blue("Type '/help' for available commands.\n"))
}

Expand All @@ -34,7 +35,7 @@ object TerminalChat {
}

def run(agent: BaseChatAgent): Unit = {
printHeader()
printHeader(agent)
val terminal: Terminal = TerminalBuilder.builder().system(true).build()
val reader: LineReader = LineReaderBuilder.builder().terminal(terminal).build()

Expand Down Expand Up @@ -69,7 +70,7 @@ object TerminalChat {
case "/help" => showHelp()
case "/clear" =>
clearScreen()
printHeader()
printHeader(agent)
case "exit" | "bye" =>
println(blue("\nChat session terminated. Goodbye!"))
keepRunning = false
Expand Down
Loading