Skip to content

Commit c665e21

Browse files
fix(client): cancel okhttp call when future cancelled
1 parent b6ae698 commit c665e21

File tree

3 files changed

+65
-14
lines changed

3 files changed

+65
-14
lines changed

openai-java-client-okhttp/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ dependencies {
1111

1212
testImplementation(kotlin("test"))
1313
testImplementation("org.assertj:assertj-core:3.25.3")
14+
testImplementation("com.github.tomakehurst:wiremock-jre8:2.35.2")
1415
}

openai-java-client-okhttp/src/main/kotlin/com/openai/client/okhttp/OkHttpClient.kt

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import java.io.IOException
1313
import java.io.InputStream
1414
import java.net.Proxy
1515
import java.time.Duration
16+
import java.util.concurrent.CancellationException
1617
import java.util.concurrent.CompletableFuture
1718
import javax.net.ssl.HostnameVerifier
1819
import javax.net.ssl.SSLSocketFactory
@@ -29,8 +30,8 @@ import okhttp3.Response
2930
import okhttp3.logging.HttpLoggingInterceptor
3031
import okio.BufferedSink
3132

32-
class OkHttpClient private constructor(private val okHttpClient: okhttp3.OkHttpClient) :
33-
HttpClient {
33+
class OkHttpClient
34+
private constructor(@JvmSynthetic internal val okHttpClient: okhttp3.OkHttpClient) : HttpClient {
3435

3536
override fun execute(request: HttpRequest, requestOptions: RequestOptions): HttpResponse {
3637
val call = newCall(request, requestOptions)
@@ -50,20 +51,25 @@ class OkHttpClient private constructor(private val okHttpClient: okhttp3.OkHttpC
5051
): CompletableFuture<HttpResponse> {
5152
val future = CompletableFuture<HttpResponse>()
5253

53-
request.body?.run { future.whenComplete { _, _ -> close() } }
54-
55-
newCall(request, requestOptions)
56-
.enqueue(
57-
object : Callback {
58-
override fun onResponse(call: Call, response: Response) {
59-
future.complete(response.toResponse())
60-
}
54+
val call = newCall(request, requestOptions)
55+
call.enqueue(
56+
object : Callback {
57+
override fun onResponse(call: Call, response: Response) {
58+
future.complete(response.toResponse())
59+
}
6160

62-
override fun onFailure(call: Call, e: IOException) {
63-
future.completeExceptionally(OpenAIIoException("Request failed", e))
64-
}
61+
override fun onFailure(call: Call, e: IOException) {
62+
future.completeExceptionally(OpenAIIoException("Request failed", e))
6563
}
66-
)
64+
}
65+
)
66+
67+
future.whenComplete { _, e ->
68+
if (e is CancellationException) {
69+
call.cancel()
70+
}
71+
request.body?.close()
72+
}
6773

6874
return future
6975
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package com.openai.client.okhttp
2+
3+
import com.github.tomakehurst.wiremock.client.WireMock.*
4+
import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo
5+
import com.github.tomakehurst.wiremock.junit5.WireMockTest
6+
import com.openai.core.http.HttpMethod
7+
import com.openai.core.http.HttpRequest
8+
import org.assertj.core.api.Assertions.assertThat
9+
import org.junit.jupiter.api.BeforeEach
10+
import org.junit.jupiter.api.Test
11+
import org.junit.jupiter.api.parallel.ResourceLock
12+
13+
@WireMockTest
14+
@ResourceLock("https://github.com/wiremock/wiremock/issues/169")
15+
internal class OkHttpClientTest {
16+
17+
private lateinit var baseUrl: String
18+
private lateinit var httpClient: OkHttpClient
19+
20+
@BeforeEach
21+
fun beforeEach(wmRuntimeInfo: WireMockRuntimeInfo) {
22+
baseUrl = wmRuntimeInfo.httpBaseUrl
23+
httpClient = OkHttpClient.builder().build()
24+
}
25+
26+
@Test
27+
fun executeAsync_whenFutureCancelled_cancelsUnderlyingCall() {
28+
stubFor(post(urlPathEqualTo("/something")).willReturn(ok()))
29+
val responseFuture =
30+
httpClient.executeAsync(
31+
HttpRequest.builder()
32+
.method(HttpMethod.POST)
33+
.baseUrl(baseUrl)
34+
.addPathSegment("something")
35+
.build()
36+
)
37+
val call = httpClient.okHttpClient.dispatcher.runningCalls().single()
38+
39+
responseFuture.cancel(false)
40+
41+
// Should have cancelled the underlying call
42+
assertThat(call.isCanceled()).isTrue()
43+
}
44+
}

0 commit comments

Comments
 (0)