Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions baseten-performance-client/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion baseten-performance-client/core/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "baseten_performance_client_core"
version = "0.0.11"
version = "0.0.12-rc.1"
edition = "2021"
description = "High performance HTTP client for Baseten.co and other APIs"
license = "MIT"
Expand Down
104 changes: 89 additions & 15 deletions baseten-performance-client/core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ impl PerformanceClientCore {
let mut indexed_results: Vec<(usize, R, Duration, usize)> =
Vec::with_capacity(total_requests);

let max_retries = config.max_retries();
let mut current_absolute_index = 0;
for (batch_index, batch) in batches.into_iter().enumerate() {
let current_batch_absolute_start_index = current_absolute_index;
Expand All @@ -275,7 +276,7 @@ impl PerformanceClientCore {

let request_time_start = Instant::now();
let config = SendRequestConfig {
max_retries: MAX_HTTP_RETRIES,
max_retries: max_retries,
initial_backoff: Duration::from_millis(INITIAL_BACKOFF_MS),
retry_budget: retry_budget,
cancel_token: cancel_token.clone(),
Expand Down Expand Up @@ -305,18 +306,39 @@ impl PerformanceClientCore {
}

// Process results as they complete with fast-fail cancellation
while let Some(task_result) = join_set.join_next().await {
match process_joinset_outcome(task_result, &cancel_token) {
Ok((response, duration, start_index, batch_index)) => {
indexed_results.push((batch_index, response, duration, start_index));
let process_results = async {
while let Some(task_result) = join_set.join_next().await {
match process_joinset_outcome(task_result, &cancel_token) {
Ok((response, duration, start_index, batch_index)) => {
indexed_results.push((batch_index, response, duration, start_index));
}
Err(e) => {
// Cancel all remaining tasks immediately
cancel_token.store(true, Ordering::SeqCst);
join_set.abort_all();
return Err(e);
}
}
Err(e) => {
// Cancel all remaining tasks immediately
}
Ok(())
};

// Apply total timeout if configured
if let Some(total_timeout) = config.total_timeout_duration() {
match tokio::time::timeout(total_timeout, process_results).await {
Ok(Ok(())) => {}
Ok(Err(e)) => return Err(e),
Err(_) => {
cancel_token.store(true, Ordering::SeqCst);
join_set.abort_all();
return Err(e);
return Err(ClientError::Timeout(format!(
"Batch operation timed out after {:.3}s",
total_timeout.as_secs_f64()
)));
}
}
} else {
process_results.await?;
}

// Sort results by original batch order to preserve ordering
Expand Down Expand Up @@ -367,6 +389,8 @@ impl PerformanceClientCore {
max_chars_per_request: Option<usize>,
timeout_s: f64,
hedge_delay: Option<f64>,
total_timeout_s: Option<f64>,
max_retries: Option<i64>,
Comment on lines 367 to +393

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Update node bindings for added timeout/retry params

The core API now requires total_timeout_s and max_retries for process_embeddings_requests (and the other batched helpers), but the Node bindings still invoke these methods with the old argument list (e.g., node_bindings/src/lib.rs calls them only through hedge_delay). Building the workspace or running cargo build -p node_bindings now fails with a mismatched-argument-count compile error because the new parameters are never supplied.

Useful? React with 👍 / 👎.

) -> Result<(CoreOpenAIEmbeddingsResponse, Vec<Duration>, Duration), ClientError> {
// Create and validate config
let config = RequestProcessingConfig::new(
Expand All @@ -376,6 +400,8 @@ impl PerformanceClientCore {
self.base_url.to_string(),
hedge_delay,
max_chars_per_request,
total_timeout_s,
max_retries,
)?;

// Create batches
Expand Down Expand Up @@ -426,6 +452,8 @@ impl PerformanceClientCore {
max_chars_per_request: Option<usize>,
timeout_s: f64,
hedge_delay: Option<f64>,
total_timeout_s: Option<f64>,
max_retries: Option<i64>,
) -> Result<(CoreRerankResponse, Vec<Duration>, Duration), ClientError> {
// Create and validate config
let config = RequestProcessingConfig::new(
Expand All @@ -435,6 +463,8 @@ impl PerformanceClientCore {
self.base_url.to_string(),
hedge_delay,
max_chars_per_request,
total_timeout_s,
max_retries,
)?;

// Create batches
Expand Down Expand Up @@ -487,6 +517,8 @@ impl PerformanceClientCore {
max_chars_per_request: Option<usize>,
timeout_s: f64,
hedge_delay: Option<f64>,
total_timeout_s: Option<f64>,
max_retries: Option<i64>,
) -> Result<(CoreClassificationResponse, Vec<Duration>, Duration), ClientError> {
// Create and validate config
let config = RequestProcessingConfig::new(
Expand All @@ -496,6 +528,8 @@ impl PerformanceClientCore {
self.base_url.to_string(),
hedge_delay,
max_chars_per_request,
total_timeout_s,
max_retries,
)?;

// Create batches
Expand Down Expand Up @@ -544,6 +578,7 @@ impl PerformanceClientCore {
max_concurrent_requests: usize,
timeout_s: f64,
hedge_delay: Option<f64>,
total_timeout_s: Option<f64>,
) -> Result<
(
Vec<(
Expand All @@ -560,6 +595,23 @@ impl PerformanceClientCore {
// Validate parameters internally (using batch_size of 128 for validation)
let (validated_concurrency, request_timeout_duration) =
self.validate_request_parameters(max_concurrent_requests, 128, timeout_s)?;

// Validate total_timeout_s if provided
if let Some(total_timeout) = total_timeout_s {
if !(MIN_TOTAL_TIMEOUT_S..=MAX_TOTAL_TIMEOUT_S).contains(&total_timeout) {
return Err(ClientError::InvalidParameter(format!(
"Total timeout {:.3}s is outside the allowed range [{:.3}s, {:.3}s].",
total_timeout, MIN_TOTAL_TIMEOUT_S, MAX_TOTAL_TIMEOUT_S
)));
}
if total_timeout < timeout_s {
return Err(ClientError::InvalidParameter(format!(
"Total timeout {:.3}s must be greater than or equal to per-request timeout {:.3}s.",
total_timeout, timeout_s
)));
}
}

let semaphore = Arc::new(Semaphore::new(validated_concurrency));
let cancel_token = Arc::new(AtomicBool::new(false));
let total_payloads = payloads_json.len();
Expand Down Expand Up @@ -653,18 +705,40 @@ impl PerformanceClientCore {
}

// Process results as they complete with fast-fail cancellation
while let Some(task_result) = join_set.join_next().await {
match process_joinset_outcome(task_result, &cancel_token) {
Ok(indexed_data) => {
indexed_results.push(indexed_data);
let process_results = async {
while let Some(task_result) = join_set.join_next().await {
match process_joinset_outcome(task_result, &cancel_token) {
Ok(indexed_data) => {
indexed_results.push(indexed_data);
}
Err(e) => {
// Cancel all remaining tasks immediately
cancel_token.store(true, Ordering::SeqCst);
join_set.abort_all();
return Err(e);
}
}
Err(e) => {
// Cancel all remaining tasks immediately
}
Ok(())
};

// Apply total timeout if configured
if let Some(total_timeout) = total_timeout_s {
let total_timeout_duration = Duration::from_secs_f64(total_timeout);
match tokio::time::timeout(total_timeout_duration, process_results).await {
Ok(Ok(())) => {}
Ok(Err(e)) => return Err(e),
Err(_) => {
cancel_token.store(true, Ordering::SeqCst);
join_set.abort_all();
return Err(e);
return Err(ClientError::Timeout(format!(
"Batch post operation timed out after {:.3}s",
total_timeout
)));
}
}
} else {
process_results.await?;
}

indexed_results.sort_by_key(|&(original_index, _, _, _)| original_index);
Expand Down
8 changes: 6 additions & 2 deletions baseten-performance-client/core/src/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@ use std::time::Duration;

// Request timeout constants
pub const DEFAULT_REQUEST_TIMEOUT_S: f64 = 3600.0;
pub const MIN_REQUEST_TIMEOUT_S: f64 = 1.0;
pub const MIN_REQUEST_TIMEOUT_S: f64 = 0.1;
pub const MAX_REQUEST_TIMEOUT_S: f64 = 3600.0;

// Total timeout constants
pub const MIN_TOTAL_TIMEOUT_S: f64 = 0.1;
pub const MAX_TOTAL_TIMEOUT_S: f64 = 3600.0;

// Concurrency constants
pub const MAX_CONCURRENCY_HIGH_BATCH: usize = 1024;
pub const MAX_CONCURRENCY_LOW_BATCH: usize = 512;
Expand All @@ -14,7 +18,7 @@ pub const MIN_CHARACTERS_PER_REQUEST: usize = 50;
pub const MAX_CHARACTERS_PER_REQUEST: usize = 256000;

// hedging settings:
pub const MIN_HEDGE_DELAY_S: f64 = 0.2;
pub const MIN_HEDGE_DELAY_S: f64 = 0.1;
pub const HEDGE_BUDGET_PERCENTAGE: f64 = 0.10;

// Batch size constants
Expand Down
49 changes: 49 additions & 0 deletions baseten-performance-client/core/src/split_policy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ pub struct RequestProcessingConfig {
pub base_url: String,
pub hedge_delay: Option<f64>,
pub max_chars_per_request: Option<usize>,
pub total_timeout_s: Option<f64>,
pub max_retries: Option<u32>,
}

impl RequestProcessingConfig {
Expand All @@ -33,6 +35,8 @@ impl RequestProcessingConfig {
base_url: String,
hedge_delay: Option<f64>,
max_chars_per_request: Option<usize>,
total_timeout_s: Option<f64>,
max_retries: Option<i64>,
) -> Result<Self, crate::errors::ClientError> {
// Validate timeout
if !(MIN_REQUEST_TIMEOUT_S..=MAX_REQUEST_TIMEOUT_S).contains(&timeout_s) {
Expand Down Expand Up @@ -65,6 +69,39 @@ impl RequestProcessingConfig {
)));
}
}
if total_timeout_s.is_some() {
let total_timeout = total_timeout_s.unwrap();
if !(MIN_TOTAL_TIMEOUT_S..=MAX_TOTAL_TIMEOUT_S).contains(&total_timeout) {
return Err(crate::errors::ClientError::InvalidParameter(format!(
"Total timeout {:.3}s is outside the allowed range [{:.3}s, {:.3}s].",
total_timeout, MIN_TOTAL_TIMEOUT_S, MAX_TOTAL_TIMEOUT_S
)));
}
if total_timeout < timeout_s {
return Err(crate::errors::ClientError::InvalidParameter(format!(
"Total timeout {:.3}s must be greater than or equal to per-request timeout {:.3}s.",
total_timeout, timeout_s
)));
}
}
// Validate and convert max_retries from i64 to u32
let max_retries_u32 = if let Some(retries) = max_retries {
if retries < 0 {
return Err(crate::errors::ClientError::InvalidParameter(format!(
"max_retries must be non-negative, got {}",
retries
)));
}
if retries > MAX_HTTP_RETRIES as i64 {
return Err(crate::errors::ClientError::InvalidParameter(format!(
"max_retries {} exceeds maximum allowed retries {}",
retries, MAX_HTTP_RETRIES
)));
}
Some(retries as u32)
} else {
None
};

// Validate concurrency parameters
if max_concurrent_requests == 0 || max_concurrent_requests > MAX_CONCURRENCY_HIGH_BATCH {
Expand Down Expand Up @@ -93,13 +130,25 @@ impl RequestProcessingConfig {
base_url,
hedge_delay,
max_chars_per_request,
total_timeout_s,
max_retries: max_retries_u32,
})
}

/// Get timeout duration
pub fn timeout_duration(&self) -> std::time::Duration {
std::time::Duration::from_secs_f64(self.timeout_s)
}

/// Get total timeout duration if set
pub fn total_timeout_duration(&self) -> Option<std::time::Duration> {
self.total_timeout_s.map(|s| std::time::Duration::from_secs_f64(s))
}

/// Get max retries, defaulting to MAX_HTTP_RETRIES if not set
pub fn max_retries(&self) -> u32 {
self.max_retries.unwrap_or(MAX_HTTP_RETRIES)
}
}

impl SplitPolicy {
Expand Down
Loading
Loading