From 71bd5162905d24dd2242ce8cee3f94f4f4f3c98b Mon Sep 17 00:00:00 2001 From: Aaron Erickson Date: Fri, 8 May 2026 19:42:02 -0700 Subject: [PATCH 01/17] fix(sandbox): rewrite credential placeholders in websocket text frames Signed-off-by: Aaron Erickson --- architecture/security-policy.md | 10 +- crates/openshell-policy/src/lib.rs | 53 ++ crates/openshell-policy/src/merge.rs | 2 + crates/openshell-providers/src/profiles.rs | 4 + crates/openshell-sandbox/src/l7/mod.rs | 61 ++ crates/openshell-sandbox/src/l7/relay.rs | 113 ++- crates/openshell-sandbox/src/l7/rest.rs | 136 +++- crates/openshell-sandbox/src/l7/websocket.rs | 784 +++++++++++++++++++ crates/openshell-sandbox/src/opa.rs | 60 ++ crates/openshell-sandbox/src/policy_local.rs | 1 + crates/openshell-sandbox/src/proxy.rs | 40 +- crates/openshell-sandbox/src/secrets.rs | 115 +++ docs/reference/policy-schema.mdx | 1 + docs/sandboxes/policies.mdx | 6 +- docs/security/best-practices.mdx | 2 +- proto/sandbox.proto | 4 + 16 files changed, 1361 insertions(+), 31 deletions(-) create mode 100644 crates/openshell-sandbox/src/l7/websocket.rs diff --git a/architecture/security-policy.md b/architecture/security-policy.md index e5f179dc1..5c04bebf5 100644 --- a/architecture/security-policy.md +++ b/architecture/security-policy.md @@ -43,9 +43,13 @@ with the sandbox's ephemeral CA and inspect method/path or protocol-specific metadata before forwarding. The proxy also supports credential injection on terminated HTTP streams when policy allows the endpoint. -Raw streams, HTTP upgrades, and long-lived response bodies are connection -scoped. Policy reloads affect the next connection or the next parsed HTTP -request; they do not rewrite bytes already being relayed. +Raw streams and long-lived response bodies are connection scoped. Policy +reloads affect the next connection or the next parsed HTTP request; they do not +rewrite bytes already being relayed. HTTP upgrades switch to raw relay by +default. A `protocol: rest` endpoint can opt in to +`websocket_credential_rewrite` for client-to-server WebSocket text messages +after an allowed `101` upgrade; server-to-client traffic and all other upgraded +protocols remain raw passthrough. ## Live Updates diff --git a/crates/openshell-policy/src/lib.rs b/crates/openshell-policy/src/lib.rs index 61df0aadb..0eb42b647 100644 --- a/crates/openshell-policy/src/lib.rs +++ b/crates/openshell-policy/src/lib.rs @@ -120,6 +120,11 @@ struct NetworkEndpointDef { /// Defaults to false (strict). #[serde(default, skip_serializing_if = "std::ops::Not::not")] allow_encoded_slash: bool, + /// When true, client-to-server WebSocket text messages on this REST + /// endpoint rewrite credential placeholders after an allowed 101 upgrade. + /// Defaults to false. + #[serde(default, skip_serializing_if = "std::ops::Not::not")] + websocket_credential_rewrite: bool, #[serde(default, skip_serializing_if = "String::is_empty")] persisted_queries: String, #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] @@ -317,6 +322,7 @@ fn to_proto(raw: PolicyFile) -> SandboxPolicy { }) .collect(), allow_encoded_slash: e.allow_encoded_slash, + websocket_credential_rewrite: e.websocket_credential_rewrite, persisted_queries: e.persisted_queries, graphql_persisted_queries: e .graphql_persisted_queries @@ -480,6 +486,7 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { }) .collect(), allow_encoded_slash: e.allow_encoded_slash, + websocket_credential_rewrite: e.websocket_credential_rewrite, persisted_queries: e.persisted_queries.clone(), graphql_persisted_queries: e .graphql_persisted_queries @@ -1656,6 +1663,52 @@ network_policies: assert_eq!(ep.deny_rules[0].fields, vec!["deleteRepository"]); } + #[test] + fn round_trip_preserves_websocket_credential_rewrite() { + let yaml = r" +version: 1 +network_policies: + discord_gateway: + name: discord_gateway + endpoints: + - host: gateway.example.com + port: 443 + protocol: rest + enforcement: enforce + access: full + websocket_credential_rewrite: true + binaries: + - path: /usr/bin/node +"; + let proto1 = parse_sandbox_policy(yaml).expect("parse failed"); + let yaml_out = serialize_sandbox_policy(&proto1).expect("serialize failed"); + let proto2 = parse_sandbox_policy(&yaml_out).expect("re-parse failed"); + + let ep = &proto2.network_policies["discord_gateway"].endpoints[0]; + assert_eq!(ep.protocol, "rest"); + assert!(ep.websocket_credential_rewrite); + assert!(yaml_out.contains("websocket_credential_rewrite: true")); + } + + #[test] + fn websocket_credential_rewrite_defaults_false() { + let yaml = r" +version: 1 +network_policies: + gateway: + endpoints: + - host: gateway.example.com + port: 443 + protocol: rest + access: full + binaries: + - path: /usr/bin/node +"; + let proto = parse_sandbox_policy(yaml).expect("parse failed"); + let ep = &proto.network_policies["gateway"].endpoints[0]; + assert!(!ep.websocket_credential_rewrite); + } + #[test] fn parse_rejects_unknown_fields_in_deny_rule() { let yaml = r" diff --git a/crates/openshell-policy/src/merge.rs b/crates/openshell-policy/src/merge.rs index 7a5dec916..ca4748b5b 100644 --- a/crates/openshell-policy/src/merge.rs +++ b/crates/openshell-policy/src/merge.rs @@ -462,6 +462,8 @@ fn merge_endpoint( append_unique_deny_rules(&mut existing.deny_rules, &incoming.deny_rules); append_unique_strings(&mut existing.allowed_ips, &incoming.allowed_ips); + existing.allow_encoded_slash |= incoming.allow_encoded_slash; + existing.websocket_credential_rewrite |= incoming.websocket_credential_rewrite; normalize_endpoint(existing); Ok(()) } diff --git a/crates/openshell-providers/src/profiles.rs b/crates/openshell-providers/src/profiles.rs index 8c3f247cf..c15fd0dac 100644 --- a/crates/openshell-providers/src/profiles.rs +++ b/crates/openshell-providers/src/profiles.rs @@ -114,6 +114,8 @@ pub struct EndpointProfile { pub deny_rules: Vec, #[serde(default, skip_serializing_if = "is_false")] pub allow_encoded_slash: bool, + #[serde(default, skip_serializing_if = "is_false")] + pub websocket_credential_rewrite: bool, #[serde(default, skip_serializing_if = "String::is_empty")] pub persisted_queries: String, #[serde(default, skip_serializing_if = "HashMap::is_empty")] @@ -414,6 +416,7 @@ fn endpoint_to_proto(endpoint: &EndpointProfile) -> NetworkEndpoint { ports: endpoint.ports.clone(), deny_rules: endpoint.deny_rules.iter().map(deny_rule_to_proto).collect(), allow_encoded_slash: endpoint.allow_encoded_slash, + websocket_credential_rewrite: endpoint.websocket_credential_rewrite, persisted_queries: endpoint.persisted_queries.clone(), graphql_persisted_queries: endpoint .graphql_persisted_queries @@ -442,6 +445,7 @@ fn endpoint_from_proto(endpoint: &NetworkEndpoint) -> EndpointProfile { .map(deny_rule_from_proto) .collect(), allow_encoded_slash: endpoint.allow_encoded_slash, + websocket_credential_rewrite: endpoint.websocket_credential_rewrite, persisted_queries: endpoint.persisted_queries.clone(), graphql_persisted_queries: endpoint .graphql_persisted_queries diff --git a/crates/openshell-sandbox/src/l7/mod.rs b/crates/openshell-sandbox/src/l7/mod.rs index 5301ac4d5..153611726 100644 --- a/crates/openshell-sandbox/src/l7/mod.rs +++ b/crates/openshell-sandbox/src/l7/mod.rs @@ -15,6 +15,7 @@ pub mod provider; pub mod relay; pub mod rest; pub mod tls; +pub(crate) mod websocket; /// Application-layer protocol for L7 inspection. #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -72,6 +73,9 @@ pub struct L7EndpointConfig { /// rather than rejected at the parser. Needed by upstreams like GitLab /// that embed `%2F` in namespaced project paths. Defaults to false. pub allow_encoded_slash: bool, + /// Opt-in rewrite of credential placeholders in client-to-server + /// WebSocket text messages after an allowed HTTP 101 upgrade. + pub websocket_credential_rewrite: bool, } /// Result of an L7 policy decision for a single request. @@ -138,6 +142,8 @@ pub fn parse_l7_config(val: ®orus::Value) -> Option { }; let allow_encoded_slash = get_object_bool(val, "allow_encoded_slash").unwrap_or(false); + let websocket_credential_rewrite = + get_object_bool(val, "websocket_credential_rewrite").unwrap_or(false); let graphql_max_body_bytes = get_object_u64(val, "graphql_max_body_bytes") .and_then(|v| usize::try_from(v).ok()) .filter(|v| *v > 0) @@ -150,6 +156,7 @@ pub fn parse_l7_config(val: ®orus::Value) -> Option { enforcement, graphql_max_body_bytes, allow_encoded_slash, + websocket_credential_rewrite, }) } @@ -498,6 +505,17 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< )); } + if ep + .get("websocket_credential_rewrite") + .and_then(serde_json::Value::as_bool) + .unwrap_or(false) + && protocol != "rest" + { + warnings.push(format!( + "{loc}: websocket_credential_rewrite is ignored unless protocol is rest" + )); + } + if let Some(registry_value) = ep.get("graphql_persisted_queries") { let Some(registry) = registry_value.as_object() else { errors.push(format!( @@ -1031,6 +1049,49 @@ mod tests { assert!(config.allow_encoded_slash); } + #[test] + fn parse_l7_config_websocket_credential_rewrite_defaults_false() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "rest", "host": "gateway.example.com", "port": 443}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(!config.websocket_credential_rewrite); + } + + #[test] + fn parse_l7_config_websocket_credential_rewrite_opt_in() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "rest", "host": "gateway.example.com", "port": 443, "websocket_credential_rewrite": true}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(config.websocket_credential_rewrite); + } + + #[test] + fn validate_websocket_credential_rewrite_warns_unless_rest() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "websocket_credential_rewrite": true + }], + "binaries": [] + } + } + }); + let (_errors, warnings) = validate_l7_policies(&data); + assert!( + warnings + .iter() + .any(|w| w.contains("websocket_credential_rewrite is ignored")), + "expected websocket_credential_rewrite warning: {warnings:?}" + ); + } + #[test] fn validate_rules_and_access_mutual_exclusion() { let data = serde_json::json!({ diff --git a/crates/openshell-sandbox/src/l7/relay.rs b/crates/openshell-sandbox/src/l7/relay.rs index f099c3558..244f3cf99 100644 --- a/crates/openshell-sandbox/src/l7/relay.rs +++ b/crates/openshell-sandbox/src/l7/relay.rs @@ -38,6 +38,14 @@ pub struct L7EvalContext { pub(crate) secret_resolver: Option>, } +#[derive(Clone, Default)] +pub(crate) struct UpgradeRelayOptions { + pub(crate) websocket_request: bool, + pub(crate) websocket_credential_rewrite: bool, + pub(crate) secret_resolver: Option>, + pub(crate) policy_name: String, +} + #[derive(Debug, Clone, Copy)] enum ParseRejectionMode { L7Endpoint, @@ -282,19 +290,32 @@ where let _ = &eval_target; if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { - let outcome = crate::l7::rest::relay_http_request_with_resolver_guarded( + let websocket_request = crate::l7::rest::request_is_websocket_upgrade(&req.raw_header); + let outcome = crate::l7::rest::relay_http_request_with_options_guarded( &req, client, upstream, - ctx.secret_resolver.as_deref(), - Some(engine.generation_guard()), + crate::l7::rest::RelayRequestOptions { + resolver: ctx.secret_resolver.as_deref(), + generation_guard: Some(engine.generation_guard()), + strip_websocket_extensions: config.protocol == L7Protocol::Rest + && config.websocket_credential_rewrite, + }, ) .await?; match outcome { RelayOutcome::Reusable => {} RelayOutcome::Consumed => return Ok(()), RelayOutcome::Upgraded { overflow } => { - return handle_upgrade(client, upstream, overflow, &ctx.host, ctx.port).await; + return handle_upgrade( + client, + upstream, + overflow, + &ctx.host, + ctx.port, + upgrade_options(config, ctx, websocket_request), + ) + .await; } } } else { @@ -383,11 +404,20 @@ pub(crate) async fn handle_upgrade( overflow: Vec, host: &str, port: u16, + options: UpgradeRelayOptions, ) -> Result<()> where C: AsyncRead + AsyncWrite + Unpin + Send, U: AsyncRead + AsyncWrite + Unpin + Send, { + let use_websocket_rewrite = options.websocket_request + && options.websocket_credential_rewrite + && options.secret_resolver.is_some(); + let relay_mode = if use_websocket_rewrite { + "websocket credential rewrite relay" + } else { + "raw bidirectional relay (L7 enforcement no longer active)" + }; ocsf_emit!( NetworkActivityBuilder::new(crate::ocsf_ctx()) .activity(ActivityId::Other) @@ -395,12 +425,23 @@ where .severity(SeverityId::Informational) .dst_endpoint(Endpoint::from_domain(host, port)) .message(format!( - "101 Switching Protocols โ€” raw bidirectional relay (L7 enforcement no longer active) \ - [host:{host} port:{port} overflow_bytes:{}]", + "101 Switching Protocols โ€” {relay_mode} [host:{host} port:{port} overflow_bytes:{}]", overflow.len() )) .build() ); + if use_websocket_rewrite && let Some(resolver) = options.secret_resolver.as_deref() { + return crate::l7::websocket::relay_with_credential_rewrite( + client, + upstream, + overflow, + host, + port, + &options.policy_name, + resolver, + ) + .await; + } if !overflow.is_empty() { client.write_all(&overflow).await.into_diagnostic()?; client.flush().await.into_diagnostic()?; @@ -411,6 +452,25 @@ where Ok(()) } +fn upgrade_options( + config: &L7EndpointConfig, + ctx: &L7EvalContext, + websocket_request: bool, +) -> UpgradeRelayOptions { + let websocket_credential_rewrite = + config.protocol == L7Protocol::Rest && config.websocket_credential_rewrite; + UpgradeRelayOptions { + websocket_request, + websocket_credential_rewrite, + secret_resolver: if websocket_credential_rewrite { + ctx.secret_resolver.clone() + } else { + None + }, + policy_name: ctx.policy_name.clone(), + } +} + /// REST relay loop: parse request -> evaluate -> allow/deny -> relay response -> repeat. async fn relay_rest( config: &L7EndpointConfig, @@ -558,12 +618,17 @@ where if allowed || config.enforcement == EnforcementMode::Audit { // Forward request to upstream and relay response - let outcome = crate::l7::rest::relay_http_request_with_resolver_guarded( + let websocket_request = crate::l7::rest::request_is_websocket_upgrade(&req.raw_header); + let outcome = crate::l7::rest::relay_http_request_with_options_guarded( &req, client, upstream, - ctx.secret_resolver.as_deref(), - Some(engine.generation_guard()), + crate::l7::rest::RelayRequestOptions { + resolver: ctx.secret_resolver.as_deref(), + generation_guard: Some(engine.generation_guard()), + strip_websocket_extensions: config.protocol == L7Protocol::Rest + && config.websocket_credential_rewrite, + }, ) .await?; match outcome { @@ -577,7 +642,15 @@ where return Ok(()); } RelayOutcome::Upgraded { overflow } => { - return handle_upgrade(client, upstream, overflow, &ctx.host, ctx.port).await; + return handle_upgrade( + client, + upstream, + overflow, + &ctx.host, + ctx.port, + upgrade_options(config, ctx, websocket_request), + ) + .await; } } } else { @@ -788,7 +861,15 @@ where return Ok(()); } RelayOutcome::Upgraded { overflow } => { - return handle_upgrade(client, upstream, overflow, &ctx.host, ctx.port).await; + return handle_upgrade( + client, + upstream, + overflow, + &ctx.host, + ctx.port, + UpgradeRelayOptions::default(), + ) + .await; } } } else { @@ -1029,7 +1110,15 @@ where RelayOutcome::Reusable => {} // continue loop RelayOutcome::Consumed => break, RelayOutcome::Upgraded { overflow } => { - return handle_upgrade(client, upstream, overflow, &ctx.host, ctx.port).await; + return handle_upgrade( + client, + upstream, + overflow, + &ctx.host, + ctx.port, + UpgradeRelayOptions::default(), + ) + .await; } } } diff --git a/crates/openshell-sandbox/src/l7/rest.rs b/crates/openshell-sandbox/src/l7/rest.rs index 85ae01290..53fc60597 100644 --- a/crates/openshell-sandbox/src/l7/rest.rs +++ b/crates/openshell-sandbox/src/l7/rest.rs @@ -9,7 +9,7 @@ use crate::l7::provider::{BodyLength, L7Provider, L7Request, RelayOutcome}; use crate::opa::PolicyGenerationGuard; -use crate::secrets::rewrite_http_header_block; +use crate::secrets::{SecretResolver, rewrite_http_header_block}; use miette::{IntoDiagnostic, Result, miette}; use std::collections::HashMap; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; @@ -343,7 +343,7 @@ pub(crate) async fn relay_http_request_with_resolver( req: &L7Request, client: &mut C, upstream: &mut U, - resolver: Option<&crate::secrets::SecretResolver>, + resolver: Option<&SecretResolver>, ) -> Result where C: AsyncRead + AsyncWrite + Unpin, @@ -356,9 +356,39 @@ pub(crate) async fn relay_http_request_with_resolver_guarded( req: &L7Request, client: &mut C, upstream: &mut U, - resolver: Option<&crate::secrets::SecretResolver>, + resolver: Option<&SecretResolver>, generation_guard: Option<&PolicyGenerationGuard>, ) -> Result +where + C: AsyncRead + AsyncWrite + Unpin, + U: AsyncRead + AsyncWrite + Unpin, +{ + relay_http_request_with_options_guarded( + req, + client, + upstream, + RelayRequestOptions { + resolver, + generation_guard, + strip_websocket_extensions: false, + }, + ) + .await +} + +#[derive(Clone, Copy, Default)] +pub(crate) struct RelayRequestOptions<'a> { + pub(crate) resolver: Option<&'a SecretResolver>, + pub(crate) generation_guard: Option<&'a PolicyGenerationGuard>, + pub(crate) strip_websocket_extensions: bool, +} + +pub(crate) async fn relay_http_request_with_options_guarded( + req: &L7Request, + client: &mut C, + upstream: &mut U, + options: RelayRequestOptions<'_>, +) -> Result where C: AsyncRead + AsyncWrite + Unpin, U: AsyncRead + AsyncWrite + Unpin, @@ -369,10 +399,16 @@ where .position(|w| w == b"\r\n\r\n") .map_or(req.raw_header.len(), |p| p + 4); - let rewrite_result = rewrite_http_header_block(&req.raw_header[..header_end], resolver) + let header_bytes = if options.strip_websocket_extensions { + strip_websocket_extensions_if_requested(&req.raw_header[..header_end])? + } else { + req.raw_header[..header_end].to_vec() + }; + + let rewrite_result = rewrite_http_header_block(&header_bytes, options.resolver) .map_err(|e| miette!("credential injection failed: {e}"))?; - if let Some(guard) = generation_guard { + if let Some(guard) = options.generation_guard { guard.ensure_current()?; } @@ -383,7 +419,7 @@ where let overflow = &req.raw_header[header_end..]; if !overflow.is_empty() { - if let Some(guard) = generation_guard { + if let Some(guard) = options.generation_guard { guard.ensure_current()?; } upstream.write_all(overflow).await.into_diagnostic()?; @@ -394,7 +430,7 @@ where BodyLength::ContentLength(len) => { let remaining = len.saturating_sub(overflow_len); if remaining > 0 { - relay_fixed(client, upstream, remaining, generation_guard).await?; + relay_fixed(client, upstream, remaining, options.generation_guard).await?; } } BodyLength::Chunked => { @@ -402,7 +438,7 @@ where client, upstream, &req.raw_header[header_end..], - generation_guard, + options.generation_guard, ) .await?; } @@ -452,6 +488,35 @@ where Ok(outcome) } +pub(crate) fn request_is_websocket_upgrade(raw_header: &[u8]) -> bool { + let header_end = raw_header + .windows(4) + .position(|w| w == b"\r\n\r\n") + .map_or(raw_header.len(), |p| p + 4); + std::str::from_utf8(&raw_header[..header_end]).is_ok_and(client_requested_websocket_upgrade) +} + +fn strip_websocket_extensions_if_requested(raw_header: &[u8]) -> Result> { + let header_str = std::str::from_utf8(raw_header) + .map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + if !client_requested_websocket_upgrade(header_str) { + return Ok(raw_header.to_vec()); + } + + let mut out = Vec::with_capacity(raw_header.len()); + for line in header_str.split_inclusive("\r\n") { + let bare = line.strip_suffix("\r\n").unwrap_or(line); + if bare + .to_ascii_lowercase() + .starts_with("sec-websocket-extensions:") + { + continue; + } + out.extend_from_slice(line.as_bytes()); + } + Ok(out) +} + /// Send a 403 Forbidden JSON deny response. /// /// When `redacted_target` is provided, it is used instead of `req.target` @@ -965,6 +1030,29 @@ fn client_requested_upgrade(headers: &str) -> bool { has_upgrade_header && connection_contains_upgrade } +fn client_requested_websocket_upgrade(headers: &str) -> bool { + let mut upgrade_is_websocket = false; + let mut connection_contains_upgrade = false; + + for line in headers.lines().skip(1) { + let lower = line.to_ascii_lowercase(); + if lower.starts_with("upgrade:") { + let val = lower.split_once(':').map_or("", |(_, v)| v.trim()); + if val == "websocket" { + upgrade_is_websocket = true; + } + } + if lower.starts_with("connection:") { + let val = lower.split_once(':').map_or("", |(_, v)| v.trim()); + if val.split(',').any(|tok| tok.trim() == "upgrade") { + connection_contains_upgrade = true; + } + } + } + + upgrade_is_websocket && connection_contains_upgrade +} + /// Returns true for responses that MUST NOT contain a message body per RFC 7230 ยง3.3.3: /// HEAD responses, 1xx informational, 204 No Content, 304 Not Modified. fn is_bodiless_response(request_method: &str, status_code: u16) -> bool { @@ -2243,6 +2331,38 @@ mod tests { assert!(client_requested_upgrade(headers)); } + #[test] + fn request_is_websocket_upgrade_detects_websocket_upgrade() { + let raw = b"GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: keep-alive, Upgrade\r\n\r\n"; + assert!(request_is_websocket_upgrade(raw)); + } + + #[test] + fn strip_websocket_extensions_removes_extension_negotiation() { + let raw = b"GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\nSec-WebSocket-Version: 13\r\n\r\n"; + + let stripped = strip_websocket_extensions_if_requested(raw).expect("strip should succeed"); + let stripped = String::from_utf8(stripped).unwrap(); + + assert!(stripped.contains("Upgrade: websocket\r\n")); + assert!(stripped.contains("Sec-WebSocket-Version: 13\r\n")); + assert!( + !stripped + .to_ascii_lowercase() + .contains("sec-websocket-extensions") + ); + assert!(stripped.ends_with("\r\n\r\n")); + } + + #[test] + fn strip_websocket_extensions_leaves_non_websocket_request_unchanged() { + let raw = b"GET /api HTTP/1.1\r\nHost: example.com\r\nSec-WebSocket-Extensions: permessage-deflate\r\n\r\n"; + + let stripped = strip_websocket_extensions_if_requested(raw).expect("strip should succeed"); + + assert_eq!(stripped, raw); + } + #[test] fn rewrite_header_block_resolves_placeholder_auth_headers() { let (_, resolver) = SecretResolver::from_provider_env( diff --git a/crates/openshell-sandbox/src/l7/websocket.rs b/crates/openshell-sandbox/src/l7/websocket.rs new file mode 100644 index 000000000..0c04f9ea1 --- /dev/null +++ b/crates/openshell-sandbox/src/l7/websocket.rs @@ -0,0 +1,784 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Minimal WebSocket relay for opt-in credential placeholder rewriting. +//! +//! The relay parses only client-to-server frames. Server-to-client bytes stay +//! raw passthrough so this remains a narrow post-upgrade credential boundary, +//! not a general WebSocket inspection engine. + +use crate::secrets::SecretResolver; +use miette::{IntoDiagnostic, Result, miette}; +use openshell_ocsf::{ + ActionId, ActivityId, DispositionId, Endpoint, NetworkActivityBuilder, SeverityId, StatusId, + ocsf_emit, +}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +const MAX_TEXT_MESSAGE_BYTES: usize = 1024 * 1024; +const COPY_BUF_SIZE: usize = 8192; +const OPCODE_CONTINUATION: u8 = 0x0; +const OPCODE_TEXT: u8 = 0x1; +const OPCODE_BINARY: u8 = 0x2; +const OPCODE_CLOSE: u8 = 0x8; +const OPCODE_PING: u8 = 0x9; +const OPCODE_PONG: u8 = 0xA; + +#[derive(Debug)] +struct FrameHeader { + fin: bool, + rsv: u8, + opcode: u8, + masked: bool, + payload_len: u64, + mask_key: Option<[u8; 4]>, + raw_header: Vec, +} + +#[derive(Debug)] +enum FragmentState { + None, + Text { payload: Vec }, + Binary, +} + +/// Relay an upgraded WebSocket connection, rewriting credential placeholders +/// in client-to-server UTF-8 text messages. +pub(super) async fn relay_with_credential_rewrite( + client: &mut C, + upstream: &mut U, + overflow: Vec, + host: &str, + port: u16, + policy_name: &str, + resolver: &SecretResolver, +) -> Result<()> +where + C: AsyncRead + AsyncWrite + Unpin + Send, + U: AsyncRead + AsyncWrite + Unpin + Send, +{ + let (mut client_read, mut client_write) = tokio::io::split(client); + let (mut upstream_read, mut upstream_write) = tokio::io::split(upstream); + + if !overflow.is_empty() { + client_write.write_all(&overflow).await.into_diagnostic()?; + client_write.flush().await.into_diagnostic()?; + } + + let client_to_server = relay_client_to_server( + &mut client_read, + &mut upstream_write, + host, + port, + policy_name, + resolver, + ); + let server_to_client = async { + tokio::io::copy(&mut upstream_read, &mut client_write) + .await + .into_diagnostic()?; + client_write.flush().await.into_diagnostic()?; + Ok::<(), miette::Report>(()) + }; + + tokio::select! { + result = client_to_server => result, + result = server_to_client => result, + } +} + +async fn relay_client_to_server( + reader: &mut R, + writer: &mut W, + host: &str, + port: u16, + policy_name: &str, + resolver: &SecretResolver, +) -> Result<()> +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + let mut fragments = FragmentState::None; + + loop { + let Some(frame) = read_frame_header(reader).await.inspect_err(|e| { + emit_protocol_failure(host, port, policy_name, &e.to_string()); + })? + else { + writer.shutdown().await.into_diagnostic()?; + return Ok(()); + }; + + if let Err(e) = validate_frame_header(&frame, &fragments) { + emit_protocol_failure(host, port, policy_name, &e.to_string()); + return Err(e); + } + + match frame.opcode { + OPCODE_TEXT => { + let payload = read_masked_payload(reader, &frame).await.inspect_err(|e| { + emit_protocol_failure(host, port, policy_name, &e.to_string()); + })?; + if frame.fin { + relay_text_payload( + writer, + &frame, + payload, + false, + host, + port, + policy_name, + resolver, + ) + .await + .inspect_err(|e| { + emit_protocol_failure(host, port, policy_name, &e.to_string()); + })?; + } else { + fragments = FragmentState::Text { payload }; + } + } + OPCODE_CONTINUATION => match &mut fragments { + FragmentState::Text { payload } => { + let next = read_masked_payload(reader, &frame).await.inspect_err(|e| { + emit_protocol_failure(host, port, policy_name, &e.to_string()); + })?; + if let Err(e) = append_text_fragment(payload, next) { + emit_protocol_failure(host, port, policy_name, &e.to_string()); + return Err(e); + } + if frame.fin { + let complete = std::mem::take(payload); + fragments = FragmentState::None; + relay_text_payload( + writer, + &frame, + complete, + true, + host, + port, + policy_name, + resolver, + ) + .await + .inspect_err(|e| { + emit_protocol_failure(host, port, policy_name, &e.to_string()); + })?; + } + } + FragmentState::Binary => { + copy_raw_frame_payload(reader, writer, &frame).await?; + if frame.fin { + fragments = FragmentState::None; + } + } + FragmentState::None => { + let e = + miette!("websocket continuation frame without active fragmented message"); + emit_protocol_failure(host, port, policy_name, &e.to_string()); + return Err(e); + } + }, + OPCODE_BINARY => { + if !frame.fin { + fragments = FragmentState::Binary; + } + copy_raw_frame_payload(reader, writer, &frame).await?; + } + OPCODE_CLOSE | OPCODE_PING | OPCODE_PONG => { + copy_raw_frame_payload(reader, writer, &frame).await?; + } + _ => unreachable!("validated opcode"), + } + } +} + +async fn read_frame_header(reader: &mut R) -> Result> { + let first = match reader.read_u8().await { + Ok(byte) => byte, + Err(e) + if matches!( + e.kind(), + std::io::ErrorKind::UnexpectedEof + | std::io::ErrorKind::ConnectionReset + | std::io::ErrorKind::BrokenPipe + ) => + { + return Ok(None); + } + Err(e) => return Err(miette!("{e}")), + }; + let second = reader + .read_u8() + .await + .map_err(|e| miette!("malformed websocket frame header: {e}"))?; + + let mut raw_header = vec![first, second]; + let len_code = second & 0x7F; + let payload_len = match len_code { + 0..=125 => u64::from(len_code), + 126 => { + let mut bytes = [0u8; 2]; + reader + .read_exact(&mut bytes) + .await + .map_err(|e| miette!("malformed websocket extended length: {e}"))?; + raw_header.extend_from_slice(&bytes); + u64::from(u16::from_be_bytes(bytes)) + } + 127 => { + let mut bytes = [0u8; 8]; + reader + .read_exact(&mut bytes) + .await + .map_err(|e| miette!("malformed websocket extended length: {e}"))?; + if bytes[0] & 0x80 != 0 { + return Err(miette!("websocket frame uses non-canonical 64-bit length")); + } + raw_header.extend_from_slice(&bytes); + u64::from_be_bytes(bytes) + } + _ => unreachable!("7-bit length code"), + }; + + let masked = second & 0x80 != 0; + let mask_key = if masked { + let mut key = [0u8; 4]; + reader + .read_exact(&mut key) + .await + .map_err(|e| miette!("malformed websocket mask key: {e}"))?; + raw_header.extend_from_slice(&key); + Some(key) + } else { + None + }; + + Ok(Some(FrameHeader { + fin: first & 0x80 != 0, + rsv: first & 0x70, + opcode: first & 0x0F, + masked, + payload_len, + mask_key, + raw_header, + })) +} + +fn validate_frame_header(frame: &FrameHeader, fragments: &FragmentState) -> Result<()> { + if frame.rsv != 0 { + return Err(miette!( + "websocket frame has RSV bits set; compression/extensions are not supported" + )); + } + if !frame.masked { + return Err(miette!("websocket client frame is not masked")); + } + if !matches!( + frame.opcode, + OPCODE_CONTINUATION + | OPCODE_TEXT + | OPCODE_BINARY + | OPCODE_CLOSE + | OPCODE_PING + | OPCODE_PONG + ) { + return Err(miette!("websocket frame uses reserved opcode")); + } + if matches!(frame.opcode, OPCODE_CLOSE | OPCODE_PING | OPCODE_PONG) { + if !frame.fin { + return Err(miette!("websocket control frame is fragmented")); + } + if frame.payload_len > 125 { + return Err(miette!("websocket control frame exceeds 125 bytes")); + } + } + if matches!(frame.opcode, OPCODE_TEXT | OPCODE_BINARY) + && !matches!(fragments, FragmentState::None) + { + return Err(miette!( + "websocket data frame started before previous fragmented message completed" + )); + } + if matches!(frame.opcode, OPCODE_CONTINUATION) && matches!(fragments, FragmentState::None) { + return Err(miette!( + "websocket continuation frame without active fragmented message" + )); + } + Ok(()) +} + +async fn read_masked_payload( + reader: &mut R, + frame: &FrameHeader, +) -> Result> { + let payload_len = usize::try_from(frame.payload_len) + .map_err(|_| miette!("websocket text frame is too large to buffer"))?; + if payload_len > MAX_TEXT_MESSAGE_BYTES { + return Err(miette!( + "websocket text message exceeds {MAX_TEXT_MESSAGE_BYTES} byte limit" + )); + } + let mut payload = vec![0u8; payload_len]; + reader + .read_exact(&mut payload) + .await + .map_err(|e| miette!("malformed websocket payload: {e}"))?; + let mask_key = frame + .mask_key + .ok_or_else(|| miette!("websocket client frame is not masked"))?; + apply_mask(&mut payload, mask_key); + Ok(payload) +} + +fn append_text_fragment(buffer: &mut Vec, next: Vec) -> Result<()> { + let new_len = buffer + .len() + .checked_add(next.len()) + .ok_or_else(|| miette!("websocket text message length overflow"))?; + if new_len > MAX_TEXT_MESSAGE_BYTES { + return Err(miette!( + "websocket text message exceeds {MAX_TEXT_MESSAGE_BYTES} byte limit" + )); + } + buffer.extend_from_slice(&next); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +async fn relay_text_payload( + writer: &mut W, + frame: &FrameHeader, + payload: Vec, + force_reframe: bool, + host: &str, + port: u16, + policy_name: &str, + resolver: &SecretResolver, +) -> Result<()> { + let mut text = String::from_utf8(payload) + .map_err(|_| miette!("websocket text message is not valid UTF-8"))?; + let replacements = resolver + .rewrite_websocket_text_placeholders(&mut text) + .map_err(|e| miette!("{e}"))?; + + if replacements == 0 && !force_reframe { + writer + .write_all(&frame.raw_header) + .await + .into_diagnostic()?; + let mut payload = text.into_bytes(); + let mask_key = frame + .mask_key + .ok_or_else(|| miette!("websocket client frame is not masked"))?; + apply_mask(&mut payload, mask_key); + writer.write_all(&payload).await.into_diagnostic()?; + writer.flush().await.into_diagnostic()?; + return Ok(()); + } + + if replacements > 0 { + emit_rewrite_event(host, port, policy_name, replacements); + } + write_masked_frame(writer, OPCODE_TEXT, text.as_bytes()).await +} + +async fn copy_raw_frame_payload( + reader: &mut R, + writer: &mut W, + frame: &FrameHeader, +) -> Result<()> +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + writer + .write_all(&frame.raw_header) + .await + .into_diagnostic()?; + let mut remaining = frame.payload_len; + let mut buf = [0u8; COPY_BUF_SIZE]; + while remaining > 0 { + let to_read = usize::try_from(remaining) + .unwrap_or(buf.len()) + .min(buf.len()); + let n = reader.read(&mut buf[..to_read]).await.into_diagnostic()?; + if n == 0 { + return Err(miette!("websocket payload ended before declared length")); + } + writer.write_all(&buf[..n]).await.into_diagnostic()?; + remaining -= n as u64; + } + writer.flush().await.into_diagnostic()?; + Ok(()) +} + +async fn write_masked_frame( + writer: &mut W, + opcode: u8, + payload: &[u8], +) -> Result<()> { + let mut header = Vec::with_capacity(14); + header.push(0x80 | opcode); + match payload.len() { + 0..=125 => header.push(0x80 | u8::try_from(payload.len()).expect("payload <= 125")), + 126..=65_535 => { + header.push(0x80 | 0x7e); + header.extend_from_slice( + &u16::try_from(payload.len()) + .expect("payload <= 65535") + .to_be_bytes(), + ); + } + _ => { + header.push(0x80 | 127); + header.extend_from_slice(&(payload.len() as u64).to_be_bytes()); + } + } + let mask_key = new_mask_key(); + header.extend_from_slice(&mask_key); + + let mut masked = payload.to_vec(); + apply_mask(&mut masked, mask_key); + writer.write_all(&header).await.into_diagnostic()?; + writer.write_all(&masked).await.into_diagnostic()?; + writer.flush().await.into_diagnostic()?; + Ok(()) +} + +fn new_mask_key() -> [u8; 4] { + let bytes = uuid::Uuid::new_v4().into_bytes(); + [bytes[0], bytes[1], bytes[2], bytes[3]] +} + +fn apply_mask(payload: &mut [u8], mask_key: [u8; 4]) { + for (i, byte) in payload.iter_mut().enumerate() { + *byte ^= mask_key[i % 4]; + } +} + +fn emit_rewrite_event(host: &str, port: u16, policy_name: &str, replacements: usize) { + let policy_name = if policy_name.is_empty() { + "-" + } else { + policy_name + }; + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(ActionId::Allowed) + .disposition(DispositionId::Allowed) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .dst_endpoint(Endpoint::from_domain(host, port)) + .firewall_rule(policy_name, "l7-websocket") + .message(format!( + "WEBSOCKET_CREDENTIAL_REWRITE rewrote client text message [host:{host} port:{port} replacements:{replacements}]" + )) + .build(); + ocsf_emit!(event); +} + +fn emit_protocol_failure(host: &str, port: u16, policy_name: &str, detail: &str) { + let policy_name = if policy_name.is_empty() { + "-" + } else { + policy_name + }; + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(host, port)) + .firewall_rule(policy_name, "l7-websocket") + .message(format!( + "WEBSOCKET_CREDENTIAL_REWRITE closed ambiguous client frame [host:{host} port:{port}]" + )) + .status_detail(detail) + .build(); + ocsf_emit!(event); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::secrets::SecretResolver; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + fn resolver() -> (std::collections::HashMap, SecretResolver) { + let (child_env, resolver) = SecretResolver::from_provider_env( + std::iter::once(("DISCORD_BOT_TOKEN".to_string(), "real-token".to_string())).collect(), + ); + (child_env, resolver.expect("resolver")) + } + + fn masked_frame(fin: bool, opcode: u8, payload: &[u8]) -> Vec { + let mask_key = [0x37, 0xfa, 0x21, 0x3d]; + let mut frame = Vec::new(); + frame.push(if fin { 0x80 | opcode } else { opcode }); + match payload.len() { + 0..=125 => frame.push(0x80 | u8::try_from(payload.len()).expect("payload <= 125")), + 126..=65_535 => { + frame.push(0x80 | 0x7e); + frame.extend_from_slice( + &u16::try_from(payload.len()) + .expect("payload <= 65535") + .to_be_bytes(), + ); + } + _ => { + frame.push(0x80 | 127); + frame.extend_from_slice(&(payload.len() as u64).to_be_bytes()); + } + } + frame.extend_from_slice(&mask_key); + for (i, byte) in payload.iter().enumerate() { + frame.push(byte ^ mask_key[i % 4]); + } + frame + } + + fn unmasked_frame(opcode: u8, payload: &[u8]) -> Vec { + let mut frame = Vec::new(); + frame.push(0x80 | opcode); + frame.push(u8::try_from(payload.len()).expect("test payload fits in one byte")); + frame.extend_from_slice(payload); + frame + } + + async fn run_client_to_server(input: Vec) -> Result> { + let (_, resolver) = resolver(); + let (mut client_write, mut relay_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + let (mut relay_write, mut upstream_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + + client_write.write_all(&input).await.unwrap(); + drop(client_write); + + let result = relay_client_to_server( + &mut relay_read, + &mut relay_write, + "gateway.example.test", + 443, + "test-policy", + &resolver, + ) + .await; + drop(relay_write); + + let mut output = Vec::new(); + upstream_read.read_to_end(&mut output).await.unwrap(); + result.map(|()| output) + } + + fn decode_masked_text_frame(frame: &[u8]) -> String { + assert_eq!(frame[0] & 0x0F, OPCODE_TEXT); + assert_ne!(frame[1] & 0x80, 0); + let len_code = frame[1] & 0x7F; + let (payload_len, mask_offset) = match len_code { + 0..=125 => (usize::from(len_code), 2), + 126 => (usize::from(u16::from_be_bytes([frame[2], frame[3]])), 4), + 127 => { + let len = u64::from_be_bytes(frame[2..10].try_into().unwrap()); + (usize::try_from(len).unwrap(), 10) + } + _ => unreachable!(), + }; + let mask_key: [u8; 4] = frame[mask_offset..mask_offset + 4].try_into().unwrap(); + let mut payload = frame[mask_offset + 4..mask_offset + 4 + payload_len].to_vec(); + apply_mask(&mut payload, mask_key); + String::from_utf8(payload).unwrap() + } + + async fn read_one_frame(reader: &mut R) -> Vec { + let mut header = [0u8; 2]; + reader.read_exact(&mut header).await.unwrap(); + let len_code = header[1] & 0x7F; + let extended_len = match len_code { + 0..=125 => Vec::new(), + 126 => { + let mut bytes = vec![0u8; 2]; + reader.read_exact(&mut bytes).await.unwrap(); + bytes + } + 127 => { + let mut bytes = vec![0u8; 8]; + reader.read_exact(&mut bytes).await.unwrap(); + bytes + } + _ => unreachable!(), + }; + let payload_len = match len_code { + 0..=125 => usize::from(len_code), + 126 => usize::from(u16::from_be_bytes( + extended_len.as_slice().try_into().unwrap(), + )), + 127 => usize::try_from(u64::from_be_bytes( + extended_len.as_slice().try_into().unwrap(), + )) + .unwrap(), + _ => unreachable!(), + }; + let mask_len = if header[1] & 0x80 != 0 { 4 } else { 0 }; + let mut rest = vec![0u8; extended_len.len() + mask_len + payload_len]; + rest[..extended_len.len()].copy_from_slice(&extended_len); + reader + .read_exact(&mut rest[extended_len.len()..]) + .await + .unwrap(); + + let mut frame = header.to_vec(); + frame.extend_from_slice(&rest); + frame + } + + #[tokio::test] + async fn rewrites_discord_like_identify_text_payload() { + let (child_env, _) = resolver(); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); + + let output = run_client_to_server(masked_frame(true, OPCODE_TEXT, payload.as_bytes())) + .await + .expect("relay should succeed"); + + assert_eq!( + decode_masked_text_frame(&output), + r#"{"op":2,"d":{"token":"real-token"}}"# + ); + } + + #[tokio::test] + async fn upgraded_relay_rewrites_client_text_before_upstream_receives_it() { + let (child_env, resolver) = resolver(); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); + let client_frame = masked_frame(true, OPCODE_TEXT, payload.as_bytes()); + assert!( + !String::from_utf8_lossy(&client_frame).contains("real-token"), + "client-side fixture must not contain the real token" + ); + + let (mut client_app, mut relay_client) = tokio::io::duplex(4096); + let (mut relay_upstream, mut upstream_app) = tokio::io::duplex(4096); + let relay = tokio::spawn(async move { + relay_with_credential_rewrite( + &mut relay_client, + &mut relay_upstream, + Vec::new(), + "gateway.example.test", + 443, + "test-policy", + &resolver, + ) + .await + }); + + client_app.write_all(&client_frame).await.unwrap(); + client_app.flush().await.unwrap(); + + let upstream_frame = tokio::time::timeout( + std::time::Duration::from_secs(2), + read_one_frame(&mut upstream_app), + ) + .await + .expect("upstream should receive rewritten frame"); + assert_eq!( + decode_masked_text_frame(&upstream_frame), + r#"{"op":2,"d":{"token":"real-token"}}"# + ); + + drop(client_app); + drop(upstream_app); + let _ = tokio::time::timeout(std::time::Duration::from_secs(2), relay).await; + } + + #[tokio::test] + async fn text_without_placeholder_passes_semantically_unchanged() { + let frame = masked_frame(true, OPCODE_TEXT, br#"{"op":1,"d":42}"#); + let output = run_client_to_server(frame.clone()) + .await + .expect("relay should succeed"); + + assert_eq!(output, frame); + assert_eq!(decode_masked_text_frame(&output), r#"{"op":1,"d":42}"#); + } + + #[tokio::test] + async fn unknown_placeholder_fails_closed() { + let frame = masked_frame( + true, + OPCODE_TEXT, + br#"{"token":"openshell:resolve:env:UNKNOWN"}"#, + ); + + let err = run_client_to_server(frame) + .await + .expect_err("unknown placeholder should fail"); + + assert!( + err.to_string() + .contains("unresolved credential placeholder") + ); + } + + #[tokio::test] + async fn fragmented_text_rewrites_after_final_continuation() { + let (child_env, _) = resolver(); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let first = format!(r#"{{"token":"{placeholder}"#); + let second = r#""}"#; + let mut input = masked_frame(false, OPCODE_TEXT, first.as_bytes()); + input.extend(masked_frame(true, OPCODE_CONTINUATION, second.as_bytes())); + + let output = run_client_to_server(input) + .await + .expect("relay should succeed"); + + assert_eq!( + decode_masked_text_frame(&output), + r#"{"token":"real-token"}"# + ); + } + + #[tokio::test] + async fn rejects_rsv_bits() { + let mut frame = masked_frame(true, OPCODE_TEXT, b"hello"); + frame[0] |= 0x40; + + let err = run_client_to_server(frame) + .await + .expect_err("RSV frame should fail"); + + assert!(err.to_string().contains("RSV bits")); + } + + #[tokio::test] + async fn rejects_unmasked_client_frame() { + let err = run_client_to_server(unmasked_frame(OPCODE_TEXT, b"hello")) + .await + .expect_err("unmasked frame should fail"); + + assert!(err.to_string().contains("not masked")); + } + + #[tokio::test] + async fn rejects_invalid_utf8_text() { + let err = run_client_to_server(masked_frame(true, OPCODE_TEXT, &[0xff])) + .await + .expect_err("invalid UTF-8 should fail"); + + assert!(err.to_string().contains("valid UTF-8")); + } + + #[tokio::test] + async fn rejects_oversize_text_message() { + let payload = vec![b'a'; MAX_TEXT_MESSAGE_BYTES + 1]; + let err = run_client_to_server(masked_frame(true, OPCODE_TEXT, &payload)) + .await + .expect_err("oversize text should fail"); + + assert!(err.to_string().contains("exceeds")); + } +} diff --git a/crates/openshell-sandbox/src/opa.rs b/crates/openshell-sandbox/src/opa.rs index 5897679a0..803dc2ad1 100644 --- a/crates/openshell-sandbox/src/opa.rs +++ b/crates/openshell-sandbox/src/opa.rs @@ -1061,6 +1061,9 @@ fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> St if e.allow_encoded_slash { ep["allow_encoded_slash"] = true.into(); } + if e.websocket_credential_rewrite { + ep["websocket_credential_rewrite"] = true.into(); + } if !e.persisted_queries.is_empty() { ep["persisted_queries"] = e.persisted_queries.clone().into(); } @@ -2463,6 +2466,63 @@ network_policies: assert!(l7.allow_encoded_slash); } + #[test] + fn l7_endpoint_config_preserves_proto_websocket_credential_rewrite() { + let mut network_policies = std::collections::HashMap::new(); + network_policies.insert( + "gateway".to_string(), + NetworkPolicyRule { + name: "gateway".to_string(), + endpoints: vec![NetworkEndpoint { + host: "gateway.example.com".to_string(), + port: 443, + protocol: "rest".to_string(), + enforcement: "enforce".to_string(), + access: "full".to_string(), + websocket_credential_rewrite: true, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/node".to_string(), + ..Default::default() + }], + }, + ); + let proto = ProtoSandboxPolicy { + version: 1, + filesystem: Some(ProtoFs { + include_workdir: true, + read_only: vec![], + read_write: vec![], + }), + landlock: Some(openshell_core::proto::LandlockPolicy { + compatibility: "best_effort".to_string(), + }), + process: Some(ProtoProc { + run_as_user: "sandbox".to_string(), + run_as_group: "sandbox".to_string(), + }), + network_policies, + }; + + let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); + let input = NetworkInput { + host: "gateway.example.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + + let config = engine + .query_endpoint_config(&input) + .unwrap() + .expect("endpoint config"); + let l7 = crate::l7::parse_l7_config(&config).unwrap(); + assert!(l7.websocket_credential_rewrite); + } + #[test] fn l7_endpoint_config_none_for_l4_only() { let engine = l7_engine(); diff --git a/crates/openshell-sandbox/src/policy_local.rs b/crates/openshell-sandbox/src/policy_local.rs index 21556ec6a..d40bc31c9 100644 --- a/crates/openshell-sandbox/src/policy_local.rs +++ b/crates/openshell-sandbox/src/policy_local.rs @@ -619,6 +619,7 @@ fn network_endpoint_from_json( ports, deny_rules, allow_encoded_slash: endpoint.allow_encoded_slash, + websocket_credential_rewrite: false, // GraphQL persisted-query knobs and path scoping default empty โ€” // agent proposals don't author them today. persisted_queries: String::new(), diff --git a/crates/openshell-sandbox/src/proxy.rs b/crates/openshell-sandbox/src/proxy.rs index f20e51655..44b9e8c45 100644 --- a/crates/openshell-sandbox/src/proxy.rs +++ b/crates/openshell-sandbox/src/proxy.rs @@ -2376,6 +2376,7 @@ async fn relay_rewritten_forward_request( client: &mut C, upstream: &mut U, generation_guard: &PolicyGenerationGuard, + strip_websocket_extensions: bool, ) -> Result where C: TokioAsyncRead + TokioAsyncWrite + Unpin, @@ -2396,12 +2397,15 @@ where body_length, }; - crate::l7::rest::relay_http_request_with_resolver_guarded( + crate::l7::rest::relay_http_request_with_options_guarded( &req, client, upstream, - None, - Some(generation_guard), + crate::l7::rest::RelayRequestOptions { + resolver: None, + generation_guard: Some(generation_guard), + strip_websocket_extensions, + }, ) .await } @@ -2623,6 +2627,8 @@ async fn handle_forward_proxy( }; let mut forward_request_bytes = buf[..used].to_vec(); let mut upstream_target = path.clone(); + let mut strip_websocket_extensions = false; + let mut upgrade_options = crate::l7::relay::UpgradeRelayOptions::default(); // 4b. If the endpoint has L7 config, evaluate the request against // L7 policy. The forward proxy handles exactly one request per @@ -2760,6 +2766,19 @@ async fn handle_forward_proxy( .await?; return Ok(()); }; + let websocket_request = + crate::l7::rest::request_is_websocket_upgrade(&forward_request_bytes); + if l7_config.config.protocol == crate::l7::L7Protocol::Rest + && l7_config.config.websocket_credential_rewrite + { + strip_websocket_extensions = true; + upgrade_options = crate::l7::relay::UpgradeRelayOptions { + websocket_request, + websocket_credential_rewrite: true, + secret_resolver: secret_resolver.clone(), + policy_name: matched_policy.clone().unwrap_or_default(), + }; + } let graphql = if l7_config.config.protocol == crate::l7::L7Protocol::Graphql { let header_end = forward_request_bytes .windows(4) @@ -3223,10 +3242,19 @@ async fn handle_forward_proxy( client, &mut upstream, &forward_generation_guard, + strip_websocket_extensions, ) .await?; if let crate::l7::provider::RelayOutcome::Upgraded { overflow } = outcome { - crate::l7::relay::handle_upgrade(client, &mut upstream, overflow, &host_lc, port).await?; + crate::l7::relay::handle_upgrade( + client, + &mut upstream, + overflow, + &host_lc, + port, + upgrade_options, + ) + .await?; } Ok(()) @@ -3310,6 +3338,7 @@ mod tests { enforcement: crate::l7::EnforcementMode::Enforce, graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, allow_encoded_slash: false, + websocket_credential_rewrite: false, }, }, L7ConfigSnapshot { @@ -3320,6 +3349,7 @@ mod tests { enforcement: crate::l7::EnforcementMode::Enforce, graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, allow_encoded_slash: false, + websocket_credential_rewrite: false, }, }, ]; @@ -4387,6 +4417,7 @@ mod tests { &mut proxy_to_client, &mut proxy_to_upstream, &guard, + false, ) .await; assert!( @@ -4425,6 +4456,7 @@ mod tests { &mut proxy_to_client, &mut proxy_to_upstream, &guard, + false, ) .await; assert!(result.is_err(), "forward relay must reject CL/TE ambiguity"); diff --git a/crates/openshell-sandbox/src/secrets.rs b/crates/openshell-sandbox/src/secrets.rs index d645e1482..bc8ac7c6e 100644 --- a/crates/openshell-sandbox/src/secrets.rs +++ b/crates/openshell-sandbox/src/secrets.rs @@ -155,6 +155,58 @@ impl SecretResolver { Some(format!("{prefix} {secret}")) } + /// Rewrite credential placeholders inside a WebSocket text message. + /// + /// The message is mutated only after all placeholders resolve + /// successfully. The return value is the number of replacements; callers + /// must not log the rewritten text. + pub(crate) fn rewrite_websocket_text_placeholders( + &self, + text: &mut String, + ) -> Result { + if !text.contains(PLACEHOLDER_PREFIX) { + return Ok(0); + } + + let mut rewritten = String::with_capacity(text.len()); + let mut pos = 0; + let mut replacements = 0; + + while pos < text.len() { + let Some(start) = text[pos..].find(PLACEHOLDER_PREFIX) else { + rewritten.push_str(&text[pos..]); + break; + }; + let abs_start = pos + start; + rewritten.push_str(&text[pos..abs_start]); + + let key_start = abs_start + PLACEHOLDER_PREFIX.len(); + let key_end = text[key_start..] + .bytes() + .position(|b| !is_env_key_char(b)) + .map_or(text.len(), |p| key_start + p); + + if key_end == key_start { + return Err(UnresolvedPlaceholderError { + location: "websocket", + }); + } + + let full_placeholder = &text[abs_start..key_end]; + let Some(secret) = self.resolve_placeholder(full_placeholder) else { + return Err(UnresolvedPlaceholderError { + location: "websocket", + }); + }; + rewritten.push_str(secret); + replacements += 1; + pos = key_end; + } + + *text = rewritten; + Ok(replacements) + } + /// Decode a Base64-encoded Basic auth token, resolve any placeholders in /// the decoded `username:password` string, and re-encode. /// @@ -1444,6 +1496,69 @@ mod tests { assert_eq!(raw.as_slice(), result.rewritten.as_slice()); } + #[test] + fn rewrite_websocket_text_replaces_placeholders_and_returns_count() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [ + ("DISCORD_BOT_TOKEN".to_string(), "real-token".to_string()), + ("APP_ID".to_string(), "app-123".to_string()), + ] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let token = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let app_id = child_env.get("APP_ID").unwrap(); + let mut payload = + format!(r#"{{"op":2,"d":{{"token":"{token}","properties":{{"app":"{app_id}"}}}}}}"#); + + let count = resolver + .rewrite_websocket_text_placeholders(&mut payload) + .expect("rewrite should succeed"); + + assert_eq!(count, 2); + assert!(payload.contains(r#""token":"real-token""#)); + assert!(payload.contains(r#""app":"app-123""#)); + assert!(!payload.contains(PLACEHOLDER_PREFIX)); + } + + #[test] + fn rewrite_websocket_text_without_placeholder_is_unchanged() { + let (_, resolver) = SecretResolver::from_provider_env( + [("KEY".to_string(), "secret".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let mut payload = r#"{"op":1,"d":42}"#.to_string(); + + let count = resolver + .rewrite_websocket_text_placeholders(&mut payload) + .expect("rewrite should succeed"); + + assert_eq!(count, 0); + assert_eq!(payload, r#"{"op":1,"d":42}"#); + } + + #[test] + fn rewrite_websocket_text_unknown_placeholder_fails_closed_without_mutating() { + let (_, resolver) = SecretResolver::from_provider_env( + [("KEY".to_string(), "secret".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let original = r#"{"token":"openshell:resolve:env:UNKNOWN"}"#.to_string(); + let mut payload = original.clone(); + + let err = resolver + .rewrite_websocket_text_placeholders(&mut payload) + .expect_err("unknown placeholder should fail"); + + assert_eq!(err.location, "websocket"); + assert_eq!(payload, original); + } + // === Redaction tests === #[test] diff --git a/docs/reference/policy-schema.mdx b/docs/reference/policy-schema.mdx index a98e8087c..35967e802 100644 --- a/docs/reference/policy-schema.mdx +++ b/docs/reference/policy-schema.mdx @@ -163,6 +163,7 @@ Each endpoint defines a reachable destination and optional inspection rules. | `deny_rules` | list of deny rule objects | No | L7 deny rules that block specific requests even when allowed by `access` or `rules`. Deny rules take precedence over allow rules. | | `allowed_ips` | list of string | No | CIDR or IP allowlist for SSRF override. Entries overlapping loopback (`127.0.0.0/8`), link-local (`169.254.0.0/16`), or unspecified (`0.0.0.0`) are rejected at load time. | | `allow_encoded_slash` | bool | No | When `true`, L7 request parsing preserves `%2F` inside path segments instead of rejecting it. Use this for registries and APIs such as npm scoped packages (`/@scope%2Fname`). Defaults to `false`. | +| `websocket_credential_rewrite` | bool | No | When `true` on a `protocol: rest` endpoint, OpenShell rewrites `openshell:resolve:env:*` placeholders in client-to-server WebSocket text messages after an allowed HTTP `101` upgrade. Defaults to `false`. | | `persisted_queries` | string | No | GraphQL hash-only behavior. Default is `deny`; use `allow_registered` only with `graphql_persisted_queries`. | | `graphql_persisted_queries` | map | No | Trusted GraphQL persisted-query registry keyed by hash or saved-query ID. Values contain `operation_type`, optional `operation_name`, and optional root `fields`. | | `graphql_max_body_bytes` | integer | No | Maximum GraphQL request body bytes buffered for inspection. Defaults to `65536`. | diff --git a/docs/sandboxes/policies.mdx b/docs/sandboxes/policies.mdx index 4e0aa4357..1939cf651 100644 --- a/docs/sandboxes/policies.mdx +++ b/docs/sandboxes/policies.mdx @@ -49,14 +49,14 @@ network_policies: Static sections are locked at sandbox creation. Changing them requires destroying and recreating the sandbox. Dynamic sections can be updated on a running sandbox with `openshell policy update` for incremental merges or `openshell policy set` for full replacement, and take effect without restarting. When a hot reload changes rules on an active HTTP L7 endpoint, existing keep-alive tunnels are closed before forwarding another parsed request. Credential-injection-only HTTP passthrough tunnels use the same reload boundary. Most HTTP clients reconnect automatically, and the next request is evaluated against the current policy. -Raw streams are connection-scoped and outside L7 live-reload guarantees. This includes `tls: skip`, non-HTTP TCP payloads, HTTP upgrades such as WebSocket, and long-lived response streams such as SSE. A reload applies to the next connection or next parsed HTTP request; it does not interrupt an already-forwarded raw stream. +Raw streams are connection-scoped and outside L7 live-reload guarantees. This includes `tls: skip`, non-HTTP TCP payloads, HTTP upgrades such as WebSocket, and long-lived response streams such as SSE. A reload applies to the next connection or next parsed HTTP request; it does not interrupt an already-forwarded raw stream. The exception is `websocket_credential_rewrite: true` on `protocol: rest` endpoints, which keeps policy evaluation on the HTTP upgrade request and rewrites credential placeholders only in client-to-server WebSocket text messages after the allowed upgrade. | Section | Type | Description | |---|---|---| | `filesystem_policy` | Static | Controls which directories the agent can access on disk. Paths are split into `read_only` and `read_write` lists. Any path not listed in either list is inaccessible. Set `include_workdir: true` to automatically add the agent's working directory to `read_write`. [Landlock LSM](https://docs.kernel.org/security/landlock.html) enforces these restrictions at the kernel level. | | `landlock` | Static | Configures Landlock LSM enforcement behavior. Set `compatibility` to `best_effort` (skip individual inaccessible paths while applying remaining rules) or `hard_requirement` (fail if any path is inaccessible or the required kernel ABI is unavailable). See the [Policy Schema Reference](/reference/policy-schema#landlock) for the full behavior table. | | `process` | Static | Sets the OS-level identity for the agent process. `run_as_user` and `run_as_group` default to `sandbox`. Root (`root` or `0`) is rejected. The agent also runs with seccomp filters that block dangerous system calls. | -| `network_policies` | Dynamic | Controls network access for ordinary outbound traffic from the sandbox. Each block has a name, a list of endpoints (host, port, protocol, and optional rules), and a list of binaries allowed to use those endpoints.
Every outbound connection except `https://inference.local` goes through the proxy, which queries the [policy engine](/about/how-it-works#core-components) with the destination and calling binary. A connection is allowed only when both match an entry in the same policy block.
For endpoints with `protocol: rest`, the proxy auto-detects TLS and terminates it so each HTTP request can be checked against that endpoint's `rules` (method and path).
Endpoints without `protocol` allow the TCP stream through without inspecting payloads.
If no endpoint matches, the connection is denied. Configure managed inference separately through [Inference Routing](/sandboxes/inference-routing). | +| `network_policies` | Dynamic | Controls network access for ordinary outbound traffic from the sandbox. Each block has a name, a list of endpoints (host, port, protocol, and optional rules), and a list of binaries allowed to use those endpoints.
Every outbound connection except `https://inference.local` goes through the proxy, which queries the [policy engine](/about/how-it-works#core-components) with the destination and calling binary. A connection is allowed only when both match an entry in the same policy block.
For endpoints with `protocol: rest`, the proxy auto-detects TLS and terminates it so each HTTP request can be checked against that endpoint's `rules` (method and path). Set `websocket_credential_rewrite: true` only when a REST-shaped WebSocket upgrade must keep placeholder credentials in sandbox-owned payloads and resolve them at the OpenShell relay boundary.
Endpoints without `protocol` allow the TCP stream through without inspecting payloads.
If no endpoint matches, the connection is denied. Configure managed inference separately through [Inference Routing](/sandboxes/inference-routing). | ## Baseline Filesystem Paths @@ -447,7 +447,7 @@ Allow `pip install` and `uv pip install` to reach PyPI: - { path: /usr/local/bin/uv } ``` -Endpoints without `protocol` use TCP passthrough, where the proxy allows the stream without inspecting payloads. If the stream is HTTP and TLS is auto-terminated, the proxy can still rewrite configured credential placeholders and closes keep-alive passthrough tunnels on policy reload before forwarding another request. +Endpoints without `protocol` use TCP passthrough, where the proxy allows the stream without inspecting payloads. If the stream is HTTP and TLS is auto-terminated, the proxy can still rewrite configured credential placeholders and closes keep-alive passthrough tunnels on policy reload before forwarding another request. WebSocket payload credential rewrite requires an explicit `protocol: rest` endpoint with `websocket_credential_rewrite: true`. diff --git a/docs/security/best-practices.mdx b/docs/security/best-practices.mdx index 227bb4567..11afd41c8 100644 --- a/docs/security/best-practices.mdx +++ b/docs/security/best-practices.mdx @@ -99,7 +99,7 @@ The `protocol` field on an endpoint controls whether the proxy inspects individu | Default | Endpoints without a `protocol` field use L4-only enforcement: the proxy checks host, port, and binary, then relays the TCP stream without inspecting payloads. | | What you can change | Add `protocol: rest` to enable per-request HTTP method/path inspection, or `protocol: graphql` to inspect GraphQL operation type, operation name, and root fields. Pair either protocol with `rules` or access presets (`full`, `read-only`, `read-write`). | | Risk if relaxed | L4-only endpoints allow the agent to send any data through the tunnel after the initial connection is permitted. The proxy cannot see HTTP methods, paths, or GraphQL operations. Adding `access: full` with L7 inspection enables observability but permits all inspected actions. | -| Recommendation | Use `protocol: rest` with specific `rules` for APIs where intent is encoded in method and path. Use `protocol: graphql` for GraphQL APIs where destructive operations are body-encoded. Prefer `access: read-only` or explicit allowlists, and deny hash-only persisted queries unless you maintain a trusted registry. Omit `protocol` for non-HTTP protocols (WebSocket, gRPC streaming). | +| Recommendation | Use `protocol: rest` with specific `rules` for APIs where intent is encoded in method and path. Use `protocol: graphql` for GraphQL APIs where destructive operations are body-encoded. Prefer `access: read-only` or explicit allowlists, and deny hash-only persisted queries unless you maintain a trusted registry. Omit `protocol` for non-HTTP protocols. For WebSocket endpoints that begin with an HTTP upgrade and must carry placeholder credentials in client text frames, use `protocol: rest` with `websocket_credential_rewrite: true`. | ### Enforcement Mode (`audit` vs `enforce`) diff --git a/proto/sandbox.proto b/proto/sandbox.proto index f7df5945e..5ff88f6cf 100644 --- a/proto/sandbox.proto +++ b/proto/sandbox.proto @@ -116,6 +116,10 @@ message NetworkEndpoint { // protocol "rest" when both surfaces live under api.example.com:443. // Empty means all paths. string path = 15; + // When true on a "rest" endpoint, OpenShell rewrites credential placeholders + // inside client-to-server WebSocket text messages after an allowed HTTP 101 + // upgrade. Defaults to false. + bool websocket_credential_rewrite = 16; } // Trusted GraphQL operation classification. From 726efa3ac02d7ed0f76b4781316c712cf9d48d7c Mon Sep 17 00:00:00 2001 From: Aaron Erickson Date: Fri, 8 May 2026 19:58:33 -0700 Subject: [PATCH 02/17] fix(sandbox): harden websocket credential rewrite Signed-off-by: Aaron Erickson --- crates/openshell-sandbox/src/l7/rest.rs | 391 +++++++++++++++++-- crates/openshell-sandbox/src/l7/websocket.rs | 333 +++++++++++++++- crates/openshell-sandbox/src/secrets.rs | 77 ++++ 3 files changed, 740 insertions(+), 61 deletions(-) diff --git a/crates/openshell-sandbox/src/l7/rest.rs b/crates/openshell-sandbox/src/l7/rest.rs index 53fc60597..4a82b95c9 100644 --- a/crates/openshell-sandbox/src/l7/rest.rs +++ b/crates/openshell-sandbox/src/l7/rest.rs @@ -10,6 +10,7 @@ use crate::l7::provider::{BodyLength, L7Provider, L7Request, RelayOutcome}; use crate::opa::PolicyGenerationGuard; use crate::secrets::{SecretResolver, rewrite_http_header_block}; +use base64::Engine as _; use miette::{IntoDiagnostic, Result, miette}; use std::collections::HashMap; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; @@ -398,8 +399,13 @@ where .windows(4) .position(|w| w == b"\r\n\r\n") .map_or(req.raw_header.len(), |p| p + 4); + let websocket_rewrite_request = if options.strip_websocket_extensions { + validate_websocket_upgrade_request(&req.raw_header[..header_end])? + } else { + false + }; - let header_bytes = if options.strip_websocket_extensions { + let header_bytes = if websocket_rewrite_request { strip_websocket_extensions_if_requested(&req.raw_header[..header_end])? } else { req.raw_header[..header_end].to_vec() @@ -446,7 +452,15 @@ where } upstream.flush().await.into_diagnostic()?; - let outcome = relay_response(&req.action, upstream, client).await?; + let outcome = relay_response( + &req.action, + upstream, + client, + RelayResponseOptions { + reject_websocket_extensions: websocket_rewrite_request, + }, + ) + .await?; // Validate that the client actually requested an upgrade before accepting // a 101 from upstream. Per RFC 9110 Section 7.8, the server MUST NOT send @@ -493,13 +507,13 @@ pub(crate) fn request_is_websocket_upgrade(raw_header: &[u8]) -> bool { .windows(4) .position(|w| w == b"\r\n\r\n") .map_or(raw_header.len(), |p| p + 4); - std::str::from_utf8(&raw_header[..header_end]).is_ok_and(client_requested_websocket_upgrade) + validate_websocket_upgrade_request(&raw_header[..header_end]).unwrap_or(false) } fn strip_websocket_extensions_if_requested(raw_header: &[u8]) -> Result> { let header_str = std::str::from_utf8(raw_header) .map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; - if !client_requested_websocket_upgrade(header_str) { + if !validate_websocket_upgrade_request(raw_header)? { return Ok(raw_header.to_vec()); } @@ -517,6 +531,104 @@ fn strip_websocket_extensions_if_requested(raw_header: &[u8]) -> Result> Ok(out) } +fn validate_websocket_upgrade_request(raw_header: &[u8]) -> Result { + let header_str = std::str::from_utf8(raw_header) + .map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let mut lines = header_str.lines(); + let Some(request_line) = lines.next() else { + return Ok(false); + }; + let method = request_line.split_whitespace().next().unwrap_or_default(); + let mut headers = WebSocketUpgradeHeaders::default(); + + for line in lines { + if line.is_empty() { + break; + } + let Some((name, value)) = line.split_once(':') else { + continue; + }; + let name = name.trim().to_ascii_lowercase(); + let value = value.trim(); + match name.as_str() { + "upgrade" if header_value_contains_token(value, "websocket") => { + headers.upgrade_websocket = true; + } + "connection" if header_value_contains_token(value, "upgrade") => { + headers.connection_upgrade = true; + } + "sec-websocket-key" => { + headers.sec_key_count += 1; + headers.sec_key = Some(value.to_string()); + } + "sec-websocket-version" => { + headers.version_count += 1; + headers.version = Some(value.to_string()); + } + _ => {} + } + } + + if !headers.is_attempt() { + return Ok(false); + } + if !method.eq_ignore_ascii_case("GET") { + return Err(miette!("websocket upgrade request must use GET")); + } + if !headers.upgrade_websocket { + return Err(miette!( + "websocket upgrade request missing Upgrade: websocket" + )); + } + if !headers.connection_upgrade { + return Err(miette!( + "websocket upgrade request missing Connection: Upgrade" + )); + } + if headers.sec_key_count != 1 { + return Err(miette!( + "websocket upgrade request must include exactly one Sec-WebSocket-Key" + )); + } + let key = headers.sec_key.as_deref().unwrap_or_default(); + let decoded_key = base64::engine::general_purpose::STANDARD + .decode(key.as_bytes()) + .map_err(|_| miette!("websocket upgrade request has invalid Sec-WebSocket-Key"))?; + if decoded_key.len() != 16 { + return Err(miette!( + "websocket upgrade request has invalid Sec-WebSocket-Key length" + )); + } + if headers.version_count != 1 || headers.version.as_deref() != Some("13") { + return Err(miette!( + "websocket upgrade request must use Sec-WebSocket-Version: 13" + )); + } + Ok(true) +} + +fn header_value_contains_token(value: &str, expected: &str) -> bool { + value + .split(',') + .any(|token| token.trim().eq_ignore_ascii_case(expected)) +} + +#[derive(Default)] +struct WebSocketUpgradeHeaders { + upgrade_websocket: bool, + connection_upgrade: bool, + sec_key: Option, + sec_key_count: usize, + version: Option, + version_count: usize, +} + +impl WebSocketUpgradeHeaders { + fn is_attempt(&self) -> bool { + self.upgrade_websocket || self.sec_key.is_some() || self.version.is_some() + } +} + /// Send a 403 Forbidden JSON deny response. /// /// When `redacted_target` is provided, it is used instead of `req.target` @@ -833,10 +945,16 @@ fn find_crlf(buf: &[u8], start: usize) -> Option { .map(|offset| start + offset) } +#[derive(Clone, Copy, Default)] +struct RelayResponseOptions { + reject_websocket_extensions: bool, +} + async fn relay_response( request_method: &str, upstream: &mut U, client: &mut C, + options: RelayResponseOptions, ) -> Result where U: AsyncRead + Unpin, @@ -890,6 +1008,17 @@ where // from upstream beyond the headers are overflow that belong to the // upgraded protocol and must be forwarded before switching. if status_code == 101 { + if options.reject_websocket_extensions + && header_str.lines().skip(1).any(|line| { + line.split_once(':').is_some_and(|(name, _)| { + name.trim().eq_ignore_ascii_case("sec-websocket-extensions") + }) + }) + { + return Err(miette!( + "upstream negotiated unsupported WebSocket extensions" + )); + } client .write_all(&buf[..header_end]) .await @@ -1030,29 +1159,6 @@ fn client_requested_upgrade(headers: &str) -> bool { has_upgrade_header && connection_contains_upgrade } -fn client_requested_websocket_upgrade(headers: &str) -> bool { - let mut upgrade_is_websocket = false; - let mut connection_contains_upgrade = false; - - for line in headers.lines().skip(1) { - let lower = line.to_ascii_lowercase(); - if lower.starts_with("upgrade:") { - let val = lower.split_once(':').map_or("", |(_, v)| v.trim()); - if val == "websocket" { - upgrade_is_websocket = true; - } - } - if lower.starts_with("connection:") { - let val = lower.split_once(':').map_or("", |(_, v)| v.trim()); - if val.split(',').any(|tok| tok.trim() == "upgrade") { - connection_contains_upgrade = true; - } - } - } - - upgrade_is_websocket && connection_contains_upgrade -} - /// Returns true for responses that MUST NOT contain a message body per RFC 7230 ยง3.3.3: /// HEAD responses, 1xx informational, 204 No Content, 304 Not Modified. fn is_bodiless_response(request_method: &str, status_code: u16) -> bool { @@ -1134,9 +1240,9 @@ mod tests { use super::*; use crate::opa::OpaEngine; use crate::secrets::SecretResolver; - use base64::Engine as _; const TEST_POLICY: &str = include_str!("../../data/sandbox-policy.rego"); + const VALID_WS_KEY: &str = "dGhlIHNhbXBsZSBub25jZQ=="; #[test] fn deny_response_body_is_agent_readable_and_redacted() { @@ -1799,7 +1905,12 @@ mod tests { let result = tokio::time::timeout( std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), ) .await .expect("relay_response should not deadlock"); @@ -1840,7 +1951,12 @@ mod tests { let result = tokio::time::timeout( std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), ) .await .expect("must not block when no Connection: close"); @@ -1876,7 +1992,12 @@ mod tests { let result = tokio::time::timeout( std::time::Duration::from_secs(2), - relay_response("HEAD", &mut upstream_read, &mut client_write), + relay_response( + "HEAD", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), ) .await .expect("HEAD relay must not deadlock waiting for body"); @@ -1909,7 +2030,12 @@ mod tests { let result = tokio::time::timeout( std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), ) .await .expect("204 relay must not deadlock"); @@ -1944,7 +2070,12 @@ mod tests { let result = tokio::time::timeout( std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), ) .await .expect("must not block when chunked body is complete in overflow"); @@ -1983,7 +2114,12 @@ mod tests { let result = tokio::time::timeout( std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), ) .await .expect("must not block when chunked response has trailers"); @@ -2021,7 +2157,12 @@ mod tests { let result = tokio::time::timeout( std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), ) .await .expect("normal relay must not deadlock"); @@ -2052,7 +2193,12 @@ mod tests { let result = tokio::time::timeout( std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), ) .await .expect("relay must not deadlock"); @@ -2093,7 +2239,12 @@ mod tests { let result = tokio::time::timeout( std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), ) .await .expect("relay_response should not deadlock"); @@ -2135,7 +2286,12 @@ mod tests { let result = tokio::time::timeout( std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), ) .await .expect("relay_response should not deadlock"); @@ -2266,6 +2422,112 @@ mod tests { upstream_task.await.expect("upstream task should complete"); } + #[tokio::test] + async fn opted_in_websocket_relay_rejects_invalid_upgrade_before_upstream_write() { + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let req = L7Request { + action: "GET".to_string(), + target: "/ws".to_string(), + query_params: HashMap::new(), + raw_header: b"GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: 13\r\n\r\n".to_vec(), + body_length: BodyLength::None, + }; + + let result = relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + strip_websocket_extensions: true, + ..Default::default() + }, + ) + .await; + + assert!( + result.is_err(), + "missing Sec-WebSocket-Key must fail closed" + ); + drop(proxy_to_upstream); + let mut forwarded = Vec::new(); + upstream_side.read_to_end(&mut forwarded).await.unwrap(); + assert!( + forwarded.is_empty(), + "invalid opted-in upgrade must not reach upstream" + ); + } + + #[tokio::test] + async fn opted_in_websocket_relay_strips_request_extensions_and_rejects_response_extensions() { + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let req = L7Request { + action: "GET".to_string(), + target: "/ws".to_string(), + query_params: HashMap::new(), + raw_header: format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate\r\nSec-WebSocket-Version: 13\r\n\r\n" + ) + .into_bytes(), + body_length: BodyLength::None, + }; + + let upstream_task = tokio::spawn(async move { + let mut buf = vec![0u8; 4096]; + let mut total = 0; + loop { + let n = upstream_side.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + } + let forwarded = String::from_utf8_lossy(&buf[..total]); + assert!( + !forwarded + .to_ascii_lowercase() + .contains("sec-websocket-extensions"), + "opted-in request must strip extension negotiation" + ); + upstream_side + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Extensions: permessage-deflate\r\n\r\n", + ) + .await + .unwrap(); + upstream_side.flush().await.unwrap(); + }); + + let result = relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + strip_websocket_extensions: true, + ..Default::default() + }, + ) + .await; + + let err = result.expect_err("upstream extension negotiation must fail closed"); + assert!(err.to_string().contains("unsupported WebSocket extensions")); + upstream_task.await.expect("upstream task should complete"); + + drop(proxy_to_client); + let mut received = Vec::new(); + app_side.read_to_end(&mut received).await.unwrap(); + assert!( + received.is_empty(), + "rejected extension negotiation must not forward 101 headers" + ); + } + #[tokio::test] async fn relay_request_guard_blocks_stale_generation_before_upstream_write() { let policy_data = "network_policies: {}\n"; @@ -2333,18 +2595,63 @@ mod tests { #[test] fn request_is_websocket_upgrade_detects_websocket_upgrade() { - let raw = b"GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: keep-alive, Upgrade\r\n\r\n"; - assert!(request_is_websocket_upgrade(raw)); + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: keep-alive, Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); + assert!(request_is_websocket_upgrade(raw.as_bytes())); + } + + #[test] + fn request_is_websocket_upgrade_rejects_missing_key() { + let raw = b"GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: 13\r\n\r\n"; + assert!(!request_is_websocket_upgrade(raw)); + assert!(validate_websocket_upgrade_request(raw).is_err()); + } + + #[test] + fn request_is_websocket_upgrade_rejects_wrong_method() { + let raw = format!( + "POST /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); + assert!(!request_is_websocket_upgrade(raw.as_bytes())); + assert!(validate_websocket_upgrade_request(raw.as_bytes()).is_err()); + } + + #[test] + fn request_is_websocket_upgrade_rejects_wrong_version() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Version: 12\r\n\r\n" + ); + assert!(!request_is_websocket_upgrade(raw.as_bytes())); + assert!(validate_websocket_upgrade_request(raw.as_bytes()).is_err()); + } + + #[test] + fn validate_websocket_upgrade_ignores_plain_rest_request() { + let raw = b"GET /api HTTP/1.1\r\nHost: example.com\r\n\r\n"; + assert!(!request_is_websocket_upgrade(raw)); + assert!(!validate_websocket_upgrade_request(raw).expect("plain request should parse")); + } + + #[test] + fn validate_websocket_upgrade_ignores_non_websocket_upgrade() { + let raw = b"GET /h2c HTTP/1.1\r\nHost: example.com\r\nUpgrade: h2c\r\nConnection: Upgrade\r\n\r\n"; + assert!(!request_is_websocket_upgrade(raw)); + assert!(!validate_websocket_upgrade_request(raw).expect("h2c request should parse")); } #[test] fn strip_websocket_extensions_removes_extension_negotiation() { - let raw = b"GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\nSec-WebSocket-Version: 13\r\n\r\n"; + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); - let stripped = strip_websocket_extensions_if_requested(raw).expect("strip should succeed"); + let stripped = + strip_websocket_extensions_if_requested(raw.as_bytes()).expect("strip should succeed"); let stripped = String::from_utf8(stripped).unwrap(); assert!(stripped.contains("Upgrade: websocket\r\n")); + assert!(stripped.contains("Sec-WebSocket-Key: ")); assert!(stripped.contains("Sec-WebSocket-Version: 13\r\n")); assert!( !stripped diff --git a/crates/openshell-sandbox/src/l7/websocket.rs b/crates/openshell-sandbox/src/l7/websocket.rs index 0c04f9ea1..e2649f811 100644 --- a/crates/openshell-sandbox/src/l7/websocket.rs +++ b/crates/openshell-sandbox/src/l7/websocket.rs @@ -16,6 +16,7 @@ use openshell_ocsf::{ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; const MAX_TEXT_MESSAGE_BYTES: usize = 1024 * 1024; +const MAX_RAW_FRAME_PAYLOAD_BYTES: u64 = 16 * 1024 * 1024; const COPY_BUF_SIZE: usize = 8192; const OPCODE_CONTINUATION: u8 = 0x0; const OPCODE_TEXT: u8 = 0x1; @@ -81,10 +82,13 @@ where Ok::<(), miette::Report>(()) }; - tokio::select! { + let result = tokio::select! { result = client_to_server => result, result = server_to_client => result, - } + }; + let _ = upstream_write.shutdown().await; + let _ = client_write.shutdown().await; + result } async fn relay_client_to_server( @@ -100,25 +104,32 @@ where W: AsyncWrite + Unpin, { let mut fragments = FragmentState::None; + let mut close_seen = false; loop { let Some(frame) = read_frame_header(reader).await.inspect_err(|e| { - emit_protocol_failure(host, port, policy_name, &e.to_string()); + emit_protocol_failure(host, port, policy_name, protocol_failure_class(e)); })? else { writer.shutdown().await.into_diagnostic()?; return Ok(()); }; + if close_seen { + let e = miette!("websocket frame received after close frame"); + emit_protocol_failure(host, port, policy_name, protocol_failure_class(&e)); + return Err(e); + } + if let Err(e) = validate_frame_header(&frame, &fragments) { - emit_protocol_failure(host, port, policy_name, &e.to_string()); + emit_protocol_failure(host, port, policy_name, protocol_failure_class(&e)); return Err(e); } match frame.opcode { OPCODE_TEXT => { let payload = read_masked_payload(reader, &frame).await.inspect_err(|e| { - emit_protocol_failure(host, port, policy_name, &e.to_string()); + emit_protocol_failure(host, port, policy_name, protocol_failure_class(e)); })?; if frame.fin { relay_text_payload( @@ -133,7 +144,7 @@ where ) .await .inspect_err(|e| { - emit_protocol_failure(host, port, policy_name, &e.to_string()); + emit_protocol_failure(host, port, policy_name, protocol_failure_class(e)); })?; } else { fragments = FragmentState::Text { payload }; @@ -142,10 +153,10 @@ where OPCODE_CONTINUATION => match &mut fragments { FragmentState::Text { payload } => { let next = read_masked_payload(reader, &frame).await.inspect_err(|e| { - emit_protocol_failure(host, port, policy_name, &e.to_string()); + emit_protocol_failure(host, port, policy_name, protocol_failure_class(e)); })?; if let Err(e) = append_text_fragment(payload, next) { - emit_protocol_failure(host, port, policy_name, &e.to_string()); + emit_protocol_failure(host, port, policy_name, protocol_failure_class(&e)); return Err(e); } if frame.fin { @@ -163,12 +174,26 @@ where ) .await .inspect_err(|e| { - emit_protocol_failure(host, port, policy_name, &e.to_string()); + emit_protocol_failure( + host, + port, + policy_name, + protocol_failure_class(e), + ); })?; } } FragmentState::Binary => { - copy_raw_frame_payload(reader, writer, &frame).await?; + copy_raw_frame_payload(reader, writer, &frame) + .await + .inspect_err(|e| { + emit_protocol_failure( + host, + port, + policy_name, + protocol_failure_class(e), + ); + })?; if frame.fin { fragments = FragmentState::None; } @@ -176,7 +201,7 @@ where FragmentState::None => { let e = miette!("websocket continuation frame without active fragmented message"); - emit_protocol_failure(host, port, policy_name, &e.to_string()); + emit_protocol_failure(host, port, policy_name, protocol_failure_class(&e)); return Err(e); } }, @@ -184,10 +209,21 @@ where if !frame.fin { fragments = FragmentState::Binary; } - copy_raw_frame_payload(reader, writer, &frame).await?; + copy_raw_frame_payload(reader, writer, &frame) + .await + .inspect_err(|e| { + emit_protocol_failure(host, port, policy_name, protocol_failure_class(e)); + })?; } OPCODE_CLOSE | OPCODE_PING | OPCODE_PONG => { - copy_raw_frame_payload(reader, writer, &frame).await?; + relay_control_frame(reader, writer, &frame) + .await + .inspect_err(|e| { + emit_protocol_failure(host, port, policy_name, protocol_failure_class(e)); + })?; + if frame.opcode == OPCODE_CLOSE { + close_seen = true; + } } _ => unreachable!("validated opcode"), } @@ -225,7 +261,13 @@ async fn read_frame_header(reader: &mut R) -> Result { let mut bytes = [0u8; 8]; @@ -237,7 +279,13 @@ async fn read_frame_header(reader: &mut R) -> Result unreachable!("7-bit length code"), }; @@ -306,6 +354,14 @@ fn validate_frame_header(frame: &FrameHeader, fragments: &FragmentState) -> Resu "websocket continuation frame without active fragmented message" )); } + if (frame.opcode == OPCODE_BINARY + || (frame.opcode == OPCODE_CONTINUATION && matches!(fragments, FragmentState::Binary))) + && frame.payload_len > MAX_RAW_FRAME_PAYLOAD_BYTES + { + return Err(miette!( + "websocket binary frame exceeds {MAX_RAW_FRAME_PAYLOAD_BYTES} byte relay limit" + )); + } Ok(()) } @@ -361,7 +417,7 @@ async fn relay_text_payload( .map_err(|_| miette!("websocket text message is not valid UTF-8"))?; let replacements = resolver .rewrite_websocket_text_placeholders(&mut text) - .map_err(|e| miette!("{e}"))?; + .map_err(|_| miette!("websocket credential placeholder resolution failed"))?; if replacements == 0 && !force_reframe { writer @@ -384,6 +440,65 @@ async fn relay_text_payload( write_masked_frame(writer, OPCODE_TEXT, text.as_bytes()).await } +async fn relay_control_frame( + reader: &mut R, + writer: &mut W, + frame: &FrameHeader, +) -> Result<()> +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + let raw_payload_len = usize::try_from(frame.payload_len) + .map_err(|_| miette!("websocket control frame payload length overflow"))?; + let mut raw_payload = vec![0u8; raw_payload_len]; + reader + .read_exact(&mut raw_payload) + .await + .map_err(|e| miette!("malformed websocket control payload: {e}"))?; + + if frame.opcode == OPCODE_CLOSE { + let mut payload = raw_payload.clone(); + let mask_key = frame + .mask_key + .ok_or_else(|| miette!("websocket client frame is not masked"))?; + apply_mask(&mut payload, mask_key); + validate_close_payload(&payload)?; + } + + writer + .write_all(&frame.raw_header) + .await + .into_diagnostic()?; + writer.write_all(&raw_payload).await.into_diagnostic()?; + writer.flush().await.into_diagnostic()?; + Ok(()) +} + +fn validate_close_payload(payload: &[u8]) -> Result<()> { + if payload.len() == 1 { + return Err(miette!( + "websocket close frame payload cannot be exactly one byte" + )); + } + if payload.len() < 2 { + return Ok(()); + } + + let code = u16::from_be_bytes([payload[0], payload[1]]); + if !valid_close_code(code) { + return Err(miette!("websocket close frame uses invalid close code")); + } + if std::str::from_utf8(&payload[2..]).is_err() { + return Err(miette!("websocket close frame reason is not valid UTF-8")); + } + Ok(()) +} + +fn valid_close_code(code: u16) -> bool { + (matches!(code, 1000..=1014) && !matches!(code, 1004..=1006)) || (3000..=4999).contains(&code) +} + async fn copy_raw_frame_payload( reader: &mut R, writer: &mut W, @@ -479,7 +594,38 @@ fn emit_rewrite_event(host: &str, port: u16, policy_name: &str, replacements: us ocsf_emit!(event); } -fn emit_protocol_failure(host: &str, port: u16, policy_name: &str, detail: &str) { +fn protocol_failure_class(error: &miette::Report) -> &'static str { + let msg = error.to_string().to_ascii_lowercase(); + if msg.contains("credential") { + "credential_resolution_failed" + } else if msg.contains("utf-8") { + "invalid_utf8" + } else if msg.contains("close frame") || msg.contains("after close") { + "invalid_close_frame" + } else if msg.contains("control frame") { + "invalid_control_frame" + } else if msg.contains("length") + || msg.contains("too large") + || msg.contains("exceeds") + || msg.contains("overflow") + { + "invalid_length" + } else if msg.contains("continuation") || msg.contains("fragmented") { + "invalid_fragmentation" + } else if msg.contains("reserved opcode") { + "reserved_opcode" + } else if msg.contains("not masked") { + "unmasked_client_frame" + } else if msg.contains("rsv") { + "rsv_bits" + } else if msg.contains("malformed") { + "malformed_frame" + } else { + "protocol_error" + } +} + +fn emit_protocol_failure(host: &str, port: u16, policy_name: &str, failure_class: &str) { let policy_name = if policy_name.is_empty() { "-" } else { @@ -496,7 +642,7 @@ fn emit_protocol_failure(host: &str, port: u16, policy_name: &str, detail: &str) .message(format!( "WEBSOCKET_CREDENTIAL_REWRITE closed ambiguous client frame [host:{host} port:{port}]" )) - .status_detail(detail) + .status_detail(failure_class) .build(); ocsf_emit!(event); } @@ -548,6 +694,39 @@ mod tests { frame } + fn masked_frame_with_declared_len(opcode: u8, declared_len: u64) -> Vec { + let mut frame = Vec::new(); + frame.push(0x80 | opcode); + frame.push(0x80 | 127); + frame.extend_from_slice(&declared_len.to_be_bytes()); + frame.extend_from_slice(&[0x37, 0xfa, 0x21, 0x3d]); + frame + } + + fn masked_frame_with_non_minimal_16_bit_len(opcode: u8, payload: &[u8]) -> Vec { + let mask_key = [0x37, 0xfa, 0x21, 0x3d]; + let mut frame = Vec::new(); + frame.push(0x80 | opcode); + frame.push(0x80 | 0x7e); + frame.extend_from_slice( + &u16::try_from(payload.len()) + .expect("test payload fits u16") + .to_be_bytes(), + ); + frame.extend_from_slice(&mask_key); + for (i, byte) in payload.iter().enumerate() { + frame.push(byte ^ mask_key[i % 4]); + } + frame + } + + fn close_payload(code: u16, reason: &[u8]) -> Vec { + let mut payload = Vec::with_capacity(2 + reason.len()); + payload.extend_from_slice(&code.to_be_bytes()); + payload.extend_from_slice(reason); + payload + } + async fn run_client_to_server(input: Vec) -> Result> { let (_, resolver) = resolver(); let (mut client_write, mut relay_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); @@ -719,7 +898,7 @@ mod tests { assert!( err.to_string() - .contains("unresolved credential placeholder") + .contains("credential placeholder resolution") ); } @@ -781,4 +960,120 @@ mod tests { assert!(err.to_string().contains("exceeds")); } + + #[tokio::test] + async fn fragmented_text_allows_interleaved_ping_pong_and_rewrites_at_completion() { + let (child_env, _) = resolver(); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let first = format!(r#"{{"token":"{placeholder}"#); + let first_control_frame = masked_frame(true, OPCODE_PING, b"p"); + let second_control_frame = masked_frame(true, OPCODE_PONG, b"q"); + let mut input = masked_frame(false, OPCODE_TEXT, first.as_bytes()); + input.extend_from_slice(&first_control_frame); + input.extend_from_slice(&second_control_frame); + input.extend(masked_frame(true, OPCODE_CONTINUATION, br#""}"#)); + + let output = run_client_to_server(input) + .await + .expect("relay should allow interleaved control frames"); + + assert!(output.starts_with(&first_control_frame)); + assert_eq!( + &output + [first_control_frame.len()..first_control_frame.len() + second_control_frame.len()], + second_control_frame.as_slice() + ); + assert_eq!( + decode_masked_text_frame( + &output[first_control_frame.len() + second_control_frame.len()..] + ), + r#"{"token":"real-token"}"# + ); + } + + #[tokio::test] + async fn binary_frame_passes_through_unchanged() { + let frame = masked_frame(true, OPCODE_BINARY, &[0, 1, 2, 3, 255]); + + let output = run_client_to_server(frame.clone()) + .await + .expect("binary frame should pass through"); + + assert_eq!(output, frame); + } + + #[tokio::test] + async fn rejects_reserved_opcode() { + let err = run_client_to_server(masked_frame(true, 0x3, b"reserved")) + .await + .expect_err("reserved opcode should fail"); + + assert!(err.to_string().contains("reserved opcode")); + } + + #[tokio::test] + async fn rejects_non_minimal_extended_length() { + let err = run_client_to_server(masked_frame_with_non_minimal_16_bit_len( + OPCODE_TEXT, + b"hello", + )) + .await + .expect_err("non-minimal length should fail"); + + assert!(err.to_string().contains("non-minimal")); + } + + #[tokio::test] + async fn rejects_oversize_binary_frame_before_payload_buffering() { + let err = run_client_to_server(masked_frame_with_declared_len( + OPCODE_BINARY, + MAX_RAW_FRAME_PAYLOAD_BYTES + 1, + )) + .await + .expect_err("oversize binary frame should fail"); + + assert!(err.to_string().contains("binary frame exceeds")); + } + + #[tokio::test] + async fn validates_close_frame_payloads() { + let frame = masked_frame(true, OPCODE_CLOSE, &close_payload(1000, b"done")); + + let output = run_client_to_server(frame.clone()) + .await + .expect("valid close frame should pass through"); + + assert_eq!(output, frame); + } + + #[tokio::test] + async fn rejects_close_frame_with_one_byte_payload() { + let err = run_client_to_server(masked_frame(true, OPCODE_CLOSE, &[0x03])) + .await + .expect_err("one-byte close frame should fail"); + + assert!(err.to_string().contains("exactly one byte")); + } + + #[tokio::test] + async fn rejects_reserved_close_code() { + let err = run_client_to_server(masked_frame(true, OPCODE_CLOSE, &close_payload(1005, b""))) + .await + .expect_err("reserved close code should fail"); + + assert!(err.to_string().contains("invalid close code")); + } + + #[tokio::test] + async fn rejects_close_reason_with_invalid_utf8() { + let err = run_client_to_server(masked_frame( + true, + OPCODE_CLOSE, + &close_payload(1000, &[0xff]), + )) + .await + .expect_err("invalid close reason should fail"); + + assert!(err.to_string().contains("valid UTF-8")); + } } diff --git a/crates/openshell-sandbox/src/secrets.rs b/crates/openshell-sandbox/src/secrets.rs index bc8ac7c6e..d9b5af2c1 100644 --- a/crates/openshell-sandbox/src/secrets.rs +++ b/crates/openshell-sandbox/src/secrets.rs @@ -181,6 +181,20 @@ impl SecretResolver { rewritten.push_str(&text[pos..abs_start]); let key_start = abs_start + PLACEHOLDER_PREFIX.len(); + if let Some((key_end, full_placeholder)) = + self.longest_known_placeholder_match(text, abs_start) + { + let Some(secret) = self.resolve_placeholder(full_placeholder) else { + return Err(UnresolvedPlaceholderError { + location: "websocket", + }); + }; + rewritten.push_str(secret); + replacements += 1; + pos = key_end; + continue; + } + let key_end = text[key_start..] .bytes() .position(|b| !is_env_key_char(b)) @@ -207,6 +221,27 @@ impl SecretResolver { Ok(replacements) } + fn longest_known_placeholder_match<'a>( + &'a self, + text: &str, + abs_start: usize, + ) -> Option<(usize, &'a str)> { + let suffix = &text[abs_start..]; + self.by_placeholder + .keys() + .filter_map(|placeholder| { + if !suffix.starts_with(placeholder) { + return None; + } + let key_end = abs_start + placeholder.len(); + let boundary_ok = key_end == text.len() + || !is_env_key_char(text.as_bytes()[key_end]) + || text[key_end..].starts_with(PLACEHOLDER_PREFIX); + boundary_ok.then_some((key_end, placeholder.as_str())) + }) + .max_by_key(|(_, placeholder)| placeholder.len()) + } + /// Decode a Base64-encoded Basic auth token, resolve any placeholders in /// the decoded `username:password` string, and re-encode. /// @@ -1559,6 +1594,48 @@ mod tests { assert_eq!(payload, original); } + #[test] + fn rewrite_websocket_text_handles_repeated_adjacent_and_unicode_placeholders() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [ + ("TOKEN".to_string(), "tok".to_string()), + ("APP".to_string(), "app".to_string()), + ] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let token = child_env.get("TOKEN").unwrap(); + let app = child_env.get("APP").unwrap(); + let mut payload = format!("prefix-โ˜ƒ-{token}{app}-{token}-suffix"); + + let count = resolver + .rewrite_websocket_text_placeholders(&mut payload) + .expect("rewrite should succeed"); + + assert_eq!(count, 3); + assert_eq!(payload, "prefix-โ˜ƒ-tokapp-tok-suffix"); + } + + #[test] + fn rewrite_websocket_text_placeholder_like_prefix_fails_without_mutating() { + let (_, resolver) = SecretResolver::from_provider_env( + [("KEY".to_string(), "secret".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let original = "openshell:resolve:env:-not-a-key".to_string(); + let mut payload = original.clone(); + + let err = resolver + .rewrite_websocket_text_placeholders(&mut payload) + .expect_err("placeholder-like prefix should fail closed"); + + assert_eq!(err.location, "websocket"); + assert_eq!(payload, original); + } + // === Redaction tests === #[test] From 6fd6d4e2ba92c7717e5459a496c2c5069c808e2b Mon Sep 17 00:00:00 2001 From: Aaron Erickson Date: Fri, 8 May 2026 20:30:50 -0700 Subject: [PATCH 03/17] feat(sandbox): add websocket l7 inspection and compression Signed-off-by: Aaron Erickson --- Cargo.lock | 1 + crates/openshell-sandbox/Cargo.toml | 1 + crates/openshell-sandbox/src/l7/mod.rs | 81 ++- crates/openshell-sandbox/src/l7/provider.rs | 5 +- crates/openshell-sandbox/src/l7/relay.rs | 284 +++++++++-- crates/openshell-sandbox/src/l7/rest.rs | 303 ++++++++++-- crates/openshell-sandbox/src/l7/websocket.rs | 462 +++++++++++++++--- crates/openshell-sandbox/src/proxy.rs | 27 +- .../tests/websocket_upgrade.rs | 2 +- proto/sandbox.proto | 2 +- 10 files changed, 987 insertions(+), 181 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 808956cd9..63ba19ced 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3571,6 +3571,7 @@ dependencies = [ "base64 0.22.1", "bytes", "clap", + "flate2", "futures", "glob", "hex", diff --git a/crates/openshell-sandbox/Cargo.toml b/crates/openshell-sandbox/Cargo.toml index 4e07521ce..6d8310b26 100644 --- a/crates/openshell-sandbox/Cargo.toml +++ b/crates/openshell-sandbox/Cargo.toml @@ -58,6 +58,7 @@ uuid = { workspace = true } # Encoding base64 = { workspace = true } +flate2 = "1" # IP network / CIDR parsing ipnet = "2" diff --git a/crates/openshell-sandbox/src/l7/mod.rs b/crates/openshell-sandbox/src/l7/mod.rs index 153611726..e553a9e05 100644 --- a/crates/openshell-sandbox/src/l7/mod.rs +++ b/crates/openshell-sandbox/src/l7/mod.rs @@ -21,6 +21,7 @@ pub(crate) mod websocket; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum L7Protocol { Rest, + Websocket, Graphql, Sql, } @@ -29,6 +30,7 @@ impl L7Protocol { pub fn parse(s: &str) -> Option { match s.to_ascii_lowercase().as_str() { "rest" => Some(Self::Rest), + "websocket" => Some(Self::Websocket), "graphql" => Some(Self::Graphql), "sql" => Some(Self::Sql), _ => None, @@ -469,7 +471,7 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< if !protocol.is_empty() && L7Protocol::parse(protocol).is_none() { errors.push(format!( - "{loc}: unknown protocol '{protocol}' (expected rest, graphql, or sql)" + "{loc}: unknown protocol '{protocol}' (expected rest, websocket, graphql, or sql)" )); } @@ -510,9 +512,10 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< .and_then(serde_json::Value::as_bool) .unwrap_or(false) && protocol != "rest" + && protocol != "websocket" { warnings.push(format!( - "{loc}: websocket_credential_rewrite is ignored unless protocol is rest" + "{loc}: websocket_credential_rewrite is ignored unless protocol is rest or websocket" )); } @@ -592,14 +595,13 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< // Validate method if let Some(method) = deny_rule.get("method").and_then(|m| m.as_str()) && !method.is_empty() - && protocol == "rest" + && (protocol == "rest" || protocol == "websocket") { - let valid_methods = [ - "GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "*", - ]; + let valid_methods = valid_methods_for_protocol(protocol); if !valid_methods.contains(&method.to_ascii_uppercase().as_str()) { warnings.push(format!( - "{deny_loc}: Unknown HTTP method '{method}'. Standard methods: GET, HEAD, POST, PUT, DELETE, PATCH, OPTIONS." + "{deny_loc}: Unknown HTTP/WebSocket method '{method}'. Standard methods: {}." + , valid_methods.join(", ") )); } } @@ -751,10 +753,8 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< } // Validate HTTP methods in rules - if has_rules && protocol == "rest" { - let valid_methods = [ - "GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "*", - ]; + if has_rules && (protocol == "rest" || protocol == "websocket") { + let valid_methods = valid_methods_for_protocol(protocol); if let Some(rules) = ep.get("rules").and_then(|v| v.as_array()) { for (rule_idx, rule) in rules.iter().enumerate() { if let Some(method) = rule @@ -765,7 +765,8 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< && !valid_methods.contains(&method.to_ascii_uppercase().as_str()) { warnings.push(format!( - "{loc}: Unknown HTTP method '{method}'. Standard methods: GET, HEAD, POST, PUT, DELETE, PATCH, OPTIONS." + "{loc}: Unknown HTTP/WebSocket method '{method}'. Standard methods: {}." + , valid_methods.join(", ") )); } @@ -939,6 +940,13 @@ pub fn expand_access_presets(data: &mut serde_json::Value) { "full" => vec![graphql_rule_json("*")], _ => continue, } + } else if protocol == "websocket" { + match access.as_str() { + "read-only" => vec![rule_json("GET", "**")], + "read-write" => vec![rule_json("GET", "**"), rule_json("WEBSOCKET_TEXT", "**")], + "full" => vec![rule_json("*", "**")], + _ => continue, + } } else { match access.as_str() { "read-only" => vec![ @@ -975,6 +983,15 @@ fn rule_json(method: &str, path: &str) -> serde_json::Value { }) } +fn valid_methods_for_protocol(protocol: &str) -> &'static [&'static str] { + match protocol { + "websocket" => &["GET", "WEBSOCKET_TEXT", "*"], + _ => &[ + "GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "*", + ], + } +} + fn graphql_rule_json(operation_type: &str) -> serde_json::Value { serde_json::json!({ "allow": { @@ -1012,6 +1029,16 @@ mod tests { assert_eq!(config.enforcement, EnforcementMode::Audit); } + #[test] + fn parse_l7_config_websocket_protocol() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "websocket", "host": "gateway.example.com", "port": 443}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert_eq!(config.protocol, L7Protocol::Websocket); + } + #[test] fn parse_l7_config_skip() { let val = regorus::Value::from_json_str( @@ -1070,7 +1097,7 @@ mod tests { } #[test] - fn validate_websocket_credential_rewrite_warns_unless_rest() { + fn validate_websocket_credential_rewrite_warns_unless_rest_or_websocket() { let data = serde_json::json!({ "network_policies": { "test": { @@ -1092,6 +1119,34 @@ mod tests { ); } + #[test] + fn expand_websocket_read_write_access_includes_text_messages() { + let mut data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "access": "read-write" + }], + "binaries": [] + } + } + }); + + expand_access_presets(&mut data); + let rules = data["network_policies"]["test"]["endpoints"][0]["rules"] + .as_array() + .unwrap(); + let methods: Vec<&str> = rules + .iter() + .map(|r| r["allow"]["method"].as_str().unwrap()) + .collect(); + assert!(methods.contains(&"GET")); + assert!(methods.contains(&"WEBSOCKET_TEXT")); + } + #[test] fn validate_rules_and_access_mutual_exclusion() { let data = serde_json::json!({ diff --git a/crates/openshell-sandbox/src/l7/provider.rs b/crates/openshell-sandbox/src/l7/provider.rs index 7516aa85c..864d94ad2 100644 --- a/crates/openshell-sandbox/src/l7/provider.rs +++ b/crates/openshell-sandbox/src/l7/provider.rs @@ -27,7 +27,10 @@ pub enum RelayOutcome { /// Contains any overflow bytes read from upstream past the 101 response /// headers that belong to the upgraded protocol. The 101 headers /// themselves have already been forwarded to the client. - Upgraded { overflow: Vec }, + Upgraded { + overflow: Vec, + websocket_permessage_deflate: bool, + }, } /// Body framing for HTTP requests/responses. diff --git a/crates/openshell-sandbox/src/l7/relay.rs b/crates/openshell-sandbox/src/l7/relay.rs index 244f3cf99..7af4a82ea 100644 --- a/crates/openshell-sandbox/src/l7/relay.rs +++ b/crates/openshell-sandbox/src/l7/relay.rs @@ -8,6 +8,7 @@ //! and either forwards or denies the request. use crate::l7::provider::{L7Provider, RelayOutcome}; +use crate::l7::rest::WebSocketExtensionMode; use crate::l7::{EnforcementMode, L7EndpointConfig, L7Protocol, L7RequestInfo}; use crate::opa::{PolicyGenerationGuard, TunnelPolicyEngine}; use crate::secrets::{self, SecretResolver}; @@ -38,14 +39,26 @@ pub struct L7EvalContext { pub(crate) secret_resolver: Option>, } -#[derive(Clone, Default)] -pub(crate) struct UpgradeRelayOptions { +#[derive(Default)] +pub(crate) struct UpgradeRelayOptions<'a> { pub(crate) websocket_request: bool, - pub(crate) websocket_credential_rewrite: bool, + pub(crate) websocket: WebSocketUpgradeBehavior, pub(crate) secret_resolver: Option>, + pub(crate) engine: Option<&'a TunnelPolicyEngine>, + pub(crate) ctx: Option<&'a L7EvalContext>, + pub(crate) enforcement: EnforcementMode, + pub(crate) target: String, + pub(crate) query_params: std::collections::HashMap>, pub(crate) policy_name: String, } +#[derive(Default)] +pub(crate) struct WebSocketUpgradeBehavior { + pub(crate) credential_rewrite: bool, + pub(crate) message_inspection: bool, + pub(crate) permessage_deflate: bool, +} + #[derive(Debug, Clone, Copy)] enum ParseRejectionMode { L7Endpoint, @@ -109,7 +122,9 @@ where U: AsyncRead + AsyncWrite + Unpin + Send, { match config.protocol { - L7Protocol::Rest => relay_rest(config, &engine, client, upstream, ctx).await, + L7Protocol::Rest | L7Protocol::Websocket => { + relay_rest(config, &engine, client, upstream, ctx).await + } L7Protocol::Graphql => relay_graphql(config, &engine, client, upstream, ctx).await, L7Protocol::Sql => { if close_if_stale(engine.generation_guard(), ctx) { @@ -250,6 +265,24 @@ where query_params: req.query_params.clone(), graphql: graphql_info.clone(), }; + let websocket_request = crate::l7::rest::request_is_websocket_upgrade(&req.raw_header); + if config.protocol == L7Protocol::Websocket && !websocket_request { + crate::l7::rest::RestProvider::default() + .deny_with_redacted_target( + &req, + &ctx.policy_name, + "websocket endpoint requires a valid WebSocket upgrade request", + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } let parse_error_reason = graphql_info .as_ref() @@ -272,10 +305,10 @@ where (false, EnforcementMode::Audit) => "audit", (false, EnforcementMode::Enforce) => "deny", }; - let engine_type = if config.protocol == L7Protocol::Graphql { - "l7-graphql" - } else { - "l7" + let engine_type = match config.protocol { + L7Protocol::Graphql => "l7-graphql", + L7Protocol::Websocket => "l7-websocket", + L7Protocol::Rest | L7Protocol::Sql => "l7", }; emit_l7_request_log( ctx, @@ -290,7 +323,6 @@ where let _ = &eval_target; if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { - let websocket_request = crate::l7::rest::request_is_websocket_upgrade(&req.raw_header); let outcome = crate::l7::rest::relay_http_request_with_options_guarded( &req, client, @@ -298,22 +330,28 @@ where crate::l7::rest::RelayRequestOptions { resolver: ctx.secret_resolver.as_deref(), generation_guard: Some(engine.generation_guard()), - strip_websocket_extensions: config.protocol == L7Protocol::Rest - && config.websocket_credential_rewrite, + websocket_extensions: websocket_extension_mode(config), }, ) .await?; match outcome { RelayOutcome::Reusable => {} RelayOutcome::Consumed => return Ok(()), - RelayOutcome::Upgraded { overflow } => { + RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + } => { + let mut options = upgrade_options( + config, + ctx, + websocket_request, + &redacted_target, + &req.query_params, + Some(&engine), + ); + options.websocket.permessage_deflate = websocket_permessage_deflate; return handle_upgrade( - client, - upstream, - overflow, - &ctx.host, - ctx.port, - upgrade_options(config, ctx, websocket_request), + client, upstream, overflow, &ctx.host, ctx.port, options, ) .await; } @@ -395,26 +433,26 @@ fn emit_l7_request_log( /// Handle an upgraded connection (101 Switching Protocols). /// /// Forwards any overflow bytes from the upgrade response to the client, then -/// switches to raw bidirectional TCP copy for the upgraded protocol (WebSocket, -/// HTTP/2, etc.). L7 policy enforcement does not apply after the upgrade โ€” -/// the initial HTTP request was already evaluated. +/// either switches to a parsed WebSocket relay for opted-in message policy / +/// credential rewriting or to raw bidirectional TCP copy for other upgrades. pub(crate) async fn handle_upgrade( client: &mut C, upstream: &mut U, overflow: Vec, host: &str, port: u16, - options: UpgradeRelayOptions, + options: UpgradeRelayOptions<'_>, ) -> Result<()> where C: AsyncRead + AsyncWrite + Unpin + Send, U: AsyncRead + AsyncWrite + Unpin + Send, { - let use_websocket_rewrite = options.websocket_request - && options.websocket_credential_rewrite - && options.secret_resolver.is_some(); - let relay_mode = if use_websocket_rewrite { - "websocket credential rewrite relay" + let use_websocket_relay = options.websocket_request + && (options.websocket.message_inspection + || options.websocket.permessage_deflate + || (options.websocket.credential_rewrite && options.secret_resolver.is_some())); + let relay_mode = if use_websocket_relay { + "websocket parsed relay" } else { "raw bidirectional relay (L7 enforcement no longer active)" }; @@ -430,15 +468,47 @@ where )) .build() ); - if use_websocket_rewrite && let Some(resolver) = options.secret_resolver.as_deref() { - return crate::l7::websocket::relay_with_credential_rewrite( + if use_websocket_relay { + let resolver = if options.websocket.credential_rewrite { + options.secret_resolver.as_deref() + } else { + None + }; + let inspector = if options.websocket.message_inspection { + match (options.engine, options.ctx) { + (Some(engine), Some(ctx)) => Some(crate::l7::websocket::InspectionOptions { + engine, + ctx, + enforcement: options.enforcement, + target: options.target.clone(), + query_params: options.query_params.clone(), + }), + _ => { + return Err(miette!( + "websocket message inspection missing policy context" + )); + } + } + } else { + None + }; + let compression = if options.websocket.permessage_deflate { + crate::l7::websocket::WebSocketCompression::PermessageDeflate + } else { + crate::l7::websocket::WebSocketCompression::None + }; + return crate::l7::websocket::relay_with_options( client, upstream, overflow, host, port, - &options.policy_name, - resolver, + crate::l7::websocket::RelayOptions { + policy_name: &options.policy_name, + resolver, + inspector, + compression, + }, ) .await; } @@ -452,25 +522,49 @@ where Ok(()) } -fn upgrade_options( +fn upgrade_options<'a>( config: &L7EndpointConfig, - ctx: &L7EvalContext, + ctx: &'a L7EvalContext, websocket_request: bool, -) -> UpgradeRelayOptions { + target: &str, + query_params: &std::collections::HashMap>, + engine: Option<&'a TunnelPolicyEngine>, +) -> UpgradeRelayOptions<'a> { let websocket_credential_rewrite = - config.protocol == L7Protocol::Rest && config.websocket_credential_rewrite; + matches!(config.protocol, L7Protocol::Rest | L7Protocol::Websocket) + && config.websocket_credential_rewrite; + let websocket_message_inspection = config.protocol == L7Protocol::Websocket; UpgradeRelayOptions { websocket_request, - websocket_credential_rewrite, + websocket: WebSocketUpgradeBehavior { + credential_rewrite: websocket_credential_rewrite, + message_inspection: websocket_message_inspection, + permessage_deflate: false, + }, secret_resolver: if websocket_credential_rewrite { ctx.secret_resolver.clone() } else { None }, + engine, + ctx: engine.map(|_| ctx), + enforcement: config.enforcement, + target: target.to_string(), + query_params: query_params.clone(), policy_name: ctx.policy_name.clone(), } } +fn websocket_extension_mode(config: &L7EndpointConfig) -> WebSocketExtensionMode { + if config.protocol == L7Protocol::Websocket + || (config.protocol == L7Protocol::Rest && config.websocket_credential_rewrite) + { + WebSocketExtensionMode::PermessageDeflate + } else { + WebSocketExtensionMode::Preserve + } +} + /// REST relay loop: parse request -> evaluate -> allow/deny -> relay response -> repeat. async fn relay_rest( config: &L7EndpointConfig, @@ -550,6 +644,24 @@ where query_params: req.query_params.clone(), graphql: None, }; + let websocket_request = crate::l7::rest::request_is_websocket_upgrade(&req.raw_header); + if config.protocol == L7Protocol::Websocket && !websocket_request { + provider + .deny_with_redacted_target( + &req, + &ctx.policy_name, + "websocket endpoint requires a valid WebSocket upgrade request", + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } // Evaluate L7 policy via Rego (using redacted target) let (allowed, reason) = evaluate_l7_request(engine, ctx, &request_info)?; @@ -618,7 +730,6 @@ where if allowed || config.enforcement == EnforcementMode::Audit { // Forward request to upstream and relay response - let websocket_request = crate::l7::rest::request_is_websocket_upgrade(&req.raw_header); let outcome = crate::l7::rest::relay_http_request_with_options_guarded( &req, client, @@ -626,8 +737,7 @@ where crate::l7::rest::RelayRequestOptions { resolver: ctx.secret_resolver.as_deref(), generation_guard: Some(engine.generation_guard()), - strip_websocket_extensions: config.protocol == L7Protocol::Rest - && config.websocket_credential_rewrite, + websocket_extensions: websocket_extension_mode(config), }, ) .await?; @@ -641,14 +751,21 @@ where ); return Ok(()); } - RelayOutcome::Upgraded { overflow } => { + RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + } => { + let mut options = upgrade_options( + config, + ctx, + websocket_request, + &redacted_target, + &req.query_params, + Some(engine), + ); + options.websocket.permessage_deflate = websocket_permessage_deflate; return handle_upgrade( - client, - upstream, - overflow, - &ctx.host, - ctx.port, - upgrade_options(config, ctx, websocket_request), + client, upstream, overflow, &ctx.host, ctx.port, options, ) .await; } @@ -860,14 +977,19 @@ where ); return Ok(()); } - RelayOutcome::Upgraded { overflow } => { + RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + } => { + let options = UpgradeRelayOptions { + websocket: WebSocketUpgradeBehavior { + permessage_deflate: websocket_permessage_deflate, + ..Default::default() + }, + ..Default::default() + }; return handle_upgrade( - client, - upstream, - overflow, - &ctx.host, - ctx.port, - UpgradeRelayOptions::default(), + client, upstream, overflow, &ctx.host, ctx.port, options, ) .await; } @@ -1109,7 +1231,7 @@ where match outcome { RelayOutcome::Reusable => {} // continue loop RelayOutcome::Consumed => break, - RelayOutcome::Upgraded { overflow } => { + RelayOutcome::Upgraded { overflow, .. } => { return handle_upgrade( client, upstream, @@ -1175,6 +1297,60 @@ mod tests { ); } + #[test] + fn websocket_text_policy_requires_explicit_message_rule() { + let data = r#" +network_policies: + ws_api: + name: ws_api + endpoints: + - host: gateway.example.test + port: 443 + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/ws" + binaries: + - { path: /usr/bin/node } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "gateway.example.test".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let generation = engine + .evaluate_network_action_with_generation(&input) + .unwrap() + .1; + let tunnel_engine = engine.clone_engine_for_tunnel(generation).unwrap(); + let ctx = L7EvalContext { + host: "gateway.example.test".into(), + port: 443, + policy_name: "ws_api".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + }; + let request = L7RequestInfo { + action: "WEBSOCKET_TEXT".into(), + target: "/ws".into(), + query_params: std::collections::HashMap::new(), + graphql: None, + }; + + let (allowed, reason) = evaluate_l7_request(&tunnel_engine, &ctx, &request).unwrap(); + + assert!(!allowed); + assert!(reason.contains("WEBSOCKET_TEXT /ws not permitted")); + } + #[tokio::test] async fn l7_relay_closes_keep_alive_tunnel_after_policy_generation_change() { let initial_data = r#" diff --git a/crates/openshell-sandbox/src/l7/rest.rs b/crates/openshell-sandbox/src/l7/rest.rs index 4a82b95c9..f3dfc0f0e 100644 --- a/crates/openshell-sandbox/src/l7/rest.rs +++ b/crates/openshell-sandbox/src/l7/rest.rs @@ -371,17 +371,24 @@ where RelayRequestOptions { resolver, generation_guard, - strip_websocket_extensions: false, + websocket_extensions: WebSocketExtensionMode::Preserve, }, ) .await } +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub(crate) enum WebSocketExtensionMode { + #[default] + Preserve, + PermessageDeflate, +} + #[derive(Clone, Copy, Default)] pub(crate) struct RelayRequestOptions<'a> { pub(crate) resolver: Option<&'a SecretResolver>, pub(crate) generation_guard: Option<&'a PolicyGenerationGuard>, - pub(crate) strip_websocket_extensions: bool, + pub(crate) websocket_extensions: WebSocketExtensionMode, } pub(crate) async fn relay_http_request_with_options_guarded( @@ -399,18 +406,18 @@ where .windows(4) .position(|w| w == b"\r\n\r\n") .map_or(req.raw_header.len(), |p| p + 4); - let websocket_rewrite_request = if options.strip_websocket_extensions { - validate_websocket_upgrade_request(&req.raw_header[..header_end])? - } else { + let websocket_request = if options.websocket_extensions == WebSocketExtensionMode::Preserve { false - }; - - let header_bytes = if websocket_rewrite_request { - strip_websocket_extensions_if_requested(&req.raw_header[..header_end])? } else { - req.raw_header[..header_end].to_vec() + validate_websocket_upgrade_request(&req.raw_header[..header_end])? }; + let (header_bytes, permessage_deflate_offered) = rewrite_websocket_extensions_for_mode( + &req.raw_header[..header_end], + options.websocket_extensions, + websocket_request, + )?; + let rewrite_result = rewrite_http_header_block(&header_bytes, options.resolver) .map_err(|e| miette!("credential injection failed: {e}"))?; @@ -457,7 +464,8 @@ where upstream, client, RelayResponseOptions { - reject_websocket_extensions: websocket_rewrite_request, + websocket_extensions: options.websocket_extensions, + permessage_deflate_offered, }, ) .await?; @@ -510,14 +518,31 @@ pub(crate) fn request_is_websocket_upgrade(raw_header: &[u8]) -> bool { validate_websocket_upgrade_request(&raw_header[..header_end]).unwrap_or(false) } -fn strip_websocket_extensions_if_requested(raw_header: &[u8]) -> Result> { - let header_str = std::str::from_utf8(raw_header) - .map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; - if !validate_websocket_upgrade_request(raw_header)? { - return Ok(raw_header.to_vec()); +fn rewrite_websocket_extensions_for_mode( + raw_header: &[u8], + mode: WebSocketExtensionMode, + websocket_request: bool, +) -> Result<(Vec, bool)> { + if !websocket_request || mode == WebSocketExtensionMode::Preserve { + return Ok((raw_header.to_vec(), false)); + } + match mode { + WebSocketExtensionMode::Preserve => Ok((raw_header.to_vec(), false)), + WebSocketExtensionMode::PermessageDeflate => { + rewrite_websocket_extensions_for_permessage_deflate(raw_header) + } } +} +fn rewrite_websocket_extensions_for_permessage_deflate( + raw_header: &[u8], +) -> Result<(Vec, bool)> { + let header_str = std::str::from_utf8(raw_header) + .map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let safe_offer = supported_permessage_deflate_offer(header_str); let mut out = Vec::with_capacity(raw_header.len()); + let mut inserted = false; + for line in header_str.split_inclusive("\r\n") { let bare = line.strip_suffix("\r\n").unwrap_or(line); if bare @@ -526,9 +551,74 @@ fn strip_websocket_extensions_if_requested(raw_header: &[u8]) -> Result> { continue; } + if bare.is_empty() && !inserted { + if let Some(offer) = safe_offer.as_deref() { + out.extend_from_slice(b"Sec-WebSocket-Extensions: "); + out.extend_from_slice(offer.as_bytes()); + out.extend_from_slice(b"\r\n"); + } + inserted = true; + } out.extend_from_slice(line.as_bytes()); } - Ok(out) + Ok((out, safe_offer.is_some())) +} + +fn supported_permessage_deflate_offer(header_str: &str) -> Option { + for params in websocket_extension_offers(header_str) { + let Some((extension, rest)) = params.split_first() else { + continue; + }; + if !extension.eq_ignore_ascii_case("permessage-deflate") { + continue; + } + let mut client_no_context_takeover = false; + let mut server_no_context_takeover = false; + let mut unsupported = false; + for param in rest { + let (name, value) = param.split_once('=').unwrap_or((param, "")); + if name.eq_ignore_ascii_case("client_no_context_takeover") && value.is_empty() { + client_no_context_takeover = true; + } else if name.eq_ignore_ascii_case("server_no_context_takeover") && value.is_empty() { + server_no_context_takeover = true; + } else { + unsupported = true; + break; + } + } + if client_no_context_takeover && !unsupported { + let mut offer = "permessage-deflate; client_no_context_takeover".to_string(); + if server_no_context_takeover { + offer.push_str("; server_no_context_takeover"); + } + return Some(offer); + } + } + None +} + +fn websocket_extension_offers(header_str: &str) -> Vec> { + let mut offers = Vec::new(); + for line in header_str.lines().skip(1) { + let Some((name, value)) = line.split_once(':') else { + continue; + }; + if !name.trim().eq_ignore_ascii_case("sec-websocket-extensions") { + continue; + } + for extension in value.split(',') { + let params: Vec = extension + .split(';') + .map(str::trim) + .filter(|v| !v.is_empty()) + .map(ToOwned::to_owned) + .collect(); + if !params.is_empty() { + offers.push(params); + } + } + } + offers } fn validate_websocket_upgrade_request(raw_header: &[u8]) -> Result { @@ -947,7 +1037,8 @@ fn find_crlf(buf: &[u8], start: usize) -> Option { #[derive(Clone, Copy, Default)] struct RelayResponseOptions { - reject_websocket_extensions: bool, + websocket_extensions: WebSocketExtensionMode, + permessage_deflate_offered: bool, } async fn relay_response( @@ -1008,17 +1099,11 @@ where // from upstream beyond the headers are overflow that belong to the // upgraded protocol and must be forwarded before switching. if status_code == 101 { - if options.reject_websocket_extensions - && header_str.lines().skip(1).any(|line| { - line.split_once(':').is_some_and(|(name, _)| { - name.trim().eq_ignore_ascii_case("sec-websocket-extensions") - }) - }) - { - return Err(miette!( - "upstream negotiated unsupported WebSocket extensions" - )); - } + let websocket_permessage_deflate = validate_websocket_response_extensions( + &header_str, + options.websocket_extensions, + options.permessage_deflate_offered, + )?; client .write_all(&buf[..header_end]) .await @@ -1030,7 +1115,10 @@ where overflow_bytes = overflow.len(), "101 Switching Protocols โ€” signaling protocol upgrade" ); - return Ok(RelayOutcome::Upgraded { overflow }); + return Ok(RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + }); } // Bodiless responses (HEAD, 1xx, 204, 304): forward headers only, skip body @@ -1132,6 +1220,59 @@ fn parse_connection_close(headers: &str) -> bool { false } +fn validate_websocket_response_extensions( + headers: &str, + mode: WebSocketExtensionMode, + permessage_deflate_offered: bool, +) -> Result { + let offers = websocket_extension_offers(headers); + if offers.is_empty() { + return Ok(false); + } + + match mode { + WebSocketExtensionMode::Preserve => Ok(false), + WebSocketExtensionMode::PermessageDeflate => { + if !permessage_deflate_offered { + return Err(miette!( + "upstream negotiated WebSocket compression that was not offered" + )); + } + if offers.len() != 1 { + return Err(miette!("upstream negotiated multiple WebSocket extensions")); + } + let params = &offers[0]; + let Some((extension, rest)) = params.split_first() else { + return Ok(false); + }; + if !extension.eq_ignore_ascii_case("permessage-deflate") { + return Err(miette!( + "upstream negotiated unsupported WebSocket extension" + )); + } + let mut client_no_context_takeover = false; + for param in rest { + let (name, value) = param.split_once('=').unwrap_or((param, "")); + if name.eq_ignore_ascii_case("client_no_context_takeover") && value.is_empty() { + client_no_context_takeover = true; + } else if !(name.eq_ignore_ascii_case("server_no_context_takeover") + && value.is_empty()) + { + return Err(miette!( + "upstream negotiated unsupported permessage-deflate parameter" + )); + } + } + if !client_no_context_takeover { + return Err(miette!( + "upstream negotiated permessage-deflate without client_no_context_takeover" + )); + } + Ok(true) + } + } +} + /// Check if the client request headers contain both `Upgrade` and /// `Connection: Upgrade` headers, indicating the client requested a /// protocol upgrade (e.g. WebSocket). @@ -2251,7 +2392,7 @@ mod tests { let outcome = result.expect("relay_response should succeed"); match outcome { - RelayOutcome::Upgraded { overflow } => { + RelayOutcome::Upgraded { overflow, .. } => { assert_eq!( &overflow, b"\x81\x05hello", "overflow should contain WebSocket frame data" @@ -2297,7 +2438,7 @@ mod tests { .expect("relay_response should not deadlock"); match result.expect("should succeed") { - RelayOutcome::Upgraded { overflow } => { + RelayOutcome::Upgraded { overflow, .. } => { assert!(overflow.is_empty(), "no overflow expected"); } other => panic!("Expected Upgraded, got {other:?}"), @@ -2440,7 +2581,7 @@ mod tests { &mut proxy_to_client, &mut proxy_to_upstream, RelayRequestOptions { - strip_websocket_extensions: true, + websocket_extensions: WebSocketExtensionMode::PermessageDeflate, ..Default::default() }, ) @@ -2509,14 +2650,14 @@ mod tests { &mut proxy_to_client, &mut proxy_to_upstream, RelayRequestOptions { - strip_websocket_extensions: true, + websocket_extensions: WebSocketExtensionMode::PermessageDeflate, ..Default::default() }, ) .await; let err = result.expect_err("upstream extension negotiation must fail closed"); - assert!(err.to_string().contains("unsupported WebSocket extensions")); + assert!(err.to_string().contains("not offered")); upstream_task.await.expect("upstream task should complete"); drop(proxy_to_client); @@ -2528,6 +2669,81 @@ mod tests { ); } + #[tokio::test] + async fn permessage_deflate_mode_allows_supported_no_context_takeover() { + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let req = L7Request { + action: "GET".to_string(), + target: "/ws".to_string(), + query_params: HashMap::new(), + raw_header: format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover; server_no_context_takeover\r\nSec-WebSocket-Version: 13\r\n\r\n" + ) + .into_bytes(), + body_length: BodyLength::None, + }; + + let upstream_task = tokio::spawn(async move { + let mut buf = vec![0u8; 4096]; + let mut total = 0; + loop { + let n = upstream_side.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + } + let forwarded = String::from_utf8_lossy(&buf[..total]).to_ascii_lowercase(); + assert!(forwarded.contains( + "sec-websocket-extensions: permessage-deflate; client_no_context_takeover; server_no_context_takeover" + )); + upstream_side + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover; server_no_context_takeover\r\n\r\n", + ) + .await + .unwrap(); + upstream_side.flush().await.unwrap(); + }); + + let outcome = relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + websocket_extensions: WebSocketExtensionMode::PermessageDeflate, + ..Default::default() + }, + ) + .await + .expect("safe permessage-deflate negotiation should pass"); + + assert!( + matches!( + outcome, + RelayOutcome::Upgraded { + websocket_permessage_deflate: true, + .. + } + ), + "safe permessage-deflate must be marked negotiated" + ); + upstream_task.await.expect("upstream task should complete"); + } + + #[test] + fn permessage_deflate_offer_requires_client_no_context_takeover() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); + assert!(supported_permessage_deflate_offer(&raw).is_none()); + } + #[tokio::test] async fn relay_request_guard_blocks_stale_generation_before_upstream_write() { let policy_data = "network_policies: {}\n"; @@ -2646,8 +2862,13 @@ mod tests { "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\nSec-WebSocket-Version: 13\r\n\r\n" ); - let stripped = - strip_websocket_extensions_if_requested(raw.as_bytes()).expect("strip should succeed"); + let (stripped, offered) = rewrite_websocket_extensions_for_mode( + raw.as_bytes(), + WebSocketExtensionMode::PermessageDeflate, + true, + ) + .expect("strip should succeed"); + assert!(!offered); let stripped = String::from_utf8(stripped).unwrap(); assert!(stripped.contains("Upgrade: websocket\r\n")); @@ -2665,8 +2886,14 @@ mod tests { fn strip_websocket_extensions_leaves_non_websocket_request_unchanged() { let raw = b"GET /api HTTP/1.1\r\nHost: example.com\r\nSec-WebSocket-Extensions: permessage-deflate\r\n\r\n"; - let stripped = strip_websocket_extensions_if_requested(raw).expect("strip should succeed"); + let (stripped, offered) = rewrite_websocket_extensions_for_mode( + raw, + WebSocketExtensionMode::PermessageDeflate, + false, + ) + .expect("strip should succeed"); + assert!(!offered); assert_eq!(stripped, raw); } diff --git a/crates/openshell-sandbox/src/l7/websocket.rs b/crates/openshell-sandbox/src/l7/websocket.rs index e2649f811..758697185 100644 --- a/crates/openshell-sandbox/src/l7/websocket.rs +++ b/crates/openshell-sandbox/src/l7/websocket.rs @@ -1,18 +1,22 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -//! Minimal WebSocket relay for opt-in credential placeholder rewriting. +//! WebSocket relay for opt-in credential placeholder rewriting and message policy. //! //! The relay parses only client-to-server frames. Server-to-client bytes stay -//! raw passthrough so this remains a narrow post-upgrade credential boundary, -//! not a general WebSocket inspection engine. +//! raw passthrough so inspection and rewriting cannot expose response payloads. +use crate::l7::relay::{L7EvalContext, evaluate_l7_request}; +use crate::l7::{EnforcementMode, L7RequestInfo}; +use crate::opa::TunnelPolicyEngine; use crate::secrets::SecretResolver; +use flate2::{Compress, Compression, Decompress, FlushCompress, FlushDecompress, Status}; use miette::{IntoDiagnostic, Result, miette}; use openshell_ocsf::{ ActionId, ActivityId, DispositionId, Endpoint, NetworkActivityBuilder, SeverityId, StatusId, ocsf_emit, }; +use std::collections::HashMap; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; const MAX_TEXT_MESSAGE_BYTES: usize = 1024 * 1024; @@ -39,20 +43,40 @@ struct FrameHeader { #[derive(Debug)] enum FragmentState { None, - Text { payload: Vec }, + Text { payload: Vec, compressed: bool }, Binary, } -/// Relay an upgraded WebSocket connection, rewriting credential placeholders -/// in client-to-server UTF-8 text messages. -pub(super) async fn relay_with_credential_rewrite( +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) enum WebSocketCompression { + None, + PermessageDeflate, +} + +pub(super) struct InspectionOptions<'a> { + pub(super) engine: &'a TunnelPolicyEngine, + pub(super) ctx: &'a L7EvalContext, + pub(super) enforcement: EnforcementMode, + pub(super) target: String, + pub(super) query_params: HashMap>, +} + +pub(super) struct RelayOptions<'a> { + pub(super) policy_name: &'a str, + pub(super) resolver: Option<&'a SecretResolver>, + pub(super) inspector: Option>, + pub(super) compression: WebSocketCompression, +} + +/// Relay an upgraded WebSocket connection with optional client text inspection, +/// credential rewriting, and strict permessage-deflate handling. +pub(super) async fn relay_with_options( client: &mut C, upstream: &mut U, overflow: Vec, host: &str, port: u16, - policy_name: &str, - resolver: &SecretResolver, + options: RelayOptions<'_>, ) -> Result<()> where C: AsyncRead + AsyncWrite + Unpin + Send, @@ -66,14 +90,8 @@ where client_write.flush().await.into_diagnostic()?; } - let client_to_server = relay_client_to_server( - &mut client_read, - &mut upstream_write, - host, - port, - policy_name, - resolver, - ); + let client_to_server = + relay_client_to_server(&mut client_read, &mut upstream_write, host, port, &options); let server_to_client = async { tokio::io::copy(&mut upstream_read, &mut client_write) .await @@ -96,8 +114,7 @@ async fn relay_client_to_server( writer: &mut W, host: &str, port: u16, - policy_name: &str, - resolver: &SecretResolver, + options: &RelayOptions<'_>, ) -> Result<()> where R: AsyncRead + Unpin, @@ -108,7 +125,7 @@ where loop { let Some(frame) = read_frame_header(reader).await.inspect_err(|e| { - emit_protocol_failure(host, port, policy_name, protocol_failure_class(e)); + emit_protocol_failure(host, port, options.policy_name, protocol_failure_class(e)); })? else { writer.shutdown().await.into_diagnostic()?; @@ -117,67 +134,88 @@ where if close_seen { let e = miette!("websocket frame received after close frame"); - emit_protocol_failure(host, port, policy_name, protocol_failure_class(&e)); + emit_protocol_failure(host, port, options.policy_name, protocol_failure_class(&e)); return Err(e); } - if let Err(e) = validate_frame_header(&frame, &fragments) { - emit_protocol_failure(host, port, policy_name, protocol_failure_class(&e)); + if let Err(e) = validate_frame_header(&frame, &fragments, options.compression) { + emit_protocol_failure(host, port, options.policy_name, protocol_failure_class(&e)); return Err(e); } match frame.opcode { OPCODE_TEXT => { let payload = read_masked_payload(reader, &frame).await.inspect_err(|e| { - emit_protocol_failure(host, port, policy_name, protocol_failure_class(e)); + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); })?; + let compressed = frame.rsv == 0x40; if frame.fin { relay_text_payload( - writer, - &frame, - payload, - false, - host, - port, - policy_name, - resolver, + writer, &frame, payload, false, compressed, host, port, options, ) .await .inspect_err(|e| { - emit_protocol_failure(host, port, policy_name, protocol_failure_class(e)); + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); })?; } else { - fragments = FragmentState::Text { payload }; + fragments = FragmentState::Text { + payload, + compressed, + }; } } OPCODE_CONTINUATION => match &mut fragments { - FragmentState::Text { payload } => { + FragmentState::Text { + payload, + compressed, + } => { let next = read_masked_payload(reader, &frame).await.inspect_err(|e| { - emit_protocol_failure(host, port, policy_name, protocol_failure_class(e)); + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); })?; if let Err(e) = append_text_fragment(payload, next) { - emit_protocol_failure(host, port, policy_name, protocol_failure_class(&e)); + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(&e), + ); return Err(e); } if frame.fin { let complete = std::mem::take(payload); + let was_compressed = *compressed; fragments = FragmentState::None; relay_text_payload( writer, &frame, complete, true, + was_compressed, host, port, - policy_name, - resolver, + options, ) .await .inspect_err(|e| { emit_protocol_failure( host, port, - policy_name, + options.policy_name, protocol_failure_class(e), ); })?; @@ -190,7 +228,7 @@ where emit_protocol_failure( host, port, - policy_name, + options.policy_name, protocol_failure_class(e), ); })?; @@ -201,7 +239,12 @@ where FragmentState::None => { let e = miette!("websocket continuation frame without active fragmented message"); - emit_protocol_failure(host, port, policy_name, protocol_failure_class(&e)); + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(&e), + ); return Err(e); } }, @@ -212,14 +255,24 @@ where copy_raw_frame_payload(reader, writer, &frame) .await .inspect_err(|e| { - emit_protocol_failure(host, port, policy_name, protocol_failure_class(e)); + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); })?; } OPCODE_CLOSE | OPCODE_PING | OPCODE_PONG => { relay_control_frame(reader, writer, &frame) .await .inspect_err(|e| { - emit_protocol_failure(host, port, policy_name, protocol_failure_class(e)); + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); })?; if frame.opcode == OPCODE_CLOSE { close_seen = true; @@ -314,10 +367,14 @@ async fn read_frame_header(reader: &mut R) -> Result Result<()> { - if frame.rsv != 0 { +fn validate_frame_header( + frame: &FrameHeader, + fragments: &FragmentState, + compression: WebSocketCompression, +) -> Result<()> { + if !valid_rsv_bits(frame, fragments, compression) { return Err(miette!( - "websocket frame has RSV bits set; compression/extensions are not supported" + "websocket frame has unsupported RSV bits or extension state" )); } if !frame.masked { @@ -365,6 +422,20 @@ fn validate_frame_header(frame: &FrameHeader, fragments: &FragmentState) -> Resu Ok(()) } +fn valid_rsv_bits( + frame: &FrameHeader, + fragments: &FragmentState, + compression: WebSocketCompression, +) -> bool { + if frame.rsv == 0 { + return true; + } + if compression != WebSocketCompression::PermessageDeflate || frame.rsv != 0x40 { + return false; + } + matches!(fragments, FragmentState::None) && matches!(frame.opcode, OPCODE_TEXT | OPCODE_BINARY) +} + async fn read_masked_payload( reader: &mut R, frame: &FrameHeader, @@ -408,18 +479,31 @@ async fn relay_text_payload( frame: &FrameHeader, payload: Vec, force_reframe: bool, + compressed: bool, host: &str, port: u16, - policy_name: &str, - resolver: &SecretResolver, + options: &RelayOptions<'_>, ) -> Result<()> { - let mut text = String::from_utf8(payload) + let message_payload = if compressed { + decompress_permessage_deflate(&payload)? + } else { + payload + }; + let mut text = String::from_utf8(message_payload) .map_err(|_| miette!("websocket text message is not valid UTF-8"))?; - let replacements = resolver - .rewrite_websocket_text_placeholders(&mut text) - .map_err(|_| miette!("websocket credential placeholder resolution failed"))?; + let replacements = if let Some(resolver) = options.resolver { + resolver + .rewrite_websocket_text_placeholders(&mut text) + .map_err(|_| miette!("websocket credential placeholder resolution failed"))? + } else { + 0 + }; + + if let Some(inspector) = options.inspector.as_ref() { + inspect_websocket_text_message(host, port, options.policy_name, inspector)?; + } - if replacements == 0 && !force_reframe { + if replacements == 0 && !force_reframe && !compressed { writer .write_all(&frame.raw_header) .await @@ -435,11 +519,40 @@ async fn relay_text_payload( } if replacements > 0 { - emit_rewrite_event(host, port, policy_name, replacements); + emit_rewrite_event(host, port, options.policy_name, replacements); + } + if compressed { + let compressed_payload = compress_permessage_deflate(text.as_bytes())?; + return write_masked_frame_with_rsv(writer, OPCODE_TEXT, 0x40, &compressed_payload).await; } write_masked_frame(writer, OPCODE_TEXT, text.as_bytes()).await } +fn inspect_websocket_text_message( + host: &str, + port: u16, + policy_name: &str, + inspector: &InspectionOptions<'_>, +) -> Result<()> { + let request_info = L7RequestInfo { + action: "WEBSOCKET_TEXT".to_string(), + target: inspector.target.clone(), + query_params: inspector.query_params.clone(), + graphql: None, + }; + let (allowed, reason) = evaluate_l7_request(inspector.engine, inspector.ctx, &request_info)?; + let decision = match (allowed, inspector.enforcement) { + (true, _) => "allow", + (false, EnforcementMode::Audit) => "audit", + (false, EnforcementMode::Enforce) => "deny", + }; + emit_websocket_l7_event(host, port, policy_name, &request_info, decision, &reason); + if !allowed && inspector.enforcement == EnforcementMode::Enforce { + return Err(miette!("websocket text message denied by policy")); + } + Ok(()) +} + async fn relay_control_frame( reader: &mut R, writer: &mut W, @@ -533,9 +646,18 @@ async fn write_masked_frame( writer: &mut W, opcode: u8, payload: &[u8], +) -> Result<()> { + write_masked_frame_with_rsv(writer, opcode, 0, payload).await +} + +async fn write_masked_frame_with_rsv( + writer: &mut W, + opcode: u8, + rsv: u8, + payload: &[u8], ) -> Result<()> { let mut header = Vec::with_capacity(14); - header.push(0x80 | opcode); + header.push(0x80 | rsv | opcode); match payload.len() { 0..=125 => header.push(0x80 | u8::try_from(payload.len()).expect("payload <= 125")), 126..=65_535 => { @@ -562,6 +684,91 @@ async fn write_masked_frame( Ok(()) } +fn decompress_permessage_deflate(payload: &[u8]) -> Result> { + let mut decoder = Decompress::new(false); + let mut input = Vec::with_capacity(payload.len() + 4); + input.extend_from_slice(payload); + input.extend_from_slice(&[0x00, 0x00, 0xff, 0xff]); + let mut out = Vec::with_capacity(payload.len().saturating_mul(2).min(MAX_TEXT_MESSAGE_BYTES)); + let mut input_pos = 0usize; + let mut scratch = [0u8; COPY_BUF_SIZE]; + loop { + let before_in = decoder.total_in(); + let before_out = decoder.total_out(); + let status = decoder + .decompress(&input[input_pos..], &mut scratch, FlushDecompress::Sync) + .map_err(|e| miette!("websocket permessage-deflate decompression failed: {e}"))?; + let read = usize::try_from(decoder.total_in() - before_in) + .map_err(|_| miette!("websocket permessage-deflate input length overflow"))?; + let written = usize::try_from(decoder.total_out() - before_out) + .map_err(|_| miette!("websocket permessage-deflate output length overflow"))?; + input_pos = input_pos + .checked_add(read) + .ok_or_else(|| miette!("websocket permessage-deflate input length overflow"))?; + if out.len().saturating_add(written) > MAX_TEXT_MESSAGE_BYTES { + return Err(miette!( + "websocket text message exceeds {MAX_TEXT_MESSAGE_BYTES} byte limit" + )); + } + out.extend_from_slice(&scratch[..written]); + if matches!(status, Status::StreamEnd) { + break; + } + if input_pos >= input.len() && written < scratch.len() { + break; + } + if read == 0 && written == 0 { + return Err(miette!( + "websocket permessage-deflate decompression did not make progress" + )); + } + } + Ok(out) +} + +fn compress_permessage_deflate(payload: &[u8]) -> Result> { + let mut compressor = Compress::new(Compression::fast(), false); + let expansion = payload.len() / 16; + let mut out = Vec::with_capacity(payload.len().saturating_add(expansion).saturating_add(128)); + loop { + let consumed = usize::try_from(compressor.total_in()) + .map_err(|_| miette!("websocket permessage-deflate input length overflow"))?; + if consumed >= payload.len() { + break; + } + let before_in = compressor.total_in(); + let before_out = compressor.total_out(); + let status = compressor + .compress_vec(&payload[consumed..], &mut out, FlushCompress::None) + .map_err(|e| miette!("websocket permessage-deflate compression failed: {e}"))?; + if matches!(status, Status::BufError) + || (compressor.total_in() == before_in && compressor.total_out() == before_out) + { + out.reserve(out.capacity().max(1024)); + } + } + loop { + out.reserve(64); + let before_out = compressor.total_out(); + compressor + .compress_vec(&[], &mut out, FlushCompress::Sync) + .map_err(|e| miette!("websocket permessage-deflate compression failed: {e}"))?; + if out.ends_with(&[0x00, 0x00, 0xff, 0xff]) { + break; + } + if compressor.total_out() == before_out { + out.reserve(out.capacity().max(1024)); + } + } + if !out.ends_with(&[0x00, 0x00, 0xff, 0xff]) { + return Err(miette!( + "websocket permessage-deflate compression missing sync marker" + )); + } + out.truncate(out.len() - 4); + Ok(out) +} + fn new_mask_key() -> [u8; 4] { let bytes = uuid::Uuid::new_v4().into_bytes(); [bytes[0], bytes[1], bytes[2], bytes[3]] @@ -594,6 +801,48 @@ fn emit_rewrite_event(host: &str, port: u16, policy_name: &str, replacements: us ocsf_emit!(event); } +fn emit_websocket_l7_event( + host: &str, + port: u16, + policy_name: &str, + request_info: &L7RequestInfo, + decision: &str, + reason: &str, +) { + let policy_name = if policy_name.is_empty() { + "-" + } else { + policy_name + }; + let (action_id, disposition_id, severity) = match decision { + "deny" => (ActionId::Denied, DispositionId::Blocked, SeverityId::Medium), + "allow" | "audit" => ( + ActionId::Allowed, + DispositionId::Allowed, + SeverityId::Informational, + ), + _ => ( + ActionId::Other, + DispositionId::Other, + SeverityId::Informational, + ), + }; + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(action_id) + .disposition(disposition_id) + .severity(severity) + .status(StatusId::Success) + .dst_endpoint(Endpoint::from_domain(host, port)) + .firewall_rule(policy_name, "l7-websocket") + .message(format!( + "WEBSOCKET_L7_REQUEST {decision} {} {host}:{port}{} reason={reason}", + request_info.action, request_info.target + )) + .build(); + ocsf_emit!(event); +} + fn protocol_failure_class(error: &miette::Report) -> &'static str { let msg = error.to_string().to_ascii_lowercase(); if msg.contains("credential") { @@ -653,7 +902,7 @@ mod tests { use crate::secrets::SecretResolver; use tokio::io::{AsyncReadExt, AsyncWriteExt}; - fn resolver() -> (std::collections::HashMap, SecretResolver) { + fn resolver() -> (HashMap, SecretResolver) { let (child_env, resolver) = SecretResolver::from_provider_env( std::iter::once(("DISCORD_BOT_TOKEN".to_string(), "real-token".to_string())).collect(), ); @@ -661,9 +910,13 @@ mod tests { } fn masked_frame(fin: bool, opcode: u8, payload: &[u8]) -> Vec { + masked_frame_with_rsv(fin, opcode, 0, payload) + } + + fn masked_frame_with_rsv(fin: bool, opcode: u8, rsv: u8, payload: &[u8]) -> Vec { let mask_key = [0x37, 0xfa, 0x21, 0x3d]; let mut frame = Vec::new(); - frame.push(if fin { 0x80 | opcode } else { opcode }); + frame.push((if fin { 0x80 } else { 0 }) | rsv | opcode); match payload.len() { 0..=125 => frame.push(0x80 | u8::try_from(payload.len()).expect("payload <= 125")), 126..=65_535 => { @@ -735,13 +988,47 @@ mod tests { client_write.write_all(&input).await.unwrap(); drop(client_write); + let options = RelayOptions { + policy_name: "test-policy", + resolver: Some(&resolver), + inspector: None, + compression: WebSocketCompression::None, + }; + let result = relay_client_to_server( + &mut relay_read, + &mut relay_write, + "gateway.example.test", + 443, + &options, + ) + .await; + drop(relay_write); + + let mut output = Vec::new(); + upstream_read.read_to_end(&mut output).await.unwrap(); + result.map(|()| output) + } + + async fn run_client_to_server_compressed(input: Vec) -> Result> { + let (_, resolver) = resolver(); + let (mut client_write, mut relay_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + let (mut relay_write, mut upstream_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + + client_write.write_all(&input).await.unwrap(); + drop(client_write); + + let options = RelayOptions { + policy_name: "test-policy", + resolver: Some(&resolver), + inspector: None, + compression: WebSocketCompression::PermessageDeflate, + }; let result = relay_client_to_server( &mut relay_read, &mut relay_write, "gateway.example.test", 443, - "test-policy", - &resolver, + &options, ) .await; drop(relay_write); @@ -753,6 +1040,11 @@ mod tests { fn decode_masked_text_frame(frame: &[u8]) -> String { assert_eq!(frame[0] & 0x0F, OPCODE_TEXT); + assert_ne!(frame[1] & 0x80, 0); + String::from_utf8(decode_masked_payload(frame)).unwrap() + } + + fn decode_masked_payload(frame: &[u8]) -> Vec { assert_ne!(frame[1] & 0x80, 0); let len_code = frame[1] & 0x7F; let (payload_len, mask_offset) = match len_code { @@ -767,7 +1059,14 @@ mod tests { let mask_key: [u8; 4] = frame[mask_offset..mask_offset + 4].try_into().unwrap(); let mut payload = frame[mask_offset + 4..mask_offset + 4 + payload_len].to_vec(); apply_mask(&mut payload, mask_key); - String::from_utf8(payload).unwrap() + payload + } + + fn decode_compressed_masked_text_frame(frame: &[u8]) -> String { + assert_eq!(frame[0] & 0x0F, OPCODE_TEXT); + assert_eq!(frame[0] & 0x40, 0x40); + let payload = decode_masked_payload(frame); + String::from_utf8(decompress_permessage_deflate(&payload).unwrap()).unwrap() } async fn read_one_frame(reader: &mut R) -> Vec { @@ -842,14 +1141,18 @@ mod tests { let (mut client_app, mut relay_client) = tokio::io::duplex(4096); let (mut relay_upstream, mut upstream_app) = tokio::io::duplex(4096); let relay = tokio::spawn(async move { - relay_with_credential_rewrite( + relay_with_options( &mut relay_client, &mut relay_upstream, Vec::new(), "gateway.example.test", 443, - "test-policy", - &resolver, + RelayOptions { + policy_name: "test-policy", + resolver: Some(&resolver), + inspector: None, + compression: WebSocketCompression::None, + }, ) .await }); @@ -991,6 +1294,37 @@ mod tests { ); } + #[tokio::test] + async fn compressed_text_rewrites_with_permessage_deflate() { + let (child_env, _) = resolver(); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let payload = format!(r#"{{"token":"{placeholder}"}}"#); + let compressed = compress_permessage_deflate(payload.as_bytes()).unwrap(); + let input = masked_frame_with_rsv(true, OPCODE_TEXT, 0x40, &compressed); + + let output = run_client_to_server_compressed(input) + .await + .expect("compressed text should relay"); + + assert_eq!( + decode_compressed_masked_text_frame(&output), + r#"{"token":"real-token"}"# + ); + } + + #[tokio::test] + async fn compressed_text_rejects_decompressed_oversize_message() { + let payload = vec![b'a'; MAX_TEXT_MESSAGE_BYTES + 1]; + let compressed = compress_permessage_deflate(&payload).unwrap(); + let input = masked_frame_with_rsv(true, OPCODE_TEXT, 0x40, &compressed); + + let err = run_client_to_server_compressed(input) + .await + .expect_err("oversize decompressed text should fail"); + + assert!(err.to_string().contains("exceeds")); + } + #[tokio::test] async fn binary_frame_passes_through_unchanged() { let frame = masked_frame(true, OPCODE_BINARY, &[0, 1, 2, 3, 255]); diff --git a/crates/openshell-sandbox/src/proxy.rs b/crates/openshell-sandbox/src/proxy.rs index 44b9e8c45..73cc649f3 100644 --- a/crates/openshell-sandbox/src/proxy.rs +++ b/crates/openshell-sandbox/src/proxy.rs @@ -2376,7 +2376,7 @@ async fn relay_rewritten_forward_request( client: &mut C, upstream: &mut U, generation_guard: &PolicyGenerationGuard, - strip_websocket_extensions: bool, + websocket_extensions: crate::l7::rest::WebSocketExtensionMode, ) -> Result where C: TokioAsyncRead + TokioAsyncWrite + Unpin, @@ -2404,7 +2404,7 @@ where crate::l7::rest::RelayRequestOptions { resolver: None, generation_guard: Some(generation_guard), - strip_websocket_extensions, + websocket_extensions, }, ) .await @@ -2627,7 +2627,7 @@ async fn handle_forward_proxy( }; let mut forward_request_bytes = buf[..used].to_vec(); let mut upstream_target = path.clone(); - let mut strip_websocket_extensions = false; + let mut websocket_extensions = crate::l7::rest::WebSocketExtensionMode::Preserve; let mut upgrade_options = crate::l7::relay::UpgradeRelayOptions::default(); // 4b. If the endpoint has L7 config, evaluate the request against @@ -2771,12 +2771,16 @@ async fn handle_forward_proxy( if l7_config.config.protocol == crate::l7::L7Protocol::Rest && l7_config.config.websocket_credential_rewrite { - strip_websocket_extensions = true; + websocket_extensions = crate::l7::rest::WebSocketExtensionMode::PermessageDeflate; upgrade_options = crate::l7::relay::UpgradeRelayOptions { websocket_request, - websocket_credential_rewrite: true, + websocket: crate::l7::relay::WebSocketUpgradeBehavior { + credential_rewrite: true, + ..Default::default() + }, secret_resolver: secret_resolver.clone(), policy_name: matched_policy.clone().unwrap_or_default(), + ..Default::default() }; } let graphql = if l7_config.config.protocol == crate::l7::L7Protocol::Graphql { @@ -3242,10 +3246,15 @@ async fn handle_forward_proxy( client, &mut upstream, &forward_generation_guard, - strip_websocket_extensions, + websocket_extensions, ) .await?; - if let crate::l7::provider::RelayOutcome::Upgraded { overflow } = outcome { + if let crate::l7::provider::RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + } = outcome + { + upgrade_options.websocket.permessage_deflate = websocket_permessage_deflate; crate::l7::relay::handle_upgrade( client, &mut upstream, @@ -4417,7 +4426,7 @@ mod tests { &mut proxy_to_client, &mut proxy_to_upstream, &guard, - false, + crate::l7::rest::WebSocketExtensionMode::Preserve, ) .await; assert!( @@ -4456,7 +4465,7 @@ mod tests { &mut proxy_to_client, &mut proxy_to_upstream, &guard, - false, + crate::l7::rest::WebSocketExtensionMode::Preserve, ) .await; assert!(result.is_err(), "forward relay must reject CL/TE ambiguity"); diff --git a/crates/openshell-sandbox/tests/websocket_upgrade.rs b/crates/openshell-sandbox/tests/websocket_upgrade.rs index e4cd232ce..b35076a9a 100644 --- a/crates/openshell-sandbox/tests/websocket_upgrade.rs +++ b/crates/openshell-sandbox/tests/websocket_upgrade.rs @@ -124,7 +124,7 @@ async fn websocket_upgrade_through_l7_relay_exchanges_message() { .expect("relay should succeed"); match outcome { - RelayOutcome::Upgraded { overflow } => { + RelayOutcome::Upgraded { overflow, .. } => { // This is what handle_upgrade() does in relay.rs if !overflow.is_empty() { client_proxy.write_all(&overflow).await.unwrap(); diff --git a/proto/sandbox.proto b/proto/sandbox.proto index 5ff88f6cf..db1b15448 100644 --- a/proto/sandbox.proto +++ b/proto/sandbox.proto @@ -70,7 +70,7 @@ message NetworkEndpoint { // Single port (backwards compat). Use `ports` for multiple ports. // Mutually exclusive with `ports` โ€” if both are set, `ports` takes precedence. uint32 port = 2; - // Application protocol for L7 inspection: "rest", "graphql", "sql", or "" (L4-only). + // Application protocol for L7 inspection: "rest", "websocket", "graphql", "sql", or "" (L4-only). string protocol = 3; // TLS handling: "terminate" or "passthrough" (default). string tls = 4; From bac23a10ef354c5b05638eec030fa657ce5d440a Mon Sep 17 00:00:00 2001 From: Aaron Erickson Date: Fri, 8 May 2026 20:44:18 -0700 Subject: [PATCH 04/17] fix(sandbox): harden websocket upgrade validation Signed-off-by: Aaron Erickson --- Cargo.lock | 1 + crates/openshell-sandbox/Cargo.toml | 1 + crates/openshell-sandbox/src/l7/rest.rs | 384 ++++++++++++++----- crates/openshell-sandbox/src/l7/websocket.rs | 89 ++++- crates/openshell-sandbox/src/proxy.rs | 30 +- 5 files changed, 401 insertions(+), 104 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 63ba19ced..05a1bdff2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3595,6 +3595,7 @@ dependencies = [ "serde", "serde_json", "serde_yml", + "sha1 0.10.6", "sha2 0.10.9", "temp-env", "tempfile", diff --git a/crates/openshell-sandbox/Cargo.toml b/crates/openshell-sandbox/Cargo.toml index 6d8310b26..29919ede4 100644 --- a/crates/openshell-sandbox/Cargo.toml +++ b/crates/openshell-sandbox/Cargo.toml @@ -59,6 +59,7 @@ uuid = { workspace = true } # Encoding base64 = { workspace = true } flate2 = "1" +sha1 = "0.10" # IP network / CIDR parsing ipnet = "2" diff --git a/crates/openshell-sandbox/src/l7/rest.rs b/crates/openshell-sandbox/src/l7/rest.rs index f3dfc0f0e..9b39bd7b9 100644 --- a/crates/openshell-sandbox/src/l7/rest.rs +++ b/crates/openshell-sandbox/src/l7/rest.rs @@ -12,6 +12,7 @@ use crate::opa::PolicyGenerationGuard; use crate::secrets::{SecretResolver, rewrite_http_header_block}; use base64::Engine as _; use miette::{IntoDiagnostic, Result, miette}; +use sha1::{Digest, Sha1}; use std::collections::HashMap; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tracing::debug; @@ -406,17 +407,27 @@ where .windows(4) .position(|w| w == b"\r\n\r\n") .map_or(req.raw_header.len(), |p| p + 4); + let header_str = std::str::from_utf8(&req.raw_header[..header_end]) + .map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let client_requested_upgrade = client_requested_upgrade(header_str); let websocket_request = if options.websocket_extensions == WebSocketExtensionMode::Preserve { - false + None } else { - validate_websocket_upgrade_request(&req.raw_header[..header_end])? + parse_websocket_upgrade_request(&req.raw_header[..header_end])? }; - let (header_bytes, permessage_deflate_offered) = rewrite_websocket_extensions_for_mode( + let (header_bytes, expected_websocket_extension) = rewrite_websocket_extensions_for_mode( &req.raw_header[..header_end], options.websocket_extensions, - websocket_request, + websocket_request.is_some(), )?; + let websocket_response = + websocket_request + .as_ref() + .map(|request| WebSocketResponseValidation { + expected_accept: websocket_accept_for_key(&request.sec_key), + expected_extension: expected_websocket_extension.clone(), + }); let rewrite_result = rewrite_http_header_block(&header_bytes, options.resolver) .map_err(|e| miette!("credential injection failed: {e}"))?; @@ -465,48 +476,12 @@ where client, RelayResponseOptions { websocket_extensions: options.websocket_extensions, - permessage_deflate_offered, + websocket: websocket_response, + client_requested_upgrade, }, ) .await?; - // Validate that the client actually requested an upgrade before accepting - // a 101 from upstream. Per RFC 9110 Section 7.8, the server MUST NOT send - // 101 unless the client sent Upgrade + Connection: Upgrade headers. A - // non-compliant or malicious upstream could send an unsolicited 101 to - // bypass L7 inspection. - if matches!(outcome, RelayOutcome::Upgraded { .. }) { - let header_str = String::from_utf8_lossy(&req.raw_header[..header_end]); - if !client_requested_upgrade(&header_str) { - openshell_ocsf::ocsf_emit!( - openshell_ocsf::DetectionFindingBuilder::new(crate::ocsf_ctx()) - .activity(openshell_ocsf::ActivityId::Open) - .action(openshell_ocsf::ActionId::Denied) - .disposition(openshell_ocsf::DispositionId::Blocked) - .severity(openshell_ocsf::SeverityId::High) - .confidence(openshell_ocsf::ConfidenceId::High) - .is_alert(true) - .finding_info( - openshell_ocsf::FindingInfo::new( - "unsolicited-101-upgrade", - "Unsolicited 101 Switching Protocols", - ) - .with_desc(&format!( - "Upstream sent 101 without client Upgrade request for {} {} โ€” \ - possible L7 inspection bypass. Connection closed.", - req.action, req.target, - )), - ) - .message(format!( - "Unsolicited 101 upgrade blocked: {} {}", - req.action, req.target, - )) - .build() - ); - return Ok(RelayOutcome::Consumed); - } - } - Ok(outcome) } @@ -522,12 +497,12 @@ fn rewrite_websocket_extensions_for_mode( raw_header: &[u8], mode: WebSocketExtensionMode, websocket_request: bool, -) -> Result<(Vec, bool)> { +) -> Result<(Vec, Option)> { if !websocket_request || mode == WebSocketExtensionMode::Preserve { - return Ok((raw_header.to_vec(), false)); + return Ok((raw_header.to_vec(), None)); } match mode { - WebSocketExtensionMode::Preserve => Ok((raw_header.to_vec(), false)), + WebSocketExtensionMode::Preserve => Ok((raw_header.to_vec(), None)), WebSocketExtensionMode::PermessageDeflate => { rewrite_websocket_extensions_for_permessage_deflate(raw_header) } @@ -536,7 +511,7 @@ fn rewrite_websocket_extensions_for_mode( fn rewrite_websocket_extensions_for_permessage_deflate( raw_header: &[u8], -) -> Result<(Vec, bool)> { +) -> Result<(Vec, Option)> { let header_str = std::str::from_utf8(raw_header) .map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; let safe_offer = supported_permessage_deflate_offer(header_str); @@ -561,7 +536,7 @@ fn rewrite_websocket_extensions_for_permessage_deflate( } out.extend_from_slice(line.as_bytes()); } - Ok((out, safe_offer.is_some())) + Ok((out, safe_offer)) } fn supported_permessage_deflate_offer(header_str: &str) -> Option { @@ -621,12 +596,27 @@ fn websocket_extension_offers(header_str: &str) -> Vec> { offers } +#[derive(Debug, Clone, PartialEq, Eq)] +struct WebSocketUpgradeRequest { + sec_key: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct WebSocketResponseValidation { + expected_accept: String, + expected_extension: Option, +} + fn validate_websocket_upgrade_request(raw_header: &[u8]) -> Result { + parse_websocket_upgrade_request(raw_header).map(|request| request.is_some()) +} + +fn parse_websocket_upgrade_request(raw_header: &[u8]) -> Result> { let header_str = std::str::from_utf8(raw_header) .map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; let mut lines = header_str.lines(); let Some(request_line) = lines.next() else { - return Ok(false); + return Ok(None); }; let method = request_line.split_whitespace().next().unwrap_or_default(); let mut headers = WebSocketUpgradeHeaders::default(); @@ -660,7 +650,7 @@ fn validate_websocket_upgrade_request(raw_header: &[u8]) -> Result { } if !headers.is_attempt() { - return Ok(false); + return Ok(None); } if !method.eq_ignore_ascii_case("GET") { return Err(miette!("websocket upgrade request must use GET")); @@ -694,7 +684,17 @@ fn validate_websocket_upgrade_request(raw_header: &[u8]) -> Result { "websocket upgrade request must use Sec-WebSocket-Version: 13" )); } - Ok(true) + Ok(Some(WebSocketUpgradeRequest { + sec_key: key.to_string(), + })) +} + +fn websocket_accept_for_key(sec_key: &str) -> String { + const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + let mut hasher = Sha1::new(); + hasher.update(sec_key.as_bytes()); + hasher.update(WEBSOCKET_GUID.as_bytes()); + base64::engine::general_purpose::STANDARD.encode(hasher.finalize()) } fn header_value_contains_token(value: &str, expected: &str) -> bool { @@ -1035,10 +1035,21 @@ fn find_crlf(buf: &[u8], start: usize) -> Option { .map(|offset| start + offset) } -#[derive(Clone, Copy, Default)] +#[derive(Clone)] struct RelayResponseOptions { websocket_extensions: WebSocketExtensionMode, - permessage_deflate_offered: bool, + client_requested_upgrade: bool, + websocket: Option, +} + +impl Default for RelayResponseOptions { + fn default() -> Self { + Self { + websocket_extensions: WebSocketExtensionMode::Preserve, + client_requested_upgrade: true, + websocket: None, + } + } } async fn relay_response( @@ -1099,10 +1110,13 @@ where // from upstream beyond the headers are overflow that belong to the // upgraded protocol and must be forwarded before switching. if status_code == 101 { - let websocket_permessage_deflate = validate_websocket_response_extensions( + if !options.client_requested_upgrade { + return Ok(RelayOutcome::Consumed); + } + let websocket_permessage_deflate = validate_websocket_response( &header_str, options.websocket_extensions, - options.permessage_deflate_offered, + options.websocket.as_ref(), )?; client .write_all(&buf[..header_end]) @@ -1220,10 +1234,73 @@ fn parse_connection_close(headers: &str) -> bool { false } -fn validate_websocket_response_extensions( +fn validate_websocket_response( + headers: &str, + mode: WebSocketExtensionMode, + websocket: Option<&WebSocketResponseValidation>, +) -> Result { + let Some(validation) = websocket else { + return validate_websocket_response_extensions_preserved(headers, mode); + }; + + let mut upgrade_websocket = false; + let mut connection_upgrade = false; + let mut accept_count = 0usize; + let mut accept_matches = false; + + for line in headers.lines().skip(1) { + let Some((name, value)) = line.split_once(':') else { + continue; + }; + let name = name.trim().to_ascii_lowercase(); + let value = value.trim(); + match name.as_str() { + "upgrade" if header_value_contains_token(value, "websocket") => { + upgrade_websocket = true; + } + "connection" if header_value_contains_token(value, "upgrade") => { + connection_upgrade = true; + } + "sec-websocket-accept" => { + accept_count += 1; + accept_matches = value == validation.expected_accept; + } + _ => {} + } + } + + if !upgrade_websocket { + return Err(miette!( + "websocket upgrade response missing Upgrade: websocket" + )); + } + if !connection_upgrade { + return Err(miette!( + "websocket upgrade response missing Connection: Upgrade" + )); + } + if accept_count != 1 || !accept_matches { + return Err(miette!( + "websocket upgrade response has invalid Sec-WebSocket-Accept" + )); + } + + let actual_extension = normalized_websocket_extension(headers)?; + match (&validation.expected_extension, actual_extension.as_deref()) { + (None, Some(_)) => Err(miette!( + "upstream negotiated WebSocket extension that was not offered" + )), + (None | Some(_), None) => Ok(false), + (Some(expected), Some(actual)) if expected.eq_ignore_ascii_case(actual) => Ok(true), + (Some(_), Some(_)) => Err(miette!( + "upstream negotiated WebSocket extension that does not match the safe offer" + )), + } +} + +fn validate_websocket_response_extensions_preserved( headers: &str, mode: WebSocketExtensionMode, - permessage_deflate_offered: bool, ) -> Result { let offers = websocket_extension_offers(headers); if offers.is_empty() { @@ -1232,45 +1309,47 @@ fn validate_websocket_response_extensions( match mode { WebSocketExtensionMode::Preserve => Ok(false), - WebSocketExtensionMode::PermessageDeflate => { - if !permessage_deflate_offered { - return Err(miette!( - "upstream negotiated WebSocket compression that was not offered" - )); - } - if offers.len() != 1 { - return Err(miette!("upstream negotiated multiple WebSocket extensions")); - } - let params = &offers[0]; - let Some((extension, rest)) = params.split_first() else { - return Ok(false); - }; - if !extension.eq_ignore_ascii_case("permessage-deflate") { - return Err(miette!( - "upstream negotiated unsupported WebSocket extension" - )); - } - let mut client_no_context_takeover = false; - for param in rest { - let (name, value) = param.split_once('=').unwrap_or((param, "")); - if name.eq_ignore_ascii_case("client_no_context_takeover") && value.is_empty() { - client_no_context_takeover = true; - } else if !(name.eq_ignore_ascii_case("server_no_context_takeover") - && value.is_empty()) - { - return Err(miette!( - "upstream negotiated unsupported permessage-deflate parameter" - )); - } - } - if !client_no_context_takeover { - return Err(miette!( - "upstream negotiated permessage-deflate without client_no_context_takeover" - )); - } - Ok(true) + WebSocketExtensionMode::PermessageDeflate => Err(miette!( + "upstream negotiated WebSocket extension that was not offered" + )), + } +} + +fn normalized_websocket_extension(headers: &str) -> Result> { + let offers = websocket_extension_offers(headers); + if offers.is_empty() { + return Ok(None); + } + if offers.len() != 1 { + return Err(miette!("upstream negotiated multiple WebSocket extensions")); + } + let params = &offers[0]; + let Some((extension, rest)) = params.split_first() else { + return Ok(None); + }; + if !extension.eq_ignore_ascii_case("permessage-deflate") { + return Err(miette!( + "upstream negotiated unsupported WebSocket extension" + )); + } + let mut normalized = String::from("permessage-deflate"); + for param in rest { + let (name, value) = param.split_once('=').unwrap_or((param, "")); + if !value.is_empty() { + return Err(miette!( + "upstream negotiated unsupported permessage-deflate parameter" + )); } + let name = name.to_ascii_lowercase(); + if name != "client_no_context_takeover" && name != "server_no_context_takeover" { + return Err(miette!( + "upstream negotiated unsupported permessage-deflate parameter" + )); + } + normalized.push_str("; "); + normalized.push_str(&name); } + Ok(Some(normalized)) } /// Check if the client request headers contain both `Upgrade` and @@ -1384,6 +1463,7 @@ mod tests { const TEST_POLICY: &str = include_str!("../../data/sandbox-policy.rego"); const VALID_WS_KEY: &str = "dGhlIHNhbXBsZSBub25jZQ=="; + const VALID_WS_ACCEPT: &str = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="; #[test] fn deny_response_body_is_agent_readable_and_redacted() { @@ -2638,7 +2718,10 @@ mod tests { ); upstream_side .write_all( - b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Extensions: permessage-deflate\r\n\r\n", + format!( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: {VALID_WS_ACCEPT}\r\nSec-WebSocket-Extensions: permessage-deflate\r\n\r\n" + ) + .as_bytes(), ) .await .unwrap(); @@ -2704,7 +2787,10 @@ mod tests { )); upstream_side .write_all( - b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover; server_no_context_takeover\r\n\r\n", + format!( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: {VALID_WS_ACCEPT}\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover; server_no_context_takeover\r\n\r\n" + ) + .as_bytes(), ) .await .unwrap(); @@ -2736,6 +2822,110 @@ mod tests { upstream_task.await.expect("upstream task should complete"); } + #[tokio::test] + async fn opted_in_websocket_relay_rejects_invalid_accept_before_forwarding_101() { + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let req = L7Request { + action: "GET".to_string(), + target: "/ws".to_string(), + query_params: HashMap::new(), + raw_header: format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Version: 13\r\n\r\n" + ) + .into_bytes(), + body_length: BodyLength::None, + }; + + let upstream_task = tokio::spawn(async move { + let mut buf = vec![0u8; 4096]; + let mut total = 0; + loop { + let n = upstream_side.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + } + upstream_side + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: invalid\r\n\r\n", + ) + .await + .unwrap(); + upstream_side.flush().await.unwrap(); + }); + + let result = relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + websocket_extensions: WebSocketExtensionMode::PermessageDeflate, + ..Default::default() + }, + ) + .await; + + let err = result.expect_err("invalid Sec-WebSocket-Accept must fail closed"); + assert!(err.to_string().contains("Sec-WebSocket-Accept")); + upstream_task.await.expect("upstream task should complete"); + + drop(proxy_to_client); + let mut received = Vec::new(); + app_side.read_to_end(&mut received).await.unwrap(); + assert!( + received.is_empty(), + "invalid websocket response must not forward 101 headers" + ); + } + + #[test] + fn websocket_accept_matches_rfc_6455_sample() { + assert_eq!(websocket_accept_for_key(VALID_WS_KEY), VALID_WS_ACCEPT); + } + + #[test] + fn strict_response_validation_rejects_missing_upgrade_headers() { + let validation = WebSocketResponseValidation { + expected_accept: VALID_WS_ACCEPT.to_string(), + expected_extension: None, + }; + + let err = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n", + WebSocketExtensionMode::PermessageDeflate, + Some(&validation), + ) + .expect_err("missing Upgrade/Connection must fail"); + + assert!(err.to_string().contains("Upgrade: websocket")); + } + + #[test] + fn permessage_deflate_response_must_match_exact_safe_offer() { + let validation = WebSocketResponseValidation { + expected_accept: VALID_WS_ACCEPT.to_string(), + expected_extension: Some( + "permessage-deflate; client_no_context_takeover; server_no_context_takeover" + .to_string(), + ), + }; + + let err = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover\r\n\r\n", + WebSocketExtensionMode::PermessageDeflate, + Some(&validation), + ) + .expect_err("extension response must exactly match the safe offer"); + + assert!(err.to_string().contains("safe offer")); + } + #[test] fn permessage_deflate_offer_requires_client_no_context_takeover() { let raw = format!( @@ -2868,7 +3058,7 @@ mod tests { true, ) .expect("strip should succeed"); - assert!(!offered); + assert!(offered.is_none()); let stripped = String::from_utf8(stripped).unwrap(); assert!(stripped.contains("Upgrade: websocket\r\n")); @@ -2893,7 +3083,7 @@ mod tests { ) .expect("strip should succeed"); - assert!(!offered); + assert!(offered.is_none()); assert_eq!(stripped, raw); } diff --git a/crates/openshell-sandbox/src/l7/websocket.rs b/crates/openshell-sandbox/src/l7/websocket.rs index 758697185..0777bfb1e 100644 --- a/crates/openshell-sandbox/src/l7/websocket.rs +++ b/crates/openshell-sandbox/src/l7/websocket.rs @@ -794,13 +794,17 @@ fn emit_rewrite_event(host: &str, port: u16, policy_name: &str, replacements: us .status(StatusId::Success) .dst_endpoint(Endpoint::from_domain(host, port)) .firewall_rule(policy_name, "l7-websocket") - .message(format!( - "WEBSOCKET_CREDENTIAL_REWRITE rewrote client text message [host:{host} port:{port} replacements:{replacements}]" - )) + .message(rewrite_event_message(host, port, replacements)) .build(); ocsf_emit!(event); } +fn rewrite_event_message(host: &str, port: u16, replacements: usize) -> String { + format!( + "WEBSOCKET_CREDENTIAL_REWRITE rewrote client text message [host:{host} port:{port} replacements:{replacements}]" + ) +} + fn emit_websocket_l7_event( host: &str, port: u16, @@ -888,14 +892,16 @@ fn emit_protocol_failure(host: &str, port: u16, policy_name: &str, failure_class .status(StatusId::Failure) .dst_endpoint(Endpoint::from_domain(host, port)) .firewall_rule(policy_name, "l7-websocket") - .message(format!( - "WEBSOCKET_CREDENTIAL_REWRITE closed ambiguous client frame [host:{host} port:{port}]" - )) + .message(protocol_failure_message(host, port)) .status_detail(failure_class) .build(); ocsf_emit!(event); } +fn protocol_failure_message(host: &str, port: u16) -> String { + format!("WEBSOCKET_CREDENTIAL_REWRITE closed ambiguous client frame [host:{host} port:{port}]") +} + #[cfg(test)] mod tests { use super::*; @@ -1345,6 +1351,46 @@ mod tests { assert!(err.to_string().contains("reserved opcode")); } + #[tokio::test] + async fn rejects_continuation_without_active_message() { + let err = run_client_to_server(masked_frame(true, OPCODE_CONTINUATION, b"orphan")) + .await + .expect_err("orphan continuation should fail"); + + assert!(err.to_string().contains("continuation")); + } + + #[tokio::test] + async fn rejects_new_data_frame_before_fragment_completion() { + let mut input = masked_frame(false, OPCODE_TEXT, b"partial"); + input.extend(masked_frame(true, OPCODE_TEXT, b"second")); + + let err = run_client_to_server(input) + .await + .expect_err("new data frame during fragmentation should fail"); + + assert!(err.to_string().contains("previous fragmented message")); + } + + #[tokio::test] + async fn rejects_fragmented_control_frame() { + let err = run_client_to_server(masked_frame(false, OPCODE_PING, b"ping")) + .await + .expect_err("fragmented control frame should fail"); + + assert!(err.to_string().contains("control frame is fragmented")); + } + + #[tokio::test] + async fn rejects_control_frame_over_125_bytes() { + let payload = vec![b'a'; 126]; + let err = run_client_to_server(masked_frame(true, OPCODE_PING, &payload)) + .await + .expect_err("oversize control frame should fail"); + + assert!(err.to_string().contains("control frame exceeds")); + } + #[tokio::test] async fn rejects_non_minimal_extended_length() { let err = run_client_to_server(masked_frame_with_non_minimal_16_bit_len( @@ -1410,4 +1456,35 @@ mod tests { assert!(err.to_string().contains("valid UTF-8")); } + + #[tokio::test] + async fn rejects_frames_after_client_close_frame() { + let mut input = masked_frame(true, OPCODE_CLOSE, &close_payload(1000, b"done")); + input.extend(masked_frame(true, OPCODE_TEXT, b"late")); + + let err = run_client_to_server(input) + .await + .expect_err("frames after close should fail"); + + assert!(err.to_string().contains("after close")); + } + + #[test] + fn websocket_ocsf_messages_do_not_include_payload_or_secret_material() { + let placeholder = "openshell:resolve:env:DISCORD_BOT_TOKEN"; + let secret = "real-token"; + let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); + + let rewrite = rewrite_event_message("gateway.example.test", 443, 1); + let failure = protocol_failure_message("gateway.example.test", 443); + let messages = [rewrite, failure]; + + for message in messages { + assert!(!message.contains(placeholder)); + assert!(!message.contains(secret)); + assert!(!message.contains(&payload)); + assert!(!message.contains("secret_len")); + assert!(!message.contains("payload_len")); + } + } } diff --git a/crates/openshell-sandbox/src/proxy.rs b/crates/openshell-sandbox/src/proxy.rs index 73cc649f3..94c1f53f5 100644 --- a/crates/openshell-sandbox/src/proxy.rs +++ b/crates/openshell-sandbox/src/proxy.rs @@ -2282,6 +2282,7 @@ fn rewrite_forward_request( .windows(4) .position(|w| w == b"\r\n\r\n") .map_or(used, |p| p + 4); + let websocket_upgrade = crate::l7::rest::request_is_websocket_upgrade(&raw[..header_end]); let header_str = String::from_utf8_lossy(&raw[..header_end]); let lines = header_str.split("\r\n").collect::>(); @@ -2325,6 +2326,11 @@ fn rewrite_forward_request( // Replace Connection header if lower.starts_with("connection:") { has_connection = true; + if websocket_upgrade { + output.extend_from_slice(line.as_bytes()); + output.extend_from_slice(b"\r\n"); + continue; + } output.extend_from_slice(b"Connection: close\r\n"); continue; } @@ -2343,7 +2349,7 @@ fn rewrite_forward_request( } // Inject missing headers - if !has_connection { + if !has_connection && !websocket_upgrade { output.extend_from_slice(b"Connection: close\r\n"); } if !has_via { @@ -4403,6 +4409,28 @@ mod tests { assert!(!result_str.contains("openshell:resolve:env:ANTHROPIC_API_KEY")); } + #[test] + fn test_forward_rewrite_preserves_websocket_upgrade_connection_header() { + let raw = "GET http://gateway.example.test/ws HTTP/1.1\r\n\ + Host: gateway.example.test\r\n\ + Upgrade: websocket\r\n\ + Connection: keep-alive, Upgrade\r\n\ + Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ + Sec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover\r\n\ + Sec-WebSocket-Version: 13\r\n\r\n"; + + let result = rewrite_forward_request(raw.as_bytes(), raw.len(), "/ws", None) + .expect("websocket forward rewrite should succeed"); + let result_str = String::from_utf8_lossy(&result); + + assert!(result_str.starts_with("GET /ws HTTP/1.1\r\n")); + assert!(result_str.contains("Connection: keep-alive, Upgrade\r\n")); + assert!( + !result_str.contains("Connection: close\r\n"), + "websocket forward proxy must not strip the upgrade token" + ); + } + #[tokio::test] async fn test_forward_relay_guard_blocks_stale_generation_before_upstream_write() { let policy = include_str!("../data/sandbox-policy.rego"); From 9a6f696bb10cb387e0126b044b47e0a440fc0903 Mon Sep 17 00:00:00 2001 From: Aaron Erickson Date: Fri, 8 May 2026 20:48:16 -0700 Subject: [PATCH 05/17] test(sandbox): cover route-selected websocket upgrades Signed-off-by: Aaron Erickson --- crates/openshell-sandbox/src/l7/relay.rs | 94 ++++++++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/crates/openshell-sandbox/src/l7/relay.rs b/crates/openshell-sandbox/src/l7/relay.rs index 7af4a82ea..b80dd03c0 100644 --- a/crates/openshell-sandbox/src/l7/relay.rs +++ b/crates/openshell-sandbox/src/l7/relay.rs @@ -1351,6 +1351,100 @@ network_policies: assert!(reason.contains("WEBSOCKET_TEXT /ws not permitted")); } + #[tokio::test] + async fn route_selected_websocket_upgrade_rejects_invalid_accept_without_forwarding_101() { + let data = r#" +network_policies: + route_api: + name: route_api + endpoints: + - host: gateway.example.test + port: 443 + protocol: rest + enforcement: enforce + rules: + - allow: + method: GET + path: "/ws" + binaries: + - { path: /usr/bin/node } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let tunnel_engine = engine + .clone_engine_for_tunnel(engine.current_generation()) + .unwrap(); + let configs = vec![L7EndpointConfig { + protocol: L7Protocol::Rest, + path: "/ws".into(), + tls: crate::l7::TlsMode::Auto, + enforcement: EnforcementMode::Enforce, + graphql_max_body_bytes: 0, + allow_encoded_slash: false, + websocket_credential_rewrite: true, + }]; + let ctx = L7EvalContext { + host: "gateway.example.test".into(), + port: 443, + policy_name: "route_api".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + }; + + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_route_selection( + &configs, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + app.write_all( + b"GET /ws HTTP/1.1\r\nHost: gateway.example.test\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n", + ) + .await + .unwrap(); + + let mut forwarded = [0u8; 1024]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut forwarded), + ) + .await + .expect("upgrade request should reach upstream") + .unwrap(); + let forwarded = String::from_utf8_lossy(&forwarded[..n]); + assert!(forwarded.contains("Upgrade: websocket\r\n")); + assert!(forwarded.contains("Connection: Upgrade\r\n")); + + upstream + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: invalid\r\n\r\n", + ) + .await + .unwrap(); + + let err = tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should fail closed on invalid accept") + .unwrap() + .expect_err("invalid accept must fail the route-selected relay"); + assert!(err.to_string().contains("Sec-WebSocket-Accept")); + + let mut response = [0u8; 1]; + let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) + .await + .expect("client side should close without 101") + .unwrap(); + assert_eq!(n, 0, "invalid response must not forward 101 headers"); + } + #[tokio::test] async fn l7_relay_closes_keep_alive_tunnel_after_policy_generation_change() { let initial_data = r#" From 143237f83068f2be3340298466fda7acb4da77db Mon Sep 17 00:00:00 2001 From: Aaron Erickson Date: Fri, 8 May 2026 21:19:05 -0700 Subject: [PATCH 06/17] fix(sandbox): harden websocket negotiation parsing Signed-off-by: Aaron Erickson --- crates/openshell-sandbox/src/l7/rest.rs | 376 +++++++++++++++++++++--- 1 file changed, 329 insertions(+), 47 deletions(-) diff --git a/crates/openshell-sandbox/src/l7/rest.rs b/crates/openshell-sandbox/src/l7/rest.rs index 9b39bd7b9..a550ebe38 100644 --- a/crates/openshell-sandbox/src/l7/rest.rs +++ b/crates/openshell-sandbox/src/l7/rest.rs @@ -13,7 +13,7 @@ use crate::secrets::{SecretResolver, rewrite_http_header_block}; use base64::Engine as _; use miette::{IntoDiagnostic, Result, miette}; use sha1::{Digest, Sha1}; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tracing::debug; @@ -427,6 +427,7 @@ where .map(|request| WebSocketResponseValidation { expected_accept: websocket_accept_for_key(&request.sec_key), expected_extension: expected_websocket_extension.clone(), + offered_subprotocols: request.subprotocols.clone(), }); let rewrite_result = rewrite_http_header_block(&header_bytes, options.resolver) @@ -514,7 +515,7 @@ fn rewrite_websocket_extensions_for_permessage_deflate( ) -> Result<(Vec, Option)> { let header_str = std::str::from_utf8(raw_header) .map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; - let safe_offer = supported_permessage_deflate_offer(header_str); + let safe_offer = supported_permessage_deflate_offer(header_str)?; let mut out = Vec::with_capacity(raw_header.len()); let mut inserted = false; @@ -539,22 +540,24 @@ fn rewrite_websocket_extensions_for_permessage_deflate( Ok((out, safe_offer)) } -fn supported_permessage_deflate_offer(header_str: &str) -> Option { - for params in websocket_extension_offers(header_str) { - let Some((extension, rest)) = params.split_first() else { - continue; - }; - if !extension.eq_ignore_ascii_case("permessage-deflate") { +fn supported_permessage_deflate_offer(header_str: &str) -> Result> { + for offer in websocket_extension_offers(header_str)? { + if !offer.name.eq_ignore_ascii_case("permessage-deflate") { continue; } let mut client_no_context_takeover = false; let mut server_no_context_takeover = false; let mut unsupported = false; - for param in rest { - let (name, value) = param.split_once('=').unwrap_or((param, "")); - if name.eq_ignore_ascii_case("client_no_context_takeover") && value.is_empty() { + let mut seen = HashSet::new(); + for param in &offer.params { + let name = param.name.to_ascii_lowercase(); + if param.value.is_some() || !seen.insert(name.clone()) { + unsupported = true; + break; + } + if name == "client_no_context_takeover" { client_no_context_takeover = true; - } else if name.eq_ignore_ascii_case("server_no_context_takeover") && value.is_empty() { + } else if name == "server_no_context_takeover" { server_no_context_takeover = true; } else { unsupported = true; @@ -566,13 +569,25 @@ fn supported_permessage_deflate_offer(header_str: &str) -> Option { if server_no_context_takeover { offer.push_str("; server_no_context_takeover"); } - return Some(offer); + return Ok(Some(offer)); } } - None + Ok(None) +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct WebSocketExtensionOffer { + name: String, + params: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct WebSocketExtensionParam { + name: String, + value: Option, } -fn websocket_extension_offers(header_str: &str) -> Vec> { +fn websocket_extension_offers(header_str: &str) -> Result> { let mut offers = Vec::new(); for line in header_str.lines().skip(1) { let Some((name, value)) = line.split_once(':') else { @@ -582,29 +597,56 @@ fn websocket_extension_offers(header_str: &str) -> Vec> { continue; } for extension in value.split(',') { - let params: Vec = extension - .split(';') - .map(str::trim) - .filter(|v| !v.is_empty()) - .map(ToOwned::to_owned) - .collect(); - if !params.is_empty() { - offers.push(params); + let mut parts = extension.split(';').map(str::trim); + let Some(extension_name) = parts.next().filter(|name| !name.is_empty()) else { + return Err(miette!("invalid WebSocket extension offer")); + }; + if !is_http_token(extension_name) { + return Err(miette!("invalid WebSocket extension token")); + } + let mut params = Vec::new(); + for param in parts { + if param.is_empty() { + return Err(miette!("invalid WebSocket extension parameter")); + } + let (param_name, param_value) = match param.split_once('=') { + Some((name, value)) => { + let value = value.trim(); + if value.is_empty() || value.starts_with('"') || !is_http_token(value) { + return Err(miette!("unsupported WebSocket extension parameter value")); + } + (name.trim(), Some(value.to_string())) + } + None => (param, None), + }; + if param_name.is_empty() || !is_http_token(param_name) { + return Err(miette!("invalid WebSocket extension parameter")); + } + params.push(WebSocketExtensionParam { + name: param_name.to_string(), + value: param_value, + }); } + offers.push(WebSocketExtensionOffer { + name: extension_name.to_string(), + params, + }); } } - offers + Ok(offers) } #[derive(Debug, Clone, PartialEq, Eq)] struct WebSocketUpgradeRequest { sec_key: String, + subprotocols: Vec, } #[derive(Debug, Clone, PartialEq, Eq)] struct WebSocketResponseValidation { expected_accept: String, expected_extension: Option, + offered_subprotocols: Vec, } fn validate_websocket_upgrade_request(raw_header: &[u8]) -> Result { @@ -645,6 +687,9 @@ fn parse_websocket_upgrade_request(raw_header: &[u8]) -> Result { + headers.subprotocols.extend(parse_http_token_list(value)?); + } _ => {} } } @@ -686,6 +731,7 @@ fn parse_websocket_upgrade_request(raw_header: &[u8]) -> Result bool { .any(|token| token.trim().eq_ignore_ascii_case(expected)) } +fn parse_http_token_list(value: &str) -> Result> { + let mut tokens = Vec::new(); + for token in value.split(',') { + let token = token.trim(); + if token.is_empty() || !is_http_token(token) { + return Err(miette!("invalid HTTP token list")); + } + tokens.push(token.to_string()); + } + Ok(tokens) +} + +fn is_http_token(value: &str) -> bool { + !value.is_empty() + && value.as_bytes().iter().all(|byte| { + matches!( + byte, + b'!' | b'#' + | b'$' + | b'%' + | b'&' + | b'\'' + | b'*' + | b'+' + | b'-' + | b'.' + | b'^' + | b'_' + | b'`' + | b'|' + | b'~' + | b'0'..=b'9' + | b'A'..=b'Z' + | b'a'..=b'z' + ) + }) +} + #[derive(Default)] struct WebSocketUpgradeHeaders { upgrade_websocket: bool, @@ -711,6 +795,7 @@ struct WebSocketUpgradeHeaders { sec_key_count: usize, version: Option, version_count: usize, + subprotocols: Vec, } impl WebSocketUpgradeHeaders { @@ -1247,6 +1332,8 @@ fn validate_websocket_response( let mut connection_upgrade = false; let mut accept_count = 0usize; let mut accept_matches = false; + let mut subprotocol_count = 0usize; + let mut selected_subprotocol = None; for line in headers.lines().skip(1) { let Some((name, value)) = line.split_once(':') else { @@ -1265,6 +1352,15 @@ fn validate_websocket_response( accept_count += 1; accept_matches = value == validation.expected_accept; } + "sec-websocket-protocol" => { + subprotocol_count += 1; + if !is_http_token(value) { + return Err(miette!( + "websocket upgrade response has invalid Sec-WebSocket-Protocol" + )); + } + selected_subprotocol = Some(value.to_string()); + } _ => {} } } @@ -1284,6 +1380,21 @@ fn validate_websocket_response( "websocket upgrade response has invalid Sec-WebSocket-Accept" )); } + if subprotocol_count > 1 { + return Err(miette!( + "websocket upgrade response has multiple Sec-WebSocket-Protocol headers" + )); + } + if let Some(protocol) = selected_subprotocol + && !validation + .offered_subprotocols + .iter() + .any(|offered| offered == &protocol) + { + return Err(miette!( + "upstream selected WebSocket subprotocol that was not offered" + )); + } let actual_extension = normalized_websocket_extension(headers)?; match (&validation.expected_extension, actual_extension.as_deref()) { @@ -1302,52 +1413,61 @@ fn validate_websocket_response_extensions_preserved( headers: &str, mode: WebSocketExtensionMode, ) -> Result { - let offers = websocket_extension_offers(headers); - if offers.is_empty() { - return Ok(false); - } - match mode { WebSocketExtensionMode::Preserve => Ok(false), - WebSocketExtensionMode::PermessageDeflate => Err(miette!( - "upstream negotiated WebSocket extension that was not offered" - )), + WebSocketExtensionMode::PermessageDeflate => { + let offers = websocket_extension_offers(headers)?; + if offers.is_empty() { + Ok(false) + } else { + Err(miette!( + "upstream negotiated WebSocket extension that was not offered" + )) + } + } } } fn normalized_websocket_extension(headers: &str) -> Result> { - let offers = websocket_extension_offers(headers); + let offers = websocket_extension_offers(headers)?; if offers.is_empty() { return Ok(None); } if offers.len() != 1 { return Err(miette!("upstream negotiated multiple WebSocket extensions")); } - let params = &offers[0]; - let Some((extension, rest)) = params.split_first() else { - return Ok(None); - }; - if !extension.eq_ignore_ascii_case("permessage-deflate") { + let offer = &offers[0]; + if !offer.name.eq_ignore_ascii_case("permessage-deflate") { return Err(miette!( "upstream negotiated unsupported WebSocket extension" )); } - let mut normalized = String::from("permessage-deflate"); - for param in rest { - let (name, value) = param.split_once('=').unwrap_or((param, "")); - if !value.is_empty() { + let mut client_no_context_takeover = false; + let mut server_no_context_takeover = false; + let mut seen = HashSet::new(); + for param in &offer.params { + let name = param.name.to_ascii_lowercase(); + if param.value.is_some() || !seen.insert(name.clone()) { return Err(miette!( "upstream negotiated unsupported permessage-deflate parameter" )); } - let name = name.to_ascii_lowercase(); - if name != "client_no_context_takeover" && name != "server_no_context_takeover" { + if name == "client_no_context_takeover" { + client_no_context_takeover = true; + } else if name == "server_no_context_takeover" { + server_no_context_takeover = true; + } else { return Err(miette!( "upstream negotiated unsupported permessage-deflate parameter" )); } - normalized.push_str("; "); - normalized.push_str(&name); + } + let mut normalized = String::from("permessage-deflate"); + if client_no_context_takeover { + normalized.push_str("; client_no_context_takeover"); + } + if server_no_context_takeover { + normalized.push_str("; server_no_context_takeover"); } Ok(Some(normalized)) } @@ -2894,6 +3014,7 @@ mod tests { let validation = WebSocketResponseValidation { expected_accept: VALID_WS_ACCEPT.to_string(), expected_extension: None, + offered_subprotocols: Vec::new(), }; let err = validate_websocket_response( @@ -2914,6 +3035,7 @@ mod tests { "permessage-deflate; client_no_context_takeover; server_no_context_takeover" .to_string(), ), + offered_subprotocols: Vec::new(), }; let err = validate_websocket_response( @@ -2931,7 +3053,167 @@ mod tests { let raw = format!( "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\nSec-WebSocket-Version: 13\r\n\r\n" ); - assert!(supported_permessage_deflate_offer(&raw).is_none()); + assert!( + supported_permessage_deflate_offer(&raw) + .expect("valid unsupported extension offer should parse") + .is_none() + ); + } + + #[test] + fn permessage_deflate_offer_canonicalizes_safe_params() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); + + assert_eq!( + supported_permessage_deflate_offer(&raw) + .expect("safe extension offer should parse") + .as_deref(), + Some("permessage-deflate; client_no_context_takeover; server_no_context_takeover") + ); + } + + #[test] + fn permessage_deflate_offer_rejects_duplicate_safe_params() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover; client_no_context_takeover\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); + + assert!( + supported_permessage_deflate_offer(&raw) + .expect("duplicate safe param should parse but not be supported") + .is_none() + ); + } + + #[test] + fn permessage_deflate_offer_rejects_quoted_values() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover=\"true\"\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); + + let err = supported_permessage_deflate_offer(&raw) + .expect_err("quoted permessage-deflate parameter values should fail closed"); + assert!(err.to_string().contains("parameter value")); + } + + #[test] + fn permessage_deflate_response_accepts_reordered_safe_params() { + let validation = WebSocketResponseValidation { + expected_accept: VALID_WS_ACCEPT.to_string(), + expected_extension: Some( + "permessage-deflate; client_no_context_takeover; server_no_context_takeover" + .to_string(), + ), + offered_subprotocols: Vec::new(), + }; + + let negotiated = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n\r\n", + WebSocketExtensionMode::PermessageDeflate, + Some(&validation), + ) + .expect("reordered safe extension params should canonicalize"); + + assert!(negotiated); + } + + #[test] + fn permessage_deflate_response_rejects_duplicate_safe_params() { + let validation = WebSocketResponseValidation { + expected_accept: VALID_WS_ACCEPT.to_string(), + expected_extension: Some("permessage-deflate; client_no_context_takeover".to_string()), + offered_subprotocols: Vec::new(), + }; + + let err = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover; client_no_context_takeover\r\n\r\n", + WebSocketExtensionMode::PermessageDeflate, + Some(&validation), + ) + .expect_err("duplicate extension params should fail closed"); + + assert!(err.to_string().contains("unsupported permessage-deflate")); + } + + #[test] + fn preserve_mode_leaves_malformed_extension_response_raw() { + let negotiated = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover=\"true\"\r\n\r\n", + WebSocketExtensionMode::Preserve, + None, + ) + .expect("preserve mode should not parse or reject raw extension negotiation"); + + assert!(!negotiated); + } + + #[test] + fn parse_websocket_upgrade_request_tracks_subprotocols() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Protocol: chat, superchat\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); + + let request = parse_websocket_upgrade_request(raw.as_bytes()) + .expect("request should parse") + .expect("request should be websocket"); + + assert_eq!(request.subprotocols, ["chat", "superchat"]); + } + + #[test] + fn strict_response_validation_allows_offered_subprotocol() { + let validation = WebSocketResponseValidation { + expected_accept: VALID_WS_ACCEPT.to_string(), + expected_extension: None, + offered_subprotocols: vec!["chat".to_string(), "superchat".to_string()], + }; + + let negotiated = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Protocol: superchat\r\n\r\n", + WebSocketExtensionMode::PermessageDeflate, + Some(&validation), + ) + .expect("offered subprotocol should validate"); + + assert!(!negotiated); + } + + #[test] + fn strict_response_validation_rejects_unoffered_subprotocol() { + let validation = WebSocketResponseValidation { + expected_accept: VALID_WS_ACCEPT.to_string(), + expected_extension: None, + offered_subprotocols: vec!["chat".to_string()], + }; + + let err = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Protocol: admin\r\n\r\n", + WebSocketExtensionMode::PermessageDeflate, + Some(&validation), + ) + .expect_err("unoffered subprotocol should fail closed"); + + assert!(err.to_string().contains("subprotocol")); + } + + #[test] + fn strict_response_validation_rejects_multiple_subprotocol_headers() { + let validation = WebSocketResponseValidation { + expected_accept: VALID_WS_ACCEPT.to_string(), + expected_extension: None, + offered_subprotocols: vec!["chat".to_string(), "superchat".to_string()], + }; + + let err = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Protocol: chat\r\nSec-WebSocket-Protocol: superchat\r\n\r\n", + WebSocketExtensionMode::PermessageDeflate, + Some(&validation), + ) + .expect_err("multiple selected subprotocols should fail closed"); + + assert!(err.to_string().contains("Sec-WebSocket-Protocol")); } #[tokio::test] From 7bff1c43273ee5237d928c28b3387c340b1c91fb Mon Sep 17 00:00:00 2001 From: Aaron Erickson Date: Fri, 8 May 2026 21:31:56 -0700 Subject: [PATCH 07/17] test(sandbox): add websocket conformance relay matrix Signed-off-by: Aaron Erickson --- crates/openshell-sandbox/src/l7/rest.rs | 369 ++++++++++++++++++++++++ 1 file changed, 369 insertions(+) diff --git a/crates/openshell-sandbox/src/l7/rest.rs b/crates/openshell-sandbox/src/l7/rest.rs index a550ebe38..e8e8e5b52 100644 --- a/crates/openshell-sandbox/src/l7/rest.rs +++ b/crates/openshell-sandbox/src/l7/rest.rs @@ -1580,10 +1580,285 @@ mod tests { use super::*; use crate::opa::OpaEngine; use crate::secrets::SecretResolver; + use flate2::{Compress, Compression, Decompress, FlushCompress, FlushDecompress, Status}; + use std::sync::Arc; const TEST_POLICY: &str = include_str!("../../data/sandbox-policy.rego"); const VALID_WS_KEY: &str = "dGhlIHNhbXBsZSBub25jZQ=="; const VALID_WS_ACCEPT: &str = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="; + const TEXT_OPCODE: u8 = 0x1; + + #[derive(Debug)] + struct CapturedFrame { + fin_opcode: u8, + masked: bool, + payload: Vec, + } + + async fn read_http_header_block(reader: &mut R) -> Vec { + tokio::time::timeout(std::time::Duration::from_secs(2), async { + let mut header = Vec::new(); + let mut byte = [0u8; 1]; + loop { + reader.read_exact(&mut byte).await.unwrap(); + header.push(byte[0]); + if header.ends_with(b"\r\n\r\n") { + break; + } + } + header + }) + .await + .expect("HTTP header block should arrive") + } + + async fn read_websocket_frame(reader: &mut R) -> CapturedFrame { + tokio::time::timeout(std::time::Duration::from_secs(2), async { + let mut prefix = [0u8; 2]; + reader.read_exact(&mut prefix).await.unwrap(); + let masked = prefix[1] & 0x80 != 0; + let mut payload_len = u64::from(prefix[1] & 0x7f); + if payload_len == 126 { + let mut extended = [0u8; 2]; + reader.read_exact(&mut extended).await.unwrap(); + payload_len = u64::from(u16::from_be_bytes(extended)); + } else if payload_len == 127 { + let mut extended = [0u8; 8]; + reader.read_exact(&mut extended).await.unwrap(); + payload_len = u64::from_be_bytes(extended); + } + let mut mask_key = [0u8; 4]; + if masked { + reader.read_exact(&mut mask_key).await.unwrap(); + } + let payload_len = usize::try_from(payload_len).unwrap(); + let mut payload = vec![0u8; payload_len]; + reader.read_exact(&mut payload).await.unwrap(); + if masked { + apply_test_mask(&mut payload, mask_key); + } + CapturedFrame { + fin_opcode: prefix[0], + masked, + payload, + } + }) + .await + .expect("WebSocket frame should arrive") + } + + fn masked_frame_with_rsv(opcode: u8, rsv: u8, payload: &[u8]) -> Vec { + let mask_key = [0x37, 0xfa, 0x21, 0x3d]; + let mut frame = Vec::new(); + frame.push(0x80 | rsv | opcode); + write_test_payload_len(&mut frame, 0x80, payload.len()); + frame.extend_from_slice(&mask_key); + let mut masked = payload.to_vec(); + apply_test_mask(&mut masked, mask_key); + frame.extend_from_slice(&masked); + frame + } + + fn unmasked_frame(opcode: u8, payload: &[u8]) -> Vec { + let mut frame = Vec::new(); + frame.push(0x80 | opcode); + write_test_payload_len(&mut frame, 0, payload.len()); + frame.extend_from_slice(payload); + frame + } + + fn write_test_payload_len(frame: &mut Vec, mask_bit: u8, payload_len: usize) { + if payload_len < 126 { + frame.push(mask_bit | payload_len as u8); + } else if u16::try_from(payload_len).is_ok() { + frame.push(mask_bit | 0x7e); + frame.extend_from_slice(&(payload_len as u16).to_be_bytes()); + } else { + frame.push(mask_bit | 0x7f); + frame.extend_from_slice(&(payload_len as u64).to_be_bytes()); + } + } + + fn apply_test_mask(payload: &mut [u8], mask_key: [u8; 4]) { + for (index, byte) in payload.iter_mut().enumerate() { + *byte ^= mask_key[index % 4]; + } + } + + fn compress_test_permessage_deflate(payload: &[u8]) -> Vec { + let mut compressor = Compress::new(Compression::fast(), false); + let mut out = Vec::with_capacity(payload.len().saturating_add(128)); + loop { + let consumed = usize::try_from(compressor.total_in()).unwrap(); + if consumed >= payload.len() { + break; + } + let before_in = compressor.total_in(); + let before_out = compressor.total_out(); + let status = compressor + .compress_vec(&payload[consumed..], &mut out, FlushCompress::None) + .unwrap(); + if matches!(status, Status::BufError) + || (compressor.total_in() == before_in && compressor.total_out() == before_out) + { + out.reserve(out.capacity().max(1024)); + } + } + loop { + out.reserve(64); + let before_out = compressor.total_out(); + compressor + .compress_vec(&[], &mut out, FlushCompress::Sync) + .unwrap(); + if out.ends_with(&[0x00, 0x00, 0xff, 0xff]) { + break; + } + if compressor.total_out() == before_out { + out.reserve(out.capacity().max(1024)); + } + } + out.truncate(out.len() - 4); + out + } + + fn decompress_test_permessage_deflate(payload: &[u8]) -> Vec { + let mut decoder = Decompress::new(false); + let mut input = Vec::with_capacity(payload.len() + 4); + input.extend_from_slice(payload); + input.extend_from_slice(&[0x00, 0x00, 0xff, 0xff]); + let mut out = Vec::new(); + let mut input_pos = 0usize; + let mut scratch = [0u8; RELAY_BUF_SIZE]; + loop { + let before_in = decoder.total_in(); + let before_out = decoder.total_out(); + let status = decoder + .decompress(&input[input_pos..], &mut scratch, FlushDecompress::Sync) + .unwrap(); + let read = usize::try_from(decoder.total_in() - before_in).unwrap(); + let written = usize::try_from(decoder.total_out() - before_out).unwrap(); + input_pos += read; + out.extend_from_slice(&scratch[..written]); + if matches!(status, Status::StreamEnd) { + break; + } + if input_pos >= input.len() && written < scratch.len() { + break; + } + assert!( + read != 0 || written != 0, + "test permessage-deflate decompression did not make progress" + ); + } + out + } + + fn websocket_request(extension: Option<&str>) -> L7Request { + let mut raw_header = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\n" + ); + if let Some(extension) = extension { + raw_header.push_str("Sec-WebSocket-Extensions: "); + raw_header.push_str(extension); + raw_header.push_str("\r\n"); + } + raw_header.push_str("Sec-WebSocket-Version: 13\r\n\r\n"); + L7Request { + action: "GET".to_string(), + target: "/ws".to_string(), + query_params: HashMap::new(), + raw_header: raw_header.into_bytes(), + body_length: BodyLength::None, + } + } + + async fn run_upgraded_websocket_case( + request_extension: Option<&'static str>, + response_extension: Option<&'static str>, + extension_mode: WebSocketExtensionMode, + resolver: Option>, + client_frame: Vec, + ) -> (String, CapturedFrame) { + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(16384); + let (mut client_app, mut proxy_to_client) = tokio::io::duplex(16384); + let req = websocket_request(request_extension); + let resolver_for_header = resolver.clone(); + let resolver_for_upgrade = resolver.clone(); + + let upstream_task = tokio::spawn(async move { + let forwarded = read_http_header_block(&mut upstream_side).await; + let forwarded = String::from_utf8(forwarded).unwrap(); + let mut response = format!( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: {VALID_WS_ACCEPT}\r\n" + ); + if let Some(extension) = response_extension { + response.push_str("Sec-WebSocket-Extensions: "); + response.push_str(extension); + response.push_str("\r\n"); + } + response.push_str("\r\n"); + upstream_side.write_all(response.as_bytes()).await.unwrap(); + upstream_side.flush().await.unwrap(); + let frame = read_websocket_frame(&mut upstream_side).await; + (forwarded, frame) + }); + + let relay_task = tokio::spawn(async move { + let outcome = relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + resolver: resolver_for_header.as_deref(), + websocket_extensions: extension_mode, + ..Default::default() + }, + ) + .await + .expect("handshake relay should succeed"); + let RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + } = outcome + else { + panic!("expected upgraded relay outcome"); + }; + let credential_rewrite = resolver_for_upgrade.is_some(); + crate::l7::relay::handle_upgrade( + &mut proxy_to_client, + &mut proxy_to_upstream, + overflow, + "example.com", + 443, + crate::l7::relay::UpgradeRelayOptions { + websocket_request: true, + websocket: crate::l7::relay::WebSocketUpgradeBehavior { + credential_rewrite, + permessage_deflate: websocket_permessage_deflate, + ..Default::default() + }, + secret_resolver: resolver_for_upgrade, + target: "/ws".to_string(), + policy_name: "test-policy".to_string(), + ..Default::default() + }, + ) + .await + }); + + let response = read_http_header_block(&mut client_app).await; + assert!( + String::from_utf8_lossy(&response).contains("101 Switching Protocols"), + "client must receive the upgrade before frame relay starts" + ); + client_app.write_all(&client_frame).await.unwrap(); + client_app.flush().await.unwrap(); + + let result = upstream_task.await.expect("upstream task should complete"); + drop(client_app); + let _ = tokio::time::timeout(std::time::Duration::from_secs(2), relay_task).await; + result + } #[test] fn deny_response_body_is_agent_readable_and_redacted() { @@ -2942,6 +3217,100 @@ mod tests { upstream_task.await.expect("upstream task should complete"); } + #[tokio::test] + async fn websocket_conformance_preserve_mode_relays_raw_frames_without_validation() { + let (forwarded, frame) = run_upgraded_websocket_case( + None, + None, + WebSocketExtensionMode::Preserve, + None, + unmasked_frame(TEXT_OPCODE, b"raw-unmasked"), + ) + .await; + + assert!( + forwarded.contains("Upgrade: websocket"), + "raw preserve path should still forward the upgrade request" + ); + assert!( + !frame.masked, + "raw preserve path must not validate or rewrite client frame masking" + ); + assert_eq!(frame.fin_opcode & 0x0f, TEXT_OPCODE); + assert_eq!(frame.payload, b"raw-unmasked"); + } + + #[tokio::test] + async fn websocket_conformance_rewrite_mode_rewrites_text_after_upgrade() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [("DISCORD_BOT_TOKEN".to_string(), "real-token".to_string())] + .into_iter() + .collect(), + ); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); + + let (forwarded, frame) = run_upgraded_websocket_case( + None, + None, + WebSocketExtensionMode::PermessageDeflate, + resolver.map(Arc::new), + masked_frame_with_rsv(TEXT_OPCODE, 0, payload.as_bytes()), + ) + .await; + + assert!( + !forwarded + .to_ascii_lowercase() + .contains("sec-websocket-extensions"), + "plain rewrite path should not offer compression when the client did not offer a safe subset" + ); + assert!(frame.masked, "parsed relay must preserve client masking"); + assert_eq!(frame.fin_opcode & 0x0f, TEXT_OPCODE); + assert_eq!( + String::from_utf8(frame.payload).unwrap(), + r#"{"op":2,"d":{"token":"real-token"}}"# + ); + } + + #[tokio::test] + async fn websocket_conformance_deflate_rewrites_compressed_text_after_upgrade() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [("DISCORD_BOT_TOKEN".to_string(), "real-token".to_string())] + .into_iter() + .collect(), + ); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); + let compressed = compress_test_permessage_deflate(payload.as_bytes()); + + let (forwarded, frame) = run_upgraded_websocket_case( + Some("permessage-deflate; server_no_context_takeover; client_no_context_takeover"), + Some("permessage-deflate; server_no_context_takeover; client_no_context_takeover"), + WebSocketExtensionMode::PermessageDeflate, + resolver.map(Arc::new), + masked_frame_with_rsv(TEXT_OPCODE, 0x40, &compressed), + ) + .await; + + assert!( + forwarded.to_ascii_lowercase().contains( + "sec-websocket-extensions: permessage-deflate; client_no_context_takeover; server_no_context_takeover" + ), + "safe extension offer should be canonicalized before forwarding" + ); + assert!(frame.masked, "parsed relay must preserve client masking"); + assert_eq!(frame.fin_opcode & 0x0f, TEXT_OPCODE); + assert!( + frame.fin_opcode & 0x40 != 0, + "rewritten compressed text must retain RSV1" + ); + assert_eq!( + String::from_utf8(decompress_test_permessage_deflate(&frame.payload)).unwrap(), + r#"{"op":2,"d":{"token":"real-token"}}"# + ); + } + #[tokio::test] async fn opted_in_websocket_relay_rejects_invalid_accept_before_forwarding_101() { let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); From 2f97ff3cc1c72687b90e4024973d262e3fa25673 Mon Sep 17 00:00:00 2001 From: Aaron Erickson Date: Fri, 8 May 2026 22:04:44 -0700 Subject: [PATCH 08/17] test(e2e): add websocket conformance lane Signed-off-by: Aaron Erickson --- .github/workflows/websocket-conformance.yml | 65 ++++ docs/reference/policy-schema.mdx | 60 +++- docs/security/best-practices.mdx | 4 +- e2e/rust/Cargo.toml | 5 + e2e/rust/tests/websocket_conformance.rs | 376 ++++++++++++++++++++ tasks/test.toml | 7 + 6 files changed, 508 insertions(+), 9 deletions(-) create mode 100644 .github/workflows/websocket-conformance.yml create mode 100644 e2e/rust/tests/websocket_conformance.rs diff --git a/.github/workflows/websocket-conformance.yml b/.github/workflows/websocket-conformance.yml new file mode 100644 index 000000000..c1c689c54 --- /dev/null +++ b/.github/workflows/websocket-conformance.yml @@ -0,0 +1,65 @@ +name: WebSocket Conformance + +on: + workflow_dispatch: {} + # Add `schedule:` here after this focused lane has burned in manually. + +permissions: {} + +jobs: + build-gateway: + permissions: + contents: read + packages: write + uses: ./.github/workflows/docker-build.yml + with: + component: gateway + platform: linux/amd64 + + build-supervisor: + permissions: + contents: read + packages: write + uses: ./.github/workflows/docker-build.yml + with: + component: supervisor + platform: linux/amd64 + + websocket-conformance: + name: WebSocket Conformance + needs: [build-gateway, build-supervisor] + runs-on: linux-amd64-cpu8 + timeout-minutes: 30 + permissions: + contents: read + packages: read + container: + image: ghcr.io/nvidia/openshell/ci:latest + credentials: + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + options: --privileged + volumes: + - /var/run/docker.sock:/var/run/docker.sock + - /home/runner/_work:/home/runner/_work + env: + MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + IMAGE_TAG: ${{ github.sha }} + OPENSHELL_REGISTRY: ghcr.io/nvidia/openshell + OPENSHELL_REGISTRY_HOST: ghcr.io + OPENSHELL_REGISTRY_NAMESPACE: nvidia/openshell + OPENSHELL_REGISTRY_USERNAME: ${{ github.actor }} + OPENSHELL_REGISTRY_PASSWORD: ${{ secrets.GITHUB_TOKEN }} + steps: + - uses: actions/checkout@v6 + + - name: Install OS test dependencies + run: apt-get update && apt-get install -y openssh-client && rm -rf /var/lib/apt/lists/* + + - name: Log in to GHCR + run: echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u "${{ github.actor }}" --password-stdin + + - name: Run WebSocket conformance e2e + env: + OPENSHELL_SUPERVISOR_IMAGE: ${{ format('ghcr.io/nvidia/openshell/supervisor:{0}', github.sha) }} + run: mise run --no-deps --skip-deps e2e:websocket-conformance diff --git a/docs/reference/policy-schema.mdx b/docs/reference/policy-schema.mdx index 35967e802..3d51a3da2 100644 --- a/docs/reference/policy-schema.mdx +++ b/docs/reference/policy-schema.mdx @@ -155,7 +155,7 @@ Each endpoint defines a reachable destination and optional inspection rules. | `host` | string | Yes | Hostname or IP address. Supports wildcards: `*.example.com` matches any subdomain. | | `port` | integer | Yes | TCP port number. | | `path` | string | No | Optional HTTP path glob used to select between L7 endpoints that share the same host and port. Empty means all paths. Use this when REST and GraphQL live under the same host, such as `/repos/**` and `/graphql`. | -| `protocol` | string | No | Set to `rest` for HTTP method/path inspection or `graphql` for GraphQL operation inspection. Omit for TCP passthrough. | +| `protocol` | string | No | Set to `rest` for HTTP method/path inspection, `websocket` for RFC 6455 upgrade and client text-message inspection, or `graphql` for GraphQL operation inspection. Omit for TCP passthrough. | | `tls` | string | No | TLS handling mode. The proxy auto-detects TLS by peeking the first bytes of each connection and terminates it for inspected HTTPS traffic, so this field is optional in most cases. Set to `skip` to disable auto-detection for edge cases such as client-certificate mTLS or non-standard protocols. The values `terminate` and `passthrough` are deprecated and log a warning; they are still accepted for backward compatibility but have no effect on behavior. | | `enforcement` | string | No | `enforce` actively blocks disallowed requests. `audit` logs violations but allows traffic through. | | `access` | string | No | Access preset. One of `read-only`, `read-write`, or `full`. Mutually exclusive with `rules`. | @@ -163,7 +163,7 @@ Each endpoint defines a reachable destination and optional inspection rules. | `deny_rules` | list of deny rule objects | No | L7 deny rules that block specific requests even when allowed by `access` or `rules`. Deny rules take precedence over allow rules. | | `allowed_ips` | list of string | No | CIDR or IP allowlist for SSRF override. Entries overlapping loopback (`127.0.0.0/8`), link-local (`169.254.0.0/16`), or unspecified (`0.0.0.0`) are rejected at load time. | | `allow_encoded_slash` | bool | No | When `true`, L7 request parsing preserves `%2F` inside path segments instead of rejecting it. Use this for registries and APIs such as npm scoped packages (`/@scope%2Fname`). Defaults to `false`. | -| `websocket_credential_rewrite` | bool | No | When `true` on a `protocol: rest` endpoint, OpenShell rewrites `openshell:resolve:env:*` placeholders in client-to-server WebSocket text messages after an allowed HTTP `101` upgrade. Defaults to `false`. | +| `websocket_credential_rewrite` | bool | No | When `true` on a `protocol: rest` or `protocol: websocket` endpoint, OpenShell rewrites `openshell:resolve:env:*` placeholders in client-to-server WebSocket text messages after an allowed HTTP `101` upgrade. Binary frames are relayed but not rewritten. Defaults to `false`. | | `persisted_queries` | string | No | GraphQL hash-only behavior. Default is `deny`; use `allow_registered` only with `graphql_persisted_queries`. | | `graphql_persisted_queries` | map | No | Trusted GraphQL persisted-query registry keyed by hash or saved-query ID. Values contain `operation_type`, optional `operation_name`, and optional root `fields`. | | `graphql_max_body_bytes` | integer | No | Maximum GraphQL request body bytes buffered for inspection. Defaults to `65536`. | @@ -172,11 +172,11 @@ Each endpoint defines a reachable destination and optional inspection rules. The `access` field accepts one of the following values: -| Value | REST expansion | GraphQL expansion | -|---|---|---| -| `full` | All methods and paths. | All operation types. | -| `read-only` | `GET`, `HEAD`, `OPTIONS`. | `query` operations. | -| `read-write` | `GET`, `HEAD`, `OPTIONS`, `POST`, `PUT`, `PATCH`. | `query` and `mutation` operations. | +| Value | REST expansion | WebSocket expansion | GraphQL expansion | +|---|---|---|---| +| `full` | All methods and paths. | WebSocket upgrade and all inspected client text-message paths. | All operation types. | +| `read-only` | `GET`, `HEAD`, `OPTIONS`. | WebSocket upgrade handshake only. | `query` operations. | +| `read-write` | `GET`, `HEAD`, `OPTIONS`, `POST`, `PUT`, `PATCH`. | WebSocket upgrade handshake and client text messages. | `query` and `mutation` operations. | #### Allow Rule Objects @@ -209,6 +209,28 @@ rules: any: ["v1.*", "v2.*"] ``` +##### WebSocket Allow Rule (`protocol: websocket`) + +WebSocket allow rules match the RFC 6455 HTTP upgrade by path and match client-to-server text messages on the same upgraded connection with the synthetic `WEBSOCKET_TEXT` method. Binary frames are relayed but are not rewritten. + +| Field | Type | Required | Description | +|---|---|---|---| +| `method` | string | Yes | `GET` allows the upgrade handshake, `WEBSOCKET_TEXT` allows client text messages after upgrade, and `*` matches both inspected actions. | +| `path` | string | Yes | URL path pattern from the original upgrade request. Supports `*` and `**` glob syntax. | +| `query` | map | No | Query parameter matchers from the original upgrade request. Matcher syntax is the same as REST allow rules. | + +Example WebSocket allow rules: + +```yaml showLineNumbers={false} +rules: + - allow: + method: GET + path: /v1/realtime/** + - allow: + method: WEBSOCKET_TEXT + path: /v1/realtime/** +``` + ##### GraphQL Allow Rule (`protocol: graphql`) GraphQL allow rules match parsed GraphQL operations by operation type, optional operation name, and optional root fields. @@ -264,6 +286,30 @@ endpoints: path: "/repos/*/rulesets" ``` +##### WebSocket Deny Rule (`protocol: websocket`) + +WebSocket deny rules use the same field names as WebSocket allow rules, but they appear directly under each `deny_rules` entry instead of under an `allow` wrapper. + +| Field | Type | Required | Description | +|---|---|---|---| +| `method` | string | Yes | `GET` denies matching upgrade handshakes, `WEBSOCKET_TEXT` denies matching client text messages after upgrade, and `*` matches both inspected actions. | +| `path` | string | Yes | URL path pattern from the original upgrade request. Same glob syntax as allow rules. | +| `query` | map | No | Query parameter matchers from the original upgrade request. Same syntax as allow rule `query`. | + +Example WebSocket deny rules: + +```yaml showLineNumbers={false} +endpoints: + - host: realtime.example.com + port: 443 + protocol: websocket + enforcement: enforce + access: read-write + deny_rules: + - method: WEBSOCKET_TEXT + path: "/v1/admin/**" +``` + ##### GraphQL Deny Rule (`protocol: graphql`) GraphQL deny rules use the same field names as GraphQL allow rules, but they appear directly under each `deny_rules` entry instead of under an `allow` wrapper. diff --git a/docs/security/best-practices.mdx b/docs/security/best-practices.mdx index 11afd41c8..e9ba48194 100644 --- a/docs/security/best-practices.mdx +++ b/docs/security/best-practices.mdx @@ -97,9 +97,9 @@ The `protocol` field on an endpoint controls whether the proxy inspects individu | Aspect | Detail | |---|---| | Default | Endpoints without a `protocol` field use L4-only enforcement: the proxy checks host, port, and binary, then relays the TCP stream without inspecting payloads. | -| What you can change | Add `protocol: rest` to enable per-request HTTP method/path inspection, or `protocol: graphql` to inspect GraphQL operation type, operation name, and root fields. Pair either protocol with `rules` or access presets (`full`, `read-only`, `read-write`). | +| What you can change | Add `protocol: rest` to enable per-request HTTP method/path inspection, `protocol: websocket` to inspect RFC 6455 upgrade handshakes and client text messages, or `protocol: graphql` to inspect GraphQL operation type, operation name, and root fields. Pair inspected protocols with `rules` or access presets (`full`, `read-only`, `read-write`). | | Risk if relaxed | L4-only endpoints allow the agent to send any data through the tunnel after the initial connection is permitted. The proxy cannot see HTTP methods, paths, or GraphQL operations. Adding `access: full` with L7 inspection enables observability but permits all inspected actions. | -| Recommendation | Use `protocol: rest` with specific `rules` for APIs where intent is encoded in method and path. Use `protocol: graphql` for GraphQL APIs where destructive operations are body-encoded. Prefer `access: read-only` or explicit allowlists, and deny hash-only persisted queries unless you maintain a trusted registry. Omit `protocol` for non-HTTP protocols. For WebSocket endpoints that begin with an HTTP upgrade and must carry placeholder credentials in client text frames, use `protocol: rest` with `websocket_credential_rewrite: true`. | +| Recommendation | Use `protocol: rest` with specific `rules` for APIs where intent is encoded in method and path. Use `protocol: graphql` for GraphQL APIs where destructive operations are body-encoded. Use `protocol: websocket` for RFC 6455 endpoints, with explicit `GET` and `WEBSOCKET_TEXT` rules when finer control is needed. Prefer `access: read-only` or explicit allowlists, and deny hash-only persisted queries unless you maintain a trusted registry. Omit `protocol` for non-HTTP protocols. For WebSocket endpoints that must carry placeholder credentials in client text frames, add `websocket_credential_rewrite: true`. | ### Enforcement Mode (`audit` vs `enforce`) diff --git a/e2e/rust/Cargo.toml b/e2e/rust/Cargo.toml index 0da2e417b..73b723718 100644 --- a/e2e/rust/Cargo.toml +++ b/e2e/rust/Cargo.toml @@ -41,6 +41,11 @@ name = "gateway_resume" path = "tests/gateway_resume.rs" required-features = ["e2e-docker"] +[[test]] +name = "websocket_conformance" +path = "tests/websocket_conformance.rs" +required-features = ["e2e-docker"] + [[test]] name = "user_namespaces" path = "tests/user_namespaces.rs" diff --git a/e2e/rust/tests/websocket_conformance.rs b/e2e/rust/tests/websocket_conformance.rs new file mode 100644 index 000000000..16606448c --- /dev/null +++ b/e2e/rust/tests/websocket_conformance.rs @@ -0,0 +1,376 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#![cfg(feature = "e2e")] + +//! E2E regression: WebSocket credential placeholders are resolved on the real +//! Docker-backed sandbox path after an RFC 6455 upgrade. +//! +//! The sandbox process sends its provider-managed placeholder in a masked text +//! frame. The local upstream only reports whether it saw the real secret and +//! whether any placeholder survived; it never echoes payload bytes, placeholder +//! text, or secret material back into test output. + +use std::io::Write; +use std::process::Stdio; +use std::sync::Mutex; + +use openshell_e2e::harness::binary::openshell_cmd; +use openshell_e2e::harness::container::ContainerHttpServer; +use openshell_e2e::harness::sandbox::SandboxGuard; +use tempfile::NamedTempFile; + +const PROVIDER_NAME: &str = "e2e-websocket-conformance"; +const TEST_SERVER_ALIAS: &str = "websocket-conformance.openshell.test"; +const TEST_SECRET: &str = "sk-e2e-websocket-conformance-secret"; +const TOKEN_ENV: &str = "WS_E2E_TOKEN"; +const PLACEHOLDER_PREFIX: &str = "openshell:resolve:env:"; +static PROVIDER_LOCK: Mutex<()> = Mutex::new(()); + +async fn run_cli(args: &[&str]) -> Result { + let mut cmd = openshell_cmd(); + cmd.args(args).stdout(Stdio::piped()).stderr(Stdio::piped()); + + let output = cmd + .output() + .await + .map_err(|e| format!("failed to spawn openshell {}: {e}", args.join(" ")))?; + + let stdout = String::from_utf8_lossy(&output.stdout).to_string(); + let stderr = String::from_utf8_lossy(&output.stderr).to_string(); + let combined = format!("{stdout}{stderr}"); + + if !output.status.success() { + return Err(format!( + "openshell {} failed (exit {:?}):\n{combined}", + args.join(" "), + output.status.code() + )); + } + + Ok(combined) +} + +async fn delete_provider(name: &str) { + let mut cmd = openshell_cmd(); + cmd.arg("provider") + .arg("delete") + .arg(name) + .stdout(Stdio::null()) + .stderr(Stdio::null()); + let _ = cmd.status().await; +} + +async fn create_generic_provider(name: &str) -> Result { + let credential = format!("{TOKEN_ENV}={TEST_SECRET}"); + run_cli(&[ + "provider", + "create", + "--name", + name, + "--type", + "generic", + "--credential", + &credential, + ]) + .await +} + +async fn start_websocket_probe_server() -> Result { + let script = format!( + r#" +import base64 +import hashlib +import json +import socketserver +import struct + +GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" +SECRET = {secret:?} +PLACEHOLDER_PREFIX = {placeholder_prefix:?} + +def recv_until(sock, marker): + data = b"" + while marker not in data: + chunk = sock.recv(4096) + if not chunk: + break + data += chunk + return data + +def read_exact(sock, size): + data = b"" + while len(data) < size: + chunk = sock.recv(size - len(data)) + if not chunk: + raise EOFError("unexpected end of websocket frame") + data += chunk + return data + +def read_frame(sock): + header = read_exact(sock, 2) + first, second = header[0], header[1] + length = second & 0x7F + if length == 126: + length = struct.unpack("!H", read_exact(sock, 2))[0] + elif length == 127: + length = struct.unpack("!Q", read_exact(sock, 8))[0] + mask = read_exact(sock, 4) if second & 0x80 else b"" + payload = read_exact(sock, length) + if mask: + payload = bytes(byte ^ mask[index % 4] for index, byte in enumerate(payload)) + return first, payload + +def send_text(sock, payload): + data = payload.encode("utf-8") + if len(data) < 126: + header = bytes([0x81, len(data)]) + elif len(data) <= 0xFFFF: + header = bytes([0x81, 126]) + struct.pack("!H", len(data)) + else: + header = bytes([0x81, 127]) + struct.pack("!Q", len(data)) + sock.sendall(header + data) + +def header_value(request, name): + prefix = name.lower() + ":" + for line in request.split("\r\n"): + if line.lower().startswith(prefix): + return line.split(":", 1)[1].strip() + return "" + +class Handler(socketserver.BaseRequestHandler): + def handle(self): + request_bytes = recv_until(self.request, b"\r\n\r\n") + request = request_bytes.decode("iso-8859-1", "replace") + if "upgrade: websocket" not in request.lower(): + self.request.sendall( + b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: close\r\n\r\nok" + ) + return + + key = header_value(request, "Sec-WebSocket-Key") + accept = base64.b64encode(hashlib.sha1((key + GUID).encode("ascii")).digest()).decode("ascii") + response = ( + "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + f"Sec-WebSocket-Accept: {{accept}}\r\n" + "\r\n" + ) + self.request.sendall(response.encode("ascii")) + + _, payload = read_frame(self.request) + text = payload.decode("utf-8", "replace") + result = {{ + "saw_placeholder": PLACEHOLDER_PREFIX in text, + "saw_secret": SECRET in text, + }} + send_text(self.request, json.dumps(result, sort_keys=True)) + +class Server(socketserver.ThreadingTCPServer): + allow_reuse_address = True + +Server(("0.0.0.0", 8000), Handler).serve_forever() +"#, + secret = TEST_SECRET, + placeholder_prefix = PLACEHOLDER_PREFIX, + ); + + ContainerHttpServer::start_python(TEST_SERVER_ALIAS, &script).await +} + +fn write_websocket_policy(host: &str, port: u16) -> Result { + let mut file = NamedTempFile::new().map_err(|e| format!("create temp policy file: {e}"))?; + let policy = format!( + r#"version: 1 + +filesystem_policy: + include_workdir: true + read_only: + - /usr + - /lib + - /proc + - /dev/urandom + - /app + - /etc + - /var/log + read_write: + - /sandbox + - /tmp + - /dev/null + +landlock: + compatibility: best_effort + +process: + run_as_user: sandbox + run_as_group: sandbox + +network_policies: + websocket_conformance: + name: websocket_conformance + endpoints: + - host: {host} + port: {port} + protocol: websocket + enforcement: enforce + access: read-write + websocket_credential_rewrite: true + allowed_ips: + - "10.0.0.0/8" + - "172.0.0.0/8" + - "192.168.0.0/16" + - "fc00::/7" + binaries: + - path: /usr/bin/python* + - path: /usr/local/bin/python* + - path: /sandbox/.uv/python/*/bin/python* +"# + ); + file.write_all(policy.as_bytes()) + .map_err(|e| format!("write temp policy file: {e}"))?; + file.flush() + .map_err(|e| format!("flush temp policy file: {e}"))?; + Ok(file) +} + +fn websocket_client_script(host: &str, port: u16) -> String { + format!( + r#" +import base64 +import json +import os +import socket +import struct + +HOST = {host:?} +PORT = {port} +TOKEN_ENV = {token_env:?} + +def recv_until(sock, marker): + data = b"" + while marker not in data: + chunk = sock.recv(4096) + if not chunk: + break + data += chunk + return data + +def read_exact(sock, size): + data = b"" + while len(data) < size: + chunk = sock.recv(size - len(data)) + if not chunk: + raise EOFError("unexpected end of websocket frame") + data += chunk + return data + +def masked_text_frame(payload): + data = payload.encode("utf-8") + mask = os.urandom(4) + if len(data) < 126: + header = bytes([0x81, 0x80 | len(data)]) + elif len(data) <= 0xFFFF: + header = bytes([0x81, 0x80 | 126]) + struct.pack("!H", len(data)) + else: + header = bytes([0x81, 0x80 | 127]) + struct.pack("!Q", len(data)) + masked = bytes(byte ^ mask[index % 4] for index, byte in enumerate(data)) + return header + mask + masked + +def read_frame(sock): + first, second = read_exact(sock, 2) + length = second & 0x7F + if length == 126: + length = struct.unpack("!H", read_exact(sock, 2))[0] + elif length == 127: + length = struct.unpack("!Q", read_exact(sock, 8))[0] + mask = read_exact(sock, 4) if second & 0x80 else b"" + payload = read_exact(sock, length) + if mask: + payload = bytes(byte ^ mask[index % 4] for index, byte in enumerate(payload)) + return first, payload + +token = os.environ[TOKEN_ENV] +payload = json.dumps({{"authorization": "Bearer " + token}}, sort_keys=True) +key = base64.b64encode(os.urandom(16)).decode("ascii") + +with socket.create_connection((HOST, PORT), timeout=20) as sock: + request = ( + f"GET /ws HTTP/1.1\r\n" + f"Host: {{HOST}}:{{PORT}}\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + f"Sec-WebSocket-Key: {{key}}\r\n" + "Sec-WebSocket-Version: 13\r\n" + "\r\n" + ) + sock.sendall(request.encode("ascii")) + response = recv_until(sock, b"\r\n\r\n").decode("iso-8859-1", "replace") + if not response.startswith("HTTP/1.1 101"): + raise RuntimeError("websocket upgrade failed") + sock.sendall(masked_text_frame(payload)) + _, response_payload = read_frame(sock) + print(response_payload.decode("utf-8")) +"#, + host = host, + port = port, + token_env = TOKEN_ENV, + ) +} + +#[tokio::test] +async fn websocket_text_placeholder_is_rewritten_in_docker_sandbox() { + let _provider_lock = PROVIDER_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + + delete_provider(PROVIDER_NAME).await; + create_generic_provider(PROVIDER_NAME) + .await + .expect("create generic provider"); + + let result = async { + let server = start_websocket_probe_server().await?; + let policy = write_websocket_policy(&server.host, server.port)?; + let policy_path = policy + .path() + .to_str() + .ok_or_else(|| "temp policy path should be utf-8".to_string())? + .to_string(); + let script = websocket_client_script(&server.host, server.port); + + SandboxGuard::create(&[ + "--policy", + &policy_path, + "--provider", + PROVIDER_NAME, + "--", + "python3", + "-c", + &script, + ]) + .await + } + .await; + + delete_provider(PROVIDER_NAME).await; + + let guard = result.expect("sandbox create"); + assert!( + guard + .create_output + .contains(r#"{"saw_placeholder": false, "saw_secret": true}"#), + "expected upstream to see only the resolved secret marker:\n{}", + guard.create_output + ); + assert!( + !guard.create_output.contains(TEST_SECRET), + "test output should not expose the raw WebSocket credential:\n{}", + guard.create_output + ); + assert!( + !guard.create_output.contains(PLACEHOLDER_PREFIX), + "test output should not expose unresolved credential placeholders:\n{}", + guard.create_output + ); +} diff --git a/tasks/test.toml b/tasks/test.toml index bf5741c72..a1f4c6429 100644 --- a/tasks/test.toml +++ b/tasks/test.toml @@ -34,6 +34,13 @@ run = [ "e2e/with-docker-gateway.sh cargo test --manifest-path e2e/rust/Cargo.toml --features e2e-docker", ] +["e2e:websocket-conformance"] +description = "Run focused WebSocket conformance e2e tests against a Docker-backed gateway" +run = [ + "cargo build -p openshell-cli --features openshell-core/dev-settings", + "e2e/with-docker-gateway.sh cargo test --manifest-path e2e/rust/Cargo.toml --features e2e-docker --test websocket_conformance", +] + ["e2e:python"] description = "Run Python e2e tests against a Docker-backed gateway (E2E_PARALLEL=N or 'auto'; default 5)" depends = ["python:proto"] From a6701784df384331dab0588a2371586c273247e7 Mon Sep 17 00:00:00 2001 From: Aaron Erickson Date: Fri, 8 May 2026 22:20:06 -0700 Subject: [PATCH 09/17] fix(policy): support websocket incremental rules Signed-off-by: Aaron Erickson --- crates/openshell-cli/src/main.rs | 4 +- crates/openshell-cli/src/policy_update.rs | 51 +++++++++- crates/openshell-policy/src/merge.rs | 116 +++++++++++++++++++--- docs/sandboxes/policies.mdx | 80 ++++++++++----- docs/security/best-practices.mdx | 4 +- 5 files changed, 209 insertions(+), 46 deletions(-) diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index 25fa07cf2..ae71fbac0 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -1411,11 +1411,11 @@ enum PolicyCommands { #[arg(long = "remove-endpoint")] remove_endpoints: Vec, - /// Add a REST allow rule: `host:port:METHOD:path_glob`. + /// Add a REST or WebSocket method/path allow rule: `host:port:METHOD:path_glob`. #[arg(long = "add-allow")] add_allow: Vec, - /// Add a REST deny rule: `host:port:METHOD:path_glob`. + /// Add a REST or WebSocket method/path deny rule: `host:port:METHOD:path_glob`. #[arg(long = "add-deny")] add_deny: Vec, diff --git a/crates/openshell-cli/src/policy_update.rs b/crates/openshell-cli/src/policy_update.rs index 322a28df6..6a17e5c32 100644 --- a/crates/openshell-cli/src/policy_update.rs +++ b/crates/openshell-cli/src/policy_update.rs @@ -285,9 +285,9 @@ fn parse_add_endpoint_spec(spec: &str) -> Result { "--add-endpoint access segment must be one of read-only, read-write, or full; got '{access}' in '{spec}'" )); } - if !protocol.is_empty() && !matches!(protocol, "rest" | "sql") { + if !protocol.is_empty() && !matches!(protocol, "rest" | "websocket" | "sql") { return Err(miette!( - "--add-endpoint protocol segment must be 'rest' or 'sql'; got '{protocol}' in '{spec}'" + "--add-endpoint protocol segment must be 'rest', 'websocket', or 'sql'; got '{protocol}' in '{spec}'" )); } if !enforcement.is_empty() && !matches!(enforcement, "enforce" | "audit") { @@ -353,6 +353,7 @@ fn dedup_strings(values: &[String]) -> Vec { #[cfg(test)] mod tests { use super::build_policy_update_plan; + use openshell_policy::PolicyMergeOp; #[test] fn parse_add_endpoint_basic_l4() { @@ -392,6 +393,52 @@ mod tests { .expect("plan should build"); } + #[test] + fn parse_add_endpoint_accepts_websocket_protocol() { + let plan = build_policy_update_plan( + &["realtime.example.com:443:read-write:websocket:enforce".to_string()], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect("plan should build"); + + let PolicyMergeOp::AddRule { rule, .. } = &plan.preview_operations[0] else { + panic!("expected add-rule preview"); + }; + let endpoint = &rule.endpoints[0]; + assert_eq!(endpoint.host, "realtime.example.com"); + assert_eq!(endpoint.protocol, "websocket"); + assert_eq!(endpoint.access, "read-write"); + assert_eq!(endpoint.enforcement, "enforce"); + } + + #[test] + fn parse_add_allow_accepts_websocket_text_method() { + let plan = build_policy_update_plan( + &[], + &[], + &[], + &["realtime.example.com:443:websocket_text:/v1/messages/**".to_string()], + &[], + &[], + None, + ) + .expect("plan should build"); + + let PolicyMergeOp::AddAllowRules { host, port, rules } = &plan.preview_operations[0] else { + panic!("expected add-allow preview"); + }; + assert_eq!(host, "realtime.example.com"); + assert_eq!(*port, 443); + let allow = rules[0].allow.as_ref().expect("allow rule"); + assert_eq!(allow.method, "WEBSOCKET_TEXT"); + assert_eq!(allow.path, "/v1/messages/**"); + } + #[test] fn parse_add_deny_rejects_empty_method() { let error = build_policy_update_plan( diff --git a/crates/openshell-policy/src/merge.rs b/crates/openshell-policy/src/merge.rs index ca4748b5b..df5a953b5 100644 --- a/crates/openshell-policy/src/merge.rs +++ b/crates/openshell-policy/src/merge.rs @@ -184,7 +184,7 @@ impl std::fmt::Display for PolicyMergeError { protocol, } => write!( f, - "endpoint {host}:{port} uses unsupported protocol '{protocol}'; this operation currently supports only protocol 'rest'" + "endpoint {host}:{port} uses unsupported protocol '{protocol}'; this operation currently supports only protocol 'rest' or 'websocket'" ), Self::EndpointHasNoAllowBase { host, port } => write!( f, @@ -265,7 +265,7 @@ fn apply_operation( port: *port, } })?; - ensure_rest_endpoint(endpoint, host, *port)?; + ensure_method_path_endpoint(endpoint, host, *port)?; if endpoint.access.is_empty() && endpoint.rules.is_empty() { return Err(PolicyMergeError::EndpointHasNoAllowBase { host: host.clone(), @@ -281,7 +281,7 @@ fn apply_operation( port: *port, } })?; - ensure_rest_endpoint(endpoint, host, *port)?; + ensure_method_path_endpoint(endpoint, host, *port)?; expand_existing_access(endpoint, host, *port, warnings)?; append_unique_l7_rules(&mut endpoint.rules, rules); } @@ -570,7 +570,7 @@ fn endpoint_matches_host_port(endpoint: &NetworkEndpoint, host: &str, port: u32) endpoint.host.eq_ignore_ascii_case(host) && canonical_ports(endpoint).contains(&port) } -fn ensure_rest_endpoint( +fn ensure_method_path_endpoint( endpoint: &NetworkEndpoint, host: &str, port: u32, @@ -581,7 +581,7 @@ fn ensure_rest_endpoint( port, }); } - if endpoint.protocol != "rest" { + if !matches!(endpoint.protocol.as_str(), "rest" | "websocket") { return Err(PolicyMergeError::UnsupportedEndpointProtocol { host: host.to_string(), port, @@ -602,12 +602,13 @@ fn expand_existing_access( } let access = endpoint.access.clone(); - let expanded = - expand_access_preset(&access).ok_or_else(|| PolicyMergeError::UnsupportedAccessPreset { + let expanded = expand_access_preset(&endpoint.protocol, &access).ok_or_else(|| { + PolicyMergeError::UnsupportedAccessPreset { host: host.to_string(), port, access: access.clone(), - })?; + } + })?; endpoint.access.clear(); append_unique_l7_rules(&mut endpoint.rules, &expanded); warnings.push(PolicyMergeWarning::ExpandedAccessPreset { @@ -618,11 +619,13 @@ fn expand_existing_access( Ok(()) } -fn expand_access_preset(access: &str) -> Option> { - let methods = match access { - "read-only" => vec!["GET", "HEAD", "OPTIONS"], - "read-write" => vec!["GET", "HEAD", "OPTIONS", "POST", "PUT", "PATCH"], - "full" => vec!["*"], +fn expand_access_preset(protocol: &str, access: &str) -> Option> { + let methods = match (protocol, access) { + (_, "full") => vec!["*"], + ("websocket", "read-only") => vec!["GET"], + ("websocket", "read-write") => vec!["GET", "WEBSOCKET_TEXT"], + (_, "read-only") => vec!["GET", "HEAD", "OPTIONS"], + (_, "read-write") => vec!["GET", "HEAD", "OPTIONS", "POST", "PUT", "PATCH"], _ => return None, }; @@ -911,7 +914,92 @@ mod tests { } #[test] - fn add_deny_requires_rest_protocol() { + fn add_allow_expands_websocket_access_preset() { + let mut policy = restrictive_default_policy(); + policy.network_policies.insert( + "realtime".to_string(), + NetworkPolicyRule { + name: "realtime".to_string(), + endpoints: vec![NetworkEndpoint { + host: "realtime.example.com".to_string(), + port: 443, + ports: vec![443], + protocol: "websocket".to_string(), + access: "read-write".to_string(), + ..Default::default() + }], + ..Default::default() + }, + ); + + let result = merge_policy( + policy, + &[PolicyMergeOp::AddAllowRules { + host: "realtime.example.com".to_string(), + port: 443, + rules: vec![rest_rule("WEBSOCKET_TEXT", "/rooms/private/**")], + }], + ) + .expect("merge should succeed"); + + let endpoint = &result.policy.network_policies["realtime"].endpoints[0]; + assert!(endpoint.access.is_empty()); + assert_eq!(endpoint.rules.len(), 3); + assert!(endpoint.rules.contains(&rest_rule("GET", "**"))); + assert!(endpoint.rules.contains(&rest_rule("WEBSOCKET_TEXT", "**"))); + assert!( + endpoint + .rules + .contains(&rest_rule("WEBSOCKET_TEXT", "/rooms/private/**")) + ); + assert!(!endpoint.rules.contains(&rest_rule("POST", "**"))); + assert!(result.warnings.iter().any(|warning| matches!( + warning, + PolicyMergeWarning::ExpandedAccessPreset { access, .. } if access == "read-write" + ))); + } + + #[test] + fn add_deny_accepts_websocket_protocol() { + let mut policy = restrictive_default_policy(); + policy.network_policies.insert( + "realtime".to_string(), + NetworkPolicyRule { + name: "realtime".to_string(), + endpoints: vec![NetworkEndpoint { + host: "realtime.example.com".to_string(), + port: 443, + ports: vec![443], + protocol: "websocket".to_string(), + access: "read-write".to_string(), + ..Default::default() + }], + ..Default::default() + }, + ); + + let result = merge_policy( + policy, + &[PolicyMergeOp::AddDenyRules { + host: "realtime.example.com".to_string(), + port: 443, + deny_rules: vec![L7DenyRule { + method: "WEBSOCKET_TEXT".to_string(), + path: "/admin/**".to_string(), + ..Default::default() + }], + }], + ) + .expect("merge should succeed"); + + let endpoint = &result.policy.network_policies["realtime"].endpoints[0]; + assert_eq!(endpoint.deny_rules.len(), 1); + assert_eq!(endpoint.deny_rules[0].method, "WEBSOCKET_TEXT"); + assert_eq!(endpoint.deny_rules[0].path, "/admin/**"); + } + + #[test] + fn add_deny_rejects_unsupported_protocol() { let mut policy = restrictive_default_policy(); policy.network_policies.insert( "db".to_string(), diff --git a/docs/sandboxes/policies.mdx b/docs/sandboxes/policies.mdx index 1939cf651..491271c42 100644 --- a/docs/sandboxes/policies.mdx +++ b/docs/sandboxes/policies.mdx @@ -49,14 +49,14 @@ network_policies: Static sections are locked at sandbox creation. Changing them requires destroying and recreating the sandbox. Dynamic sections can be updated on a running sandbox with `openshell policy update` for incremental merges or `openshell policy set` for full replacement, and take effect without restarting. When a hot reload changes rules on an active HTTP L7 endpoint, existing keep-alive tunnels are closed before forwarding another parsed request. Credential-injection-only HTTP passthrough tunnels use the same reload boundary. Most HTTP clients reconnect automatically, and the next request is evaluated against the current policy. -Raw streams are connection-scoped and outside L7 live-reload guarantees. This includes `tls: skip`, non-HTTP TCP payloads, HTTP upgrades such as WebSocket, and long-lived response streams such as SSE. A reload applies to the next connection or next parsed HTTP request; it does not interrupt an already-forwarded raw stream. The exception is `websocket_credential_rewrite: true` on `protocol: rest` endpoints, which keeps policy evaluation on the HTTP upgrade request and rewrites credential placeholders only in client-to-server WebSocket text messages after the allowed upgrade. +Raw streams are connection-scoped and outside L7 live-reload guarantees. This includes `tls: skip`, non-HTTP TCP payloads, HTTP upgrades such as WebSocket, and long-lived response streams such as SSE. A reload applies to the next connection or next parsed HTTP request; it does not interrupt an already-forwarded raw stream. Use `protocol: websocket` when policy should stay attached to the RFC 6455 upgrade and client text messages after the allowed upgrade. Add `websocket_credential_rewrite: true` only when the relay should rewrite credential placeholders in client-to-server WebSocket text messages. | Section | Type | Description | |---|---|---| | `filesystem_policy` | Static | Controls which directories the agent can access on disk. Paths are split into `read_only` and `read_write` lists. Any path not listed in either list is inaccessible. Set `include_workdir: true` to automatically add the agent's working directory to `read_write`. [Landlock LSM](https://docs.kernel.org/security/landlock.html) enforces these restrictions at the kernel level. | | `landlock` | Static | Configures Landlock LSM enforcement behavior. Set `compatibility` to `best_effort` (skip individual inaccessible paths while applying remaining rules) or `hard_requirement` (fail if any path is inaccessible or the required kernel ABI is unavailable). See the [Policy Schema Reference](/reference/policy-schema#landlock) for the full behavior table. | | `process` | Static | Sets the OS-level identity for the agent process. `run_as_user` and `run_as_group` default to `sandbox`. Root (`root` or `0`) is rejected. The agent also runs with seccomp filters that block dangerous system calls. | -| `network_policies` | Dynamic | Controls network access for ordinary outbound traffic from the sandbox. Each block has a name, a list of endpoints (host, port, protocol, and optional rules), and a list of binaries allowed to use those endpoints.
Every outbound connection except `https://inference.local` goes through the proxy, which queries the [policy engine](/about/how-it-works#core-components) with the destination and calling binary. A connection is allowed only when both match an entry in the same policy block.
For endpoints with `protocol: rest`, the proxy auto-detects TLS and terminates it so each HTTP request can be checked against that endpoint's `rules` (method and path). Set `websocket_credential_rewrite: true` only when a REST-shaped WebSocket upgrade must keep placeholder credentials in sandbox-owned payloads and resolve them at the OpenShell relay boundary.
Endpoints without `protocol` allow the TCP stream through without inspecting payloads.
If no endpoint matches, the connection is denied. Configure managed inference separately through [Inference Routing](/sandboxes/inference-routing). | +| `network_policies` | Dynamic | Controls network access for ordinary outbound traffic from the sandbox. Each block has a name, a list of endpoints (host, port, protocol, and optional rules), and a list of binaries allowed to use those endpoints.
Every outbound connection except `https://inference.local` goes through the proxy, which queries the [policy engine](/about/how-it-works#core-components) with the destination and calling binary. A connection is allowed only when both match an entry in the same policy block.
For endpoints with `protocol: rest`, the proxy auto-detects TLS and terminates it so each HTTP request can be checked against that endpoint's `rules` (method and path). For endpoints with `protocol: websocket`, the proxy validates the RFC 6455 upgrade and evaluates `GET` rules for the handshake plus `WEBSOCKET_TEXT` rules for client text messages on the upgraded request path. Set `websocket_credential_rewrite: true` only when a WebSocket or REST compatibility endpoint must keep placeholder credentials in sandbox-owned text frames and resolve them at the OpenShell relay boundary.
Endpoints without `protocol` allow the TCP stream through without inspecting payloads.
If no endpoint matches, the connection is denied. Configure managed inference separately through [Inference Routing](/sandboxes/inference-routing). | ## Baseline Filesystem Paths @@ -123,7 +123,7 @@ The following steps outline the hot-reload policy update workflow. openshell logs --tail --source sandbox ``` -3. For additive network changes, use `openshell policy update`. This is the fastest path for adding endpoints, binaries, or REST allow/deny rules without replacing the full policy. The full option and format reference is in [Incremental Policy Updates](#incremental-policy-updates). +3. For additive network changes, use `openshell policy update`. This is the fastest path for adding endpoints, binaries, or REST and WebSocket allow/deny rules without replacing the full policy. The full option and format reference is in [Incremental Policy Updates](#incremental-policy-updates). ```shell openshell policy update \ @@ -136,7 +136,7 @@ The following steps outline the hot-reload policy update workflow. --wait ``` - `--add-allow` and `--add-deny` currently target existing `protocol: rest` endpoints only. If you pass multiple update flags in one command, OpenShell applies them as one atomic merge batch and persists at most one new revision. + `--add-allow` and `--add-deny` target existing `protocol: rest` or `protocol: websocket` endpoints. If you pass multiple update flags in one command, OpenShell applies them as one atomic merge batch and persists at most one new revision. 4. For larger edits, pull the current policy and edit the YAML directly. Strip the metadata header (Version, Hash, Status) before reusing the file. @@ -165,7 +165,7 @@ Use `openshell policy update` when you want to merge network policy changes into `openshell policy update` is useful when you want to: - add a new endpoint for an existing binary without touching other policy sections. -- add a few REST allow or deny rules after you see a blocked request in the logs. +- add a few REST or WebSocket allow/deny rules after you see a blocked request in the logs. - remove one endpoint or one named rule without rewriting the rest of the file. - preview a merged result locally with `--dry-run` before you send it to the gateway. @@ -173,15 +173,15 @@ Use `openshell policy set` instead when you want to replace the full policy, upd ### Update Commands -The incremental update surface is split into endpoint-level operations and REST rule-level operations. +The incremental update surface is split into endpoint-level operations and method/path rule-level operations for REST and WebSocket endpoints. | Flag | What it changes | Typical use | |---|---|---| | `--add-endpoint ` | Creates or merges a network rule and endpoint. | Allow a new host and port, optionally with `access`, `protocol`, `enforcement`, and binaries. | | `--remove-endpoint ` | Removes one host and port match from the current policy. | Drop a stale endpoint or remove one port from a multi-port endpoint. | | `--remove-rule ` | Deletes a named `network_policies` entry. | Remove a whole rule by name when you no longer need it. | -| `--add-allow ` | Appends REST allow rules to an existing endpoint. | Permit one additional method and path on a REST API that is already configured. | -| `--add-deny ` | Appends REST deny rules to an existing endpoint. | Block a sensitive REST path under an endpoint that is otherwise allowed. | +| `--add-allow ` | Appends method/path allow rules to an existing REST or WebSocket endpoint. | Permit one additional REST method/path or WebSocket `WEBSOCKET_TEXT` path on an API that is already configured. | +| `--add-deny ` | Appends method/path deny rules to an existing REST or WebSocket endpoint. | Block a sensitive REST path or WebSocket text-message path under an endpoint that is otherwise allowed. | | `--binary ` | Adds binaries to every `--add-endpoint` rule in the same command. | Bind a new endpoint to one or more executables. | | `--rule-name ` | Overrides the generated rule name. | Keep a stable human-chosen rule name when adding exactly one endpoint. | | `--dry-run` | Shows the merged policy locally and does not call the gateway. | Review the result before persisting it. | @@ -194,17 +194,18 @@ The incremental update surface is split into endpoint-level operations and REST `--add-endpoint` works at the endpoint and rule level. It creates a new `network_policies` entry when needed, or merges into an existing rule that already covers the same host and port. Use it when you are defining where traffic may go and which binaries may send it. -`--add-allow` and `--add-deny` work at the REST request level. They do not create binaries, and they do not create a new endpoint. They modify an existing endpoint that already has `protocol: rest`. +`--add-allow` and `--add-deny` work at the method/path rule level. They do not create binaries, and they do not create a new endpoint. They modify an existing endpoint that already has `protocol: rest` or `protocol: websocket`. This is the practical difference: - Use `--add-endpoint` to say "allow this binary to reach `api.github.com:443`." - Use `--add-allow` to say "for that existing REST endpoint, also allow `POST /repos/*/issues`." - Use `--add-deny` to say "for that existing REST endpoint, explicitly deny `POST /admin/**`." +- Use `--add-allow` to say "for that existing WebSocket endpoint, also allow client text messages on `/v1/realtime/**`." -In the first pass of this feature: +Current constraints: -- `--add-allow` and `--add-deny` only work on `protocol: rest` endpoints. +- `--add-allow` and `--add-deny` work on `protocol: rest` and `protocol: websocket` endpoints. - `--add-deny` requires the endpoint to already have an allow base, either an `access` preset or explicit allow `rules`. - `protocol: sql` is not a practical incremental workflow today. OpenShell does not do full SQL parsing, and SQL enforcement is not meaningfully supported yet. @@ -222,8 +223,8 @@ Each segment has a fixed meaning: |---|---|---| | `host` | Yes | Destination hostname. | | `port` | Yes | Destination port, `1` through `65535`. | -| `access` | No | Access preset for L7 endpoints: `read-only`, `read-write`, or `full`. Incremental updates currently expand presets for REST-shaped access. | -| `protocol` | No | L7 inspection mode: `rest` or `sql`. `sql` is audit-only and not a recommended workflow today. | +| `access` | No | Access preset for L7 endpoints: `read-only`, `read-write`, or `full`. Incremental updates expand presets into protocol-specific method/path rules for REST and WebSocket endpoints. | +| `protocol` | No | L7 inspection mode: `rest`, `websocket`, or `sql`. `sql` is audit-only and not a recommended workflow today. | | `enforcement` | No | Enforcement mode for inspected traffic: `enforce` or `audit`. | Examples: @@ -232,19 +233,21 @@ Examples: |---|---| | `pypi.org:443` | Add a plain L4 endpoint. The proxy allows the TCP stream and does not inspect HTTP requests. | | `api.github.com:443:read-only:rest:enforce` | Add a REST endpoint with the `read-only` preset expanded by the policy engine into GET, HEAD, and OPTIONS access. | +| `realtime.example.com:443:read-write:websocket:enforce` | Add a WebSocket endpoint with the `read-write` preset expanded by the policy engine into the upgrade `GET` and client `WEBSOCKET_TEXT` access. | -If you set `protocol: rest`, you also need an allow shape. With incremental updates, that means you should provide an `access` preset on `--add-endpoint`, then use `--add-allow` or `--add-deny` to refine REST endpoints later. +If you set `protocol: rest` or `protocol: websocket`, you also need an allow shape. With incremental updates, that means you should provide an `access` preset on `--add-endpoint`, then use `--add-allow` or `--add-deny` to refine method/path rules later. For example: - `api.github.com:443:read-only:rest` is valid. +- `realtime.example.com:443:read-write:websocket` is valid. - `api.github.com:443::rest` is invalid. It does not mean "allow all traffic." An L7 endpoint with `protocol` but no `access` or `rules` is rejected when the policy loads. When you pass multiple `--add-endpoint` flags in one command, every `--binary` value applies to every added endpoint in that command. If different endpoints need different binaries, use separate `policy update` commands. If you do not pass `--rule-name`, OpenShell generates one from the host and port, such as `allow_api_github_com_443`. -### REST Rule Specs +### Method/Path Rule Specs `--add-allow` and `--add-deny` use this format: @@ -252,7 +255,7 @@ If you do not pass `--rule-name`, OpenShell generates one from the host and port host:port:METHOD:path_glob ``` -This string identifies an existing REST endpoint and the request pattern you want to add. +This string identifies an existing REST or WebSocket endpoint and the request pattern you want to add. In shell commands, quote the full `SPEC` when it contains `*` or `**` so your shell passes it literally instead of expanding it as a local file glob. @@ -260,8 +263,8 @@ In shell commands, quote the full `SPEC` when it contains `*` or `**` so your sh |---|---| | `host` | Existing endpoint host. | | `port` | Existing endpoint port. | -| `METHOD` | HTTP method. The CLI normalizes it to uppercase. | -| `path_glob` | URL path glob. It must start with `/`, or be `**`, or start with `**/`. | +| `METHOD` | HTTP method for REST endpoints, or `GET` / `WEBSOCKET_TEXT` for WebSocket endpoints. The CLI normalizes it to uppercase. | +| `path_glob` | URL path glob. For WebSocket text messages, this still matches the upgraded request path, not message payload content. It must start with `/`, or be `**`, or start with `**/`. | This example: @@ -283,11 +286,11 @@ Path globs follow the same semantics as YAML allow and deny rules: - `/repos/*/issues` matches one repository owner or name segment in the middle. - `/repos/**` matches everything under `/repos/`. -The rule-level commands only modify method and path constraints. They do not change binaries, hostnames, ports, or protocol settings. +The rule-level commands only modify method and path constraints. They do not change binaries, hostnames, ports, protocol settings, or WebSocket message payload matching. ### Common Workflows -Use these patterns as starting points when you decide whether to update an endpoint or append REST rules. +Use these patterns as starting points when you decide whether to update an endpoint or append REST/WebSocket rules. #### Add a new L4 endpoint @@ -302,7 +305,7 @@ openshell policy update demo \ --wait ``` -This creates or merges endpoint entries and binds them to the listed binaries. It does not create per-path REST rules. +This creates or merges endpoint entries and binds them to the listed binaries. It does not create inspected method/path rules. #### Create a REST endpoint with a base allow set @@ -341,6 +344,31 @@ openshell policy update demo \ This adds a deny rule to the existing REST endpoint. The endpoint must already have an allow base. +#### Create a WebSocket endpoint with a base allow set + +Use `--add-endpoint` with `protocol: websocket` when the destination is an RFC 6455 WebSocket API. + +```shell +openshell policy update demo \ + --add-endpoint realtime.example.com:443:read-write:websocket:enforce \ + --binary /usr/bin/node \ + --wait +``` + +This creates a WebSocket endpoint and sets its base allow behavior through the `read-write` access preset. For WebSocket endpoints, `read-write` expands to the upgrade `GET` and client `WEBSOCKET_TEXT` messages on the upgraded request path. + +#### Add a WebSocket text-message deny rule + +Use `WEBSOCKET_TEXT` when you want to refine client-to-server text-frame policy without matching message payload content. + +```shell +openshell policy update demo \ + --add-deny 'realtime.example.com:443:WEBSOCKET_TEXT:/v1/admin/**' \ + --wait +``` + +This adds a deny rule to the existing WebSocket endpoint. The path glob matches the WebSocket upgrade path. + #### Remove one endpoint or rule Use `--remove-endpoint` to remove one host and port pair, or `--remove-rule` to delete the whole named rule. @@ -379,7 +407,7 @@ The CLI validates the argument shapes before it sends the request. The gateway t - a required segment is missing. - a port is outside `1` through `65535`. - `--add-allow` or `--add-deny` points at an endpoint that does not exist. -- `--add-allow` or `--add-deny` targets a non-REST endpoint. +- `--add-allow` or `--add-deny` targets an endpoint that is neither REST nor WebSocket. - `--add-deny` targets an endpoint that has no base allow set. ## Global Policy Override @@ -415,7 +443,7 @@ When triaging denied requests, check: - Destination host and port to confirm which endpoint is missing. - Calling binary path to confirm which `binaries` entry needs to be added or adjusted. -- HTTP method and path (for REST endpoints) to confirm which `rules` entry needs to be added or adjusted. +- HTTP method and path for REST endpoints, or `GET` / `WEBSOCKET_TEXT` and the upgraded request path for WebSocket endpoints, to confirm which `rules` entry needs to be added or adjusted. Then push the updated policy as described above. @@ -427,7 +455,7 @@ openshell policy update --add-allow 'api.github.com:443:GET:/repos/**' -- ## Examples -Add these blocks to the `network_policies` section of your sandbox policy. Apply simple endpoints and REST rule additions with `openshell policy update`, or apply any complete YAML block with `openshell policy set --policy --wait`. +Add these blocks to the `network_policies` section of your sandbox policy. Apply simple endpoints and REST/WebSocket rule additions with `openshell policy update`, or apply any complete YAML block with `openshell policy set --policy --wait`. Use **Simple endpoint** for host-level allowlists and **Granular rules** for method/path control. @@ -447,7 +475,7 @@ Allow `pip install` and `uv pip install` to reach PyPI: - { path: /usr/local/bin/uv } ``` -Endpoints without `protocol` use TCP passthrough, where the proxy allows the stream without inspecting payloads. If the stream is HTTP and TLS is auto-terminated, the proxy can still rewrite configured credential placeholders and closes keep-alive passthrough tunnels on policy reload before forwarding another request. WebSocket payload credential rewrite requires an explicit `protocol: rest` endpoint with `websocket_credential_rewrite: true`. +Endpoints without `protocol` use TCP passthrough, where the proxy allows the stream without inspecting payloads. If the stream is HTTP and TLS is auto-terminated, the proxy can still rewrite configured credential placeholders and closes keep-alive passthrough tunnels on policy reload before forwarding another request. WebSocket text-frame policy requires an explicit `protocol: websocket` endpoint. WebSocket payload credential rewrite can also be enabled on a `protocol: rest` compatibility endpoint with `websocket_credential_rewrite: true`.
@@ -505,7 +533,7 @@ For an end-to-end walkthrough that combines this policy with a GitHub credential - { path: /usr/bin/gh } ``` -Endpoints with `protocol: rest` enable HTTP request inspection. Endpoints with `protocol: graphql` parse GraphQL-over-HTTP payloads before evaluating rules. The endpoint-level `path` field lets both protocols share `api.github.com:443` without treating GraphQL payloads as plain REST `POST /graphql` requests. +Endpoints with `protocol: rest` enable HTTP request inspection. Endpoints with `protocol: websocket` validate WebSocket upgrades and inspect client text messages on the upgraded request path. Endpoints with `protocol: graphql` parse GraphQL-over-HTTP payloads before evaluating rules. The endpoint-level `path` field lets these protocols share `api.github.com:443` without treating GraphQL payloads as plain REST `POST /graphql` requests. diff --git a/docs/security/best-practices.mdx b/docs/security/best-practices.mdx index e9ba48194..224cdf644 100644 --- a/docs/security/best-practices.mdx +++ b/docs/security/best-practices.mdx @@ -283,8 +283,8 @@ The following patterns weaken security without providing meaningful benefit. | Mistake | Why it matters | What to do instead | |---------|---------------|-------------------| -| Omitting `protocol: rest` on REST API endpoints | Without `protocol: rest`, the proxy uses L4-only enforcement. It allows the TCP stream through after checking host, port, and binary, but cannot inspect individual HTTP requests. | Add `protocol: rest` with specific `rules` to enable per-request method and path control. | -| Using `access: full` when finer rules would suffice | `access: full` with `protocol: rest` enables inspection but allows all HTTP methods and paths. | Use `access: read-only` or explicit `rules` to restrict what the agent can do at the HTTP level. | +| Omitting an inspected protocol on REST or WebSocket API endpoints | Without `protocol: rest` or `protocol: websocket`, the proxy uses L4-only enforcement. It allows the TCP stream through after checking host, port, and binary, but cannot inspect individual HTTP requests or WebSocket text messages. | Add `protocol: rest` or `protocol: websocket` with specific `rules` to enable method and path control. | +| Using `access: full` when finer rules would suffice | `access: full` with `protocol: rest` or `protocol: websocket` enables inspection but allows all methods and paths for that protocol. | Use `access: read-only` or explicit `rules` to restrict what the agent can do at the L7 level. | | Adding endpoints permanently when operator approval would suffice | Adding endpoints to the policy YAML makes them permanently reachable across all instances. | Use operator approval. Approved endpoints persist within the sandbox instance but reset on re-creation. | | Using broad binary globs | A glob like `/**` allows any binary to reach the endpoint, defeating binary-scoped enforcement. | Scope globs to specific directories (for example, `/sandbox/.vscode-server/**`). | | Skipping TLS termination on HTTPS APIs | Setting `tls: skip` disables credential injection and L7 inspection. | Use the default auto-detect behavior unless the upstream requires client-certificate mTLS. | From 4009a4a46585b5b50843834f42ddc3a2c5fafa08 Mon Sep 17 00:00:00 2001 From: Aaron Erickson Date: Fri, 8 May 2026 22:33:43 -0700 Subject: [PATCH 10/17] feat(policy): enable websocket credential rewrite updates Signed-off-by: Aaron Erickson --- crates/openshell-cli/src/main.rs | 7 + crates/openshell-cli/src/policy_update.rs | 98 ++++++++++++- crates/openshell-cli/src/run.rs | 2 + crates/openshell-policy/src/merge.rs | 45 ++++++ crates/openshell-sandbox/src/l7/relay.rs | 157 ++++++++++++++++++++- crates/openshell-server/src/grpc/policy.rs | 31 ++++ docs/sandboxes/policies.mdx | 6 +- 7 files changed, 342 insertions(+), 4 deletions(-) diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index ae71fbac0..14ab9811a 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -287,6 +287,7 @@ const POLICY_EXAMPLES: &str = "\x1b[1mALIAS\x1b[0m $ openshell policy get my-sandbox $ openshell policy set my-sandbox --policy policy.yaml $ openshell policy update my-sandbox --add-endpoint api.github.com:443:read-only:rest:enforce + $ openshell policy update my-sandbox --add-endpoint realtime.example.com:443:read-write:websocket:enforce --websocket-credential-rewrite $ openshell policy update my-sandbox --add-allow 'api.github.com:443:GET:/repos/**' $ openshell policy set --global --policy policy.yaml $ openshell policy delete --global @@ -1427,6 +1428,10 @@ enum PolicyCommands { #[arg(long = "binary", value_hint = ValueHint::FilePath)] binaries: Vec, + /// Rewrite credential placeholders in WebSocket client text frames for added REST/WebSocket endpoints. + #[arg(long = "websocket-credential-rewrite")] + websocket_credential_rewrite: bool, + /// Override the generated rule name when exactly one --add-endpoint is provided. #[arg(long = "rule-name")] rule_name: Option, @@ -1974,6 +1979,7 @@ async fn main() -> Result<()> { add_deny, remove_rules, binaries, + websocket_credential_rewrite, rule_name, dry_run, wait, @@ -1989,6 +1995,7 @@ async fn main() -> Result<()> { &add_allow, &remove_rules, &binaries, + websocket_credential_rewrite, rule_name.as_deref(), dry_run, wait, diff --git a/crates/openshell-cli/src/policy_update.rs b/crates/openshell-cli/src/policy_update.rs index 6a17e5c32..f5caa9c00 100644 --- a/crates/openshell-cli/src/policy_update.rs +++ b/crates/openshell-cli/src/policy_update.rs @@ -18,6 +18,7 @@ pub struct PolicyUpdatePlan { pub preview_operations: Vec, } +#[allow(clippy::too_many_arguments)] pub fn build_policy_update_plan( add_endpoints: &[String], remove_endpoints: &[String], @@ -25,6 +26,7 @@ pub fn build_policy_update_plan( add_allow: &[String], remove_rules: &[String], binaries: &[String], + websocket_credential_rewrite: bool, rule_name: Option<&str>, ) -> Result { if binaries.iter().any(|binary| binary.trim().is_empty()) { @@ -41,13 +43,22 @@ pub fn build_policy_update_plan( "--rule-name is only supported when exactly one --add-endpoint is provided" )); } + if websocket_credential_rewrite && add_endpoints.is_empty() { + return Err(miette!( + "--websocket-credential-rewrite can only be used with --add-endpoint" + )); + } let mut merge_operations = Vec::new(); let mut preview_operations = Vec::new(); let deduped_binaries = dedup_strings(binaries); for spec in add_endpoints { - let endpoint = parse_add_endpoint_spec(spec)?; + let mut endpoint = parse_add_endpoint_spec(spec)?; + if websocket_credential_rewrite { + ensure_websocket_credential_rewrite_protocol(spec, &endpoint)?; + endpoint.websocket_credential_rewrite = true; + } let target_rule_name = rule_name .map(str::trim) .filter(|name| !name.is_empty()) @@ -155,6 +166,23 @@ pub fn build_policy_update_plan( }) } +fn ensure_websocket_credential_rewrite_protocol( + spec: &str, + endpoint: &NetworkEndpoint, +) -> Result<()> { + if matches!(endpoint.protocol.as_str(), "rest" | "websocket") { + return Ok(()); + } + let protocol = if endpoint.protocol.is_empty() { + "" + } else { + endpoint.protocol.as_str() + }; + Err(miette!( + "--websocket-credential-rewrite requires --add-endpoint protocol segment to be 'rest' or 'websocket'; got '{protocol}' in '{spec}'" + )) +} + fn group_allow_rules(specs: &[String]) -> Result>> { let mut grouped = BTreeMap::new(); for spec in specs { @@ -352,9 +380,32 @@ fn dedup_strings(values: &[String]) -> Vec { #[cfg(test)] mod tests { - use super::build_policy_update_plan; + use super::{ + PolicyUpdatePlan, build_policy_update_plan as build_policy_update_plan_with_options, + }; use openshell_policy::PolicyMergeOp; + fn build_policy_update_plan( + add_endpoints: &[String], + remove_endpoints: &[String], + add_deny: &[String], + add_allow: &[String], + remove_rules: &[String], + binaries: &[String], + rule_name: Option<&str>, + ) -> miette::Result { + build_policy_update_plan_with_options( + add_endpoints, + remove_endpoints, + add_deny, + add_allow, + remove_rules, + binaries, + false, + rule_name, + ) + } + #[test] fn parse_add_endpoint_basic_l4() { let plan = @@ -416,6 +467,49 @@ mod tests { assert_eq!(endpoint.enforcement, "enforce"); } + #[test] + fn parse_add_endpoint_enables_websocket_credential_rewrite() { + let plan = build_policy_update_plan_with_options( + &["realtime.example.com:443:read-write:websocket:enforce".to_string()], + &[], + &[], + &[], + &[], + &[], + true, + None, + ) + .expect("plan should build"); + + let PolicyMergeOp::AddRule { rule, .. } = &plan.preview_operations[0] else { + panic!("expected add-rule preview"); + }; + assert!(rule.endpoints[0].websocket_credential_rewrite); + } + + #[test] + fn websocket_credential_rewrite_requires_add_endpoint() { + let error = build_policy_update_plan_with_options(&[], &[], &[], &[], &[], &[], true, None) + .expect_err("plan should fail"); + assert!(error.to_string().contains("--websocket-credential-rewrite")); + } + + #[test] + fn websocket_credential_rewrite_rejects_l4_endpoint() { + let error = build_policy_update_plan_with_options( + &["realtime.example.com:443".to_string()], + &[], + &[], + &[], + &[], + &[], + true, + None, + ) + .expect_err("plan should fail"); + assert!(error.to_string().contains("protocol segment")); + } + #[test] fn parse_add_allow_accepts_websocket_text_method() { let plan = build_policy_update_plan( diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index 165713b6e..84a18983a 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -4705,6 +4705,7 @@ pub async fn sandbox_policy_update( add_allow: &[String], remove_rules: &[String], binaries: &[String], + websocket_credential_rewrite: bool, rule_name: Option<&str>, dry_run: bool, wait: bool, @@ -4722,6 +4723,7 @@ pub async fn sandbox_policy_update( add_allow, remove_rules, binaries, + websocket_credential_rewrite, rule_name, )?; diff --git a/crates/openshell-policy/src/merge.rs b/crates/openshell-policy/src/merge.rs index df5a953b5..5b102cdb8 100644 --- a/crates/openshell-policy/src/merge.rs +++ b/crates/openshell-policy/src/merge.rs @@ -875,6 +875,51 @@ mod tests { assert_eq!(rule.binaries.len(), 2); } + #[test] + fn add_rule_merges_websocket_credential_rewrite_flag() { + let mut policy = restrictive_default_policy(); + policy.network_policies.insert( + "existing".to_string(), + NetworkPolicyRule { + name: "existing".to_string(), + endpoints: vec![NetworkEndpoint { + host: "realtime.example.com".to_string(), + port: 443, + ports: vec![443], + protocol: "websocket".to_string(), + access: "read-write".to_string(), + ..Default::default() + }], + ..Default::default() + }, + ); + + let incoming = NetworkPolicyRule { + name: "incoming".to_string(), + endpoints: vec![NetworkEndpoint { + host: "realtime.example.com".to_string(), + port: 443, + ports: vec![443], + protocol: "websocket".to_string(), + websocket_credential_rewrite: true, + ..Default::default() + }], + ..Default::default() + }; + + let result = merge_policy( + policy, + &[PolicyMergeOp::AddRule { + rule_name: "allow_realtime_example_com_443".to_string(), + rule: incoming, + }], + ) + .expect("merge should succeed"); + + let endpoint = &result.policy.network_policies["existing"].endpoints[0]; + assert!(endpoint.websocket_credential_rewrite); + } + #[test] fn add_allow_expands_access_preset() { let mut policy = restrictive_default_policy(); diff --git a/crates/openshell-sandbox/src/l7/relay.rs b/crates/openshell-sandbox/src/l7/relay.rs index b80dd03c0..09d0a9ef8 100644 --- a/crates/openshell-sandbox/src/l7/relay.rs +++ b/crates/openshell-sandbox/src/l7/relay.rs @@ -1260,7 +1260,7 @@ mod tests { use super::*; use crate::opa::{NetworkInput, OpaEngine}; use std::path::PathBuf; - use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; const TEST_POLICY: &str = include_str!("../../data/sandbox-policy.rego"); @@ -1445,6 +1445,161 @@ network_policies: assert_eq!(n, 0, "invalid response must not forward 101 headers"); } + #[tokio::test] + async fn route_selected_websocket_rewrites_text_credentials_after_upgrade() { + let data = r#" +network_policies: + route_api: + name: route_api + endpoints: + - host: gateway.example.test + port: 443 + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/ws" + - allow: + method: WEBSOCKET_TEXT + path: "/ws" + websocket_credential_rewrite: true + binaries: + - { path: /usr/bin/node } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let tunnel_engine = engine + .clone_engine_for_tunnel(engine.current_generation()) + .unwrap(); + let configs = vec![L7EndpointConfig { + protocol: L7Protocol::Websocket, + path: "/ws".into(), + tls: crate::l7::TlsMode::Auto, + enforcement: EnforcementMode::Enforce, + graphql_max_body_bytes: 0, + allow_encoded_slash: false, + websocket_credential_rewrite: true, + }]; + let (child_env, resolver) = SecretResolver::from_provider_env( + std::iter::once(("DISCORD_BOT_TOKEN".to_string(), "real-token".to_string())).collect(), + ); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").expect("placeholder env"); + let ctx = L7EvalContext { + host: "gateway.example.test".into(), + port: 443, + policy_name: "route_api".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: resolver.map(Arc::new), + }; + + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_route_selection( + &configs, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + app.write_all( + b"GET /ws HTTP/1.1\r\nHost: gateway.example.test\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n", + ) + .await + .unwrap(); + + let mut forwarded = [0u8; 1024]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut forwarded), + ) + .await + .expect("upgrade request should reach upstream") + .unwrap(); + let forwarded = String::from_utf8_lossy(&forwarded[..n]); + assert!(forwarded.contains("Upgrade: websocket\r\n")); + + upstream + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n", + ) + .await + .unwrap(); + + let mut response = [0u8; 1024]; + let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) + .await + .expect("client should receive upgrade response") + .unwrap(); + assert!(String::from_utf8_lossy(&response[..n]).contains("101 Switching Protocols")); + + let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); + app.write_all(&masked_text_frame(payload.as_bytes())) + .await + .unwrap(); + + let (masked, rewritten) = tokio::time::timeout( + std::time::Duration::from_secs(1), + read_text_frame(&mut upstream), + ) + .await + .expect("rewritten websocket text should reach upstream") + .unwrap(); + assert!(masked, "client-to-server frame must remain masked"); + assert_eq!(rewritten, r#"{"op":2,"d":{"token":"real-token"}}"#); + assert!(!rewritten.contains(placeholder)); + + drop(app); + drop(upstream); + let _ = tokio::time::timeout(std::time::Duration::from_secs(1), relay).await; + } + + fn masked_text_frame(payload: &[u8]) -> Vec { + let mask = [0x11, 0x22, 0x33, 0x44]; + assert!( + payload.len() <= 125, + "test helper only supports small frames" + ); + let payload_len = u8::try_from(payload.len()).expect("small frame length"); + let mut frame = vec![0x81, 0x80 | payload_len]; + frame.extend_from_slice(&mask); + frame.extend( + payload + .iter() + .enumerate() + .map(|(idx, byte)| byte ^ mask[idx % 4]), + ); + frame + } + + async fn read_text_frame( + reader: &mut R, + ) -> std::io::Result<(bool, String)> { + let mut header = [0u8; 2]; + reader.read_exact(&mut header).await?; + assert_eq!(header[0] & 0x0f, 0x1, "expected text frame"); + let masked = header[1] & 0x80 != 0; + let payload_len = usize::from(header[1] & 0x7f); + assert!(payload_len <= 125, "test helper only supports small frames"); + let mut mask = [0u8; 4]; + if masked { + reader.read_exact(&mut mask).await?; + } + let mut payload = vec![0u8; payload_len]; + reader.read_exact(&mut payload).await?; + if masked { + for (idx, byte) in payload.iter_mut().enumerate() { + *byte ^= mask[idx % 4]; + } + } + Ok((masked, String::from_utf8(payload).expect("text payload"))) + } + #[tokio::test] async fn l7_relay_closes_keep_alive_tunnel_after_policy_generation_change() { let initial_data = r#" diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index d5a47bcba..bdc8d8c1a 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -216,6 +216,9 @@ fn summarize_endpoint(endpoint: &NetworkEndpoint) -> String { if !endpoint.tls.is_empty() { parts.push(format!("tls={}", endpoint.tls)); } + if endpoint.websocket_credential_rewrite { + parts.push("websocket_credential_rewrite=true".to_string()); + } if !endpoint.allowed_ips.is_empty() { parts.push(format!("allowed_ips={}", endpoint.allowed_ips.len())); } @@ -4318,6 +4321,34 @@ mod tests { ); } + #[test] + fn summarize_cli_policy_merge_op_formats_websocket_credential_rewrite() { + let operation = PolicyMergeOp::AddRule { + rule_name: "realtime_api".to_string(), + rule: NetworkPolicyRule { + name: "realtime_api".to_string(), + endpoints: vec![NetworkEndpoint { + host: "realtime.example.com".to_string(), + port: 443, + protocol: "websocket".to_string(), + access: "read-write".to_string(), + enforcement: "enforce".to_string(), + websocket_credential_rewrite: true, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/node".to_string(), + ..Default::default() + }], + }, + }; + + assert_eq!( + summarize_cli_policy_merge_op(&operation), + "add-endpoint realtime_api endpoints=[realtime.example.com:443 protocol=websocket access=read-write enforcement=enforce websocket_credential_rewrite=true] binaries=[/usr/bin/node]" + ); + } + // ---- merge_chunk_into_policy ---- #[tokio::test] diff --git a/docs/sandboxes/policies.mdx b/docs/sandboxes/policies.mdx index 491271c42..b29f60010 100644 --- a/docs/sandboxes/policies.mdx +++ b/docs/sandboxes/policies.mdx @@ -183,6 +183,7 @@ The incremental update surface is split into endpoint-level operations and metho | `--add-allow ` | Appends method/path allow rules to an existing REST or WebSocket endpoint. | Permit one additional REST method/path or WebSocket `WEBSOCKET_TEXT` path on an API that is already configured. | | `--add-deny ` | Appends method/path deny rules to an existing REST or WebSocket endpoint. | Block a sensitive REST path or WebSocket text-message path under an endpoint that is otherwise allowed. | | `--binary ` | Adds binaries to every `--add-endpoint` rule in the same command. | Bind a new endpoint to one or more executables. | +| `--websocket-credential-rewrite` | Enables WebSocket text-frame credential placeholder rewriting on every `--add-endpoint` in the same command. | Let sandbox code send `openshell:resolve:env:*` placeholders in client WebSocket text frames while OpenShell resolves them at the relay boundary. | | `--rule-name ` | Overrides the generated rule name. | Keep a stable human-chosen rule name when adding exactly one endpoint. | | `--dry-run` | Shows the merged policy locally and does not call the gateway. | Review the result before persisting it. | | `--wait` | Polls until the sandbox reports that the new revision loaded. | Confirm the change took effect before continuing. | @@ -237,6 +238,8 @@ Examples: If you set `protocol: rest` or `protocol: websocket`, you also need an allow shape. With incremental updates, that means you should provide an `access` preset on `--add-endpoint`, then use `--add-allow` or `--add-deny` to refine method/path rules later. +Use `--websocket-credential-rewrite` with `protocol: websocket` when the sandbox should send credential placeholders in client text frames and have OpenShell resolve them after the allowed upgrade. The flag can also be used with `protocol: rest` compatibility endpoints that perform a WebSocket upgrade. It applies to every `--add-endpoint` in the same command and is rejected for plain L4 or `protocol: sql` endpoints. + For example: - `api.github.com:443:read-only:rest` is valid. @@ -351,11 +354,12 @@ Use `--add-endpoint` with `protocol: websocket` when the destination is an RFC 6 ```shell openshell policy update demo \ --add-endpoint realtime.example.com:443:read-write:websocket:enforce \ + --websocket-credential-rewrite \ --binary /usr/bin/node \ --wait ``` -This creates a WebSocket endpoint and sets its base allow behavior through the `read-write` access preset. For WebSocket endpoints, `read-write` expands to the upgrade `GET` and client `WEBSOCKET_TEXT` messages on the upgraded request path. +This creates a WebSocket endpoint and sets its base allow behavior through the `read-write` access preset. For WebSocket endpoints, `read-write` expands to the upgrade `GET` and client `WEBSOCKET_TEXT` messages on the upgraded request path. The rewrite flag lets the sandbox send `openshell:resolve:env:*` placeholders in client text frames; OpenShell resolves them before forwarding to the upstream service. #### Add a WebSocket text-message deny rule From 1562b70677555aa3a6fc24629d05be4711d98931 Mon Sep 17 00:00:00 2001 From: Aaron Erickson Date: Fri, 8 May 2026 22:57:17 -0700 Subject: [PATCH 11/17] fix(cli): make websocket rewrite endpoint-local Signed-off-by: Aaron Erickson --- crates/openshell-cli/src/main.rs | 10 +- crates/openshell-cli/src/policy_update.rs | 113 ++++++++++++++++------ crates/openshell-cli/src/run.rs | 2 - docs/sandboxes/policies.mdx | 16 +-- 4 files changed, 95 insertions(+), 46 deletions(-) diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index 14ab9811a..94951c6c7 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -287,7 +287,7 @@ const POLICY_EXAMPLES: &str = "\x1b[1mALIAS\x1b[0m $ openshell policy get my-sandbox $ openshell policy set my-sandbox --policy policy.yaml $ openshell policy update my-sandbox --add-endpoint api.github.com:443:read-only:rest:enforce - $ openshell policy update my-sandbox --add-endpoint realtime.example.com:443:read-write:websocket:enforce --websocket-credential-rewrite + $ openshell policy update my-sandbox --add-endpoint realtime.example.com:443:read-write:websocket:enforce:websocket-credential-rewrite $ openshell policy update my-sandbox --add-allow 'api.github.com:443:GET:/repos/**' $ openshell policy set --global --policy policy.yaml $ openshell policy delete --global @@ -1404,7 +1404,7 @@ enum PolicyCommands { #[arg(add = ArgValueCompleter::new(completers::complete_sandbox_names))] name: Option, - /// Add or merge an endpoint: host:port[:access[:protocol[:enforcement]]]. + /// Add or merge an endpoint: host:port[:access[:protocol[:enforcement[:options]]]]. #[arg(long = "add-endpoint")] add_endpoints: Vec, @@ -1428,10 +1428,6 @@ enum PolicyCommands { #[arg(long = "binary", value_hint = ValueHint::FilePath)] binaries: Vec, - /// Rewrite credential placeholders in WebSocket client text frames for added REST/WebSocket endpoints. - #[arg(long = "websocket-credential-rewrite")] - websocket_credential_rewrite: bool, - /// Override the generated rule name when exactly one --add-endpoint is provided. #[arg(long = "rule-name")] rule_name: Option, @@ -1979,7 +1975,6 @@ async fn main() -> Result<()> { add_deny, remove_rules, binaries, - websocket_credential_rewrite, rule_name, dry_run, wait, @@ -1995,7 +1990,6 @@ async fn main() -> Result<()> { &add_allow, &remove_rules, &binaries, - websocket_credential_rewrite, rule_name.as_deref(), dry_run, wait, diff --git a/crates/openshell-cli/src/policy_update.rs b/crates/openshell-cli/src/policy_update.rs index f5caa9c00..8fe4db269 100644 --- a/crates/openshell-cli/src/policy_update.rs +++ b/crates/openshell-cli/src/policy_update.rs @@ -26,7 +26,6 @@ pub fn build_policy_update_plan( add_allow: &[String], remove_rules: &[String], binaries: &[String], - websocket_credential_rewrite: bool, rule_name: Option<&str>, ) -> Result { if binaries.iter().any(|binary| binary.trim().is_empty()) { @@ -43,22 +42,12 @@ pub fn build_policy_update_plan( "--rule-name is only supported when exactly one --add-endpoint is provided" )); } - if websocket_credential_rewrite && add_endpoints.is_empty() { - return Err(miette!( - "--websocket-credential-rewrite can only be used with --add-endpoint" - )); - } - let mut merge_operations = Vec::new(); let mut preview_operations = Vec::new(); let deduped_binaries = dedup_strings(binaries); for spec in add_endpoints { - let mut endpoint = parse_add_endpoint_spec(spec)?; - if websocket_credential_rewrite { - ensure_websocket_credential_rewrite_protocol(spec, &endpoint)?; - endpoint.websocket_credential_rewrite = true; - } + let endpoint = parse_add_endpoint_spec(spec)?; let target_rule_name = rule_name .map(str::trim) .filter(|name| !name.is_empty()) @@ -179,7 +168,7 @@ fn ensure_websocket_credential_rewrite_protocol( endpoint.protocol.as_str() }; Err(miette!( - "--websocket-credential-rewrite requires --add-endpoint protocol segment to be 'rest' or 'websocket'; got '{protocol}' in '{spec}'" + "websocket-credential-rewrite endpoint option requires --add-endpoint protocol segment to be 'rest' or 'websocket'; got '{protocol}' in '{spec}'" )) } @@ -285,9 +274,9 @@ fn parse_remove_endpoint_spec(spec: &str) -> Result<(String, u32)> { fn parse_add_endpoint_spec(spec: &str) -> Result { let parts = spec.split(':').collect::>(); - if !(2..=5).contains(&parts.len()) { + if !(2..=6).contains(&parts.len()) { return Err(miette!( - "--add-endpoint expects host:port[:access[:protocol[:enforcement]]], got '{spec}'" + "--add-endpoint expects host:port[:access[:protocol[:enforcement[:options]]]], got '{spec}'" )); } @@ -297,12 +286,18 @@ fn parse_add_endpoint_spec(spec: &str) -> Result { let access = parts.get(2).copied().unwrap_or("").trim(); let protocol = parts.get(3).copied().unwrap_or("").trim(); let enforcement = parts.get(4).copied().unwrap_or("").trim(); + let options = parts.get(5).copied().unwrap_or("").trim(); if parts.len() == 3 && access.is_empty() { return Err(miette!( "--add-endpoint has an empty access segment in '{spec}'; omit it entirely if you do not need access or protocol fields" )); } + if parts.len() == 6 && options.is_empty() { + return Err(miette!( + "--add-endpoint has an empty options segment in '{spec}'; omit it entirely if you do not need endpoint options" + )); + } if !enforcement.is_empty() && protocol.is_empty() { return Err(miette!( "--add-endpoint cannot set enforcement without protocol in '{spec}'" @@ -324,7 +319,7 @@ fn parse_add_endpoint_spec(spec: &str) -> Result { )); } - Ok(NetworkEndpoint { + let mut endpoint = NetworkEndpoint { host, port, ports: vec![port], @@ -332,7 +327,41 @@ fn parse_add_endpoint_spec(spec: &str) -> Result { enforcement: enforcement.to_string(), access: access.to_string(), ..Default::default() - }) + }; + apply_add_endpoint_options(spec, &mut endpoint, options)?; + Ok(endpoint) +} + +fn apply_add_endpoint_options( + spec: &str, + endpoint: &mut NetworkEndpoint, + options: &str, +) -> Result<()> { + if options.is_empty() { + return Ok(()); + } + + for option in options.split(',') { + let option = option.trim(); + if option.is_empty() { + return Err(miette!( + "--add-endpoint options segment must not contain empty options in '{spec}'" + )); + } + match option { + "websocket-credential-rewrite" => { + ensure_websocket_credential_rewrite_protocol(spec, endpoint)?; + endpoint.websocket_credential_rewrite = true; + } + _ => { + return Err(miette!( + "--add-endpoint options segment supports only 'websocket-credential-rewrite'; got '{option}' in '{spec}'" + )); + } + } + } + + Ok(()) } fn parse_host(flag: &str, spec: &str, host: &str) -> Result { @@ -401,7 +430,6 @@ mod tests { add_allow, remove_rules, binaries, - false, rule_name, ) } @@ -469,14 +497,14 @@ mod tests { #[test] fn parse_add_endpoint_enables_websocket_credential_rewrite() { - let plan = build_policy_update_plan_with_options( - &["realtime.example.com:443:read-write:websocket:enforce".to_string()], + let plan = build_policy_update_plan( + &["realtime.example.com:443:read-write:websocket:enforce:websocket-credential-rewrite" + .to_string()], &[], &[], &[], &[], &[], - true, None, ) .expect("plan should build"); @@ -488,28 +516,57 @@ mod tests { } #[test] - fn websocket_credential_rewrite_requires_add_endpoint() { - let error = build_policy_update_plan_with_options(&[], &[], &[], &[], &[], &[], true, None) - .expect_err("plan should fail"); - assert!(error.to_string().contains("--websocket-credential-rewrite")); + fn parse_add_endpoint_enables_websocket_credential_rewrite_on_rest_compat_endpoint() { + let plan = build_policy_update_plan( + &[ + "realtime.example.com:443:read-write:rest:enforce:websocket-credential-rewrite" + .to_string(), + ], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect("plan should build"); + + let PolicyMergeOp::AddRule { rule, .. } = &plan.preview_operations[0] else { + panic!("expected add-rule preview"); + }; + assert!(rule.endpoints[0].websocket_credential_rewrite); } #[test] fn websocket_credential_rewrite_rejects_l4_endpoint() { - let error = build_policy_update_plan_with_options( - &["realtime.example.com:443".to_string()], + let error = build_policy_update_plan( + &["realtime.example.com:443::::websocket-credential-rewrite".to_string()], &[], &[], &[], &[], &[], - true, None, ) .expect_err("plan should fail"); assert!(error.to_string().contains("protocol segment")); } + #[test] + fn parse_add_endpoint_rejects_unknown_options() { + let error = build_policy_update_plan( + &["realtime.example.com:443:read-write:websocket:enforce:future-option".to_string()], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect_err("plan should fail"); + assert!(error.to_string().contains("options segment")); + } + #[test] fn parse_add_allow_accepts_websocket_text_method() { let plan = build_policy_update_plan( diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index 84a18983a..165713b6e 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -4705,7 +4705,6 @@ pub async fn sandbox_policy_update( add_allow: &[String], remove_rules: &[String], binaries: &[String], - websocket_credential_rewrite: bool, rule_name: Option<&str>, dry_run: bool, wait: bool, @@ -4723,7 +4722,6 @@ pub async fn sandbox_policy_update( add_allow, remove_rules, binaries, - websocket_credential_rewrite, rule_name, )?; diff --git a/docs/sandboxes/policies.mdx b/docs/sandboxes/policies.mdx index b29f60010..eed8d59ab 100644 --- a/docs/sandboxes/policies.mdx +++ b/docs/sandboxes/policies.mdx @@ -177,13 +177,12 @@ The incremental update surface is split into endpoint-level operations and metho | Flag | What it changes | Typical use | |---|---|---| -| `--add-endpoint ` | Creates or merges a network rule and endpoint. | Allow a new host and port, optionally with `access`, `protocol`, `enforcement`, and binaries. | +| `--add-endpoint ` | Creates or merges a network rule and endpoint. | Allow a new host and port, optionally with `access`, `protocol`, `enforcement`, endpoint options, and binaries. | | `--remove-endpoint ` | Removes one host and port match from the current policy. | Drop a stale endpoint or remove one port from a multi-port endpoint. | | `--remove-rule ` | Deletes a named `network_policies` entry. | Remove a whole rule by name when you no longer need it. | | `--add-allow ` | Appends method/path allow rules to an existing REST or WebSocket endpoint. | Permit one additional REST method/path or WebSocket `WEBSOCKET_TEXT` path on an API that is already configured. | | `--add-deny ` | Appends method/path deny rules to an existing REST or WebSocket endpoint. | Block a sensitive REST path or WebSocket text-message path under an endpoint that is otherwise allowed. | | `--binary ` | Adds binaries to every `--add-endpoint` rule in the same command. | Bind a new endpoint to one or more executables. | -| `--websocket-credential-rewrite` | Enables WebSocket text-frame credential placeholder rewriting on every `--add-endpoint` in the same command. | Let sandbox code send `openshell:resolve:env:*` placeholders in client WebSocket text frames while OpenShell resolves them at the relay boundary. | | `--rule-name ` | Overrides the generated rule name. | Keep a stable human-chosen rule name when adding exactly one endpoint. | | `--dry-run` | Shows the merged policy locally and does not call the gateway. | Review the result before persisting it. | | `--wait` | Polls until the sandbox reports that the new revision loaded. | Confirm the change took effect before continuing. | @@ -215,7 +214,7 @@ Current constraints: `--add-endpoint` uses this format: ```text -host:port[:access[:protocol[:enforcement]]] +host:port[:access[:protocol[:enforcement[:options]]]] ``` Each segment has a fixed meaning: @@ -227,6 +226,7 @@ Each segment has a fixed meaning: | `access` | No | Access preset for L7 endpoints: `read-only`, `read-write`, or `full`. Incremental updates expand presets into protocol-specific method/path rules for REST and WebSocket endpoints. | | `protocol` | No | L7 inspection mode: `rest`, `websocket`, or `sql`. `sql` is audit-only and not a recommended workflow today. | | `enforcement` | No | Enforcement mode for inspected traffic: `enforce` or `audit`. | +| `options` | No | Comma-separated endpoint options. Currently only `websocket-credential-rewrite` is supported, and only for `protocol: websocket` or REST compatibility endpoints that perform a WebSocket upgrade. | Examples: @@ -235,10 +235,11 @@ Examples: | `pypi.org:443` | Add a plain L4 endpoint. The proxy allows the TCP stream and does not inspect HTTP requests. | | `api.github.com:443:read-only:rest:enforce` | Add a REST endpoint with the `read-only` preset expanded by the policy engine into GET, HEAD, and OPTIONS access. | | `realtime.example.com:443:read-write:websocket:enforce` | Add a WebSocket endpoint with the `read-write` preset expanded by the policy engine into the upgrade `GET` and client `WEBSOCKET_TEXT` access. | +| `realtime.example.com:443:read-write:websocket:enforce:websocket-credential-rewrite` | Add a WebSocket endpoint that rewrites `openshell:resolve:env:*` placeholders in client text frames after an allowed upgrade. | If you set `protocol: rest` or `protocol: websocket`, you also need an allow shape. With incremental updates, that means you should provide an `access` preset on `--add-endpoint`, then use `--add-allow` or `--add-deny` to refine method/path rules later. -Use `--websocket-credential-rewrite` with `protocol: websocket` when the sandbox should send credential placeholders in client text frames and have OpenShell resolve them after the allowed upgrade. The flag can also be used with `protocol: rest` compatibility endpoints that perform a WebSocket upgrade. It applies to every `--add-endpoint` in the same command and is rejected for plain L4 or `protocol: sql` endpoints. +Use the `websocket-credential-rewrite` endpoint option with `protocol: websocket` when the sandbox should send credential placeholders in client text frames and have OpenShell resolve them after the allowed upgrade. The option can also be used with `protocol: rest` compatibility endpoints that perform a WebSocket upgrade. It is rejected for plain L4 or `protocol: sql` endpoints. For example: @@ -246,7 +247,7 @@ For example: - `realtime.example.com:443:read-write:websocket` is valid. - `api.github.com:443::rest` is invalid. It does not mean "allow all traffic." An L7 endpoint with `protocol` but no `access` or `rules` is rejected when the policy loads. -When you pass multiple `--add-endpoint` flags in one command, every `--binary` value applies to every added endpoint in that command. If different endpoints need different binaries, use separate `policy update` commands. +Endpoint options belong to the individual `--add-endpoint` spec. When you pass multiple `--add-endpoint` flags in one command, every `--binary` value applies to every added endpoint in that command. If different endpoints need different binaries, use separate `policy update` commands. If you do not pass `--rule-name`, OpenShell generates one from the host and port, such as `allow_api_github_com_443`. @@ -353,13 +354,12 @@ Use `--add-endpoint` with `protocol: websocket` when the destination is an RFC 6 ```shell openshell policy update demo \ - --add-endpoint realtime.example.com:443:read-write:websocket:enforce \ - --websocket-credential-rewrite \ + --add-endpoint realtime.example.com:443:read-write:websocket:enforce:websocket-credential-rewrite \ --binary /usr/bin/node \ --wait ``` -This creates a WebSocket endpoint and sets its base allow behavior through the `read-write` access preset. For WebSocket endpoints, `read-write` expands to the upgrade `GET` and client `WEBSOCKET_TEXT` messages on the upgraded request path. The rewrite flag lets the sandbox send `openshell:resolve:env:*` placeholders in client text frames; OpenShell resolves them before forwarding to the upstream service. +This creates a WebSocket endpoint and sets its base allow behavior through the `read-write` access preset. For WebSocket endpoints, `read-write` expands to the upgrade `GET` and client `WEBSOCKET_TEXT` messages on the upgraded request path. The rewrite option lets the sandbox send `openshell:resolve:env:*` placeholders in client text frames; OpenShell resolves them before forwarding to the upstream service. #### Add a WebSocket text-message deny rule From 55f52e73b0194aa68330c224429efb1f919bc142 Mon Sep 17 00:00:00 2001 From: Aaron Erickson Date: Sat, 9 May 2026 06:33:48 -0700 Subject: [PATCH 12/17] feat(sandbox): support graphql websocket policy Signed-off-by: Aaron Erickson --- .../data/sandbox-policy.rego | 10 + crates/openshell-sandbox/src/l7/graphql.rs | 13 + crates/openshell-sandbox/src/l7/mod.rs | 276 ++++++++++- crates/openshell-sandbox/src/l7/relay.rs | 161 ++++++- crates/openshell-sandbox/src/l7/websocket.rs | 455 +++++++++++++++++- crates/openshell-sandbox/src/opa.rs | 132 +++++ crates/openshell-sandbox/src/proxy.rs | 2 + docs/reference/policy-schema.mdx | 31 +- docs/sandboxes/policies.mdx | 34 +- docs/security/best-practices.mdx | 4 +- 10 files changed, 1086 insertions(+), 32 deletions(-) diff --git a/crates/openshell-sandbox/data/sandbox-policy.rego b/crates/openshell-sandbox/data/sandbox-policy.rego index 9fa820627..a8e4affce 100644 --- a/crates/openshell-sandbox/data/sandbox-policy.rego +++ b/crates/openshell-sandbox/data/sandbox-policy.rego @@ -260,6 +260,16 @@ request_denied_for_endpoint(request, endpoint) if { not graphql_request_allowed(request, endpoint) } +# The same authority applies when a WebSocket endpoint opts into GraphQL +# operation policy. Once the relay classifies a client text message as a +# GraphQL-over-WebSocket operation, generic WEBSOCKET_TEXT rules must not bypass +# operation_type / operation_name / fields policy. +request_denied_for_endpoint(request, endpoint) if { + endpoint.protocol == "websocket" + is_object(request.graphql) + not graphql_request_allowed(request, endpoint) +} + # Deny query matching: fail-closed semantics. # If no query rules on the deny rule, match unconditionally (any query params). # If query rules present, trigger the deny if ANY value for a configured key diff --git a/crates/openshell-sandbox/src/l7/graphql.rs b/crates/openshell-sandbox/src/l7/graphql.rs index db91ecb45..5d0746d01 100644 --- a/crates/openshell-sandbox/src/l7/graphql.rs +++ b/crates/openshell-sandbox/src/l7/graphql.rs @@ -78,6 +78,19 @@ pub fn classify_request(request: &L7Request, body: &[u8]) -> GraphqlRequestInfo } } +pub fn classify_json_envelope_value(value: &Value) -> GraphqlRequestInfo { + match classify_json_envelope(value) { + Ok(operations) => GraphqlRequestInfo { + operations, + error: None, + }, + Err(err) => GraphqlRequestInfo { + operations: Vec::new(), + error: Some(err), + }, + } +} + fn classify_request_inner( request: &L7Request, body: &[u8], diff --git a/crates/openshell-sandbox/src/l7/mod.rs b/crates/openshell-sandbox/src/l7/mod.rs index e553a9e05..e1dc74b21 100644 --- a/crates/openshell-sandbox/src/l7/mod.rs +++ b/crates/openshell-sandbox/src/l7/mod.rs @@ -78,6 +78,9 @@ pub struct L7EndpointConfig { /// Opt-in rewrite of credential placeholders in client-to-server /// WebSocket text messages after an allowed HTTP 101 upgrade. pub websocket_credential_rewrite: bool, + /// When true, client-to-server GraphQL-over-WebSocket operation messages + /// are classified with the same operation policy used by GraphQL-over-HTTP. + pub websocket_graphql_policy: bool, } /// Result of an L7 policy decision for a single request. @@ -146,6 +149,8 @@ pub fn parse_l7_config(val: ®orus::Value) -> Option { let allow_encoded_slash = get_object_bool(val, "allow_encoded_slash").unwrap_or(false); let websocket_credential_rewrite = get_object_bool(val, "websocket_credential_rewrite").unwrap_or(false); + let websocket_graphql_policy = + protocol == L7Protocol::Websocket && endpoint_has_graphql_policy(val); let graphql_max_body_bytes = get_object_u64(val, "graphql_max_body_bytes") .and_then(|v| usize::try_from(v).ok()) .filter(|v| *v > 0) @@ -159,6 +164,7 @@ pub fn parse_l7_config(val: ®orus::Value) -> Option { graphql_max_body_bytes, allow_encoded_slash, websocket_credential_rewrite, + websocket_graphql_policy, }) } @@ -240,6 +246,60 @@ fn get_object_str(val: ®orus::Value, key: &str) -> Option { } } +fn endpoint_has_graphql_policy(val: ®orus::Value) -> bool { + has_non_empty_object_field(val, "graphql_persisted_queries") + || has_graphql_persisted_query_mode(val) + || rules_have_graphql_policy(val, "rules", true) + || rules_have_graphql_policy(val, "deny_rules", false) +} + +fn rules_have_graphql_policy(val: ®orus::Value, key: &str, allow_wrapped: bool) -> bool { + let Some(regorus::Value::Array(rules)) = get_object_value(val, key) else { + return false; + }; + rules.iter().any(|rule| { + let rule = if allow_wrapped { + get_object_value(rule, "allow").unwrap_or(rule) + } else { + rule + }; + has_graphql_rule_fields(rule) + }) +} + +fn has_graphql_rule_fields(val: ®orus::Value) -> bool { + has_non_empty_string_field(val, "operation_type") + || has_non_empty_string_field(val, "operation_name") + || has_non_empty_array_field(val, "fields") +} + +fn has_non_empty_string_field(val: ®orus::Value, key: &str) -> bool { + matches!(get_object_value(val, key), Some(regorus::Value::String(s)) if !s.is_empty()) +} + +fn has_non_empty_array_field(val: ®orus::Value, key: &str) -> bool { + matches!(get_object_value(val, key), Some(regorus::Value::Array(values)) if !values.is_empty()) +} + +fn has_non_empty_object_field(val: ®orus::Value, key: &str) -> bool { + matches!(get_object_value(val, key), Some(regorus::Value::Object(values)) if !values.is_empty()) +} + +fn has_graphql_persisted_query_mode(val: ®orus::Value) -> bool { + matches!( + get_object_value(val, "persisted_queries"), + Some(regorus::Value::String(mode)) if !mode.is_empty() && mode.as_ref() != "deny" + ) +} + +fn get_object_value<'a>(val: &'a regorus::Value, key: &str) -> Option<&'a regorus::Value> { + let key_val = regorus::Value::String(key.into()); + match val { + regorus::Value::Object(map) => map.get(&key_val), + _ => None, + } +} + /// Check a glob pattern for obvious syntax issues. /// /// Returns `Some(warning_message)` if the pattern looks malformed. @@ -362,6 +422,45 @@ fn validate_graphql_rule( validate_graphql_fields(errors, warnings, loc, rule.get("fields")); } +fn json_rule_has_graphql_fields(rule: &serde_json::Value) -> bool { + rule.get("operation_type") + .and_then(|v| v.as_str()) + .is_some_and(|v| !v.is_empty()) + || rule + .get("operation_name") + .and_then(|v| v.as_str()) + .is_some_and(|v| !v.is_empty()) + || rule.get("fields").is_some() +} + +fn json_rule_has_transport_fields(rule: &serde_json::Value) -> bool { + rule.get("method").is_some() || rule.get("path").is_some() || rule.get("query").is_some() +} + +fn json_endpoint_has_graphql_policy(ep: &serde_json::Value) -> bool { + ep.get("graphql_persisted_queries") + .and_then(|v| v.as_object()) + .is_some_and(|v| !v.is_empty()) + || ep + .get("persisted_queries") + .and_then(|v| v.as_str()) + .is_some_and(|v| !v.is_empty() && v != "deny") + || ep + .get("rules") + .and_then(|v| v.as_array()) + .is_some_and(|rules| { + rules.iter().any(|rule| { + rule.get("allow") + .or(Some(rule)) + .is_some_and(json_rule_has_graphql_fields) + }) + }) + || ep + .get("deny_rules") + .and_then(|v| v.as_array()) + .is_some_and(|rules| rules.iter().any(json_rule_has_graphql_fields)) +} + /// Validate L7 policy configuration in the loaded OPA data. /// /// Returns a list of errors and warnings. Errors should prevent sandbox startup; @@ -391,6 +490,8 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< .get("rules") .and_then(|v| v.as_array()) .is_some_and(|a| !a.is_empty()); + let websocket_has_graphql_policy = + protocol == "websocket" && json_endpoint_has_graphql_policy(ep); let host = ep.get("host").and_then(|v| v.as_str()).unwrap_or(""); let endpoint_path = ep.get("path").and_then(|v| v.as_str()).unwrap_or(""); @@ -498,12 +599,13 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< } if protocol != "graphql" + && protocol != "websocket" && (ep.get("persisted_queries").is_some() || ep.get("graphql_persisted_queries").is_some() || ep.get("graphql_max_body_bytes").is_some()) { warnings.push(format!( - "{loc}: GraphQL-specific endpoint fields are ignored unless protocol is graphql" + "{loc}: GraphQL-specific endpoint fields are ignored unless protocol is graphql or websocket" )); } @@ -721,7 +823,17 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< .push(format!("{deny_loc}: command is for SQL protocol, not REST")); } - if protocol == "graphql" { + let deny_has_graphql = json_rule_has_graphql_fields(deny_rule); + if protocol == "websocket" + && deny_has_graphql + && json_rule_has_transport_fields(deny_rule) + { + errors.push(format!( + "{deny_loc}: WebSocket GraphQL deny rules must not combine method/path/query with operation_type/operation_name/fields" + )); + } + + if protocol == "graphql" || (protocol == "websocket" && deny_has_graphql) { validate_graphql_rule( &mut errors, &mut warnings, @@ -729,12 +841,9 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< deny_rule, true, ); - } else if deny_rule.get("operation_type").is_some() - || deny_rule.get("operation_name").is_some() - || deny_rule.get("fields").is_some() - { + } else if deny_has_graphql { warnings.push(format!( - "{deny_loc}: GraphQL rule fields are ignored unless protocol is graphql" + "{deny_loc}: GraphQL rule fields are ignored unless protocol is graphql or websocket" )); } } @@ -877,14 +986,36 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< } } - if has_rules - && protocol == "graphql" - && let Some(rules) = ep.get("rules").and_then(|v| v.as_array()) - { + if has_rules && let Some(rules) = ep.get("rules").and_then(|v| v.as_array()) { for (rule_idx, rule) in rules.iter().enumerate() { let allow = rule.get("allow").unwrap_or(rule); let rule_loc = format!("{loc}.rules[{rule_idx}].allow"); - validate_graphql_rule(&mut errors, &mut warnings, &rule_loc, allow, true); + let allow_has_graphql = json_rule_has_graphql_fields(allow); + if websocket_has_graphql_policy + && allow + .get("method") + .and_then(|m| m.as_str()) + .is_some_and(|method| method.eq_ignore_ascii_case("WEBSOCKET_TEXT")) + { + errors.push(format!( + "{rule_loc}: WebSocket endpoints with GraphQL operation policy must use operation_type/operation_name/fields rules for client messages instead of WEBSOCKET_TEXT" + )); + } + if protocol == "websocket" + && allow_has_graphql + && json_rule_has_transport_fields(allow) + { + errors.push(format!( + "{rule_loc}: WebSocket GraphQL allow rules must not combine method/path/query with operation_type/operation_name/fields" + )); + } + if protocol == "graphql" || (protocol == "websocket" && allow_has_graphql) { + validate_graphql_rule(&mut errors, &mut warnings, &rule_loc, allow, true); + } else if allow_has_graphql { + warnings.push(format!( + "{rule_loc}: GraphQL rule fields are ignored unless protocol is graphql or websocket" + )); + } } } } @@ -1096,6 +1227,26 @@ mod tests { assert!(config.websocket_credential_rewrite); } + #[test] + fn parse_l7_config_websocket_graphql_policy_defaults_false() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "websocket", "host": "gateway.example.com", "port": 443, "rules": [{"allow": {"method": "GET", "path": "/graphql"}}, {"allow": {"method": "WEBSOCKET_TEXT", "path": "/graphql"}}]}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(!config.websocket_graphql_policy); + } + + #[test] + fn parse_l7_config_websocket_graphql_policy_detects_operation_rules() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "websocket", "host": "gateway.example.com", "port": 443, "rules": [{"allow": {"method": "GET", "path": "/graphql"}}, {"allow": {"operation_type": "subscription", "fields": ["messageAdded"]}}]}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(config.websocket_graphql_policy); + } + #[test] fn validate_websocket_credential_rewrite_warns_unless_rest_or_websocket() { let data = serde_json::json!({ @@ -1147,6 +1298,107 @@ mod tests { assert!(methods.contains(&"WEBSOCKET_TEXT")); } + #[test] + fn validate_websocket_accepts_graphql_operation_rules() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "rules": [ + {"allow": {"method": "GET", "path": "/graphql"}}, + {"allow": {"operation_type": "subscription", "fields": ["messageAdded"]}} + ] + }], + "binaries": [] + } + } + }); + let (errors, warnings) = validate_l7_policies(&data); + assert!(errors.is_empty(), "expected no errors: {errors:?}"); + assert!(warnings.is_empty(), "expected no warnings: {warnings:?}"); + } + + #[test] + fn validate_websocket_graphql_rule_requires_operation_type() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "rules": [ + {"allow": {"method": "GET", "path": "/graphql"}}, + {"allow": {"fields": ["messageAdded"]}} + ] + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors.iter().any(|e| e.contains("operation_type")), + "expected missing operation_type error: {errors:?}" + ); + } + + #[test] + fn validate_websocket_graphql_rule_rejects_mixed_transport_fields() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "rules": [ + {"allow": {"method": "GET", "path": "/graphql"}}, + {"allow": {"method": "WEBSOCKET_TEXT", "path": "/graphql", "operation_type": "subscription"}} + ] + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors.iter().any(|e| e.contains("must not combine")), + "expected mixed-field error: {errors:?}" + ); + } + + #[test] + fn validate_websocket_graphql_policy_rejects_raw_text_message_rule() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "rules": [ + {"allow": {"method": "GET", "path": "/graphql"}}, + {"allow": {"method": "WEBSOCKET_TEXT", "path": "/graphql"}}, + {"allow": {"operation_type": "query"}} + ] + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors + .iter() + .any(|e| e.contains("instead of WEBSOCKET_TEXT")), + "expected raw WEBSOCKET_TEXT rejection: {errors:?}" + ); + } + #[test] fn validate_rules_and_access_mutual_exclusion() { let data = serde_json::json!({ diff --git a/crates/openshell-sandbox/src/l7/relay.rs b/crates/openshell-sandbox/src/l7/relay.rs index 09d0a9ef8..4014bd273 100644 --- a/crates/openshell-sandbox/src/l7/relay.rs +++ b/crates/openshell-sandbox/src/l7/relay.rs @@ -55,10 +55,28 @@ pub(crate) struct UpgradeRelayOptions<'a> { #[derive(Default)] pub(crate) struct WebSocketUpgradeBehavior { pub(crate) credential_rewrite: bool, - pub(crate) message_inspection: bool, + pub(crate) message_policy: WebSocketMessagePolicy, pub(crate) permessage_deflate: bool, } +#[derive(Clone, Copy, Default, PartialEq, Eq)] +pub(crate) enum WebSocketMessagePolicy { + #[default] + None, + Transport, + Graphql, +} + +impl WebSocketMessagePolicy { + fn inspects_messages(self) -> bool { + self != Self::None + } + + fn is_graphql(self) -> bool { + self == Self::Graphql + } +} + #[derive(Debug, Clone, Copy)] enum ParseRejectionMode { L7Endpoint, @@ -448,7 +466,7 @@ where U: AsyncRead + AsyncWrite + Unpin + Send, { let use_websocket_relay = options.websocket_request - && (options.websocket.message_inspection + && (options.websocket.message_policy.inspects_messages() || options.websocket.permessage_deflate || (options.websocket.credential_rewrite && options.secret_resolver.is_some())); let relay_mode = if use_websocket_relay { @@ -474,7 +492,7 @@ where } else { None }; - let inspector = if options.websocket.message_inspection { + let inspector = if options.websocket.message_policy.inspects_messages() { match (options.engine, options.ctx) { (Some(engine), Some(ctx)) => Some(crate::l7::websocket::InspectionOptions { engine, @@ -482,6 +500,7 @@ where enforcement: options.enforcement, target: options.target.clone(), query_params: options.query_params.clone(), + graphql_policy: options.websocket.message_policy.is_graphql(), }), _ => { return Err(miette!( @@ -533,12 +552,20 @@ fn upgrade_options<'a>( let websocket_credential_rewrite = matches!(config.protocol, L7Protocol::Rest | L7Protocol::Websocket) && config.websocket_credential_rewrite; - let websocket_message_inspection = config.protocol == L7Protocol::Websocket; + let websocket_message_policy = if config.protocol == L7Protocol::Websocket { + if config.websocket_graphql_policy { + WebSocketMessagePolicy::Graphql + } else { + WebSocketMessagePolicy::Transport + } + } else { + WebSocketMessagePolicy::None + }; UpgradeRelayOptions { websocket_request, websocket: WebSocketUpgradeBehavior { credential_rewrite: websocket_credential_rewrite, - message_inspection: websocket_message_inspection, + message_policy: websocket_message_policy, permessage_deflate: false, }, secret_resolver: if websocket_credential_rewrite { @@ -1381,6 +1408,7 @@ network_policies: graphql_max_body_bytes: 0, allow_encoded_slash: false, websocket_credential_rewrite: true, + websocket_graphql_policy: false, }]; let ctx = L7EvalContext { host: "gateway.example.test".into(), @@ -1479,6 +1507,7 @@ network_policies: graphql_max_body_bytes: 0, allow_encoded_slash: false, websocket_credential_rewrite: true, + websocket_graphql_policy: false, }]; let (child_env, resolver) = SecretResolver::from_provider_env( std::iter::once(("DISCORD_BOT_TOKEN".to_string(), "real-token".to_string())).collect(), @@ -1559,6 +1588,128 @@ network_policies: let _ = tokio::time::timeout(std::time::Duration::from_secs(1), relay).await; } + #[tokio::test] + async fn route_selected_graphql_websocket_rewrites_connection_init_credentials_after_upgrade() { + let data = r#" +network_policies: + route_api: + name: route_api + endpoints: + - host: gateway.example.test + port: 443 + path: "/graphql" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/graphql" + - allow: + operation_type: query + fields: [viewer] + websocket_credential_rewrite: true + binaries: + - { path: /usr/bin/node } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let tunnel_engine = engine + .clone_engine_for_tunnel(engine.current_generation()) + .unwrap(); + let configs = vec![L7EndpointConfig { + protocol: L7Protocol::Websocket, + path: "/graphql".into(), + tls: crate::l7::TlsMode::Auto, + enforcement: EnforcementMode::Enforce, + graphql_max_body_bytes: 0, + allow_encoded_slash: false, + websocket_credential_rewrite: true, + websocket_graphql_policy: true, + }]; + let (child_env, resolver) = SecretResolver::from_provider_env( + std::iter::once(("T".to_string(), "real-token".to_string())).collect(), + ); + let placeholder = child_env.get("T").expect("placeholder env"); + let ctx = L7EvalContext { + host: "gateway.example.test".into(), + port: 443, + policy_name: "route_api".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: resolver.map(Arc::new), + }; + + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_route_selection( + &configs, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + app.write_all( + b"GET /graphql HTTP/1.1\r\nHost: gateway.example.test\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n", + ) + .await + .unwrap(); + + let mut forwarded = [0u8; 1024]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut forwarded), + ) + .await + .expect("upgrade request should reach upstream") + .unwrap(); + let forwarded = String::from_utf8_lossy(&forwarded[..n]); + assert!(forwarded.contains("GET /graphql HTTP/1.1")); + assert!(forwarded.contains("Upgrade: websocket\r\n")); + + upstream + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n", + ) + .await + .unwrap(); + + let mut response = [0u8; 1024]; + let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) + .await + .expect("client should receive upgrade response") + .unwrap(); + assert!(String::from_utf8_lossy(&response[..n]).contains("101 Switching Protocols")); + + let payload = format!( + r#"{{"type":"connection_init","payload":{{"authorization":"{placeholder}"}}}}"# + ); + app.write_all(&masked_text_frame(payload.as_bytes())) + .await + .unwrap(); + + let (masked, rewritten) = tokio::time::timeout( + std::time::Duration::from_secs(1), + read_text_frame(&mut upstream), + ) + .await + .expect("rewritten GraphQL WebSocket control message should reach upstream") + .unwrap(); + assert!(masked, "client-to-server frame must remain masked"); + assert_eq!( + rewritten, + r#"{"type":"connection_init","payload":{"authorization":"real-token"}}"# + ); + assert!(!rewritten.contains(placeholder)); + + drop(app); + drop(upstream); + let _ = tokio::time::timeout(std::time::Duration::from_secs(1), relay).await; + } + fn masked_text_frame(payload: &[u8]) -> Vec { let mask = [0x11, 0x22, 0x33, 0x44]; assert!( diff --git a/crates/openshell-sandbox/src/l7/websocket.rs b/crates/openshell-sandbox/src/l7/websocket.rs index 0777bfb1e..2dc1b25c3 100644 --- a/crates/openshell-sandbox/src/l7/websocket.rs +++ b/crates/openshell-sandbox/src/l7/websocket.rs @@ -59,6 +59,7 @@ pub(super) struct InspectionOptions<'a> { pub(super) enforcement: EnforcementMode, pub(super) target: String, pub(super) query_params: HashMap>, + pub(super) graphql_policy: bool, } pub(super) struct RelayOptions<'a> { @@ -500,7 +501,7 @@ async fn relay_text_payload( }; if let Some(inspector) = options.inspector.as_ref() { - inspect_websocket_text_message(host, port, options.policy_name, inspector)?; + inspect_websocket_text_message(host, port, options.policy_name, inspector, &text)?; } if replacements == 0 && !force_reframe && !compressed { @@ -533,7 +534,12 @@ fn inspect_websocket_text_message( port: u16, policy_name: &str, inspector: &InspectionOptions<'_>, + text: &str, ) -> Result<()> { + if inspector.graphql_policy { + return inspect_graphql_websocket_message(host, port, policy_name, inspector, text); + } + let request_info = L7RequestInfo { action: "WEBSOCKET_TEXT".to_string(), target: inspector.target.clone(), @@ -546,13 +552,175 @@ fn inspect_websocket_text_message( (false, EnforcementMode::Audit) => "audit", (false, EnforcementMode::Enforce) => "deny", }; - emit_websocket_l7_event(host, port, policy_name, &request_info, decision, &reason); + emit_websocket_l7_event( + host, + port, + policy_name, + &request_info, + decision, + &reason, + None, + ); if !allowed && inspector.enforcement == EnforcementMode::Enforce { return Err(miette!("websocket text message denied by policy")); } Ok(()) } +fn inspect_graphql_websocket_message( + host: &str, + port: u16, + policy_name: &str, + inspector: &InspectionOptions<'_>, + text: &str, +) -> Result<()> { + match classify_graphql_websocket_message(text) { + GraphqlWebSocketMessage::Control { message_type } => { + let request_info = L7RequestInfo { + action: "WEBSOCKET_CONTROL".to_string(), + target: inspector.target.clone(), + query_params: inspector.query_params.clone(), + graphql: None, + }; + emit_websocket_l7_event( + host, + port, + policy_name, + &request_info, + "allow", + &format!("GraphQL WebSocket control message {message_type}"), + None, + ); + Ok(()) + } + GraphqlWebSocketMessage::Operation { + message_type, + graphql, + } => { + let request_info = L7RequestInfo { + action: "WEBSOCKET_TEXT".to_string(), + target: inspector.target.clone(), + query_params: inspector.query_params.clone(), + graphql: Some(graphql.clone()), + }; + let parse_error_reason = graphql + .error + .as_deref() + .map(|error| format!("GraphQL WebSocket message rejected: {error}")); + let force_deny = parse_error_reason.is_some(); + let (allowed, reason) = if let Some(reason) = parse_error_reason { + (false, reason) + } else { + evaluate_l7_request(inspector.engine, inspector.ctx, &request_info)? + }; + let decision = match (allowed, inspector.enforcement) { + (_, _) if force_deny => "deny", + (true, _) => "allow", + (false, EnforcementMode::Audit) => "audit", + (false, EnforcementMode::Enforce) => "deny", + }; + let reason = format!("graphql_ws_type={message_type} {reason}"); + emit_websocket_l7_event( + host, + port, + policy_name, + &request_info, + decision, + &reason, + Some(&graphql), + ); + if (!allowed && inspector.enforcement == EnforcementMode::Enforce) || force_deny { + return Err(miette!("websocket GraphQL message denied by policy")); + } + Ok(()) + } + } +} + +#[derive(Debug)] +enum GraphqlWebSocketMessage { + Control { + message_type: String, + }, + Operation { + message_type: String, + graphql: crate::l7::graphql::GraphqlRequestInfo, + }, +} + +fn classify_graphql_websocket_message(text: &str) -> GraphqlWebSocketMessage { + let value = match serde_json::from_str::(text) { + Ok(value) => value, + Err(err) => { + return GraphqlWebSocketMessage::Operation { + message_type: "unknown".to_string(), + graphql: graphql_error(format!( + "GraphQL WebSocket message is not valid JSON: {err}" + )), + }; + } + }; + let Some(obj) = value.as_object() else { + return GraphqlWebSocketMessage::Operation { + message_type: "unknown".to_string(), + graphql: graphql_error("GraphQL WebSocket message must be a JSON object"), + }; + }; + let Some(message_type) = obj.get("type").and_then(serde_json::Value::as_str) else { + return GraphqlWebSocketMessage::Operation { + message_type: "unknown".to_string(), + graphql: graphql_error("GraphQL WebSocket message missing string type"), + }; + }; + + match message_type { + "subscribe" | "start" => { + if obj + .get("id") + .and_then(serde_json::Value::as_str) + .is_none_or(str::is_empty) + { + return GraphqlWebSocketMessage::Operation { + message_type: message_type.to_string(), + graphql: graphql_error( + "GraphQL WebSocket operation message missing non-empty id", + ), + }; + } + let Some(payload) = obj.get("payload").filter(|value| value.is_object()) else { + return GraphqlWebSocketMessage::Operation { + message_type: message_type.to_string(), + graphql: graphql_error( + "GraphQL WebSocket operation message missing object payload", + ), + }; + }; + GraphqlWebSocketMessage::Operation { + message_type: message_type.to_string(), + graphql: crate::l7::graphql::classify_json_envelope_value(payload), + } + } + "connection_init" | "connection_terminate" | "ping" | "pong" | "complete" | "stop" => { + GraphqlWebSocketMessage::Control { + message_type: message_type.to_string(), + } + } + _ => GraphqlWebSocketMessage::Operation { + message_type: message_type.to_string(), + graphql: graphql_error(format!( + "unsupported GraphQL WebSocket client message type {message_type:?}" + )), + }, + } +} + +fn graphql_error(message: impl Into) -> crate::l7::graphql::GraphqlRequestInfo { + crate::l7::graphql::GraphqlRequestInfo { + operations: Vec::new(), + error: Some(message.into()), + } +} + async fn relay_control_frame( reader: &mut R, writer: &mut W, @@ -812,6 +980,7 @@ fn emit_websocket_l7_event( request_info: &L7RequestInfo, decision: &str, reason: &str, + graphql: Option<&crate::l7::graphql::GraphqlRequestInfo>, ) { let policy_name = if policy_name.is_empty() { "-" @@ -831,6 +1000,7 @@ fn emit_websocket_l7_event( SeverityId::Informational, ), }; + let summary = graphql.map(graphql_log_summary).unwrap_or_default(); let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) .activity(ActivityId::Other) .action(action_id) @@ -840,13 +1010,41 @@ fn emit_websocket_l7_event( .dst_endpoint(Endpoint::from_domain(host, port)) .firewall_rule(policy_name, "l7-websocket") .message(format!( - "WEBSOCKET_L7_REQUEST {decision} {} {host}:{port}{} reason={reason}", - request_info.action, request_info.target + "WEBSOCKET_L7_REQUEST {decision} {} {host}:{port}{}{} reason={reason}", + request_info.action, request_info.target, summary )) .build(); ocsf_emit!(event); } +fn graphql_log_summary(info: &crate::l7::graphql::GraphqlRequestInfo) -> String { + if let Some(error) = info.error.as_deref() { + return format!(" graphql_error={error:?}"); + } + let ops: Vec = info + .operations + .iter() + .map(|op| { + let name = op.operation_name.as_deref().unwrap_or("-"); + let fields = if op.fields.is_empty() { + "-".to_string() + } else { + op.fields.join(",") + }; + let persisted = op + .persisted_query_hash + .as_deref() + .or(op.persisted_query_id.as_deref()) + .unwrap_or("-"); + format!( + "type={} name={} fields={} persisted={}", + op.operation_type, name, fields, persisted + ) + }) + .collect(); + format!(" graphql_ops={}", ops.join(";")) +} + fn protocol_failure_class(error: &miette::Report) -> &'static str { let msg = error.to_string().to_ascii_lowercase(); if msg.contains("credential") { @@ -905,9 +1103,37 @@ fn protocol_failure_message(host: &str, port: u16) -> String { #[cfg(test)] mod tests { use super::*; + use crate::l7::relay::L7EvalContext; + use crate::opa::{NetworkInput, OpaEngine}; use crate::secrets::SecretResolver; + use std::path::PathBuf; use tokio::io::{AsyncReadExt, AsyncWriteExt}; + const TEST_POLICY: &str = include_str!("../../data/sandbox-policy.rego"); + const GRAPHQL_WS_POLICY: &str = r#" +network_policies: + graphql_ws: + name: graphql_ws + endpoints: + - host: realtime.graphql.test + port: 443 + path: "/graphql" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/graphql" + - allow: + operation_type: query + fields: [viewer] + - allow: + operation_type: subscription + fields: [messageAdded] + binaries: + - { path: /usr/bin/node } +"#; + fn resolver() -> (HashMap, SecretResolver) { let (child_env, resolver) = SecretResolver::from_provider_env( std::iter::once(("DISCORD_BOT_TOKEN".to_string(), "real-token".to_string())).collect(), @@ -1015,6 +1241,70 @@ mod tests { result.map(|()| output) } + async fn run_client_to_server_with_graphql_policy( + input: Vec, + resolver: Option<&SecretResolver>, + ) -> Result> { + let engine = OpaEngine::from_strings(TEST_POLICY, GRAPHQL_WS_POLICY) + .expect("GraphQL WebSocket policy should load"); + let network_input = NetworkInput { + host: "realtime.graphql.test".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let generation = engine + .evaluate_network_action_with_generation(&network_input) + .expect("network action should evaluate") + .1; + let tunnel_engine = engine + .clone_engine_for_tunnel(generation) + .expect("tunnel engine"); + let ctx = L7EvalContext { + host: "realtime.graphql.test".into(), + port: 443, + policy_name: "graphql_ws".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + }; + let (mut client_write, mut relay_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + let (mut relay_write, mut upstream_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + + client_write.write_all(&input).await.unwrap(); + drop(client_write); + + let options = RelayOptions { + policy_name: "graphql_ws", + resolver, + inspector: Some(InspectionOptions { + engine: &tunnel_engine, + ctx: &ctx, + enforcement: EnforcementMode::Enforce, + target: "/graphql".to_string(), + query_params: HashMap::new(), + graphql_policy: true, + }), + compression: WebSocketCompression::None, + }; + let result = relay_client_to_server( + &mut relay_read, + &mut relay_write, + "realtime.graphql.test", + 443, + &options, + ) + .await; + drop(relay_write); + + let mut output = Vec::new(); + upstream_read.read_to_end(&mut output).await.unwrap(); + result.map(|()| output) + } + async fn run_client_to_server_compressed(input: Vec) -> Result> { let (_, resolver) = resolver(); let (mut client_write, mut relay_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); @@ -1117,6 +1407,113 @@ mod tests { frame } + #[test] + fn classifies_graphql_transport_ws_subscribe_operation() { + let message = r#"{"type":"subscribe","id":"1","payload":{"query":"subscription NewMessages { messageAdded }"}}"#; + + match classify_graphql_websocket_message(message) { + GraphqlWebSocketMessage::Operation { + message_type, + graphql, + } => { + assert_eq!(message_type, "subscribe"); + assert!( + graphql.error.is_none(), + "unexpected error: {:?}", + graphql.error + ); + assert_eq!(graphql.operations.len(), 1); + assert_eq!(graphql.operations[0].operation_type, "subscription"); + assert_eq!( + graphql.operations[0].operation_name.as_deref(), + Some("NewMessages") + ); + assert_eq!(graphql.operations[0].fields, vec!["messageAdded"]); + } + other @ GraphqlWebSocketMessage::Control { .. } => { + panic!("expected operation, got {other:?}") + } + } + } + + #[test] + fn classifies_legacy_graphql_ws_start_operation() { + let message = r#"{"type":"start","id":"1","payload":{"query":"query Viewer { viewer }"}}"#; + + match classify_graphql_websocket_message(message) { + GraphqlWebSocketMessage::Operation { + message_type, + graphql, + } => { + assert_eq!(message_type, "start"); + assert!( + graphql.error.is_none(), + "unexpected error: {:?}", + graphql.error + ); + assert_eq!(graphql.operations[0].operation_type, "query"); + assert_eq!(graphql.operations[0].fields, vec!["viewer"]); + } + other @ GraphqlWebSocketMessage::Control { .. } => { + panic!("expected operation, got {other:?}") + } + } + } + + #[test] + fn classifies_graphql_websocket_control_message_without_payload_logging() { + match classify_graphql_websocket_message( + r#"{"type":"connection_init","payload":{"authorization":"secret"}}"#, + ) { + GraphqlWebSocketMessage::Control { message_type } => { + assert_eq!(message_type, "connection_init"); + } + other @ GraphqlWebSocketMessage::Operation { .. } => { + panic!("expected control message, got {other:?}") + } + } + } + + #[test] + fn unsupported_graphql_websocket_message_type_fails_closed() { + match classify_graphql_websocket_message(r#"{"type":"next","id":"1"}"#) { + GraphqlWebSocketMessage::Operation { graphql, .. } => { + assert!( + graphql + .error + .as_deref() + .is_some_and(|error| error.contains("unsupported")) + ); + } + other @ GraphqlWebSocketMessage::Control { .. } => { + panic!("expected operation error, got {other:?}") + } + } + } + + #[test] + fn graphql_websocket_log_summary_excludes_payload_variables_and_secrets() { + let placeholder = "openshell:resolve:env:T"; + let message = format!( + r#"{{"type":"subscribe","id":"1","payload":{{"query":"query Viewer {{ viewer }}","variables":{{"token":"{placeholder}"}}}}}}"# + ); + let graphql = match classify_graphql_websocket_message(&message) { + GraphqlWebSocketMessage::Operation { graphql, .. } => graphql, + other @ GraphqlWebSocketMessage::Control { .. } => { + panic!("expected operation, got {other:?}") + } + }; + let summary = graphql_log_summary(&graphql); + + assert!(summary.contains("type=query")); + assert!(summary.contains("fields=viewer")); + assert!(!summary.contains(placeholder)); + assert!(!summary.contains("real-token")); + assert!(!summary.contains("variables")); + assert!(!summary.contains("token")); + assert!(!summary.contains("secret_len")); + } + #[tokio::test] async fn rewrites_discord_like_identify_text_payload() { let (child_env, _) = resolver(); @@ -1182,6 +1579,56 @@ mod tests { let _ = tokio::time::timeout(std::time::Duration::from_secs(2), relay).await; } + #[tokio::test] + async fn graphql_websocket_policy_allows_subscription_operation() { + let payload = r#"{"type":"subscribe","id":"1","payload":{"query":"subscription NewMessages { messageAdded }"}}"#; + let frame = masked_frame(true, OPCODE_TEXT, payload.as_bytes()); + + let output = run_client_to_server_with_graphql_policy(frame.clone(), None) + .await + .expect("allowed subscription should relay"); + + assert_eq!(output, frame); + assert_eq!(decode_masked_text_frame(&output), payload); + } + + #[tokio::test] + async fn graphql_websocket_policy_denies_unlisted_operation_field() { + let payload = + r#"{"type":"subscribe","id":"1","payload":{"query":"query Admin { adminAuditLog }"}}"#; + let frame = masked_frame(true, OPCODE_TEXT, payload.as_bytes()); + + let err = run_client_to_server_with_graphql_policy(frame, None) + .await + .expect_err("unlisted field should be denied"); + + assert!(err.to_string().contains("websocket GraphQL message denied")); + } + + #[tokio::test] + async fn graphql_websocket_control_message_rewrites_credentials_before_relay() { + let (child_env, resolver) = SecretResolver::from_provider_env( + std::iter::once(("T".to_string(), "real-token".to_string())).collect(), + ); + let resolver = resolver.expect("resolver"); + let placeholder = child_env.get("T").expect("placeholder env"); + let payload = format!( + r#"{{"type":"connection_init","payload":{{"authorization":"{placeholder}"}}}}"# + ); + let frame = masked_frame(true, OPCODE_TEXT, payload.as_bytes()); + + let output = run_client_to_server_with_graphql_policy(frame, Some(&resolver)) + .await + .expect("control message should relay after credential rewrite"); + + let rewritten = decode_masked_text_frame(&output); + assert_eq!( + rewritten, + r#"{"type":"connection_init","payload":{"authorization":"real-token"}}"# + ); + assert!(!rewritten.contains(placeholder)); + } + #[tokio::test] async fn text_without_placeholder_passes_semantically_unchanged() { let frame = masked_frame(true, OPCODE_TEXT, br#"{"op":1,"d":42}"#); diff --git a/crates/openshell-sandbox/src/opa.rs b/crates/openshell-sandbox/src/opa.rs index 803dc2ad1..c6cd32f8a 100644 --- a/crates/openshell-sandbox/src/opa.rs +++ b/crates/openshell-sandbox/src/opa.rs @@ -1814,6 +1814,28 @@ network_policies: access: read-only binaries: - { path: /usr/bin/curl } + graphql_ws: + name: graphql_ws + endpoints: + - host: realtime.graphql.com + ports: [443] + path: "/graphql" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/graphql" + - allow: + operation_type: query + fields: [viewer] + - allow: + operation_type: subscription + fields: [messageAdded] + deny_rules: + - operation_type: mutation + binaries: + - { path: /usr/bin/curl } l4_only: name: l4_only endpoints: @@ -1900,6 +1922,25 @@ process: }) } + fn l7_websocket_graphql_input(host: &str, operations: serde_json::Value) -> serde_json::Value { + serde_json::json!({ + "network": { "host": host, "port": 443 }, + "exec": { + "path": "/usr/bin/curl", + "ancestors": [], + "cmdline_paths": [] + }, + "request": { + "method": "WEBSOCKET_TEXT", + "path": "/graphql", + "query_params": {}, + "graphql": { + "operations": operations + } + } + }) + } + fn eval_l7(engine: &OpaEngine, input: &serde_json::Value) -> bool { let mut eng = engine.engine.lock().unwrap(); eng.set_input_json(&input.to_string()).unwrap(); @@ -2137,6 +2178,97 @@ process: assert!(!eval_l7(&engine, &mutation)); } + #[test] + fn l7_websocket_graphql_subscription_allowed_by_field_rule() { + let engine = l7_engine(); + let input = l7_websocket_graphql_input( + "realtime.graphql.com", + serde_json::json!([{ + "operation_type": "subscription", + "operation_name": "NewMessages", + "fields": ["messageAdded"], + "persisted_query": false + }]), + ); + assert!(eval_l7(&engine, &input)); + } + + #[test] + fn l7_websocket_graphql_unlisted_field_denied() { + let engine = l7_engine(); + let input = l7_websocket_graphql_input( + "realtime.graphql.com", + serde_json::json!([{ + "operation_type": "query", + "fields": ["adminAuditLog"], + "persisted_query": false + }]), + ); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_websocket_graphql_deny_rule_takes_precedence() { + let engine = l7_engine(); + let input = l7_websocket_graphql_input( + "realtime.graphql.com", + serde_json::json!([{ + "operation_type": "mutation", + "operation_name": "DeleteRepo", + "fields": ["deleteRepository"], + "persisted_query": false + }]), + ); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_websocket_graphql_not_bypassed_by_generic_text_rule() { + let data = r#" +network_policies: + graphql_ws: + name: graphql_ws + endpoints: + - host: realtime.graphql.com + ports: [443] + path: "/graphql" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/graphql" + - allow: + method: WEBSOCKET_TEXT + path: "/graphql" + - allow: + operation_type: query + fields: [viewer] + binaries: + - { path: /usr/bin/curl } +"#; + let data_json: serde_json::Value = + serde_yml::from_str(data).expect("fixture should parse as YAML"); + let mut rego = regorus::Engine::new(); + rego.add_policy("policy.rego".into(), TEST_POLICY.into()) + .expect("policy should load"); + rego.add_data_json(&data_json.to_string()) + .expect("data should load"); + let engine = OpaEngine { + engine: Mutex::new(rego), + generation: Arc::new(AtomicU64::new(0)), + }; + let input = l7_websocket_graphql_input( + "realtime.graphql.com", + serde_json::json!([{ + "operation_type": "query", + "fields": ["adminAuditLog"], + "persisted_query": false + }]), + ); + assert!(!eval_l7(&engine, &input)); + } + #[test] fn l7_endpoint_path_scopes_rest_and_graphql_on_same_host() { let data = r#" diff --git a/crates/openshell-sandbox/src/proxy.rs b/crates/openshell-sandbox/src/proxy.rs index 94c1f53f5..7413b895d 100644 --- a/crates/openshell-sandbox/src/proxy.rs +++ b/crates/openshell-sandbox/src/proxy.rs @@ -3354,6 +3354,7 @@ mod tests { graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, allow_encoded_slash: false, websocket_credential_rewrite: false, + websocket_graphql_policy: false, }, }, L7ConfigSnapshot { @@ -3365,6 +3366,7 @@ mod tests { graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, allow_encoded_slash: false, websocket_credential_rewrite: false, + websocket_graphql_policy: false, }, }, ]; diff --git a/docs/reference/policy-schema.mdx b/docs/reference/policy-schema.mdx index 3d51a3da2..5865f78e1 100644 --- a/docs/reference/policy-schema.mdx +++ b/docs/reference/policy-schema.mdx @@ -155,7 +155,7 @@ Each endpoint defines a reachable destination and optional inspection rules. | `host` | string | Yes | Hostname or IP address. Supports wildcards: `*.example.com` matches any subdomain. | | `port` | integer | Yes | TCP port number. | | `path` | string | No | Optional HTTP path glob used to select between L7 endpoints that share the same host and port. Empty means all paths. Use this when REST and GraphQL live under the same host, such as `/repos/**` and `/graphql`. | -| `protocol` | string | No | Set to `rest` for HTTP method/path inspection, `websocket` for RFC 6455 upgrade and client text-message inspection, or `graphql` for GraphQL operation inspection. Omit for TCP passthrough. | +| `protocol` | string | No | Set to `rest` for HTTP method/path inspection, `websocket` for RFC 6455 upgrade and client text-message inspection, or `graphql` for GraphQL-over-HTTP operation inspection. WebSocket endpoints can also use GraphQL operation rules for GraphQL-over-WebSocket traffic. Omit for TCP passthrough. | | `tls` | string | No | TLS handling mode. The proxy auto-detects TLS by peeking the first bytes of each connection and terminates it for inspected HTTPS traffic, so this field is optional in most cases. Set to `skip` to disable auto-detection for edge cases such as client-certificate mTLS or non-standard protocols. The values `terminate` and `passthrough` are deprecated and log a warning; they are still accepted for backward compatibility but have no effect on behavior. | | `enforcement` | string | No | `enforce` actively blocks disallowed requests. `audit` logs violations but allows traffic through. | | `access` | string | No | Access preset. One of `read-only`, `read-write`, or `full`. Mutually exclusive with `rules`. | @@ -164,9 +164,9 @@ Each endpoint defines a reachable destination and optional inspection rules. | `allowed_ips` | list of string | No | CIDR or IP allowlist for SSRF override. Entries overlapping loopback (`127.0.0.0/8`), link-local (`169.254.0.0/16`), or unspecified (`0.0.0.0`) are rejected at load time. | | `allow_encoded_slash` | bool | No | When `true`, L7 request parsing preserves `%2F` inside path segments instead of rejecting it. Use this for registries and APIs such as npm scoped packages (`/@scope%2Fname`). Defaults to `false`. | | `websocket_credential_rewrite` | bool | No | When `true` on a `protocol: rest` or `protocol: websocket` endpoint, OpenShell rewrites `openshell:resolve:env:*` placeholders in client-to-server WebSocket text messages after an allowed HTTP `101` upgrade. Binary frames are relayed but not rewritten. Defaults to `false`. | -| `persisted_queries` | string | No | GraphQL hash-only behavior. Default is `deny`; use `allow_registered` only with `graphql_persisted_queries`. | +| `persisted_queries` | string | No | GraphQL hash-only behavior for `protocol: graphql` and GraphQL-over-WebSocket operation policy. Default is `deny`; use `allow_registered` only with `graphql_persisted_queries`. | | `graphql_persisted_queries` | map | No | Trusted GraphQL persisted-query registry keyed by hash or saved-query ID. Values contain `operation_type`, optional `operation_name`, and optional root `fields`. | -| `graphql_max_body_bytes` | integer | No | Maximum GraphQL request body bytes buffered for inspection. Defaults to `65536`. | +| `graphql_max_body_bytes` | integer | No | Maximum GraphQL-over-HTTP request body bytes buffered for inspection. Defaults to `65536`. | #### Access Levels @@ -231,9 +231,9 @@ rules: path: /v1/realtime/** ``` -##### GraphQL Allow Rule (`protocol: graphql`) +##### GraphQL Allow Rule (`protocol: graphql` or GraphQL-over-WebSocket) -GraphQL allow rules match parsed GraphQL operations by operation type, optional operation name, and optional root fields. +GraphQL allow rules match parsed GraphQL operations by operation type, optional operation name, and optional root fields. On `protocol: graphql`, they apply to GraphQL-over-HTTP `GET` and `POST` requests. On `protocol: websocket`, include a separate `GET` allow rule for the RFC 6455 upgrade, then use GraphQL allow rules for client operation messages using the `graphql-transport-ws` `subscribe` message type or the legacy `graphql-ws` `start` message type. | Field | Type | Required | Description | |---|---|---|---| @@ -254,6 +254,23 @@ rules: fields: [createIssue] ``` +Example GraphQL-over-WebSocket allow rules: + +```yaml showLineNumbers={false} +rules: + - allow: + method: GET + path: /graphql + - allow: + operation_type: subscription + fields: [messageAdded] + - allow: + operation_type: query + fields: [viewer] +``` + +Do not combine `method`, `path`, or `query` with `operation_type`, `operation_name`, or `fields` inside the same WebSocket rule. When a WebSocket endpoint has GraphQL operation policy, use GraphQL rules for client messages instead of a raw `WEBSOCKET_TEXT` allow rule. + #### Deny Rule Objects Blocks specific operations on endpoints that otherwise have broad access. Deny rules are evaluated after allow rules and take precedence: if a request matches any deny rule, it is blocked regardless of what the allow rules or access preset permit. @@ -310,9 +327,9 @@ endpoints: path: "/v1/admin/**" ``` -##### GraphQL Deny Rule (`protocol: graphql`) +##### GraphQL Deny Rule (`protocol: graphql` or GraphQL-over-WebSocket) -GraphQL deny rules use the same field names as GraphQL allow rules, but they appear directly under each `deny_rules` entry instead of under an `allow` wrapper. +GraphQL deny rules use the same field names as GraphQL allow rules, but they appear directly under each `deny_rules` entry instead of under an `allow` wrapper. On WebSocket GraphQL endpoints, they apply only to classified GraphQL operation messages; protocol lifecycle messages such as `connection_init`, `ping`, `pong`, and `complete` are allowed as WebSocket control-plane messages and are not payload-logged. | Field | Type | Required | Description | |---|---|---|---| diff --git a/docs/sandboxes/policies.mdx b/docs/sandboxes/policies.mdx index eed8d59ab..8cc0ee586 100644 --- a/docs/sandboxes/policies.mdx +++ b/docs/sandboxes/policies.mdx @@ -56,7 +56,7 @@ Raw streams are connection-scoped and outside L7 live-reload guarantees. This in | `filesystem_policy` | Static | Controls which directories the agent can access on disk. Paths are split into `read_only` and `read_write` lists. Any path not listed in either list is inaccessible. Set `include_workdir: true` to automatically add the agent's working directory to `read_write`. [Landlock LSM](https://docs.kernel.org/security/landlock.html) enforces these restrictions at the kernel level. | | `landlock` | Static | Configures Landlock LSM enforcement behavior. Set `compatibility` to `best_effort` (skip individual inaccessible paths while applying remaining rules) or `hard_requirement` (fail if any path is inaccessible or the required kernel ABI is unavailable). See the [Policy Schema Reference](/reference/policy-schema#landlock) for the full behavior table. | | `process` | Static | Sets the OS-level identity for the agent process. `run_as_user` and `run_as_group` default to `sandbox`. Root (`root` or `0`) is rejected. The agent also runs with seccomp filters that block dangerous system calls. | -| `network_policies` | Dynamic | Controls network access for ordinary outbound traffic from the sandbox. Each block has a name, a list of endpoints (host, port, protocol, and optional rules), and a list of binaries allowed to use those endpoints.
Every outbound connection except `https://inference.local` goes through the proxy, which queries the [policy engine](/about/how-it-works#core-components) with the destination and calling binary. A connection is allowed only when both match an entry in the same policy block.
For endpoints with `protocol: rest`, the proxy auto-detects TLS and terminates it so each HTTP request can be checked against that endpoint's `rules` (method and path). For endpoints with `protocol: websocket`, the proxy validates the RFC 6455 upgrade and evaluates `GET` rules for the handshake plus `WEBSOCKET_TEXT` rules for client text messages on the upgraded request path. Set `websocket_credential_rewrite: true` only when a WebSocket or REST compatibility endpoint must keep placeholder credentials in sandbox-owned text frames and resolve them at the OpenShell relay boundary.
Endpoints without `protocol` allow the TCP stream through without inspecting payloads.
If no endpoint matches, the connection is denied. Configure managed inference separately through [Inference Routing](/sandboxes/inference-routing). | +| `network_policies` | Dynamic | Controls network access for ordinary outbound traffic from the sandbox. Each block has a name, a list of endpoints (host, port, protocol, and optional rules), and a list of binaries allowed to use those endpoints.
Every outbound connection except `https://inference.local` goes through the proxy, which queries the [policy engine](/about/how-it-works#core-components) with the destination and calling binary. A connection is allowed only when both match an entry in the same policy block.
For endpoints with `protocol: rest`, the proxy auto-detects TLS and terminates it so each HTTP request can be checked against that endpoint's `rules` (method and path). For endpoints with `protocol: websocket`, the proxy validates the RFC 6455 upgrade and evaluates `GET` rules for the handshake plus either `WEBSOCKET_TEXT` rules for raw client text messages or GraphQL operation rules for GraphQL-over-WebSocket messages. Set `websocket_credential_rewrite: true` only when a WebSocket or REST compatibility endpoint must keep placeholder credentials in sandbox-owned text frames and resolve them at the OpenShell relay boundary.
Endpoints without `protocol` allow the TCP stream through without inspecting payloads.
If no endpoint matches, the connection is denied. Configure managed inference separately through [Inference Routing](/sandboxes/inference-routing). | ## Baseline Filesystem Paths @@ -537,7 +537,7 @@ For an end-to-end walkthrough that combines this policy with a GitHub credential - { path: /usr/bin/gh } ``` -Endpoints with `protocol: rest` enable HTTP request inspection. Endpoints with `protocol: websocket` validate WebSocket upgrades and inspect client text messages on the upgraded request path. Endpoints with `protocol: graphql` parse GraphQL-over-HTTP payloads before evaluating rules. The endpoint-level `path` field lets these protocols share `api.github.com:443` without treating GraphQL payloads as plain REST `POST /graphql` requests. +Endpoints with `protocol: rest` enable HTTP request inspection. Endpoints with `protocol: websocket` validate WebSocket upgrades and inspect client text messages on the upgraded request path. WebSocket endpoints can also classify GraphQL-over-WebSocket operation messages with the same operation rules used by GraphQL-over-HTTP. Endpoints with `protocol: graphql` parse GraphQL-over-HTTP payloads before evaluating rules. The endpoint-level `path` field lets these protocols share `api.github.com:443` without treating GraphQL payloads as plain REST `POST /graphql` requests. @@ -602,6 +602,36 @@ For allow rules, every selected root field in an operation must match one of the Hash-only persisted queries cannot be classified from the request alone. OpenShell denies them unless the endpoint uses `persisted_queries: allow_registered` and provides a trusted `graphql_persisted_queries` entry keyed by hash or saved-query ID. +### GraphQL-over-WebSocket matching + +Some APIs carry GraphQL operations over RFC 6455 WebSockets, commonly for subscriptions and realtime updates. Configure these as `protocol: websocket`, allow the upgrade with a normal `GET` rule, then add GraphQL operation rules for client operation messages. OpenShell recognizes modern `graphql-transport-ws` `subscribe` messages and legacy `graphql-ws` `start` messages. + +```yaml showLineNumbers={false} + realtime_graphql: + name: realtime_graphql + endpoints: + - host: realtime.example.com + port: 443 + path: "/graphql" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/graphql" + - allow: + operation_type: subscription + fields: [messageAdded] + - allow: + operation_type: query + fields: [viewer] + websocket_credential_rewrite: true + binaries: + - { path: /usr/bin/node } +``` + +When a WebSocket endpoint has GraphQL operation policy, client operation messages are fail-closed on malformed JSON, unsupported message types, parse errors, unregistered hash-only persisted queries, or unallowed operations. Use GraphQL operation rules for client messages rather than a raw `WEBSOCKET_TEXT` allow rule. Protocol lifecycle messages such as `connection_init`, `ping`, `pong`, and `complete` are allowed without payload logging; if `websocket_credential_rewrite: true` is set, placeholders inside those text messages are resolved before forwarding. + ### GraphQL service policy shapes GraphQL field names are application-specific, so treat these as starting shapes to review against the actual app schema: diff --git a/docs/security/best-practices.mdx b/docs/security/best-practices.mdx index 224cdf644..4571ed94c 100644 --- a/docs/security/best-practices.mdx +++ b/docs/security/best-practices.mdx @@ -97,9 +97,9 @@ The `protocol` field on an endpoint controls whether the proxy inspects individu | Aspect | Detail | |---|---| | Default | Endpoints without a `protocol` field use L4-only enforcement: the proxy checks host, port, and binary, then relays the TCP stream without inspecting payloads. | -| What you can change | Add `protocol: rest` to enable per-request HTTP method/path inspection, `protocol: websocket` to inspect RFC 6455 upgrade handshakes and client text messages, or `protocol: graphql` to inspect GraphQL operation type, operation name, and root fields. Pair inspected protocols with `rules` or access presets (`full`, `read-only`, `read-write`). | +| What you can change | Add `protocol: rest` to enable per-request HTTP method/path inspection, `protocol: websocket` to inspect RFC 6455 upgrade handshakes and client text messages, or `protocol: graphql` to inspect GraphQL-over-HTTP operation type, operation name, and root fields. WebSocket endpoints can also use GraphQL operation rules for GraphQL-over-WebSocket messages. Pair inspected protocols with `rules` or access presets (`full`, `read-only`, `read-write`). | | Risk if relaxed | L4-only endpoints allow the agent to send any data through the tunnel after the initial connection is permitted. The proxy cannot see HTTP methods, paths, or GraphQL operations. Adding `access: full` with L7 inspection enables observability but permits all inspected actions. | -| Recommendation | Use `protocol: rest` with specific `rules` for APIs where intent is encoded in method and path. Use `protocol: graphql` for GraphQL APIs where destructive operations are body-encoded. Use `protocol: websocket` for RFC 6455 endpoints, with explicit `GET` and `WEBSOCKET_TEXT` rules when finer control is needed. Prefer `access: read-only` or explicit allowlists, and deny hash-only persisted queries unless you maintain a trusted registry. Omit `protocol` for non-HTTP protocols. For WebSocket endpoints that must carry placeholder credentials in client text frames, add `websocket_credential_rewrite: true`. | +| Recommendation | Use `protocol: rest` with specific `rules` for APIs where intent is encoded in method and path. Use `protocol: graphql` for GraphQL-over-HTTP APIs where destructive operations are body-encoded. Use `protocol: websocket` for RFC 6455 endpoints, with explicit `GET` and `WEBSOCKET_TEXT` rules for raw text protocols or explicit GraphQL operation rules for GraphQL-over-WebSocket. Prefer `access: read-only` or explicit allowlists, and deny hash-only persisted queries unless you maintain a trusted registry. Omit `protocol` for non-HTTP protocols. For WebSocket endpoints that must carry placeholder credentials in client text frames, add `websocket_credential_rewrite: true`. | ### Enforcement Mode (`audit` vs `enforce`) From eab184f20bb27c1db8b62deb33717590b018a24a Mon Sep 17 00:00:00 2001 From: Aaron Erickson Date: Sun, 10 May 2026 16:06:45 -0700 Subject: [PATCH 13/17] fix(policy): allow private IPs for websocket endpoints --- crates/openshell-cli/src/main.rs | 2 +- crates/openshell-cli/src/policy_update.rs | 87 ++++++++++++- crates/openshell-driver-docker/src/lib.rs | 20 ++- crates/openshell-driver-docker/src/tests.rs | 31 ++++- .../src/provider_credentials.rs | 4 + crates/openshell-sandbox/src/proxy.rs | 114 +++++++++++++++--- crates/openshell-sandbox/src/secrets.rs | 6 +- 7 files changed, 238 insertions(+), 26 deletions(-) diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index 94951c6c7..96714ab30 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -287,7 +287,7 @@ const POLICY_EXAMPLES: &str = "\x1b[1mALIAS\x1b[0m $ openshell policy get my-sandbox $ openshell policy set my-sandbox --policy policy.yaml $ openshell policy update my-sandbox --add-endpoint api.github.com:443:read-only:rest:enforce - $ openshell policy update my-sandbox --add-endpoint realtime.example.com:443:read-write:websocket:enforce:websocket-credential-rewrite + $ openshell policy update my-sandbox --add-endpoint realtime.example.com:443:read-write:websocket:enforce:websocket-credential-rewrite,allowed-ip=10.0.0.0/8 $ openshell policy update my-sandbox --add-allow 'api.github.com:443:GET:/repos/**' $ openshell policy set --global --policy policy.yaml $ openshell policy delete --global diff --git a/crates/openshell-cli/src/policy_update.rs b/crates/openshell-cli/src/policy_update.rs index 8fe4db269..10f66ff09 100644 --- a/crates/openshell-cli/src/policy_update.rs +++ b/crates/openshell-cli/src/policy_update.rs @@ -354,9 +354,29 @@ fn apply_add_endpoint_options( endpoint.websocket_credential_rewrite = true; } _ => { - return Err(miette!( - "--add-endpoint options segment supports only 'websocket-credential-rewrite'; got '{option}' in '{spec}'" - )); + let Some(allowed_ip) = option.strip_prefix("allowed-ip=") else { + return Err(miette!( + "--add-endpoint options segment supports only 'websocket-credential-rewrite' and 'allowed-ip='; got '{option}' in '{spec}'" + )); + }; + let allowed_ip = allowed_ip.trim(); + if allowed_ip.is_empty() { + return Err(miette!( + "--add-endpoint allowed-ip option must include a CIDR or IP value in '{spec}'" + )); + } + if allowed_ip.contains(char::is_whitespace) { + return Err(miette!( + "--add-endpoint allowed-ip option must not contain whitespace in '{spec}'" + )); + } + if !endpoint + .allowed_ips + .iter() + .any(|existing| existing == allowed_ip) + { + endpoint.allowed_ips.push(allowed_ip.to_string()); + } } } } @@ -537,6 +557,67 @@ mod tests { assert!(rule.endpoints[0].websocket_credential_rewrite); } + #[test] + fn parse_add_endpoint_merges_allowed_ips_with_websocket_options() { + let plan = build_policy_update_plan( + &[ + "realtime.example.com:443:read-write:websocket:enforce:websocket-credential-rewrite,allowed-ip=10.0.0.0/8,allowed-ip=172.16.0.0/12,allowed-ip=10.0.0.0/8" + .to_string(), + ], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect("plan should build"); + + let PolicyMergeOp::AddRule { rule, .. } = &plan.preview_operations[0] else { + panic!("expected add-rule preview"); + }; + let endpoint = &rule.endpoints[0]; + assert!(endpoint.websocket_credential_rewrite); + assert_eq!( + endpoint.allowed_ips, + vec!["10.0.0.0/8".to_string(), "172.16.0.0/12".to_string()] + ); + } + + #[test] + fn parse_add_endpoint_accepts_allowed_ip_on_rest_endpoint() { + let plan = build_policy_update_plan( + &["api.example.com:443:read-write:rest:enforce:allowed-ip=192.168.0.0/16".to_string()], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect("plan should build"); + + let PolicyMergeOp::AddRule { rule, .. } = &plan.preview_operations[0] else { + panic!("expected add-rule preview"); + }; + assert_eq!(rule.endpoints[0].allowed_ips, vec!["192.168.0.0/16"]); + } + + #[test] + fn parse_add_endpoint_rejects_empty_allowed_ip() { + let error = build_policy_update_plan( + &["api.example.com:443:read-write:rest:enforce:allowed-ip=".to_string()], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect_err("plan should fail"); + assert!(error.to_string().contains("allowed-ip option")); + } + #[test] fn websocket_credential_rewrite_rejects_l4_endpoint() { let error = build_policy_update_plan( diff --git a/crates/openshell-driver-docker/src/lib.rs b/crates/openshell-driver-docker/src/lib.rs index a74609fb5..b6be2fce5 100644 --- a/crates/openshell-driver-docker/src/lib.rs +++ b/crates/openshell-driver-docker/src/lib.rs @@ -1099,7 +1099,7 @@ fn docker_gateway_route( }; } - if is_docker_desktop(info) { + if uses_host_gateway_alias(info) { DockerGatewayRoute::HostGateway } else { DockerGatewayRoute::Bridge { @@ -1109,7 +1109,7 @@ fn docker_gateway_route( } } -fn is_docker_desktop(info: &SystemInfo) -> bool { +fn uses_host_gateway_alias(info: &SystemInfo) -> bool { let operating_system = info .operating_system .as_deref() @@ -1119,6 +1119,15 @@ fn is_docker_desktop(info: &SystemInfo) -> bool { return true; } + let name = info + .name + .as_deref() + .unwrap_or_default() + .to_ascii_lowercase(); + if name == "colima" { + return true; + } + info.labels.as_ref().is_some_and(|labels| { labels .iter() @@ -1132,9 +1141,10 @@ fn docker_extra_hosts(route: &DockerGatewayRoute) -> Vec { format!("{HOST_DOCKER_INTERNAL}:{host_alias_ip}"), format!("{HOST_OPENSHELL_INTERNAL}:{host_alias_ip}"), ], - DockerGatewayRoute::HostGateway => { - vec![format!("{HOST_OPENSHELL_INTERNAL}:host-gateway")] - } + DockerGatewayRoute::HostGateway => vec![ + format!("{HOST_DOCKER_INTERNAL}:host-gateway"), + format!("{HOST_OPENSHELL_INTERNAL}:host-gateway"), + ], } } diff --git a/crates/openshell-driver-docker/src/tests.rs b/crates/openshell-driver-docker/src/tests.rs index c89019398..83906e73c 100644 --- a/crates/openshell-driver-docker/src/tests.rs +++ b/crates/openshell-driver-docker/src/tests.rs @@ -160,7 +160,36 @@ fn docker_gateway_route_uses_host_gateway_for_docker_desktop() { ); assert_eq!( docker_extra_hosts(&DockerGatewayRoute::HostGateway), - vec!["host.openshell.internal:host-gateway".to_string()] + vec![ + "host.docker.internal:host-gateway".to_string(), + "host.openshell.internal:host-gateway".to_string() + ] + ); +} + +#[test] +fn docker_gateway_route_uses_host_gateway_for_colima() { + let info = SystemInfo { + name: Some("colima".to_string()), + operating_system: Some("Ubuntu 24.04.4 LTS".to_string()), + ..Default::default() + }; + + assert_eq!( + docker_gateway_route( + &info, + IpAddr::V4(Ipv4Addr::new(172, 20, 0, 1)), + DEFAULT_SERVER_PORT, + None, + ), + DockerGatewayRoute::HostGateway + ); + assert_eq!( + docker_extra_hosts(&DockerGatewayRoute::HostGateway), + vec![ + "host.docker.internal:host-gateway".to_string(), + "host.openshell.internal:host-gateway".to_string() + ] ); } diff --git a/crates/openshell-sandbox/src/provider_credentials.rs b/crates/openshell-sandbox/src/provider_credentials.rs index bd28824ae..ffe0148a4 100644 --- a/crates/openshell-sandbox/src/provider_credentials.rs +++ b/crates/openshell-sandbox/src/provider_credentials.rs @@ -122,6 +122,10 @@ mod tests { resolver.resolve_placeholder("openshell:resolve:env:v11_GITHUB_TOKEN"), Some("new") ); + assert_eq!( + resolver.resolve_placeholder("openshell:resolve:env:GITHUB_TOKEN"), + Some("new") + ); } #[test] diff --git a/crates/openshell-sandbox/src/proxy.rs b/crates/openshell-sandbox/src/proxy.rs index 7413b895d..325c93859 100644 --- a/crates/openshell-sandbox/src/proxy.rs +++ b/crates/openshell-sandbox/src/proxy.rs @@ -2416,6 +2416,44 @@ where .await } +fn forward_websocket_upgrade_settings( + config: &crate::l7::L7EndpointConfig, + websocket_request: bool, + secret_resolver: Option>, +) -> ( + crate::l7::rest::WebSocketExtensionMode, + crate::l7::relay::UpgradeRelayOptions<'static>, +) { + let websocket_credential_rewrite = matches!( + config.protocol, + crate::l7::L7Protocol::Rest | crate::l7::L7Protocol::Websocket + ) && config.websocket_credential_rewrite; + let websocket_extensions = if config.protocol == crate::l7::L7Protocol::Websocket + || (config.protocol == crate::l7::L7Protocol::Rest && websocket_credential_rewrite) + { + crate::l7::rest::WebSocketExtensionMode::PermessageDeflate + } else { + crate::l7::rest::WebSocketExtensionMode::Preserve + }; + + let upgrade_options = crate::l7::relay::UpgradeRelayOptions { + websocket_request, + websocket: crate::l7::relay::WebSocketUpgradeBehavior { + credential_rewrite: websocket_credential_rewrite, + ..Default::default() + }, + secret_resolver: if websocket_credential_rewrite { + secret_resolver + } else { + None + }, + enforcement: config.enforcement, + ..Default::default() + }; + + (websocket_extensions, upgrade_options) +} + /// Handle a plain HTTP forward proxy request (non-CONNECT). /// /// Public IPs are allowed through when the endpoint passes OPA evaluation. @@ -2774,21 +2812,12 @@ async fn handle_forward_proxy( }; let websocket_request = crate::l7::rest::request_is_websocket_upgrade(&forward_request_bytes); - if l7_config.config.protocol == crate::l7::L7Protocol::Rest - && l7_config.config.websocket_credential_rewrite - { - websocket_extensions = crate::l7::rest::WebSocketExtensionMode::PermessageDeflate; - upgrade_options = crate::l7::relay::UpgradeRelayOptions { - websocket_request, - websocket: crate::l7::relay::WebSocketUpgradeBehavior { - credential_rewrite: true, - ..Default::default() - }, - secret_resolver: secret_resolver.clone(), - policy_name: matched_policy.clone().unwrap_or_default(), - ..Default::default() - }; - } + (websocket_extensions, upgrade_options) = forward_websocket_upgrade_settings( + &l7_config.config, + websocket_request, + secret_resolver.clone(), + ); + upgrade_options.policy_name = matched_policy.clone().unwrap_or_default(); let graphql = if l7_config.config.protocol == crate::l7::L7Protocol::Graphql { let header_end = forward_request_bytes .windows(4) @@ -3341,6 +3370,61 @@ fn is_benign_relay_error(err: &miette::Report) -> bool { mod tests { use super::*; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; + use std::sync::Arc; + + fn websocket_l7_config( + protocol: crate::l7::L7Protocol, + websocket_credential_rewrite: bool, + ) -> crate::l7::L7EndpointConfig { + crate::l7::L7EndpointConfig { + protocol, + path: "/**".to_string(), + tls: crate::l7::TlsMode::Auto, + enforcement: crate::l7::EnforcementMode::Enforce, + graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, + allow_encoded_slash: false, + websocket_credential_rewrite, + websocket_graphql_policy: false, + } + } + + #[test] + fn forward_websocket_upgrade_enables_rewrite_for_native_websocket_endpoint() { + let (_, resolver) = SecretResolver::from_provider_env( + [("DISCORD_BOT_TOKEN".to_string(), "discord-real".to_string())] + .into_iter() + .collect(), + ); + + let (extensions, options) = forward_websocket_upgrade_settings( + &websocket_l7_config(crate::l7::L7Protocol::Websocket, true), + true, + resolver.map(Arc::new), + ); + + assert_eq!( + extensions, + crate::l7::rest::WebSocketExtensionMode::PermessageDeflate + ); + assert!(options.websocket.credential_rewrite); + assert!(options.secret_resolver.is_some()); + } + + #[test] + fn forward_websocket_upgrade_preserves_rest_without_rewrite() { + let (extensions, options) = forward_websocket_upgrade_settings( + &websocket_l7_config(crate::l7::L7Protocol::Rest, false), + true, + None, + ); + + assert_eq!( + extensions, + crate::l7::rest::WebSocketExtensionMode::Preserve + ); + assert!(!options.websocket.credential_rewrite); + assert!(options.secret_resolver.is_none()); + } #[test] fn l7_route_selection_prefers_path_specific_graphql_endpoint() { diff --git a/crates/openshell-sandbox/src/secrets.rs b/crates/openshell-sandbox/src/secrets.rs index d9b5af2c1..f2cacf44c 100644 --- a/crates/openshell-sandbox/src/secrets.rs +++ b/crates/openshell-sandbox/src/secrets.rs @@ -90,8 +90,12 @@ impl SecretResolver { for (key, value) in provider_env { let placeholder = placeholder_for_env_key_for_revision(&key, revision); + let canonical_placeholder = (revision != 0).then(|| placeholder_for_env_key(&key)); child_env.insert(key, placeholder.clone()); - by_placeholder.insert(placeholder, value); + by_placeholder.insert(placeholder, value.clone()); + if let Some(canonical_placeholder) = canonical_placeholder { + by_placeholder.insert(canonical_placeholder, value); + } } (child_env, Some(Self { by_placeholder })) From ecad571d63a1f2909a6cb9dc3363a1bf3b866a82 Mon Sep 17 00:00:00 2001 From: Aaron Erickson Date: Sun, 10 May 2026 17:43:48 -0700 Subject: [PATCH 14/17] feat(sandbox): rewrite REST credential placeholders Signed-off-by: Aaron Erickson --- crates/openshell-cli/src/policy_update.rs | 64 ++- crates/openshell-policy/src/lib.rs | 34 ++ crates/openshell-policy/src/merge.rs | 46 ++ crates/openshell-providers/src/profiles.rs | 4 + crates/openshell-sandbox/src/l7/mod.rs | 65 +++ crates/openshell-sandbox/src/l7/relay.rs | 16 +- crates/openshell-sandbox/src/l7/rest.rs | 532 ++++++++++++++++++- crates/openshell-sandbox/src/opa.rs | 60 +++ crates/openshell-sandbox/src/policy_local.rs | 1 + crates/openshell-sandbox/src/proxy.rs | 67 ++- crates/openshell-sandbox/src/secrets.rs | 445 +++++++++++----- crates/openshell-server/src/grpc/policy.rs | 31 ++ docs/reference/policy-schema.mdx | 5 +- docs/sandboxes/policies.mdx | 13 +- docs/security/best-practices.mdx | 4 +- proto/sandbox.proto | 4 + 16 files changed, 1206 insertions(+), 185 deletions(-) diff --git a/crates/openshell-cli/src/policy_update.rs b/crates/openshell-cli/src/policy_update.rs index 10f66ff09..9c9e91188 100644 --- a/crates/openshell-cli/src/policy_update.rs +++ b/crates/openshell-cli/src/policy_update.rs @@ -172,6 +172,23 @@ fn ensure_websocket_credential_rewrite_protocol( )) } +fn ensure_request_body_credential_rewrite_protocol( + spec: &str, + endpoint: &NetworkEndpoint, +) -> Result<()> { + if endpoint.protocol == "rest" { + return Ok(()); + } + let protocol = if endpoint.protocol.is_empty() { + "" + } else { + endpoint.protocol.as_str() + }; + Err(miette!( + "request-body-credential-rewrite endpoint option requires --add-endpoint protocol segment to be 'rest'; got '{protocol}' in '{spec}'" + )) +} + fn group_allow_rules(specs: &[String]) -> Result>> { let mut grouped = BTreeMap::new(); for spec in specs { @@ -353,10 +370,14 @@ fn apply_add_endpoint_options( ensure_websocket_credential_rewrite_protocol(spec, endpoint)?; endpoint.websocket_credential_rewrite = true; } + "request-body-credential-rewrite" => { + ensure_request_body_credential_rewrite_protocol(spec, endpoint)?; + endpoint.request_body_credential_rewrite = true; + } _ => { let Some(allowed_ip) = option.strip_prefix("allowed-ip=") else { return Err(miette!( - "--add-endpoint options segment supports only 'websocket-credential-rewrite' and 'allowed-ip='; got '{option}' in '{spec}'" + "--add-endpoint options segment supports only 'websocket-credential-rewrite', 'request-body-credential-rewrite', and 'allowed-ip='; got '{option}' in '{spec}'" )); }; let allowed_ip = allowed_ip.trim(); @@ -557,6 +578,27 @@ mod tests { assert!(rule.endpoints[0].websocket_credential_rewrite); } + #[test] + fn parse_add_endpoint_enables_request_body_credential_rewrite_on_rest_endpoint() { + let plan = build_policy_update_plan( + &["slack.com:443:read-write:rest:enforce:request-body-credential-rewrite".to_string()], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect("plan should build"); + + let PolicyMergeOp::AddRule { rule, .. } = &plan.preview_operations[0] else { + panic!("expected add-rule preview"); + }; + let endpoint = &rule.endpoints[0]; + assert_eq!(endpoint.protocol, "rest"); + assert!(endpoint.request_body_credential_rewrite); + } + #[test] fn parse_add_endpoint_merges_allowed_ips_with_websocket_options() { let plan = build_policy_update_plan( @@ -633,6 +675,26 @@ mod tests { assert!(error.to_string().contains("protocol segment")); } + #[test] + fn request_body_credential_rewrite_rejects_non_rest_endpoint() { + let error = build_policy_update_plan( + &[ + "realtime.example.com:443:read-write:websocket:enforce:request-body-credential-rewrite" + .to_string(), + ], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect_err("plan should fail"); + + assert!(error.to_string().contains("protocol segment")); + assert!(error.to_string().contains("'rest'")); + } + #[test] fn parse_add_endpoint_rejects_unknown_options() { let error = build_policy_update_plan( diff --git a/crates/openshell-policy/src/lib.rs b/crates/openshell-policy/src/lib.rs index 0eb42b647..908450111 100644 --- a/crates/openshell-policy/src/lib.rs +++ b/crates/openshell-policy/src/lib.rs @@ -125,6 +125,10 @@ struct NetworkEndpointDef { /// Defaults to false. #[serde(default, skip_serializing_if = "std::ops::Not::not")] websocket_credential_rewrite: bool, + /// When true, supported textual REST request bodies rewrite credential + /// placeholders before forwarding upstream. Defaults to false. + #[serde(default, skip_serializing_if = "std::ops::Not::not")] + request_body_credential_rewrite: bool, #[serde(default, skip_serializing_if = "String::is_empty")] persisted_queries: String, #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] @@ -323,6 +327,7 @@ fn to_proto(raw: PolicyFile) -> SandboxPolicy { .collect(), allow_encoded_slash: e.allow_encoded_slash, websocket_credential_rewrite: e.websocket_credential_rewrite, + request_body_credential_rewrite: e.request_body_credential_rewrite, persisted_queries: e.persisted_queries, graphql_persisted_queries: e .graphql_persisted_queries @@ -487,6 +492,7 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { .collect(), allow_encoded_slash: e.allow_encoded_slash, websocket_credential_rewrite: e.websocket_credential_rewrite, + request_body_credential_rewrite: e.request_body_credential_rewrite, persisted_queries: e.persisted_queries.clone(), graphql_persisted_queries: e .graphql_persisted_queries @@ -1690,6 +1696,33 @@ network_policies: assert!(yaml_out.contains("websocket_credential_rewrite: true")); } + #[test] + fn round_trip_preserves_request_body_credential_rewrite() { + let yaml = r" +version: 1 +network_policies: + slack_api: + name: slack_api + endpoints: + - host: slack.com + port: 443 + protocol: rest + enforcement: enforce + access: read-write + request_body_credential_rewrite: true + binaries: + - path: /usr/bin/node +"; + let proto1 = parse_sandbox_policy(yaml).expect("parse failed"); + let yaml_out = serialize_sandbox_policy(&proto1).expect("serialize failed"); + let proto2 = parse_sandbox_policy(&yaml_out).expect("re-parse failed"); + + let ep = &proto2.network_policies["slack_api"].endpoints[0]; + assert_eq!(ep.protocol, "rest"); + assert!(ep.request_body_credential_rewrite); + assert!(yaml_out.contains("request_body_credential_rewrite: true")); + } + #[test] fn websocket_credential_rewrite_defaults_false() { let yaml = r" @@ -1707,6 +1740,7 @@ network_policies: let proto = parse_sandbox_policy(yaml).expect("parse failed"); let ep = &proto.network_policies["gateway"].endpoints[0]; assert!(!ep.websocket_credential_rewrite); + assert!(!ep.request_body_credential_rewrite); } #[test] diff --git a/crates/openshell-policy/src/merge.rs b/crates/openshell-policy/src/merge.rs index 5b102cdb8..d99d9c216 100644 --- a/crates/openshell-policy/src/merge.rs +++ b/crates/openshell-policy/src/merge.rs @@ -464,6 +464,7 @@ fn merge_endpoint( append_unique_strings(&mut existing.allowed_ips, &incoming.allowed_ips); existing.allow_encoded_slash |= incoming.allow_encoded_slash; existing.websocket_credential_rewrite |= incoming.websocket_credential_rewrite; + existing.request_body_credential_rewrite |= incoming.request_body_credential_rewrite; normalize_endpoint(existing); Ok(()) } @@ -920,6 +921,51 @@ mod tests { assert!(endpoint.websocket_credential_rewrite); } + #[test] + fn add_rule_merges_request_body_credential_rewrite_flag() { + let mut policy = restrictive_default_policy(); + policy.network_policies.insert( + "existing".to_string(), + NetworkPolicyRule { + name: "existing".to_string(), + endpoints: vec![NetworkEndpoint { + host: "slack.com".to_string(), + port: 443, + ports: vec![443], + protocol: "rest".to_string(), + access: "read-write".to_string(), + ..Default::default() + }], + ..Default::default() + }, + ); + + let incoming = NetworkPolicyRule { + name: "incoming".to_string(), + endpoints: vec![NetworkEndpoint { + host: "slack.com".to_string(), + port: 443, + ports: vec![443], + protocol: "rest".to_string(), + request_body_credential_rewrite: true, + ..Default::default() + }], + ..Default::default() + }; + + let result = merge_policy( + policy, + &[PolicyMergeOp::AddRule { + rule_name: "allow_slack_com_443".to_string(), + rule: incoming, + }], + ) + .expect("merge should succeed"); + + let endpoint = &result.policy.network_policies["existing"].endpoints[0]; + assert!(endpoint.request_body_credential_rewrite); + } + #[test] fn add_allow_expands_access_preset() { let mut policy = restrictive_default_policy(); diff --git a/crates/openshell-providers/src/profiles.rs b/crates/openshell-providers/src/profiles.rs index c15fd0dac..588e77702 100644 --- a/crates/openshell-providers/src/profiles.rs +++ b/crates/openshell-providers/src/profiles.rs @@ -116,6 +116,8 @@ pub struct EndpointProfile { pub allow_encoded_slash: bool, #[serde(default, skip_serializing_if = "is_false")] pub websocket_credential_rewrite: bool, + #[serde(default, skip_serializing_if = "is_false")] + pub request_body_credential_rewrite: bool, #[serde(default, skip_serializing_if = "String::is_empty")] pub persisted_queries: String, #[serde(default, skip_serializing_if = "HashMap::is_empty")] @@ -417,6 +419,7 @@ fn endpoint_to_proto(endpoint: &EndpointProfile) -> NetworkEndpoint { deny_rules: endpoint.deny_rules.iter().map(deny_rule_to_proto).collect(), allow_encoded_slash: endpoint.allow_encoded_slash, websocket_credential_rewrite: endpoint.websocket_credential_rewrite, + request_body_credential_rewrite: endpoint.request_body_credential_rewrite, persisted_queries: endpoint.persisted_queries.clone(), graphql_persisted_queries: endpoint .graphql_persisted_queries @@ -446,6 +449,7 @@ fn endpoint_from_proto(endpoint: &NetworkEndpoint) -> EndpointProfile { .collect(), allow_encoded_slash: endpoint.allow_encoded_slash, websocket_credential_rewrite: endpoint.websocket_credential_rewrite, + request_body_credential_rewrite: endpoint.request_body_credential_rewrite, persisted_queries: endpoint.persisted_queries.clone(), graphql_persisted_queries: endpoint .graphql_persisted_queries diff --git a/crates/openshell-sandbox/src/l7/mod.rs b/crates/openshell-sandbox/src/l7/mod.rs index e1dc74b21..09278b4f8 100644 --- a/crates/openshell-sandbox/src/l7/mod.rs +++ b/crates/openshell-sandbox/src/l7/mod.rs @@ -61,6 +61,10 @@ pub enum EnforcementMode { } /// L7 configuration for an endpoint, extracted from policy data. +#[allow( + clippy::struct_excessive_bools, + reason = "Endpoint config mirrors independent policy schema toggles." +)] #[derive(Debug, Clone)] pub struct L7EndpointConfig { pub protocol: L7Protocol, @@ -78,6 +82,9 @@ pub struct L7EndpointConfig { /// Opt-in rewrite of credential placeholders in client-to-server /// WebSocket text messages after an allowed HTTP 101 upgrade. pub websocket_credential_rewrite: bool, + /// Opt-in rewrite of credential placeholders in supported textual REST + /// request bodies before forwarding upstream. + pub request_body_credential_rewrite: bool, /// When true, client-to-server GraphQL-over-WebSocket operation messages /// are classified with the same operation policy used by GraphQL-over-HTTP. pub websocket_graphql_policy: bool, @@ -149,6 +156,8 @@ pub fn parse_l7_config(val: ®orus::Value) -> Option { let allow_encoded_slash = get_object_bool(val, "allow_encoded_slash").unwrap_or(false); let websocket_credential_rewrite = get_object_bool(val, "websocket_credential_rewrite").unwrap_or(false); + let request_body_credential_rewrite = + get_object_bool(val, "request_body_credential_rewrite").unwrap_or(false); let websocket_graphql_policy = protocol == L7Protocol::Websocket && endpoint_has_graphql_policy(val); let graphql_max_body_bytes = get_object_u64(val, "graphql_max_body_bytes") @@ -164,6 +173,7 @@ pub fn parse_l7_config(val: ®orus::Value) -> Option { graphql_max_body_bytes, allow_encoded_slash, websocket_credential_rewrite, + request_body_credential_rewrite, websocket_graphql_policy, }) } @@ -621,6 +631,17 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< )); } + if ep + .get("request_body_credential_rewrite") + .and_then(serde_json::Value::as_bool) + .unwrap_or(false) + && protocol != "rest" + { + warnings.push(format!( + "{loc}: request_body_credential_rewrite is ignored unless protocol is rest" + )); + } + if let Some(registry_value) = ep.get("graphql_persisted_queries") { let Some(registry) = registry_value.as_object() else { errors.push(format!( @@ -1227,6 +1248,26 @@ mod tests { assert!(config.websocket_credential_rewrite); } + #[test] + fn parse_l7_config_request_body_credential_rewrite_defaults_false() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "rest", "host": "slack.com", "port": 443}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(!config.request_body_credential_rewrite); + } + + #[test] + fn parse_l7_config_request_body_credential_rewrite_opt_in() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "rest", "host": "slack.com", "port": 443, "request_body_credential_rewrite": true}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(config.request_body_credential_rewrite); + } + #[test] fn parse_l7_config_websocket_graphql_policy_defaults_false() { let val = regorus::Value::from_json_str( @@ -1270,6 +1311,30 @@ mod tests { ); } + #[test] + fn validate_request_body_credential_rewrite_warns_unless_rest() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "request_body_credential_rewrite": true + }], + "binaries": [] + } + } + }); + let (_errors, warnings) = validate_l7_policies(&data); + assert!( + warnings + .iter() + .any(|w| w.contains("request_body_credential_rewrite is ignored")), + "expected request_body_credential_rewrite warning: {warnings:?}" + ); + } + #[test] fn expand_websocket_read_write_access_includes_text_messages() { let mut data = serde_json::json!({ diff --git a/crates/openshell-sandbox/src/l7/relay.rs b/crates/openshell-sandbox/src/l7/relay.rs index 4014bd273..971b2e8e5 100644 --- a/crates/openshell-sandbox/src/l7/relay.rs +++ b/crates/openshell-sandbox/src/l7/relay.rs @@ -349,6 +349,8 @@ where resolver: ctx.secret_resolver.as_deref(), generation_guard: Some(engine.generation_guard()), websocket_extensions: websocket_extension_mode(config), + request_body_credential_rewrite: config.protocol == L7Protocol::Rest + && config.request_body_credential_rewrite, }, ) .await?; @@ -765,6 +767,8 @@ where resolver: ctx.secret_resolver.as_deref(), generation_guard: Some(engine.generation_guard()), websocket_extensions: websocket_extension_mode(config), + request_body_credential_rewrite: config.protocol == L7Protocol::Rest + && config.request_body_credential_rewrite, }, ) .await?; @@ -1246,12 +1250,15 @@ where // Forward request with credential rewriting and relay the response. // relay_http_request_with_resolver handles both directions: it sends // the request upstream and reads the response back to the client. - let outcome = crate::l7::rest::relay_http_request_with_resolver_guarded( + let outcome = crate::l7::rest::relay_http_request_with_options_guarded( &req, client, upstream, - resolver, - Some(generation_guard), + crate::l7::rest::RelayRequestOptions { + resolver, + generation_guard: Some(generation_guard), + ..Default::default() + }, ) .await?; @@ -1408,6 +1415,7 @@ network_policies: graphql_max_body_bytes: 0, allow_encoded_slash: false, websocket_credential_rewrite: true, + request_body_credential_rewrite: false, websocket_graphql_policy: false, }]; let ctx = L7EvalContext { @@ -1507,6 +1515,7 @@ network_policies: graphql_max_body_bytes: 0, allow_encoded_slash: false, websocket_credential_rewrite: true, + request_body_credential_rewrite: false, websocket_graphql_policy: false, }]; let (child_env, resolver) = SecretResolver::from_provider_env( @@ -1623,6 +1632,7 @@ network_policies: graphql_max_body_bytes: 0, allow_encoded_slash: false, websocket_credential_rewrite: true, + request_body_credential_rewrite: false, websocket_graphql_policy: true, }]; let (child_env, resolver) = SecretResolver::from_provider_env( diff --git a/crates/openshell-sandbox/src/l7/rest.rs b/crates/openshell-sandbox/src/l7/rest.rs index e8e8e5b52..b4e2d1675 100644 --- a/crates/openshell-sandbox/src/l7/rest.rs +++ b/crates/openshell-sandbox/src/l7/rest.rs @@ -9,7 +9,9 @@ use crate::l7::provider::{BodyLength, L7Provider, L7Request, RelayOutcome}; use crate::opa::PolicyGenerationGuard; -use crate::secrets::{SecretResolver, rewrite_http_header_block}; +use crate::secrets::{ + SecretResolver, contains_reserved_credential_marker, rewrite_http_header_block, +}; use base64::Engine as _; use miette::{IntoDiagnostic, Result, miette}; use sha1::{Digest, Sha1}; @@ -18,6 +20,7 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tracing::debug; const MAX_HEADER_BYTES: usize = 16384; // 16 KiB for HTTP headers +const MAX_REWRITE_BODY_BYTES: usize = 256 * 1024; const RELAY_BUF_SIZE: usize = 8192; /// Idle timeout for `relay_until_eof`. If no data arrives within this window /// the body is considered complete. Prevents blocking on servers that keep @@ -373,6 +376,7 @@ where resolver, generation_guard, websocket_extensions: WebSocketExtensionMode::Preserve, + request_body_credential_rewrite: false, }, ) .await @@ -390,6 +394,7 @@ pub(crate) struct RelayRequestOptions<'a> { pub(crate) resolver: Option<&'a SecretResolver>, pub(crate) generation_guard: Option<&'a PolicyGenerationGuard>, pub(crate) websocket_extensions: WebSocketExtensionMode, + pub(crate) request_body_credential_rewrite: bool, } pub(crate) async fn relay_http_request_with_options_guarded( @@ -437,37 +442,54 @@ where guard.ensure_current()?; } - upstream - .write_all(&rewrite_result.rewritten) - .await - .into_diagnostic()?; - - let overflow = &req.raw_header[header_end..]; - if !overflow.is_empty() { - if let Some(guard) = options.generation_guard { - guard.ensure_current()?; + if options.request_body_credential_rewrite { + let body = collect_and_rewrite_request_body( + req, + client, + &rewrite_result.rewritten, + header_str, + &req.raw_header[header_end..], + options.resolver, + options.generation_guard, + ) + .await?; + upstream.write_all(&body.headers).await.into_diagnostic()?; + if !body.body.is_empty() { + upstream.write_all(&body.body).await.into_diagnostic()?; } - upstream.write_all(overflow).await.into_diagnostic()?; - } - let overflow_len = overflow.len() as u64; + } else { + upstream + .write_all(&rewrite_result.rewritten) + .await + .into_diagnostic()?; - match req.body_length { - BodyLength::ContentLength(len) => { - let remaining = len.saturating_sub(overflow_len); - if remaining > 0 { - relay_fixed(client, upstream, remaining, options.generation_guard).await?; + let overflow = &req.raw_header[header_end..]; + if !overflow.is_empty() { + if let Some(guard) = options.generation_guard { + guard.ensure_current()?; } + upstream.write_all(overflow).await.into_diagnostic()?; } - BodyLength::Chunked => { - relay_chunked( - client, - upstream, - &req.raw_header[header_end..], - options.generation_guard, - ) - .await?; + let overflow_len = overflow.len() as u64; + + match req.body_length { + BodyLength::ContentLength(len) => { + let remaining = len.saturating_sub(overflow_len); + if remaining > 0 { + relay_fixed(client, upstream, remaining, options.generation_guard).await?; + } + } + BodyLength::Chunked => { + relay_chunked( + client, + upstream, + &req.raw_header[header_end..], + options.generation_guard, + ) + .await?; + } + BodyLength::None => {} } - BodyLength::None => {} } upstream.flush().await.into_diagnostic()?; @@ -486,6 +508,290 @@ where Ok(outcome) } +struct PreparedRequestBody { + headers: Vec, + body: Vec, +} + +async fn collect_and_rewrite_request_body( + req: &L7Request, + client: &mut C, + rewritten_headers: &[u8], + original_header_str: &str, + already_read: &[u8], + resolver: Option<&SecretResolver>, + generation_guard: Option<&PolicyGenerationGuard>, +) -> Result { + match req.body_length { + BodyLength::None => { + if body_bytes_contain_reserved_marker(already_read) { + return Err(miette!( + "request body credential rewrite cannot resolve placeholders without explicit body framing" + )); + } + Ok(PreparedRequestBody { + headers: rewritten_headers.to_vec(), + body: already_read.to_vec(), + }) + } + BodyLength::ContentLength(len) => { + let len = usize::try_from(len) + .map_err(|_| miette!("request body is too large for credential rewrite"))?; + if len > MAX_REWRITE_BODY_BYTES { + return Err(miette!( + "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" + )); + } + let mut body = Vec::with_capacity(len); + let initial_len = already_read.len().min(len); + body.extend_from_slice(&already_read[..initial_len]); + let mut remaining = len.saturating_sub(initial_len); + let mut buf = [0u8; RELAY_BUF_SIZE]; + while remaining > 0 { + let to_read = remaining.min(buf.len()); + let n = client.read(&mut buf[..to_read]).await.into_diagnostic()?; + if n == 0 { + return Err(miette!( + "Connection closed with {remaining} body bytes remaining" + )); + } + if let Some(guard) = generation_guard { + guard.ensure_current()?; + } + body.extend_from_slice(&buf[..n]); + remaining -= n; + } + let (headers, body) = + rewrite_buffered_body(rewritten_headers, original_header_str, body, resolver)?; + Ok(PreparedRequestBody { headers, body }) + } + BodyLength::Chunked => { + let body = collect_chunked_body(client, already_read, generation_guard).await?; + if body_bytes_contain_reserved_marker(&body) { + return Err(miette!( + "request body credential rewrite does not support chunked bodies containing credential placeholders" + )); + } + Ok(PreparedRequestBody { + headers: rewritten_headers.to_vec(), + body, + }) + } + } +} + +fn rewrite_buffered_body( + headers: &[u8], + original_header_str: &str, + body: Vec, + resolver: Option<&SecretResolver>, +) -> Result<(Vec, Vec)> { + if body.is_empty() { + return Ok((headers.to_vec(), body)); + } + + let content_type = content_type(original_header_str); + if !is_rewritable_content_type(content_type.as_deref()) { + if body_bytes_contain_reserved_marker(&body) { + return Err(miette!( + "request body credential rewrite found placeholders in an unsupported content type" + )); + } + return Ok((headers.to_vec(), body)); + } + + let mut text = String::from_utf8(body) + .map_err(|_| miette!("request body credential rewrite requires UTF-8 text bodies"))?; + if !contains_reserved_credential_marker(&text) { + return Ok((headers.to_vec(), text.into_bytes())); + } + + let Some(resolver) = resolver else { + return Err(miette!( + "request body credential rewrite found placeholders but no resolver is available" + )); + }; + let replacements = resolver + .rewrite_text_placeholders(&mut text, "request_body") + .map_err(|e| miette!("credential injection failed: {e}"))?; + if replacements == 0 || contains_reserved_credential_marker(&text) { + return Err(miette!( + "request body credential rewrite left unresolved credential placeholders" + )); + } + + let body = text.into_bytes(); + let headers = set_content_length(headers, body.len())?; + Ok((headers, body)) +} + +async fn collect_chunked_body( + client: &mut C, + already_read: &[u8], + generation_guard: Option<&PolicyGenerationGuard>, +) -> Result> { + let mut read_buf = [0u8; RELAY_BUF_SIZE]; + let mut parse_buf = Vec::from(already_read); + let mut pos = 0usize; + + loop { + if parse_buf.len() > MAX_REWRITE_BODY_BYTES { + return Err(miette!( + "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" + )); + } + + let size_line_end = loop { + if let Some(end) = find_crlf(&parse_buf, pos) { + break end; + } + let n = client.read(&mut read_buf).await.into_diagnostic()?; + if n == 0 { + return Err(miette!("Chunked body ended before chunk-size line")); + } + if let Some(guard) = generation_guard { + guard.ensure_current()?; + } + parse_buf.extend_from_slice(&read_buf[..n]); + if parse_buf.len() > MAX_REWRITE_BODY_BYTES { + return Err(miette!( + "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" + )); + } + }; + + let size_line = std::str::from_utf8(&parse_buf[pos..size_line_end]) + .into_diagnostic() + .map_err(|_| miette!("Invalid UTF-8 in chunk-size line"))?; + let size_token = size_line + .split(';') + .next() + .map(str::trim) + .unwrap_or_default(); + let chunk_size = usize::from_str_radix(size_token, 16) + .into_diagnostic() + .map_err(|_| miette!("Invalid chunk size token: {size_token:?}"))?; + pos = size_line_end + 2; + + if chunk_size == 0 { + loop { + let trailer_end = loop { + if let Some(end) = find_crlf(&parse_buf, pos) { + break end; + } + let n = client.read(&mut read_buf).await.into_diagnostic()?; + if n == 0 { + return Err(miette!("Chunked body ended before trailer terminator")); + } + if let Some(guard) = generation_guard { + guard.ensure_current()?; + } + parse_buf.extend_from_slice(&read_buf[..n]); + if parse_buf.len() > MAX_REWRITE_BODY_BYTES { + return Err(miette!( + "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" + )); + } + }; + let trailer_line = &parse_buf[pos..trailer_end]; + pos = trailer_end + 2; + if trailer_line.is_empty() { + return Ok(parse_buf); + } + } + } + + let chunk_end = pos + .checked_add(chunk_size) + .ok_or_else(|| miette!("Chunk size overflow"))?; + let chunk_with_crlf_end = chunk_end + .checked_add(2) + .ok_or_else(|| miette!("Chunk size overflow"))?; + while parse_buf.len() < chunk_with_crlf_end { + let n = client.read(&mut read_buf).await.into_diagnostic()?; + if n == 0 { + return Err(miette!("Chunked body ended mid-chunk")); + } + if let Some(guard) = generation_guard { + guard.ensure_current()?; + } + parse_buf.extend_from_slice(&read_buf[..n]); + if parse_buf.len() > MAX_REWRITE_BODY_BYTES { + return Err(miette!( + "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" + )); + } + } + if &parse_buf[chunk_end..chunk_with_crlf_end] != b"\r\n" { + return Err(miette!("Chunk missing terminating CRLF")); + } + pos = chunk_with_crlf_end; + } +} + +fn content_type(headers: &str) -> Option { + headers.lines().skip(1).find_map(|line| { + let (name, value) = line.split_once(':')?; + name.trim().eq_ignore_ascii_case("content-type").then(|| { + value + .split(';') + .next() + .unwrap_or("") + .trim() + .to_ascii_lowercase() + }) + }) +} + +fn is_rewritable_content_type(content_type: Option<&str>) -> bool { + let Some(content_type) = content_type else { + return false; + }; + content_type == "application/json" + || content_type == "application/x-www-form-urlencoded" + || content_type.starts_with("text/") +} + +fn body_bytes_contain_reserved_marker(body: &[u8]) -> bool { + if body.is_empty() { + return false; + } + String::from_utf8_lossy(body) + .split('\0') + .any(contains_reserved_credential_marker) +} + +fn set_content_length(headers: &[u8], len: usize) -> Result> { + use std::fmt::Write as _; + + let header_str = + std::str::from_utf8(headers).map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let mut out = String::with_capacity(header_str.len() + 32); + let mut inserted = false; + for line in header_str.split("\r\n") { + if line.is_empty() { + if !inserted { + let _ = write!(out, "Content-Length: {len}\r\n"); + } + out.push_str("\r\n"); + break; + } + if line + .split_once(':') + .is_some_and(|(name, _)| name.trim().eq_ignore_ascii_case("content-length")) + { + if !inserted { + let _ = write!(out, "Content-Length: {len}\r\n"); + inserted = true; + } + continue; + } + out.push_str(line); + out.push_str("\r\n"); + } + Ok(out.into_bytes()) +} + pub(crate) fn request_is_websocket_upgrade(raw_header: &[u8]) -> bool { let header_end = raw_header .windows(4) @@ -4009,6 +4315,178 @@ mod tests { Ok(forwarded) } + async fn relay_and_capture_with_options( + raw_header: Vec, + body_length: BodyLength, + resolver: Option<&SecretResolver>, + request_body_credential_rewrite: bool, + ) -> Result { + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let header_str = String::from_utf8_lossy(&raw_header); + let first_line = header_str.lines().next().unwrap_or(""); + let parts: Vec<&str> = first_line.splitn(3, ' ').collect(); + let action = parts.first().unwrap_or(&"GET").to_string(); + let target = parts.get(1).unwrap_or(&"/").to_string(); + + let req = L7Request { + action, + target, + query_params: HashMap::new(), + raw_header, + body_length, + }; + + let upstream_task = tokio::spawn(async move { + let mut buf = vec![0u8; 8192]; + let mut total = 0usize; + let mut header_end = None; + let mut expected_total = None; + loop { + let n = upstream_side.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if header_end.is_none() + && let Some(end) = buf[..total].windows(4).position(|w| w == b"\r\n\r\n") + { + let end = end + 4; + let headers = String::from_utf8_lossy(&buf[..end]); + let len = headers + .lines() + .find_map(|line| { + let (name, value) = line.split_once(':')?; + name.eq_ignore_ascii_case("content-length") + .then(|| value.trim().parse::().ok()) + .flatten() + }) + .unwrap_or(0); + header_end = Some(end); + expected_total = Some(end + len); + } + if expected_total.is_some_and(|expected| total >= expected) { + break; + } + } + upstream_side + .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok") + .await + .unwrap(); + upstream_side.flush().await.unwrap(); + String::from_utf8_lossy(&buf[..total]).to_string() + }); + + relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + resolver, + request_body_credential_rewrite, + ..Default::default() + }, + ) + .await?; + + upstream_task + .await + .map_err(|e| miette!("upstream task failed: {e}")) + } + + #[tokio::test] + async fn relay_request_body_rewrites_slack_header_and_urlencoded_token() { + let (_, resolver) = SecretResolver::from_provider_env( + [( + "SLACK_BOT_TOKEN".to_string(), + "xoxb-real-bot-token".to_string(), + )] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let alias = "xoxb-OPENSHELL-RESOLVE-ENV-SLACK_BOT_TOKEN"; + let body = format!("token={alias}&channel=C123"); + let raw = format!( + "POST /api/chat.postMessage HTTP/1.1\r\n\ + Host: slack.com\r\n\ + Authorization: Bearer {alias}\r\n\ + Content-Type: application/x-www-form-urlencoded\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body + ); + + let forwarded = relay_and_capture_with_options( + raw.into_bytes(), + BodyLength::ContentLength(body.len() as u64), + Some(&resolver), + true, + ) + .await + .expect("relay should succeed"); + + let expected_body = "token=xoxb-real-bot-token&channel=C123"; + assert!(forwarded.contains("Authorization: Bearer xoxb-real-bot-token\r\n")); + assert!(forwarded.contains(&format!("Content-Length: {}\r\n", expected_body.len()))); + assert!(forwarded.ends_with(expected_body)); + assert!(!forwarded.contains("OPENSHELL-RESOLVE-ENV")); + } + + #[tokio::test] + async fn relay_request_body_unresolved_alias_fails_before_upstream_write() { + let (_, resolver) = SecretResolver::from_provider_env( + [( + "SLACK_BOT_TOKEN".to_string(), + "xoxb-real-bot-token".to_string(), + )] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let body = "token=xapp-OPENSHELL-RESOLVE-ENV-SLACK_APP_TOKEN"; + let raw = format!( + "POST /api/apps.connections.open HTTP/1.1\r\n\ + Host: slack.com\r\n\ + Content-Type: application/x-www-form-urlencoded\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body + ); + let req = L7Request { + action: "POST".to_string(), + target: "/api/apps.connections.open".to_string(), + query_params: HashMap::new(), + raw_header: raw.into_bytes(), + body_length: BodyLength::ContentLength(body.len() as u64), + }; + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let err = relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + resolver: Some(&resolver), + request_body_credential_rewrite: true, + ..Default::default() + }, + ) + .await + .expect_err("unknown body alias should fail closed"); + + assert!(!err.to_string().contains("xoxb-real-bot-token")); + drop(proxy_to_upstream); + let mut forwarded = Vec::new(); + upstream_side.read_to_end(&mut forwarded).await.unwrap(); + assert!( + forwarded.is_empty(), + "failed body rewrite must not reach upstream" + ); + } + #[tokio::test] async fn relay_injects_bearer_header_credential() { let (child_env, resolver) = SecretResolver::from_provider_env( diff --git a/crates/openshell-sandbox/src/opa.rs b/crates/openshell-sandbox/src/opa.rs index c6cd32f8a..a9ab94a2b 100644 --- a/crates/openshell-sandbox/src/opa.rs +++ b/crates/openshell-sandbox/src/opa.rs @@ -1064,6 +1064,9 @@ fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> St if e.websocket_credential_rewrite { ep["websocket_credential_rewrite"] = true.into(); } + if e.request_body_credential_rewrite { + ep["request_body_credential_rewrite"] = true.into(); + } if !e.persisted_queries.is_empty() { ep["persisted_queries"] = e.persisted_queries.clone().into(); } @@ -2655,6 +2658,63 @@ network_policies: assert!(l7.websocket_credential_rewrite); } + #[test] + fn l7_endpoint_config_preserves_proto_request_body_credential_rewrite() { + let mut network_policies = std::collections::HashMap::new(); + network_policies.insert( + "slack".to_string(), + NetworkPolicyRule { + name: "slack".to_string(), + endpoints: vec![NetworkEndpoint { + host: "slack.com".to_string(), + port: 443, + protocol: "rest".to_string(), + enforcement: "enforce".to_string(), + access: "read-write".to_string(), + request_body_credential_rewrite: true, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/node".to_string(), + ..Default::default() + }], + }, + ); + let proto = ProtoSandboxPolicy { + version: 1, + filesystem: Some(ProtoFs { + include_workdir: true, + read_only: vec![], + read_write: vec![], + }), + landlock: Some(openshell_core::proto::LandlockPolicy { + compatibility: "best_effort".to_string(), + }), + process: Some(ProtoProc { + run_as_user: "sandbox".to_string(), + run_as_group: "sandbox".to_string(), + }), + network_policies, + }; + + let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); + let input = NetworkInput { + host: "slack.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + + let config = engine + .query_endpoint_config(&input) + .unwrap() + .expect("endpoint config"); + let l7 = crate::l7::parse_l7_config(&config).unwrap(); + assert!(l7.request_body_credential_rewrite); + } + #[test] fn l7_endpoint_config_none_for_l4_only() { let engine = l7_engine(); diff --git a/crates/openshell-sandbox/src/policy_local.rs b/crates/openshell-sandbox/src/policy_local.rs index d40bc31c9..165b0c1bd 100644 --- a/crates/openshell-sandbox/src/policy_local.rs +++ b/crates/openshell-sandbox/src/policy_local.rs @@ -620,6 +620,7 @@ fn network_endpoint_from_json( deny_rules, allow_encoded_slash: endpoint.allow_encoded_slash, websocket_credential_rewrite: false, + request_body_credential_rewrite: false, // GraphQL persisted-query knobs and path scoping default empty โ€” // agent proposals don't author them today. persisted_queries: String::new(), diff --git a/crates/openshell-sandbox/src/proxy.rs b/crates/openshell-sandbox/src/proxy.rs index 325c93859..dca522c12 100644 --- a/crates/openshell-sandbox/src/proxy.rs +++ b/crates/openshell-sandbox/src/proxy.rs @@ -10,7 +10,7 @@ use crate::opa::{NetworkAction, OpaEngine, PolicyGenerationGuard}; use crate::policy::ProxyPolicy; use crate::policy_local::{POLICY_LOCAL_HOST, PolicyLocalContext}; use crate::provider_credentials::ProviderCredentialState; -use crate::secrets::{SecretResolver, rewrite_header_line}; +use crate::secrets::{SecretResolver, rewrite_header_line_checked}; use miette::{IntoDiagnostic, Result}; use openshell_core::net::{is_always_blocked_ip, is_internal_ip}; use openshell_ocsf::{ @@ -2283,6 +2283,10 @@ fn rewrite_forward_request( .position(|w| w == b"\r\n\r\n") .map_or(used, |p| p + 4); let websocket_upgrade = crate::l7::rest::request_is_websocket_upgrade(&raw[..header_end]); + let upstream_path = match secret_resolver { + Some(resolver) => crate::secrets::rewrite_target_for_eval(path, resolver)?.resolved, + None => path.to_string(), + }; let header_str = String::from_utf8_lossy(&raw[..header_end]); let lines = header_str.split("\r\n").collect::>(); @@ -2299,7 +2303,7 @@ fn rewrite_forward_request( if parts.len() == 3 { output.extend_from_slice(parts[0].as_bytes()); output.push(b' '); - output.extend_from_slice(path.as_bytes()); + output.extend_from_slice(upstream_path.as_bytes()); output.push(b' '); output.extend_from_slice(parts[2].as_bytes()); } else { @@ -2335,10 +2339,10 @@ fn rewrite_forward_request( continue; } - let rewritten_line = secret_resolver.map_or_else( - || line.to_string(), - |resolver| rewrite_header_line(line, resolver), - ); + let rewritten_line = match secret_resolver { + Some(resolver) => rewrite_header_line_checked(line, resolver)?, + None => line.to_string(), + }; output.extend_from_slice(rewritten_line.as_bytes()); output.extend_from_slice(b"\r\n"); @@ -2367,7 +2371,9 @@ fn rewrite_forward_request( // Fail-closed: scan for any remaining unresolved placeholders if secret_resolver.is_some() { let output_str = String::from_utf8_lossy(&output); - if output_str.contains(crate::secrets::PLACEHOLDER_PREFIX_PUBLIC) { + if output_str.contains(crate::secrets::PLACEHOLDER_PREFIX_PUBLIC) + || output_str.contains(crate::secrets::PROVIDER_ALIAS_MARKER_PUBLIC) + { return Err(crate::secrets::UnresolvedPlaceholderError { location: "header" }); } } @@ -2375,14 +2381,20 @@ fn rewrite_forward_request( Ok(output) } +struct ForwardRelayOptions<'a> { + generation_guard: &'a PolicyGenerationGuard, + websocket_extensions: crate::l7::rest::WebSocketExtensionMode, + secret_resolver: Option<&'a SecretResolver>, + request_body_credential_rewrite: bool, +} + async fn relay_rewritten_forward_request( method: &str, path: &str, rewritten: Vec, client: &mut C, upstream: &mut U, - generation_guard: &PolicyGenerationGuard, - websocket_extensions: crate::l7::rest::WebSocketExtensionMode, + options: ForwardRelayOptions<'_>, ) -> Result where C: TokioAsyncRead + TokioAsyncWrite + Unpin, @@ -2408,9 +2420,10 @@ where client, upstream, crate::l7::rest::RelayRequestOptions { - resolver: None, - generation_guard: Some(generation_guard), - websocket_extensions, + resolver: options.secret_resolver, + generation_guard: Some(options.generation_guard), + websocket_extensions: options.websocket_extensions, + request_body_credential_rewrite: options.request_body_credential_rewrite, }, ) .await @@ -2673,6 +2686,7 @@ async fn handle_forward_proxy( let mut upstream_target = path.clone(); let mut websocket_extensions = crate::l7::rest::WebSocketExtensionMode::Preserve; let mut upgrade_options = crate::l7::relay::UpgradeRelayOptions::default(); + let mut request_body_credential_rewrite = false; // 4b. If the endpoint has L7 config, evaluate the request against // L7 policy. The forward proxy handles exactly one request per @@ -2817,6 +2831,8 @@ async fn handle_forward_proxy( websocket_request, secret_resolver.clone(), ); + request_body_credential_rewrite = l7_config.config.protocol == crate::l7::L7Protocol::Rest + && l7_config.config.request_body_credential_rewrite; upgrade_options.policy_name = matched_policy.clone().unwrap_or_default(); let graphql = if l7_config.config.protocol == crate::l7::L7Protocol::Graphql { let header_end = forward_request_bytes @@ -3280,8 +3296,12 @@ async fn handle_forward_proxy( rewritten, client, &mut upstream, - &forward_generation_guard, - websocket_extensions, + ForwardRelayOptions { + generation_guard: &forward_generation_guard, + websocket_extensions, + secret_resolver: secret_resolver.as_deref(), + request_body_credential_rewrite, + }, ) .await?; if let crate::l7::provider::RelayOutcome::Upgraded { @@ -3384,6 +3404,7 @@ mod tests { graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, allow_encoded_slash: false, websocket_credential_rewrite, + request_body_credential_rewrite: false, websocket_graphql_policy: false, } } @@ -3438,6 +3459,7 @@ mod tests { graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, allow_encoded_slash: false, websocket_credential_rewrite: false, + request_body_credential_rewrite: false, websocket_graphql_policy: false, }, }, @@ -3450,6 +3472,7 @@ mod tests { graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, allow_encoded_slash: false, websocket_credential_rewrite: false, + request_body_credential_rewrite: false, websocket_graphql_policy: false, }, }, @@ -4539,8 +4562,12 @@ mod tests { rewritten, &mut proxy_to_client, &mut proxy_to_upstream, - &guard, - crate::l7::rest::WebSocketExtensionMode::Preserve, + ForwardRelayOptions { + generation_guard: &guard, + websocket_extensions: crate::l7::rest::WebSocketExtensionMode::Preserve, + secret_resolver: None, + request_body_credential_rewrite: false, + }, ) .await; assert!( @@ -4578,8 +4605,12 @@ mod tests { rewritten, &mut proxy_to_client, &mut proxy_to_upstream, - &guard, - crate::l7::rest::WebSocketExtensionMode::Preserve, + ForwardRelayOptions { + generation_guard: &guard, + websocket_extensions: crate::l7::rest::WebSocketExtensionMode::Preserve, + secret_resolver: None, + request_body_credential_rewrite: false, + }, ) .await; assert!(result.is_err(), "forward relay must reject CL/TE ambiguity"); diff --git a/crates/openshell-sandbox/src/secrets.rs b/crates/openshell-sandbox/src/secrets.rs index f2cacf44c..3f9b74346 100644 --- a/crates/openshell-sandbox/src/secrets.rs +++ b/crates/openshell-sandbox/src/secrets.rs @@ -6,9 +6,13 @@ use std::collections::HashMap; use std::fmt; const PLACEHOLDER_PREFIX: &str = "openshell:resolve:env:"; +const PROVIDER_ALIAS_MARKER: &str = "OPENSHELL-RESOLVE-ENV-"; +const SLACK_BOT_ALIAS_PREFIX: &str = "xoxb-OPENSHELL-RESOLVE-ENV-"; +const SLACK_APP_ALIAS_PREFIX: &str = "xapp-OPENSHELL-RESOLVE-ENV-"; /// Public access to the placeholder prefix for fail-closed scanning in other modules. pub const PLACEHOLDER_PREFIX_PUBLIC: &str = PLACEHOLDER_PREFIX; +pub const PROVIDER_ALIAS_MARKER_PUBLIC: &str = PROVIDER_ALIAS_MARKER; /// Characters that are valid in an env var key name (used to extract /// placeholder boundaries within concatenated strings like path segments). @@ -16,6 +20,22 @@ fn is_env_key_char(b: u8) -> bool { b.is_ascii_alphanumeric() || b == b'_' } +fn is_alias_token_char(b: u8) -> bool { + b.is_ascii_alphanumeric() || b == b'_' || b == b'-' +} + +fn contains_raw_reserved_marker(value: &str) -> bool { + value.contains(PLACEHOLDER_PREFIX) || value.contains(PROVIDER_ALIAS_MARKER) +} + +pub fn contains_reserved_credential_marker(value: &str) -> bool { + if contains_raw_reserved_marker(value) { + return true; + } + let decoded = percent_decode(value); + contains_raw_reserved_marker(&decoded) +} + // --------------------------------------------------------------------------- // Error and result types // --------------------------------------------------------------------------- @@ -91,11 +111,13 @@ impl SecretResolver { for (key, value) in provider_env { let placeholder = placeholder_for_env_key_for_revision(&key, revision); let canonical_placeholder = (revision != 0).then(|| placeholder_for_env_key(&key)); - child_env.insert(key, placeholder.clone()); + child_env.insert(key.clone(), placeholder.clone()); by_placeholder.insert(placeholder, value.clone()); if let Some(canonical_placeholder) = canonical_placeholder { - by_placeholder.insert(canonical_placeholder, value); + by_placeholder.insert(canonical_placeholder, value.clone()); } + by_placeholder.insert(slack_bot_alias_for_env_key(&key), value.clone()); + by_placeholder.insert(slack_app_alias_for_env_key(&key), value); } (child_env, Some(Self { by_placeholder })) @@ -132,10 +154,13 @@ impl SecretResolver { } } - pub(crate) fn rewrite_header_value(&self, value: &str) -> Option { + pub(crate) fn rewrite_header_value( + &self, + value: &str, + ) -> Result, UnresolvedPlaceholderError> { // Direct placeholder match: `x-api-key: openshell:resolve:env:KEY` if let Some(secret) = self.resolve_placeholder(value.trim()) { - return Some(secret.to_string()); + return Ok(Some(secret.to_string())); } let trimmed = value.trim(); @@ -146,29 +171,37 @@ impl SecretResolver { .strip_prefix("Basic ") .or_else(|| trimmed.strip_prefix("basic ")) .map(str::trim) - && let Some(rewritten) = self.rewrite_basic_auth_token(encoded) + && let Some(rewritten) = self.rewrite_basic_auth_token(encoded)? { - return Some(format!("Basic {rewritten}")); + return Ok(Some(format!("Basic {rewritten}"))); } // Prefixed placeholder: `Bearer openshell:resolve:env:KEY` - let split_at = trimmed.find(char::is_whitespace)?; + let Some(split_at) = trimmed.find(char::is_whitespace) else { + if contains_reserved_credential_marker(trimmed) { + return Err(UnresolvedPlaceholderError { location: "header" }); + } + return Ok(None); + }; let prefix = &trimmed[..split_at]; let candidate = trimmed[split_at..].trim(); - let secret = self.resolve_placeholder(candidate)?; - Some(format!("{prefix} {secret}")) + if let Some(secret) = self.resolve_placeholder(candidate) { + return Ok(Some(format!("{prefix} {secret}"))); + } + + if contains_reserved_credential_marker(candidate) { + return Err(UnresolvedPlaceholderError { location: "header" }); + } + + Ok(None) } - /// Rewrite credential placeholders inside a WebSocket text message. - /// - /// The message is mutated only after all placeholders resolve - /// successfully. The return value is the number of replacements; callers - /// must not log the rewritten text. - pub(crate) fn rewrite_websocket_text_placeholders( + pub(crate) fn rewrite_text_placeholders( &self, text: &mut String, + location: &'static str, ) -> Result { - if !text.contains(PLACEHOLDER_PREFIX) { + if !contains_raw_reserved_marker(text) { return Ok(0); } @@ -177,55 +210,80 @@ impl SecretResolver { let mut replacements = 0; while pos < text.len() { - let Some(start) = text[pos..].find(PLACEHOLDER_PREFIX) else { + let next_canonical = text[pos..].find(PLACEHOLDER_PREFIX).map(|p| pos + p); + let next_alias = text[pos..].find(PROVIDER_ALIAS_MARKER).map(|marker_pos| { + let marker_abs = pos + marker_pos; + alias_start_for_marker(text, marker_abs).unwrap_or(marker_abs) + }); + let Some(abs_start) = [next_canonical, next_alias].into_iter().flatten().min() else { rewritten.push_str(&text[pos..]); break; }; - let abs_start = pos + start; + rewritten.push_str(&text[pos..abs_start]); - let key_start = abs_start + PLACEHOLDER_PREFIX.len(); - if let Some((key_end, full_placeholder)) = - self.longest_known_placeholder_match(text, abs_start) - { - let Some(secret) = self.resolve_placeholder(full_placeholder) else { - return Err(UnresolvedPlaceholderError { - location: "websocket", - }); + if text[abs_start..].starts_with(PLACEHOLDER_PREFIX) { + let Some((token_end, token)) = self.credential_token_at(text, abs_start) else { + return Err(UnresolvedPlaceholderError { location }); + }; + let Some(secret) = self.resolve_placeholder(token) else { + return Err(UnresolvedPlaceholderError { location }); }; rewritten.push_str(secret); replacements += 1; - pos = key_end; + pos = token_end; continue; } - let key_end = text[key_start..] - .bytes() - .position(|b| !is_env_key_char(b)) - .map_or(text.len(), |p| key_start + p); - - if key_end == key_start { - return Err(UnresolvedPlaceholderError { - location: "websocket", - }); + if text[abs_start..].starts_with(SLACK_BOT_ALIAS_PREFIX) + || text[abs_start..].starts_with(SLACK_APP_ALIAS_PREFIX) + { + let Some((token_end, token)) = self.credential_token_at(text, abs_start) else { + return Err(UnresolvedPlaceholderError { location }); + }; + let Some(secret) = self.resolve_placeholder(token) else { + return Err(UnresolvedPlaceholderError { location }); + }; + rewritten.push_str(secret); + replacements += 1; + pos = token_end; + continue; } - let full_placeholder = &text[abs_start..key_end]; - let Some(secret) = self.resolve_placeholder(full_placeholder) else { - return Err(UnresolvedPlaceholderError { - location: "websocket", - }); - }; - rewritten.push_str(secret); - replacements += 1; - pos = key_end; + return Err(UnresolvedPlaceholderError { location }); + } + + if contains_raw_reserved_marker(&rewritten) { + return Err(UnresolvedPlaceholderError { location }); } *text = rewritten; Ok(replacements) } - fn longest_known_placeholder_match<'a>( + /// Rewrite credential placeholders inside a WebSocket text message. + /// + /// The message is mutated only after all placeholders resolve + /// successfully. The return value is the number of replacements; callers + /// must not log the rewritten text. + pub(crate) fn rewrite_websocket_text_placeholders( + &self, + text: &mut String, + ) -> Result { + self.rewrite_text_placeholders(text, "websocket") + } + + fn credential_token_at<'a>( + &'a self, + text: &'a str, + abs_start: usize, + ) -> Option<(usize, &'a str)> { + self.longest_known_token_match(text, abs_start) + .or_else(|| canonical_token_at(text, abs_start)) + .or_else(|| alias_token_at(text, abs_start)) + } + + fn longest_known_token_match<'a>( &'a self, text: &str, abs_start: usize, @@ -238,9 +296,7 @@ impl SecretResolver { return None; } let key_end = abs_start + placeholder.len(); - let boundary_ok = key_end == text.len() - || !is_env_key_char(text.as_bytes()[key_end]) - || text[key_end..].starts_with(PLACEHOLDER_PREFIX); + let boundary_ok = token_boundary_ok(text, abs_start, key_end, placeholder); boundary_ok.then_some((key_end, placeholder.as_str())) }) .max_by_key(|(_, placeholder)| placeholder.len()) @@ -250,45 +306,100 @@ impl SecretResolver { /// the decoded `username:password` string, and re-encode. /// /// Returns `None` if decoding fails or no placeholders are found. - fn rewrite_basic_auth_token(&self, encoded: &str) -> Option { + fn rewrite_basic_auth_token( + &self, + encoded: &str, + ) -> Result, UnresolvedPlaceholderError> { let b64 = base64::engine::general_purpose::STANDARD; - let decoded_bytes = b64.decode(encoded.trim()).ok()?; - let decoded = std::str::from_utf8(&decoded_bytes).ok()?; - - // Check if the decoded string contains any placeholder - if !decoded.contains(PLACEHOLDER_PREFIX) { - return None; + let Some(decoded_bytes) = b64.decode(encoded.trim()).ok() else { + return Ok(None); + }; + let Some(decoded) = std::str::from_utf8(&decoded_bytes).ok() else { + return Ok(None); + }; + + if !contains_raw_reserved_marker(decoded) { + return Ok(None); } - // Rewrite all placeholder occurrences in the decoded string let mut rewritten = decoded.to_string(); - for (placeholder, secret) in &self.by_placeholder { - if rewritten.contains(placeholder.as_str()) { - // Validate the resolved secret for control characters - if validate_resolved_secret(secret).is_err() { - tracing::warn!( - location = "basic_auth", - "credential resolution rejected: resolved value contains prohibited characters" - ); - return None; - } - rewritten = rewritten.replace(placeholder.as_str(), secret); - } - } + let replacements = self.rewrite_text_placeholders(&mut rewritten, "header")?; - // Only return if we actually changed something - if rewritten == decoded { - return None; + if replacements == 0 { + return Ok(None); } - Some(b64.encode(rewritten.as_bytes())) + Ok(Some(b64.encode(rewritten.as_bytes()))) } } +fn alias_start_for_marker(text: &str, marker_abs: usize) -> Option { + marker_abs + .checked_sub("xoxb-".len()) + .filter(|start| text[*start..].starts_with(SLACK_BOT_ALIAS_PREFIX)) + .or_else(|| { + marker_abs + .checked_sub("xapp-".len()) + .filter(|start| text[*start..].starts_with(SLACK_APP_ALIAS_PREFIX)) + }) +} + +fn canonical_token_at(text: &str, abs_start: usize) -> Option<(usize, &str)> { + if !text[abs_start..].starts_with(PLACEHOLDER_PREFIX) { + return None; + } + let key_start = abs_start + PLACEHOLDER_PREFIX.len(); + let key_end = text[key_start..] + .bytes() + .position(|b| !is_env_key_char(b)) + .map_or(text.len(), |p| key_start + p); + (key_end > key_start).then_some((key_end, &text[abs_start..key_end])) +} + +fn alias_token_at(text: &str, abs_start: usize) -> Option<(usize, &str)> { + let prefix_len = if text[abs_start..].starts_with(SLACK_BOT_ALIAS_PREFIX) { + SLACK_BOT_ALIAS_PREFIX.len() + } else if text[abs_start..].starts_with(SLACK_APP_ALIAS_PREFIX) { + SLACK_APP_ALIAS_PREFIX.len() + } else { + return None; + }; + let key_start = abs_start + prefix_len; + let key_end = text[key_start..] + .bytes() + .position(|b| !is_env_key_char(b)) + .map_or(text.len(), |p| key_start + p); + if key_end == key_start { + return None; + } + let before_ok = abs_start == 0 || !is_alias_token_char(text.as_bytes()[abs_start - 1]); + let after_ok = key_end == text.len() || !is_alias_token_char(text.as_bytes()[key_end]); + (before_ok && after_ok).then_some((key_end, &text[abs_start..key_end])) +} + +fn token_boundary_ok(text: &str, abs_start: usize, token_end: usize, token: &str) -> bool { + if token.starts_with(PLACEHOLDER_PREFIX) { + return token_end == text.len() + || !is_env_key_char(text.as_bytes()[token_end]) + || text[token_end..].starts_with(PLACEHOLDER_PREFIX); + } + let before_ok = abs_start == 0 || !is_alias_token_char(text.as_bytes()[abs_start - 1]); + let after_ok = token_end == text.len() || !is_alias_token_char(text.as_bytes()[token_end]); + before_ok && after_ok +} + pub fn placeholder_for_env_key(key: &str) -> String { format!("{PLACEHOLDER_PREFIX}{key}") } +pub fn slack_bot_alias_for_env_key(key: &str) -> String { + format!("{SLACK_BOT_ALIAS_PREFIX}{key}") +} + +pub fn slack_app_alias_for_env_key(key: &str) -> String { + format!("{SLACK_APP_ALIAS_PREFIX}{key}") +} + pub fn placeholder_for_env_key_for_revision(key: &str, revision: u64) -> String { if revision == 0 { placeholder_for_env_key(key) @@ -478,8 +589,9 @@ fn rewrite_request_line( return unchanged(); }; - // Only rewrite if the URI contains a placeholder - if !uri.contains(PLACEHOLDER_PREFIX) { + // Only rewrite if the URI contains a placeholder or a provider-shaped + // credential alias, including percent-encoded canonical placeholders. + if !contains_reserved_credential_marker(uri) { return unchanged(); } @@ -535,10 +647,6 @@ fn rewrite_uri_path( path: &str, resolver: &SecretResolver, ) -> Result, UnresolvedPlaceholderError> { - if !path.contains(PLACEHOLDER_PREFIX) { - return Ok(None); - } - let segments: Vec<&str> = path.split('/').collect(); let mut resolved_segments = Vec::with_capacity(segments.len()); let mut redacted_segments = Vec::with_capacity(segments.len()); @@ -546,7 +654,7 @@ fn rewrite_uri_path( for segment in &segments { let decoded = percent_decode(segment); - if !decoded.contains(PLACEHOLDER_PREFIX) { + if !contains_raw_reserved_marker(&decoded) { resolved_segments.push(segment.to_string()); redacted_segments.push(segment.to_string()); continue; @@ -586,28 +694,23 @@ fn rewrite_path_segment( let bytes = segment.as_bytes(); while pos < bytes.len() { - if let Some(start) = segment[pos..].find(PLACEHOLDER_PREFIX) { - let abs_start = pos + start; + let next_canonical = segment[pos..].find(PLACEHOLDER_PREFIX).map(|p| pos + p); + let next_alias = segment[pos..] + .find(PROVIDER_ALIAS_MARKER) + .map(|marker_pos| { + let marker_abs = pos + marker_pos; + alias_start_for_marker(segment, marker_abs).unwrap_or(marker_abs) + }); + if let Some(abs_start) = [next_canonical, next_alias].into_iter().flatten().min() { // Copy literal prefix before the placeholder resolved.push_str(&segment[pos..abs_start]); redacted.push_str(&segment[pos..abs_start]); - // Extract the key name using the env var grammar: [A-Za-z_][A-Za-z0-9_]* - let key_start = abs_start + PLACEHOLDER_PREFIX.len(); - let key_end = segment[key_start..] - .bytes() - .position(|b| !is_env_key_char(b)) - .map_or(segment.len(), |p| key_start + p); - - if key_end == key_start { - // Empty key โ€” not a valid placeholder, copy literally - resolved.push_str(&segment[abs_start..abs_start + PLACEHOLDER_PREFIX.len()]); - redacted.push_str(&segment[abs_start..abs_start + PLACEHOLDER_PREFIX.len()]); - pos = abs_start + PLACEHOLDER_PREFIX.len(); - continue; - } - - let full_placeholder = &segment[abs_start..key_end]; + let Some((token_end, full_placeholder)) = canonical_token_at(segment, abs_start) + .or_else(|| alias_token_at(segment, abs_start)) + else { + return Err(UnresolvedPlaceholderError { location: "path" }); + }; if let Some(secret) = resolver.resolve_placeholder(full_placeholder) { validate_credential_for_path(secret).map_err(|reason| { tracing::warn!( @@ -622,7 +725,7 @@ fn rewrite_path_segment( } else { return Err(UnresolvedPlaceholderError { location: "path" }); } - pos = key_end; + pos = token_end; } else { // No more placeholders in remainder resolved.push_str(&segment[pos..]); @@ -641,7 +744,7 @@ fn rewrite_uri_query_params( query: &str, resolver: &SecretResolver, ) -> Result, UnresolvedPlaceholderError> { - if !query.contains(PLACEHOLDER_PREFIX) { + if !contains_reserved_credential_marker(query) { return Ok(None); } @@ -652,15 +755,18 @@ fn rewrite_uri_query_params( for param in query.split('&') { if let Some((key, value)) = param.split_once('=') { let decoded_value = percent_decode(value); - if let Some(secret) = resolver.resolve_placeholder(&decoded_value) { - resolved_params.push(format!("{key}={}", percent_encode_query(secret))); + if contains_raw_reserved_marker(&decoded_value) { + let mut rewritten = decoded_value.clone(); + let replacements = + resolver.rewrite_text_placeholders(&mut rewritten, "query_param")?; + if replacements == 0 || contains_raw_reserved_marker(&rewritten) { + return Err(UnresolvedPlaceholderError { + location: "query_param", + }); + } + resolved_params.push(format!("{key}={}", percent_encode_query(&rewritten))); redacted_params.push(format!("{key}=[CREDENTIAL]")); any_rewritten = true; - } else if decoded_value.contains(PLACEHOLDER_PREFIX) { - // Placeholder detected but not resolved - return Err(UnresolvedPlaceholderError { - location: "query_param", - }); } else { resolved_params.push(param.to_string()); redacted_params.push(param.to_string()); @@ -730,41 +836,42 @@ pub fn rewrite_http_header_block( break; } - output.extend_from_slice(rewrite_header_line(line, resolver).as_bytes()); + output.extend_from_slice(rewrite_header_line_checked(line, resolver)?.as_bytes()); output.extend_from_slice(b"\r\n"); } output.extend_from_slice(b"\r\n"); output.extend_from_slice(&raw[header_end..]); - // Fail-closed scan: check for any remaining unresolved placeholders - // in both raw form and percent-decoded form of the output header block. + // Fail-closed scan: check for any remaining unresolved placeholders or + // provider-shaped aliases in both raw and percent-decoded header bytes. let output_header = String::from_utf8_lossy(&output[..output.len().min(header_end + 256)]); - if output_header.contains(PLACEHOLDER_PREFIX) { + if contains_reserved_credential_marker(&output_header) { return Err(UnresolvedPlaceholderError { location: "header" }); } - // Also check percent-decoded form of the request line (F5 โ€” encoded placeholder bypass) - let rewritten_rl = output_header.split("\r\n").next().unwrap_or(""); - let decoded_rl = percent_decode(rewritten_rl); - if decoded_rl.contains(PLACEHOLDER_PREFIX) { - return Err(UnresolvedPlaceholderError { location: "path" }); - } - Ok(RewriteResult { rewritten: output, redacted_target: rl_result.redacted_target, }) } +#[cfg_attr(not(test), allow(dead_code))] pub fn rewrite_header_line(line: &str, resolver: &SecretResolver) -> String { + rewrite_header_line_checked(line, resolver).unwrap_or_else(|_| line.to_string()) +} + +pub fn rewrite_header_line_checked( + line: &str, + resolver: &SecretResolver, +) -> Result { let Some((name, value)) = line.split_once(':') else { - return line.to_string(); + return Ok(line.to_string()); }; - resolver.rewrite_header_value(value.trim()).map_or_else( - || line.to_string(), - |rewritten| format!("{name}: {rewritten}"), + resolver.rewrite_header_value(value.trim())?.map_or_else( + || Ok(line.to_string()), + |rewritten| Ok(format!("{name}: {rewritten}")), ) } @@ -779,12 +886,7 @@ pub fn rewrite_target_for_eval( target: &str, resolver: &SecretResolver, ) -> Result { - if !target.contains(PLACEHOLDER_PREFIX) { - // Also check percent-decoded form - let decoded = percent_decode(target); - if decoded.contains(PLACEHOLDER_PREFIX) { - return Err(UnresolvedPlaceholderError { location: "path" }); - } + if !contains_reserved_credential_marker(target) { return Ok(RewriteTargetResult { resolved: target.to_string(), redacted: target.to_string(), @@ -891,6 +993,50 @@ mod tests { ); } + #[test] + fn rewrites_slack_shaped_alias_header_values() { + let (_, resolver) = SecretResolver::from_provider_env( + [ + ("SLACK_BOT_TOKEN".to_string(), "xoxb-real-bot".to_string()), + ("SLACK_APP_TOKEN".to_string(), "xapp-real-app".to_string()), + ] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + + assert_eq!( + rewrite_header_line( + "Authorization: Bearer xoxb-OPENSHELL-RESOLVE-ENV-SLACK_BOT_TOKEN", + &resolver, + ), + "Authorization: Bearer xoxb-real-bot" + ); + assert_eq!( + rewrite_header_line( + "Authorization: Bearer xapp-OPENSHELL-RESOLVE-ENV-SLACK_APP_TOKEN", + &resolver, + ), + "Authorization: Bearer xapp-real-app" + ); + } + + #[test] + fn unresolved_slack_shaped_alias_fails_closed() { + let (_, resolver) = SecretResolver::from_provider_env( + [("SLACK_BOT_TOKEN".to_string(), "xoxb-real-bot".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let raw = b"GET / HTTP/1.1\r\nAuthorization: Bearer xapp-OPENSHELL-RESOLVE-ENV-SLACK_APP_TOKEN\r\n\r\n"; + + let err = rewrite_http_header_block(raw, Some(&resolver)) + .expect_err("unknown alias should fail closed"); + + assert_eq!(err.location, "header"); + } + #[test] fn rewrites_http_header_blocks_and_preserves_body() { let (_, resolver) = SecretResolver::from_provider_env( @@ -1501,6 +1647,29 @@ mod tests { ); } + #[test] + fn percent_encoded_canonical_placeholder_in_query_rewrites() { + let (_, resolver) = SecretResolver::from_provider_env( + [("SLACK_BOT_TOKEN".to_string(), "xoxb-real-bot".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let encoded = "openshell%3Aresolve%3Aenv%3ASLACK_BOT_TOKEN"; + let raw = format!("GET /api?token={encoded} HTTP/1.1\r\nHost: x\r\n\r\n"); + + let result = + rewrite_http_header_block(raw.as_bytes(), Some(&resolver)).expect("should rewrite"); + let rewritten = String::from_utf8(result.rewritten).expect("utf8"); + + assert!(rewritten.starts_with("GET /api?token=xoxb-real-bot HTTP/1.1")); + assert!(!rewritten.contains("openshell")); + assert_eq!( + result.redacted_target.as_deref(), + Some("/api?token=[CREDENTIAL]") + ); + } + #[test] fn all_resolved_succeeds() { let (child_env, resolver) = SecretResolver::from_provider_env( @@ -1561,6 +1730,24 @@ mod tests { assert!(!payload.contains(PLACEHOLDER_PREFIX)); } + #[test] + fn rewrite_websocket_text_replaces_slack_shaped_alias() { + let (_, resolver) = SecretResolver::from_provider_env( + [("SLACK_APP_TOKEN".to_string(), "xapp-real-app".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let mut payload = r#"{"token":"xapp-OPENSHELL-RESOLVE-ENV-SLACK_APP_TOKEN"}"#.to_string(); + + let count = resolver + .rewrite_websocket_text_placeholders(&mut payload) + .expect("alias should rewrite"); + + assert_eq!(count, 1); + assert_eq!(payload, r#"{"token":"xapp-real-app"}"#); + } + #[test] fn rewrite_websocket_text_without_placeholder_is_unchanged() { let (_, resolver) = SecretResolver::from_provider_env( diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index bdc8d8c1a..885dbc9ad 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -219,6 +219,9 @@ fn summarize_endpoint(endpoint: &NetworkEndpoint) -> String { if endpoint.websocket_credential_rewrite { parts.push("websocket_credential_rewrite=true".to_string()); } + if endpoint.request_body_credential_rewrite { + parts.push("request_body_credential_rewrite=true".to_string()); + } if !endpoint.allowed_ips.is_empty() { parts.push(format!("allowed_ips={}", endpoint.allowed_ips.len())); } @@ -4349,6 +4352,34 @@ mod tests { ); } + #[test] + fn summarize_cli_policy_merge_op_formats_request_body_credential_rewrite() { + let operation = PolicyMergeOp::AddRule { + rule_name: "slack_api".to_string(), + rule: NetworkPolicyRule { + name: "slack_api".to_string(), + endpoints: vec![NetworkEndpoint { + host: "slack.com".to_string(), + port: 443, + protocol: "rest".to_string(), + access: "read-write".to_string(), + enforcement: "enforce".to_string(), + request_body_credential_rewrite: true, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/node".to_string(), + ..Default::default() + }], + }, + }; + + assert_eq!( + summarize_cli_policy_merge_op(&operation), + "add-endpoint slack_api endpoints=[slack.com:443 protocol=rest access=read-write enforcement=enforce request_body_credential_rewrite=true] binaries=[/usr/bin/node]" + ); + } + // ---- merge_chunk_into_policy ---- #[tokio::test] diff --git a/docs/reference/policy-schema.mdx b/docs/reference/policy-schema.mdx index 5865f78e1..7d38cd878 100644 --- a/docs/reference/policy-schema.mdx +++ b/docs/reference/policy-schema.mdx @@ -163,11 +163,14 @@ Each endpoint defines a reachable destination and optional inspection rules. | `deny_rules` | list of deny rule objects | No | L7 deny rules that block specific requests even when allowed by `access` or `rules`. Deny rules take precedence over allow rules. | | `allowed_ips` | list of string | No | CIDR or IP allowlist for SSRF override. Entries overlapping loopback (`127.0.0.0/8`), link-local (`169.254.0.0/16`), or unspecified (`0.0.0.0`) are rejected at load time. | | `allow_encoded_slash` | bool | No | When `true`, L7 request parsing preserves `%2F` inside path segments instead of rejecting it. Use this for registries and APIs such as npm scoped packages (`/@scope%2Fname`). Defaults to `false`. | -| `websocket_credential_rewrite` | bool | No | When `true` on a `protocol: rest` or `protocol: websocket` endpoint, OpenShell rewrites `openshell:resolve:env:*` placeholders in client-to-server WebSocket text messages after an allowed HTTP `101` upgrade. Binary frames are relayed but not rewritten. Defaults to `false`. | +| `websocket_credential_rewrite` | bool | No | When `true` on a `protocol: rest` or `protocol: websocket` endpoint, OpenShell rewrites credential placeholders in client-to-server WebSocket text messages after an allowed HTTP `101` upgrade. Binary frames are relayed but not rewritten. Defaults to `false`. | +| `request_body_credential_rewrite` | bool | No | When `true` on a `protocol: rest` endpoint, OpenShell rewrites credential placeholders in UTF-8 `application/json`, `application/x-www-form-urlencoded`, and `text/*` request bodies before forwarding upstream. The proxy buffers at most 256 KiB and updates `Content-Length` after rewriting. Defaults to `false`. | | `persisted_queries` | string | No | GraphQL hash-only behavior for `protocol: graphql` and GraphQL-over-WebSocket operation policy. Default is `deny`; use `allow_registered` only with `graphql_persisted_queries`. | | `graphql_persisted_queries` | map | No | Trusted GraphQL persisted-query registry keyed by hash or saved-query ID. Values contain `operation_type`, optional `operation_name`, and optional root `fields`. | | `graphql_max_body_bytes` | integer | No | Maximum GraphQL-over-HTTP request body bytes buffered for inspection. Defaults to `65536`. | +Credential rewrite recognizes the canonical `openshell:resolve:env:KEY` placeholder form and whole-token Slack-shaped aliases such as `xoxb-OPENSHELL-RESOLVE-ENV-SLACK_BOT_TOKEN` and `xapp-OPENSHELL-RESOLVE-ENV-SLACK_APP_TOKEN` when the referenced environment key exists in the configured provider credentials. + #### Access Levels The `access` field accepts one of the following values: diff --git a/docs/sandboxes/policies.mdx b/docs/sandboxes/policies.mdx index 8cc0ee586..9037e9687 100644 --- a/docs/sandboxes/policies.mdx +++ b/docs/sandboxes/policies.mdx @@ -49,7 +49,7 @@ network_policies: Static sections are locked at sandbox creation. Changing them requires destroying and recreating the sandbox. Dynamic sections can be updated on a running sandbox with `openshell policy update` for incremental merges or `openshell policy set` for full replacement, and take effect without restarting. When a hot reload changes rules on an active HTTP L7 endpoint, existing keep-alive tunnels are closed before forwarding another parsed request. Credential-injection-only HTTP passthrough tunnels use the same reload boundary. Most HTTP clients reconnect automatically, and the next request is evaluated against the current policy. -Raw streams are connection-scoped and outside L7 live-reload guarantees. This includes `tls: skip`, non-HTTP TCP payloads, HTTP upgrades such as WebSocket, and long-lived response streams such as SSE. A reload applies to the next connection or next parsed HTTP request; it does not interrupt an already-forwarded raw stream. Use `protocol: websocket` when policy should stay attached to the RFC 6455 upgrade and client text messages after the allowed upgrade. Add `websocket_credential_rewrite: true` only when the relay should rewrite credential placeholders in client-to-server WebSocket text messages. +Raw streams are connection-scoped and outside L7 live-reload guarantees. This includes `tls: skip`, non-HTTP TCP payloads, HTTP upgrades such as WebSocket, and long-lived response streams such as SSE. A reload applies to the next connection or next parsed HTTP request; it does not interrupt an already-forwarded raw stream. Use `protocol: websocket` when policy should stay attached to the RFC 6455 upgrade and client text messages after the allowed upgrade. Add `websocket_credential_rewrite: true` only when the relay should rewrite credential placeholders in client-to-server WebSocket text messages. Add `request_body_credential_rewrite: true` only on inspected REST endpoints that need OpenShell to rewrite placeholders in supported text request bodies. | Section | Type | Description | |---|---|---| @@ -226,7 +226,7 @@ Each segment has a fixed meaning: | `access` | No | Access preset for L7 endpoints: `read-only`, `read-write`, or `full`. Incremental updates expand presets into protocol-specific method/path rules for REST and WebSocket endpoints. | | `protocol` | No | L7 inspection mode: `rest`, `websocket`, or `sql`. `sql` is audit-only and not a recommended workflow today. | | `enforcement` | No | Enforcement mode for inspected traffic: `enforce` or `audit`. | -| `options` | No | Comma-separated endpoint options. Currently only `websocket-credential-rewrite` is supported, and only for `protocol: websocket` or REST compatibility endpoints that perform a WebSocket upgrade. | +| `options` | No | Comma-separated endpoint options. Use `websocket-credential-rewrite` with `protocol: websocket` or REST compatibility endpoints that perform a WebSocket upgrade. Use `request-body-credential-rewrite` only with `protocol: rest`. | Examples: @@ -234,6 +234,7 @@ Examples: |---|---| | `pypi.org:443` | Add a plain L4 endpoint. The proxy allows the TCP stream and does not inspect HTTP requests. | | `api.github.com:443:read-only:rest:enforce` | Add a REST endpoint with the `read-only` preset expanded by the policy engine into GET, HEAD, and OPTIONS access. | +| `slack.com:443:read-write:rest:enforce:request-body-credential-rewrite` | Add a REST endpoint that rewrites credential placeholders in supported text request bodies. | | `realtime.example.com:443:read-write:websocket:enforce` | Add a WebSocket endpoint with the `read-write` preset expanded by the policy engine into the upgrade `GET` and client `WEBSOCKET_TEXT` access. | | `realtime.example.com:443:read-write:websocket:enforce:websocket-credential-rewrite` | Add a WebSocket endpoint that rewrites `openshell:resolve:env:*` placeholders in client text frames after an allowed upgrade. | @@ -241,6 +242,10 @@ If you set `protocol: rest` or `protocol: websocket`, you also need an allow sha Use the `websocket-credential-rewrite` endpoint option with `protocol: websocket` when the sandbox should send credential placeholders in client text frames and have OpenShell resolve them after the allowed upgrade. The option can also be used with `protocol: rest` compatibility endpoints that perform a WebSocket upgrade. It is rejected for plain L4 or `protocol: sql` endpoints. +Use the `request-body-credential-rewrite` endpoint option with `protocol: rest` when an API expects OpenShell-managed credentials in UTF-8 JSON, form, or text request bodies. OpenShell buffers up to 256 KiB, rewrites recognized credential placeholders, updates `Content-Length`, and rejects unresolved placeholders instead of forwarding them. The option is rejected for WebSocket, GraphQL, SQL, and plain L4 endpoints. + +Credential rewrite recognizes the canonical `openshell:resolve:env:KEY` placeholder form and whole-token Slack-shaped aliases such as `xoxb-OPENSHELL-RESOLVE-ENV-SLACK_BOT_TOKEN` and `xapp-OPENSHELL-RESOLVE-ENV-SLACK_APP_TOKEN` when the referenced environment key exists in the configured provider credentials. + For example: - `api.github.com:443:read-only:rest` is valid. @@ -479,7 +484,7 @@ Allow `pip install` and `uv pip install` to reach PyPI: - { path: /usr/local/bin/uv } ``` -Endpoints without `protocol` use TCP passthrough, where the proxy allows the stream without inspecting payloads. If the stream is HTTP and TLS is auto-terminated, the proxy can still rewrite configured credential placeholders and closes keep-alive passthrough tunnels on policy reload before forwarding another request. WebSocket text-frame policy requires an explicit `protocol: websocket` endpoint. WebSocket payload credential rewrite can also be enabled on a `protocol: rest` compatibility endpoint with `websocket_credential_rewrite: true`. +Endpoints without `protocol` use TCP passthrough, where the proxy allows the stream without inspecting payloads. If the stream is HTTP and TLS is auto-terminated, the proxy can still rewrite configured credential placeholders and closes keep-alive passthrough tunnels on policy reload before forwarding another request. WebSocket text-frame policy requires an explicit `protocol: websocket` endpoint. WebSocket payload credential rewrite can also be enabled on a `protocol: rest` compatibility endpoint with `websocket_credential_rewrite: true`. REST request body credential rewrite requires an inspected `protocol: rest` endpoint with `request_body_credential_rewrite: true`. @@ -537,7 +542,7 @@ For an end-to-end walkthrough that combines this policy with a GitHub credential - { path: /usr/bin/gh } ``` -Endpoints with `protocol: rest` enable HTTP request inspection. Endpoints with `protocol: websocket` validate WebSocket upgrades and inspect client text messages on the upgraded request path. WebSocket endpoints can also classify GraphQL-over-WebSocket operation messages with the same operation rules used by GraphQL-over-HTTP. Endpoints with `protocol: graphql` parse GraphQL-over-HTTP payloads before evaluating rules. The endpoint-level `path` field lets these protocols share `api.github.com:443` without treating GraphQL payloads as plain REST `POST /graphql` requests. +Endpoints with `protocol: rest` enable HTTP request inspection and can opt in to supported text request body credential rewrite. Endpoints with `protocol: websocket` validate WebSocket upgrades and inspect client text messages on the upgraded request path. WebSocket endpoints can also classify GraphQL-over-WebSocket operation messages with the same operation rules used by GraphQL-over-HTTP. Endpoints with `protocol: graphql` parse GraphQL-over-HTTP payloads before evaluating rules. The endpoint-level `path` field lets these protocols share `api.github.com:443` without treating GraphQL payloads as plain REST `POST /graphql` requests. diff --git a/docs/security/best-practices.mdx b/docs/security/best-practices.mdx index 4571ed94c..f51781144 100644 --- a/docs/security/best-practices.mdx +++ b/docs/security/best-practices.mdx @@ -97,9 +97,9 @@ The `protocol` field on an endpoint controls whether the proxy inspects individu | Aspect | Detail | |---|---| | Default | Endpoints without a `protocol` field use L4-only enforcement: the proxy checks host, port, and binary, then relays the TCP stream without inspecting payloads. | -| What you can change | Add `protocol: rest` to enable per-request HTTP method/path inspection, `protocol: websocket` to inspect RFC 6455 upgrade handshakes and client text messages, or `protocol: graphql` to inspect GraphQL-over-HTTP operation type, operation name, and root fields. WebSocket endpoints can also use GraphQL operation rules for GraphQL-over-WebSocket messages. Pair inspected protocols with `rules` or access presets (`full`, `read-only`, `read-write`). | +| What you can change | Add `protocol: rest` to enable per-request HTTP method/path inspection, `protocol: websocket` to inspect RFC 6455 upgrade handshakes and client text messages, or `protocol: graphql` to inspect GraphQL-over-HTTP operation type, operation name, and root fields. WebSocket endpoints can also use GraphQL operation rules for GraphQL-over-WebSocket messages. Pair inspected protocols with `rules` or access presets (`full`, `read-only`, `read-write`). REST endpoints that need credential placeholders in supported text request bodies can set `request_body_credential_rewrite: true`. | | Risk if relaxed | L4-only endpoints allow the agent to send any data through the tunnel after the initial connection is permitted. The proxy cannot see HTTP methods, paths, or GraphQL operations. Adding `access: full` with L7 inspection enables observability but permits all inspected actions. | -| Recommendation | Use `protocol: rest` with specific `rules` for APIs where intent is encoded in method and path. Use `protocol: graphql` for GraphQL-over-HTTP APIs where destructive operations are body-encoded. Use `protocol: websocket` for RFC 6455 endpoints, with explicit `GET` and `WEBSOCKET_TEXT` rules for raw text protocols or explicit GraphQL operation rules for GraphQL-over-WebSocket. Prefer `access: read-only` or explicit allowlists, and deny hash-only persisted queries unless you maintain a trusted registry. Omit `protocol` for non-HTTP protocols. For WebSocket endpoints that must carry placeholder credentials in client text frames, add `websocket_credential_rewrite: true`. | +| Recommendation | Use `protocol: rest` with specific `rules` for APIs where intent is encoded in method and path. Add `request_body_credential_rewrite: true` only for REST APIs that require OpenShell-managed credentials in UTF-8 JSON, form, or text request bodies. Use `protocol: graphql` for GraphQL-over-HTTP APIs where destructive operations are body-encoded. Use `protocol: websocket` for RFC 6455 endpoints, with explicit `GET` and `WEBSOCKET_TEXT` rules for raw text protocols or explicit GraphQL operation rules for GraphQL-over-WebSocket. Prefer `access: read-only` or explicit allowlists, and deny hash-only persisted queries unless you maintain a trusted registry. Omit `protocol` for non-HTTP protocols. For WebSocket endpoints that must carry placeholder credentials in client text frames, add `websocket_credential_rewrite: true`. | ### Enforcement Mode (`audit` vs `enforce`) diff --git a/proto/sandbox.proto b/proto/sandbox.proto index db1b15448..b40d95cb1 100644 --- a/proto/sandbox.proto +++ b/proto/sandbox.proto @@ -120,6 +120,10 @@ message NetworkEndpoint { // inside client-to-server WebSocket text messages after an allowed HTTP 101 // upgrade. Defaults to false. bool websocket_credential_rewrite = 16; + // When true on a "rest" endpoint, OpenShell rewrites credential placeholders + // inside supported textual HTTP request bodies before forwarding upstream. + // Defaults to false. + bool request_body_credential_rewrite = 17; } // Trusted GraphQL operation classification. From e59ab4dbcf7f7753b25850d66a23127d107cdc85 Mon Sep 17 00:00:00 2001 From: Aaron Erickson Date: Sun, 10 May 2026 17:53:26 -0700 Subject: [PATCH 15/17] refactor(sandbox): generalize credential aliases Signed-off-by: Aaron Erickson --- crates/openshell-cli/src/policy_update.rs | 5 +- crates/openshell-sandbox/src/l7/rest.rs | 40 ++++---- crates/openshell-sandbox/src/secrets.rs | 113 +++++++++++----------- docs/reference/policy-schema.mdx | 2 +- docs/sandboxes/policies.mdx | 4 +- 5 files changed, 81 insertions(+), 83 deletions(-) diff --git a/crates/openshell-cli/src/policy_update.rs b/crates/openshell-cli/src/policy_update.rs index 9c9e91188..57656b878 100644 --- a/crates/openshell-cli/src/policy_update.rs +++ b/crates/openshell-cli/src/policy_update.rs @@ -581,7 +581,10 @@ mod tests { #[test] fn parse_add_endpoint_enables_request_body_credential_rewrite_on_rest_endpoint() { let plan = build_policy_update_plan( - &["slack.com:443:read-write:rest:enforce:request-body-credential-rewrite".to_string()], + &[ + "api.example.com:443:read-write:rest:enforce:request-body-credential-rewrite" + .to_string(), + ], &[], &[], &[], diff --git a/crates/openshell-sandbox/src/l7/rest.rs b/crates/openshell-sandbox/src/l7/rest.rs index b4e2d1675..ade126828 100644 --- a/crates/openshell-sandbox/src/l7/rest.rs +++ b/crates/openshell-sandbox/src/l7/rest.rs @@ -4396,21 +4396,18 @@ mod tests { } #[tokio::test] - async fn relay_request_body_rewrites_slack_header_and_urlencoded_token() { + async fn relay_request_body_rewrites_provider_alias_header_and_urlencoded_token() { let (_, resolver) = SecretResolver::from_provider_env( - [( - "SLACK_BOT_TOKEN".to_string(), - "xoxb-real-bot-token".to_string(), - )] - .into_iter() - .collect(), + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), ); let resolver = resolver.expect("resolver"); - let alias = "xoxb-OPENSHELL-RESOLVE-ENV-SLACK_BOT_TOKEN"; + let alias = "provider.v1-OPENSHELL-RESOLVE-ENV-API_TOKEN"; let body = format!("token={alias}&channel=C123"); let raw = format!( - "POST /api/chat.postMessage HTTP/1.1\r\n\ - Host: slack.com\r\n\ + "POST /api/messages HTTP/1.1\r\n\ + Host: api.example.com\r\n\ Authorization: Bearer {alias}\r\n\ Content-Type: application/x-www-form-urlencoded\r\n\ Content-Length: {}\r\n\r\n{}", @@ -4427,8 +4424,8 @@ mod tests { .await .expect("relay should succeed"); - let expected_body = "token=xoxb-real-bot-token&channel=C123"; - assert!(forwarded.contains("Authorization: Bearer xoxb-real-bot-token\r\n")); + let expected_body = "token=provider-real-token&channel=C123"; + assert!(forwarded.contains("Authorization: Bearer provider-real-token\r\n")); assert!(forwarded.contains(&format!("Content-Length: {}\r\n", expected_body.len()))); assert!(forwarded.ends_with(expected_body)); assert!(!forwarded.contains("OPENSHELL-RESOLVE-ENV")); @@ -4437,18 +4434,15 @@ mod tests { #[tokio::test] async fn relay_request_body_unresolved_alias_fails_before_upstream_write() { let (_, resolver) = SecretResolver::from_provider_env( - [( - "SLACK_BOT_TOKEN".to_string(), - "xoxb-real-bot-token".to_string(), - )] - .into_iter() - .collect(), + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), ); let resolver = resolver.expect("resolver"); - let body = "token=xapp-OPENSHELL-RESOLVE-ENV-SLACK_APP_TOKEN"; + let body = "token=provider-OPENSHELL-RESOLVE-ENV-APP_TOKEN"; let raw = format!( - "POST /api/apps.connections.open HTTP/1.1\r\n\ - Host: slack.com\r\n\ + "POST /api/connections.open HTTP/1.1\r\n\ + Host: api.example.com\r\n\ Content-Type: application/x-www-form-urlencoded\r\n\ Content-Length: {}\r\n\r\n{}", body.len(), @@ -4456,7 +4450,7 @@ mod tests { ); let req = L7Request { action: "POST".to_string(), - target: "/api/apps.connections.open".to_string(), + target: "/api/connections.open".to_string(), query_params: HashMap::new(), raw_header: raw.into_bytes(), body_length: BodyLength::ContentLength(body.len() as u64), @@ -4477,7 +4471,7 @@ mod tests { .await .expect_err("unknown body alias should fail closed"); - assert!(!err.to_string().contains("xoxb-real-bot-token")); + assert!(!err.to_string().contains("provider-real-token")); drop(proxy_to_upstream); let mut forwarded = Vec::new(); upstream_side.read_to_end(&mut forwarded).await.unwrap(); diff --git a/crates/openshell-sandbox/src/secrets.rs b/crates/openshell-sandbox/src/secrets.rs index 3f9b74346..54c43d07a 100644 --- a/crates/openshell-sandbox/src/secrets.rs +++ b/crates/openshell-sandbox/src/secrets.rs @@ -7,8 +7,6 @@ use std::fmt; const PLACEHOLDER_PREFIX: &str = "openshell:resolve:env:"; const PROVIDER_ALIAS_MARKER: &str = "OPENSHELL-RESOLVE-ENV-"; -const SLACK_BOT_ALIAS_PREFIX: &str = "xoxb-OPENSHELL-RESOLVE-ENV-"; -const SLACK_APP_ALIAS_PREFIX: &str = "xapp-OPENSHELL-RESOLVE-ENV-"; /// Public access to the placeholder prefix for fail-closed scanning in other modules. pub const PLACEHOLDER_PREFIX_PUBLIC: &str = PLACEHOLDER_PREFIX; @@ -21,7 +19,7 @@ fn is_env_key_char(b: u8) -> bool { } fn is_alias_token_char(b: u8) -> bool { - b.is_ascii_alphanumeric() || b == b'_' || b == b'-' + b.is_ascii_alphanumeric() || matches!(b, b'_' | b'-' | b'.' | b'~') } fn contains_raw_reserved_marker(value: &str) -> bool { @@ -51,7 +49,7 @@ impl fmt::Display for UnresolvedPlaceholderError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "unresolved credential placeholder in {}: detected openshell:resolve:env:* token that could not be resolved", + "unresolved credential placeholder in {}: detected reserved credential token that could not be resolved", self.location ) } @@ -116,8 +114,6 @@ impl SecretResolver { if let Some(canonical_placeholder) = canonical_placeholder { by_placeholder.insert(canonical_placeholder, value.clone()); } - by_placeholder.insert(slack_bot_alias_for_env_key(&key), value.clone()); - by_placeholder.insert(slack_app_alias_for_env_key(&key), value); } (child_env, Some(Self { by_placeholder })) @@ -140,7 +136,13 @@ impl SecretResolver { /// Returns `None` if the placeholder is unknown or the resolved value /// contains prohibited control characters (CRLF, null byte). pub(crate) fn resolve_placeholder(&self, value: &str) -> Option<&str> { - let secret = self.by_placeholder.get(value).map(String::as_str)?; + let secret = if let Some(secret) = self.by_placeholder.get(value) { + secret.as_str() + } else { + let key = alias_env_key(value)?; + let canonical = placeholder_for_env_key(key); + self.by_placeholder.get(&canonical).map(String::as_str)? + }; match validate_resolved_secret(secret) { Ok(s) => Some(s), Err(reason) => { @@ -213,7 +215,7 @@ impl SecretResolver { let next_canonical = text[pos..].find(PLACEHOLDER_PREFIX).map(|p| pos + p); let next_alias = text[pos..].find(PROVIDER_ALIAS_MARKER).map(|marker_pos| { let marker_abs = pos + marker_pos; - alias_start_for_marker(text, marker_abs).unwrap_or(marker_abs) + alias_start_for_marker(text, marker_abs) }); let Some(abs_start) = [next_canonical, next_alias].into_iter().flatten().min() else { rewritten.push_str(&text[pos..]); @@ -235,12 +237,7 @@ impl SecretResolver { continue; } - if text[abs_start..].starts_with(SLACK_BOT_ALIAS_PREFIX) - || text[abs_start..].starts_with(SLACK_APP_ALIAS_PREFIX) - { - let Some((token_end, token)) = self.credential_token_at(text, abs_start) else { - return Err(UnresolvedPlaceholderError { location }); - }; + if let Some((token_end, token)) = alias_token_at(text, abs_start) { let Some(secret) = self.resolve_placeholder(token) else { return Err(UnresolvedPlaceholderError { location }); }; @@ -333,15 +330,13 @@ impl SecretResolver { } } -fn alias_start_for_marker(text: &str, marker_abs: usize) -> Option { - marker_abs - .checked_sub("xoxb-".len()) - .filter(|start| text[*start..].starts_with(SLACK_BOT_ALIAS_PREFIX)) - .or_else(|| { - marker_abs - .checked_sub("xapp-".len()) - .filter(|start| text[*start..].starts_with(SLACK_APP_ALIAS_PREFIX)) - }) +fn alias_start_for_marker(text: &str, marker_abs: usize) -> usize { + let mut start = marker_abs; + let bytes = text.as_bytes(); + while start > 0 && is_alias_token_char(bytes[start - 1]) { + start -= 1; + } + start } fn canonical_token_at(text: &str, abs_start: usize) -> Option<(usize, &str)> { @@ -357,14 +352,12 @@ fn canonical_token_at(text: &str, abs_start: usize) -> Option<(usize, &str)> { } fn alias_token_at(text: &str, abs_start: usize) -> Option<(usize, &str)> { - let prefix_len = if text[abs_start..].starts_with(SLACK_BOT_ALIAS_PREFIX) { - SLACK_BOT_ALIAS_PREFIX.len() - } else if text[abs_start..].starts_with(SLACK_APP_ALIAS_PREFIX) { - SLACK_APP_ALIAS_PREFIX.len() - } else { + let suffix = &text[abs_start..]; + let marker_rel = suffix.find(PROVIDER_ALIAS_MARKER)?; + if marker_rel == 0 { return None; - }; - let key_start = abs_start + prefix_len; + } + let key_start = abs_start + marker_rel + PROVIDER_ALIAS_MARKER.len(); let key_end = text[key_start..] .bytes() .position(|b| !is_env_key_char(b)) @@ -377,6 +370,22 @@ fn alias_token_at(text: &str, abs_start: usize) -> Option<(usize, &str)> { (before_ok && after_ok).then_some((key_end, &text[abs_start..key_end])) } +fn alias_env_key(token: &str) -> Option<&str> { + let marker_start = token.find(PROVIDER_ALIAS_MARKER)?; + if marker_start == 0 { + return None; + } + if !token[..marker_start].bytes().all(is_alias_token_char) { + return None; + } + let key_start = marker_start + PROVIDER_ALIAS_MARKER.len(); + let key_end = token[key_start..] + .bytes() + .position(|b| !is_env_key_char(b)) + .map_or(token.len(), |p| key_start + p); + (key_end == token.len() && key_end > key_start).then_some(&token[key_start..key_end]) +} + fn token_boundary_ok(text: &str, abs_start: usize, token_end: usize, token: &str) -> bool { if token.starts_with(PLACEHOLDER_PREFIX) { return token_end == text.len() @@ -392,14 +401,6 @@ pub fn placeholder_for_env_key(key: &str) -> String { format!("{PLACEHOLDER_PREFIX}{key}") } -pub fn slack_bot_alias_for_env_key(key: &str) -> String { - format!("{SLACK_BOT_ALIAS_PREFIX}{key}") -} - -pub fn slack_app_alias_for_env_key(key: &str) -> String { - format!("{SLACK_APP_ALIAS_PREFIX}{key}") -} - pub fn placeholder_for_env_key_for_revision(key: &str, revision: u64) -> String { if revision == 0 { placeholder_for_env_key(key) @@ -699,7 +700,7 @@ fn rewrite_path_segment( .find(PROVIDER_ALIAS_MARKER) .map(|marker_pos| { let marker_abs = pos + marker_pos; - alias_start_for_marker(segment, marker_abs).unwrap_or(marker_abs) + alias_start_for_marker(segment, marker_abs) }); if let Some(abs_start) = [next_canonical, next_alias].into_iter().flatten().min() { // Copy literal prefix before the placeholder @@ -994,11 +995,11 @@ mod tests { } #[test] - fn rewrites_slack_shaped_alias_header_values() { + fn rewrites_provider_shaped_alias_header_values() { let (_, resolver) = SecretResolver::from_provider_env( [ - ("SLACK_BOT_TOKEN".to_string(), "xoxb-real-bot".to_string()), - ("SLACK_APP_TOKEN".to_string(), "xapp-real-app".to_string()), + ("API_TOKEN".to_string(), "provider-real-token".to_string()), + ("CHAT_APP_TOKEN".to_string(), "app-real-token".to_string()), ] .into_iter() .collect(), @@ -1007,29 +1008,29 @@ mod tests { assert_eq!( rewrite_header_line( - "Authorization: Bearer xoxb-OPENSHELL-RESOLVE-ENV-SLACK_BOT_TOKEN", + "Authorization: Bearer vendor-OPENSHELL-RESOLVE-ENV-API_TOKEN", &resolver, ), - "Authorization: Bearer xoxb-real-bot" + "Authorization: Bearer provider-real-token" ); assert_eq!( rewrite_header_line( - "Authorization: Bearer xapp-OPENSHELL-RESOLVE-ENV-SLACK_APP_TOKEN", + "x-app-token: token.v1-OPENSHELL-RESOLVE-ENV-CHAT_APP_TOKEN", &resolver, ), - "Authorization: Bearer xapp-real-app" + "x-app-token: app-real-token" ); } #[test] - fn unresolved_slack_shaped_alias_fails_closed() { + fn unresolved_provider_shaped_alias_fails_closed() { let (_, resolver) = SecretResolver::from_provider_env( - [("SLACK_BOT_TOKEN".to_string(), "xoxb-real-bot".to_string())] + [("API_TOKEN".to_string(), "provider-real-token".to_string())] .into_iter() .collect(), ); let resolver = resolver.expect("resolver"); - let raw = b"GET / HTTP/1.1\r\nAuthorization: Bearer xapp-OPENSHELL-RESOLVE-ENV-SLACK_APP_TOKEN\r\n\r\n"; + let raw = b"GET / HTTP/1.1\r\nAuthorization: Bearer vendor-OPENSHELL-RESOLVE-ENV-UNKNOWN_TOKEN\r\n\r\n"; let err = rewrite_http_header_block(raw, Some(&resolver)) .expect_err("unknown alias should fail closed"); @@ -1650,19 +1651,19 @@ mod tests { #[test] fn percent_encoded_canonical_placeholder_in_query_rewrites() { let (_, resolver) = SecretResolver::from_provider_env( - [("SLACK_BOT_TOKEN".to_string(), "xoxb-real-bot".to_string())] + [("API_TOKEN".to_string(), "provider-real-token".to_string())] .into_iter() .collect(), ); let resolver = resolver.expect("resolver"); - let encoded = "openshell%3Aresolve%3Aenv%3ASLACK_BOT_TOKEN"; + let encoded = "openshell%3Aresolve%3Aenv%3AAPI_TOKEN"; let raw = format!("GET /api?token={encoded} HTTP/1.1\r\nHost: x\r\n\r\n"); let result = rewrite_http_header_block(raw.as_bytes(), Some(&resolver)).expect("should rewrite"); let rewritten = String::from_utf8(result.rewritten).expect("utf8"); - assert!(rewritten.starts_with("GET /api?token=xoxb-real-bot HTTP/1.1")); + assert!(rewritten.starts_with("GET /api?token=provider-real-token HTTP/1.1")); assert!(!rewritten.contains("openshell")); assert_eq!( result.redacted_target.as_deref(), @@ -1731,21 +1732,21 @@ mod tests { } #[test] - fn rewrite_websocket_text_replaces_slack_shaped_alias() { + fn rewrite_websocket_text_replaces_provider_shaped_alias() { let (_, resolver) = SecretResolver::from_provider_env( - [("SLACK_APP_TOKEN".to_string(), "xapp-real-app".to_string())] + [("APP_TOKEN".to_string(), "app-real-token".to_string())] .into_iter() .collect(), ); let resolver = resolver.expect("resolver"); - let mut payload = r#"{"token":"xapp-OPENSHELL-RESOLVE-ENV-SLACK_APP_TOKEN"}"#.to_string(); + let mut payload = r#"{"token":"provider-OPENSHELL-RESOLVE-ENV-APP_TOKEN"}"#.to_string(); let count = resolver .rewrite_websocket_text_placeholders(&mut payload) .expect("alias should rewrite"); assert_eq!(count, 1); - assert_eq!(payload, r#"{"token":"xapp-real-app"}"#); + assert_eq!(payload, r#"{"token":"app-real-token"}"#); } #[test] diff --git a/docs/reference/policy-schema.mdx b/docs/reference/policy-schema.mdx index 7d38cd878..295f850df 100644 --- a/docs/reference/policy-schema.mdx +++ b/docs/reference/policy-schema.mdx @@ -169,7 +169,7 @@ Each endpoint defines a reachable destination and optional inspection rules. | `graphql_persisted_queries` | map | No | Trusted GraphQL persisted-query registry keyed by hash or saved-query ID. Values contain `operation_type`, optional `operation_name`, and optional root `fields`. | | `graphql_max_body_bytes` | integer | No | Maximum GraphQL-over-HTTP request body bytes buffered for inspection. Defaults to `65536`. | -Credential rewrite recognizes the canonical `openshell:resolve:env:KEY` placeholder form and whole-token Slack-shaped aliases such as `xoxb-OPENSHELL-RESOLVE-ENV-SLACK_BOT_TOKEN` and `xapp-OPENSHELL-RESOLVE-ENV-SLACK_APP_TOKEN` when the referenced environment key exists in the configured provider credentials. +Credential rewrite recognizes the canonical `openshell:resolve:env:KEY` placeholder form and whole-token provider-shaped aliases such as `provider-OPENSHELL-RESOLVE-ENV-API_TOKEN` when the referenced environment key exists in the configured provider credentials. #### Access Levels diff --git a/docs/sandboxes/policies.mdx b/docs/sandboxes/policies.mdx index 9037e9687..fb0b04cfe 100644 --- a/docs/sandboxes/policies.mdx +++ b/docs/sandboxes/policies.mdx @@ -234,7 +234,7 @@ Examples: |---|---| | `pypi.org:443` | Add a plain L4 endpoint. The proxy allows the TCP stream and does not inspect HTTP requests. | | `api.github.com:443:read-only:rest:enforce` | Add a REST endpoint with the `read-only` preset expanded by the policy engine into GET, HEAD, and OPTIONS access. | -| `slack.com:443:read-write:rest:enforce:request-body-credential-rewrite` | Add a REST endpoint that rewrites credential placeholders in supported text request bodies. | +| `api.example.com:443:read-write:rest:enforce:request-body-credential-rewrite` | Add a REST endpoint that rewrites credential placeholders in supported text request bodies. | | `realtime.example.com:443:read-write:websocket:enforce` | Add a WebSocket endpoint with the `read-write` preset expanded by the policy engine into the upgrade `GET` and client `WEBSOCKET_TEXT` access. | | `realtime.example.com:443:read-write:websocket:enforce:websocket-credential-rewrite` | Add a WebSocket endpoint that rewrites `openshell:resolve:env:*` placeholders in client text frames after an allowed upgrade. | @@ -244,7 +244,7 @@ Use the `websocket-credential-rewrite` endpoint option with `protocol: websocket Use the `request-body-credential-rewrite` endpoint option with `protocol: rest` when an API expects OpenShell-managed credentials in UTF-8 JSON, form, or text request bodies. OpenShell buffers up to 256 KiB, rewrites recognized credential placeholders, updates `Content-Length`, and rejects unresolved placeholders instead of forwarding them. The option is rejected for WebSocket, GraphQL, SQL, and plain L4 endpoints. -Credential rewrite recognizes the canonical `openshell:resolve:env:KEY` placeholder form and whole-token Slack-shaped aliases such as `xoxb-OPENSHELL-RESOLVE-ENV-SLACK_BOT_TOKEN` and `xapp-OPENSHELL-RESOLVE-ENV-SLACK_APP_TOKEN` when the referenced environment key exists in the configured provider credentials. +Credential rewrite recognizes the canonical `openshell:resolve:env:KEY` placeholder form and whole-token provider-shaped aliases such as `provider-OPENSHELL-RESOLVE-ENV-API_TOKEN` when the referenced environment key exists in the configured provider credentials. For example: From 59a51603713fa0c2b1c05cad02050daa83037925 Mon Sep 17 00:00:00 2001 From: Aaron Erickson Date: Mon, 11 May 2026 08:41:57 -0700 Subject: [PATCH 16/17] fix(sandbox): rewrite encoded form credentials Signed-off-by: Aaron Erickson --- crates/openshell-sandbox/src/l7/rest.rs | 204 +++++++++++++++++- crates/openshell-sandbox/src/proxy.rs | 261 ++++++++++++++++++++++-- 2 files changed, 448 insertions(+), 17 deletions(-) diff --git a/crates/openshell-sandbox/src/l7/rest.rs b/crates/openshell-sandbox/src/l7/rest.rs index ade126828..c513499f4 100644 --- a/crates/openshell-sandbox/src/l7/rest.rs +++ b/crates/openshell-sandbox/src/l7/rest.rs @@ -611,9 +611,16 @@ fn rewrite_buffered_body( "request body credential rewrite found placeholders but no resolver is available" )); }; - let replacements = resolver - .rewrite_text_placeholders(&mut text, "request_body") - .map_err(|e| miette!("credential injection failed: {e}"))?; + + let replacements = if content_type.as_deref() == Some("application/x-www-form-urlencoded") { + let (rewritten, replacements) = rewrite_form_urlencoded_body(&text, resolver)?; + text = rewritten; + replacements + } else { + resolver + .rewrite_text_placeholders(&mut text, "request_body") + .map_err(|e| miette!("credential injection failed: {e}"))? + }; if replacements == 0 || contains_reserved_credential_marker(&text) { return Err(miette!( "request body credential rewrite left unresolved credential placeholders" @@ -625,6 +632,112 @@ fn rewrite_buffered_body( Ok((headers, body)) } +fn rewrite_form_urlencoded_body(body: &str, resolver: &SecretResolver) -> Result<(String, usize)> { + let mut rewritten = String::with_capacity(body.len()); + let mut replacements = 0usize; + + for (idx, field) in body.split('&').enumerate() { + if idx > 0 { + rewritten.push('&'); + } + + let (name, value) = field + .split_once('=') + .map_or((field, None), |(name, value)| (name, Some(value))); + let decoded_name = form_url_decode(name)?; + if contains_reserved_credential_marker(&decoded_name) { + return Err(miette!( + "request body credential rewrite does not support placeholders in form field names" + )); + } + + rewritten.push_str(name); + let Some(value) = value else { + continue; + }; + + rewritten.push('='); + let decoded_value = form_url_decode(value)?; + if !contains_reserved_credential_marker(&decoded_value) { + rewritten.push_str(value); + continue; + } + + let mut rewritten_value = decoded_value; + let field_replacements = resolver + .rewrite_text_placeholders(&mut rewritten_value, "request_body") + .map_err(|e| miette!("credential injection failed: {e}"))?; + if field_replacements == 0 || contains_reserved_credential_marker(&rewritten_value) { + return Err(miette!( + "request body credential rewrite left unresolved credential placeholders" + )); + } + replacements += field_replacements; + rewritten.push_str(&form_url_encode(&rewritten_value)); + } + + Ok((rewritten, replacements)) +} + +fn form_url_decode(input: &str) -> Result { + let bytes = input.as_bytes(); + let mut decoded = Vec::with_capacity(bytes.len()); + let mut pos = 0usize; + + while pos < bytes.len() { + match bytes[pos] { + b'+' => { + decoded.push(b' '); + pos += 1; + } + b'%' if pos + 2 < bytes.len() => { + if let (Some(hi), Some(lo)) = (hex_value(bytes[pos + 1]), hex_value(bytes[pos + 2])) + { + decoded.push((hi << 4) | lo); + pos += 3; + } else { + decoded.push(bytes[pos]); + pos += 1; + } + } + byte => { + decoded.push(byte); + pos += 1; + } + } + } + + String::from_utf8(decoded).map_err(|_| { + miette!("request body credential rewrite requires UTF-8 form-url-encoded fields") + }) +} + +fn form_url_encode(input: &str) -> String { + let mut encoded = String::with_capacity(input.len()); + for byte in input.bytes() { + match byte { + b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'.' | b'_' | b'*' => { + encoded.push(byte as char); + } + b' ' => encoded.push('+'), + _ => { + use std::fmt::Write as _; + let _ = write!(encoded, "%{byte:02X}"); + } + } + } + encoded +} + +fn hex_value(byte: u8) -> Option { + match byte { + b'0'..=b'9' => Some(byte - b'0'), + b'a'..=b'f' => Some(byte - b'a' + 10), + b'A'..=b'F' => Some(byte - b'A' + 10), + _ => None, + } +} + async fn collect_chunked_body( client: &mut C, already_read: &[u8], @@ -4431,6 +4544,40 @@ mod tests { assert!(!forwarded.contains("OPENSHELL-RESOLVE-ENV")); } + #[tokio::test] + async fn relay_request_body_rewrites_percent_encoded_canonical_urlencoded_token() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let body = "token=openshell%3Aresolve%3Aenv%3AAPI_TOKEN¬e=hello+world"; + let raw = format!( + "POST /api/messages HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Content-Type: application/x-www-form-urlencoded\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body + ); + + let forwarded = relay_and_capture_with_options( + raw.into_bytes(), + BodyLength::ContentLength(body.len() as u64), + Some(&resolver), + true, + ) + .await + .expect("relay should succeed"); + + let expected_body = "token=provider-real-token¬e=hello+world"; + assert!(forwarded.contains(&format!("Content-Length: {}\r\n", expected_body.len()))); + assert!(forwarded.ends_with(expected_body)); + assert!(!forwarded.contains("openshell%3Aresolve%3Aenv%3AAPI_TOKEN")); + assert!(!forwarded.contains("openshell:resolve:env:API_TOKEN")); + } + #[tokio::test] async fn relay_request_body_unresolved_alias_fails_before_upstream_write() { let (_, resolver) = SecretResolver::from_provider_env( @@ -4481,6 +4628,57 @@ mod tests { ); } + #[tokio::test] + async fn relay_request_body_unresolved_encoded_canonical_fails_before_upstream_write() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let body = "token=openshell%3Aresolve%3Aenv%3AMISSING_TOKEN"; + let raw = format!( + "POST /api/messages HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Content-Type: application/x-www-form-urlencoded\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body + ); + let req = L7Request { + action: "POST".to_string(), + target: "/api/messages".to_string(), + query_params: HashMap::new(), + raw_header: raw.into_bytes(), + body_length: BodyLength::ContentLength(body.len() as u64), + }; + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let err = relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + resolver: Some(&resolver), + request_body_credential_rewrite: true, + ..Default::default() + }, + ) + .await + .expect_err("unknown encoded body placeholder should fail closed"); + + assert!(!err.to_string().contains("provider-real-token")); + assert!(!err.to_string().contains("MISSING_TOKEN")); + drop(proxy_to_upstream); + let mut forwarded = Vec::new(); + upstream_side.read_to_end(&mut forwarded).await.unwrap(); + assert!( + forwarded.is_empty(), + "failed body rewrite must not reach upstream" + ); + } + #[tokio::test] async fn relay_injects_bearer_header_credential() { let (child_env, resolver) = SecretResolver::from_provider_env( diff --git a/crates/openshell-sandbox/src/proxy.rs b/crates/openshell-sandbox/src/proxy.rs index dca522c12..0395562fb 100644 --- a/crates/openshell-sandbox/src/proxy.rs +++ b/crates/openshell-sandbox/src/proxy.rs @@ -2277,6 +2277,7 @@ fn rewrite_forward_request( used: usize, path: &str, secret_resolver: Option<&SecretResolver>, + request_body_credential_rewrite: bool, ) -> Result, crate::secrets::UnresolvedPlaceholderError> { let header_end = raw[..used] .windows(4) @@ -2362,6 +2363,7 @@ fn rewrite_forward_request( // End of headers output.extend_from_slice(b"\r\n"); + let rewritten_header_end = output.len(); // Append any overflow body bytes from the original buffer if header_end < used { @@ -2370,7 +2372,12 @@ fn rewrite_forward_request( // Fail-closed: scan for any remaining unresolved placeholders if secret_resolver.is_some() { - let output_str = String::from_utf8_lossy(&output); + let scan_end = if request_body_credential_rewrite { + rewritten_header_end + } else { + output.len() + }; + let output_str = String::from_utf8_lossy(&output[..scan_end]); if output_str.contains(crate::secrets::PLACEHOLDER_PREFIX_PUBLIC) || output_str.contains(crate::secrets::PROVIDER_ALIAS_MARKER_PUBLIC) { @@ -3254,6 +3261,7 @@ async fn handle_forward_proxy( forward_request_bytes.len(), &upstream_target, secret_resolver.as_deref(), + request_body_credential_rewrite, ) { Ok(bytes) => bytes, Err(e) => { @@ -3409,6 +3417,92 @@ mod tests { } } + fn forward_test_guard() -> PolicyGenerationGuard { + let policy = include_str!("../data/sandbox-policy.rego"); + let policy_data = "network_policies: {}\n"; + let engine = OpaEngine::from_strings(policy, policy_data).unwrap(); + engine + .generation_guard(engine.current_generation()) + .unwrap() + } + + async fn relay_forward_request_and_capture( + method: &str, + path: &str, + raw: &[u8], + resolver: Option<&SecretResolver>, + request_body_credential_rewrite: bool, + ) -> Result { + let guard = forward_test_guard(); + let rewritten = rewrite_forward_request( + raw, + raw.len(), + path, + resolver, + request_body_credential_rewrite, + ) + .map_err(|e| miette::miette!("{e}"))?; + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let upstream_task = tokio::spawn(async move { + let mut buf = vec![0u8; 8192]; + let mut total = 0usize; + let mut expected_total = None; + loop { + let n = upstream_side.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if expected_total.is_none() + && let Some(end) = buf[..total].windows(4).position(|w| w == b"\r\n\r\n") + { + let header_end = end + 4; + let headers = String::from_utf8_lossy(&buf[..header_end]); + let len = headers + .lines() + .find_map(|line| { + let (name, value) = line.split_once(':')?; + name.eq_ignore_ascii_case("content-length") + .then(|| value.trim().parse::().ok()) + .flatten() + }) + .unwrap_or(0); + expected_total = Some(header_end + len); + } + if expected_total.is_some_and(|expected| total >= expected) { + break; + } + } + upstream_side + .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok") + .await + .unwrap(); + upstream_side.flush().await.unwrap(); + String::from_utf8_lossy(&buf[..total]).to_string() + }); + + relay_rewritten_forward_request( + method, + path, + rewritten, + &mut proxy_to_client, + &mut proxy_to_upstream, + ForwardRelayOptions { + generation_guard: &guard, + websocket_extensions: crate::l7::rest::WebSocketExtensionMode::Preserve, + secret_resolver: resolver, + request_body_credential_rewrite, + }, + ) + .await?; + + upstream_task + .await + .map_err(|e| miette::miette!("upstream task failed: {e}")) + } + #[test] fn forward_websocket_upgrade_enables_rewrite_for_native_websocket_endpoint() { let (_, resolver) = SecretResolver::from_provider_env( @@ -4400,7 +4494,8 @@ mod tests { fn test_rewrite_get_request() { let raw = b"GET http://10.0.0.1:8000/api HTTP/1.1\r\nHost: 10.0.0.1:8000\r\nAccept: */*\r\n\r\n"; - let result = rewrite_forward_request(raw, raw.len(), "/api", None).expect("should succeed"); + let result = + rewrite_forward_request(raw, raw.len(), "/api", None, false).expect("should succeed"); let result_str = String::from_utf8_lossy(&result); assert!(result_str.starts_with("GET /api HTTP/1.1\r\n")); assert!(result_str.contains("Host: 10.0.0.1:8000")); @@ -4411,7 +4506,8 @@ mod tests { #[test] fn test_rewrite_strips_proxy_headers() { let raw = b"GET http://host/p HTTP/1.1\r\nHost: host\r\nProxy-Authorization: Basic abc\r\nProxy-Connection: keep-alive\r\nAccept: */*\r\n\r\n"; - let result = rewrite_forward_request(raw, raw.len(), "/p", None).expect("should succeed"); + let result = + rewrite_forward_request(raw, raw.len(), "/p", None, false).expect("should succeed"); let result_str = String::from_utf8_lossy(&result); assert!( !result_str @@ -4425,7 +4521,8 @@ mod tests { #[test] fn test_rewrite_replaces_connection_header() { let raw = b"GET http://host/p HTTP/1.1\r\nHost: host\r\nConnection: keep-alive\r\n\r\n"; - let result = rewrite_forward_request(raw, raw.len(), "/p", None).expect("should succeed"); + let result = + rewrite_forward_request(raw, raw.len(), "/p", None, false).expect("should succeed"); let result_str = String::from_utf8_lossy(&result); assert!(result_str.contains("Connection: close")); assert!(!result_str.contains("keep-alive")); @@ -4434,7 +4531,8 @@ mod tests { #[test] fn test_rewrite_preserves_body_overflow() { let raw = b"POST http://host/api HTTP/1.1\r\nHost: host\r\nContent-Length: 13\r\n\r\n{\"key\":\"val\"}"; - let result = rewrite_forward_request(raw, raw.len(), "/api", None).expect("should succeed"); + let result = + rewrite_forward_request(raw, raw.len(), "/api", None, false).expect("should succeed"); let result_str = String::from_utf8_lossy(&result); assert!(result_str.contains("{\"key\":\"val\"}")); assert!(result_str.contains("POST /api HTTP/1.1")); @@ -4443,7 +4541,8 @@ mod tests { #[test] fn test_rewrite_preserves_existing_via() { let raw = b"GET http://host/p HTTP/1.1\r\nHost: host\r\nVia: 1.0 upstream\r\n\r\n"; - let result = rewrite_forward_request(raw, raw.len(), "/p", None).expect("should succeed"); + let result = + rewrite_forward_request(raw, raw.len(), "/p", None, false).expect("should succeed"); let result_str = String::from_utf8_lossy(&result); assert!(result_str.contains("Via: 1.0 upstream")); // Should not add a second Via header @@ -4466,7 +4565,7 @@ mod tests { .expect("canonicalization should succeed for the attack payload"); assert_eq!(canon.path, "/secret"); - let rewritten = rewrite_forward_request(raw, raw.len(), &canon.path, None) + let rewritten = rewrite_forward_request(raw, raw.len(), &canon.path, None, false) .expect("rewrite_forward_request should succeed"); let rewritten_str = String::from_utf8_lossy(&rewritten); assert!( @@ -4492,7 +4591,7 @@ mod tests { _ => canon.path, }; - let rewritten = rewrite_forward_request(raw, raw.len(), &upstream_target, None) + let rewritten = rewrite_forward_request(raw, raw.len(), &upstream_target, None, false) .expect("rewrite_forward_request should succeed"); let rewritten_str = String::from_utf8_lossy(&rewritten); assert!( @@ -4511,13 +4610,147 @@ mod tests { .collect(), ); let raw = b"GET http://host/p HTTP/1.1\r\nHost: host\r\nAuthorization: Bearer openshell:resolve:env:ANTHROPIC_API_KEY\r\n\r\n"; - let result = rewrite_forward_request(raw, raw.len(), "/p", resolver.as_ref()) + let result = rewrite_forward_request(raw, raw.len(), "/p", resolver.as_ref(), false) .expect("should succeed"); let result_str = String::from_utf8_lossy(&result); assert!(result_str.contains("Authorization: Bearer sk-test")); assert!(!result_str.contains("openshell:resolve:env:ANTHROPIC_API_KEY")); } + #[tokio::test] + async fn forward_relay_rewrites_urlencoded_body_alias_from_initial_read() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let alias = "provider-OPENSHELL-RESOLVE-ENV-API_TOKEN"; + let body = format!("token={alias}&channel=C123"); + let raw = format!( + "POST http://api.example.com/api/messages HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Authorization: Bearer {alias}\r\n\ + Content-Type: application/x-www-form-urlencoded\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body + ); + + let forwarded = relay_forward_request_and_capture( + "POST", + "/api/messages", + raw.as_bytes(), + Some(&resolver), + true, + ) + .await + .expect("forward relay should rewrite credentials"); + + let expected_body = "token=provider-real-token&channel=C123"; + assert!(forwarded.starts_with("POST /api/messages HTTP/1.1\r\n")); + assert!(forwarded.contains("Authorization: Bearer provider-real-token\r\n")); + assert!(forwarded.contains(&format!("Content-Length: {}\r\n", expected_body.len()))); + assert!(forwarded.ends_with(expected_body)); + assert!(!forwarded.contains("OPENSHELL-RESOLVE-ENV")); + } + + #[tokio::test] + async fn forward_relay_rewrites_urlencoded_canonical_body_from_initial_read() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let alias = "provider-OPENSHELL-RESOLVE-ENV-API_TOKEN"; + let body = "token=openshell%3Aresolve%3Aenv%3AAPI_TOKEN&channel=C123"; + let raw = format!( + "POST http://api.example.com/api/messages HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Authorization: Bearer {alias}\r\n\ + Content-Type: application/x-www-form-urlencoded\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body + ); + + let forwarded = relay_forward_request_and_capture( + "POST", + "/api/messages", + raw.as_bytes(), + Some(&resolver), + true, + ) + .await + .expect("forward relay should rewrite credentials"); + + let expected_body = "token=provider-real-token&channel=C123"; + assert!(forwarded.contains("Authorization: Bearer provider-real-token\r\n")); + assert!(forwarded.contains(&format!("Content-Length: {}\r\n", expected_body.len()))); + assert!(forwarded.ends_with(expected_body)); + assert!(!forwarded.contains("openshell%3Aresolve%3Aenv%3AAPI_TOKEN")); + assert!(!forwarded.contains("openshell:resolve:env:API_TOKEN")); + } + + #[tokio::test] + async fn forward_relay_unresolved_body_placeholder_fails_before_upstream_write() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let alias = "provider-OPENSHELL-RESOLVE-ENV-API_TOKEN"; + let body = "token=provider-OPENSHELL-RESOLVE-ENV-MISSING_TOKEN"; + let raw = format!( + "POST http://api.example.com/api/messages HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Authorization: Bearer {alias}\r\n\ + Content-Type: application/x-www-form-urlencoded\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body + ); + let guard = forward_test_guard(); + let rewritten = rewrite_forward_request( + raw.as_bytes(), + raw.len(), + "/api/messages", + Some(&resolver), + true, + ) + .expect("header rewrite should defer body overflow to body rewriter"); + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let err = relay_rewritten_forward_request( + "POST", + "/api/messages", + rewritten, + &mut proxy_to_client, + &mut proxy_to_upstream, + ForwardRelayOptions { + generation_guard: &guard, + websocket_extensions: crate::l7::rest::WebSocketExtensionMode::Preserve, + secret_resolver: Some(&resolver), + request_body_credential_rewrite: true, + }, + ) + .await + .expect_err("unresolved body placeholder should fail closed"); + + assert!(!err.to_string().contains("provider-real-token")); + assert!(!err.to_string().contains("MISSING_TOKEN")); + drop(proxy_to_upstream); + let mut forwarded = Vec::new(); + upstream_side.read_to_end(&mut forwarded).await.unwrap(); + assert!( + forwarded.is_empty(), + "failed forward body rewrite must not reach upstream" + ); + } + #[test] fn test_forward_rewrite_preserves_websocket_upgrade_connection_header() { let raw = "GET http://gateway.example.test/ws HTTP/1.1\r\n\ @@ -4528,7 +4761,7 @@ mod tests { Sec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover\r\n\ Sec-WebSocket-Version: 13\r\n\r\n"; - let result = rewrite_forward_request(raw.as_bytes(), raw.len(), "/ws", None) + let result = rewrite_forward_request(raw.as_bytes(), raw.len(), "/ws", None, false) .expect("websocket forward rewrite should succeed"); let result_str = String::from_utf8_lossy(&result); @@ -4551,8 +4784,8 @@ mod tests { engine.reload(policy, policy_data).unwrap(); let raw = b"GET http://host/api HTTP/1.1\r\nHost: host\r\n\r\n"; - let rewritten = - rewrite_forward_request(raw, raw.len(), "/api", None).expect("rewrite should succeed"); + let rewritten = rewrite_forward_request(raw, raw.len(), "/api", None, false) + .expect("rewrite should succeed"); let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); @@ -4594,8 +4827,8 @@ mod tests { .unwrap(); let raw = b"POST http://host/api HTTP/1.1\r\nHost: host\r\nContent-Length: 4\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"; - let rewritten = - rewrite_forward_request(raw, raw.len(), "/api", None).expect("rewrite should succeed"); + let rewritten = rewrite_forward_request(raw, raw.len(), "/api", None, false) + .expect("rewrite should succeed"); let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); From 170655385581ee15f84c1a8d48801dd67a51a8ba Mon Sep 17 00:00:00 2001 From: Aaron Erickson Date: Mon, 11 May 2026 09:23:45 -0700 Subject: [PATCH 17/17] fix(sandbox): close websocket policy and provider alias gaps Signed-off-by: Aaron Erickson --- crates/openshell-sandbox/src/l7/relay.rs | 4 +- .../src/provider_credentials.rs | 46 +- crates/openshell-sandbox/src/proxy.rs | 462 +++++++++++++++--- crates/openshell-sandbox/src/secrets.rs | 37 +- 4 files changed, 459 insertions(+), 90 deletions(-) diff --git a/crates/openshell-sandbox/src/l7/relay.rs b/crates/openshell-sandbox/src/l7/relay.rs index 971b2e8e5..6d271af21 100644 --- a/crates/openshell-sandbox/src/l7/relay.rs +++ b/crates/openshell-sandbox/src/l7/relay.rs @@ -543,7 +543,7 @@ where Ok(()) } -fn upgrade_options<'a>( +pub(crate) fn upgrade_options<'a>( config: &L7EndpointConfig, ctx: &'a L7EvalContext, websocket_request: bool, @@ -584,7 +584,7 @@ fn upgrade_options<'a>( } } -fn websocket_extension_mode(config: &L7EndpointConfig) -> WebSocketExtensionMode { +pub(crate) fn websocket_extension_mode(config: &L7EndpointConfig) -> WebSocketExtensionMode { if config.protocol == L7Protocol::Websocket || (config.protocol == L7Protocol::Rest && config.websocket_credential_rewrite) { diff --git a/crates/openshell-sandbox/src/provider_credentials.rs b/crates/openshell-sandbox/src/provider_credentials.rs index ffe0148a4..829e1b226 100644 --- a/crates/openshell-sandbox/src/provider_credentials.rs +++ b/crates/openshell-sandbox/src/provider_credentials.rs @@ -19,6 +19,7 @@ pub struct ProviderCredentialSnapshot { struct ProviderCredentialStateInner { current: Arc, generations: VecDeque>, + current_resolver: Option>, combined_resolver: Option>, } @@ -29,19 +30,21 @@ pub struct ProviderCredentialState { impl ProviderCredentialState { pub fn from_environment(revision: u64, env: HashMap) -> Self { - let (child_env, resolver) = SecretResolver::from_provider_env_for_revision(env, revision); + let (child_env, generation_resolver, current_resolver) = + SecretResolver::from_provider_env_for_current_revision(env, revision); let snapshot = Arc::new(ProviderCredentialSnapshot { revision, child_env, }); - let generations: VecDeque<_> = resolver.map(Arc::new).into_iter().collect(); - let combined_resolver = - SecretResolver::merge(generations.iter().map(Arc::as_ref)).map(Arc::new); + let generations: VecDeque<_> = generation_resolver.map(Arc::new).into_iter().collect(); + let current_resolver = current_resolver.map(Arc::new); + let combined_resolver = merge_resolvers(&generations, current_resolver.as_ref()); Self { inner: Arc::new(RwLock::new(ProviderCredentialStateInner { current: snapshot, generations, + current_resolver, combined_resolver, })), } @@ -64,7 +67,8 @@ impl ProviderCredentialState { } pub fn install_environment(&self, revision: u64, env: HashMap) -> usize { - let (child_env, resolver) = SecretResolver::from_provider_env_for_revision(env, revision); + let (child_env, generation_resolver, current_resolver) = + SecretResolver::from_provider_env_for_current_revision(env, revision); let mut inner = self .inner .write() @@ -74,19 +78,33 @@ impl ProviderCredentialState { revision, child_env, }); + inner.current_resolver = current_resolver.map(Arc::new); - if let Some(resolver) = resolver { + if let Some(resolver) = generation_resolver { inner.generations.push_back(Arc::new(resolver)); while inner.generations.len() > MAX_RETAINED_CREDENTIAL_GENERATIONS { inner.generations.pop_front(); } } inner.combined_resolver = - SecretResolver::merge(inner.generations.iter().map(Arc::as_ref)).map(Arc::new); + merge_resolvers(&inner.generations, inner.current_resolver.as_ref()); inner.current.child_env.len() } } +fn merge_resolvers( + generations: &VecDeque>, + current_resolver: Option<&Arc>, +) -> Option> { + SecretResolver::merge( + generations + .iter() + .map(Arc::as_ref) + .chain(current_resolver.into_iter().map(Arc::as_ref)), + ) + .map(Arc::new) +} + #[cfg(test)] mod tests { use super::*; @@ -126,10 +144,14 @@ mod tests { resolver.resolve_placeholder("openshell:resolve:env:GITHUB_TOKEN"), Some("new") ); + assert_eq!( + resolver.resolve_placeholder("provider-OPENSHELL-RESOLVE-ENV-GITHUB_TOKEN"), + Some("new") + ); } #[test] - fn empty_refresh_removes_env_from_new_snapshots_but_retains_old_resolver() { + fn empty_refresh_removes_current_aliases_but_retains_revisioned_resolver() { let state = ProviderCredentialState::from_environment( 10, HashMap::from([("GITHUB_TOKEN".to_string(), "old".to_string())]), @@ -143,5 +165,13 @@ mod tests { resolver.resolve_placeholder("openshell:resolve:env:v10_GITHUB_TOKEN"), Some("old") ); + assert_eq!( + resolver.resolve_placeholder("openshell:resolve:env:GITHUB_TOKEN"), + None + ); + assert_eq!( + resolver.resolve_placeholder("provider-OPENSHELL-RESOLVE-ENV-GITHUB_TOKEN"), + None + ); } } diff --git a/crates/openshell-sandbox/src/proxy.rs b/crates/openshell-sandbox/src/proxy.rs index 0395562fb..3012930e2 100644 --- a/crates/openshell-sandbox/src/proxy.rs +++ b/crates/openshell-sandbox/src/proxy.rs @@ -2436,44 +2436,6 @@ where .await } -fn forward_websocket_upgrade_settings( - config: &crate::l7::L7EndpointConfig, - websocket_request: bool, - secret_resolver: Option>, -) -> ( - crate::l7::rest::WebSocketExtensionMode, - crate::l7::relay::UpgradeRelayOptions<'static>, -) { - let websocket_credential_rewrite = matches!( - config.protocol, - crate::l7::L7Protocol::Rest | crate::l7::L7Protocol::Websocket - ) && config.websocket_credential_rewrite; - let websocket_extensions = if config.protocol == crate::l7::L7Protocol::Websocket - || (config.protocol == crate::l7::L7Protocol::Rest && websocket_credential_rewrite) - { - crate::l7::rest::WebSocketExtensionMode::PermessageDeflate - } else { - crate::l7::rest::WebSocketExtensionMode::Preserve - }; - - let upgrade_options = crate::l7::relay::UpgradeRelayOptions { - websocket_request, - websocket: crate::l7::relay::WebSocketUpgradeBehavior { - credential_rewrite: websocket_credential_rewrite, - ..Default::default() - }, - secret_resolver: if websocket_credential_rewrite { - secret_resolver - } else { - None - }, - enforcement: config.enforcement, - ..Default::default() - }; - - (websocket_extensions, upgrade_options) -} - /// Handle a plain HTTP forward proxy request (non-CONNECT). /// /// Public IPs are allowed through when the endpoint passes OPA evaluation. @@ -2692,8 +2654,34 @@ async fn handle_forward_proxy( let mut forward_request_bytes = buf[..used].to_vec(); let mut upstream_target = path.clone(); let mut websocket_extensions = crate::l7::rest::WebSocketExtensionMode::Preserve; - let mut upgrade_options = crate::l7::relay::UpgradeRelayOptions::default(); + let mut forward_tunnel_engine: Option = None; + let mut forward_upgrade_config: Option = None; + let mut forward_upgrade_target = String::new(); + let mut forward_upgrade_query_params = std::collections::HashMap::new(); + let mut forward_websocket_request = + crate::l7::rest::request_is_websocket_upgrade(&forward_request_bytes); let mut request_body_credential_rewrite = false; + let l7_ctx = crate::l7::relay::L7EvalContext { + host: host_lc.clone(), + port, + policy_name: matched_policy.clone().unwrap_or_default(), + binary_path: decision + .binary + .as_ref() + .map(|p| p.to_string_lossy().into_owned()) + .unwrap_or_default(), + ancestors: decision + .ancestors + .iter() + .map(|p| p.to_string_lossy().into_owned()) + .collect(), + cmdline_paths: decision + .cmdline_paths + .iter() + .map(|p| p.to_string_lossy().into_owned()) + .collect(), + secret_resolver: secret_resolver.clone(), + }; // 4b. If the endpoint has L7 config, evaluate the request against // L7 policy. The forward proxy handles exactly one request per @@ -2741,28 +2729,6 @@ async fn handle_forward_proxy( } }; - let l7_ctx = crate::l7::relay::L7EvalContext { - host: host_lc.clone(), - port, - policy_name: matched_policy.clone().unwrap_or_default(), - binary_path: decision - .binary - .as_ref() - .map(|p| p.to_string_lossy().into_owned()) - .unwrap_or_default(), - ancestors: decision - .ancestors - .iter() - .map(|p| p.to_string_lossy().into_owned()) - .collect(), - cmdline_paths: decision - .cmdline_paths - .iter() - .map(|p| p.to_string_lossy().into_owned()) - .collect(), - secret_resolver: secret_resolver.clone(), - }; - // Canonicalize the request-target. The canonical form is fed to OPA // AND reassigned to the outer `path` variable so the later call to // `rewrite_forward_request` writes canonical bytes to the upstream. @@ -2831,16 +2797,14 @@ async fn handle_forward_proxy( .await?; return Ok(()); }; - let websocket_request = + forward_websocket_request = crate::l7::rest::request_is_websocket_upgrade(&forward_request_bytes); - (websocket_extensions, upgrade_options) = forward_websocket_upgrade_settings( - &l7_config.config, - websocket_request, - secret_resolver.clone(), - ); + websocket_extensions = crate::l7::relay::websocket_extension_mode(&l7_config.config); request_body_credential_rewrite = l7_config.config.protocol == crate::l7::L7Protocol::Rest && l7_config.config.request_body_credential_rewrite; - upgrade_options.policy_name = matched_policy.clone().unwrap_or_default(); + forward_upgrade_config = Some(l7_config.config.clone()); + forward_upgrade_target = path.clone(); + forward_upgrade_query_params = query_params.clone(); let graphql = if l7_config.config.protocol == crate::l7::L7Protocol::Graphql { let header_end = forward_request_bytes .windows(4) @@ -3001,6 +2965,7 @@ async fn handle_forward_proxy( .await?; return Ok(()); } + forward_tunnel_engine = Some(tunnel_engine); } // 5. DNS resolution + SSRF defence (mirrors the CONNECT path logic). @@ -3317,6 +3282,24 @@ async fn handle_forward_proxy( websocket_permessage_deflate, } = outcome { + let mut upgrade_options = if let (Some(config), Some(engine)) = ( + forward_upgrade_config.as_ref(), + forward_tunnel_engine.as_ref(), + ) { + crate::l7::relay::upgrade_options( + config, + &l7_ctx, + forward_websocket_request, + &forward_upgrade_target, + &forward_upgrade_query_params, + Some(engine), + ) + } else { + crate::l7::relay::UpgradeRelayOptions { + websocket_request: forward_websocket_request, + ..Default::default() + } + }; upgrade_options.websocket.permessage_deflate = websocket_permessage_deflate; crate::l7::relay::handle_upgrade( client, @@ -3503,18 +3486,225 @@ mod tests { .map_err(|e| miette::miette!("upstream task failed: {e}")) } + fn forward_websocket_policy_parts( + data: &str, + host: &str, + port: u16, + path: &str, + policy_name: &str, + ) -> ( + crate::l7::L7EndpointConfig, + crate::opa::TunnelPolicyEngine, + crate::l7::relay::L7EvalContext, + ) { + let policy = include_str!("../data/sandbox-policy.rego"); + let engine = OpaEngine::from_strings(policy, data).unwrap(); + let decision = ConnectDecision { + action: NetworkAction::Allow { + matched_policy: Some(policy_name.to_string()), + }, + generation: engine.current_generation(), + binary: Some(PathBuf::from("/usr/bin/node")), + binary_pid: None, + ancestors: vec![], + cmdline_paths: vec![], + }; + let route = + query_l7_route_snapshot(&engine, &decision, host, port).expect("L7 route should match"); + let config = select_l7_config_for_path(&route.configs, path) + .expect("path-specific L7 config should match") + .config + .clone(); + let tunnel_engine = engine + .clone_engine_for_tunnel(route.generation) + .expect("tunnel engine"); + let ctx = crate::l7::relay::L7EvalContext { + host: host.to_string(), + port, + policy_name: policy_name.to_string(), + binary_path: "/usr/bin/node".to_string(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + }; + (config, tunnel_engine, ctx) + } + + async fn read_http_headers(reader: &mut R) -> Vec { + let mut bytes = Vec::new(); + let mut chunk = [0u8; 256]; + loop { + let n = + tokio::time::timeout(std::time::Duration::from_secs(1), reader.read(&mut chunk)) + .await + .expect("HTTP headers should arrive") + .expect("header read should succeed"); + assert!(n > 0, "stream closed before HTTP headers"); + bytes.extend_from_slice(&chunk[..n]); + if bytes.windows(4).any(|w| w == b"\r\n\r\n") { + return bytes; + } + } + } + + fn masked_text_frame(payload: &[u8]) -> Vec { + let mask = [0x11, 0x22, 0x33, 0x44]; + assert!( + payload.len() <= 125, + "test helper only supports small frames" + ); + let payload_len = u8::try_from(payload.len()).expect("small frame length"); + let mut frame = vec![0x81, 0x80 | payload_len]; + frame.extend_from_slice(&mask); + frame.extend( + payload + .iter() + .enumerate() + .map(|(idx, byte)| byte ^ mask[idx % 4]), + ); + frame + } + + async fn forward_websocket_denied_after_upgrade( + config: crate::l7::L7EndpointConfig, + tunnel_engine: crate::opa::TunnelPolicyEngine, + ctx: crate::l7::relay::L7EvalContext, + path: &str, + payload: &str, + ) -> (miette::Report, Vec) { + let host = ctx.host.clone(); + let port = ctx.port; + let raw = format!( + "GET http://{host}{path} HTTP/1.1\r\n\ + Host: {host}\r\n\ + Upgrade: websocket\r\n\ + Connection: Upgrade\r\n\ + Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ + Sec-WebSocket-Version: 13\r\n\r\n" + ); + let rewritten = rewrite_forward_request(raw.as_bytes(), raw.len(), path, None, false) + .expect("forward websocket request should rewrite to origin form"); + let websocket_extensions = crate::l7::relay::websocket_extension_mode(&config); + let target = path.to_string(); + let query_params = std::collections::HashMap::new(); + let (mut proxy_to_upstream, mut upstream) = tokio::io::duplex(8192); + let (mut app, mut proxy_to_client) = tokio::io::duplex(8192); + + let relay = tokio::spawn(async move { + let guard = tunnel_engine.generation_guard(); + let outcome = relay_rewritten_forward_request( + "GET", + &target, + rewritten, + &mut proxy_to_client, + &mut proxy_to_upstream, + ForwardRelayOptions { + generation_guard: guard, + websocket_extensions, + secret_resolver: None, + request_body_credential_rewrite: false, + }, + ) + .await?; + if let crate::l7::provider::RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + } = outcome + { + let mut options = crate::l7::relay::upgrade_options( + &config, + &ctx, + true, + &target, + &query_params, + Some(&tunnel_engine), + ); + options.websocket.permessage_deflate = websocket_permessage_deflate; + crate::l7::relay::handle_upgrade( + &mut proxy_to_client, + &mut proxy_to_upstream, + overflow, + &host, + port, + options, + ) + .await?; + } + Ok::<(), miette::Report>(()) + }); + + let forwarded_headers = read_http_headers(&mut upstream).await; + let forwarded_headers = String::from_utf8_lossy(&forwarded_headers); + assert!(forwarded_headers.starts_with(&format!("GET {path} HTTP/1.1\r\n"))); + assert!(forwarded_headers.contains("Upgrade: websocket\r\n")); + + upstream + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n", + ) + .await + .unwrap(); + + let response = read_http_headers(&mut app).await; + assert!(String::from_utf8_lossy(&response).contains("101 Switching Protocols")); + + app.write_all(&masked_text_frame(payload.as_bytes())) + .await + .unwrap(); + + let err = tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("websocket relay should fail closed after denied frame") + .expect("relay task should not panic") + .expect_err("denied websocket frame should fail the forward relay"); + + let mut leaked = Vec::new(); + tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read_to_end(&mut leaked), + ) + .await + .expect("upstream side should close") + .expect("upstream read should succeed"); + (err, leaked) + } + #[test] - fn forward_websocket_upgrade_enables_rewrite_for_native_websocket_endpoint() { + fn forward_websocket_upgrade_options_enable_native_policy_context() { let (_, resolver) = SecretResolver::from_provider_env( [("DISCORD_BOT_TOKEN".to_string(), "discord-real".to_string())] .into_iter() .collect(), ); + let resolver = resolver.map(Arc::new); + let policy = include_str!("../data/sandbox-policy.rego"); + let policy_data = "network_policies: {}\n"; + let engine = OpaEngine::from_strings(policy, policy_data).unwrap(); + let tunnel_engine = engine + .clone_engine_for_tunnel(engine.current_generation()) + .unwrap(); + let ctx = crate::l7::relay::L7EvalContext { + host: "gateway.example.test".to_string(), + port: 80, + policy_name: "ws_api".to_string(), + binary_path: "/usr/bin/node".to_string(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: resolver, + }; + let query_params = std::collections::HashMap::new(); - let (extensions, options) = forward_websocket_upgrade_settings( + let extensions = crate::l7::relay::websocket_extension_mode(&websocket_l7_config( + crate::l7::L7Protocol::Websocket, + true, + )); + let options = crate::l7::relay::upgrade_options( &websocket_l7_config(crate::l7::L7Protocol::Websocket, true), + &ctx, true, - resolver.map(Arc::new), + "/ws", + &query_params, + Some(&tunnel_engine), ); assert_eq!( @@ -3523,15 +3713,30 @@ mod tests { ); assert!(options.websocket.credential_rewrite); assert!(options.secret_resolver.is_some()); + assert!(options.engine.is_some()); + assert!(options.ctx.is_some()); + assert!(matches!( + options.websocket.message_policy, + crate::l7::relay::WebSocketMessagePolicy::Transport + )); } #[test] - fn forward_websocket_upgrade_preserves_rest_without_rewrite() { - let (extensions, options) = forward_websocket_upgrade_settings( - &websocket_l7_config(crate::l7::L7Protocol::Rest, false), - true, - None, - ); + fn forward_websocket_upgrade_options_preserve_rest_without_rewrite() { + let ctx = crate::l7::relay::L7EvalContext { + host: "gateway.example.test".to_string(), + port: 80, + policy_name: "rest_api".to_string(), + binary_path: "/usr/bin/node".to_string(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + }; + let query_params = std::collections::HashMap::new(); + let config = websocket_l7_config(crate::l7::L7Protocol::Rest, false); + let extensions = crate::l7::relay::websocket_extension_mode(&config); + let options = + crate::l7::relay::upgrade_options(&config, &ctx, true, "/ws", &query_params, None); assert_eq!( extensions, @@ -3539,6 +3744,109 @@ mod tests { ); assert!(!options.websocket.credential_rewrite); assert!(options.secret_resolver.is_none()); + assert!(options.engine.is_none()); + assert!(options.ctx.is_none()); + assert!(matches!( + options.websocket.message_policy, + crate::l7::relay::WebSocketMessagePolicy::None + )); + } + + #[tokio::test] + async fn forward_websocket_upgrade_blocks_text_frame_by_policy() { + let data = r#" +network_policies: + ws_api: + name: ws_api + endpoints: + - host: gateway.example.test + port: 80 + path: "/ws" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/ws" + - allow: + method: WEBSOCKET_TEXT + path: "/ws" + deny_rules: + - method: WEBSOCKET_TEXT + path: "/ws" + binaries: + - { path: /usr/bin/node } +"#; + let (config, tunnel_engine, ctx) = + forward_websocket_policy_parts(data, "gateway.example.test", 80, "/ws", "ws_api"); + + let (err, leaked) = forward_websocket_denied_after_upgrade( + config, + tunnel_engine, + ctx, + "/ws", + r#"{"type":"unsafe"}"#, + ) + .await; + + assert!(err.to_string().contains("websocket text message denied")); + assert!( + leaked.is_empty(), + "denied forward-proxy WebSocket text frames must not reach upstream" + ); + } + + #[tokio::test] + async fn forward_graphql_websocket_upgrade_blocks_unallowed_operation() { + let data = r#" +network_policies: + graphql_ws: + name: graphql_ws + endpoints: + - host: gateway.example.test + port: 80 + path: "/graphql" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/graphql" + - allow: + operation_type: query + fields: [viewer] + deny_rules: + - operation_type: query + fields: [admin] + binaries: + - { path: /usr/bin/node } +"#; + let (config, tunnel_engine, ctx) = forward_websocket_policy_parts( + data, + "gateway.example.test", + 80, + "/graphql", + "graphql_ws", + ); + assert!( + config.websocket_graphql_policy, + "operation rules should enable GraphQL-over-WebSocket inspection" + ); + + let (err, leaked) = forward_websocket_denied_after_upgrade( + config, + tunnel_engine, + ctx, + "/graphql", + r#"{"id":"1","type":"subscribe","payload":{"query":"query { admin }"}}"#, + ) + .await; + + assert!(err.to_string().contains("websocket GraphQL message denied")); + assert!( + leaked.is_empty(), + "denied forward-proxy GraphQL WebSocket operations must not reach upstream" + ); } #[test] diff --git a/crates/openshell-sandbox/src/secrets.rs b/crates/openshell-sandbox/src/secrets.rs index 54c43d07a..6dbd34dcb 100644 --- a/crates/openshell-sandbox/src/secrets.rs +++ b/crates/openshell-sandbox/src/secrets.rs @@ -98,6 +98,38 @@ impl SecretResolver { pub(crate) fn from_provider_env_for_revision( provider_env: HashMap, revision: u64, + ) -> (HashMap, Option) { + Self::from_provider_env_for_revision_with_current_aliases(provider_env, revision, false) + } + + pub(crate) fn from_provider_env_for_current_revision( + provider_env: HashMap, + revision: u64, + ) -> (HashMap, Option, Option) { + if revision == 0 { + let (child_env, current_resolver) = + Self::from_provider_env_for_revision_with_current_aliases(provider_env, 0, true); + return (child_env, None, current_resolver); + } + let provider_env_for_current = provider_env.clone(); + let (child_env, revision_resolver) = + Self::from_provider_env_for_revision_with_current_aliases( + provider_env, + revision, + false, + ); + let (_, current_resolver) = Self::from_provider_env_for_revision_with_current_aliases( + provider_env_for_current, + revision, + true, + ); + (child_env, revision_resolver, current_resolver) + } + + fn from_provider_env_for_revision_with_current_aliases( + provider_env: HashMap, + revision: u64, + include_current_aliases: bool, ) -> (HashMap, Option) { if provider_env.is_empty() { return (HashMap::new(), None); @@ -108,11 +140,10 @@ impl SecretResolver { for (key, value) in provider_env { let placeholder = placeholder_for_env_key_for_revision(&key, revision); - let canonical_placeholder = (revision != 0).then(|| placeholder_for_env_key(&key)); child_env.insert(key.clone(), placeholder.clone()); by_placeholder.insert(placeholder, value.clone()); - if let Some(canonical_placeholder) = canonical_placeholder { - by_placeholder.insert(canonical_placeholder, value.clone()); + if include_current_aliases && revision != 0 { + by_placeholder.insert(placeholder_for_env_key(&key), value.clone()); } }