diff --git a/samples/kotlin-mcp-server/src/main/kotlin/io/modelcontextprotocol/sample/server/server.kt b/samples/kotlin-mcp-server/src/main/kotlin/io/modelcontextprotocol/sample/server/server.kt index e707743d..2a85c74a 100644 --- a/samples/kotlin-mcp-server/src/main/kotlin/io/modelcontextprotocol/sample/server/server.kt +++ b/samples/kotlin-mcp-server/src/main/kotlin/io/modelcontextprotocol/sample/server/server.kt @@ -31,6 +31,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.types.TextContent import io.modelcontextprotocol.kotlin.sdk.types.TextResourceContents import kotlinx.coroutines.Job +import kotlinx.coroutines.awaitCancellation import kotlinx.coroutines.runBlocking import kotlinx.io.asSink import kotlinx.io.asSource @@ -102,13 +103,13 @@ fun configureServer(): Server { return server } -fun runSseMcpServerWithPlainConfiguration(port: Int, wait: Boolean = true) { +fun runSseMcpServerWithPlainConfiguration(port: Int, wait: Boolean = true): EmbeddedServer<*, *> { printBanner(port = port, path = "/sse") val serverSessions = ConcurrentMap() val server = configureServer() - embeddedServer(CIO, host = "127.0.0.1", port = port) { + val ktorServer = embeddedServer(CIO, host = "127.0.0.1", port = port) { installCors() install(SSE) routing { @@ -121,6 +122,7 @@ fun runSseMcpServerWithPlainConfiguration(port: Int, wait: Boolean = true) { println("Server session closed for: ${transport.sessionId}") serverSessions.remove(transport.sessionId) } + awaitCancellation() } post("/message") { val sessionId: String? = call.request.queryParameters["sessionId"] @@ -139,6 +141,8 @@ fun runSseMcpServerWithPlainConfiguration(port: Int, wait: Boolean = true) { } } }.start(wait = wait) + + return ktorServer } /** diff --git a/samples/kotlin-mcp-server/src/test/kotlin/McpServerType.kt b/samples/kotlin-mcp-server/src/test/kotlin/McpServerType.kt new file mode 100644 index 00000000..b39dc43c --- /dev/null +++ b/samples/kotlin-mcp-server/src/test/kotlin/McpServerType.kt @@ -0,0 +1,17 @@ +import io.ktor.server.engine.EmbeddedServer +import io.modelcontextprotocol.sample.server.runSseMcpServerUsingKtorPlugin +import io.modelcontextprotocol.sample.server.runSseMcpServerWithPlainConfiguration + +enum class McpServerType( + val sseEndpoint: String, + val serverFactory: (port: Int) -> EmbeddedServer<*, *> +) { + KTOR_PLUGIN( + sseEndpoint = "", + serverFactory = { port -> runSseMcpServerUsingKtorPlugin(port, wait = false) } + ), + PLAIN_CONFIGURATION( + sseEndpoint = "/sse", + serverFactory = { port -> runSseMcpServerWithPlainConfiguration(port, wait = false) } + ) +} diff --git a/samples/kotlin-mcp-server/src/test/kotlin/SseServerIntegrationTest.kt b/samples/kotlin-mcp-server/src/test/kotlin/SseServerIntegrationTest.kt index 9cbeddac..2e091ba5 100644 --- a/samples/kotlin-mcp-server/src/test/kotlin/SseServerIntegrationTest.kt +++ b/samples/kotlin-mcp-server/src/test/kotlin/SseServerIntegrationTest.kt @@ -11,9 +11,9 @@ import kotlin.test.assertNotNull import kotlin.test.assertNull import kotlin.test.assertTrue -class SseServerIntegrationTest { +abstract class SseServerIntegrationTestBase { - private val client: Client = TestEnvironment.client + abstract val client: Client @Test fun `should list tools`(): Unit = runBlocking { @@ -88,3 +88,13 @@ class SseServerIntegrationTest { assertEquals(expected = "text", actual = "${content.type}".lowercase()) } } + +class SseServerKtorPluginIntegrationTest : SseServerIntegrationTestBase() { + private val testEnvironment = TestEnvironment(McpServerType.KTOR_PLUGIN) + override val client: Client = testEnvironment.client +} + +class SseServerPlainConfigurationIntegrationTest : SseServerIntegrationTestBase() { + private val testEnvironment = TestEnvironment(McpServerType.PLAIN_CONFIGURATION) + override val client: Client = testEnvironment.client +} diff --git a/samples/kotlin-mcp-server/src/test/kotlin/TestEnvironment.kt b/samples/kotlin-mcp-server/src/test/kotlin/TestEnvironment.kt index 6d457ddc..71932c39 100644 --- a/samples/kotlin-mcp-server/src/test/kotlin/TestEnvironment.kt +++ b/samples/kotlin-mcp-server/src/test/kotlin/TestEnvironment.kt @@ -1,34 +1,34 @@ import io.ktor.client.HttpClient import io.ktor.client.engine.cio.CIO import io.ktor.client.plugins.sse.SSE -import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.ktor.server.engine.EmbeddedServer import io.modelcontextprotocol.kotlin.sdk.client.Client import io.modelcontextprotocol.kotlin.sdk.client.mcpSseTransport -import io.modelcontextprotocol.sample.server.runSseMcpServerUsingKtorPlugin +import io.modelcontextprotocol.kotlin.sdk.types.Implementation import kotlinx.coroutines.runBlocking import java.util.concurrent.TimeUnit -object TestEnvironment { +class TestEnvironment(private val serverConfig: McpServerType) { - val server = runSseMcpServerUsingKtorPlugin(0, wait = false) + val server: EmbeddedServer<*, *> = serverConfig.serverFactory(0) val client: Client init { client = runBlocking { val port = server.engine.resolvedConnectors().single().port - initClient(port) + initClient(port, serverConfig) } Runtime.getRuntime().addShutdownHook( Thread { - println("🏁 Shutting down server") + println("🏁 Shutting down server (${serverConfig.name})") server.stop(500, 700, TimeUnit.MILLISECONDS) println("☑️ Shutdown complete") }, ) } - private suspend fun initClient(port: Int): Client { + private suspend fun initClient(port: Int, config: McpServerType): Client { val client = Client( Implementation(name = "test-client", version = "0.1.0"), ) @@ -37,13 +37,7 @@ object TestEnvironment { install(SSE) } - // Create a transport wrapper that captures the session ID and received messages - val transport = httpClient.mcpSseTransport { - url { - this.host = "127.0.0.1" - this.port = port - } - } + val transport = httpClient.mcpSseTransport("http://127.0.0.1:$port/${config.sseEndpoint}") client.connect(transport) return client }