Skip to content

Commit 9cb6608

Browse files
authored
feat: Support Tool calling for all models (#4)
1 parent 32c4250 commit 9cb6608

7 files changed

Lines changed: 122 additions & 173 deletions

File tree

src/main/scala/com/supercoder/base/Agent.scala

Lines changed: 110 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,49 @@
11
package com.supercoder.base
22

33
import com.openai.client.okhttp.OpenAIOkHttpClient
4+
import com.openai.core.http.Headers
45
import com.openai.models.*
56

67
import java.util
78
import java.util.Optional
89
import scala.collection.mutable.ListBuffer
9-
import com.supercoder.lib.Console.blue
10+
import com.supercoder.lib.Console.{blue, red}
11+
import io.circe.*
12+
import io.circe.generic.auto.*
13+
import io.circe.parser.*
14+
15+
val BasePrompt = s"""
16+
# Tool calling
17+
For each function call, return a json object with function name and arguments within <@TOOL></@TOOL> XML tags:
18+
19+
<@TOOL>
20+
{"name": <function-name>, "arguments": "<json-encoded-string-of-the-arguments>"}
21+
</@TOOL>
22+
23+
The arguments value is ALWAYS a JSON-encoded string, when there is no arguments, use empty string "".
24+
25+
For example:
26+
<@TOOL>
27+
{"name": "file-read", "arguments": "{\"fileName\": \"example.txt\"}"}
28+
</@TOOL>
29+
30+
<@TOOL>
31+
{"name": "project-structure", "arguments": ""}
32+
</@TOOL>
33+
34+
The client will response with <@TOOL-RESULT>[content]</@TOOL-RESULT> XML tags to provide the result of the function call.
35+
Use it to continue the conversation with the user.
36+
37+
# Response format
38+
When responding to the user, use plain text format. NEVER use Markdown's bold or italic formatting.
39+
40+
# Safety
41+
Please refuse to answer any unsafe or unethical requests.
42+
Do not execute any command that could harm the system or access sensitive information.
43+
When you want to execute some potentially unsafe command, please ask for user confirmation first before generating the tool call instruction.
44+
45+
# Agent Instructions
46+
"""
1047

1148
object AgentConfig {
1249
val OpenAIAPIBaseURL: String = sys.env.get("SUPERCODER_BASE_URL")
@@ -20,14 +57,11 @@ object AgentConfig {
2057
val OpenAIAPIKey: String = sys.env.get("SUPERCODER_API_KEY")
2158
.orElse(sys.env.get("OPENAI_API_KEY"))
2259
.getOrElse(throw new RuntimeException("You need to config SUPERCODER_API_KEY or OPENAI_API_KEY variable"))
23-
24-
val IsGeminiMode: String = sys.env.get("SUPERCODER_GEMINI_MODE").getOrElse("false").toLowerCase
2560
}
2661

2762
case class ToolCallDescription(
2863
name: String = "",
2964
arguments: String = "",
30-
id: String = ""
3165
) {
3266

3367
def addName(name: Optional[String]): ToolCallDescription =
@@ -36,15 +70,16 @@ case class ToolCallDescription(
3670
def addArguments(arguments: Optional[String]): ToolCallDescription =
3771
copy(arguments = this.arguments + arguments.orElse(""))
3872

39-
def addId(id: Optional[String]): ToolCallDescription =
40-
copy(id = this.id + id.orElse(""))
41-
4273
}
4374

4475
abstract class BaseChatAgent(prompt: String) {
4576
private val client = OpenAIOkHttpClient.builder()
4677
.baseUrl(AgentConfig.OpenAIAPIBaseURL)
4778
.apiKey(AgentConfig.OpenAIAPIKey)
79+
.headers(Headers.builder()
80+
.put("HTTP-Referer", "https://github.com/huytd/supercoder/")
81+
.put("X-Title", "SuperCoder")
82+
.build())
4883
.build()
4984

5085
private var chatHistory: ListBuffer[ChatCompletionMessageParam] =
@@ -72,49 +107,11 @@ abstract class BaseChatAgent(prompt: String) {
72107
.builder()
73108
.content(content)
74109

75-
private def createAssistantToolCallMessage(
76-
toolCall: ToolCallDescription
77-
): Unit = {
78-
var messageBuilder = createAssistantMessageBuilder("")
79-
messageBuilder.addToolCall(
80-
ChatCompletionMessageToolCall
81-
.builder()
82-
.id(toolCall.id)
83-
.function(
84-
ChatCompletionMessageToolCall.Function
85-
.builder()
86-
.name(toolCall.name)
87-
.arguments(toolCall.arguments)
88-
.build()
89-
)
90-
.build()
91-
)
92-
93-
addMessageToHistory(
94-
ChatCompletionMessageParam.ofAssistant(messageBuilder.build())
95-
)
96-
}
97-
98-
private def createToolResponseMessage(
99-
result: String,
100-
toolCallId: String
101-
): ChatCompletionMessageParam = {
102-
val toolResponse = ChatCompletionMessageParam.ofTool(
103-
ChatCompletionToolMessageParam
104-
.builder()
105-
.content(result)
106-
.toolCallId(toolCallId)
107-
.build()
108-
)
109-
110-
toolResponse
111-
}
112-
113110
// Helper method to build base parameters with system prompt and chat history
114111
private def buildBaseParams(): ChatCompletionCreateParams.Builder = {
115112
val params = ChatCompletionCreateParams
116113
.builder()
117-
.addSystemMessage(prompt)
114+
.addSystemMessage(BasePrompt + prompt)
118115
.model(AgentConfig.OpenAIModel)
119116

120117
// Add all messages from chat history
@@ -132,23 +129,11 @@ abstract class BaseChatAgent(prompt: String) {
132129
)
133130
}
134131

135-
// Build parameters with tool definition
136-
var params = buildBaseParams()
137-
toolDefinitionList.foreach(tool =>
138-
params.addTool(
139-
ChatCompletionTool
140-
.builder()
141-
.function(tool)
142-
.build()
143-
)
144-
)
145-
146-
// Stream the response with support for cancelling using Ctrl+C
147-
val streamResponse = client.chat().completions().createStreaming(params.build())
148-
var currentMessageBuilder = new StringBuilder()
132+
val params = buildBaseParams().build()
133+
val streamResponse = client.chat().completions().createStreaming(params)
134+
val currentMessageBuilder = new StringBuilder()
149135
var currentToolCall = ToolCallDescription()
150136

151-
// Set up a SIGINT handler to cancel the streaming response only after streaming starts
152137
import sun.misc.{Signal, SignalHandler}
153138
var cancelStreaming = false
154139
var streamingStarted = false
@@ -164,25 +149,51 @@ abstract class BaseChatAgent(prompt: String) {
164149
try {
165150
val it = streamResponse.stream().iterator()
166151
streamingStarted = true
152+
val wordBuffer = new StringBuilder()
153+
var isHiddenTokens = false
154+
167155
while(it.hasNext && !cancelStreaming) {
168156
val chunk = it.next()
169157
val delta = chunk.choices.getFirst.delta
170-
if (delta.toolCalls().isPresent && !delta.toolCalls().get().isEmpty) {
171-
val toolCall = delta.toolCalls().get().getFirst
172-
if (toolCall.function().isPresent) {
173-
val toolFunction = toolCall.function().get()
174-
currentToolCall = currentToolCall
175-
.addName(toolFunction.name())
176-
.addArguments(toolFunction.arguments())
177-
.addId(toolCall.id())
178-
}
179-
}
158+
180159
if (delta.content().isPresent) {
181160
val chunkContent = delta.content().get()
182161
currentMessageBuilder.append(chunkContent)
183-
print(blue(chunkContent))
162+
wordBuffer.append(chunkContent)
163+
val bufferContent = wordBuffer.toString()
164+
if (bufferContent.contains(" ")) {
165+
val words = bufferContent.split(" ")
166+
val endsWithSpace = bufferContent.last.isWhitespace
167+
val completeWords = if (endsWithSpace) words else words.dropRight(1)
168+
for (word <- completeWords) {
169+
if (word.contains("<@TOOL>")) {
170+
isHiddenTokens = true
171+
}
172+
if (word.contains("</@TOOL>")) {
173+
isHiddenTokens = false
174+
}
175+
if (!isHiddenTokens) {
176+
print(blue(word + " "))
177+
}
178+
}
179+
wordBuffer.clear()
180+
if (!endsWithSpace && words.nonEmpty) {
181+
wordBuffer.append(words.last)
182+
}
183+
}
184184
}
185185
}
186+
187+
if (wordBuffer.nonEmpty) {
188+
val remainingContent = wordBuffer.toString()
189+
if (remainingContent.nonEmpty) {
190+
if (!isHiddenTokens) {
191+
println(blue(remainingContent))
192+
}
193+
currentMessageBuilder.append(remainingContent)
194+
}
195+
}
196+
186197
if (cancelStreaming) {
187198
println(blue("\nStreaming cancelled by user"))
188199
}
@@ -194,14 +205,29 @@ abstract class BaseChatAgent(prompt: String) {
194205
streamResponse.close()
195206
if (currentMessageBuilder.nonEmpty) {
196207
println()
208+
val messageContent = currentMessageBuilder.toString()
197209
addMessageToHistory(
198210
ChatCompletionMessageParam.ofAssistant(
199-
createAssistantMessageBuilder(currentMessageBuilder.toString())
211+
createAssistantMessageBuilder(messageContent)
200212
.build()
201213
)
202214
)
215+
216+
// Check if the message contains a tool call
217+
val toolCallRegex = """(?s)<@TOOL>(.*?)</@TOOL>""".r
218+
val toolCallMatch = toolCallRegex.findFirstMatchIn(messageContent).map(_.group(1))
219+
if (toolCallMatch.isDefined) {
220+
val toolCallJson = toolCallMatch.get
221+
try {
222+
val parseResult: Either[Error, ToolCallDescription] = decode[ToolCallDescription](toolCallJson)
223+
currentToolCall = parseResult.getOrElse(ToolCallDescription())
224+
} catch {
225+
case e: Exception =>
226+
println(red(s"Error parsing tool call: ${e.getMessage}"))
227+
}
228+
}
203229
}
204-
if (currentToolCall.id.nonEmpty || currentToolCall.name.nonEmpty) {
230+
if (currentToolCall.name.nonEmpty) {
205231
handleToolCall(currentToolCall)
206232
}
207233
}
@@ -210,26 +236,19 @@ abstract class BaseChatAgent(prompt: String) {
210236
private def handleToolCall(toolCall: ToolCallDescription): Unit = {
211237
val toolResult = toolExecution(toolCall)
212238

213-
if (AgentConfig.IsGeminiMode != "true") {
214-
// Add the assistant's tool call message to chat history
215-
createAssistantToolCallMessage(toolCall)
216-
// Add result to chat history
217-
addMessageToHistory(createToolResponseMessage(toolResult, toolCall.id))
218-
} else {
219-
// Add the result as assistant's message
220-
addMessageToHistory(
221-
ChatCompletionMessageParam.ofAssistant(
222-
createAssistantMessageBuilder(s"I will need to use the ${toolCall.name} tool...").build()
223-
)
239+
// Add the result as assistant's message
240+
addMessageToHistory(
241+
ChatCompletionMessageParam.ofAssistant(
242+
createAssistantMessageBuilder(s"Calling ${toolCall.name} tool...").build()
224243
)
225-
addMessageToHistory(
226-
ChatCompletionMessageParam.ofUser(
227-
createUserMessageBuilder(s"Here's the tool call result: ${toolResult}").build()
228-
)
244+
)
245+
addMessageToHistory(
246+
ChatCompletionMessageParam.ofUser(
247+
createUserMessageBuilder(s"<@TOOL-RESULT>${toolResult}</@TOOL-RESULT>").build()
229248
)
230-
}
249+
)
231250

232-
// Trigger follow up response from assistant
251+
// Trigger follow-up response from assistant
233252
chat("")
234253
}
235254

src/main/scala/com/supercoder/tools/CodeEditTool.scala

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
package com.supercoder.tools
22

3-
import com.openai.core.JsonValue
4-
import com.openai.models.{FunctionDefinition, FunctionParameters}
3+
import com.openai.models.FunctionDefinition
54
import com.supercoder.base.Tool
65
import com.supercoder.lib.Console.green
76
import io.circe.*
87
import io.circe.generic.auto.*
98
import io.circe.parser.*
109

1110
import java.io.{File, PrintWriter}
12-
import java.nio.file.{Files, Paths}
13-
import java.util
1411

1512
case class CodeEditToolArguments(filepath: String, content: String)
1613

@@ -20,28 +17,7 @@ object CodeEditTool extends Tool {
2017
.builder()
2118
.name("code-edit")
2219
.description(
23-
"Edit a code file in the repository. Provide the file path and the new content for the file."
24-
)
25-
.parameters(
26-
FunctionParameters
27-
.builder()
28-
.putAdditionalProperty("type", JsonValue.from("object"))
29-
.putAdditionalProperty(
30-
"properties",
31-
JsonValue.from(
32-
util.Map.of(
33-
"filepath",
34-
util.Map.of("type", "string"),
35-
"content",
36-
util.Map.of("type", "string")
37-
)
38-
)
39-
)
40-
.putAdditionalProperty(
41-
"required",
42-
JsonValue.from(util.List.of("filepath", "content"))
43-
)
44-
.build()
20+
"Edit a code file in the repository. Provide the file path and the new content for the file. Arguments: {\"filepath\": \"<file-path>\", \"content\": \"<new-content>\"}"
4521
)
4622
.build()
4723

src/main/scala/com/supercoder/tools/CodeSearchTool.scala

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
package com.supercoder.tools
22

3-
import com.openai.core.JsonValue
4-
import com.openai.models.{FunctionDefinition, FunctionParameters}
3+
import com.openai.models.FunctionDefinition
54
import com.supercoder.base.Tool
65
import com.supercoder.lib.Console.green
76
import io.circe.*
87
import io.circe.generic.auto.*
98
import io.circe.parser.*
109

11-
import java.util
1210
import scala.sys.process.*
1311

1412
case class CodeSearchToolArguments(query: String)
@@ -19,17 +17,7 @@ object CodeSearchTool extends Tool {
1917
.builder()
2018
.name("code-search")
2119
.description(
22-
"Search for code in a given repository. The query parameter should be a regular expression."
23-
)
24-
.parameters(
25-
FunctionParameters
26-
.builder()
27-
.putAdditionalProperty("type", JsonValue.from("object"))
28-
.putAdditionalProperty(
29-
"properties",
30-
JsonValue.from(util.Map.of("query", util.Map.of("type", "string")))
31-
)
32-
.build()
20+
"Search for code in a given repository. The query parameter should be a regular expression. Arguments: {\"query\": \"<search-query>\"}"
3321
)
3422
.build()
3523

0 commit comments

Comments
 (0)