Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 110 additions & 91 deletions src/main/scala/com/supercoder/base/Agent.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,49 @@
package com.supercoder.base

import com.openai.client.okhttp.OpenAIOkHttpClient
import com.openai.core.http.Headers
import com.openai.models.*

import java.util
import java.util.Optional
import scala.collection.mutable.ListBuffer
import com.supercoder.lib.Console.blue
import com.supercoder.lib.Console.{blue, red}
import io.circe.*
import io.circe.generic.auto.*
import io.circe.parser.*

val BasePrompt = s"""
# Tool calling
For each function call, return a json object with function name and arguments within <@TOOL></@TOOL> XML tags:

<@TOOL>
{"name": <function-name>, "arguments": "<json-encoded-string-of-the-arguments>"}
</@TOOL>

The arguments value is ALWAYS a JSON-encoded string, when there is no arguments, use empty string "".

For example:
<@TOOL>
{"name": "file-read", "arguments": "{\"fileName\": \"example.txt\"}"}
</@TOOL>

<@TOOL>
{"name": "project-structure", "arguments": ""}
</@TOOL>

The client will response with <@TOOL-RESULT>[content]</@TOOL-RESULT> XML tags to provide the result of the function call.
Use it to continue the conversation with the user.

# Response format
When responding to the user, use plain text format. NEVER use Markdown's bold or italic formatting.

# Safety
Please refuse to answer any unsafe or unethical requests.
Do not execute any command that could harm the system or access sensitive information.
When you want to execute some potentially unsafe command, please ask for user confirmation first before generating the tool call instruction.

# Agent Instructions
"""

object AgentConfig {
val OpenAIAPIBaseURL: String = sys.env.get("SUPERCODER_BASE_URL")
Expand All @@ -20,14 +57,11 @@ object AgentConfig {
val OpenAIAPIKey: String = sys.env.get("SUPERCODER_API_KEY")
.orElse(sys.env.get("OPENAI_API_KEY"))
.getOrElse(throw new RuntimeException("You need to config SUPERCODER_API_KEY or OPENAI_API_KEY variable"))

val IsGeminiMode: String = sys.env.get("SUPERCODER_GEMINI_MODE").getOrElse("false").toLowerCase
}

case class ToolCallDescription(
name: String = "",
arguments: String = "",
id: String = ""
) {

def addName(name: Optional[String]): ToolCallDescription =
Expand All @@ -36,15 +70,16 @@ case class ToolCallDescription(
def addArguments(arguments: Optional[String]): ToolCallDescription =
copy(arguments = this.arguments + arguments.orElse(""))

def addId(id: Optional[String]): ToolCallDescription =
copy(id = this.id + id.orElse(""))

}

abstract class BaseChatAgent(prompt: String) {
private val client = OpenAIOkHttpClient.builder()
.baseUrl(AgentConfig.OpenAIAPIBaseURL)
.apiKey(AgentConfig.OpenAIAPIKey)
.headers(Headers.builder()
.put("HTTP-Referer", "https://github.com/huytd/supercoder/")
.put("X-Title", "SuperCoder")
.build())
.build()

private var chatHistory: ListBuffer[ChatCompletionMessageParam] =
Expand Down Expand Up @@ -72,49 +107,11 @@ abstract class BaseChatAgent(prompt: String) {
.builder()
.content(content)

private def createAssistantToolCallMessage(
toolCall: ToolCallDescription
): Unit = {
var messageBuilder = createAssistantMessageBuilder("")
messageBuilder.addToolCall(
ChatCompletionMessageToolCall
.builder()
.id(toolCall.id)
.function(
ChatCompletionMessageToolCall.Function
.builder()
.name(toolCall.name)
.arguments(toolCall.arguments)
.build()
)
.build()
)

addMessageToHistory(
ChatCompletionMessageParam.ofAssistant(messageBuilder.build())
)
}

private def createToolResponseMessage(
result: String,
toolCallId: String
): ChatCompletionMessageParam = {
val toolResponse = ChatCompletionMessageParam.ofTool(
ChatCompletionToolMessageParam
.builder()
.content(result)
.toolCallId(toolCallId)
.build()
)

toolResponse
}

// Helper method to build base parameters with system prompt and chat history
private def buildBaseParams(): ChatCompletionCreateParams.Builder = {
val params = ChatCompletionCreateParams
.builder()
.addSystemMessage(prompt)
.addSystemMessage(BasePrompt + prompt)
.model(AgentConfig.OpenAIModel)

// Add all messages from chat history
Expand All @@ -132,23 +129,11 @@ abstract class BaseChatAgent(prompt: String) {
)
}

// Build parameters with tool definition
var params = buildBaseParams()
toolDefinitionList.foreach(tool =>
params.addTool(
ChatCompletionTool
.builder()
.function(tool)
.build()
)
)

// Stream the response with support for cancelling using Ctrl+C
val streamResponse = client.chat().completions().createStreaming(params.build())
var currentMessageBuilder = new StringBuilder()
val params = buildBaseParams().build()
val streamResponse = client.chat().completions().createStreaming(params)
val currentMessageBuilder = new StringBuilder()
var currentToolCall = ToolCallDescription()

// Set up a SIGINT handler to cancel the streaming response only after streaming starts
import sun.misc.{Signal, SignalHandler}
var cancelStreaming = false
var streamingStarted = false
Expand All @@ -164,25 +149,51 @@ abstract class BaseChatAgent(prompt: String) {
try {
val it = streamResponse.stream().iterator()
streamingStarted = true
val wordBuffer = new StringBuilder()
var isHiddenTokens = false

while(it.hasNext && !cancelStreaming) {
val chunk = it.next()
val delta = chunk.choices.getFirst.delta
if (delta.toolCalls().isPresent && !delta.toolCalls().get().isEmpty) {
val toolCall = delta.toolCalls().get().getFirst
if (toolCall.function().isPresent) {
val toolFunction = toolCall.function().get()
currentToolCall = currentToolCall
.addName(toolFunction.name())
.addArguments(toolFunction.arguments())
.addId(toolCall.id())
}
}

if (delta.content().isPresent) {
val chunkContent = delta.content().get()
currentMessageBuilder.append(chunkContent)
print(blue(chunkContent))
wordBuffer.append(chunkContent)
val bufferContent = wordBuffer.toString()
if (bufferContent.contains(" ")) {
val words = bufferContent.split(" ")
val endsWithSpace = bufferContent.last.isWhitespace
val completeWords = if (endsWithSpace) words else words.dropRight(1)
for (word <- completeWords) {
if (word.contains("<@TOOL>")) {
isHiddenTokens = true
}
if (word.contains("</@TOOL>")) {
isHiddenTokens = false
}
if (!isHiddenTokens) {
print(blue(word + " "))
}
}
wordBuffer.clear()
if (!endsWithSpace && words.nonEmpty) {
wordBuffer.append(words.last)
}
}
}
}

if (wordBuffer.nonEmpty) {
val remainingContent = wordBuffer.toString()
if (remainingContent.nonEmpty) {
if (!isHiddenTokens) {
println(blue(remainingContent))
}
currentMessageBuilder.append(remainingContent)
}
}

if (cancelStreaming) {
println(blue("\nStreaming cancelled by user"))
}
Expand All @@ -194,14 +205,29 @@ abstract class BaseChatAgent(prompt: String) {
streamResponse.close()
if (currentMessageBuilder.nonEmpty) {
println()
val messageContent = currentMessageBuilder.toString()
addMessageToHistory(
ChatCompletionMessageParam.ofAssistant(
createAssistantMessageBuilder(currentMessageBuilder.toString())
createAssistantMessageBuilder(messageContent)
.build()
)
)

// Check if the message contains a tool call
val toolCallRegex = """(?s)<@TOOL>(.*?)</@TOOL>""".r
val toolCallMatch = toolCallRegex.findFirstMatchIn(messageContent).map(_.group(1))
if (toolCallMatch.isDefined) {
val toolCallJson = toolCallMatch.get
try {
val parseResult: Either[Error, ToolCallDescription] = decode[ToolCallDescription](toolCallJson)
currentToolCall = parseResult.getOrElse(ToolCallDescription())
} catch {
case e: Exception =>
println(red(s"Error parsing tool call: ${e.getMessage}"))
}
}
}
if (currentToolCall.id.nonEmpty || currentToolCall.name.nonEmpty) {
if (currentToolCall.name.nonEmpty) {
handleToolCall(currentToolCall)
}
}
Expand All @@ -210,26 +236,19 @@ abstract class BaseChatAgent(prompt: String) {
private def handleToolCall(toolCall: ToolCallDescription): Unit = {
val toolResult = toolExecution(toolCall)

if (AgentConfig.IsGeminiMode != "true") {
// Add the assistant's tool call message to chat history
createAssistantToolCallMessage(toolCall)
// Add result to chat history
addMessageToHistory(createToolResponseMessage(toolResult, toolCall.id))
} else {
// Add the result as assistant's message
addMessageToHistory(
ChatCompletionMessageParam.ofAssistant(
createAssistantMessageBuilder(s"I will need to use the ${toolCall.name} tool...").build()
)
// Add the result as assistant's message
addMessageToHistory(
ChatCompletionMessageParam.ofAssistant(
createAssistantMessageBuilder(s"Calling ${toolCall.name} tool...").build()
)
addMessageToHistory(
ChatCompletionMessageParam.ofUser(
createUserMessageBuilder(s"Here's the tool call result: ${toolResult}").build()
)
)
addMessageToHistory(
ChatCompletionMessageParam.ofUser(
createUserMessageBuilder(s"<@TOOL-RESULT>${toolResult}</@TOOL-RESULT>").build()
)
}
)

// Trigger follow up response from assistant
// Trigger follow-up response from assistant
chat("")
}

Expand Down
28 changes: 2 additions & 26 deletions src/main/scala/com/supercoder/tools/CodeEditTool.scala
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
package com.supercoder.tools

import com.openai.core.JsonValue
import com.openai.models.{FunctionDefinition, FunctionParameters}
import com.openai.models.FunctionDefinition
import com.supercoder.base.Tool
import com.supercoder.lib.Console.green
import io.circe.*
import io.circe.generic.auto.*
import io.circe.parser.*

import java.io.{File, PrintWriter}
import java.nio.file.{Files, Paths}
import java.util

case class CodeEditToolArguments(filepath: String, content: String)

Expand All @@ -20,28 +17,7 @@ object CodeEditTool extends Tool {
.builder()
.name("code-edit")
.description(
"Edit a code file in the repository. Provide the file path and the new content for the file."
)
.parameters(
FunctionParameters
.builder()
.putAdditionalProperty("type", JsonValue.from("object"))
.putAdditionalProperty(
"properties",
JsonValue.from(
util.Map.of(
"filepath",
util.Map.of("type", "string"),
"content",
util.Map.of("type", "string")
)
)
)
.putAdditionalProperty(
"required",
JsonValue.from(util.List.of("filepath", "content"))
)
.build()
"Edit a code file in the repository. Provide the file path and the new content for the file. Arguments: {\"filepath\": \"<file-path>\", \"content\": \"<new-content>\"}"
)
.build()

Expand Down
16 changes: 2 additions & 14 deletions src/main/scala/com/supercoder/tools/CodeSearchTool.scala
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
package com.supercoder.tools

import com.openai.core.JsonValue
import com.openai.models.{FunctionDefinition, FunctionParameters}
import com.openai.models.FunctionDefinition
import com.supercoder.base.Tool
import com.supercoder.lib.Console.green
import io.circe.*
import io.circe.generic.auto.*
import io.circe.parser.*

import java.util
import scala.sys.process.*

case class CodeSearchToolArguments(query: String)
Expand All @@ -19,17 +17,7 @@ object CodeSearchTool extends Tool {
.builder()
.name("code-search")
.description(
"Search for code in a given repository. The query parameter should be a regular expression."
)
.parameters(
FunctionParameters
.builder()
.putAdditionalProperty("type", JsonValue.from("object"))
.putAdditionalProperty(
"properties",
JsonValue.from(util.Map.of("query", util.Map.of("type", "string")))
)
.build()
"Search for code in a given repository. The query parameter should be a regular expression. Arguments: {\"query\": \"<search-query>\"}"
)
.build()

Expand Down
Loading
Loading