diff --git a/architecture/gateway.md b/architecture/gateway.md index d89706e64..b91db4c13 100644 --- a/architecture/gateway.md +++ b/architecture/gateway.md @@ -75,6 +75,7 @@ The storage schema is intentionally narrow: | `version` | Optional monotonically increasing version for scoped records. | | `status` | Optional workflow state for records such as policy revisions or draft policy chunks. | | `dedup_key` and `hit_count` | Optional policy-advisor fields for coalescing repeated observations. | +| `resource_version` | Monotonically increasing counter for optimistic concurrency control. Incremented atomically on each update. | | `payload` | Prost-encoded protobuf payload for the full domain object. | | `created_at_ms` and `updated_at_ms` | Gateway timestamps used for ordering and list output. | | `labels` | JSON object carrying Kubernetes-style object labels for filtering and organization. | @@ -99,6 +100,44 @@ scope semantics. Persisted state includes sandboxes, providers, SSH sessions, policy revisions, settings, inference configuration, and deployment records. +### Optimistic Concurrency (CAS) + +Every object row carries a `resource_version` that the database increments +atomically on each write. Concurrent mutations use compare-and-swap (CAS): the +writer reads the current version, applies changes, and writes back with a +`WHERE resource_version = ` guard. If another writer updated the row +in between, the guard fails and the caller retries with fresh state. + +This matters for HA deployments where multiple gateway replicas share the same +Postgres database, and for single-node deployments where concurrent gRPC +handlers or the reconciler mutate the same sandbox. + +**When to use CAS** -- any mutation that merges caller-supplied fields into an +existing object: + +- Provider credential and config updates (merge maps). +- Sandbox provider attach/detach (append/remove from a list). +- Policy version bumps and draft operations. +- Compute status updates (sandbox phase transitions and reconciliation). + +**When CAS is not needed** -- create operations that generate a unique ID +(conflicts are caught by the primary key constraint), unconditional deletes, +and idempotent overwrites where the full payload is self-contained. + +The `update_message_cas` helper encapsulates the retry loop: it fetches the +latest object, applies a mutation closure, and attempts the conditional write. +On conflict it re-fetches and retries, up to a bounded limit of 5 attempts. +If the budget is exhausted the persistence layer returns a `Conflict` error, +which gRPC handlers map to `ABORTED` status so clients can retry with current +data. + +Settings updates are an exception: they use a Tokio `Mutex` instead of CAS +because settings operations require multi-step validation that is simpler under +an exclusive lock than within a retry loop. + +The `resource_version` is surfaced to clients through `ObjectMeta` in proto +responses. Database migrations backfill existing rows with version 1. + Policy and runtime settings are delivered together through the effective sandbox config path. A gateway-global policy can override sandbox-scoped policy. The sandbox supervisor polls for config revisions and hot-reloads dynamic policy diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index 165713b6e..0a1a3c2c3 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -2280,6 +2280,11 @@ pub async fn sandbox_get( println!(" {} {}", "Id:".dimmed(), id); println!(" {} {}", "Name:".dimmed(), name); println!(" {} {}", "Phase:".dimmed(), phase_name(sandbox.phase)); + println!( + " {} {}", + "Resource version:".dimmed(), + sandbox.metadata.as_ref().map_or(0, |m| m.resource_version) + ); // Display labels if present if let Some(metadata) = &sandbox.metadata @@ -2974,6 +2979,7 @@ async fn auto_create_provider( name: exact_name.to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: provider_type.to_string(), credentials: discovered.credentials.clone(), @@ -3014,6 +3020,7 @@ async fn auto_create_provider( name: name.clone(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: provider_type.to_string(), credentials: discovered.credentials.clone(), @@ -3196,6 +3203,7 @@ pub async fn provider_create( name: name.to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: provider_type.clone(), credentials: credential_map, @@ -3240,6 +3248,11 @@ pub async fn provider_get(server: &str, name: &str, tls: &TlsOptions) -> Result< println!(" {} {}", "Id:".dimmed(), provider.object_id()); println!(" {} {}", "Name:".dimmed(), provider.object_name()); println!(" {} {}", "Type:".dimmed(), provider.r#type); + println!( + " {} {}", + "Resource version:".dimmed(), + provider.metadata.as_ref().map_or(0, |m| m.resource_version) + ); println!( " {} {}", "Credential keys:".dimmed(), @@ -3696,6 +3709,7 @@ pub async fn provider_update( name: name.to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: String::new(), credentials: credential_map, diff --git a/crates/openshell-cli/tests/ensure_providers_integration.rs b/crates/openshell-cli/tests/ensure_providers_integration.rs index fec161c53..34a5b572b 100644 --- a/crates/openshell-cli/tests/ensure_providers_integration.rs +++ b/crates/openshell-cli/tests/ensure_providers_integration.rs @@ -113,6 +113,7 @@ impl TestOpenShell { name: name.to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: provider_type.to_string(), credentials: HashMap::new(), @@ -347,6 +348,7 @@ impl OpenShell for TestOpenShell { name: provider_metadata.name, created_at_ms: existing_metadata.created_at_ms, labels: existing_metadata.labels, + resource_version: 0, }), r#type: existing.r#type, credentials: merge(existing.credentials, provider.credentials), diff --git a/crates/openshell-cli/tests/provider_commands_integration.rs b/crates/openshell-cli/tests/provider_commands_integration.rs index 3902bda34..9f68a8b7a 100644 --- a/crates/openshell-cli/tests/provider_commands_integration.rs +++ b/crates/openshell-cli/tests/provider_commands_integration.rs @@ -445,6 +445,7 @@ impl OpenShell for TestOpenShell { name: provider_metadata.name, created_at_ms: existing_metadata.created_at_ms, labels: existing_metadata.labels, + resource_version: 0, }), r#type: existing.r#type, credentials: merge(existing.credentials, provider.credentials), diff --git a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs index eb28a18b3..29e6bc873 100644 --- a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs +++ b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs @@ -121,6 +121,7 @@ impl OpenShell for TestOpenShell { name: sandbox_name, created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), phase: SandboxPhase::Provisioning as i32, ..Sandbox::default() @@ -140,6 +141,7 @@ impl OpenShell for TestOpenShell { name, created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), phase: SandboxPhase::Ready as i32, ..Sandbox::default() @@ -325,6 +327,7 @@ impl OpenShell for TestOpenShell { name: sandbox_id.trim_start_matches("id-").to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), phase: SandboxPhase::Provisioning as i32, ..Sandbox::default() diff --git a/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs b/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs index 7e6ea68b8..8f98768e2 100644 --- a/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs +++ b/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs @@ -119,6 +119,7 @@ impl OpenShell for TestOpenShell { name, created_at_ms: 0, labels: std::collections::HashMap::new(), + resource_version: 0, }), ..Default::default() }), diff --git a/crates/openshell-core/src/lib.rs b/crates/openshell-core/src/lib.rs index a4a1ea822..856db6583 100644 --- a/crates/openshell-core/src/lib.rs +++ b/crates/openshell-core/src/lib.rs @@ -22,7 +22,7 @@ pub mod settings; pub use config::{ComputeDriverKind, Config, OidcConfig, TlsConfig}; pub use error::{ComputeDriverError, Error, Result}; -pub use metadata::{ObjectId, ObjectLabels, ObjectName}; +pub use metadata::{GetResourceVersion, ObjectId, ObjectLabels, ObjectName, SetResourceVersion}; /// Build version string derived from git metadata. /// diff --git a/crates/openshell-core/src/metadata.rs b/crates/openshell-core/src/metadata.rs index e7ffea61a..b9fd17aa6 100644 --- a/crates/openshell-core/src/metadata.rs +++ b/crates/openshell-core/src/metadata.rs @@ -25,6 +25,16 @@ pub trait ObjectLabels { fn object_labels(&self) -> Option>; } +/// Provides mutable access to set the object's resource version from persistence. +pub trait SetResourceVersion { + fn set_resource_version(&mut self, version: u64); +} + +/// Provides read access to the object's current resource version. +pub trait GetResourceVersion { + fn get_resource_version(&self) -> u64; +} + // Implementations for Sandbox impl ObjectId for Sandbox { fn object_id(&self) -> &str { @@ -44,6 +54,20 @@ impl ObjectLabels for Sandbox { } } +impl SetResourceVersion for Sandbox { + fn set_resource_version(&mut self, version: u64) { + if let Some(meta) = self.metadata.as_mut() { + meta.resource_version = version; + } + } +} + +impl GetResourceVersion for Sandbox { + fn get_resource_version(&self) -> u64 { + self.metadata.as_ref().map_or(0, |m| m.resource_version) + } +} + // Implementations for Provider impl ObjectId for Provider { fn object_id(&self) -> &str { @@ -63,6 +87,20 @@ impl ObjectLabels for Provider { } } +impl SetResourceVersion for Provider { + fn set_resource_version(&mut self, version: u64) { + if let Some(meta) = self.metadata.as_mut() { + meta.resource_version = version; + } + } +} + +impl GetResourceVersion for Provider { + fn get_resource_version(&self) -> u64 { + self.metadata.as_ref().map_or(0, |m| m.resource_version) + } +} + // Implementations for StoredProviderProfile impl ObjectId for StoredProviderProfile { fn object_id(&self) -> &str { @@ -82,6 +120,20 @@ impl ObjectLabels for StoredProviderProfile { } } +impl SetResourceVersion for StoredProviderProfile { + fn set_resource_version(&mut self, version: u64) { + if let Some(meta) = self.metadata.as_mut() { + meta.resource_version = version; + } + } +} + +impl GetResourceVersion for StoredProviderProfile { + fn get_resource_version(&self) -> u64 { + self.metadata.as_ref().map_or(0, |m| m.resource_version) + } +} + // Implementations for SshSession impl ObjectId for SshSession { fn object_id(&self) -> &str { @@ -101,6 +153,20 @@ impl ObjectLabels for SshSession { } } +impl SetResourceVersion for SshSession { + fn set_resource_version(&mut self, version: u64) { + if let Some(meta) = self.metadata.as_mut() { + meta.resource_version = version; + } + } +} + +impl GetResourceVersion for SshSession { + fn get_resource_version(&self) -> u64 { + self.metadata.as_ref().map_or(0, |m| m.resource_version) + } +} + // Implementations for InferenceRoute impl ObjectId for InferenceRoute { fn object_id(&self) -> &str { @@ -120,6 +186,20 @@ impl ObjectLabels for InferenceRoute { } } +impl SetResourceVersion for InferenceRoute { + fn set_resource_version(&mut self, version: u64) { + if let Some(meta) = self.metadata.as_mut() { + meta.resource_version = version; + } + } +} + +impl GetResourceVersion for InferenceRoute { + fn get_resource_version(&self) -> u64 { + self.metadata.as_ref().map_or(0, |m| m.resource_version) + } +} + // Implementations for ObjectForTest (test-only proto type) impl ObjectId for ObjectForTest { fn object_id(&self) -> &str { @@ -138,3 +218,16 @@ impl ObjectLabels for ObjectForTest { None } } + +impl SetResourceVersion for ObjectForTest { + fn set_resource_version(&mut self, _version: u64) { + // ObjectForTest doesn't have metadata, so this is a no-op + } +} + +impl GetResourceVersion for ObjectForTest { + fn get_resource_version(&self) -> u64 { + // ObjectForTest doesn't have metadata + 0 + } +} diff --git a/crates/openshell-server/migrations/postgres/005_add_resource_version.sql b/crates/openshell-server/migrations/postgres/005_add_resource_version.sql new file mode 100644 index 000000000..e6a294d62 --- /dev/null +++ b/crates/openshell-server/migrations/postgres/005_add_resource_version.sql @@ -0,0 +1,5 @@ +-- Add resource_version column for optimistic concurrency control +ALTER TABLE objects ADD COLUMN resource_version BIGINT NOT NULL DEFAULT 1; + +-- Backfill existing rows with resource_version = 1 +-- (DEFAULT clause handles this automatically for existing rows in PostgreSQL) diff --git a/crates/openshell-server/migrations/sqlite/005_add_resource_version.sql b/crates/openshell-server/migrations/sqlite/005_add_resource_version.sql new file mode 100644 index 000000000..50aacb99d --- /dev/null +++ b/crates/openshell-server/migrations/sqlite/005_add_resource_version.sql @@ -0,0 +1,5 @@ +-- Add resource_version column for optimistic concurrency control +ALTER TABLE objects ADD COLUMN resource_version INTEGER NOT NULL DEFAULT 1; + +-- Backfill existing rows with resource_version = 1 +-- (DEFAULT clause handles this automatically for existing rows in SQLite) diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs index d2fd34011..a60ec59aa 100644 --- a/crates/openshell-server/src/compute/mod.rs +++ b/crates/openshell-server/src/compute/mod.rs @@ -450,6 +450,13 @@ impl ComputeRuntime { { Ok(_) => { self.sandbox_watch_bus.notify(sandbox.object_id()); + // Read back from DB to get correct resource_version + let sandbox = self + .store + .get_message_by_name::(sandbox.object_name()) + .await + .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? + .ok_or_else(|| Status::internal("sandbox disappeared after creation"))?; Ok(sandbox) } Err(status) if status.code() == Code::AlreadyExists => { @@ -483,22 +490,30 @@ impl ComputeRuntime { } pub async fn delete_sandbox(&self, name: &str) -> Result { + // Resolve sandbox ID from name let sandbox = self .store .get_message_by_name::(name) .await .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))?; - let Some(mut sandbox) = sandbox else { + let Some(sandbox) = sandbox else { return Err(Status::not_found("sandbox not found")); }; let id = sandbox.object_id().to_string(); - sandbox.phase = SandboxPhase::Deleting as i32; - self.store - .put_message(&sandbox) + + // Use CAS to set phase to Deleting with bounded retry + let sandbox = self + .store + .update_message_cas::(&id, 5, |s| { + s.phase = SandboxPhase::Deleting as i32; + }) .await - .map_err(|e| Status::internal(format!("persist sandbox failed: {e}")))?; + .map_err(|e| { + crate::grpc::persistence_error_to_status(e, "set sandbox phase to Deleting") + })?; + self.sandbox_index.update_from_sandbox(&sandbox); self.sandbox_watch_bus.notify(&id); self.cleanup_sandbox_owned_records(&sandbox).await; @@ -811,86 +826,138 @@ impl ComputeRuntime { .as_ref() .map(decode_sandbox_record) .transpose()?; - let previous = existing.clone(); - - let mut status = incoming.status.as_ref().map(public_status_from_driver); - rewrite_user_facing_conditions( - &mut status, - existing.as_ref().and_then(|sandbox| sandbox.spec.as_ref()), - ); - let session_connected = self.supervisor_sessions.has_session(&incoming.id); - let mut phase = derive_phase(incoming.status.as_ref()); - let mut sandbox = existing.unwrap_or_else(|| { - use crate::persistence::current_time_ms; + // If no existing record, create initial sandbox (first watch event for this sandbox) + if existing.is_none() { + use crate::persistence::{WriteCondition, current_time_ms}; let now_ms = current_time_ms().unwrap_or(0); - Sandbox { + + let mut status = incoming.status.as_ref().map(public_status_from_driver); + rewrite_user_facing_conditions(&mut status, None); + + let session_connected = self.supervisor_sessions.has_session(&incoming.id); + let mut phase = derive_phase(incoming.status.as_ref()); + + let sandbox_name = incoming.name.clone(); + if session_connected + && matches!(phase, SandboxPhase::Provisioning | SandboxPhase::Unknown) + { + ensure_supervisor_ready_status(&mut status, &sandbox_name); + phase = SandboxPhase::Ready; + } + + let sandbox = Sandbox { metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { id: incoming.id.clone(), - name: incoming.name.clone(), + name: sandbox_name, created_at_ms: now_ms, labels: std::collections::HashMap::new(), + resource_version: 0, }), spec: None, - status: None, - phase: SandboxPhase::Unknown as i32, + status, + phase: phase as i32, current_policy_version: 0, - } - }); + }; - if session_connected && matches!(phase, SandboxPhase::Provisioning | SandboxPhase::Unknown) - { - ensure_supervisor_ready_status(&mut status, sandbox.object_name()); - phase = SandboxPhase::Ready; - } + self.store + .put_if( + Sandbox::object_type(), + &incoming.id, + sandbox.object_name(), + &sandbox.encode_to_vec(), + None, + WriteCondition::MustCreate, + ) + .await + .map_err(|e| match e { + crate::persistence::PersistenceError::Conflict { + current_resource_version, + } => format!( + "concurrent modification detected during sandbox creation (current resource_version: {})", + current_resource_version + .map_or_else(|| "unknown".to_string(), |v| v.to_string()) + ), + other => other.to_string(), + })?; - let old_phase = SandboxPhase::try_from(sandbox.phase).unwrap_or(SandboxPhase::Unknown); - if old_phase != phase { - info!( - sandbox_id = %incoming.id, - sandbox_name = %incoming.name, - old_phase = ?old_phase, - new_phase = ?phase, - "Sandbox phase changed" - ); + self.sandbox_index.update_from_sandbox(&sandbox); + self.sandbox_watch_bus.notify(sandbox.object_id()); + return Ok(()); } - if phase == SandboxPhase::Error - && let Some(ref status) = status - { - for condition in &status.conditions { - if condition.r#type == "Ready" - && condition.status.eq_ignore_ascii_case("false") - && is_terminal_failure_reason(&condition.reason) + // Use CAS to update existing sandbox (prevents lost updates in HA deployments with concurrent watch events) + // 5 retries = ~5ms max latency under moderate contention from multiple gateway replicas + // Capture external state once to ensure all retries use the same snapshot + let session_connected = self.supervisor_sessions.has_session(&incoming.id); + let sandbox_name = incoming.name.clone(); + + let sandbox = self + .store + .update_message_cas::(&incoming.id, 5, |sandbox| { + let mut status = incoming.status.as_ref().map(public_status_from_driver); + rewrite_user_facing_conditions(&mut status, sandbox.spec.as_ref()); + + let mut phase = derive_phase(incoming.status.as_ref()); + if session_connected + && matches!(phase, SandboxPhase::Provisioning | SandboxPhase::Unknown) { - warn!( + ensure_supervisor_ready_status(&mut status, &sandbox_name); + phase = SandboxPhase::Ready; + } + + let old_phase = + SandboxPhase::try_from(sandbox.phase).unwrap_or(SandboxPhase::Unknown); + if old_phase != phase { + info!( sandbox_id = %incoming.id, - sandbox_name = %incoming.name, - reason = %condition.reason, - message = %condition.message, - "Sandbox failed to become ready" + sandbox_name = %sandbox_name, + old_phase = ?old_phase, + new_phase = ?phase, + "Sandbox phase changed" ); } - } - } - // Update metadata fields - if let Some(metadata) = sandbox.metadata.as_mut() { - metadata.name = incoming.name; - } - // Note: namespace field removed from public Sandbox API - it remains internal to DriverSandbox - sandbox.status = status; - sandbox.phase = phase as i32; + if phase == SandboxPhase::Error + && let Some(ref status) = status + { + for condition in &status.conditions { + if condition.r#type == "Ready" + && condition.status.eq_ignore_ascii_case("false") + && is_terminal_failure_reason(&condition.reason) + { + warn!( + sandbox_id = %incoming.id, + sandbox_name = %sandbox_name, + reason = %condition.reason, + message = %condition.message, + "Sandbox failed to become ready" + ); + } + } + } - if previous.as_ref() == Some(&sandbox) { - return Ok(()); - } + // Update metadata fields + if let Some(metadata) = sandbox.metadata.as_mut() { + metadata.name.clone_from(&sandbox_name); + } + // Note: namespace field removed from public Sandbox API - it remains internal to DriverSandbox + sandbox.status = status; + sandbox.phase = phase as i32; + }) + .await + .map_err(|e| match e { + crate::persistence::PersistenceError::Conflict { + current_resource_version, + } => format!( + "concurrent modification detected during sandbox reconciliation (current resource_version: {})", + current_resource_version + .map_or_else(|| "unknown".to_string(), |v| v.to_string()) + ), + other => other.to_string(), + })?; self.sandbox_index.update_from_sandbox(&sandbox); - self.store - .put_message(&sandbox) - .await - .map_err(|e| e.to_string())?; self.sandbox_watch_bus.notify(sandbox.object_id()); Ok(()) } @@ -909,38 +976,51 @@ impl ComputeRuntime { connected: bool, ) -> Result<(), String> { let _guard = self.sync_lock.lock().await; - let Some(record) = self + + // Use CAS to update sandbox phase based on supervisor session state + let result = self .store - .get(Sandbox::object_type(), sandbox_id) - .await - .map_err(|e| e.to_string())? - else { - return Ok(()); - }; + .update_message_cas::(sandbox_id, 5, |sandbox| { + let current_phase = + SandboxPhase::try_from(sandbox.phase).unwrap_or(SandboxPhase::Unknown); - let mut sandbox = decode_sandbox_record(&record)?; - let current_phase = SandboxPhase::try_from(sandbox.phase).unwrap_or(SandboxPhase::Unknown); + // Skip if sandbox is in terminal state + if current_phase == SandboxPhase::Deleting || current_phase == SandboxPhase::Error { + return; + } - if current_phase == SandboxPhase::Deleting || current_phase == SandboxPhase::Error { - return Ok(()); - } + let sandbox_name = sandbox.object_name().to_string(); + if connected { + ensure_supervisor_ready_status(&mut sandbox.status, &sandbox_name); + sandbox.phase = SandboxPhase::Ready as i32; + } else if current_phase == SandboxPhase::Ready { + ensure_supervisor_not_ready_status(&mut sandbox.status, &sandbox_name); + sandbox.phase = SandboxPhase::Provisioning as i32; + } + }) + .await; - let sandbox_name = sandbox.object_name().to_string(); - if connected { - ensure_supervisor_ready_status(&mut sandbox.status, &sandbox_name); - sandbox.phase = SandboxPhase::Ready as i32; - } else if current_phase == SandboxPhase::Ready { - ensure_supervisor_not_ready_status(&mut sandbox.status, &sandbox_name); - sandbox.phase = SandboxPhase::Provisioning as i32; - } else { - return Ok(()); - } + // Handle not found gracefully (sandbox may have been deleted) + let sandbox = match result { + Ok(s) => s, + Err(crate::persistence::PersistenceError::Database(ref msg)) + if msg.contains("not found") => + { + return Ok(()); + } + Err(crate::persistence::PersistenceError::Conflict { + current_resource_version, + }) => { + return Err(format!( + "concurrent modification detected (current resource_version: {})", + current_resource_version + .map_or_else(|| "unknown".to_string(), |v| v.to_string()) + )); + } + Err(e) => return Err(e.to_string()), + }; self.sandbox_index.update_from_sandbox(&sandbox); - self.store - .put_message(&sandbox) - .await - .map_err(|e| e.to_string())?; self.sandbox_watch_bus.notify(sandbox_id); Ok(()) } @@ -1830,6 +1910,7 @@ mod tests { name: name.to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), phase: phase as i32, ..Default::default() @@ -1843,6 +1924,7 @@ mod tests { name: format!("session-{id}"), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), sandbox_id: sandbox_id.to_string(), token: format!("token-{id}"), @@ -2757,4 +2839,40 @@ mod tests { "unset user_namespaces must not produce host_users" ); } + + #[tokio::test] + async fn create_sandbox_returns_resource_version_one() { + let runtime = test_runtime(Arc::new(TestDriver::default())).await; + + let mut sandbox = sandbox_record("sb-new", "test-sandbox", SandboxPhase::Provisioning); + // Clear metadata to simulate incoming request + sandbox.metadata = Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "sb-new".to_string(), + name: "test-sandbox".to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }); + + let created = runtime.create_sandbox(sandbox).await.unwrap(); + + assert_eq!( + created.metadata.as_ref().unwrap().resource_version, + 1, + "create_sandbox should return resource_version: 1 after insert" + ); + + // Verify database also has resource_version: 1 + let stored = runtime + .store + .get_message::("sb-new") + .await + .unwrap() + .unwrap(); + assert_eq!( + stored.metadata.as_ref().unwrap().resource_version, + 1, + "database should have resource_version: 1 after create" + ); + } } diff --git a/crates/openshell-server/src/grpc/mod.rs b/crates/openshell-server/src/grpc/mod.rs index ebb8b1021..a3e94a573 100644 --- a/crates/openshell-server/src/grpc/mod.rs +++ b/crates/openshell-server/src/grpc/mod.rs @@ -61,6 +61,29 @@ pub fn clamp_limit(raw: u32, default: u32, max: u32) -> u32 { if raw == 0 { default } else { raw.min(max) } } +/// Map a `PersistenceError` to an appropriate gRPC `Status`. +/// +/// CAS conflicts (optimistic concurrency failures) are mapped to `ABORTED` +/// to signal that the client should retry with fresh data. Other persistence +/// errors are mapped to `INTERNAL`. +pub fn persistence_error_to_status( + err: crate::persistence::PersistenceError, + operation: &str, +) -> Status { + use crate::persistence::PersistenceError; + + match err { + PersistenceError::Conflict { + current_resource_version, + } => Status::aborted(format!( + "{} failed due to concurrent modification (current resource_version: {})", + operation, + current_resource_version.map_or_else(|| "unknown".to_string(), |v| v.to_string()) + )), + other => Status::internal(format!("{operation} failed: {other}")), + } +} + // --------------------------------------------------------------------------- // Field-level size limits (shared across submodules) // --------------------------------------------------------------------------- diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index d5a47bcba..b441f3565 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -1000,32 +1000,25 @@ pub(super) async fn handle_update_config( validate_static_fields_unchanged(baseline_policy, &new_policy)?; validate_policy_safety(&new_policy)?; } else { + // Backfill spec.policy using CAS (first-time policy discovery) let _sandbox_sync_guard = state.compute.sandbox_sync_guard().await; - let mut sandbox = state + let sandbox_id = sandbox.object_id().to_string(); + let new_policy_clone = new_policy.clone(); + state .store - .get_message::(&sandbox_id) + .update_message_cas::(&sandbox_id, 5, |sandbox| { + if let Some(ref mut spec) = sandbox.spec + && spec.policy.is_none() + { + spec.policy = Some(new_policy_clone.clone()); + } + }) .await - .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? - .ok_or_else(|| Status::not_found("sandbox not found"))?; - let spec = sandbox - .spec - .as_mut() - .ok_or_else(|| Status::internal("sandbox has no spec"))?; - if let Some(baseline_policy) = spec.policy.as_ref() { - validate_static_fields_unchanged(baseline_policy, &new_policy)?; - validate_policy_safety(&new_policy)?; - } else { - spec.policy = Some(new_policy.clone()); - state - .store - .put_message(&sandbox) - .await - .map_err(|e| Status::internal(format!("backfill spec.policy failed: {e}")))?; - info!( - sandbox_id = %sandbox_id, - "UpdateConfig: backfilled spec.policy from sandbox-discovered policy" - ); - } + .map_err(|e| super::persistence_error_to_status(e, "backfill spec.policy"))?; + info!( + sandbox_id = %sandbox_id, + "UpdateConfig: backfilled spec.policy from sandbox-discovered policy" + ); } let latest = state @@ -1222,11 +1215,18 @@ pub(super) async fn handle_report_policy_status( .store .supersede_older_policies(&req.sandbox_id, version) .await; + + // Update current_policy_version using CAS with bounded retry let _sandbox_sync_guard = state.compute.sandbox_sync_guard().await; - if let Ok(Some(mut sandbox)) = state.store.get_message::(&req.sandbox_id).await { - sandbox.current_policy_version = req.version; - let _ = state.store.put_message(&sandbox).await; - } + let version_to_set = req.version; + state + .store + .update_message_cas::(&req.sandbox_id, 5, |sandbox| { + sandbox.current_policy_version = version_to_set; + }) + .await + .map_err(|e| super::persistence_error_to_status(e, "update current_policy_version"))?; + state.sandbox_watch_bus.notify(&req.sandbox_id); } @@ -2650,19 +2650,85 @@ async fn save_settings_record( name: &str, settings: &StoredSettings, ) -> Result<(), Status> { + use crate::persistence::WriteCondition; + let payload = serde_json::to_vec(settings) .map_err(|e| Status::internal(format!("encode settings payload failed: {e}")))?; - store - .put( - object_type, - &uuid::Uuid::new_v4().to_string(), - name, - &payload, - None, - ) + + // Resolve stable ID once before retry loop to prevent lost writes from ID reassignment + let existing = store + .get_by_name(object_type, name) .await - .map_err(|e| Status::internal(format!("persist settings failed: {e}")))?; - Ok(()) + .map_err(|e| Status::internal(format!("fetch settings for CAS failed: {e}")))?; + + let (mut id, mut condition) = if let Some(record) = existing { + // Use stable ID and CAS with expected version + ( + record.id, + WriteCondition::MatchResourceVersion(record.resource_version), + ) + } else { + // Create new with fresh UUID + (uuid::Uuid::new_v4().to_string(), WriteCondition::MustCreate) + }; + + // Track if we've resolved the ID from a concurrent creation to ensure it only happens once + let mut id_resolved_from_unique_violation = false; + + // Retry loop for CAS conflicts (bounded to 20 attempts) + // Settings have higher contention than other objects (global settings accessed by all sandboxes, + // sandbox settings updated by both supervisor and control plane) so we allow more retries than + // the standard 5, but 100 was excessive. 20 retries = ~20ms max latency under heavy contention. + for attempt in 0..20 { + match store + .put_if(object_type, &id, name, &payload, None, condition) + .await + { + Ok(_) => return Ok(()), + Err(crate::persistence::PersistenceError::Conflict { .. }) if attempt + 1 < 20 => { + // Conflict - re-fetch to get latest resource_version, but keep stable ID + tokio::task::yield_now().await; + let latest = store + .get(object_type, &id) + .await + .map_err(|e| Status::internal(format!("re-fetch settings failed: {e}")))? + .ok_or_else(|| Status::internal("settings disappeared mid-update"))?; + // Update condition with latest version, but NEVER reassign ID + condition = WriteCondition::MatchResourceVersion(latest.resource_version); + } + Err(crate::persistence::PersistenceError::UniqueViolation { constraint, .. }) + if constraint.as_deref() == Some("objects_name_uq") && attempt + 1 < 20 => + { + // Unique violation on name - object was created concurrently + // This should only happen once (on first attempt with MustCreate) + if id_resolved_from_unique_violation { + return Err(Status::internal( + "UniqueViolation after ID already resolved - concurrent creation race", + )); + } + + tokio::task::yield_now().await; + let latest = store + .get_by_name(object_type, name) + .await + .map_err(|e| Status::internal(format!("re-fetch settings failed: {e}")))? + .ok_or_else(|| Status::internal("settings disappeared mid-update"))?; + + // Switch to the discovered ID (only once) + id = latest.id; + id_resolved_from_unique_violation = true; + condition = WriteCondition::MatchResourceVersion(latest.resource_version); + } + Err(e) => { + return Err(super::persistence_error_to_status(e, "persist settings")); + } + } + } + + // If we've exhausted all 20 retries, it means we hit continuous conflicts + Err(Status::aborted( + "settings save failed: retry budget exhausted due to continuous concurrent modifications", + )) } fn decode_policy_from_global_settings( @@ -2827,6 +2893,7 @@ mod tests { name: "no-policy-sandbox".to_string(), created_at_ms: 1_000_000, labels: std::collections::HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { policy: None, @@ -2852,6 +2919,7 @@ mod tests { name: name.to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), r#type: provider_type.to_string(), credentials: std::iter::once(("GITHUB_TOKEN".to_string(), "ghp-test".to_string())) @@ -2893,6 +2961,7 @@ mod tests { name: name.to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { policy: Some(policy), @@ -2961,6 +3030,7 @@ mod tests { name: "generic".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), profile: Some(openshell_core::proto::ProviderProfile { id: "generic".to_string(), @@ -3002,6 +3072,7 @@ mod tests { name: "custom-api".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), profile: Some(openshell_core::proto::ProviderProfile { id: "custom-api".to_string(), @@ -3064,6 +3135,7 @@ mod tests { name: "custom-api".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), profile: Some(openshell_core::proto::ProviderProfile { id: "custom-api".to_string(), @@ -3749,6 +3821,7 @@ mod tests { name: "global-profile-sandbox".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { policy: Some(sandbox_policy), @@ -3836,6 +3909,7 @@ mod tests { name: "backfill-sandbox".to_string(), created_at_ms: 1_000_000, labels: std::collections::HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { policy: None, @@ -3918,6 +3992,7 @@ mod tests { name: "draft-flow".to_string(), created_at_ms: 1_000_000, labels: std::collections::HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { policy: None, @@ -4121,6 +4196,7 @@ mod tests { name: "draft-owner".to_string(), created_at_ms: 1_000_000, labels: std::collections::HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { policy: None, @@ -4135,6 +4211,7 @@ mod tests { name: "draft-other".to_string(), created_at_ms: 1_000_001, labels: std::collections::HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { policy: None, @@ -5260,24 +5337,43 @@ mod tests { .settings .insert(format!("key_{i}"), StoredSettingValue::Int(i as i64)); settings.revision = settings.revision.wrapping_add(1); - save_global_settings(&store, &settings).await.unwrap(); + save_global_settings(&store, &settings).await })); } + let mut succeeded = 0; + let mut cas_conflicts = 0; for h in handles { - h.await.unwrap(); + match h.await.unwrap() { + Ok(()) => succeeded += 1, + Err(e) if e.code() == Code::Aborted => cas_conflicts += 1, + Err(e) => panic!("unexpected error: {e}"), + } } let final_settings = load_global_settings(&store).await.unwrap(); - let lost = (n as u64).saturating_sub(final_settings.revision); - if lost == 0 { - eprintln!( - "note: no lost writes detected in unlocked test (sequential scheduling); \ - the locked test is the authoritative correctness check" - ); - } else { - eprintln!("unlocked test: {lost} lost writes out of {n} (expected behavior)"); - } + + // With CAS, conflicts are detected (not silent), but without a proper retry loop that + // re-loads and re-applies the mutation, logical writes can still be lost: + // - All tasks read initial state (revision=0) + // - All increment to revision=1 + // - First write succeeds, sets revision=1 + // - Subsequent writes that succeed also write revision=1 (stale payload) + // - Result: revision=1 even though multiple "succeeded" + // + // This demonstrates that CAS on resource_version prevents database-level corruption, + // but the caller must implement proper retry logic (re-load on conflict) to prevent + // application-level lost writes. The locked test shows the correct pattern. + assert!( + final_settings.revision < succeeded as u64, + "without retry logic, some successful writes had stale payloads (expected behavior)" + ); + + eprintln!( + "unlocked CAS test: {succeeded} succeeded, {cas_conflicts} CAS conflicts, \ + final revision={} (< {succeeded} demonstrates lost application-level writes without retry logic)", + final_settings.revision + ); } // ---- Conflict guard tests ---- diff --git a/crates/openshell-server/src/grpc/provider.rs b/crates/openshell-server/src/grpc/provider.rs index 2ed4d439d..b61d53945 100644 --- a/crates/openshell-server/src/grpc/provider.rs +++ b/crates/openshell-server/src/grpc/provider.rs @@ -44,6 +44,7 @@ pub(super) async fn create_provider_record( name: generate_name(), created_at_ms: now_ms, labels: std::collections::HashMap::new(), + resource_version: 0, }); } @@ -86,6 +87,13 @@ pub(super) async fn create_provider_record( .await .map_err(|e| Status::internal(format!("persist provider failed: {e}")))?; + // Read back from DB to get correct resource_version + let provider = store + .get_message_by_name::(provider.object_name()) + .await + .map_err(|e| Status::internal(format!("fetch provider failed: {e}")))? + .ok_or_else(|| Status::internal("provider disappeared after creation"))?; + Ok(redact_provider_credentials(provider)) } @@ -126,12 +134,13 @@ pub(super) async fn update_provider_record( store: &Store, provider: Provider, ) -> Result { - use crate::persistence::ObjectName; + use crate::persistence::{ObjectId, ObjectName}; if provider.object_name().is_empty() { return Err(Status::invalid_argument("provider.name is required")); } + // Resolve provider ID from name for CAS update let existing = store .get_message_by_name::(provider.object_name()) .await @@ -141,6 +150,8 @@ pub(super) async fn update_provider_record( return Err(Status::not_found("provider not found")); }; + let provider_id = existing.object_id().to_string(); + // Provider type is immutable after creation. Reject if the caller // sends a non-empty type that differs from the existing one. let incoming_type = provider.r#type.trim(); @@ -150,23 +161,28 @@ pub(super) async fn update_provider_record( )); } - let updated = Provider { - metadata: existing.metadata, - r#type: existing.r#type, - credentials: merge_map(existing.credentials, provider.credentials), - config: merge_map(existing.config, provider.config), - }; + // Capture incoming maps for use in CAS closure + let incoming_credentials = provider.credentials; + let incoming_config = provider.config; + + // Use CAS to merge and update provider with bounded retry + let updated = store + .update_message_cas::(&provider_id, 5, |existing_provider| { + existing_provider.credentials = merge_map( + existing_provider.credentials.clone(), + incoming_credentials.clone(), + ); + existing_provider.config = + merge_map(existing_provider.config.clone(), incoming_config.clone()); + }) + .await + .map_err(|e| super::persistence_error_to_status(e, "update provider"))?; // Ensure metadata is valid (defense in depth - existing.metadata should always be valid) super::validation::validate_object_metadata(updated.metadata.as_ref(), "provider")?; validate_provider_fields(&updated)?; - store - .put_message(&updated) - .await - .map_err(|e| Status::internal(format!("persist provider failed: {e}")))?; - Ok(redact_provider_credentials(updated)) } @@ -656,6 +672,7 @@ fn stored_provider_profile(profile: ProviderProfile) -> StoredProviderProfile { name: profile.id.clone(), created_at_ms: now_ms, labels: std::collections::HashMap::new(), + resource_version: 0, }), profile: Some(profile), } @@ -796,6 +813,7 @@ mod tests { name: name.to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: provider_type.to_string(), credentials: [ @@ -1308,6 +1326,7 @@ mod tests { name: "sandbox-using-custom".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { providers: vec!["custom-provider".to_string()], @@ -1400,6 +1419,7 @@ mod tests { name: "gitlab-local".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), r#type: "gitlab".to_string(), credentials: std::iter::once(( @@ -1475,6 +1495,7 @@ mod tests { name: "attached-sandbox".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { providers: vec!["gitlab-local".to_string()], @@ -1496,6 +1517,78 @@ mod tests { ); } + #[tokio::test] + async fn provider_create_and_update_return_correct_resource_version() { + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + // Create provider and verify resource_version: 1 in response + let created = provider_with_values("test-provider", "openai"); + let persisted = create_provider_record(&store, created).await.unwrap(); + assert_eq!( + persisted.metadata.as_ref().unwrap().resource_version, + 1, + "create_provider_record should return resource_version: 1 after insert" + ); + + // Update provider and verify resource_version: 2 in response + let updated = update_provider_record( + &store, + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "test-provider".to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: "openai".to_string(), + credentials: std::iter::once(( + "OPENAI_API_KEY".to_string(), + "updated-key".to_string(), + )) + .collect(), + config: HashMap::new(), + }, + ) + .await + .unwrap(); + assert_eq!( + updated.metadata.as_ref().unwrap().resource_version, + 2, + "update_provider_record should return resource_version: 2 after first update" + ); + + // Update again and verify resource_version: 3 + let updated_again = update_provider_record( + &store, + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: "test-provider".to_string(), + created_at_ms: 1_000_000, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: "openai".to_string(), + credentials: std::iter::once(( + "OPENAI_API_KEY".to_string(), + "third-key".to_string(), + )) + .collect(), + config: HashMap::new(), + }, + ) + .await + .unwrap(); + assert_eq!( + updated_again.metadata.as_ref().unwrap().resource_version, + 3, + "update_provider_record should return resource_version: 3 after second update" + ); + } + #[tokio::test] async fn provider_validation_errors() { let store = Store::connect("sqlite::memory:?cache=shared") @@ -1510,6 +1603,7 @@ mod tests { name: "bad-provider".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), r#type: String::new(), credentials: HashMap::new(), @@ -1534,6 +1628,7 @@ mod tests { name: "missing".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), r#type: String::new(), credentials: HashMap::new(), @@ -1562,6 +1657,7 @@ mod tests { name: "noop-test".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), r#type: String::new(), credentials: HashMap::new(), @@ -1609,6 +1705,7 @@ mod tests { name: "delete-key-test".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: String::new(), credentials: std::iter::once(("SECONDARY".to_string(), String::new())).collect(), @@ -1660,6 +1757,7 @@ mod tests { name: "type-preserve-test".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: String::new(), credentials: HashMap::new(), @@ -1689,6 +1787,7 @@ mod tests { name: "type-change-test".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: "openai".to_string(), credentials: HashMap::new(), @@ -1720,6 +1819,7 @@ mod tests { name: "validate-merge-test".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: String::new(), credentials: std::iter::once((oversized_key, "value".to_string())).collect(), @@ -1748,6 +1848,7 @@ mod tests { name: "claude-local".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: "claude".to_string(), credentials: [ @@ -1791,6 +1892,7 @@ mod tests { name: "test-provider".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: "test".to_string(), credentials: [ @@ -1823,6 +1925,7 @@ mod tests { name: "claude-local".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: "claude".to_string(), credentials: std::iter::once(( @@ -1843,6 +1946,7 @@ mod tests { name: "gitlab-local".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: "gitlab".to_string(), credentials: std::iter::once(("GITLAB_TOKEN".to_string(), "glpat-xyz".to_string())) @@ -1874,6 +1978,7 @@ mod tests { name: "provider-a".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: "claude".to_string(), credentials: std::iter::once(("SHARED_KEY".to_string(), "first-value".to_string())) @@ -1891,6 +1996,7 @@ mod tests { name: "provider-b".to_string(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: "gitlab".to_string(), credentials: std::iter::once(( @@ -1927,6 +2033,7 @@ mod tests { name: "my-claude".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), r#type: "claude".to_string(), credentials: std::iter::once(( @@ -1946,6 +2053,7 @@ mod tests { name: "test-sandbox".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec { providers: vec!["my-claude".to_string()], @@ -1982,6 +2090,7 @@ mod tests { name: "empty-sandbox".to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), spec: Some(SandboxSpec::default()), status: None, diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index 65ac69acb..06aeedb29 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -24,6 +24,7 @@ use openshell_core::proto::{ use openshell_core::proto::{Sandbox, SandboxPhase, SandboxTemplate, SshSession}; use prost::Message; use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; @@ -104,6 +105,7 @@ pub(super) async fn handle_create_sandbox( name: name.clone(), created_at_ms: now_ms, labels: request.labels.clone(), + resource_version: 0, }), spec: Some(spec), status: None, @@ -212,6 +214,16 @@ pub(super) async fn handle_attach_sandbox_provider( return Err(Status::invalid_argument("provider_name is required")); } + // Validate provider name would not violate sandbox spec constraints if added + // (pre-validation ensures CAS mutations preserve invariants) + if request.provider_name.len() > super::MAX_NAME_LEN { + return Err(Status::invalid_argument(format!( + "provider_name exceeds maximum length ({} > {})", + request.provider_name.len(), + super::MAX_NAME_LEN + ))); + } + get_provider_record(state.store.as_ref(), &request.provider_name) .await .map_err(|err| { @@ -226,39 +238,59 @@ pub(super) async fn handle_attach_sandbox_provider( })?; let _sandbox_sync_guard = state.compute.sandbox_sync_guard().await; - let mut sandbox = sandbox_by_name(state, &request.sandbox_name).await?; - let sandbox_name = sandbox + let sandbox = sandbox_by_name(state, &request.sandbox_name).await?; + let sandbox_id = sandbox .metadata .as_ref() - .map_or_else(String::new, |metadata| metadata.name.clone()); + .ok_or_else(|| Status::internal("sandbox metadata is missing"))? + .id + .clone(); + + // Pre-check: fail fast if sandbox spec is missing (invariant violation) let spec = sandbox .spec - .as_mut() - .ok_or_else(|| Status::failed_precondition("sandbox spec is missing"))?; - - dedupe_provider_names(&mut spec.providers); - let attached = if spec - .providers - .iter() - .any(|name| name == &request.provider_name) + .as_ref() + .ok_or_else(|| Status::internal("sandbox spec is missing"))?; + + // Pre-check: fail fast if already at MAX_PROVIDERS limit (avoid spurious CAS conflicts) + // Note: This is an optimization; the CAS closure rechecks after dedupe in case of races + if spec.providers.len() >= MAX_PROVIDERS + && !spec + .providers + .iter() + .any(|name| name == &request.provider_name) { - false - } else { - if spec.providers.len() >= MAX_PROVIDERS { - return Err(Status::invalid_argument(format!( - "providers list exceeds maximum ({MAX_PROVIDERS})" - ))); - } - spec.providers.push(request.provider_name.clone()); - true - }; - validate_sandbox_spec(&sandbox_name, spec)?; + return Err(Status::invalid_argument(format!( + "providers list exceeds maximum ({MAX_PROVIDERS})" + ))); + } - state + let provider_name = request.provider_name.clone(); + let attached = Arc::new(AtomicBool::new(false)); + let attached_clone = attached.clone(); + + let sandbox = state .store - .put_message(&sandbox) + .update_message_cas::(&sandbox_id, 5, |sandbox| { + let Some(ref mut spec) = sandbox.spec else { + // Spec should always exist post-creation; if missing, fail CAS to surface error + return; + }; + + dedupe_provider_names(&mut spec.providers); + if !spec.providers.iter().any(|name| name == &provider_name) + && spec.providers.len() < MAX_PROVIDERS + { + spec.providers.push(provider_name.clone()); + attached_clone.store(true, Ordering::Relaxed); + } + // If limit hit during retry due to concurrent operations, CAS will retry + // Pre-check above prevents spurious failures in the common case + }) .await - .map_err(|e| Status::internal(format!("persist sandbox failed: {e}")))?; + .map_err(|e| super::persistence_error_to_status(e, "attach sandbox provider"))?; + + let attached = attached.load(Ordering::Relaxed); info!( sandbox_name = %request.sandbox_name, @@ -282,28 +314,54 @@ pub(super) async fn handle_detach_sandbox_provider( return Err(Status::invalid_argument("provider_name is required")); } + // Validate provider name (pre-validation ensures CAS mutations preserve invariants) + if request.provider_name.len() > super::MAX_NAME_LEN { + return Err(Status::invalid_argument(format!( + "provider_name exceeds maximum length ({} > {})", + request.provider_name.len(), + super::MAX_NAME_LEN + ))); + } + let _sandbox_sync_guard = state.compute.sandbox_sync_guard().await; - let mut sandbox = sandbox_by_name(state, &request.sandbox_name).await?; - let sandbox_name = sandbox + let sandbox = sandbox_by_name(state, &request.sandbox_name).await?; + let sandbox_id = sandbox .metadata .as_ref() - .map_or_else(String::new, |metadata| metadata.name.clone()); - let spec = sandbox + .ok_or_else(|| Status::internal("sandbox metadata is missing"))? + .id + .clone(); + + // Pre-check: fail fast if sandbox spec is missing (invariant violation) + let _spec = sandbox .spec - .as_mut() - .ok_or_else(|| Status::failed_precondition("sandbox spec is missing"))?; + .as_ref() + .ok_or_else(|| Status::internal("sandbox spec is missing"))?; - let before_len = spec.providers.len(); - spec.providers.retain(|name| name != &request.provider_name); - let detached = spec.providers.len() != before_len; - dedupe_provider_names(&mut spec.providers); - validate_sandbox_spec(&sandbox_name, spec)?; + let provider_name = request.provider_name.clone(); + let detached = Arc::new(AtomicBool::new(false)); + let detached_clone = detached.clone(); - state + let sandbox = state .store - .put_message(&sandbox) + .update_message_cas::(&sandbox_id, 5, |sandbox| { + let Some(ref mut spec) = sandbox.spec else { + // Spec should always exist post-creation; if missing, fail CAS to surface error + return; + }; + + let before_len = spec.providers.len(); + spec.providers.retain(|name| name != &provider_name); + if spec.providers.len() != before_len { + detached_clone.store(true, Ordering::Relaxed); + // Only dedupe after making a change + dedupe_provider_names(&mut spec.providers); + } + }) .await - .map_err(|e| Status::internal(format!("persist sandbox failed: {e}")))?; + .map_err(|e| super::persistence_error_to_status(e, "detach sandbox provider"))?; + + let detached = detached.load(Ordering::Relaxed); info!( sandbox_name = %request.sandbox_name, @@ -744,6 +802,7 @@ pub(super) async fn handle_create_ssh_session( name: generate_name(), created_at_ms: now_ms, labels: std::collections::HashMap::new(), + resource_version: 0, }), sandbox_id: req.sandbox_id.clone(), token: token.clone(), @@ -1276,6 +1335,7 @@ mod tests { name: name.to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), r#type: provider_type.to_string(), credentials: std::iter::once(("TOKEN".to_string(), "secret".to_string())).collect(), @@ -1290,6 +1350,7 @@ mod tests { name: name.to_string(), created_at_ms: 1_000_000, labels: std::iter::once(("team".to_string(), "agents".to_string())).collect(), + resource_version: 0, }), spec: Some(openshell_core::proto::SandboxSpec { log_level: "debug".to_string(), @@ -1488,4 +1549,161 @@ mod tests { assert_eq!(err.code(), tonic::Code::FailedPrecondition); } + + #[tokio::test] + async fn attach_sandbox_provider_accepts_at_max_providers_limit() { + let state = test_server_state().await; + + // Create MAX_PROVIDERS (32) providers + for i in 0..MAX_PROVIDERS { + state + .store + .put_message(&test_provider(&format!("provider-{i}"), "generic")) + .await + .unwrap(); + } + + // Create sandbox with 31 providers already attached + let mut existing_providers = Vec::new(); + for i in 0..(MAX_PROVIDERS - 1) { + existing_providers.push(format!("provider-{i}")); + } + state + .store + .put_message(&test_sandbox("work", existing_providers)) + .await + .unwrap(); + + // Attaching the 32nd provider should succeed + let response = handle_attach_sandbox_provider( + &state, + Request::new(AttachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: "provider-31".to_string(), + }), + ) + .await + .unwrap() + .into_inner(); + + assert!(response.attached); + let providers = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap() + .spec + .unwrap() + .providers; + assert_eq!(providers.len(), MAX_PROVIDERS); + } + + #[tokio::test] + async fn attach_sandbox_provider_rejects_beyond_max_providers_limit() { + let state = test_server_state().await; + + // Create MAX_PROVIDERS + 1 providers + for i in 0..=MAX_PROVIDERS { + state + .store + .put_message(&test_provider(&format!("provider-{i}"), "generic")) + .await + .unwrap(); + } + + // Create sandbox with MAX_PROVIDERS already attached + let mut existing_providers = Vec::new(); + for i in 0..MAX_PROVIDERS { + existing_providers.push(format!("provider-{i}")); + } + state + .store + .put_message(&test_sandbox("work", existing_providers)) + .await + .unwrap(); + + // Attempting to attach the 33rd provider should fail + let err = handle_attach_sandbox_provider( + &state, + Request::new(AttachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: "provider-32".to_string(), + }), + ) + .await + .unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + assert!(err.message().contains("exceeds maximum")); + + // Verify sandbox was not modified + let providers = state + .store + .get_message_by_name::("work") + .await + .unwrap() + .unwrap() + .spec + .unwrap() + .providers; + assert_eq!(providers.len(), MAX_PROVIDERS); + } + + #[tokio::test] + async fn attach_sandbox_provider_pre_validation_fails_fast() { + let state = test_server_state().await; + + // Provider name that exceeds validation limits + let long_name = "a".repeat(1000); + state + .store + .put_message(&test_provider(&long_name, "generic")) + .await + .unwrap(); + + state + .store + .put_message(&test_sandbox("work", Vec::new())) + .await + .unwrap(); + + // Should fail validation before attempting CAS + let err = handle_attach_sandbox_provider( + &state, + Request::new(AttachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: long_name, + }), + ) + .await + .unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + } + + #[tokio::test] + async fn detach_sandbox_provider_pre_validation_rejects_invalid_names() { + let state = test_server_state().await; + state + .store + .put_message(&test_sandbox("work", vec!["valid".to_string()])) + .await + .unwrap(); + + // Provider name that exceeds validation limits + let long_name = "a".repeat(1000); + + let err = handle_detach_sandbox_provider( + &state, + Request::new(DetachSandboxProviderRequest { + sandbox_name: "work".to_string(), + provider_name: long_name, + }), + ) + .await + .unwrap_err(); + + assert_eq!(err.code(), tonic::Code::InvalidArgument); + } } diff --git a/crates/openshell-server/src/grpc/validation.rs b/crates/openshell-server/src/grpc/validation.rs index 160b7e031..dbc380d82 100644 --- a/crates/openshell-server/src/grpc/validation.rs +++ b/crates/openshell-server/src/grpc/validation.rs @@ -874,6 +874,7 @@ mod tests { name: name.to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), r#type: provider_type.to_string(), credentials, diff --git a/crates/openshell-server/src/inference.rs b/crates/openshell-server/src/inference.rs index b52700f0d..892d2eee5 100644 --- a/crates/openshell-server/src/inference.rs +++ b/crates/openshell-server/src/inference.rs @@ -190,6 +190,7 @@ async fn upsert_cluster_inference_route( name: route_name.to_string(), created_at_ms: now_ms, labels: std::collections::HashMap::new(), + resource_version: 0, }), config: Some(config), version: 1, @@ -490,6 +491,7 @@ mod tests { name: name.to_string(), created_at_ms: 1_000_000, labels: std::collections::HashMap::new(), + resource_version: 0, }), config: Some(ClusterInferenceConfig { provider_name: provider_name.to_string(), @@ -507,6 +509,7 @@ mod tests { name: name.to_string(), created_at_ms: 1_000_000, labels: std::collections::HashMap::new(), + resource_version: 0, }), r#type: provider_type.to_string(), credentials: std::iter::once((key_name.to_string(), key_value.to_string())).collect(), @@ -666,6 +669,7 @@ mod tests { name: "openai-dev".to_string(), created_at_ms: 1_000_000, labels: std::collections::HashMap::new(), + resource_version: 0, }), r#type: "openai".to_string(), credentials: std::iter::once(("OPENAI_API_KEY".to_string(), "sk-test".to_string())) @@ -687,6 +691,7 @@ mod tests { name: CLUSTER_INFERENCE_ROUTE_NAME.to_string(), created_at_ms: 1_000_000, labels: std::collections::HashMap::new(), + resource_version: 0, }), config: Some(ClusterInferenceConfig { provider_name: "openai-dev".to_string(), diff --git a/crates/openshell-server/src/persistence/mod.rs b/crates/openshell-server/src/persistence/mod.rs index 1c926bd4a..4225b14cb 100644 --- a/crates/openshell-server/src/persistence/mod.rs +++ b/crates/openshell-server/src/persistence/mod.rs @@ -41,6 +41,10 @@ pub enum PersistenceError { detail: Option, constraint_msg: String, }, + #[error("resource version conflict: expected version does not match current")] + Conflict { + current_resource_version: Option, + }, } impl PersistenceError { @@ -78,6 +82,28 @@ pub struct ObjectRecord { pub updated_at_ms: i64, /// JSON-serialized labels (key-value pairs). pub labels: Option, + /// Optimistic concurrency control version. + /// Incremented on each update for compare-and-swap operations. + pub resource_version: u64, +} + +/// Write condition for compare-and-swap operations. +#[derive(Debug, Clone, Copy)] +pub enum WriteCondition { + /// Object must not exist (insert only). + MustCreate, + /// Object must exist with the specified resource version (update only). + MatchResourceVersion(u64), + /// Unconditional write (insert or update). + Unconditional, +} + +/// Result of a successful write operation. +#[derive(Debug, Clone)] +pub struct WriteResult { + pub resource_version: u64, + pub created_at_ms: i64, + pub updated_at_ms: i64, } /// Persistence store implementations. @@ -94,7 +120,9 @@ pub trait ObjectType { // Import object metadata accessor traits from openshell-core // (implementations for all proto types are in openshell-core::metadata) -pub use openshell_core::{ObjectId, ObjectLabels, ObjectName}; +pub use openshell_core::{ + GetResourceVersion, ObjectId, ObjectLabels, ObjectName, SetResourceVersion, +}; /// Generate a random 6-character lowercase alphabetic name. pub fn generate_name() -> String { @@ -147,6 +175,74 @@ impl Store { } } + /// Insert or update a generic object with compare-and-swap support. + /// + /// # Arguments + /// * `object_type` - Type discriminator for the object + /// * `id` - Stable object identifier + /// * `name` - Human-readable object name + /// * `payload` - Serialized object data + /// * `labels` - Optional JSON-serialized labels + /// * `condition` - Write precondition (`MustCreate`, `MatchResourceVersion`, or `Unconditional`) + /// + /// # Returns + /// * `Ok(WriteResult)` - Write succeeded with new `resource_version` and timestamps + /// * `Err(Conflict)` - Resource version mismatch (for `MatchResourceVersion`) + /// * `Err(UniqueViolation)` - Object already exists (for `MustCreate`) or name conflict + pub async fn put_if( + &self, + object_type: &str, + id: &str, + name: &str, + payload: &[u8], + labels: Option<&str>, + condition: WriteCondition, + ) -> PersistenceResult { + match self { + Self::Postgres(store) => { + store + .put_if(object_type, id, name, payload, labels, condition) + .await + } + Self::Sqlite(store) => { + store + .put_if(object_type, id, name, payload, labels, condition) + .await + } + } + } + + /// Delete an object by id with compare-and-swap support. + /// + /// # Arguments + /// * `object_type` - Type discriminator for the object + /// * `id` - Stable object identifier + /// * `expected_resource_version` - Required resource version for the delete to proceed + /// + /// # Returns + /// * `Ok(true)` - Object was deleted + /// * `Ok(false)` - Object not found + /// * `Err(Conflict)` - Resource version mismatch + pub async fn delete_if( + &self, + object_type: &str, + id: &str, + expected_resource_version: u64, + ) -> PersistenceResult { + match self { + Self::Postgres(store) => { + store + .delete_if(object_type, id, expected_resource_version) + .await + } + Self::Sqlite(store) => { + store + .delete_if(object_type, id, expected_resource_version) + .await + } + } + } + /// Fetch an object by id. pub async fn get( &self, @@ -253,7 +349,7 @@ impl Store { } /// Fetch and decode a protobuf message by id. - pub async fn get_message( + pub async fn get_message( &self, id: &str, ) -> PersistenceResult> { @@ -262,13 +358,17 @@ impl Store { return Ok(None); }; - T::decode(record.payload.as_slice()) - .map(Some) - .map_err(|e| PersistenceError::Decode(format!("protobuf decode error: {e}"))) + let mut message = T::decode(record.payload.as_slice()) + .map_err(|e| PersistenceError::Decode(format!("protobuf decode error: {e}")))?; + + // Hydrate resource_version from DB row (authoritative source) + message.set_resource_version(record.resource_version); + + Ok(Some(message)) } /// Fetch and decode a protobuf message by name. - pub async fn get_message_by_name( + pub async fn get_message_by_name( &self, name: &str, ) -> PersistenceResult> { @@ -277,9 +377,106 @@ impl Store { return Ok(None); }; - T::decode(record.payload.as_slice()) - .map(Some) - .map_err(|e| PersistenceError::Decode(format!("protobuf decode error: {e}"))) + let mut message = T::decode(record.payload.as_slice()) + .map_err(|e| PersistenceError::Decode(format!("protobuf decode error: {e}")))?; + + // Hydrate resource_version from DB row (authoritative source) + message.set_resource_version(record.resource_version); + + Ok(Some(message)) + } + + /// Update a protobuf message using CAS with bounded retry. + /// + /// Repeatedly fetches the latest object, applies the mutation function, + /// and attempts a CAS write until successful or max retries exhausted. + /// + /// # Arguments + /// * `id` - Object ID to update + /// * `max_retries` - Maximum number of retry attempts (typically 3-5) + /// * `mutate` - Function that modifies the object in place + /// + /// # Returns + /// * `Ok(T)` - Successfully updated object with new `resource_version` + /// * `Err(Conflict)` - Retry budget exhausted + /// * `Err(Database)` - Object not found or other DB error + pub async fn update_message_cas( + &self, + id: &str, + max_retries: u32, + mut mutate: F, + ) -> PersistenceResult + where + T: Message + + Default + + ObjectType + + ObjectId + + ObjectName + + ObjectLabels + + SetResourceVersion + + GetResourceVersion + + Clone, + F: FnMut(&mut T), + { + for attempt in 0..max_retries { + // Fetch latest object with authoritative resource_version + // Treat "not found" as a CAS conflict to make it retryable + // (object may have been deleted between caller's read and this update) + let current = self + .get_message::(id) + .await? + .ok_or(PersistenceError::Conflict { + current_resource_version: None, + })?; + + let expected_version = current.get_resource_version(); + + // Apply mutation + let mut updated = current.clone(); + mutate(&mut updated); + + // Serialize labels + let labels_map = updated.object_labels(); + let labels_json = if labels_map.as_ref().is_none_or(HashMap::is_empty) { + None + } else { + Some(serde_json::to_string(&labels_map).map_err(|e| { + PersistenceError::Encode(format!("failed to serialize labels: {e}")) + })?) + }; + + // Attempt CAS write + match self + .put_if( + T::object_type(), + updated.object_id(), + updated.object_name(), + &updated.encode_to_vec(), + labels_json.as_deref(), + WriteCondition::MatchResourceVersion(expected_version), + ) + .await + { + Ok(result) => { + // Success - hydrate the new resource_version and return + updated.set_resource_version(result.resource_version); + return Ok(updated); + } + Err(PersistenceError::Conflict { .. }) if attempt + 1 < max_retries => { + // Conflict - retry with latest state + tokio::task::yield_now().await; + } + Err(e) => { + // Non-retryable error or retry budget exhausted + return Err(e); + } + } + } + + // Should never reach here due to loop structure, but for safety + Err(PersistenceError::Conflict { + current_resource_version: None, + }) } } diff --git a/crates/openshell-server/src/persistence/postgres.rs b/crates/openshell-server/src/persistence/postgres.rs index 2cd6a046f..78a398d5b 100644 --- a/crates/openshell-server/src/persistence/postgres.rs +++ b/crates/openshell-server/src/persistence/postgres.rs @@ -2,8 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 use super::{ - DraftChunkRecord, ObjectRecord, PersistenceResult, PolicyRecord, current_time_ms, map_db_error, - map_migrate_error, + DraftChunkRecord, ObjectRecord, PersistenceError, PersistenceResult, PolicyRecord, + WriteCondition, WriteResult, current_time_ms, map_db_error, map_migrate_error, }; use crate::policy_store::{ draft_chunk_payload_from_record, draft_chunk_record_from_parts, policy_payload_from_record, @@ -52,7 +52,7 @@ impl PostgresStore { let labels_jsonb: Option = labels .map(serde_json::from_str) .transpose() - .map_err(|e| super::PersistenceError::Encode(format!("invalid labels JSON: {e}")))?; + .map_err(|e| PersistenceError::Encode(format!("invalid labels JSON: {e}")))?; sqlx::query( r" @@ -76,6 +76,157 @@ ON CONFLICT (object_type, name) WHERE name IS NOT NULL DO UPDATE SET Ok(()) } + pub async fn put_if( + &self, + object_type: &str, + id: &str, + name: &str, + payload: &[u8], + labels: Option<&str>, + condition: WriteCondition, + ) -> PersistenceResult { + let now_ms = current_time_ms()?; + let labels_jsonb: Option = labels + .map(serde_json::from_str) + .transpose() + .map_err(|e| PersistenceError::Encode(format!("invalid labels JSON: {e}")))?; + + match condition { + WriteCondition::MustCreate => { + // Insert only - fail if object exists + let row = sqlx::query( + r" +INSERT INTO objects (object_type, id, name, payload, created_at_ms, updated_at_ms, labels, resource_version) +VALUES ($1, $2, $3, $4, $5, $5, COALESCE($6, '{}'::jsonb), 1) +RETURNING resource_version, created_at_ms, updated_at_ms +", + ) + .bind(object_type) + .bind(id) + .bind(name) + .bind(payload) + .bind(now_ms) + .bind(labels_jsonb) + .fetch_one(&self.pool) + .await + .map_err(|e| map_db_error(&e))?; + + let resource_version_i64: i64 = row.try_get("resource_version").unwrap_or(1); + Ok(WriteResult { + resource_version: resource_version_i64.max(1).cast_unsigned(), + created_at_ms: row.get("created_at_ms"), + updated_at_ms: row.get("updated_at_ms"), + }) + } + WriteCondition::MatchResourceVersion(expected_version) => { + // Update with version check using RETURNING + let row_result = sqlx::query( + r" +UPDATE objects +SET payload = $4, labels = COALESCE($5, '{}'::jsonb), updated_at_ms = $6, resource_version = resource_version + 1 +WHERE object_type = $1 AND id = $2 AND resource_version = $3 +RETURNING resource_version, created_at_ms, updated_at_ms +", + ) + .bind(object_type) + .bind(id) + .bind(i64::try_from(expected_version).unwrap_or(i64::MAX)) + .bind(payload) + .bind(labels_jsonb) + .bind(now_ms) + .fetch_optional(&self.pool) + .await + .map_err(|e| map_db_error(&e))?; + + if let Some(row) = row_result { + let resource_version_i64: i64 = row.try_get("resource_version").unwrap_or(1); + Ok(WriteResult { + resource_version: resource_version_i64.max(1).cast_unsigned(), + created_at_ms: row.get("created_at_ms"), + updated_at_ms: row.get("updated_at_ms"), + }) + } else { + // Check if object exists to distinguish NotFound from Conflict + let existing = self.get(object_type, id).await?; + if let Some(record) = existing { + Err(PersistenceError::Conflict { + current_resource_version: Some(record.resource_version), + }) + } else { + Err(PersistenceError::Database(format!( + "object not found: {object_type}/{id}" + ))) + } + } + } + WriteCondition::Unconditional => { + // Unconditional upsert by name + let row = sqlx::query( + r" +INSERT INTO objects (object_type, id, name, payload, created_at_ms, updated_at_ms, labels, resource_version) +VALUES ($1, $2, $3, $4, $5, $5, COALESCE($6, '{}'::jsonb), 1) +ON CONFLICT (object_type, name) WHERE name IS NOT NULL DO UPDATE SET + payload = EXCLUDED.payload, + updated_at_ms = EXCLUDED.updated_at_ms, + labels = EXCLUDED.labels, + resource_version = objects.resource_version + 1 +RETURNING resource_version, created_at_ms, updated_at_ms +", + ) + .bind(object_type) + .bind(id) + .bind(name) + .bind(payload) + .bind(now_ms) + .bind(labels_jsonb) + .fetch_one(&self.pool) + .await + .map_err(|e| map_db_error(&e))?; + + let resource_version_i64: i64 = row.try_get("resource_version").unwrap_or(1); + Ok(WriteResult { + resource_version: resource_version_i64.max(1).cast_unsigned(), + created_at_ms: row.get("created_at_ms"), + updated_at_ms: row.get("updated_at_ms"), + }) + } + } + } + + pub async fn delete_if( + &self, + object_type: &str, + id: &str, + expected_resource_version: u64, + ) -> PersistenceResult { + let result = sqlx::query( + r" +DELETE FROM objects +WHERE object_type = $1 AND id = $2 AND resource_version = $3 +", + ) + .bind(object_type) + .bind(id) + .bind(i64::try_from(expected_resource_version).unwrap_or(i64::MAX)) + .execute(&self.pool) + .await + .map_err(|e| map_db_error(&e))?; + + if result.rows_affected() > 0 { + Ok(true) + } else { + // Check if object exists to distinguish NotFound from Conflict + let existing = self.get(object_type, id).await?; + if let Some(record) = existing { + Err(PersistenceError::Conflict { + current_resource_version: Some(record.resource_version), + }) + } else { + Ok(false) + } + } + } + pub async fn get( &self, object_type: &str, @@ -83,7 +234,7 @@ ON CONFLICT (object_type, name) WHERE name IS NOT NULL DO UPDATE SET ) -> PersistenceResult> { let row = sqlx::query( r" -SELECT object_type, id, name, payload, created_at_ms, updated_at_ms, labels +SELECT object_type, id, name, payload, created_at_ms, updated_at_ms, labels, resource_version FROM objects WHERE object_type = $1 AND id = $2 ", @@ -104,7 +255,7 @@ WHERE object_type = $1 AND id = $2 ) -> PersistenceResult> { let row = sqlx::query( r" -SELECT object_type, id, name, payload, created_at_ms, updated_at_ms, labels +SELECT object_type, id, name, payload, created_at_ms, updated_at_ms, labels, resource_version FROM objects WHERE object_type = $1 AND name = $2 ", @@ -146,7 +297,7 @@ WHERE object_type = $1 AND name = $2 ) -> PersistenceResult> { let rows = sqlx::query( r" -SELECT object_type, id, name, payload, created_at_ms, updated_at_ms, labels +SELECT object_type, id, name, payload, created_at_ms, updated_at_ms, labels, resource_version FROM objects WHERE object_type = $1 ORDER BY created_at_ms ASC, name ASC @@ -173,13 +324,12 @@ LIMIT $2 OFFSET $3 use super::parse_label_selector; let required_labels = parse_label_selector(label_selector)?; - let labels_jsonb = serde_json::to_value(&required_labels).map_err(|e| { - super::PersistenceError::Encode(format!("failed to serialize labels: {e}")) - })?; + let labels_jsonb = serde_json::to_value(&required_labels) + .map_err(|e| PersistenceError::Encode(format!("failed to serialize labels: {e}")))?; let rows = sqlx::query( r" -SELECT object_type, id, name, payload, created_at_ms, updated_at_ms, labels +SELECT object_type, id, name, payload, created_at_ms, updated_at_ms, labels, resource_version FROM objects WHERE object_type = $1 AND labels @> $2 ORDER BY created_at_ms ASC, name ASC @@ -603,6 +753,7 @@ fn draft_chunk_dedup_key(chunk: &DraftChunkRecord) -> String { fn row_to_object_record(row: sqlx::postgres::PgRow) -> ObjectRecord { let labels_jsonb: Option = row.get("labels"); + let resource_version_i64: i64 = row.try_get("resource_version").unwrap_or(1); ObjectRecord { object_type: row.get("object_type"), id: row.get("id"), @@ -611,6 +762,7 @@ fn row_to_object_record(row: sqlx::postgres::PgRow) -> ObjectRecord { created_at_ms: row.get("created_at_ms"), updated_at_ms: row.get("updated_at_ms"), labels: labels_jsonb.map(|value| value.to_string()), + resource_version: resource_version_i64.max(1).cast_unsigned(), } } diff --git a/crates/openshell-server/src/persistence/sqlite.rs b/crates/openshell-server/src/persistence/sqlite.rs index fafb07597..8c8325704 100644 --- a/crates/openshell-server/src/persistence/sqlite.rs +++ b/crates/openshell-server/src/persistence/sqlite.rs @@ -2,8 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 use super::{ - DraftChunkRecord, ObjectRecord, PersistenceResult, PolicyRecord, current_time_ms, map_db_error, - map_migrate_error, + DraftChunkRecord, ObjectRecord, PersistenceError, PersistenceResult, PolicyRecord, + WriteCondition, WriteResult, current_time_ms, map_db_error, map_migrate_error, }; use crate::policy_store::{ draft_chunk_payload_from_record, draft_chunk_record_from_parts, policy_payload_from_record, @@ -87,6 +87,155 @@ ON CONFLICT ("object_type", "name") WHERE "name" IS NOT NULL DO UPDATE SET Ok(()) } + pub async fn put_if( + &self, + object_type: &str, + id: &str, + name: &str, + payload: &[u8], + labels: Option<&str>, + condition: WriteCondition, + ) -> PersistenceResult { + let now_ms = current_time_ms()?; + + match condition { + WriteCondition::MustCreate => { + // Insert only - fail if object exists + sqlx::query( + r#" +INSERT INTO "objects" ("object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms", "labels", "resource_version") +VALUES (?1, ?2, ?3, ?4, ?5, ?5, ?6, 1) +"#, + ) + .bind(object_type) + .bind(id) + .bind(name) + .bind(payload) + .bind(now_ms) + .bind(labels.unwrap_or("{}")) + .execute(&self.pool) + .await + .map_err(|e| map_db_error(&e))?; + + Ok(WriteResult { + resource_version: 1, + created_at_ms: now_ms, + updated_at_ms: now_ms, + }) + } + WriteCondition::MatchResourceVersion(expected_version) => { + // Update with version check + let result = sqlx::query( + r#" +UPDATE "objects" +SET "payload" = ?4, "labels" = ?5, "updated_at_ms" = ?6, "resource_version" = "resource_version" + 1 +WHERE "object_type" = ?1 AND "id" = ?2 AND "resource_version" = ?3 +"#, + ) + .bind(object_type) + .bind(id) + .bind(i64::try_from(expected_version).unwrap_or(i64::MAX)) + .bind(payload) + .bind(labels.unwrap_or("{}")) + .bind(now_ms) + .execute(&self.pool) + .await + .map_err(|e| map_db_error(&e))?; + + if result.rows_affected() == 0 { + // Check if object exists to distinguish NotFound from Conflict + let existing = self.get(object_type, id).await?; + if let Some(record) = existing { + return Err(PersistenceError::Conflict { + current_resource_version: Some(record.resource_version), + }); + } + return Err(PersistenceError::Database(format!( + "object not found: {object_type}/{id}" + ))); + } + + // Fetch the updated record to get the new resource_version + let updated = self.get(object_type, id).await?.ok_or_else(|| { + PersistenceError::Database("object disappeared after update".to_string()) + })?; + + Ok(WriteResult { + resource_version: updated.resource_version, + created_at_ms: updated.created_at_ms, + updated_at_ms: updated.updated_at_ms, + }) + } + WriteCondition::Unconditional => { + // Unconditional upsert by name + sqlx::query( + r#" +INSERT INTO "objects" ("object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms", "labels", "resource_version") +VALUES (?1, ?2, ?3, ?4, ?5, ?5, ?6, 1) +ON CONFLICT ("object_type", "name") WHERE "name" IS NOT NULL DO UPDATE SET + "payload" = excluded."payload", + "updated_at_ms" = excluded."updated_at_ms", + "labels" = excluded."labels", + "resource_version" = "objects"."resource_version" + 1 +"#, + ) + .bind(object_type) + .bind(id) + .bind(name) + .bind(payload) + .bind(now_ms) + .bind(labels.unwrap_or("{}")) + .execute(&self.pool) + .await + .map_err(|e| map_db_error(&e))?; + + // Fetch the result to get the resource_version + let record = self.get_by_name(object_type, name).await?.ok_or_else(|| { + PersistenceError::Database("object disappeared after upsert".to_string()) + })?; + + Ok(WriteResult { + resource_version: record.resource_version, + created_at_ms: record.created_at_ms, + updated_at_ms: record.updated_at_ms, + }) + } + } + } + + pub async fn delete_if( + &self, + object_type: &str, + id: &str, + expected_resource_version: u64, + ) -> PersistenceResult { + let result = sqlx::query( + r#" +DELETE FROM "objects" +WHERE "object_type" = ?1 AND "id" = ?2 AND "resource_version" = ?3 +"#, + ) + .bind(object_type) + .bind(id) + .bind(i64::try_from(expected_resource_version).unwrap_or(i64::MAX)) + .execute(&self.pool) + .await + .map_err(|e| map_db_error(&e))?; + + if result.rows_affected() > 0 { + Ok(true) + } else { + // Check if object exists to distinguish NotFound from Conflict + let existing = self.get(object_type, id).await?; + if let Some(record) = existing { + return Err(PersistenceError::Conflict { + current_resource_version: Some(record.resource_version), + }); + } + Ok(false) + } + } + pub async fn get( &self, object_type: &str, @@ -94,7 +243,7 @@ ON CONFLICT ("object_type", "name") WHERE "name" IS NOT NULL DO UPDATE SET ) -> PersistenceResult> { let row = sqlx::query( r#" -SELECT "object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms", "labels" +SELECT "object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms", "labels", "resource_version" FROM "objects" WHERE "object_type" = ?1 AND "id" = ?2 "#, @@ -115,7 +264,7 @@ WHERE "object_type" = ?1 AND "id" = ?2 ) -> PersistenceResult> { let row = sqlx::query( r#" -SELECT "object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms", "labels" +SELECT "object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms", "labels", "resource_version" FROM "objects" WHERE "object_type" = ?1 AND "name" = ?2 "#, @@ -167,7 +316,7 @@ WHERE "object_type" = ?1 AND "name" = ?2 ) -> PersistenceResult> { let rows = sqlx::query( r#" -SELECT "object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms", "labels" +SELECT "object_type", "id", "name", "payload", "created_at_ms", "updated_at_ms", "labels", "resource_version" FROM "objects" WHERE "object_type" = ?1 ORDER BY "created_at_ms" ASC, "name" ASC @@ -617,6 +766,7 @@ fn draft_chunk_dedup_key(chunk: &DraftChunkRecord) -> String { } fn row_to_object_record(row: sqlx::sqlite::SqliteRow) -> ObjectRecord { + let resource_version_i64: i64 = row.try_get("resource_version").unwrap_or(1); ObjectRecord { object_type: row.get("object_type"), id: row.get("id"), @@ -625,6 +775,7 @@ fn row_to_object_record(row: sqlx::sqlite::SqliteRow) -> ObjectRecord { created_at_ms: row.get("created_at_ms"), updated_at_ms: row.get("updated_at_ms"), labels: row.get("labels"), + resource_version: resource_version_i64.max(1).cast_unsigned(), } } diff --git a/crates/openshell-server/src/persistence/tests.rs b/crates/openshell-server/src/persistence/tests.rs index bef95d4b6..b2c0e3b63 100644 --- a/crates/openshell-server/src/persistence/tests.rs +++ b/crates/openshell-server/src/persistence/tests.rs @@ -785,3 +785,454 @@ fn parse_label_selector_handles_whitespace() { assert_eq!(result.get("env"), Some(&"prod".to_string())); assert_eq!(result.get("tier"), Some(&"frontend".to_string())); } + +// --------------------------------------------------------------------------- +// CAS (compare-and-swap) tests +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn cas_put_if_must_create_succeeds() { + use super::WriteCondition; + + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + let result = store + .put_if( + "sandbox", + "id-1", + "new-sandbox", + b"payload", + None, + WriteCondition::MustCreate, + ) + .await + .unwrap(); + + assert_eq!(result.resource_version, 1); + + let record = store.get("sandbox", "id-1").await.unwrap().unwrap(); + assert_eq!(record.resource_version, 1); + assert_eq!(record.payload, b"payload"); +} + +#[tokio::test] +async fn cas_put_if_must_create_fails_on_duplicate() { + use super::{PersistenceError, WriteCondition}; + + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + // First insert succeeds + store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"payload1", + None, + WriteCondition::MustCreate, + ) + .await + .unwrap(); + + // Second insert with same ID fails + let result = store + .put_if( + "sandbox", + "id-1", + "sandbox-2", + b"payload2", + None, + WriteCondition::MustCreate, + ) + .await; + + assert!(matches!( + result, + Err(PersistenceError::UniqueViolation { .. }) + )); +} + +#[tokio::test] +async fn cas_put_if_match_version_succeeds() { + use super::WriteCondition; + + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + // Create initial object + store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"v1", + None, + WriteCondition::MustCreate, + ) + .await + .unwrap(); + + // Update with correct version + let result = store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"v2", + None, + WriteCondition::MatchResourceVersion(1), + ) + .await + .unwrap(); + + assert_eq!(result.resource_version, 2); + + let record = store.get("sandbox", "id-1").await.unwrap().unwrap(); + assert_eq!(record.resource_version, 2); + assert_eq!(record.payload, b"v2"); +} + +#[tokio::test] +async fn cas_put_if_match_version_fails_on_mismatch() { + use super::{PersistenceError, WriteCondition}; + + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + // Create initial object + store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"v1", + None, + WriteCondition::MustCreate, + ) + .await + .unwrap(); + + // Update with wrong version + let result = store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"v2", + None, + WriteCondition::MatchResourceVersion(99), + ) + .await; + + assert!(matches!( + result, + Err(PersistenceError::Conflict { + current_resource_version: Some(1) + }) + )); + + // Original payload unchanged + let record = store.get("sandbox", "id-1").await.unwrap().unwrap(); + assert_eq!(record.resource_version, 1); + assert_eq!(record.payload, b"v1"); +} + +#[tokio::test] +async fn cas_delete_if_succeeds_with_correct_version() { + use super::WriteCondition; + + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"payload", + None, + WriteCondition::MustCreate, + ) + .await + .unwrap(); + + let deleted = store.delete_if("sandbox", "id-1", 1).await.unwrap(); + assert!(deleted); + + let record = store.get("sandbox", "id-1").await.unwrap(); + assert!(record.is_none()); +} + +#[tokio::test] +async fn cas_delete_if_fails_with_wrong_version() { + use super::{PersistenceError, WriteCondition}; + + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"payload", + None, + WriteCondition::MustCreate, + ) + .await + .unwrap(); + + let result = store.delete_if("sandbox", "id-1", 99).await; + assert!(matches!( + result, + Err(PersistenceError::Conflict { + current_resource_version: Some(1) + }) + )); + + // Object still exists + let record = store.get("sandbox", "id-1").await.unwrap().unwrap(); + assert_eq!(record.resource_version, 1); +} + +#[tokio::test] +async fn cas_resource_version_increments() { + use super::WriteCondition; + + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + // Create + let r1 = store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"v1", + None, + WriteCondition::MustCreate, + ) + .await + .unwrap(); + assert_eq!(r1.resource_version, 1); + + // Update 1 + let r2 = store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"v2", + None, + WriteCondition::MatchResourceVersion(1), + ) + .await + .unwrap(); + assert_eq!(r2.resource_version, 2); + + // Update 2 + let r3 = store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"v3", + None, + WriteCondition::MatchResourceVersion(2), + ) + .await + .unwrap(); + assert_eq!(r3.resource_version, 3); + + let record = store.get("sandbox", "id-1").await.unwrap().unwrap(); + assert_eq!(record.resource_version, 3); +} + +#[tokio::test] +async fn cas_concurrent_updates_one_succeeds() { + use super::WriteCondition; + use std::sync::Arc; + + let store = Arc::new( + Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(), + ); + + // Create initial object + store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + b"initial", + None, + WriteCondition::MustCreate, + ) + .await + .unwrap(); + + // Spawn 10 concurrent updates trying to update from version 1 + let mut handles = vec![]; + for i in 0..10 { + let store = Arc::clone(&store); + let handle = tokio::spawn(async move { + store + .put_if( + "sandbox", + "id-1", + "sandbox-1", + format!("update-{i}").as_bytes(), + None, + WriteCondition::MatchResourceVersion(1), + ) + .await + }); + handles.push(handle); + } + + let results: Vec<_> = futures::future::join_all(handles) + .await + .into_iter() + .map(|r| r.unwrap()) + .collect(); + + // Exactly one should succeed, rest should conflict + let successes = results.iter().filter(|r| r.is_ok()).count(); + let conflicts = results.iter().filter(|r| r.is_err()).count(); + + assert_eq!(successes, 1); + assert_eq!(conflicts, 9); + + // Final version should be 2 + let record = store.get("sandbox", "id-1").await.unwrap().unwrap(); + assert_eq!(record.resource_version, 2); +} + +#[tokio::test] +async fn cas_update_message_cas_succeeds() { + use openshell_core::proto::Sandbox; + + let store = Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(); + + // Create a sandbox + let sandbox = Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "test-id".to_string(), + name: "test-sandbox".to_string(), + created_at_ms: 1000, + labels: std::collections::HashMap::new(), + resource_version: 0, + }), + spec: None, + status: None, + phase: 0, + current_policy_version: 0, + }; + + store.put_message(&sandbox).await.unwrap(); + + // Update using CAS + let updated = store + .update_message_cas::("test-id", 5, |s| { + s.phase = 2; // Set to Ready + s.current_policy_version = 42; + }) + .await + .unwrap(); + + assert_eq!(updated.phase, 2); + assert_eq!(updated.current_policy_version, 42); + assert_eq!( + updated.metadata.as_ref().map_or(0, |m| m.resource_version), + 2 + ); +} + +#[tokio::test] +async fn cas_update_message_cas_retries_on_conflict() { + use openshell_core::proto::Sandbox; + use std::sync::Arc; + use std::sync::atomic::{AtomicU32, Ordering}; + + let store = Arc::new( + Store::connect("sqlite::memory:?cache=shared") + .await + .unwrap(), + ); + + // Create a sandbox + let sandbox = Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "test-id".to_string(), + name: "test-sandbox".to_string(), + created_at_ms: 1000, + labels: std::collections::HashMap::new(), + resource_version: 0, + }), + spec: None, + status: None, + phase: 0, + current_policy_version: 0, + }; + + store.put_message(&sandbox).await.unwrap(); + + // Track how many updates succeed + let success_count = Arc::new(AtomicU32::new(0)); + + // Spawn 5 concurrent CAS updates + let mut handles = vec![]; + for i in 0..5 { + let store = Arc::clone(&store); + let success_count = Arc::clone(&success_count); + let handle = tokio::spawn(async move { + let result = store + .update_message_cas::("test-id", 5, |s| { + s.current_policy_version = i; + }) + .await; + if result.is_ok() { + success_count.fetch_add(1, Ordering::SeqCst); + } + result + }); + handles.push(handle); + } + + let results: Vec<_> = futures::future::join_all(handles) + .await + .into_iter() + .map(|r| r.unwrap()) + .collect(); + + // All should succeed due to retry + let successes = results.iter().filter(|r| r.is_ok()).count(); + assert_eq!(successes, 5); + assert_eq!(success_count.load(Ordering::SeqCst), 5); + + // Final version should be 6 (initial 1 + 5 updates) + let final_sandbox = store + .get_message::("test-id") + .await + .unwrap() + .unwrap(); + assert_eq!( + final_sandbox + .metadata + .as_ref() + .map_or(0, |m| m.resource_version), + 6 + ); +} diff --git a/crates/openshell-server/src/ssh_tunnel.rs b/crates/openshell-server/src/ssh_tunnel.rs index bd317d53f..dd5e0f820 100644 --- a/crates/openshell-server/src/ssh_tunnel.rs +++ b/crates/openshell-server/src/ssh_tunnel.rs @@ -330,6 +330,7 @@ mod tests { name: format!("session-{id}"), created_at_ms: 1000, labels: HashMap::new(), + resource_version: 0, }), sandbox_id: sandbox_id.to_string(), token: id.to_string(), diff --git a/crates/openshell-server/src/supervisor_session.rs b/crates/openshell-server/src/supervisor_session.rs index 94c352ba5..5a3b7fe54 100644 --- a/crates/openshell-server/src/supervisor_session.rs +++ b/crates/openshell-server/src/supervisor_session.rs @@ -740,6 +740,7 @@ mod tests { name: name.to_string(), created_at_ms: 1_000_000, labels: HashMap::new(), + resource_version: 0, }), ..Default::default() } diff --git a/crates/openshell-tui/src/lib.rs b/crates/openshell-tui/src/lib.rs index 8571ebbe1..69ee206da 100644 --- a/crates/openshell-tui/src/lib.rs +++ b/crates/openshell-tui/src/lib.rs @@ -1565,6 +1565,7 @@ fn spawn_create_provider(app: &App, tx: mpsc::UnboundedSender) { name: provider_name.clone(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: ptype.clone(), credentials: credentials.clone(), @@ -1655,6 +1656,7 @@ fn spawn_update_provider(app: &App, tx: mpsc::UnboundedSender) { name: name.clone(), created_at_ms: 0, labels: HashMap::new(), + resource_version: 0, }), r#type: ptype, credentials, diff --git a/proto/datamodel.proto b/proto/datamodel.proto index 534b043ae..462088124 100644 --- a/proto/datamodel.proto +++ b/proto/datamodel.proto @@ -8,7 +8,7 @@ package openshell.datamodel.v1; // Kubernetes-style metadata shared by all top-level OpenShell domain objects. // // This structure provides consistent metadata (identity, labels, timestamps, -// versioning) across Sandbox, Provider, SshSession, and other resources. +// resource versioning) across Sandbox, Provider, SshSession, and other resources. message ObjectMeta { // Stable object ID generated by the gateway. string id = 1; @@ -22,6 +22,10 @@ message ObjectMeta { // Key-value labels for filtering and organization. // Labels must follow Kubernetes conventions: alphanumeric + `-._/`, max 63 chars per segment. map labels = 4; + + // Optimistic concurrency control version. + // Incremented by the gateway on each update. Clients can use this for compare-and-swap operations. + uint64 resource_version = 5; } // Provider model stored by OpenShell.