Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions openai-java-core/src/main/kotlin/com/openai/core/Futures.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
package com.openai.core

import java.util.concurrent.CancellationException
import java.util.concurrent.CompletableFuture
import java.util.concurrent.atomic.AtomicReference

@JvmSynthetic
internal fun <T, U> CompletableFuture<T>.thenApplyPropagatingCancellation(
fn: (T) -> U
): CompletableFuture<U> {
val upstream = this
val result = CompletableFuture<U>()

result.whenComplete { _, error ->
if (result.isCancelled || error is CancellationException) {
upstream.cancel(false)
}
}

upstream.whenComplete { value, error ->
if (result.isCancelled) {
return@whenComplete
}
if (error != null) {
result.completeExceptionally(error)
return@whenComplete
}

try {
result.complete(fn(value))
} catch (e: Throwable) {
result.completeExceptionally(e)
}
}

return result
}

@JvmSynthetic
internal fun <T, U> CompletableFuture<T>.handlePropagatingCancellation(
fn: (T?, Throwable?) -> U
): CompletableFuture<U> {
val upstream = this
val result = CompletableFuture<U>()

result.whenComplete { _, error ->
if (result.isCancelled || error is CancellationException) {
upstream.cancel(false)
}
}

upstream.whenComplete { value, error ->
if (result.isCancelled) {
return@whenComplete
}

try {
result.complete(fn(value, error))
} catch (e: Throwable) {
result.completeExceptionally(e)
}
}

return result
}

@JvmSynthetic
internal fun <T, U> CompletableFuture<T>.thenComposePropagatingCancellation(
fn: (T) -> CompletableFuture<U>
): CompletableFuture<U> {
val upstream = this
val innerFuture = AtomicReference<CompletableFuture<U>?>()
val result = CompletableFuture<U>()

result.whenComplete { _, error ->
if (result.isCancelled || error is CancellationException) {
upstream.cancel(false)
innerFuture.get()?.cancel(false)
}
}

upstream.whenComplete { value, error ->
if (result.isCancelled) {
return@whenComplete
}
if (error != null) {
result.completeExceptionally(error)
return@whenComplete
}

val inner =
try {
fn(value)
} catch (e: Throwable) {
result.completeExceptionally(e)
return@whenComplete
}
innerFuture.set(inner)

if (result.isCancelled) {
inner.cancel(false)
return@whenComplete
}

inner.whenComplete { innerValue, innerError ->
if (innerError != null) {
result.completeExceptionally(innerError)
} else {
result.complete(innerValue)
}
}
}

return result
}

@JvmSynthetic
internal fun <T, U> CompletableFuture<T>.thenComposeAsyncPropagatingCancellation(
fn: (T) -> CompletableFuture<U>
): CompletableFuture<U> {
val upstream = this
val innerFuture = AtomicReference<CompletableFuture<U>?>()
val result = CompletableFuture<U>()

result.whenComplete { _, error ->
if (result.isCancelled || error is CancellationException) {
upstream.cancel(false)
innerFuture.get()?.cancel(false)
}
}

upstream.whenCompleteAsync { value, error ->
if (result.isCancelled) {
return@whenCompleteAsync
}
if (error != null) {
result.completeExceptionally(error)
return@whenCompleteAsync
}

val inner =
try {
fn(value)
} catch (e: Throwable) {
result.completeExceptionally(e)
return@whenCompleteAsync
}
innerFuture.set(inner)

if (result.isCancelled) {
inner.cancel(false)
return@whenCompleteAsync
}

inner.whenComplete { innerValue, innerError ->
if (innerError != null) {
result.completeExceptionally(innerError)
} else {
result.complete(innerValue)
}
}
}

return result
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package com.openai.core.http
import com.openai.core.LogLevel
import com.openai.core.RequestOptions
import com.openai.core.checkRequired
import com.openai.core.handlePropagatingCancellation
import com.openai.core.toImmutable
import java.io.ByteArrayOutputStream
import java.io.InputStream
Expand Down Expand Up @@ -80,13 +81,13 @@ private constructor(
logFailure(e, Duration.between(before, OffsetDateTime.now(clock)))
throw e
}
return future.handle { response, error ->
return future.handlePropagatingCancellation { response, error ->
val took = Duration.between(before, OffsetDateTime.now(clock))
if (error != null) {
logFailure(unwrapCompletionException(error), took)
throw error
}
logResponse(response, took)
logResponse(response!!, took)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import com.openai.core.DefaultSleeper
import com.openai.core.RequestOptions
import com.openai.core.Sleeper
import com.openai.core.checkRequired
import com.openai.core.handlePropagatingCancellation
import com.openai.core.thenComposePropagatingCancellation
import com.openai.errors.OpenAIIoException
import com.openai.errors.OpenAIRetryableException
import java.io.IOException
Expand All @@ -19,7 +21,6 @@ import java.util.UUID
import java.util.concurrent.CompletableFuture
import java.util.concurrent.ThreadLocalRandom
import java.util.concurrent.TimeUnit
import java.util.function.Function
import kotlin.math.min
import kotlin.math.pow

Expand Down Expand Up @@ -98,35 +99,30 @@ private constructor(
}

return responseFuture
.handleAsync(
fun(
response: HttpResponse?,
throwable: Throwable?,
): CompletableFuture<HttpResponse> {
if (response != null) {
if (++retries > maxRetries || !shouldRetry(response)) {
return CompletableFuture.completedFuture(response)
}
} else {
if (++retries > maxRetries || !shouldRetry(throwable!!)) {
val failedFuture = CompletableFuture<HttpResponse>()
failedFuture.completeExceptionally(throwable)
return failedFuture
}
.handlePropagatingCancellation { response, throwable ->
if (response != null) {
if (++retries > maxRetries || !shouldRetry(response)) {
return@handlePropagatingCancellation CompletableFuture.completedFuture(
response
)
}

val backoffDuration = getRetryBackoffDuration(retries, response)
// All responses must be closed, so close the failed one before retrying.
response?.close()
return sleeper.sleepAsync(backoffDuration).thenCompose {
executeWithRetries(requestWithRetryCount, requestOptions)
} else {
if (++retries > maxRetries || !shouldRetry(throwable!!)) {
val failedFuture = CompletableFuture<HttpResponse>()
failedFuture.completeExceptionally(throwable)
return@handlePropagatingCancellation failedFuture
}
}
) {
// Run in the same thread.
it.run()

val backoffDuration = getRetryBackoffDuration(retries, response)
// All responses must be closed, so close the failed one before retrying.
response?.close()
sleeper.sleepAsync(backoffDuration).thenComposePropagatingCancellation {
_: Void? ->
executeWithRetries(requestWithRetryCount, requestOptions)
}
}
.thenCompose(Function.identity())
.thenComposePropagatingCancellation { it }
}

return executeWithRetries(modifiedRequest, requestOptions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import com.openai.core.http.HttpResponseFor
import com.openai.core.http.json
import com.openai.core.http.parseable
import com.openai.core.prepareAsync
import com.openai.core.thenApplyPropagatingCancellation
import com.openai.core.thenComposeAsyncPropagatingCancellation
import com.openai.models.batches.Batch
import com.openai.models.batches.BatchCancelParams
import com.openai.models.batches.BatchCreateParams
Expand Down Expand Up @@ -46,28 +48,36 @@ class BatchServiceAsyncImpl internal constructor(private val clientOptions: Clie
requestOptions: RequestOptions,
): CompletableFuture<Batch> =
// post /batches
withRawResponse().create(params, requestOptions).thenApply { it.parse() }
withRawResponse().create(params, requestOptions).thenApplyPropagatingCancellation {
it.parse()
}

override fun retrieve(
params: BatchRetrieveParams,
requestOptions: RequestOptions,
): CompletableFuture<Batch> =
// get /batches/{batch_id}
withRawResponse().retrieve(params, requestOptions).thenApply { it.parse() }
withRawResponse().retrieve(params, requestOptions).thenApplyPropagatingCancellation {
it.parse()
}

override fun list(
params: BatchListParams,
requestOptions: RequestOptions,
): CompletableFuture<BatchListPageAsync> =
// get /batches
withRawResponse().list(params, requestOptions).thenApply { it.parse() }
withRawResponse().list(params, requestOptions).thenApplyPropagatingCancellation {
it.parse()
}

override fun cancel(
params: BatchCancelParams,
requestOptions: RequestOptions,
): CompletableFuture<Batch> =
// post /batches/{batch_id}/cancel
withRawResponse().cancel(params, requestOptions).thenApply { it.parse() }
withRawResponse().cancel(params, requestOptions).thenApplyPropagatingCancellation {
it.parse()
}

class WithRawResponseImpl internal constructor(private val clientOptions: ClientOptions) :
BatchServiceAsync.WithRawResponse {
Expand Down Expand Up @@ -102,8 +112,10 @@ class BatchServiceAsyncImpl internal constructor(private val clientOptions: Clie
)
val requestOptions = requestOptions.applyDefaults(RequestOptions.from(clientOptions))
return request
.thenComposeAsync { clientOptions.httpClient.executeAsync(it, requestOptions) }
.thenApply { response ->
.thenComposeAsyncPropagatingCancellation {
clientOptions.httpClient.executeAsync(it, requestOptions)
}
.thenApplyPropagatingCancellation { response ->
errorHandler.handle(response).parseable {
response
.use { createHandler.handle(it) }
Expand Down Expand Up @@ -138,8 +150,10 @@ class BatchServiceAsyncImpl internal constructor(private val clientOptions: Clie
)
val requestOptions = requestOptions.applyDefaults(RequestOptions.from(clientOptions))
return request
.thenComposeAsync { clientOptions.httpClient.executeAsync(it, requestOptions) }
.thenApply { response ->
.thenComposeAsyncPropagatingCancellation {
clientOptions.httpClient.executeAsync(it, requestOptions)
}
.thenApplyPropagatingCancellation { response ->
errorHandler.handle(response).parseable {
response
.use { retrieveHandler.handle(it) }
Expand Down Expand Up @@ -172,8 +186,10 @@ class BatchServiceAsyncImpl internal constructor(private val clientOptions: Clie
)
val requestOptions = requestOptions.applyDefaults(RequestOptions.from(clientOptions))
return request
.thenComposeAsync { clientOptions.httpClient.executeAsync(it, requestOptions) }
.thenApply { response ->
.thenComposeAsyncPropagatingCancellation {
clientOptions.httpClient.executeAsync(it, requestOptions)
}
.thenApplyPropagatingCancellation { response ->
errorHandler.handle(response).parseable {
response
.use { listHandler.handle(it) }
Expand Down Expand Up @@ -217,8 +233,10 @@ class BatchServiceAsyncImpl internal constructor(private val clientOptions: Clie
)
val requestOptions = requestOptions.applyDefaults(RequestOptions.from(clientOptions))
return request
.thenComposeAsync { clientOptions.httpClient.executeAsync(it, requestOptions) }
.thenApply { response ->
.thenComposeAsyncPropagatingCancellation {
clientOptions.httpClient.executeAsync(it, requestOptions)
}
.thenApplyPropagatingCancellation { response ->
errorHandler.handle(response).parseable {
response
.use { cancelHandler.handle(it) }
Expand Down
Loading