diff --git a/.github/workflows/websocket-conformance.yml b/.github/workflows/websocket-conformance.yml new file mode 100644 index 000000000..c1c689c54 --- /dev/null +++ b/.github/workflows/websocket-conformance.yml @@ -0,0 +1,65 @@ +name: WebSocket Conformance + +on: + workflow_dispatch: {} + # Add `schedule:` here after this focused lane has burned in manually. + +permissions: {} + +jobs: + build-gateway: + permissions: + contents: read + packages: write + uses: ./.github/workflows/docker-build.yml + with: + component: gateway + platform: linux/amd64 + + build-supervisor: + permissions: + contents: read + packages: write + uses: ./.github/workflows/docker-build.yml + with: + component: supervisor + platform: linux/amd64 + + websocket-conformance: + name: WebSocket Conformance + needs: [build-gateway, build-supervisor] + runs-on: linux-amd64-cpu8 + timeout-minutes: 30 + permissions: + contents: read + packages: read + container: + image: ghcr.io/nvidia/openshell/ci:latest + credentials: + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + options: --privileged + volumes: + - /var/run/docker.sock:/var/run/docker.sock + - /home/runner/_work:/home/runner/_work + env: + MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + IMAGE_TAG: ${{ github.sha }} + OPENSHELL_REGISTRY: ghcr.io/nvidia/openshell + OPENSHELL_REGISTRY_HOST: ghcr.io + OPENSHELL_REGISTRY_NAMESPACE: nvidia/openshell + OPENSHELL_REGISTRY_USERNAME: ${{ github.actor }} + OPENSHELL_REGISTRY_PASSWORD: ${{ secrets.GITHUB_TOKEN }} + steps: + - uses: actions/checkout@v6 + + - name: Install OS test dependencies + run: apt-get update && apt-get install -y openssh-client && rm -rf /var/lib/apt/lists/* + + - name: Log in to GHCR + run: echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u "${{ github.actor }}" --password-stdin + + - name: Run WebSocket conformance e2e + env: + OPENSHELL_SUPERVISOR_IMAGE: ${{ format('ghcr.io/nvidia/openshell/supervisor:{0}', github.sha) }} + run: mise run --no-deps --skip-deps e2e:websocket-conformance diff --git a/Cargo.lock b/Cargo.lock index 808956cd9..05a1bdff2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3571,6 +3571,7 @@ dependencies = [ "base64 0.22.1", "bytes", "clap", + "flate2", "futures", "glob", "hex", @@ -3594,6 +3595,7 @@ dependencies = [ "serde", "serde_json", "serde_yml", + "sha1 0.10.6", "sha2 0.10.9", "temp-env", "tempfile", diff --git a/architecture/security-policy.md b/architecture/security-policy.md index e5f179dc1..5c04bebf5 100644 --- a/architecture/security-policy.md +++ b/architecture/security-policy.md @@ -43,9 +43,13 @@ with the sandbox's ephemeral CA and inspect method/path or protocol-specific metadata before forwarding. The proxy also supports credential injection on terminated HTTP streams when policy allows the endpoint. -Raw streams, HTTP upgrades, and long-lived response bodies are connection -scoped. Policy reloads affect the next connection or the next parsed HTTP -request; they do not rewrite bytes already being relayed. +Raw streams and long-lived response bodies are connection scoped. Policy +reloads affect the next connection or the next parsed HTTP request; they do not +rewrite bytes already being relayed. HTTP upgrades switch to raw relay by +default. A `protocol: rest` endpoint can opt in to +`websocket_credential_rewrite` for client-to-server WebSocket text messages +after an allowed `101` upgrade; server-to-client traffic and all other upgraded +protocols remain raw passthrough. ## Live Updates diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index 25fa07cf2..96714ab30 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -287,6 +287,7 @@ const POLICY_EXAMPLES: &str = "\x1b[1mALIAS\x1b[0m $ openshell policy get my-sandbox $ openshell policy set my-sandbox --policy policy.yaml $ openshell policy update my-sandbox --add-endpoint api.github.com:443:read-only:rest:enforce + $ openshell policy update my-sandbox --add-endpoint realtime.example.com:443:read-write:websocket:enforce:websocket-credential-rewrite,allowed-ip=10.0.0.0/8 $ openshell policy update my-sandbox --add-allow 'api.github.com:443:GET:/repos/**' $ openshell policy set --global --policy policy.yaml $ openshell policy delete --global @@ -1403,7 +1404,7 @@ enum PolicyCommands { #[arg(add = ArgValueCompleter::new(completers::complete_sandbox_names))] name: Option, - /// Add or merge an endpoint: host:port[:access[:protocol[:enforcement]]]. + /// Add or merge an endpoint: host:port[:access[:protocol[:enforcement[:options]]]]. #[arg(long = "add-endpoint")] add_endpoints: Vec, @@ -1411,11 +1412,11 @@ enum PolicyCommands { #[arg(long = "remove-endpoint")] remove_endpoints: Vec, - /// Add a REST allow rule: `host:port:METHOD:path_glob`. + /// Add a REST or WebSocket method/path allow rule: `host:port:METHOD:path_glob`. #[arg(long = "add-allow")] add_allow: Vec, - /// Add a REST deny rule: `host:port:METHOD:path_glob`. + /// Add a REST or WebSocket method/path deny rule: `host:port:METHOD:path_glob`. #[arg(long = "add-deny")] add_deny: Vec, diff --git a/crates/openshell-cli/src/policy_update.rs b/crates/openshell-cli/src/policy_update.rs index 322a28df6..57656b878 100644 --- a/crates/openshell-cli/src/policy_update.rs +++ b/crates/openshell-cli/src/policy_update.rs @@ -18,6 +18,7 @@ pub struct PolicyUpdatePlan { pub preview_operations: Vec, } +#[allow(clippy::too_many_arguments)] pub fn build_policy_update_plan( add_endpoints: &[String], remove_endpoints: &[String], @@ -41,7 +42,6 @@ pub fn build_policy_update_plan( "--rule-name is only supported when exactly one --add-endpoint is provided" )); } - let mut merge_operations = Vec::new(); let mut preview_operations = Vec::new(); @@ -155,6 +155,40 @@ pub fn build_policy_update_plan( }) } +fn ensure_websocket_credential_rewrite_protocol( + spec: &str, + endpoint: &NetworkEndpoint, +) -> Result<()> { + if matches!(endpoint.protocol.as_str(), "rest" | "websocket") { + return Ok(()); + } + let protocol = if endpoint.protocol.is_empty() { + "" + } else { + endpoint.protocol.as_str() + }; + Err(miette!( + "websocket-credential-rewrite endpoint option requires --add-endpoint protocol segment to be 'rest' or 'websocket'; got '{protocol}' in '{spec}'" + )) +} + +fn ensure_request_body_credential_rewrite_protocol( + spec: &str, + endpoint: &NetworkEndpoint, +) -> Result<()> { + if endpoint.protocol == "rest" { + return Ok(()); + } + let protocol = if endpoint.protocol.is_empty() { + "" + } else { + endpoint.protocol.as_str() + }; + Err(miette!( + "request-body-credential-rewrite endpoint option requires --add-endpoint protocol segment to be 'rest'; got '{protocol}' in '{spec}'" + )) +} + fn group_allow_rules(specs: &[String]) -> Result>> { let mut grouped = BTreeMap::new(); for spec in specs { @@ -257,9 +291,9 @@ fn parse_remove_endpoint_spec(spec: &str) -> Result<(String, u32)> { fn parse_add_endpoint_spec(spec: &str) -> Result { let parts = spec.split(':').collect::>(); - if !(2..=5).contains(&parts.len()) { + if !(2..=6).contains(&parts.len()) { return Err(miette!( - "--add-endpoint expects host:port[:access[:protocol[:enforcement]]], got '{spec}'" + "--add-endpoint expects host:port[:access[:protocol[:enforcement[:options]]]], got '{spec}'" )); } @@ -269,12 +303,18 @@ fn parse_add_endpoint_spec(spec: &str) -> Result { let access = parts.get(2).copied().unwrap_or("").trim(); let protocol = parts.get(3).copied().unwrap_or("").trim(); let enforcement = parts.get(4).copied().unwrap_or("").trim(); + let options = parts.get(5).copied().unwrap_or("").trim(); if parts.len() == 3 && access.is_empty() { return Err(miette!( "--add-endpoint has an empty access segment in '{spec}'; omit it entirely if you do not need access or protocol fields" )); } + if parts.len() == 6 && options.is_empty() { + return Err(miette!( + "--add-endpoint has an empty options segment in '{spec}'; omit it entirely if you do not need endpoint options" + )); + } if !enforcement.is_empty() && protocol.is_empty() { return Err(miette!( "--add-endpoint cannot set enforcement without protocol in '{spec}'" @@ -285,9 +325,9 @@ fn parse_add_endpoint_spec(spec: &str) -> Result { "--add-endpoint access segment must be one of read-only, read-write, or full; got '{access}' in '{spec}'" )); } - if !protocol.is_empty() && !matches!(protocol, "rest" | "sql") { + if !protocol.is_empty() && !matches!(protocol, "rest" | "websocket" | "sql") { return Err(miette!( - "--add-endpoint protocol segment must be 'rest' or 'sql'; got '{protocol}' in '{spec}'" + "--add-endpoint protocol segment must be 'rest', 'websocket', or 'sql'; got '{protocol}' in '{spec}'" )); } if !enforcement.is_empty() && !matches!(enforcement, "enforce" | "audit") { @@ -296,7 +336,7 @@ fn parse_add_endpoint_spec(spec: &str) -> Result { )); } - Ok(NetworkEndpoint { + let mut endpoint = NetworkEndpoint { host, port, ports: vec![port], @@ -304,7 +344,65 @@ fn parse_add_endpoint_spec(spec: &str) -> Result { enforcement: enforcement.to_string(), access: access.to_string(), ..Default::default() - }) + }; + apply_add_endpoint_options(spec, &mut endpoint, options)?; + Ok(endpoint) +} + +fn apply_add_endpoint_options( + spec: &str, + endpoint: &mut NetworkEndpoint, + options: &str, +) -> Result<()> { + if options.is_empty() { + return Ok(()); + } + + for option in options.split(',') { + let option = option.trim(); + if option.is_empty() { + return Err(miette!( + "--add-endpoint options segment must not contain empty options in '{spec}'" + )); + } + match option { + "websocket-credential-rewrite" => { + ensure_websocket_credential_rewrite_protocol(spec, endpoint)?; + endpoint.websocket_credential_rewrite = true; + } + "request-body-credential-rewrite" => { + ensure_request_body_credential_rewrite_protocol(spec, endpoint)?; + endpoint.request_body_credential_rewrite = true; + } + _ => { + let Some(allowed_ip) = option.strip_prefix("allowed-ip=") else { + return Err(miette!( + "--add-endpoint options segment supports only 'websocket-credential-rewrite', 'request-body-credential-rewrite', and 'allowed-ip='; got '{option}' in '{spec}'" + )); + }; + let allowed_ip = allowed_ip.trim(); + if allowed_ip.is_empty() { + return Err(miette!( + "--add-endpoint allowed-ip option must include a CIDR or IP value in '{spec}'" + )); + } + if allowed_ip.contains(char::is_whitespace) { + return Err(miette!( + "--add-endpoint allowed-ip option must not contain whitespace in '{spec}'" + )); + } + if !endpoint + .allowed_ips + .iter() + .any(|existing| existing == allowed_ip) + { + endpoint.allowed_ips.push(allowed_ip.to_string()); + } + } + } + } + + Ok(()) } fn parse_host(flag: &str, spec: &str, host: &str) -> Result { @@ -352,7 +450,30 @@ fn dedup_strings(values: &[String]) -> Vec { #[cfg(test)] mod tests { - use super::build_policy_update_plan; + use super::{ + PolicyUpdatePlan, build_policy_update_plan as build_policy_update_plan_with_options, + }; + use openshell_policy::PolicyMergeOp; + + fn build_policy_update_plan( + add_endpoints: &[String], + remove_endpoints: &[String], + add_deny: &[String], + add_allow: &[String], + remove_rules: &[String], + binaries: &[String], + rule_name: Option<&str>, + ) -> miette::Result { + build_policy_update_plan_with_options( + add_endpoints, + remove_endpoints, + add_deny, + add_allow, + remove_rules, + binaries, + rule_name, + ) + } #[test] fn parse_add_endpoint_basic_l4() { @@ -392,6 +513,229 @@ mod tests { .expect("plan should build"); } + #[test] + fn parse_add_endpoint_accepts_websocket_protocol() { + let plan = build_policy_update_plan( + &["realtime.example.com:443:read-write:websocket:enforce".to_string()], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect("plan should build"); + + let PolicyMergeOp::AddRule { rule, .. } = &plan.preview_operations[0] else { + panic!("expected add-rule preview"); + }; + let endpoint = &rule.endpoints[0]; + assert_eq!(endpoint.host, "realtime.example.com"); + assert_eq!(endpoint.protocol, "websocket"); + assert_eq!(endpoint.access, "read-write"); + assert_eq!(endpoint.enforcement, "enforce"); + } + + #[test] + fn parse_add_endpoint_enables_websocket_credential_rewrite() { + let plan = build_policy_update_plan( + &["realtime.example.com:443:read-write:websocket:enforce:websocket-credential-rewrite" + .to_string()], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect("plan should build"); + + let PolicyMergeOp::AddRule { rule, .. } = &plan.preview_operations[0] else { + panic!("expected add-rule preview"); + }; + assert!(rule.endpoints[0].websocket_credential_rewrite); + } + + #[test] + fn parse_add_endpoint_enables_websocket_credential_rewrite_on_rest_compat_endpoint() { + let plan = build_policy_update_plan( + &[ + "realtime.example.com:443:read-write:rest:enforce:websocket-credential-rewrite" + .to_string(), + ], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect("plan should build"); + + let PolicyMergeOp::AddRule { rule, .. } = &plan.preview_operations[0] else { + panic!("expected add-rule preview"); + }; + assert!(rule.endpoints[0].websocket_credential_rewrite); + } + + #[test] + fn parse_add_endpoint_enables_request_body_credential_rewrite_on_rest_endpoint() { + let plan = build_policy_update_plan( + &[ + "api.example.com:443:read-write:rest:enforce:request-body-credential-rewrite" + .to_string(), + ], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect("plan should build"); + + let PolicyMergeOp::AddRule { rule, .. } = &plan.preview_operations[0] else { + panic!("expected add-rule preview"); + }; + let endpoint = &rule.endpoints[0]; + assert_eq!(endpoint.protocol, "rest"); + assert!(endpoint.request_body_credential_rewrite); + } + + #[test] + fn parse_add_endpoint_merges_allowed_ips_with_websocket_options() { + let plan = build_policy_update_plan( + &[ + "realtime.example.com:443:read-write:websocket:enforce:websocket-credential-rewrite,allowed-ip=10.0.0.0/8,allowed-ip=172.16.0.0/12,allowed-ip=10.0.0.0/8" + .to_string(), + ], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect("plan should build"); + + let PolicyMergeOp::AddRule { rule, .. } = &plan.preview_operations[0] else { + panic!("expected add-rule preview"); + }; + let endpoint = &rule.endpoints[0]; + assert!(endpoint.websocket_credential_rewrite); + assert_eq!( + endpoint.allowed_ips, + vec!["10.0.0.0/8".to_string(), "172.16.0.0/12".to_string()] + ); + } + + #[test] + fn parse_add_endpoint_accepts_allowed_ip_on_rest_endpoint() { + let plan = build_policy_update_plan( + &["api.example.com:443:read-write:rest:enforce:allowed-ip=192.168.0.0/16".to_string()], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect("plan should build"); + + let PolicyMergeOp::AddRule { rule, .. } = &plan.preview_operations[0] else { + panic!("expected add-rule preview"); + }; + assert_eq!(rule.endpoints[0].allowed_ips, vec!["192.168.0.0/16"]); + } + + #[test] + fn parse_add_endpoint_rejects_empty_allowed_ip() { + let error = build_policy_update_plan( + &["api.example.com:443:read-write:rest:enforce:allowed-ip=".to_string()], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect_err("plan should fail"); + assert!(error.to_string().contains("allowed-ip option")); + } + + #[test] + fn websocket_credential_rewrite_rejects_l4_endpoint() { + let error = build_policy_update_plan( + &["realtime.example.com:443::::websocket-credential-rewrite".to_string()], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect_err("plan should fail"); + assert!(error.to_string().contains("protocol segment")); + } + + #[test] + fn request_body_credential_rewrite_rejects_non_rest_endpoint() { + let error = build_policy_update_plan( + &[ + "realtime.example.com:443:read-write:websocket:enforce:request-body-credential-rewrite" + .to_string(), + ], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect_err("plan should fail"); + + assert!(error.to_string().contains("protocol segment")); + assert!(error.to_string().contains("'rest'")); + } + + #[test] + fn parse_add_endpoint_rejects_unknown_options() { + let error = build_policy_update_plan( + &["realtime.example.com:443:read-write:websocket:enforce:future-option".to_string()], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect_err("plan should fail"); + assert!(error.to_string().contains("options segment")); + } + + #[test] + fn parse_add_allow_accepts_websocket_text_method() { + let plan = build_policy_update_plan( + &[], + &[], + &[], + &["realtime.example.com:443:websocket_text:/v1/messages/**".to_string()], + &[], + &[], + None, + ) + .expect("plan should build"); + + let PolicyMergeOp::AddAllowRules { host, port, rules } = &plan.preview_operations[0] else { + panic!("expected add-allow preview"); + }; + assert_eq!(host, "realtime.example.com"); + assert_eq!(*port, 443); + let allow = rules[0].allow.as_ref().expect("allow rule"); + assert_eq!(allow.method, "WEBSOCKET_TEXT"); + assert_eq!(allow.path, "/v1/messages/**"); + } + #[test] fn parse_add_deny_rejects_empty_method() { let error = build_policy_update_plan( diff --git a/crates/openshell-driver-docker/src/lib.rs b/crates/openshell-driver-docker/src/lib.rs index a74609fb5..b6be2fce5 100644 --- a/crates/openshell-driver-docker/src/lib.rs +++ b/crates/openshell-driver-docker/src/lib.rs @@ -1099,7 +1099,7 @@ fn docker_gateway_route( }; } - if is_docker_desktop(info) { + if uses_host_gateway_alias(info) { DockerGatewayRoute::HostGateway } else { DockerGatewayRoute::Bridge { @@ -1109,7 +1109,7 @@ fn docker_gateway_route( } } -fn is_docker_desktop(info: &SystemInfo) -> bool { +fn uses_host_gateway_alias(info: &SystemInfo) -> bool { let operating_system = info .operating_system .as_deref() @@ -1119,6 +1119,15 @@ fn is_docker_desktop(info: &SystemInfo) -> bool { return true; } + let name = info + .name + .as_deref() + .unwrap_or_default() + .to_ascii_lowercase(); + if name == "colima" { + return true; + } + info.labels.as_ref().is_some_and(|labels| { labels .iter() @@ -1132,9 +1141,10 @@ fn docker_extra_hosts(route: &DockerGatewayRoute) -> Vec { format!("{HOST_DOCKER_INTERNAL}:{host_alias_ip}"), format!("{HOST_OPENSHELL_INTERNAL}:{host_alias_ip}"), ], - DockerGatewayRoute::HostGateway => { - vec![format!("{HOST_OPENSHELL_INTERNAL}:host-gateway")] - } + DockerGatewayRoute::HostGateway => vec![ + format!("{HOST_DOCKER_INTERNAL}:host-gateway"), + format!("{HOST_OPENSHELL_INTERNAL}:host-gateway"), + ], } } diff --git a/crates/openshell-driver-docker/src/tests.rs b/crates/openshell-driver-docker/src/tests.rs index c89019398..83906e73c 100644 --- a/crates/openshell-driver-docker/src/tests.rs +++ b/crates/openshell-driver-docker/src/tests.rs @@ -160,7 +160,36 @@ fn docker_gateway_route_uses_host_gateway_for_docker_desktop() { ); assert_eq!( docker_extra_hosts(&DockerGatewayRoute::HostGateway), - vec!["host.openshell.internal:host-gateway".to_string()] + vec![ + "host.docker.internal:host-gateway".to_string(), + "host.openshell.internal:host-gateway".to_string() + ] + ); +} + +#[test] +fn docker_gateway_route_uses_host_gateway_for_colima() { + let info = SystemInfo { + name: Some("colima".to_string()), + operating_system: Some("Ubuntu 24.04.4 LTS".to_string()), + ..Default::default() + }; + + assert_eq!( + docker_gateway_route( + &info, + IpAddr::V4(Ipv4Addr::new(172, 20, 0, 1)), + DEFAULT_SERVER_PORT, + None, + ), + DockerGatewayRoute::HostGateway + ); + assert_eq!( + docker_extra_hosts(&DockerGatewayRoute::HostGateway), + vec![ + "host.docker.internal:host-gateway".to_string(), + "host.openshell.internal:host-gateway".to_string() + ] ); } diff --git a/crates/openshell-policy/src/lib.rs b/crates/openshell-policy/src/lib.rs index 61df0aadb..908450111 100644 --- a/crates/openshell-policy/src/lib.rs +++ b/crates/openshell-policy/src/lib.rs @@ -120,6 +120,15 @@ struct NetworkEndpointDef { /// Defaults to false (strict). #[serde(default, skip_serializing_if = "std::ops::Not::not")] allow_encoded_slash: bool, + /// When true, client-to-server WebSocket text messages on this REST + /// endpoint rewrite credential placeholders after an allowed 101 upgrade. + /// Defaults to false. + #[serde(default, skip_serializing_if = "std::ops::Not::not")] + websocket_credential_rewrite: bool, + /// When true, supported textual REST request bodies rewrite credential + /// placeholders before forwarding upstream. Defaults to false. + #[serde(default, skip_serializing_if = "std::ops::Not::not")] + request_body_credential_rewrite: bool, #[serde(default, skip_serializing_if = "String::is_empty")] persisted_queries: String, #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] @@ -317,6 +326,8 @@ fn to_proto(raw: PolicyFile) -> SandboxPolicy { }) .collect(), allow_encoded_slash: e.allow_encoded_slash, + websocket_credential_rewrite: e.websocket_credential_rewrite, + request_body_credential_rewrite: e.request_body_credential_rewrite, persisted_queries: e.persisted_queries, graphql_persisted_queries: e .graphql_persisted_queries @@ -480,6 +491,8 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { }) .collect(), allow_encoded_slash: e.allow_encoded_slash, + websocket_credential_rewrite: e.websocket_credential_rewrite, + request_body_credential_rewrite: e.request_body_credential_rewrite, persisted_queries: e.persisted_queries.clone(), graphql_persisted_queries: e .graphql_persisted_queries @@ -1656,6 +1669,80 @@ network_policies: assert_eq!(ep.deny_rules[0].fields, vec!["deleteRepository"]); } + #[test] + fn round_trip_preserves_websocket_credential_rewrite() { + let yaml = r" +version: 1 +network_policies: + discord_gateway: + name: discord_gateway + endpoints: + - host: gateway.example.com + port: 443 + protocol: rest + enforcement: enforce + access: full + websocket_credential_rewrite: true + binaries: + - path: /usr/bin/node +"; + let proto1 = parse_sandbox_policy(yaml).expect("parse failed"); + let yaml_out = serialize_sandbox_policy(&proto1).expect("serialize failed"); + let proto2 = parse_sandbox_policy(&yaml_out).expect("re-parse failed"); + + let ep = &proto2.network_policies["discord_gateway"].endpoints[0]; + assert_eq!(ep.protocol, "rest"); + assert!(ep.websocket_credential_rewrite); + assert!(yaml_out.contains("websocket_credential_rewrite: true")); + } + + #[test] + fn round_trip_preserves_request_body_credential_rewrite() { + let yaml = r" +version: 1 +network_policies: + slack_api: + name: slack_api + endpoints: + - host: slack.com + port: 443 + protocol: rest + enforcement: enforce + access: read-write + request_body_credential_rewrite: true + binaries: + - path: /usr/bin/node +"; + let proto1 = parse_sandbox_policy(yaml).expect("parse failed"); + let yaml_out = serialize_sandbox_policy(&proto1).expect("serialize failed"); + let proto2 = parse_sandbox_policy(&yaml_out).expect("re-parse failed"); + + let ep = &proto2.network_policies["slack_api"].endpoints[0]; + assert_eq!(ep.protocol, "rest"); + assert!(ep.request_body_credential_rewrite); + assert!(yaml_out.contains("request_body_credential_rewrite: true")); + } + + #[test] + fn websocket_credential_rewrite_defaults_false() { + let yaml = r" +version: 1 +network_policies: + gateway: + endpoints: + - host: gateway.example.com + port: 443 + protocol: rest + access: full + binaries: + - path: /usr/bin/node +"; + let proto = parse_sandbox_policy(yaml).expect("parse failed"); + let ep = &proto.network_policies["gateway"].endpoints[0]; + assert!(!ep.websocket_credential_rewrite); + assert!(!ep.request_body_credential_rewrite); + } + #[test] fn parse_rejects_unknown_fields_in_deny_rule() { let yaml = r" diff --git a/crates/openshell-policy/src/merge.rs b/crates/openshell-policy/src/merge.rs index 7a5dec916..d99d9c216 100644 --- a/crates/openshell-policy/src/merge.rs +++ b/crates/openshell-policy/src/merge.rs @@ -184,7 +184,7 @@ impl std::fmt::Display for PolicyMergeError { protocol, } => write!( f, - "endpoint {host}:{port} uses unsupported protocol '{protocol}'; this operation currently supports only protocol 'rest'" + "endpoint {host}:{port} uses unsupported protocol '{protocol}'; this operation currently supports only protocol 'rest' or 'websocket'" ), Self::EndpointHasNoAllowBase { host, port } => write!( f, @@ -265,7 +265,7 @@ fn apply_operation( port: *port, } })?; - ensure_rest_endpoint(endpoint, host, *port)?; + ensure_method_path_endpoint(endpoint, host, *port)?; if endpoint.access.is_empty() && endpoint.rules.is_empty() { return Err(PolicyMergeError::EndpointHasNoAllowBase { host: host.clone(), @@ -281,7 +281,7 @@ fn apply_operation( port: *port, } })?; - ensure_rest_endpoint(endpoint, host, *port)?; + ensure_method_path_endpoint(endpoint, host, *port)?; expand_existing_access(endpoint, host, *port, warnings)?; append_unique_l7_rules(&mut endpoint.rules, rules); } @@ -462,6 +462,9 @@ fn merge_endpoint( append_unique_deny_rules(&mut existing.deny_rules, &incoming.deny_rules); append_unique_strings(&mut existing.allowed_ips, &incoming.allowed_ips); + existing.allow_encoded_slash |= incoming.allow_encoded_slash; + existing.websocket_credential_rewrite |= incoming.websocket_credential_rewrite; + existing.request_body_credential_rewrite |= incoming.request_body_credential_rewrite; normalize_endpoint(existing); Ok(()) } @@ -568,7 +571,7 @@ fn endpoint_matches_host_port(endpoint: &NetworkEndpoint, host: &str, port: u32) endpoint.host.eq_ignore_ascii_case(host) && canonical_ports(endpoint).contains(&port) } -fn ensure_rest_endpoint( +fn ensure_method_path_endpoint( endpoint: &NetworkEndpoint, host: &str, port: u32, @@ -579,7 +582,7 @@ fn ensure_rest_endpoint( port, }); } - if endpoint.protocol != "rest" { + if !matches!(endpoint.protocol.as_str(), "rest" | "websocket") { return Err(PolicyMergeError::UnsupportedEndpointProtocol { host: host.to_string(), port, @@ -600,12 +603,13 @@ fn expand_existing_access( } let access = endpoint.access.clone(); - let expanded = - expand_access_preset(&access).ok_or_else(|| PolicyMergeError::UnsupportedAccessPreset { + let expanded = expand_access_preset(&endpoint.protocol, &access).ok_or_else(|| { + PolicyMergeError::UnsupportedAccessPreset { host: host.to_string(), port, access: access.clone(), - })?; + } + })?; endpoint.access.clear(); append_unique_l7_rules(&mut endpoint.rules, &expanded); warnings.push(PolicyMergeWarning::ExpandedAccessPreset { @@ -616,11 +620,13 @@ fn expand_existing_access( Ok(()) } -fn expand_access_preset(access: &str) -> Option> { - let methods = match access { - "read-only" => vec!["GET", "HEAD", "OPTIONS"], - "read-write" => vec!["GET", "HEAD", "OPTIONS", "POST", "PUT", "PATCH"], - "full" => vec!["*"], +fn expand_access_preset(protocol: &str, access: &str) -> Option> { + let methods = match (protocol, access) { + (_, "full") => vec!["*"], + ("websocket", "read-only") => vec!["GET"], + ("websocket", "read-write") => vec!["GET", "WEBSOCKET_TEXT"], + (_, "read-only") => vec!["GET", "HEAD", "OPTIONS"], + (_, "read-write") => vec!["GET", "HEAD", "OPTIONS", "POST", "PUT", "PATCH"], _ => return None, }; @@ -870,6 +876,96 @@ mod tests { assert_eq!(rule.binaries.len(), 2); } + #[test] + fn add_rule_merges_websocket_credential_rewrite_flag() { + let mut policy = restrictive_default_policy(); + policy.network_policies.insert( + "existing".to_string(), + NetworkPolicyRule { + name: "existing".to_string(), + endpoints: vec![NetworkEndpoint { + host: "realtime.example.com".to_string(), + port: 443, + ports: vec![443], + protocol: "websocket".to_string(), + access: "read-write".to_string(), + ..Default::default() + }], + ..Default::default() + }, + ); + + let incoming = NetworkPolicyRule { + name: "incoming".to_string(), + endpoints: vec![NetworkEndpoint { + host: "realtime.example.com".to_string(), + port: 443, + ports: vec![443], + protocol: "websocket".to_string(), + websocket_credential_rewrite: true, + ..Default::default() + }], + ..Default::default() + }; + + let result = merge_policy( + policy, + &[PolicyMergeOp::AddRule { + rule_name: "allow_realtime_example_com_443".to_string(), + rule: incoming, + }], + ) + .expect("merge should succeed"); + + let endpoint = &result.policy.network_policies["existing"].endpoints[0]; + assert!(endpoint.websocket_credential_rewrite); + } + + #[test] + fn add_rule_merges_request_body_credential_rewrite_flag() { + let mut policy = restrictive_default_policy(); + policy.network_policies.insert( + "existing".to_string(), + NetworkPolicyRule { + name: "existing".to_string(), + endpoints: vec![NetworkEndpoint { + host: "slack.com".to_string(), + port: 443, + ports: vec![443], + protocol: "rest".to_string(), + access: "read-write".to_string(), + ..Default::default() + }], + ..Default::default() + }, + ); + + let incoming = NetworkPolicyRule { + name: "incoming".to_string(), + endpoints: vec![NetworkEndpoint { + host: "slack.com".to_string(), + port: 443, + ports: vec![443], + protocol: "rest".to_string(), + request_body_credential_rewrite: true, + ..Default::default() + }], + ..Default::default() + }; + + let result = merge_policy( + policy, + &[PolicyMergeOp::AddRule { + rule_name: "allow_slack_com_443".to_string(), + rule: incoming, + }], + ) + .expect("merge should succeed"); + + let endpoint = &result.policy.network_policies["existing"].endpoints[0]; + assert!(endpoint.request_body_credential_rewrite); + } + #[test] fn add_allow_expands_access_preset() { let mut policy = restrictive_default_policy(); @@ -909,7 +1005,92 @@ mod tests { } #[test] - fn add_deny_requires_rest_protocol() { + fn add_allow_expands_websocket_access_preset() { + let mut policy = restrictive_default_policy(); + policy.network_policies.insert( + "realtime".to_string(), + NetworkPolicyRule { + name: "realtime".to_string(), + endpoints: vec![NetworkEndpoint { + host: "realtime.example.com".to_string(), + port: 443, + ports: vec![443], + protocol: "websocket".to_string(), + access: "read-write".to_string(), + ..Default::default() + }], + ..Default::default() + }, + ); + + let result = merge_policy( + policy, + &[PolicyMergeOp::AddAllowRules { + host: "realtime.example.com".to_string(), + port: 443, + rules: vec![rest_rule("WEBSOCKET_TEXT", "/rooms/private/**")], + }], + ) + .expect("merge should succeed"); + + let endpoint = &result.policy.network_policies["realtime"].endpoints[0]; + assert!(endpoint.access.is_empty()); + assert_eq!(endpoint.rules.len(), 3); + assert!(endpoint.rules.contains(&rest_rule("GET", "**"))); + assert!(endpoint.rules.contains(&rest_rule("WEBSOCKET_TEXT", "**"))); + assert!( + endpoint + .rules + .contains(&rest_rule("WEBSOCKET_TEXT", "/rooms/private/**")) + ); + assert!(!endpoint.rules.contains(&rest_rule("POST", "**"))); + assert!(result.warnings.iter().any(|warning| matches!( + warning, + PolicyMergeWarning::ExpandedAccessPreset { access, .. } if access == "read-write" + ))); + } + + #[test] + fn add_deny_accepts_websocket_protocol() { + let mut policy = restrictive_default_policy(); + policy.network_policies.insert( + "realtime".to_string(), + NetworkPolicyRule { + name: "realtime".to_string(), + endpoints: vec![NetworkEndpoint { + host: "realtime.example.com".to_string(), + port: 443, + ports: vec![443], + protocol: "websocket".to_string(), + access: "read-write".to_string(), + ..Default::default() + }], + ..Default::default() + }, + ); + + let result = merge_policy( + policy, + &[PolicyMergeOp::AddDenyRules { + host: "realtime.example.com".to_string(), + port: 443, + deny_rules: vec![L7DenyRule { + method: "WEBSOCKET_TEXT".to_string(), + path: "/admin/**".to_string(), + ..Default::default() + }], + }], + ) + .expect("merge should succeed"); + + let endpoint = &result.policy.network_policies["realtime"].endpoints[0]; + assert_eq!(endpoint.deny_rules.len(), 1); + assert_eq!(endpoint.deny_rules[0].method, "WEBSOCKET_TEXT"); + assert_eq!(endpoint.deny_rules[0].path, "/admin/**"); + } + + #[test] + fn add_deny_rejects_unsupported_protocol() { let mut policy = restrictive_default_policy(); policy.network_policies.insert( "db".to_string(), diff --git a/crates/openshell-providers/src/profiles.rs b/crates/openshell-providers/src/profiles.rs index 8c3f247cf..588e77702 100644 --- a/crates/openshell-providers/src/profiles.rs +++ b/crates/openshell-providers/src/profiles.rs @@ -114,6 +114,10 @@ pub struct EndpointProfile { pub deny_rules: Vec, #[serde(default, skip_serializing_if = "is_false")] pub allow_encoded_slash: bool, + #[serde(default, skip_serializing_if = "is_false")] + pub websocket_credential_rewrite: bool, + #[serde(default, skip_serializing_if = "is_false")] + pub request_body_credential_rewrite: bool, #[serde(default, skip_serializing_if = "String::is_empty")] pub persisted_queries: String, #[serde(default, skip_serializing_if = "HashMap::is_empty")] @@ -414,6 +418,8 @@ fn endpoint_to_proto(endpoint: &EndpointProfile) -> NetworkEndpoint { ports: endpoint.ports.clone(), deny_rules: endpoint.deny_rules.iter().map(deny_rule_to_proto).collect(), allow_encoded_slash: endpoint.allow_encoded_slash, + websocket_credential_rewrite: endpoint.websocket_credential_rewrite, + request_body_credential_rewrite: endpoint.request_body_credential_rewrite, persisted_queries: endpoint.persisted_queries.clone(), graphql_persisted_queries: endpoint .graphql_persisted_queries @@ -442,6 +448,8 @@ fn endpoint_from_proto(endpoint: &NetworkEndpoint) -> EndpointProfile { .map(deny_rule_from_proto) .collect(), allow_encoded_slash: endpoint.allow_encoded_slash, + websocket_credential_rewrite: endpoint.websocket_credential_rewrite, + request_body_credential_rewrite: endpoint.request_body_credential_rewrite, persisted_queries: endpoint.persisted_queries.clone(), graphql_persisted_queries: endpoint .graphql_persisted_queries diff --git a/crates/openshell-sandbox/Cargo.toml b/crates/openshell-sandbox/Cargo.toml index 4e07521ce..29919ede4 100644 --- a/crates/openshell-sandbox/Cargo.toml +++ b/crates/openshell-sandbox/Cargo.toml @@ -58,6 +58,8 @@ uuid = { workspace = true } # Encoding base64 = { workspace = true } +flate2 = "1" +sha1 = "0.10" # IP network / CIDR parsing ipnet = "2" diff --git a/crates/openshell-sandbox/data/sandbox-policy.rego b/crates/openshell-sandbox/data/sandbox-policy.rego index 9fa820627..a8e4affce 100644 --- a/crates/openshell-sandbox/data/sandbox-policy.rego +++ b/crates/openshell-sandbox/data/sandbox-policy.rego @@ -260,6 +260,16 @@ request_denied_for_endpoint(request, endpoint) if { not graphql_request_allowed(request, endpoint) } +# The same authority applies when a WebSocket endpoint opts into GraphQL +# operation policy. Once the relay classifies a client text message as a +# GraphQL-over-WebSocket operation, generic WEBSOCKET_TEXT rules must not bypass +# operation_type / operation_name / fields policy. +request_denied_for_endpoint(request, endpoint) if { + endpoint.protocol == "websocket" + is_object(request.graphql) + not graphql_request_allowed(request, endpoint) +} + # Deny query matching: fail-closed semantics. # If no query rules on the deny rule, match unconditionally (any query params). # If query rules present, trigger the deny if ANY value for a configured key diff --git a/crates/openshell-sandbox/src/l7/graphql.rs b/crates/openshell-sandbox/src/l7/graphql.rs index db91ecb45..5d0746d01 100644 --- a/crates/openshell-sandbox/src/l7/graphql.rs +++ b/crates/openshell-sandbox/src/l7/graphql.rs @@ -78,6 +78,19 @@ pub fn classify_request(request: &L7Request, body: &[u8]) -> GraphqlRequestInfo } } +pub fn classify_json_envelope_value(value: &Value) -> GraphqlRequestInfo { + match classify_json_envelope(value) { + Ok(operations) => GraphqlRequestInfo { + operations, + error: None, + }, + Err(err) => GraphqlRequestInfo { + operations: Vec::new(), + error: Some(err), + }, + } +} + fn classify_request_inner( request: &L7Request, body: &[u8], diff --git a/crates/openshell-sandbox/src/l7/mod.rs b/crates/openshell-sandbox/src/l7/mod.rs index 5301ac4d5..09278b4f8 100644 --- a/crates/openshell-sandbox/src/l7/mod.rs +++ b/crates/openshell-sandbox/src/l7/mod.rs @@ -15,11 +15,13 @@ pub mod provider; pub mod relay; pub mod rest; pub mod tls; +pub(crate) mod websocket; /// Application-layer protocol for L7 inspection. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum L7Protocol { Rest, + Websocket, Graphql, Sql, } @@ -28,6 +30,7 @@ impl L7Protocol { pub fn parse(s: &str) -> Option { match s.to_ascii_lowercase().as_str() { "rest" => Some(Self::Rest), + "websocket" => Some(Self::Websocket), "graphql" => Some(Self::Graphql), "sql" => Some(Self::Sql), _ => None, @@ -58,6 +61,10 @@ pub enum EnforcementMode { } /// L7 configuration for an endpoint, extracted from policy data. +#[allow( + clippy::struct_excessive_bools, + reason = "Endpoint config mirrors independent policy schema toggles." +)] #[derive(Debug, Clone)] pub struct L7EndpointConfig { pub protocol: L7Protocol, @@ -72,6 +79,15 @@ pub struct L7EndpointConfig { /// rather than rejected at the parser. Needed by upstreams like GitLab /// that embed `%2F` in namespaced project paths. Defaults to false. pub allow_encoded_slash: bool, + /// Opt-in rewrite of credential placeholders in client-to-server + /// WebSocket text messages after an allowed HTTP 101 upgrade. + pub websocket_credential_rewrite: bool, + /// Opt-in rewrite of credential placeholders in supported textual REST + /// request bodies before forwarding upstream. + pub request_body_credential_rewrite: bool, + /// When true, client-to-server GraphQL-over-WebSocket operation messages + /// are classified with the same operation policy used by GraphQL-over-HTTP. + pub websocket_graphql_policy: bool, } /// Result of an L7 policy decision for a single request. @@ -138,6 +154,12 @@ pub fn parse_l7_config(val: ®orus::Value) -> Option { }; let allow_encoded_slash = get_object_bool(val, "allow_encoded_slash").unwrap_or(false); + let websocket_credential_rewrite = + get_object_bool(val, "websocket_credential_rewrite").unwrap_or(false); + let request_body_credential_rewrite = + get_object_bool(val, "request_body_credential_rewrite").unwrap_or(false); + let websocket_graphql_policy = + protocol == L7Protocol::Websocket && endpoint_has_graphql_policy(val); let graphql_max_body_bytes = get_object_u64(val, "graphql_max_body_bytes") .and_then(|v| usize::try_from(v).ok()) .filter(|v| *v > 0) @@ -150,6 +172,9 @@ pub fn parse_l7_config(val: ®orus::Value) -> Option { enforcement, graphql_max_body_bytes, allow_encoded_slash, + websocket_credential_rewrite, + request_body_credential_rewrite, + websocket_graphql_policy, }) } @@ -231,6 +256,60 @@ fn get_object_str(val: ®orus::Value, key: &str) -> Option { } } +fn endpoint_has_graphql_policy(val: ®orus::Value) -> bool { + has_non_empty_object_field(val, "graphql_persisted_queries") + || has_graphql_persisted_query_mode(val) + || rules_have_graphql_policy(val, "rules", true) + || rules_have_graphql_policy(val, "deny_rules", false) +} + +fn rules_have_graphql_policy(val: ®orus::Value, key: &str, allow_wrapped: bool) -> bool { + let Some(regorus::Value::Array(rules)) = get_object_value(val, key) else { + return false; + }; + rules.iter().any(|rule| { + let rule = if allow_wrapped { + get_object_value(rule, "allow").unwrap_or(rule) + } else { + rule + }; + has_graphql_rule_fields(rule) + }) +} + +fn has_graphql_rule_fields(val: ®orus::Value) -> bool { + has_non_empty_string_field(val, "operation_type") + || has_non_empty_string_field(val, "operation_name") + || has_non_empty_array_field(val, "fields") +} + +fn has_non_empty_string_field(val: ®orus::Value, key: &str) -> bool { + matches!(get_object_value(val, key), Some(regorus::Value::String(s)) if !s.is_empty()) +} + +fn has_non_empty_array_field(val: ®orus::Value, key: &str) -> bool { + matches!(get_object_value(val, key), Some(regorus::Value::Array(values)) if !values.is_empty()) +} + +fn has_non_empty_object_field(val: ®orus::Value, key: &str) -> bool { + matches!(get_object_value(val, key), Some(regorus::Value::Object(values)) if !values.is_empty()) +} + +fn has_graphql_persisted_query_mode(val: ®orus::Value) -> bool { + matches!( + get_object_value(val, "persisted_queries"), + Some(regorus::Value::String(mode)) if !mode.is_empty() && mode.as_ref() != "deny" + ) +} + +fn get_object_value<'a>(val: &'a regorus::Value, key: &str) -> Option<&'a regorus::Value> { + let key_val = regorus::Value::String(key.into()); + match val { + regorus::Value::Object(map) => map.get(&key_val), + _ => None, + } +} + /// Check a glob pattern for obvious syntax issues. /// /// Returns `Some(warning_message)` if the pattern looks malformed. @@ -353,6 +432,45 @@ fn validate_graphql_rule( validate_graphql_fields(errors, warnings, loc, rule.get("fields")); } +fn json_rule_has_graphql_fields(rule: &serde_json::Value) -> bool { + rule.get("operation_type") + .and_then(|v| v.as_str()) + .is_some_and(|v| !v.is_empty()) + || rule + .get("operation_name") + .and_then(|v| v.as_str()) + .is_some_and(|v| !v.is_empty()) + || rule.get("fields").is_some() +} + +fn json_rule_has_transport_fields(rule: &serde_json::Value) -> bool { + rule.get("method").is_some() || rule.get("path").is_some() || rule.get("query").is_some() +} + +fn json_endpoint_has_graphql_policy(ep: &serde_json::Value) -> bool { + ep.get("graphql_persisted_queries") + .and_then(|v| v.as_object()) + .is_some_and(|v| !v.is_empty()) + || ep + .get("persisted_queries") + .and_then(|v| v.as_str()) + .is_some_and(|v| !v.is_empty() && v != "deny") + || ep + .get("rules") + .and_then(|v| v.as_array()) + .is_some_and(|rules| { + rules.iter().any(|rule| { + rule.get("allow") + .or(Some(rule)) + .is_some_and(json_rule_has_graphql_fields) + }) + }) + || ep + .get("deny_rules") + .and_then(|v| v.as_array()) + .is_some_and(|rules| rules.iter().any(json_rule_has_graphql_fields)) +} + /// Validate L7 policy configuration in the loaded OPA data. /// /// Returns a list of errors and warnings. Errors should prevent sandbox startup; @@ -382,6 +500,8 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< .get("rules") .and_then(|v| v.as_array()) .is_some_and(|a| !a.is_empty()); + let websocket_has_graphql_policy = + protocol == "websocket" && json_endpoint_has_graphql_policy(ep); let host = ep.get("host").and_then(|v| v.as_str()).unwrap_or(""); let endpoint_path = ep.get("path").and_then(|v| v.as_str()).unwrap_or(""); @@ -462,7 +582,7 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< if !protocol.is_empty() && L7Protocol::parse(protocol).is_none() { errors.push(format!( - "{loc}: unknown protocol '{protocol}' (expected rest, graphql, or sql)" + "{loc}: unknown protocol '{protocol}' (expected rest, websocket, graphql, or sql)" )); } @@ -489,12 +609,36 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< } if protocol != "graphql" + && protocol != "websocket" && (ep.get("persisted_queries").is_some() || ep.get("graphql_persisted_queries").is_some() || ep.get("graphql_max_body_bytes").is_some()) { warnings.push(format!( - "{loc}: GraphQL-specific endpoint fields are ignored unless protocol is graphql" + "{loc}: GraphQL-specific endpoint fields are ignored unless protocol is graphql or websocket" + )); + } + + if ep + .get("websocket_credential_rewrite") + .and_then(serde_json::Value::as_bool) + .unwrap_or(false) + && protocol != "rest" + && protocol != "websocket" + { + warnings.push(format!( + "{loc}: websocket_credential_rewrite is ignored unless protocol is rest or websocket" + )); + } + + if ep + .get("request_body_credential_rewrite") + .and_then(serde_json::Value::as_bool) + .unwrap_or(false) + && protocol != "rest" + { + warnings.push(format!( + "{loc}: request_body_credential_rewrite is ignored unless protocol is rest" )); } @@ -574,14 +718,13 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< // Validate method if let Some(method) = deny_rule.get("method").and_then(|m| m.as_str()) && !method.is_empty() - && protocol == "rest" + && (protocol == "rest" || protocol == "websocket") { - let valid_methods = [ - "GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "*", - ]; + let valid_methods = valid_methods_for_protocol(protocol); if !valid_methods.contains(&method.to_ascii_uppercase().as_str()) { warnings.push(format!( - "{deny_loc}: Unknown HTTP method '{method}'. Standard methods: GET, HEAD, POST, PUT, DELETE, PATCH, OPTIONS." + "{deny_loc}: Unknown HTTP/WebSocket method '{method}'. Standard methods: {}." + , valid_methods.join(", ") )); } } @@ -701,7 +844,17 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< .push(format!("{deny_loc}: command is for SQL protocol, not REST")); } - if protocol == "graphql" { + let deny_has_graphql = json_rule_has_graphql_fields(deny_rule); + if protocol == "websocket" + && deny_has_graphql + && json_rule_has_transport_fields(deny_rule) + { + errors.push(format!( + "{deny_loc}: WebSocket GraphQL deny rules must not combine method/path/query with operation_type/operation_name/fields" + )); + } + + if protocol == "graphql" || (protocol == "websocket" && deny_has_graphql) { validate_graphql_rule( &mut errors, &mut warnings, @@ -709,12 +862,9 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< deny_rule, true, ); - } else if deny_rule.get("operation_type").is_some() - || deny_rule.get("operation_name").is_some() - || deny_rule.get("fields").is_some() - { + } else if deny_has_graphql { warnings.push(format!( - "{deny_loc}: GraphQL rule fields are ignored unless protocol is graphql" + "{deny_loc}: GraphQL rule fields are ignored unless protocol is graphql or websocket" )); } } @@ -733,10 +883,8 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< } // Validate HTTP methods in rules - if has_rules && protocol == "rest" { - let valid_methods = [ - "GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "*", - ]; + if has_rules && (protocol == "rest" || protocol == "websocket") { + let valid_methods = valid_methods_for_protocol(protocol); if let Some(rules) = ep.get("rules").and_then(|v| v.as_array()) { for (rule_idx, rule) in rules.iter().enumerate() { if let Some(method) = rule @@ -747,7 +895,8 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< && !valid_methods.contains(&method.to_ascii_uppercase().as_str()) { warnings.push(format!( - "{loc}: Unknown HTTP method '{method}'. Standard methods: GET, HEAD, POST, PUT, DELETE, PATCH, OPTIONS." + "{loc}: Unknown HTTP/WebSocket method '{method}'. Standard methods: {}." + , valid_methods.join(", ") )); } @@ -858,14 +1007,36 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< } } - if has_rules - && protocol == "graphql" - && let Some(rules) = ep.get("rules").and_then(|v| v.as_array()) - { + if has_rules && let Some(rules) = ep.get("rules").and_then(|v| v.as_array()) { for (rule_idx, rule) in rules.iter().enumerate() { let allow = rule.get("allow").unwrap_or(rule); let rule_loc = format!("{loc}.rules[{rule_idx}].allow"); - validate_graphql_rule(&mut errors, &mut warnings, &rule_loc, allow, true); + let allow_has_graphql = json_rule_has_graphql_fields(allow); + if websocket_has_graphql_policy + && allow + .get("method") + .and_then(|m| m.as_str()) + .is_some_and(|method| method.eq_ignore_ascii_case("WEBSOCKET_TEXT")) + { + errors.push(format!( + "{rule_loc}: WebSocket endpoints with GraphQL operation policy must use operation_type/operation_name/fields rules for client messages instead of WEBSOCKET_TEXT" + )); + } + if protocol == "websocket" + && allow_has_graphql + && json_rule_has_transport_fields(allow) + { + errors.push(format!( + "{rule_loc}: WebSocket GraphQL allow rules must not combine method/path/query with operation_type/operation_name/fields" + )); + } + if protocol == "graphql" || (protocol == "websocket" && allow_has_graphql) { + validate_graphql_rule(&mut errors, &mut warnings, &rule_loc, allow, true); + } else if allow_has_graphql { + warnings.push(format!( + "{rule_loc}: GraphQL rule fields are ignored unless protocol is graphql or websocket" + )); + } } } } @@ -921,6 +1092,13 @@ pub fn expand_access_presets(data: &mut serde_json::Value) { "full" => vec![graphql_rule_json("*")], _ => continue, } + } else if protocol == "websocket" { + match access.as_str() { + "read-only" => vec![rule_json("GET", "**")], + "read-write" => vec![rule_json("GET", "**"), rule_json("WEBSOCKET_TEXT", "**")], + "full" => vec![rule_json("*", "**")], + _ => continue, + } } else { match access.as_str() { "read-only" => vec![ @@ -957,6 +1135,15 @@ fn rule_json(method: &str, path: &str) -> serde_json::Value { }) } +fn valid_methods_for_protocol(protocol: &str) -> &'static [&'static str] { + match protocol { + "websocket" => &["GET", "WEBSOCKET_TEXT", "*"], + _ => &[ + "GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "*", + ], + } +} + fn graphql_rule_json(operation_type: &str) -> serde_json::Value { serde_json::json!({ "allow": { @@ -994,6 +1181,16 @@ mod tests { assert_eq!(config.enforcement, EnforcementMode::Audit); } + #[test] + fn parse_l7_config_websocket_protocol() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "websocket", "host": "gateway.example.com", "port": 443}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert_eq!(config.protocol, L7Protocol::Websocket); + } + #[test] fn parse_l7_config_skip() { let val = regorus::Value::from_json_str( @@ -1031,6 +1228,242 @@ mod tests { assert!(config.allow_encoded_slash); } + #[test] + fn parse_l7_config_websocket_credential_rewrite_defaults_false() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "rest", "host": "gateway.example.com", "port": 443}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(!config.websocket_credential_rewrite); + } + + #[test] + fn parse_l7_config_websocket_credential_rewrite_opt_in() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "rest", "host": "gateway.example.com", "port": 443, "websocket_credential_rewrite": true}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(config.websocket_credential_rewrite); + } + + #[test] + fn parse_l7_config_request_body_credential_rewrite_defaults_false() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "rest", "host": "slack.com", "port": 443}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(!config.request_body_credential_rewrite); + } + + #[test] + fn parse_l7_config_request_body_credential_rewrite_opt_in() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "rest", "host": "slack.com", "port": 443, "request_body_credential_rewrite": true}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(config.request_body_credential_rewrite); + } + + #[test] + fn parse_l7_config_websocket_graphql_policy_defaults_false() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "websocket", "host": "gateway.example.com", "port": 443, "rules": [{"allow": {"method": "GET", "path": "/graphql"}}, {"allow": {"method": "WEBSOCKET_TEXT", "path": "/graphql"}}]}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(!config.websocket_graphql_policy); + } + + #[test] + fn parse_l7_config_websocket_graphql_policy_detects_operation_rules() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "websocket", "host": "gateway.example.com", "port": 443, "rules": [{"allow": {"method": "GET", "path": "/graphql"}}, {"allow": {"operation_type": "subscription", "fields": ["messageAdded"]}}]}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(config.websocket_graphql_policy); + } + + #[test] + fn validate_websocket_credential_rewrite_warns_unless_rest_or_websocket() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "websocket_credential_rewrite": true + }], + "binaries": [] + } + } + }); + let (_errors, warnings) = validate_l7_policies(&data); + assert!( + warnings + .iter() + .any(|w| w.contains("websocket_credential_rewrite is ignored")), + "expected websocket_credential_rewrite warning: {warnings:?}" + ); + } + + #[test] + fn validate_request_body_credential_rewrite_warns_unless_rest() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "request_body_credential_rewrite": true + }], + "binaries": [] + } + } + }); + let (_errors, warnings) = validate_l7_policies(&data); + assert!( + warnings + .iter() + .any(|w| w.contains("request_body_credential_rewrite is ignored")), + "expected request_body_credential_rewrite warning: {warnings:?}" + ); + } + + #[test] + fn expand_websocket_read_write_access_includes_text_messages() { + let mut data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "access": "read-write" + }], + "binaries": [] + } + } + }); + + expand_access_presets(&mut data); + let rules = data["network_policies"]["test"]["endpoints"][0]["rules"] + .as_array() + .unwrap(); + let methods: Vec<&str> = rules + .iter() + .map(|r| r["allow"]["method"].as_str().unwrap()) + .collect(); + assert!(methods.contains(&"GET")); + assert!(methods.contains(&"WEBSOCKET_TEXT")); + } + + #[test] + fn validate_websocket_accepts_graphql_operation_rules() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "rules": [ + {"allow": {"method": "GET", "path": "/graphql"}}, + {"allow": {"operation_type": "subscription", "fields": ["messageAdded"]}} + ] + }], + "binaries": [] + } + } + }); + let (errors, warnings) = validate_l7_policies(&data); + assert!(errors.is_empty(), "expected no errors: {errors:?}"); + assert!(warnings.is_empty(), "expected no warnings: {warnings:?}"); + } + + #[test] + fn validate_websocket_graphql_rule_requires_operation_type() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "rules": [ + {"allow": {"method": "GET", "path": "/graphql"}}, + {"allow": {"fields": ["messageAdded"]}} + ] + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors.iter().any(|e| e.contains("operation_type")), + "expected missing operation_type error: {errors:?}" + ); + } + + #[test] + fn validate_websocket_graphql_rule_rejects_mixed_transport_fields() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "rules": [ + {"allow": {"method": "GET", "path": "/graphql"}}, + {"allow": {"method": "WEBSOCKET_TEXT", "path": "/graphql", "operation_type": "subscription"}} + ] + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors.iter().any(|e| e.contains("must not combine")), + "expected mixed-field error: {errors:?}" + ); + } + + #[test] + fn validate_websocket_graphql_policy_rejects_raw_text_message_rule() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "rules": [ + {"allow": {"method": "GET", "path": "/graphql"}}, + {"allow": {"method": "WEBSOCKET_TEXT", "path": "/graphql"}}, + {"allow": {"operation_type": "query"}} + ] + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors + .iter() + .any(|e| e.contains("instead of WEBSOCKET_TEXT")), + "expected raw WEBSOCKET_TEXT rejection: {errors:?}" + ); + } + #[test] fn validate_rules_and_access_mutual_exclusion() { let data = serde_json::json!({ diff --git a/crates/openshell-sandbox/src/l7/provider.rs b/crates/openshell-sandbox/src/l7/provider.rs index 7516aa85c..864d94ad2 100644 --- a/crates/openshell-sandbox/src/l7/provider.rs +++ b/crates/openshell-sandbox/src/l7/provider.rs @@ -27,7 +27,10 @@ pub enum RelayOutcome { /// Contains any overflow bytes read from upstream past the 101 response /// headers that belong to the upgraded protocol. The 101 headers /// themselves have already been forwarded to the client. - Upgraded { overflow: Vec }, + Upgraded { + overflow: Vec, + websocket_permessage_deflate: bool, + }, } /// Body framing for HTTP requests/responses. diff --git a/crates/openshell-sandbox/src/l7/relay.rs b/crates/openshell-sandbox/src/l7/relay.rs index f099c3558..971b2e8e5 100644 --- a/crates/openshell-sandbox/src/l7/relay.rs +++ b/crates/openshell-sandbox/src/l7/relay.rs @@ -8,6 +8,7 @@ //! and either forwards or denies the request. use crate::l7::provider::{L7Provider, RelayOutcome}; +use crate::l7::rest::WebSocketExtensionMode; use crate::l7::{EnforcementMode, L7EndpointConfig, L7Protocol, L7RequestInfo}; use crate::opa::{PolicyGenerationGuard, TunnelPolicyEngine}; use crate::secrets::{self, SecretResolver}; @@ -38,6 +39,44 @@ pub struct L7EvalContext { pub(crate) secret_resolver: Option>, } +#[derive(Default)] +pub(crate) struct UpgradeRelayOptions<'a> { + pub(crate) websocket_request: bool, + pub(crate) websocket: WebSocketUpgradeBehavior, + pub(crate) secret_resolver: Option>, + pub(crate) engine: Option<&'a TunnelPolicyEngine>, + pub(crate) ctx: Option<&'a L7EvalContext>, + pub(crate) enforcement: EnforcementMode, + pub(crate) target: String, + pub(crate) query_params: std::collections::HashMap>, + pub(crate) policy_name: String, +} + +#[derive(Default)] +pub(crate) struct WebSocketUpgradeBehavior { + pub(crate) credential_rewrite: bool, + pub(crate) message_policy: WebSocketMessagePolicy, + pub(crate) permessage_deflate: bool, +} + +#[derive(Clone, Copy, Default, PartialEq, Eq)] +pub(crate) enum WebSocketMessagePolicy { + #[default] + None, + Transport, + Graphql, +} + +impl WebSocketMessagePolicy { + fn inspects_messages(self) -> bool { + self != Self::None + } + + fn is_graphql(self) -> bool { + self == Self::Graphql + } +} + #[derive(Debug, Clone, Copy)] enum ParseRejectionMode { L7Endpoint, @@ -101,7 +140,9 @@ where U: AsyncRead + AsyncWrite + Unpin + Send, { match config.protocol { - L7Protocol::Rest => relay_rest(config, &engine, client, upstream, ctx).await, + L7Protocol::Rest | L7Protocol::Websocket => { + relay_rest(config, &engine, client, upstream, ctx).await + } L7Protocol::Graphql => relay_graphql(config, &engine, client, upstream, ctx).await, L7Protocol::Sql => { if close_if_stale(engine.generation_guard(), ctx) { @@ -242,6 +283,24 @@ where query_params: req.query_params.clone(), graphql: graphql_info.clone(), }; + let websocket_request = crate::l7::rest::request_is_websocket_upgrade(&req.raw_header); + if config.protocol == L7Protocol::Websocket && !websocket_request { + crate::l7::rest::RestProvider::default() + .deny_with_redacted_target( + &req, + &ctx.policy_name, + "websocket endpoint requires a valid WebSocket upgrade request", + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } let parse_error_reason = graphql_info .as_ref() @@ -264,10 +323,10 @@ where (false, EnforcementMode::Audit) => "audit", (false, EnforcementMode::Enforce) => "deny", }; - let engine_type = if config.protocol == L7Protocol::Graphql { - "l7-graphql" - } else { - "l7" + let engine_type = match config.protocol { + L7Protocol::Graphql => "l7-graphql", + L7Protocol::Websocket => "l7-websocket", + L7Protocol::Rest | L7Protocol::Sql => "l7", }; emit_l7_request_log( ctx, @@ -282,19 +341,39 @@ where let _ = &eval_target; if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { - let outcome = crate::l7::rest::relay_http_request_with_resolver_guarded( + let outcome = crate::l7::rest::relay_http_request_with_options_guarded( &req, client, upstream, - ctx.secret_resolver.as_deref(), - Some(engine.generation_guard()), + crate::l7::rest::RelayRequestOptions { + resolver: ctx.secret_resolver.as_deref(), + generation_guard: Some(engine.generation_guard()), + websocket_extensions: websocket_extension_mode(config), + request_body_credential_rewrite: config.protocol == L7Protocol::Rest + && config.request_body_credential_rewrite, + }, ) .await?; match outcome { RelayOutcome::Reusable => {} RelayOutcome::Consumed => return Ok(()), - RelayOutcome::Upgraded { overflow } => { - return handle_upgrade(client, upstream, overflow, &ctx.host, ctx.port).await; + RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + } => { + let mut options = upgrade_options( + config, + ctx, + websocket_request, + &redacted_target, + &req.query_params, + Some(&engine), + ); + options.websocket.permessage_deflate = websocket_permessage_deflate; + return handle_upgrade( + client, upstream, overflow, &ctx.host, ctx.port, options, + ) + .await; } } } else { @@ -374,20 +453,29 @@ fn emit_l7_request_log( /// Handle an upgraded connection (101 Switching Protocols). /// /// Forwards any overflow bytes from the upgrade response to the client, then -/// switches to raw bidirectional TCP copy for the upgraded protocol (WebSocket, -/// HTTP/2, etc.). L7 policy enforcement does not apply after the upgrade — -/// the initial HTTP request was already evaluated. +/// either switches to a parsed WebSocket relay for opted-in message policy / +/// credential rewriting or to raw bidirectional TCP copy for other upgrades. pub(crate) async fn handle_upgrade( client: &mut C, upstream: &mut U, overflow: Vec, host: &str, port: u16, + options: UpgradeRelayOptions<'_>, ) -> Result<()> where C: AsyncRead + AsyncWrite + Unpin + Send, U: AsyncRead + AsyncWrite + Unpin + Send, { + let use_websocket_relay = options.websocket_request + && (options.websocket.message_policy.inspects_messages() + || options.websocket.permessage_deflate + || (options.websocket.credential_rewrite && options.secret_resolver.is_some())); + let relay_mode = if use_websocket_relay { + "websocket parsed relay" + } else { + "raw bidirectional relay (L7 enforcement no longer active)" + }; ocsf_emit!( NetworkActivityBuilder::new(crate::ocsf_ctx()) .activity(ActivityId::Other) @@ -395,12 +483,56 @@ where .severity(SeverityId::Informational) .dst_endpoint(Endpoint::from_domain(host, port)) .message(format!( - "101 Switching Protocols — raw bidirectional relay (L7 enforcement no longer active) \ - [host:{host} port:{port} overflow_bytes:{}]", + "101 Switching Protocols — {relay_mode} [host:{host} port:{port} overflow_bytes:{}]", overflow.len() )) .build() ); + if use_websocket_relay { + let resolver = if options.websocket.credential_rewrite { + options.secret_resolver.as_deref() + } else { + None + }; + let inspector = if options.websocket.message_policy.inspects_messages() { + match (options.engine, options.ctx) { + (Some(engine), Some(ctx)) => Some(crate::l7::websocket::InspectionOptions { + engine, + ctx, + enforcement: options.enforcement, + target: options.target.clone(), + query_params: options.query_params.clone(), + graphql_policy: options.websocket.message_policy.is_graphql(), + }), + _ => { + return Err(miette!( + "websocket message inspection missing policy context" + )); + } + } + } else { + None + }; + let compression = if options.websocket.permessage_deflate { + crate::l7::websocket::WebSocketCompression::PermessageDeflate + } else { + crate::l7::websocket::WebSocketCompression::None + }; + return crate::l7::websocket::relay_with_options( + client, + upstream, + overflow, + host, + port, + crate::l7::websocket::RelayOptions { + policy_name: &options.policy_name, + resolver, + inspector, + compression, + }, + ) + .await; + } if !overflow.is_empty() { client.write_all(&overflow).await.into_diagnostic()?; client.flush().await.into_diagnostic()?; @@ -411,6 +543,57 @@ where Ok(()) } +fn upgrade_options<'a>( + config: &L7EndpointConfig, + ctx: &'a L7EvalContext, + websocket_request: bool, + target: &str, + query_params: &std::collections::HashMap>, + engine: Option<&'a TunnelPolicyEngine>, +) -> UpgradeRelayOptions<'a> { + let websocket_credential_rewrite = + matches!(config.protocol, L7Protocol::Rest | L7Protocol::Websocket) + && config.websocket_credential_rewrite; + let websocket_message_policy = if config.protocol == L7Protocol::Websocket { + if config.websocket_graphql_policy { + WebSocketMessagePolicy::Graphql + } else { + WebSocketMessagePolicy::Transport + } + } else { + WebSocketMessagePolicy::None + }; + UpgradeRelayOptions { + websocket_request, + websocket: WebSocketUpgradeBehavior { + credential_rewrite: websocket_credential_rewrite, + message_policy: websocket_message_policy, + permessage_deflate: false, + }, + secret_resolver: if websocket_credential_rewrite { + ctx.secret_resolver.clone() + } else { + None + }, + engine, + ctx: engine.map(|_| ctx), + enforcement: config.enforcement, + target: target.to_string(), + query_params: query_params.clone(), + policy_name: ctx.policy_name.clone(), + } +} + +fn websocket_extension_mode(config: &L7EndpointConfig) -> WebSocketExtensionMode { + if config.protocol == L7Protocol::Websocket + || (config.protocol == L7Protocol::Rest && config.websocket_credential_rewrite) + { + WebSocketExtensionMode::PermessageDeflate + } else { + WebSocketExtensionMode::Preserve + } +} + /// REST relay loop: parse request -> evaluate -> allow/deny -> relay response -> repeat. async fn relay_rest( config: &L7EndpointConfig, @@ -490,6 +673,24 @@ where query_params: req.query_params.clone(), graphql: None, }; + let websocket_request = crate::l7::rest::request_is_websocket_upgrade(&req.raw_header); + if config.protocol == L7Protocol::Websocket && !websocket_request { + provider + .deny_with_redacted_target( + &req, + &ctx.policy_name, + "websocket endpoint requires a valid WebSocket upgrade request", + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } // Evaluate L7 policy via Rego (using redacted target) let (allowed, reason) = evaluate_l7_request(engine, ctx, &request_info)?; @@ -558,12 +759,17 @@ where if allowed || config.enforcement == EnforcementMode::Audit { // Forward request to upstream and relay response - let outcome = crate::l7::rest::relay_http_request_with_resolver_guarded( + let outcome = crate::l7::rest::relay_http_request_with_options_guarded( &req, client, upstream, - ctx.secret_resolver.as_deref(), - Some(engine.generation_guard()), + crate::l7::rest::RelayRequestOptions { + resolver: ctx.secret_resolver.as_deref(), + generation_guard: Some(engine.generation_guard()), + websocket_extensions: websocket_extension_mode(config), + request_body_credential_rewrite: config.protocol == L7Protocol::Rest + && config.request_body_credential_rewrite, + }, ) .await?; match outcome { @@ -576,8 +782,23 @@ where ); return Ok(()); } - RelayOutcome::Upgraded { overflow } => { - return handle_upgrade(client, upstream, overflow, &ctx.host, ctx.port).await; + RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + } => { + let mut options = upgrade_options( + config, + ctx, + websocket_request, + &redacted_target, + &req.query_params, + Some(engine), + ); + options.websocket.permessage_deflate = websocket_permessage_deflate; + return handle_upgrade( + client, upstream, overflow, &ctx.host, ctx.port, options, + ) + .await; } } } else { @@ -787,8 +1008,21 @@ where ); return Ok(()); } - RelayOutcome::Upgraded { overflow } => { - return handle_upgrade(client, upstream, overflow, &ctx.host, ctx.port).await; + RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + } => { + let options = UpgradeRelayOptions { + websocket: WebSocketUpgradeBehavior { + permessage_deflate: websocket_permessage_deflate, + ..Default::default() + }, + ..Default::default() + }; + return handle_upgrade( + client, upstream, overflow, &ctx.host, ctx.port, options, + ) + .await; } } } else { @@ -1016,20 +1250,31 @@ where // Forward request with credential rewriting and relay the response. // relay_http_request_with_resolver handles both directions: it sends // the request upstream and reads the response back to the client. - let outcome = crate::l7::rest::relay_http_request_with_resolver_guarded( + let outcome = crate::l7::rest::relay_http_request_with_options_guarded( &req, client, upstream, - resolver, - Some(generation_guard), + crate::l7::rest::RelayRequestOptions { + resolver, + generation_guard: Some(generation_guard), + ..Default::default() + }, ) .await?; match outcome { RelayOutcome::Reusable => {} // continue loop RelayOutcome::Consumed => break, - RelayOutcome::Upgraded { overflow } => { - return handle_upgrade(client, upstream, overflow, &ctx.host, ctx.port).await; + RelayOutcome::Upgraded { overflow, .. } => { + return handle_upgrade( + client, + upstream, + overflow, + &ctx.host, + ctx.port, + UpgradeRelayOptions::default(), + ) + .await; } } } @@ -1049,7 +1294,7 @@ mod tests { use super::*; use crate::opa::{NetworkInput, OpaEngine}; use std::path::PathBuf; - use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; const TEST_POLICY: &str = include_str!("../../data/sandbox-policy.rego"); @@ -1086,6 +1331,436 @@ mod tests { ); } + #[test] + fn websocket_text_policy_requires_explicit_message_rule() { + let data = r#" +network_policies: + ws_api: + name: ws_api + endpoints: + - host: gateway.example.test + port: 443 + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/ws" + binaries: + - { path: /usr/bin/node } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "gateway.example.test".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let generation = engine + .evaluate_network_action_with_generation(&input) + .unwrap() + .1; + let tunnel_engine = engine.clone_engine_for_tunnel(generation).unwrap(); + let ctx = L7EvalContext { + host: "gateway.example.test".into(), + port: 443, + policy_name: "ws_api".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + }; + let request = L7RequestInfo { + action: "WEBSOCKET_TEXT".into(), + target: "/ws".into(), + query_params: std::collections::HashMap::new(), + graphql: None, + }; + + let (allowed, reason) = evaluate_l7_request(&tunnel_engine, &ctx, &request).unwrap(); + + assert!(!allowed); + assert!(reason.contains("WEBSOCKET_TEXT /ws not permitted")); + } + + #[tokio::test] + async fn route_selected_websocket_upgrade_rejects_invalid_accept_without_forwarding_101() { + let data = r#" +network_policies: + route_api: + name: route_api + endpoints: + - host: gateway.example.test + port: 443 + protocol: rest + enforcement: enforce + rules: + - allow: + method: GET + path: "/ws" + binaries: + - { path: /usr/bin/node } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let tunnel_engine = engine + .clone_engine_for_tunnel(engine.current_generation()) + .unwrap(); + let configs = vec![L7EndpointConfig { + protocol: L7Protocol::Rest, + path: "/ws".into(), + tls: crate::l7::TlsMode::Auto, + enforcement: EnforcementMode::Enforce, + graphql_max_body_bytes: 0, + allow_encoded_slash: false, + websocket_credential_rewrite: true, + request_body_credential_rewrite: false, + websocket_graphql_policy: false, + }]; + let ctx = L7EvalContext { + host: "gateway.example.test".into(), + port: 443, + policy_name: "route_api".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + }; + + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_route_selection( + &configs, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + app.write_all( + b"GET /ws HTTP/1.1\r\nHost: gateway.example.test\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n", + ) + .await + .unwrap(); + + let mut forwarded = [0u8; 1024]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut forwarded), + ) + .await + .expect("upgrade request should reach upstream") + .unwrap(); + let forwarded = String::from_utf8_lossy(&forwarded[..n]); + assert!(forwarded.contains("Upgrade: websocket\r\n")); + assert!(forwarded.contains("Connection: Upgrade\r\n")); + + upstream + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: invalid\r\n\r\n", + ) + .await + .unwrap(); + + let err = tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should fail closed on invalid accept") + .unwrap() + .expect_err("invalid accept must fail the route-selected relay"); + assert!(err.to_string().contains("Sec-WebSocket-Accept")); + + let mut response = [0u8; 1]; + let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) + .await + .expect("client side should close without 101") + .unwrap(); + assert_eq!(n, 0, "invalid response must not forward 101 headers"); + } + + #[tokio::test] + async fn route_selected_websocket_rewrites_text_credentials_after_upgrade() { + let data = r#" +network_policies: + route_api: + name: route_api + endpoints: + - host: gateway.example.test + port: 443 + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/ws" + - allow: + method: WEBSOCKET_TEXT + path: "/ws" + websocket_credential_rewrite: true + binaries: + - { path: /usr/bin/node } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let tunnel_engine = engine + .clone_engine_for_tunnel(engine.current_generation()) + .unwrap(); + let configs = vec![L7EndpointConfig { + protocol: L7Protocol::Websocket, + path: "/ws".into(), + tls: crate::l7::TlsMode::Auto, + enforcement: EnforcementMode::Enforce, + graphql_max_body_bytes: 0, + allow_encoded_slash: false, + websocket_credential_rewrite: true, + request_body_credential_rewrite: false, + websocket_graphql_policy: false, + }]; + let (child_env, resolver) = SecretResolver::from_provider_env( + std::iter::once(("DISCORD_BOT_TOKEN".to_string(), "real-token".to_string())).collect(), + ); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").expect("placeholder env"); + let ctx = L7EvalContext { + host: "gateway.example.test".into(), + port: 443, + policy_name: "route_api".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: resolver.map(Arc::new), + }; + + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_route_selection( + &configs, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + app.write_all( + b"GET /ws HTTP/1.1\r\nHost: gateway.example.test\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n", + ) + .await + .unwrap(); + + let mut forwarded = [0u8; 1024]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut forwarded), + ) + .await + .expect("upgrade request should reach upstream") + .unwrap(); + let forwarded = String::from_utf8_lossy(&forwarded[..n]); + assert!(forwarded.contains("Upgrade: websocket\r\n")); + + upstream + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n", + ) + .await + .unwrap(); + + let mut response = [0u8; 1024]; + let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) + .await + .expect("client should receive upgrade response") + .unwrap(); + assert!(String::from_utf8_lossy(&response[..n]).contains("101 Switching Protocols")); + + let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); + app.write_all(&masked_text_frame(payload.as_bytes())) + .await + .unwrap(); + + let (masked, rewritten) = tokio::time::timeout( + std::time::Duration::from_secs(1), + read_text_frame(&mut upstream), + ) + .await + .expect("rewritten websocket text should reach upstream") + .unwrap(); + assert!(masked, "client-to-server frame must remain masked"); + assert_eq!(rewritten, r#"{"op":2,"d":{"token":"real-token"}}"#); + assert!(!rewritten.contains(placeholder)); + + drop(app); + drop(upstream); + let _ = tokio::time::timeout(std::time::Duration::from_secs(1), relay).await; + } + + #[tokio::test] + async fn route_selected_graphql_websocket_rewrites_connection_init_credentials_after_upgrade() { + let data = r#" +network_policies: + route_api: + name: route_api + endpoints: + - host: gateway.example.test + port: 443 + path: "/graphql" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/graphql" + - allow: + operation_type: query + fields: [viewer] + websocket_credential_rewrite: true + binaries: + - { path: /usr/bin/node } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let tunnel_engine = engine + .clone_engine_for_tunnel(engine.current_generation()) + .unwrap(); + let configs = vec![L7EndpointConfig { + protocol: L7Protocol::Websocket, + path: "/graphql".into(), + tls: crate::l7::TlsMode::Auto, + enforcement: EnforcementMode::Enforce, + graphql_max_body_bytes: 0, + allow_encoded_slash: false, + websocket_credential_rewrite: true, + request_body_credential_rewrite: false, + websocket_graphql_policy: true, + }]; + let (child_env, resolver) = SecretResolver::from_provider_env( + std::iter::once(("T".to_string(), "real-token".to_string())).collect(), + ); + let placeholder = child_env.get("T").expect("placeholder env"); + let ctx = L7EvalContext { + host: "gateway.example.test".into(), + port: 443, + policy_name: "route_api".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: resolver.map(Arc::new), + }; + + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_route_selection( + &configs, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + app.write_all( + b"GET /graphql HTTP/1.1\r\nHost: gateway.example.test\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n", + ) + .await + .unwrap(); + + let mut forwarded = [0u8; 1024]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut forwarded), + ) + .await + .expect("upgrade request should reach upstream") + .unwrap(); + let forwarded = String::from_utf8_lossy(&forwarded[..n]); + assert!(forwarded.contains("GET /graphql HTTP/1.1")); + assert!(forwarded.contains("Upgrade: websocket\r\n")); + + upstream + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n", + ) + .await + .unwrap(); + + let mut response = [0u8; 1024]; + let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) + .await + .expect("client should receive upgrade response") + .unwrap(); + assert!(String::from_utf8_lossy(&response[..n]).contains("101 Switching Protocols")); + + let payload = format!( + r#"{{"type":"connection_init","payload":{{"authorization":"{placeholder}"}}}}"# + ); + app.write_all(&masked_text_frame(payload.as_bytes())) + .await + .unwrap(); + + let (masked, rewritten) = tokio::time::timeout( + std::time::Duration::from_secs(1), + read_text_frame(&mut upstream), + ) + .await + .expect("rewritten GraphQL WebSocket control message should reach upstream") + .unwrap(); + assert!(masked, "client-to-server frame must remain masked"); + assert_eq!( + rewritten, + r#"{"type":"connection_init","payload":{"authorization":"real-token"}}"# + ); + assert!(!rewritten.contains(placeholder)); + + drop(app); + drop(upstream); + let _ = tokio::time::timeout(std::time::Duration::from_secs(1), relay).await; + } + + fn masked_text_frame(payload: &[u8]) -> Vec { + let mask = [0x11, 0x22, 0x33, 0x44]; + assert!( + payload.len() <= 125, + "test helper only supports small frames" + ); + let payload_len = u8::try_from(payload.len()).expect("small frame length"); + let mut frame = vec![0x81, 0x80 | payload_len]; + frame.extend_from_slice(&mask); + frame.extend( + payload + .iter() + .enumerate() + .map(|(idx, byte)| byte ^ mask[idx % 4]), + ); + frame + } + + async fn read_text_frame( + reader: &mut R, + ) -> std::io::Result<(bool, String)> { + let mut header = [0u8; 2]; + reader.read_exact(&mut header).await?; + assert_eq!(header[0] & 0x0f, 0x1, "expected text frame"); + let masked = header[1] & 0x80 != 0; + let payload_len = usize::from(header[1] & 0x7f); + assert!(payload_len <= 125, "test helper only supports small frames"); + let mut mask = [0u8; 4]; + if masked { + reader.read_exact(&mut mask).await?; + } + let mut payload = vec![0u8; payload_len]; + reader.read_exact(&mut payload).await?; + if masked { + for (idx, byte) in payload.iter_mut().enumerate() { + *byte ^= mask[idx % 4]; + } + } + Ok((masked, String::from_utf8(payload).expect("text payload"))) + } + #[tokio::test] async fn l7_relay_closes_keep_alive_tunnel_after_policy_generation_change() { let initial_data = r#" diff --git a/crates/openshell-sandbox/src/l7/rest.rs b/crates/openshell-sandbox/src/l7/rest.rs index 85ae01290..ade126828 100644 --- a/crates/openshell-sandbox/src/l7/rest.rs +++ b/crates/openshell-sandbox/src/l7/rest.rs @@ -9,13 +9,18 @@ use crate::l7::provider::{BodyLength, L7Provider, L7Request, RelayOutcome}; use crate::opa::PolicyGenerationGuard; -use crate::secrets::rewrite_http_header_block; +use crate::secrets::{ + SecretResolver, contains_reserved_credential_marker, rewrite_http_header_block, +}; +use base64::Engine as _; use miette::{IntoDiagnostic, Result, miette}; -use std::collections::HashMap; +use sha1::{Digest, Sha1}; +use std::collections::{HashMap, HashSet}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tracing::debug; const MAX_HEADER_BYTES: usize = 16384; // 16 KiB for HTTP headers +const MAX_REWRITE_BODY_BYTES: usize = 256 * 1024; const RELAY_BUF_SIZE: usize = 8192; /// Idle timeout for `relay_until_eof`. If no data arrives within this window /// the body is considered complete. Prevents blocking on servers that keep @@ -343,7 +348,7 @@ pub(crate) async fn relay_http_request_with_resolver( req: &L7Request, client: &mut C, upstream: &mut U, - resolver: Option<&crate::secrets::SecretResolver>, + resolver: Option<&SecretResolver>, ) -> Result where C: AsyncRead + AsyncWrite + Unpin, @@ -356,9 +361,48 @@ pub(crate) async fn relay_http_request_with_resolver_guarded( req: &L7Request, client: &mut C, upstream: &mut U, - resolver: Option<&crate::secrets::SecretResolver>, + resolver: Option<&SecretResolver>, generation_guard: Option<&PolicyGenerationGuard>, ) -> Result +where + C: AsyncRead + AsyncWrite + Unpin, + U: AsyncRead + AsyncWrite + Unpin, +{ + relay_http_request_with_options_guarded( + req, + client, + upstream, + RelayRequestOptions { + resolver, + generation_guard, + websocket_extensions: WebSocketExtensionMode::Preserve, + request_body_credential_rewrite: false, + }, + ) + .await +} + +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub(crate) enum WebSocketExtensionMode { + #[default] + Preserve, + PermessageDeflate, +} + +#[derive(Clone, Copy, Default)] +pub(crate) struct RelayRequestOptions<'a> { + pub(crate) resolver: Option<&'a SecretResolver>, + pub(crate) generation_guard: Option<&'a PolicyGenerationGuard>, + pub(crate) websocket_extensions: WebSocketExtensionMode, + pub(crate) request_body_credential_rewrite: bool, +} + +pub(crate) async fn relay_http_request_with_options_guarded( + req: &L7Request, + client: &mut C, + upstream: &mut U, + options: RelayRequestOptions<'_>, +) -> Result where C: AsyncRead + AsyncWrite + Unpin, U: AsyncRead + AsyncWrite + Unpin, @@ -368,88 +412,702 @@ where .windows(4) .position(|w| w == b"\r\n\r\n") .map_or(req.raw_header.len(), |p| p + 4); + let header_str = std::str::from_utf8(&req.raw_header[..header_end]) + .map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let client_requested_upgrade = client_requested_upgrade(header_str); + let websocket_request = if options.websocket_extensions == WebSocketExtensionMode::Preserve { + None + } else { + parse_websocket_upgrade_request(&req.raw_header[..header_end])? + }; - let rewrite_result = rewrite_http_header_block(&req.raw_header[..header_end], resolver) + let (header_bytes, expected_websocket_extension) = rewrite_websocket_extensions_for_mode( + &req.raw_header[..header_end], + options.websocket_extensions, + websocket_request.is_some(), + )?; + let websocket_response = + websocket_request + .as_ref() + .map(|request| WebSocketResponseValidation { + expected_accept: websocket_accept_for_key(&request.sec_key), + expected_extension: expected_websocket_extension.clone(), + offered_subprotocols: request.subprotocols.clone(), + }); + + let rewrite_result = rewrite_http_header_block(&header_bytes, options.resolver) .map_err(|e| miette!("credential injection failed: {e}"))?; - if let Some(guard) = generation_guard { + if let Some(guard) = options.generation_guard { guard.ensure_current()?; } - upstream - .write_all(&rewrite_result.rewritten) - .await - .into_diagnostic()?; + if options.request_body_credential_rewrite { + let body = collect_and_rewrite_request_body( + req, + client, + &rewrite_result.rewritten, + header_str, + &req.raw_header[header_end..], + options.resolver, + options.generation_guard, + ) + .await?; + upstream.write_all(&body.headers).await.into_diagnostic()?; + if !body.body.is_empty() { + upstream.write_all(&body.body).await.into_diagnostic()?; + } + } else { + upstream + .write_all(&rewrite_result.rewritten) + .await + .into_diagnostic()?; - let overflow = &req.raw_header[header_end..]; - if !overflow.is_empty() { - if let Some(guard) = generation_guard { - guard.ensure_current()?; + let overflow = &req.raw_header[header_end..]; + if !overflow.is_empty() { + if let Some(guard) = options.generation_guard { + guard.ensure_current()?; + } + upstream.write_all(overflow).await.into_diagnostic()?; + } + let overflow_len = overflow.len() as u64; + + match req.body_length { + BodyLength::ContentLength(len) => { + let remaining = len.saturating_sub(overflow_len); + if remaining > 0 { + relay_fixed(client, upstream, remaining, options.generation_guard).await?; + } + } + BodyLength::Chunked => { + relay_chunked( + client, + upstream, + &req.raw_header[header_end..], + options.generation_guard, + ) + .await?; + } + BodyLength::None => {} } - upstream.write_all(overflow).await.into_diagnostic()?; } - let overflow_len = overflow.len() as u64; + upstream.flush().await.into_diagnostic()?; + + let outcome = relay_response( + &req.action, + upstream, + client, + RelayResponseOptions { + websocket_extensions: options.websocket_extensions, + websocket: websocket_response, + client_requested_upgrade, + }, + ) + .await?; + + Ok(outcome) +} + +struct PreparedRequestBody { + headers: Vec, + body: Vec, +} +async fn collect_and_rewrite_request_body( + req: &L7Request, + client: &mut C, + rewritten_headers: &[u8], + original_header_str: &str, + already_read: &[u8], + resolver: Option<&SecretResolver>, + generation_guard: Option<&PolicyGenerationGuard>, +) -> Result { match req.body_length { + BodyLength::None => { + if body_bytes_contain_reserved_marker(already_read) { + return Err(miette!( + "request body credential rewrite cannot resolve placeholders without explicit body framing" + )); + } + Ok(PreparedRequestBody { + headers: rewritten_headers.to_vec(), + body: already_read.to_vec(), + }) + } BodyLength::ContentLength(len) => { - let remaining = len.saturating_sub(overflow_len); - if remaining > 0 { - relay_fixed(client, upstream, remaining, generation_guard).await?; + let len = usize::try_from(len) + .map_err(|_| miette!("request body is too large for credential rewrite"))?; + if len > MAX_REWRITE_BODY_BYTES { + return Err(miette!( + "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" + )); + } + let mut body = Vec::with_capacity(len); + let initial_len = already_read.len().min(len); + body.extend_from_slice(&already_read[..initial_len]); + let mut remaining = len.saturating_sub(initial_len); + let mut buf = [0u8; RELAY_BUF_SIZE]; + while remaining > 0 { + let to_read = remaining.min(buf.len()); + let n = client.read(&mut buf[..to_read]).await.into_diagnostic()?; + if n == 0 { + return Err(miette!( + "Connection closed with {remaining} body bytes remaining" + )); + } + if let Some(guard) = generation_guard { + guard.ensure_current()?; + } + body.extend_from_slice(&buf[..n]); + remaining -= n; } + let (headers, body) = + rewrite_buffered_body(rewritten_headers, original_header_str, body, resolver)?; + Ok(PreparedRequestBody { headers, body }) } BodyLength::Chunked => { - relay_chunked( - client, - upstream, - &req.raw_header[header_end..], - generation_guard, - ) - .await?; + let body = collect_chunked_body(client, already_read, generation_guard).await?; + if body_bytes_contain_reserved_marker(&body) { + return Err(miette!( + "request body credential rewrite does not support chunked bodies containing credential placeholders" + )); + } + Ok(PreparedRequestBody { + headers: rewritten_headers.to_vec(), + body, + }) } - BodyLength::None => {} } - upstream.flush().await.into_diagnostic()?; +} - let outcome = relay_response(&req.action, upstream, client).await?; - - // Validate that the client actually requested an upgrade before accepting - // a 101 from upstream. Per RFC 9110 Section 7.8, the server MUST NOT send - // 101 unless the client sent Upgrade + Connection: Upgrade headers. A - // non-compliant or malicious upstream could send an unsolicited 101 to - // bypass L7 inspection. - if matches!(outcome, RelayOutcome::Upgraded { .. }) { - let header_str = String::from_utf8_lossy(&req.raw_header[..header_end]); - if !client_requested_upgrade(&header_str) { - openshell_ocsf::ocsf_emit!( - openshell_ocsf::DetectionFindingBuilder::new(crate::ocsf_ctx()) - .activity(openshell_ocsf::ActivityId::Open) - .action(openshell_ocsf::ActionId::Denied) - .disposition(openshell_ocsf::DispositionId::Blocked) - .severity(openshell_ocsf::SeverityId::High) - .confidence(openshell_ocsf::ConfidenceId::High) - .is_alert(true) - .finding_info( - openshell_ocsf::FindingInfo::new( - "unsolicited-101-upgrade", - "Unsolicited 101 Switching Protocols", - ) - .with_desc(&format!( - "Upstream sent 101 without client Upgrade request for {} {} — \ - possible L7 inspection bypass. Connection closed.", - req.action, req.target, - )), - ) - .message(format!( - "Unsolicited 101 upgrade blocked: {} {}", - req.action, req.target, - )) - .build() - ); - return Ok(RelayOutcome::Consumed); +fn rewrite_buffered_body( + headers: &[u8], + original_header_str: &str, + body: Vec, + resolver: Option<&SecretResolver>, +) -> Result<(Vec, Vec)> { + if body.is_empty() { + return Ok((headers.to_vec(), body)); + } + + let content_type = content_type(original_header_str); + if !is_rewritable_content_type(content_type.as_deref()) { + if body_bytes_contain_reserved_marker(&body) { + return Err(miette!( + "request body credential rewrite found placeholders in an unsupported content type" + )); } + return Ok((headers.to_vec(), body)); } - Ok(outcome) + let mut text = String::from_utf8(body) + .map_err(|_| miette!("request body credential rewrite requires UTF-8 text bodies"))?; + if !contains_reserved_credential_marker(&text) { + return Ok((headers.to_vec(), text.into_bytes())); + } + + let Some(resolver) = resolver else { + return Err(miette!( + "request body credential rewrite found placeholders but no resolver is available" + )); + }; + let replacements = resolver + .rewrite_text_placeholders(&mut text, "request_body") + .map_err(|e| miette!("credential injection failed: {e}"))?; + if replacements == 0 || contains_reserved_credential_marker(&text) { + return Err(miette!( + "request body credential rewrite left unresolved credential placeholders" + )); + } + + let body = text.into_bytes(); + let headers = set_content_length(headers, body.len())?; + Ok((headers, body)) +} + +async fn collect_chunked_body( + client: &mut C, + already_read: &[u8], + generation_guard: Option<&PolicyGenerationGuard>, +) -> Result> { + let mut read_buf = [0u8; RELAY_BUF_SIZE]; + let mut parse_buf = Vec::from(already_read); + let mut pos = 0usize; + + loop { + if parse_buf.len() > MAX_REWRITE_BODY_BYTES { + return Err(miette!( + "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" + )); + } + + let size_line_end = loop { + if let Some(end) = find_crlf(&parse_buf, pos) { + break end; + } + let n = client.read(&mut read_buf).await.into_diagnostic()?; + if n == 0 { + return Err(miette!("Chunked body ended before chunk-size line")); + } + if let Some(guard) = generation_guard { + guard.ensure_current()?; + } + parse_buf.extend_from_slice(&read_buf[..n]); + if parse_buf.len() > MAX_REWRITE_BODY_BYTES { + return Err(miette!( + "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" + )); + } + }; + + let size_line = std::str::from_utf8(&parse_buf[pos..size_line_end]) + .into_diagnostic() + .map_err(|_| miette!("Invalid UTF-8 in chunk-size line"))?; + let size_token = size_line + .split(';') + .next() + .map(str::trim) + .unwrap_or_default(); + let chunk_size = usize::from_str_radix(size_token, 16) + .into_diagnostic() + .map_err(|_| miette!("Invalid chunk size token: {size_token:?}"))?; + pos = size_line_end + 2; + + if chunk_size == 0 { + loop { + let trailer_end = loop { + if let Some(end) = find_crlf(&parse_buf, pos) { + break end; + } + let n = client.read(&mut read_buf).await.into_diagnostic()?; + if n == 0 { + return Err(miette!("Chunked body ended before trailer terminator")); + } + if let Some(guard) = generation_guard { + guard.ensure_current()?; + } + parse_buf.extend_from_slice(&read_buf[..n]); + if parse_buf.len() > MAX_REWRITE_BODY_BYTES { + return Err(miette!( + "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" + )); + } + }; + let trailer_line = &parse_buf[pos..trailer_end]; + pos = trailer_end + 2; + if trailer_line.is_empty() { + return Ok(parse_buf); + } + } + } + + let chunk_end = pos + .checked_add(chunk_size) + .ok_or_else(|| miette!("Chunk size overflow"))?; + let chunk_with_crlf_end = chunk_end + .checked_add(2) + .ok_or_else(|| miette!("Chunk size overflow"))?; + while parse_buf.len() < chunk_with_crlf_end { + let n = client.read(&mut read_buf).await.into_diagnostic()?; + if n == 0 { + return Err(miette!("Chunked body ended mid-chunk")); + } + if let Some(guard) = generation_guard { + guard.ensure_current()?; + } + parse_buf.extend_from_slice(&read_buf[..n]); + if parse_buf.len() > MAX_REWRITE_BODY_BYTES { + return Err(miette!( + "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" + )); + } + } + if &parse_buf[chunk_end..chunk_with_crlf_end] != b"\r\n" { + return Err(miette!("Chunk missing terminating CRLF")); + } + pos = chunk_with_crlf_end; + } +} + +fn content_type(headers: &str) -> Option { + headers.lines().skip(1).find_map(|line| { + let (name, value) = line.split_once(':')?; + name.trim().eq_ignore_ascii_case("content-type").then(|| { + value + .split(';') + .next() + .unwrap_or("") + .trim() + .to_ascii_lowercase() + }) + }) +} + +fn is_rewritable_content_type(content_type: Option<&str>) -> bool { + let Some(content_type) = content_type else { + return false; + }; + content_type == "application/json" + || content_type == "application/x-www-form-urlencoded" + || content_type.starts_with("text/") +} + +fn body_bytes_contain_reserved_marker(body: &[u8]) -> bool { + if body.is_empty() { + return false; + } + String::from_utf8_lossy(body) + .split('\0') + .any(contains_reserved_credential_marker) +} + +fn set_content_length(headers: &[u8], len: usize) -> Result> { + use std::fmt::Write as _; + + let header_str = + std::str::from_utf8(headers).map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let mut out = String::with_capacity(header_str.len() + 32); + let mut inserted = false; + for line in header_str.split("\r\n") { + if line.is_empty() { + if !inserted { + let _ = write!(out, "Content-Length: {len}\r\n"); + } + out.push_str("\r\n"); + break; + } + if line + .split_once(':') + .is_some_and(|(name, _)| name.trim().eq_ignore_ascii_case("content-length")) + { + if !inserted { + let _ = write!(out, "Content-Length: {len}\r\n"); + inserted = true; + } + continue; + } + out.push_str(line); + out.push_str("\r\n"); + } + Ok(out.into_bytes()) +} + +pub(crate) fn request_is_websocket_upgrade(raw_header: &[u8]) -> bool { + let header_end = raw_header + .windows(4) + .position(|w| w == b"\r\n\r\n") + .map_or(raw_header.len(), |p| p + 4); + validate_websocket_upgrade_request(&raw_header[..header_end]).unwrap_or(false) +} + +fn rewrite_websocket_extensions_for_mode( + raw_header: &[u8], + mode: WebSocketExtensionMode, + websocket_request: bool, +) -> Result<(Vec, Option)> { + if !websocket_request || mode == WebSocketExtensionMode::Preserve { + return Ok((raw_header.to_vec(), None)); + } + match mode { + WebSocketExtensionMode::Preserve => Ok((raw_header.to_vec(), None)), + WebSocketExtensionMode::PermessageDeflate => { + rewrite_websocket_extensions_for_permessage_deflate(raw_header) + } + } +} + +fn rewrite_websocket_extensions_for_permessage_deflate( + raw_header: &[u8], +) -> Result<(Vec, Option)> { + let header_str = std::str::from_utf8(raw_header) + .map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let safe_offer = supported_permessage_deflate_offer(header_str)?; + let mut out = Vec::with_capacity(raw_header.len()); + let mut inserted = false; + + for line in header_str.split_inclusive("\r\n") { + let bare = line.strip_suffix("\r\n").unwrap_or(line); + if bare + .to_ascii_lowercase() + .starts_with("sec-websocket-extensions:") + { + continue; + } + if bare.is_empty() && !inserted { + if let Some(offer) = safe_offer.as_deref() { + out.extend_from_slice(b"Sec-WebSocket-Extensions: "); + out.extend_from_slice(offer.as_bytes()); + out.extend_from_slice(b"\r\n"); + } + inserted = true; + } + out.extend_from_slice(line.as_bytes()); + } + Ok((out, safe_offer)) +} + +fn supported_permessage_deflate_offer(header_str: &str) -> Result> { + for offer in websocket_extension_offers(header_str)? { + if !offer.name.eq_ignore_ascii_case("permessage-deflate") { + continue; + } + let mut client_no_context_takeover = false; + let mut server_no_context_takeover = false; + let mut unsupported = false; + let mut seen = HashSet::new(); + for param in &offer.params { + let name = param.name.to_ascii_lowercase(); + if param.value.is_some() || !seen.insert(name.clone()) { + unsupported = true; + break; + } + if name == "client_no_context_takeover" { + client_no_context_takeover = true; + } else if name == "server_no_context_takeover" { + server_no_context_takeover = true; + } else { + unsupported = true; + break; + } + } + if client_no_context_takeover && !unsupported { + let mut offer = "permessage-deflate; client_no_context_takeover".to_string(); + if server_no_context_takeover { + offer.push_str("; server_no_context_takeover"); + } + return Ok(Some(offer)); + } + } + Ok(None) +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct WebSocketExtensionOffer { + name: String, + params: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct WebSocketExtensionParam { + name: String, + value: Option, +} + +fn websocket_extension_offers(header_str: &str) -> Result> { + let mut offers = Vec::new(); + for line in header_str.lines().skip(1) { + let Some((name, value)) = line.split_once(':') else { + continue; + }; + if !name.trim().eq_ignore_ascii_case("sec-websocket-extensions") { + continue; + } + for extension in value.split(',') { + let mut parts = extension.split(';').map(str::trim); + let Some(extension_name) = parts.next().filter(|name| !name.is_empty()) else { + return Err(miette!("invalid WebSocket extension offer")); + }; + if !is_http_token(extension_name) { + return Err(miette!("invalid WebSocket extension token")); + } + let mut params = Vec::new(); + for param in parts { + if param.is_empty() { + return Err(miette!("invalid WebSocket extension parameter")); + } + let (param_name, param_value) = match param.split_once('=') { + Some((name, value)) => { + let value = value.trim(); + if value.is_empty() || value.starts_with('"') || !is_http_token(value) { + return Err(miette!("unsupported WebSocket extension parameter value")); + } + (name.trim(), Some(value.to_string())) + } + None => (param, None), + }; + if param_name.is_empty() || !is_http_token(param_name) { + return Err(miette!("invalid WebSocket extension parameter")); + } + params.push(WebSocketExtensionParam { + name: param_name.to_string(), + value: param_value, + }); + } + offers.push(WebSocketExtensionOffer { + name: extension_name.to_string(), + params, + }); + } + } + Ok(offers) +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct WebSocketUpgradeRequest { + sec_key: String, + subprotocols: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct WebSocketResponseValidation { + expected_accept: String, + expected_extension: Option, + offered_subprotocols: Vec, +} + +fn validate_websocket_upgrade_request(raw_header: &[u8]) -> Result { + parse_websocket_upgrade_request(raw_header).map(|request| request.is_some()) +} + +fn parse_websocket_upgrade_request(raw_header: &[u8]) -> Result> { + let header_str = std::str::from_utf8(raw_header) + .map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let mut lines = header_str.lines(); + let Some(request_line) = lines.next() else { + return Ok(None); + }; + let method = request_line.split_whitespace().next().unwrap_or_default(); + let mut headers = WebSocketUpgradeHeaders::default(); + + for line in lines { + if line.is_empty() { + break; + } + let Some((name, value)) = line.split_once(':') else { + continue; + }; + let name = name.trim().to_ascii_lowercase(); + let value = value.trim(); + match name.as_str() { + "upgrade" if header_value_contains_token(value, "websocket") => { + headers.upgrade_websocket = true; + } + "connection" if header_value_contains_token(value, "upgrade") => { + headers.connection_upgrade = true; + } + "sec-websocket-key" => { + headers.sec_key_count += 1; + headers.sec_key = Some(value.to_string()); + } + "sec-websocket-version" => { + headers.version_count += 1; + headers.version = Some(value.to_string()); + } + "sec-websocket-protocol" => { + headers.subprotocols.extend(parse_http_token_list(value)?); + } + _ => {} + } + } + + if !headers.is_attempt() { + return Ok(None); + } + if !method.eq_ignore_ascii_case("GET") { + return Err(miette!("websocket upgrade request must use GET")); + } + if !headers.upgrade_websocket { + return Err(miette!( + "websocket upgrade request missing Upgrade: websocket" + )); + } + if !headers.connection_upgrade { + return Err(miette!( + "websocket upgrade request missing Connection: Upgrade" + )); + } + if headers.sec_key_count != 1 { + return Err(miette!( + "websocket upgrade request must include exactly one Sec-WebSocket-Key" + )); + } + let key = headers.sec_key.as_deref().unwrap_or_default(); + let decoded_key = base64::engine::general_purpose::STANDARD + .decode(key.as_bytes()) + .map_err(|_| miette!("websocket upgrade request has invalid Sec-WebSocket-Key"))?; + if decoded_key.len() != 16 { + return Err(miette!( + "websocket upgrade request has invalid Sec-WebSocket-Key length" + )); + } + if headers.version_count != 1 || headers.version.as_deref() != Some("13") { + return Err(miette!( + "websocket upgrade request must use Sec-WebSocket-Version: 13" + )); + } + Ok(Some(WebSocketUpgradeRequest { + sec_key: key.to_string(), + subprotocols: headers.subprotocols, + })) +} + +fn websocket_accept_for_key(sec_key: &str) -> String { + const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + let mut hasher = Sha1::new(); + hasher.update(sec_key.as_bytes()); + hasher.update(WEBSOCKET_GUID.as_bytes()); + base64::engine::general_purpose::STANDARD.encode(hasher.finalize()) +} + +fn header_value_contains_token(value: &str, expected: &str) -> bool { + value + .split(',') + .any(|token| token.trim().eq_ignore_ascii_case(expected)) +} + +fn parse_http_token_list(value: &str) -> Result> { + let mut tokens = Vec::new(); + for token in value.split(',') { + let token = token.trim(); + if token.is_empty() || !is_http_token(token) { + return Err(miette!("invalid HTTP token list")); + } + tokens.push(token.to_string()); + } + Ok(tokens) +} + +fn is_http_token(value: &str) -> bool { + !value.is_empty() + && value.as_bytes().iter().all(|byte| { + matches!( + byte, + b'!' | b'#' + | b'$' + | b'%' + | b'&' + | b'\'' + | b'*' + | b'+' + | b'-' + | b'.' + | b'^' + | b'_' + | b'`' + | b'|' + | b'~' + | b'0'..=b'9' + | b'A'..=b'Z' + | b'a'..=b'z' + ) + }) +} + +#[derive(Default)] +struct WebSocketUpgradeHeaders { + upgrade_websocket: bool, + connection_upgrade: bool, + sec_key: Option, + sec_key_count: usize, + version: Option, + version_count: usize, + subprotocols: Vec, +} + +impl WebSocketUpgradeHeaders { + fn is_attempt(&self) -> bool { + self.upgrade_websocket || self.sec_key.is_some() || self.version.is_some() + } } /// Send a 403 Forbidden JSON deny response. @@ -768,10 +1426,28 @@ fn find_crlf(buf: &[u8], start: usize) -> Option { .map(|offset| start + offset) } +#[derive(Clone)] +struct RelayResponseOptions { + websocket_extensions: WebSocketExtensionMode, + client_requested_upgrade: bool, + websocket: Option, +} + +impl Default for RelayResponseOptions { + fn default() -> Self { + Self { + websocket_extensions: WebSocketExtensionMode::Preserve, + client_requested_upgrade: true, + websocket: None, + } + } +} + async fn relay_response( request_method: &str, upstream: &mut U, client: &mut C, + options: RelayResponseOptions, ) -> Result where U: AsyncRead + Unpin, @@ -825,6 +1501,14 @@ where // from upstream beyond the headers are overflow that belong to the // upgraded protocol and must be forwarded before switching. if status_code == 101 { + if !options.client_requested_upgrade { + return Ok(RelayOutcome::Consumed); + } + let websocket_permessage_deflate = validate_websocket_response( + &header_str, + options.websocket_extensions, + options.websocket.as_ref(), + )?; client .write_all(&buf[..header_end]) .await @@ -836,7 +1520,10 @@ where overflow_bytes = overflow.len(), "101 Switching Protocols — signaling protocol upgrade" ); - return Ok(RelayOutcome::Upgraded { overflow }); + return Ok(RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + }); } // Bodiless responses (HEAD, 1xx, 204, 304): forward headers only, skip body @@ -938,6 +1625,159 @@ fn parse_connection_close(headers: &str) -> bool { false } +fn validate_websocket_response( + headers: &str, + mode: WebSocketExtensionMode, + websocket: Option<&WebSocketResponseValidation>, +) -> Result { + let Some(validation) = websocket else { + return validate_websocket_response_extensions_preserved(headers, mode); + }; + + let mut upgrade_websocket = false; + let mut connection_upgrade = false; + let mut accept_count = 0usize; + let mut accept_matches = false; + let mut subprotocol_count = 0usize; + let mut selected_subprotocol = None; + + for line in headers.lines().skip(1) { + let Some((name, value)) = line.split_once(':') else { + continue; + }; + let name = name.trim().to_ascii_lowercase(); + let value = value.trim(); + match name.as_str() { + "upgrade" if header_value_contains_token(value, "websocket") => { + upgrade_websocket = true; + } + "connection" if header_value_contains_token(value, "upgrade") => { + connection_upgrade = true; + } + "sec-websocket-accept" => { + accept_count += 1; + accept_matches = value == validation.expected_accept; + } + "sec-websocket-protocol" => { + subprotocol_count += 1; + if !is_http_token(value) { + return Err(miette!( + "websocket upgrade response has invalid Sec-WebSocket-Protocol" + )); + } + selected_subprotocol = Some(value.to_string()); + } + _ => {} + } + } + + if !upgrade_websocket { + return Err(miette!( + "websocket upgrade response missing Upgrade: websocket" + )); + } + if !connection_upgrade { + return Err(miette!( + "websocket upgrade response missing Connection: Upgrade" + )); + } + if accept_count != 1 || !accept_matches { + return Err(miette!( + "websocket upgrade response has invalid Sec-WebSocket-Accept" + )); + } + if subprotocol_count > 1 { + return Err(miette!( + "websocket upgrade response has multiple Sec-WebSocket-Protocol headers" + )); + } + if let Some(protocol) = selected_subprotocol + && !validation + .offered_subprotocols + .iter() + .any(|offered| offered == &protocol) + { + return Err(miette!( + "upstream selected WebSocket subprotocol that was not offered" + )); + } + + let actual_extension = normalized_websocket_extension(headers)?; + match (&validation.expected_extension, actual_extension.as_deref()) { + (None, Some(_)) => Err(miette!( + "upstream negotiated WebSocket extension that was not offered" + )), + (None | Some(_), None) => Ok(false), + (Some(expected), Some(actual)) if expected.eq_ignore_ascii_case(actual) => Ok(true), + (Some(_), Some(_)) => Err(miette!( + "upstream negotiated WebSocket extension that does not match the safe offer" + )), + } +} + +fn validate_websocket_response_extensions_preserved( + headers: &str, + mode: WebSocketExtensionMode, +) -> Result { + match mode { + WebSocketExtensionMode::Preserve => Ok(false), + WebSocketExtensionMode::PermessageDeflate => { + let offers = websocket_extension_offers(headers)?; + if offers.is_empty() { + Ok(false) + } else { + Err(miette!( + "upstream negotiated WebSocket extension that was not offered" + )) + } + } + } +} + +fn normalized_websocket_extension(headers: &str) -> Result> { + let offers = websocket_extension_offers(headers)?; + if offers.is_empty() { + return Ok(None); + } + if offers.len() != 1 { + return Err(miette!("upstream negotiated multiple WebSocket extensions")); + } + let offer = &offers[0]; + if !offer.name.eq_ignore_ascii_case("permessage-deflate") { + return Err(miette!( + "upstream negotiated unsupported WebSocket extension" + )); + } + let mut client_no_context_takeover = false; + let mut server_no_context_takeover = false; + let mut seen = HashSet::new(); + for param in &offer.params { + let name = param.name.to_ascii_lowercase(); + if param.value.is_some() || !seen.insert(name.clone()) { + return Err(miette!( + "upstream negotiated unsupported permessage-deflate parameter" + )); + } + if name == "client_no_context_takeover" { + client_no_context_takeover = true; + } else if name == "server_no_context_takeover" { + server_no_context_takeover = true; + } else { + return Err(miette!( + "upstream negotiated unsupported permessage-deflate parameter" + )); + } + } + let mut normalized = String::from("permessage-deflate"); + if client_no_context_takeover { + normalized.push_str("; client_no_context_takeover"); + } + if server_no_context_takeover { + normalized.push_str("; server_no_context_takeover"); + } + Ok(Some(normalized)) +} + /// Check if the client request headers contain both `Upgrade` and /// `Connection: Upgrade` headers, indicating the client requested a /// protocol upgrade (e.g. WebSocket). @@ -1034,21 +1874,297 @@ fn is_benign_close(err: &std::io::Error) -> bool { ) } -#[cfg(test)] -#[allow( - clippy::iter_on_single_items, - clippy::manual_string_new, - clippy::collapsible_if, - clippy::cast_possible_truncation, - reason = "Test code: test fixtures and explicit value-shape assertions are idiomatic in tests." -)] -mod tests { - use super::*; - use crate::opa::OpaEngine; - use crate::secrets::SecretResolver; - use base64::Engine as _; +#[cfg(test)] +#[allow( + clippy::iter_on_single_items, + clippy::manual_string_new, + clippy::collapsible_if, + clippy::cast_possible_truncation, + reason = "Test code: test fixtures and explicit value-shape assertions are idiomatic in tests." +)] +mod tests { + use super::*; + use crate::opa::OpaEngine; + use crate::secrets::SecretResolver; + use flate2::{Compress, Compression, Decompress, FlushCompress, FlushDecompress, Status}; + use std::sync::Arc; + + const TEST_POLICY: &str = include_str!("../../data/sandbox-policy.rego"); + const VALID_WS_KEY: &str = "dGhlIHNhbXBsZSBub25jZQ=="; + const VALID_WS_ACCEPT: &str = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="; + const TEXT_OPCODE: u8 = 0x1; + + #[derive(Debug)] + struct CapturedFrame { + fin_opcode: u8, + masked: bool, + payload: Vec, + } + + async fn read_http_header_block(reader: &mut R) -> Vec { + tokio::time::timeout(std::time::Duration::from_secs(2), async { + let mut header = Vec::new(); + let mut byte = [0u8; 1]; + loop { + reader.read_exact(&mut byte).await.unwrap(); + header.push(byte[0]); + if header.ends_with(b"\r\n\r\n") { + break; + } + } + header + }) + .await + .expect("HTTP header block should arrive") + } + + async fn read_websocket_frame(reader: &mut R) -> CapturedFrame { + tokio::time::timeout(std::time::Duration::from_secs(2), async { + let mut prefix = [0u8; 2]; + reader.read_exact(&mut prefix).await.unwrap(); + let masked = prefix[1] & 0x80 != 0; + let mut payload_len = u64::from(prefix[1] & 0x7f); + if payload_len == 126 { + let mut extended = [0u8; 2]; + reader.read_exact(&mut extended).await.unwrap(); + payload_len = u64::from(u16::from_be_bytes(extended)); + } else if payload_len == 127 { + let mut extended = [0u8; 8]; + reader.read_exact(&mut extended).await.unwrap(); + payload_len = u64::from_be_bytes(extended); + } + let mut mask_key = [0u8; 4]; + if masked { + reader.read_exact(&mut mask_key).await.unwrap(); + } + let payload_len = usize::try_from(payload_len).unwrap(); + let mut payload = vec![0u8; payload_len]; + reader.read_exact(&mut payload).await.unwrap(); + if masked { + apply_test_mask(&mut payload, mask_key); + } + CapturedFrame { + fin_opcode: prefix[0], + masked, + payload, + } + }) + .await + .expect("WebSocket frame should arrive") + } + + fn masked_frame_with_rsv(opcode: u8, rsv: u8, payload: &[u8]) -> Vec { + let mask_key = [0x37, 0xfa, 0x21, 0x3d]; + let mut frame = Vec::new(); + frame.push(0x80 | rsv | opcode); + write_test_payload_len(&mut frame, 0x80, payload.len()); + frame.extend_from_slice(&mask_key); + let mut masked = payload.to_vec(); + apply_test_mask(&mut masked, mask_key); + frame.extend_from_slice(&masked); + frame + } + + fn unmasked_frame(opcode: u8, payload: &[u8]) -> Vec { + let mut frame = Vec::new(); + frame.push(0x80 | opcode); + write_test_payload_len(&mut frame, 0, payload.len()); + frame.extend_from_slice(payload); + frame + } + + fn write_test_payload_len(frame: &mut Vec, mask_bit: u8, payload_len: usize) { + if payload_len < 126 { + frame.push(mask_bit | payload_len as u8); + } else if u16::try_from(payload_len).is_ok() { + frame.push(mask_bit | 0x7e); + frame.extend_from_slice(&(payload_len as u16).to_be_bytes()); + } else { + frame.push(mask_bit | 0x7f); + frame.extend_from_slice(&(payload_len as u64).to_be_bytes()); + } + } + + fn apply_test_mask(payload: &mut [u8], mask_key: [u8; 4]) { + for (index, byte) in payload.iter_mut().enumerate() { + *byte ^= mask_key[index % 4]; + } + } + + fn compress_test_permessage_deflate(payload: &[u8]) -> Vec { + let mut compressor = Compress::new(Compression::fast(), false); + let mut out = Vec::with_capacity(payload.len().saturating_add(128)); + loop { + let consumed = usize::try_from(compressor.total_in()).unwrap(); + if consumed >= payload.len() { + break; + } + let before_in = compressor.total_in(); + let before_out = compressor.total_out(); + let status = compressor + .compress_vec(&payload[consumed..], &mut out, FlushCompress::None) + .unwrap(); + if matches!(status, Status::BufError) + || (compressor.total_in() == before_in && compressor.total_out() == before_out) + { + out.reserve(out.capacity().max(1024)); + } + } + loop { + out.reserve(64); + let before_out = compressor.total_out(); + compressor + .compress_vec(&[], &mut out, FlushCompress::Sync) + .unwrap(); + if out.ends_with(&[0x00, 0x00, 0xff, 0xff]) { + break; + } + if compressor.total_out() == before_out { + out.reserve(out.capacity().max(1024)); + } + } + out.truncate(out.len() - 4); + out + } + + fn decompress_test_permessage_deflate(payload: &[u8]) -> Vec { + let mut decoder = Decompress::new(false); + let mut input = Vec::with_capacity(payload.len() + 4); + input.extend_from_slice(payload); + input.extend_from_slice(&[0x00, 0x00, 0xff, 0xff]); + let mut out = Vec::new(); + let mut input_pos = 0usize; + let mut scratch = [0u8; RELAY_BUF_SIZE]; + loop { + let before_in = decoder.total_in(); + let before_out = decoder.total_out(); + let status = decoder + .decompress(&input[input_pos..], &mut scratch, FlushDecompress::Sync) + .unwrap(); + let read = usize::try_from(decoder.total_in() - before_in).unwrap(); + let written = usize::try_from(decoder.total_out() - before_out).unwrap(); + input_pos += read; + out.extend_from_slice(&scratch[..written]); + if matches!(status, Status::StreamEnd) { + break; + } + if input_pos >= input.len() && written < scratch.len() { + break; + } + assert!( + read != 0 || written != 0, + "test permessage-deflate decompression did not make progress" + ); + } + out + } + + fn websocket_request(extension: Option<&str>) -> L7Request { + let mut raw_header = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\n" + ); + if let Some(extension) = extension { + raw_header.push_str("Sec-WebSocket-Extensions: "); + raw_header.push_str(extension); + raw_header.push_str("\r\n"); + } + raw_header.push_str("Sec-WebSocket-Version: 13\r\n\r\n"); + L7Request { + action: "GET".to_string(), + target: "/ws".to_string(), + query_params: HashMap::new(), + raw_header: raw_header.into_bytes(), + body_length: BodyLength::None, + } + } + + async fn run_upgraded_websocket_case( + request_extension: Option<&'static str>, + response_extension: Option<&'static str>, + extension_mode: WebSocketExtensionMode, + resolver: Option>, + client_frame: Vec, + ) -> (String, CapturedFrame) { + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(16384); + let (mut client_app, mut proxy_to_client) = tokio::io::duplex(16384); + let req = websocket_request(request_extension); + let resolver_for_header = resolver.clone(); + let resolver_for_upgrade = resolver.clone(); + + let upstream_task = tokio::spawn(async move { + let forwarded = read_http_header_block(&mut upstream_side).await; + let forwarded = String::from_utf8(forwarded).unwrap(); + let mut response = format!( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: {VALID_WS_ACCEPT}\r\n" + ); + if let Some(extension) = response_extension { + response.push_str("Sec-WebSocket-Extensions: "); + response.push_str(extension); + response.push_str("\r\n"); + } + response.push_str("\r\n"); + upstream_side.write_all(response.as_bytes()).await.unwrap(); + upstream_side.flush().await.unwrap(); + let frame = read_websocket_frame(&mut upstream_side).await; + (forwarded, frame) + }); + + let relay_task = tokio::spawn(async move { + let outcome = relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + resolver: resolver_for_header.as_deref(), + websocket_extensions: extension_mode, + ..Default::default() + }, + ) + .await + .expect("handshake relay should succeed"); + let RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + } = outcome + else { + panic!("expected upgraded relay outcome"); + }; + let credential_rewrite = resolver_for_upgrade.is_some(); + crate::l7::relay::handle_upgrade( + &mut proxy_to_client, + &mut proxy_to_upstream, + overflow, + "example.com", + 443, + crate::l7::relay::UpgradeRelayOptions { + websocket_request: true, + websocket: crate::l7::relay::WebSocketUpgradeBehavior { + credential_rewrite, + permessage_deflate: websocket_permessage_deflate, + ..Default::default() + }, + secret_resolver: resolver_for_upgrade, + target: "/ws".to_string(), + policy_name: "test-policy".to_string(), + ..Default::default() + }, + ) + .await + }); + + let response = read_http_header_block(&mut client_app).await; + assert!( + String::from_utf8_lossy(&response).contains("101 Switching Protocols"), + "client must receive the upgrade before frame relay starts" + ); + client_app.write_all(&client_frame).await.unwrap(); + client_app.flush().await.unwrap(); - const TEST_POLICY: &str = include_str!("../../data/sandbox-policy.rego"); + let result = upstream_task.await.expect("upstream task should complete"); + drop(client_app); + let _ = tokio::time::timeout(std::time::Duration::from_secs(2), relay_task).await; + result + } #[test] fn deny_response_body_is_agent_readable_and_redacted() { @@ -1711,7 +2827,12 @@ mod tests { let result = tokio::time::timeout( std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), ) .await .expect("relay_response should not deadlock"); @@ -1752,7 +2873,12 @@ mod tests { let result = tokio::time::timeout( std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), ) .await .expect("must not block when no Connection: close"); @@ -1788,7 +2914,12 @@ mod tests { let result = tokio::time::timeout( std::time::Duration::from_secs(2), - relay_response("HEAD", &mut upstream_read, &mut client_write), + relay_response( + "HEAD", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), ) .await .expect("HEAD relay must not deadlock waiting for body"); @@ -1821,7 +2952,12 @@ mod tests { let result = tokio::time::timeout( std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), ) .await .expect("204 relay must not deadlock"); @@ -1856,7 +2992,12 @@ mod tests { let result = tokio::time::timeout( std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), ) .await .expect("must not block when chunked body is complete in overflow"); @@ -1895,7 +3036,12 @@ mod tests { let result = tokio::time::timeout( std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), ) .await .expect("must not block when chunked response has trailers"); @@ -1933,7 +3079,12 @@ mod tests { let result = tokio::time::timeout( std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), ) .await .expect("normal relay must not deadlock"); @@ -1964,7 +3115,12 @@ mod tests { let result = tokio::time::timeout( std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), ) .await .expect("relay must not deadlock"); @@ -2005,14 +3161,19 @@ mod tests { let result = tokio::time::timeout( std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), ) .await .expect("relay_response should not deadlock"); let outcome = result.expect("relay_response should succeed"); match outcome { - RelayOutcome::Upgraded { overflow } => { + RelayOutcome::Upgraded { overflow, .. } => { assert_eq!( &overflow, b"\x81\x05hello", "overflow should contain WebSocket frame data" @@ -2047,13 +3208,18 @@ mod tests { let result = tokio::time::timeout( std::time::Duration::from_secs(2), - relay_response("GET", &mut upstream_read, &mut client_write), + relay_response( + "GET", + &mut upstream_read, + &mut client_write, + RelayResponseOptions::default(), + ), ) .await .expect("relay_response should not deadlock"); match result.expect("should succeed") { - RelayOutcome::Upgraded { overflow } => { + RelayOutcome::Upgraded { overflow, .. } => { assert!(overflow.is_empty(), "no overflow expected"); } other => panic!("Expected Upgraded, got {other:?}"), @@ -2125,13 +3291,345 @@ mod tests { async fn relay_accepts_101_with_client_upgrade_header() { // Client sends a proper upgrade request with Upgrade + Connection headers. let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); - let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let req = L7Request { + action: "GET".to_string(), + target: "/ws".to_string(), + query_params: HashMap::new(), + raw_header: b"GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n".to_vec(), + body_length: BodyLength::None, + }; + + let upstream_task = tokio::spawn(async move { + let mut buf = vec![0u8; 4096]; + let mut total = 0; + loop { + let n = upstream_side.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + } + upstream_side + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n", + ) + .await + .unwrap(); + upstream_side.flush().await.unwrap(); + }); + + let result = tokio::time::timeout( + std::time::Duration::from_secs(5), + relay_http_request_with_resolver( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + None, + ), + ) + .await + .expect("relay must not deadlock"); + + let outcome = result.expect("relay should succeed"); + assert!( + matches!(outcome, RelayOutcome::Upgraded { .. }), + "proper upgrade request should be accepted, got {outcome:?}" + ); + + upstream_task.await.expect("upstream task should complete"); + } + + #[tokio::test] + async fn opted_in_websocket_relay_rejects_invalid_upgrade_before_upstream_write() { + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let req = L7Request { + action: "GET".to_string(), + target: "/ws".to_string(), + query_params: HashMap::new(), + raw_header: b"GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: 13\r\n\r\n".to_vec(), + body_length: BodyLength::None, + }; + + let result = relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + websocket_extensions: WebSocketExtensionMode::PermessageDeflate, + ..Default::default() + }, + ) + .await; + + assert!( + result.is_err(), + "missing Sec-WebSocket-Key must fail closed" + ); + drop(proxy_to_upstream); + let mut forwarded = Vec::new(); + upstream_side.read_to_end(&mut forwarded).await.unwrap(); + assert!( + forwarded.is_empty(), + "invalid opted-in upgrade must not reach upstream" + ); + } + + #[tokio::test] + async fn opted_in_websocket_relay_strips_request_extensions_and_rejects_response_extensions() { + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let req = L7Request { + action: "GET".to_string(), + target: "/ws".to_string(), + query_params: HashMap::new(), + raw_header: format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate\r\nSec-WebSocket-Version: 13\r\n\r\n" + ) + .into_bytes(), + body_length: BodyLength::None, + }; + + let upstream_task = tokio::spawn(async move { + let mut buf = vec![0u8; 4096]; + let mut total = 0; + loop { + let n = upstream_side.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + } + let forwarded = String::from_utf8_lossy(&buf[..total]); + assert!( + !forwarded + .to_ascii_lowercase() + .contains("sec-websocket-extensions"), + "opted-in request must strip extension negotiation" + ); + upstream_side + .write_all( + format!( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: {VALID_WS_ACCEPT}\r\nSec-WebSocket-Extensions: permessage-deflate\r\n\r\n" + ) + .as_bytes(), + ) + .await + .unwrap(); + upstream_side.flush().await.unwrap(); + }); + + let result = relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + websocket_extensions: WebSocketExtensionMode::PermessageDeflate, + ..Default::default() + }, + ) + .await; + + let err = result.expect_err("upstream extension negotiation must fail closed"); + assert!(err.to_string().contains("not offered")); + upstream_task.await.expect("upstream task should complete"); + + drop(proxy_to_client); + let mut received = Vec::new(); + app_side.read_to_end(&mut received).await.unwrap(); + assert!( + received.is_empty(), + "rejected extension negotiation must not forward 101 headers" + ); + } + + #[tokio::test] + async fn permessage_deflate_mode_allows_supported_no_context_takeover() { + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let req = L7Request { + action: "GET".to_string(), + target: "/ws".to_string(), + query_params: HashMap::new(), + raw_header: format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover; server_no_context_takeover\r\nSec-WebSocket-Version: 13\r\n\r\n" + ) + .into_bytes(), + body_length: BodyLength::None, + }; + + let upstream_task = tokio::spawn(async move { + let mut buf = vec![0u8; 4096]; + let mut total = 0; + loop { + let n = upstream_side.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if buf[..total].windows(4).any(|w| w == b"\r\n\r\n") { + break; + } + } + let forwarded = String::from_utf8_lossy(&buf[..total]).to_ascii_lowercase(); + assert!(forwarded.contains( + "sec-websocket-extensions: permessage-deflate; client_no_context_takeover; server_no_context_takeover" + )); + upstream_side + .write_all( + format!( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: {VALID_WS_ACCEPT}\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover; server_no_context_takeover\r\n\r\n" + ) + .as_bytes(), + ) + .await + .unwrap(); + upstream_side.flush().await.unwrap(); + }); + + let outcome = relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + websocket_extensions: WebSocketExtensionMode::PermessageDeflate, + ..Default::default() + }, + ) + .await + .expect("safe permessage-deflate negotiation should pass"); + + assert!( + matches!( + outcome, + RelayOutcome::Upgraded { + websocket_permessage_deflate: true, + .. + } + ), + "safe permessage-deflate must be marked negotiated" + ); + upstream_task.await.expect("upstream task should complete"); + } + + #[tokio::test] + async fn websocket_conformance_preserve_mode_relays_raw_frames_without_validation() { + let (forwarded, frame) = run_upgraded_websocket_case( + None, + None, + WebSocketExtensionMode::Preserve, + None, + unmasked_frame(TEXT_OPCODE, b"raw-unmasked"), + ) + .await; + + assert!( + forwarded.contains("Upgrade: websocket"), + "raw preserve path should still forward the upgrade request" + ); + assert!( + !frame.masked, + "raw preserve path must not validate or rewrite client frame masking" + ); + assert_eq!(frame.fin_opcode & 0x0f, TEXT_OPCODE); + assert_eq!(frame.payload, b"raw-unmasked"); + } + + #[tokio::test] + async fn websocket_conformance_rewrite_mode_rewrites_text_after_upgrade() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [("DISCORD_BOT_TOKEN".to_string(), "real-token".to_string())] + .into_iter() + .collect(), + ); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); + + let (forwarded, frame) = run_upgraded_websocket_case( + None, + None, + WebSocketExtensionMode::PermessageDeflate, + resolver.map(Arc::new), + masked_frame_with_rsv(TEXT_OPCODE, 0, payload.as_bytes()), + ) + .await; + + assert!( + !forwarded + .to_ascii_lowercase() + .contains("sec-websocket-extensions"), + "plain rewrite path should not offer compression when the client did not offer a safe subset" + ); + assert!(frame.masked, "parsed relay must preserve client masking"); + assert_eq!(frame.fin_opcode & 0x0f, TEXT_OPCODE); + assert_eq!( + String::from_utf8(frame.payload).unwrap(), + r#"{"op":2,"d":{"token":"real-token"}}"# + ); + } + + #[tokio::test] + async fn websocket_conformance_deflate_rewrites_compressed_text_after_upgrade() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [("DISCORD_BOT_TOKEN".to_string(), "real-token".to_string())] + .into_iter() + .collect(), + ); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); + let compressed = compress_test_permessage_deflate(payload.as_bytes()); + + let (forwarded, frame) = run_upgraded_websocket_case( + Some("permessage-deflate; server_no_context_takeover; client_no_context_takeover"), + Some("permessage-deflate; server_no_context_takeover; client_no_context_takeover"), + WebSocketExtensionMode::PermessageDeflate, + resolver.map(Arc::new), + masked_frame_with_rsv(TEXT_OPCODE, 0x40, &compressed), + ) + .await; + + assert!( + forwarded.to_ascii_lowercase().contains( + "sec-websocket-extensions: permessage-deflate; client_no_context_takeover; server_no_context_takeover" + ), + "safe extension offer should be canonicalized before forwarding" + ); + assert!(frame.masked, "parsed relay must preserve client masking"); + assert_eq!(frame.fin_opcode & 0x0f, TEXT_OPCODE); + assert!( + frame.fin_opcode & 0x40 != 0, + "rewritten compressed text must retain RSV1" + ); + assert_eq!( + String::from_utf8(decompress_test_permessage_deflate(&frame.payload)).unwrap(), + r#"{"op":2,"d":{"token":"real-token"}}"# + ); + } + + #[tokio::test] + async fn opted_in_websocket_relay_rejects_invalid_accept_before_forwarding_101() { + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut app_side, mut proxy_to_client) = tokio::io::duplex(8192); let req = L7Request { action: "GET".to_string(), target: "/ws".to_string(), query_params: HashMap::new(), - raw_header: b"GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n".to_vec(), + raw_header: format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Version: 13\r\n\r\n" + ) + .into_bytes(), body_length: BodyLength::None, }; @@ -2150,32 +3648,247 @@ mod tests { } upstream_side .write_all( - b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n", + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: invalid\r\n\r\n", ) .await .unwrap(); upstream_side.flush().await.unwrap(); }); - let result = tokio::time::timeout( - std::time::Duration::from_secs(5), - relay_http_request_with_resolver( - &req, - &mut proxy_to_client, - &mut proxy_to_upstream, - None, + let result = relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + websocket_extensions: WebSocketExtensionMode::PermessageDeflate, + ..Default::default() + }, + ) + .await; + + let err = result.expect_err("invalid Sec-WebSocket-Accept must fail closed"); + assert!(err.to_string().contains("Sec-WebSocket-Accept")); + upstream_task.await.expect("upstream task should complete"); + + drop(proxy_to_client); + let mut received = Vec::new(); + app_side.read_to_end(&mut received).await.unwrap(); + assert!( + received.is_empty(), + "invalid websocket response must not forward 101 headers" + ); + } + + #[test] + fn websocket_accept_matches_rfc_6455_sample() { + assert_eq!(websocket_accept_for_key(VALID_WS_KEY), VALID_WS_ACCEPT); + } + + #[test] + fn strict_response_validation_rejects_missing_upgrade_headers() { + let validation = WebSocketResponseValidation { + expected_accept: VALID_WS_ACCEPT.to_string(), + expected_extension: None, + offered_subprotocols: Vec::new(), + }; + + let err = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n", + WebSocketExtensionMode::PermessageDeflate, + Some(&validation), + ) + .expect_err("missing Upgrade/Connection must fail"); + + assert!(err.to_string().contains("Upgrade: websocket")); + } + + #[test] + fn permessage_deflate_response_must_match_exact_safe_offer() { + let validation = WebSocketResponseValidation { + expected_accept: VALID_WS_ACCEPT.to_string(), + expected_extension: Some( + "permessage-deflate; client_no_context_takeover; server_no_context_takeover" + .to_string(), ), + offered_subprotocols: Vec::new(), + }; + + let err = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover\r\n\r\n", + WebSocketExtensionMode::PermessageDeflate, + Some(&validation), ) - .await - .expect("relay must not deadlock"); + .expect_err("extension response must exactly match the safe offer"); - let outcome = result.expect("relay should succeed"); + assert!(err.to_string().contains("safe offer")); + } + + #[test] + fn permessage_deflate_offer_requires_client_no_context_takeover() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); assert!( - matches!(outcome, RelayOutcome::Upgraded { .. }), - "proper upgrade request should be accepted, got {outcome:?}" + supported_permessage_deflate_offer(&raw) + .expect("valid unsupported extension offer should parse") + .is_none() ); + } - upstream_task.await.expect("upstream task should complete"); + #[test] + fn permessage_deflate_offer_canonicalizes_safe_params() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); + + assert_eq!( + supported_permessage_deflate_offer(&raw) + .expect("safe extension offer should parse") + .as_deref(), + Some("permessage-deflate; client_no_context_takeover; server_no_context_takeover") + ); + } + + #[test] + fn permessage_deflate_offer_rejects_duplicate_safe_params() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover; client_no_context_takeover\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); + + assert!( + supported_permessage_deflate_offer(&raw) + .expect("duplicate safe param should parse but not be supported") + .is_none() + ); + } + + #[test] + fn permessage_deflate_offer_rejects_quoted_values() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover=\"true\"\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); + + let err = supported_permessage_deflate_offer(&raw) + .expect_err("quoted permessage-deflate parameter values should fail closed"); + assert!(err.to_string().contains("parameter value")); + } + + #[test] + fn permessage_deflate_response_accepts_reordered_safe_params() { + let validation = WebSocketResponseValidation { + expected_accept: VALID_WS_ACCEPT.to_string(), + expected_extension: Some( + "permessage-deflate; client_no_context_takeover; server_no_context_takeover" + .to_string(), + ), + offered_subprotocols: Vec::new(), + }; + + let negotiated = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n\r\n", + WebSocketExtensionMode::PermessageDeflate, + Some(&validation), + ) + .expect("reordered safe extension params should canonicalize"); + + assert!(negotiated); + } + + #[test] + fn permessage_deflate_response_rejects_duplicate_safe_params() { + let validation = WebSocketResponseValidation { + expected_accept: VALID_WS_ACCEPT.to_string(), + expected_extension: Some("permessage-deflate; client_no_context_takeover".to_string()), + offered_subprotocols: Vec::new(), + }; + + let err = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover; client_no_context_takeover\r\n\r\n", + WebSocketExtensionMode::PermessageDeflate, + Some(&validation), + ) + .expect_err("duplicate extension params should fail closed"); + + assert!(err.to_string().contains("unsupported permessage-deflate")); + } + + #[test] + fn preserve_mode_leaves_malformed_extension_response_raw() { + let negotiated = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nSec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover=\"true\"\r\n\r\n", + WebSocketExtensionMode::Preserve, + None, + ) + .expect("preserve mode should not parse or reject raw extension negotiation"); + + assert!(!negotiated); + } + + #[test] + fn parse_websocket_upgrade_request_tracks_subprotocols() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Protocol: chat, superchat\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); + + let request = parse_websocket_upgrade_request(raw.as_bytes()) + .expect("request should parse") + .expect("request should be websocket"); + + assert_eq!(request.subprotocols, ["chat", "superchat"]); + } + + #[test] + fn strict_response_validation_allows_offered_subprotocol() { + let validation = WebSocketResponseValidation { + expected_accept: VALID_WS_ACCEPT.to_string(), + expected_extension: None, + offered_subprotocols: vec!["chat".to_string(), "superchat".to_string()], + }; + + let negotiated = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Protocol: superchat\r\n\r\n", + WebSocketExtensionMode::PermessageDeflate, + Some(&validation), + ) + .expect("offered subprotocol should validate"); + + assert!(!negotiated); + } + + #[test] + fn strict_response_validation_rejects_unoffered_subprotocol() { + let validation = WebSocketResponseValidation { + expected_accept: VALID_WS_ACCEPT.to_string(), + expected_extension: None, + offered_subprotocols: vec!["chat".to_string()], + }; + + let err = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Protocol: admin\r\n\r\n", + WebSocketExtensionMode::PermessageDeflate, + Some(&validation), + ) + .expect_err("unoffered subprotocol should fail closed"); + + assert!(err.to_string().contains("subprotocol")); + } + + #[test] + fn strict_response_validation_rejects_multiple_subprotocol_headers() { + let validation = WebSocketResponseValidation { + expected_accept: VALID_WS_ACCEPT.to_string(), + expected_extension: None, + offered_subprotocols: vec!["chat".to_string(), "superchat".to_string()], + }; + + let err = validate_websocket_response( + "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\nSec-WebSocket-Protocol: chat\r\nSec-WebSocket-Protocol: superchat\r\n\r\n", + WebSocketExtensionMode::PermessageDeflate, + Some(&validation), + ) + .expect_err("multiple selected subprotocols should fail closed"); + + assert!(err.to_string().contains("Sec-WebSocket-Protocol")); } #[tokio::test] @@ -2243,6 +3956,94 @@ mod tests { assert!(client_requested_upgrade(headers)); } + #[test] + fn request_is_websocket_upgrade_detects_websocket_upgrade() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: keep-alive, Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); + assert!(request_is_websocket_upgrade(raw.as_bytes())); + } + + #[test] + fn request_is_websocket_upgrade_rejects_missing_key() { + let raw = b"GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: 13\r\n\r\n"; + assert!(!request_is_websocket_upgrade(raw)); + assert!(validate_websocket_upgrade_request(raw).is_err()); + } + + #[test] + fn request_is_websocket_upgrade_rejects_wrong_method() { + let raw = format!( + "POST /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); + assert!(!request_is_websocket_upgrade(raw.as_bytes())); + assert!(validate_websocket_upgrade_request(raw.as_bytes()).is_err()); + } + + #[test] + fn request_is_websocket_upgrade_rejects_wrong_version() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Version: 12\r\n\r\n" + ); + assert!(!request_is_websocket_upgrade(raw.as_bytes())); + assert!(validate_websocket_upgrade_request(raw.as_bytes()).is_err()); + } + + #[test] + fn validate_websocket_upgrade_ignores_plain_rest_request() { + let raw = b"GET /api HTTP/1.1\r\nHost: example.com\r\n\r\n"; + assert!(!request_is_websocket_upgrade(raw)); + assert!(!validate_websocket_upgrade_request(raw).expect("plain request should parse")); + } + + #[test] + fn validate_websocket_upgrade_ignores_non_websocket_upgrade() { + let raw = b"GET /h2c HTTP/1.1\r\nHost: example.com\r\nUpgrade: h2c\r\nConnection: Upgrade\r\n\r\n"; + assert!(!request_is_websocket_upgrade(raw)); + assert!(!validate_websocket_upgrade_request(raw).expect("h2c request should parse")); + } + + #[test] + fn strip_websocket_extensions_removes_extension_negotiation() { + let raw = format!( + "GET /ws HTTP/1.1\r\nHost: example.com\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: {VALID_WS_KEY}\r\nSec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\nSec-WebSocket-Version: 13\r\n\r\n" + ); + + let (stripped, offered) = rewrite_websocket_extensions_for_mode( + raw.as_bytes(), + WebSocketExtensionMode::PermessageDeflate, + true, + ) + .expect("strip should succeed"); + assert!(offered.is_none()); + let stripped = String::from_utf8(stripped).unwrap(); + + assert!(stripped.contains("Upgrade: websocket\r\n")); + assert!(stripped.contains("Sec-WebSocket-Key: ")); + assert!(stripped.contains("Sec-WebSocket-Version: 13\r\n")); + assert!( + !stripped + .to_ascii_lowercase() + .contains("sec-websocket-extensions") + ); + assert!(stripped.ends_with("\r\n\r\n")); + } + + #[test] + fn strip_websocket_extensions_leaves_non_websocket_request_unchanged() { + let raw = b"GET /api HTTP/1.1\r\nHost: example.com\r\nSec-WebSocket-Extensions: permessage-deflate\r\n\r\n"; + + let (stripped, offered) = rewrite_websocket_extensions_for_mode( + raw, + WebSocketExtensionMode::PermessageDeflate, + false, + ) + .expect("strip should succeed"); + + assert!(offered.is_none()); + assert_eq!(stripped, raw); + } + #[test] fn rewrite_header_block_resolves_placeholder_auth_headers() { let (_, resolver) = SecretResolver::from_provider_env( @@ -2514,6 +4315,172 @@ mod tests { Ok(forwarded) } + async fn relay_and_capture_with_options( + raw_header: Vec, + body_length: BodyLength, + resolver: Option<&SecretResolver>, + request_body_credential_rewrite: bool, + ) -> Result { + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let header_str = String::from_utf8_lossy(&raw_header); + let first_line = header_str.lines().next().unwrap_or(""); + let parts: Vec<&str> = first_line.splitn(3, ' ').collect(); + let action = parts.first().unwrap_or(&"GET").to_string(); + let target = parts.get(1).unwrap_or(&"/").to_string(); + + let req = L7Request { + action, + target, + query_params: HashMap::new(), + raw_header, + body_length, + }; + + let upstream_task = tokio::spawn(async move { + let mut buf = vec![0u8; 8192]; + let mut total = 0usize; + let mut header_end = None; + let mut expected_total = None; + loop { + let n = upstream_side.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if header_end.is_none() + && let Some(end) = buf[..total].windows(4).position(|w| w == b"\r\n\r\n") + { + let end = end + 4; + let headers = String::from_utf8_lossy(&buf[..end]); + let len = headers + .lines() + .find_map(|line| { + let (name, value) = line.split_once(':')?; + name.eq_ignore_ascii_case("content-length") + .then(|| value.trim().parse::().ok()) + .flatten() + }) + .unwrap_or(0); + header_end = Some(end); + expected_total = Some(end + len); + } + if expected_total.is_some_and(|expected| total >= expected) { + break; + } + } + upstream_side + .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok") + .await + .unwrap(); + upstream_side.flush().await.unwrap(); + String::from_utf8_lossy(&buf[..total]).to_string() + }); + + relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + resolver, + request_body_credential_rewrite, + ..Default::default() + }, + ) + .await?; + + upstream_task + .await + .map_err(|e| miette!("upstream task failed: {e}")) + } + + #[tokio::test] + async fn relay_request_body_rewrites_provider_alias_header_and_urlencoded_token() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let alias = "provider.v1-OPENSHELL-RESOLVE-ENV-API_TOKEN"; + let body = format!("token={alias}&channel=C123"); + let raw = format!( + "POST /api/messages HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Authorization: Bearer {alias}\r\n\ + Content-Type: application/x-www-form-urlencoded\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body + ); + + let forwarded = relay_and_capture_with_options( + raw.into_bytes(), + BodyLength::ContentLength(body.len() as u64), + Some(&resolver), + true, + ) + .await + .expect("relay should succeed"); + + let expected_body = "token=provider-real-token&channel=C123"; + assert!(forwarded.contains("Authorization: Bearer provider-real-token\r\n")); + assert!(forwarded.contains(&format!("Content-Length: {}\r\n", expected_body.len()))); + assert!(forwarded.ends_with(expected_body)); + assert!(!forwarded.contains("OPENSHELL-RESOLVE-ENV")); + } + + #[tokio::test] + async fn relay_request_body_unresolved_alias_fails_before_upstream_write() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let body = "token=provider-OPENSHELL-RESOLVE-ENV-APP_TOKEN"; + let raw = format!( + "POST /api/connections.open HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Content-Type: application/x-www-form-urlencoded\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body + ); + let req = L7Request { + action: "POST".to_string(), + target: "/api/connections.open".to_string(), + query_params: HashMap::new(), + raw_header: raw.into_bytes(), + body_length: BodyLength::ContentLength(body.len() as u64), + }; + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let err = relay_http_request_with_options_guarded( + &req, + &mut proxy_to_client, + &mut proxy_to_upstream, + RelayRequestOptions { + resolver: Some(&resolver), + request_body_credential_rewrite: true, + ..Default::default() + }, + ) + .await + .expect_err("unknown body alias should fail closed"); + + assert!(!err.to_string().contains("provider-real-token")); + drop(proxy_to_upstream); + let mut forwarded = Vec::new(); + upstream_side.read_to_end(&mut forwarded).await.unwrap(); + assert!( + forwarded.is_empty(), + "failed body rewrite must not reach upstream" + ); + } + #[tokio::test] async fn relay_injects_bearer_header_credential() { let (child_env, resolver) = SecretResolver::from_provider_env( diff --git a/crates/openshell-sandbox/src/l7/websocket.rs b/crates/openshell-sandbox/src/l7/websocket.rs new file mode 100644 index 000000000..2dc1b25c3 --- /dev/null +++ b/crates/openshell-sandbox/src/l7/websocket.rs @@ -0,0 +1,1937 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! WebSocket relay for opt-in credential placeholder rewriting and message policy. +//! +//! The relay parses only client-to-server frames. Server-to-client bytes stay +//! raw passthrough so inspection and rewriting cannot expose response payloads. + +use crate::l7::relay::{L7EvalContext, evaluate_l7_request}; +use crate::l7::{EnforcementMode, L7RequestInfo}; +use crate::opa::TunnelPolicyEngine; +use crate::secrets::SecretResolver; +use flate2::{Compress, Compression, Decompress, FlushCompress, FlushDecompress, Status}; +use miette::{IntoDiagnostic, Result, miette}; +use openshell_ocsf::{ + ActionId, ActivityId, DispositionId, Endpoint, NetworkActivityBuilder, SeverityId, StatusId, + ocsf_emit, +}; +use std::collections::HashMap; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +const MAX_TEXT_MESSAGE_BYTES: usize = 1024 * 1024; +const MAX_RAW_FRAME_PAYLOAD_BYTES: u64 = 16 * 1024 * 1024; +const COPY_BUF_SIZE: usize = 8192; +const OPCODE_CONTINUATION: u8 = 0x0; +const OPCODE_TEXT: u8 = 0x1; +const OPCODE_BINARY: u8 = 0x2; +const OPCODE_CLOSE: u8 = 0x8; +const OPCODE_PING: u8 = 0x9; +const OPCODE_PONG: u8 = 0xA; + +#[derive(Debug)] +struct FrameHeader { + fin: bool, + rsv: u8, + opcode: u8, + masked: bool, + payload_len: u64, + mask_key: Option<[u8; 4]>, + raw_header: Vec, +} + +#[derive(Debug)] +enum FragmentState { + None, + Text { payload: Vec, compressed: bool }, + Binary, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) enum WebSocketCompression { + None, + PermessageDeflate, +} + +pub(super) struct InspectionOptions<'a> { + pub(super) engine: &'a TunnelPolicyEngine, + pub(super) ctx: &'a L7EvalContext, + pub(super) enforcement: EnforcementMode, + pub(super) target: String, + pub(super) query_params: HashMap>, + pub(super) graphql_policy: bool, +} + +pub(super) struct RelayOptions<'a> { + pub(super) policy_name: &'a str, + pub(super) resolver: Option<&'a SecretResolver>, + pub(super) inspector: Option>, + pub(super) compression: WebSocketCompression, +} + +/// Relay an upgraded WebSocket connection with optional client text inspection, +/// credential rewriting, and strict permessage-deflate handling. +pub(super) async fn relay_with_options( + client: &mut C, + upstream: &mut U, + overflow: Vec, + host: &str, + port: u16, + options: RelayOptions<'_>, +) -> Result<()> +where + C: AsyncRead + AsyncWrite + Unpin + Send, + U: AsyncRead + AsyncWrite + Unpin + Send, +{ + let (mut client_read, mut client_write) = tokio::io::split(client); + let (mut upstream_read, mut upstream_write) = tokio::io::split(upstream); + + if !overflow.is_empty() { + client_write.write_all(&overflow).await.into_diagnostic()?; + client_write.flush().await.into_diagnostic()?; + } + + let client_to_server = + relay_client_to_server(&mut client_read, &mut upstream_write, host, port, &options); + let server_to_client = async { + tokio::io::copy(&mut upstream_read, &mut client_write) + .await + .into_diagnostic()?; + client_write.flush().await.into_diagnostic()?; + Ok::<(), miette::Report>(()) + }; + + let result = tokio::select! { + result = client_to_server => result, + result = server_to_client => result, + }; + let _ = upstream_write.shutdown().await; + let _ = client_write.shutdown().await; + result +} + +async fn relay_client_to_server( + reader: &mut R, + writer: &mut W, + host: &str, + port: u16, + options: &RelayOptions<'_>, +) -> Result<()> +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + let mut fragments = FragmentState::None; + let mut close_seen = false; + + loop { + let Some(frame) = read_frame_header(reader).await.inspect_err(|e| { + emit_protocol_failure(host, port, options.policy_name, protocol_failure_class(e)); + })? + else { + writer.shutdown().await.into_diagnostic()?; + return Ok(()); + }; + + if close_seen { + let e = miette!("websocket frame received after close frame"); + emit_protocol_failure(host, port, options.policy_name, protocol_failure_class(&e)); + return Err(e); + } + + if let Err(e) = validate_frame_header(&frame, &fragments, options.compression) { + emit_protocol_failure(host, port, options.policy_name, protocol_failure_class(&e)); + return Err(e); + } + + match frame.opcode { + OPCODE_TEXT => { + let payload = read_masked_payload(reader, &frame).await.inspect_err(|e| { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); + })?; + let compressed = frame.rsv == 0x40; + if frame.fin { + relay_text_payload( + writer, &frame, payload, false, compressed, host, port, options, + ) + .await + .inspect_err(|e| { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); + })?; + } else { + fragments = FragmentState::Text { + payload, + compressed, + }; + } + } + OPCODE_CONTINUATION => match &mut fragments { + FragmentState::Text { + payload, + compressed, + } => { + let next = read_masked_payload(reader, &frame).await.inspect_err(|e| { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); + })?; + if let Err(e) = append_text_fragment(payload, next) { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(&e), + ); + return Err(e); + } + if frame.fin { + let complete = std::mem::take(payload); + let was_compressed = *compressed; + fragments = FragmentState::None; + relay_text_payload( + writer, + &frame, + complete, + true, + was_compressed, + host, + port, + options, + ) + .await + .inspect_err(|e| { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); + })?; + } + } + FragmentState::Binary => { + copy_raw_frame_payload(reader, writer, &frame) + .await + .inspect_err(|e| { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); + })?; + if frame.fin { + fragments = FragmentState::None; + } + } + FragmentState::None => { + let e = + miette!("websocket continuation frame without active fragmented message"); + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(&e), + ); + return Err(e); + } + }, + OPCODE_BINARY => { + if !frame.fin { + fragments = FragmentState::Binary; + } + copy_raw_frame_payload(reader, writer, &frame) + .await + .inspect_err(|e| { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); + })?; + } + OPCODE_CLOSE | OPCODE_PING | OPCODE_PONG => { + relay_control_frame(reader, writer, &frame) + .await + .inspect_err(|e| { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); + })?; + if frame.opcode == OPCODE_CLOSE { + close_seen = true; + } + } + _ => unreachable!("validated opcode"), + } + } +} + +async fn read_frame_header(reader: &mut R) -> Result> { + let first = match reader.read_u8().await { + Ok(byte) => byte, + Err(e) + if matches!( + e.kind(), + std::io::ErrorKind::UnexpectedEof + | std::io::ErrorKind::ConnectionReset + | std::io::ErrorKind::BrokenPipe + ) => + { + return Ok(None); + } + Err(e) => return Err(miette!("{e}")), + }; + let second = reader + .read_u8() + .await + .map_err(|e| miette!("malformed websocket frame header: {e}"))?; + + let mut raw_header = vec![first, second]; + let len_code = second & 0x7F; + let payload_len = match len_code { + 0..=125 => u64::from(len_code), + 126 => { + let mut bytes = [0u8; 2]; + reader + .read_exact(&mut bytes) + .await + .map_err(|e| miette!("malformed websocket extended length: {e}"))?; + raw_header.extend_from_slice(&bytes); + let len = u64::from(u16::from_be_bytes(bytes)); + if len < 126 { + return Err(miette!( + "websocket frame uses non-minimal 16-bit extended length" + )); + } + len + } + 127 => { + let mut bytes = [0u8; 8]; + reader + .read_exact(&mut bytes) + .await + .map_err(|e| miette!("malformed websocket extended length: {e}"))?; + if bytes[0] & 0x80 != 0 { + return Err(miette!("websocket frame uses non-canonical 64-bit length")); + } + raw_header.extend_from_slice(&bytes); + let len = u64::from_be_bytes(bytes); + if u16::try_from(len).is_ok() { + return Err(miette!( + "websocket frame uses non-minimal 64-bit extended length" + )); + } + len + } + _ => unreachable!("7-bit length code"), + }; + + let masked = second & 0x80 != 0; + let mask_key = if masked { + let mut key = [0u8; 4]; + reader + .read_exact(&mut key) + .await + .map_err(|e| miette!("malformed websocket mask key: {e}"))?; + raw_header.extend_from_slice(&key); + Some(key) + } else { + None + }; + + Ok(Some(FrameHeader { + fin: first & 0x80 != 0, + rsv: first & 0x70, + opcode: first & 0x0F, + masked, + payload_len, + mask_key, + raw_header, + })) +} + +fn validate_frame_header( + frame: &FrameHeader, + fragments: &FragmentState, + compression: WebSocketCompression, +) -> Result<()> { + if !valid_rsv_bits(frame, fragments, compression) { + return Err(miette!( + "websocket frame has unsupported RSV bits or extension state" + )); + } + if !frame.masked { + return Err(miette!("websocket client frame is not masked")); + } + if !matches!( + frame.opcode, + OPCODE_CONTINUATION + | OPCODE_TEXT + | OPCODE_BINARY + | OPCODE_CLOSE + | OPCODE_PING + | OPCODE_PONG + ) { + return Err(miette!("websocket frame uses reserved opcode")); + } + if matches!(frame.opcode, OPCODE_CLOSE | OPCODE_PING | OPCODE_PONG) { + if !frame.fin { + return Err(miette!("websocket control frame is fragmented")); + } + if frame.payload_len > 125 { + return Err(miette!("websocket control frame exceeds 125 bytes")); + } + } + if matches!(frame.opcode, OPCODE_TEXT | OPCODE_BINARY) + && !matches!(fragments, FragmentState::None) + { + return Err(miette!( + "websocket data frame started before previous fragmented message completed" + )); + } + if matches!(frame.opcode, OPCODE_CONTINUATION) && matches!(fragments, FragmentState::None) { + return Err(miette!( + "websocket continuation frame without active fragmented message" + )); + } + if (frame.opcode == OPCODE_BINARY + || (frame.opcode == OPCODE_CONTINUATION && matches!(fragments, FragmentState::Binary))) + && frame.payload_len > MAX_RAW_FRAME_PAYLOAD_BYTES + { + return Err(miette!( + "websocket binary frame exceeds {MAX_RAW_FRAME_PAYLOAD_BYTES} byte relay limit" + )); + } + Ok(()) +} + +fn valid_rsv_bits( + frame: &FrameHeader, + fragments: &FragmentState, + compression: WebSocketCompression, +) -> bool { + if frame.rsv == 0 { + return true; + } + if compression != WebSocketCompression::PermessageDeflate || frame.rsv != 0x40 { + return false; + } + matches!(fragments, FragmentState::None) && matches!(frame.opcode, OPCODE_TEXT | OPCODE_BINARY) +} + +async fn read_masked_payload( + reader: &mut R, + frame: &FrameHeader, +) -> Result> { + let payload_len = usize::try_from(frame.payload_len) + .map_err(|_| miette!("websocket text frame is too large to buffer"))?; + if payload_len > MAX_TEXT_MESSAGE_BYTES { + return Err(miette!( + "websocket text message exceeds {MAX_TEXT_MESSAGE_BYTES} byte limit" + )); + } + let mut payload = vec![0u8; payload_len]; + reader + .read_exact(&mut payload) + .await + .map_err(|e| miette!("malformed websocket payload: {e}"))?; + let mask_key = frame + .mask_key + .ok_or_else(|| miette!("websocket client frame is not masked"))?; + apply_mask(&mut payload, mask_key); + Ok(payload) +} + +fn append_text_fragment(buffer: &mut Vec, next: Vec) -> Result<()> { + let new_len = buffer + .len() + .checked_add(next.len()) + .ok_or_else(|| miette!("websocket text message length overflow"))?; + if new_len > MAX_TEXT_MESSAGE_BYTES { + return Err(miette!( + "websocket text message exceeds {MAX_TEXT_MESSAGE_BYTES} byte limit" + )); + } + buffer.extend_from_slice(&next); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +async fn relay_text_payload( + writer: &mut W, + frame: &FrameHeader, + payload: Vec, + force_reframe: bool, + compressed: bool, + host: &str, + port: u16, + options: &RelayOptions<'_>, +) -> Result<()> { + let message_payload = if compressed { + decompress_permessage_deflate(&payload)? + } else { + payload + }; + let mut text = String::from_utf8(message_payload) + .map_err(|_| miette!("websocket text message is not valid UTF-8"))?; + let replacements = if let Some(resolver) = options.resolver { + resolver + .rewrite_websocket_text_placeholders(&mut text) + .map_err(|_| miette!("websocket credential placeholder resolution failed"))? + } else { + 0 + }; + + if let Some(inspector) = options.inspector.as_ref() { + inspect_websocket_text_message(host, port, options.policy_name, inspector, &text)?; + } + + if replacements == 0 && !force_reframe && !compressed { + writer + .write_all(&frame.raw_header) + .await + .into_diagnostic()?; + let mut payload = text.into_bytes(); + let mask_key = frame + .mask_key + .ok_or_else(|| miette!("websocket client frame is not masked"))?; + apply_mask(&mut payload, mask_key); + writer.write_all(&payload).await.into_diagnostic()?; + writer.flush().await.into_diagnostic()?; + return Ok(()); + } + + if replacements > 0 { + emit_rewrite_event(host, port, options.policy_name, replacements); + } + if compressed { + let compressed_payload = compress_permessage_deflate(text.as_bytes())?; + return write_masked_frame_with_rsv(writer, OPCODE_TEXT, 0x40, &compressed_payload).await; + } + write_masked_frame(writer, OPCODE_TEXT, text.as_bytes()).await +} + +fn inspect_websocket_text_message( + host: &str, + port: u16, + policy_name: &str, + inspector: &InspectionOptions<'_>, + text: &str, +) -> Result<()> { + if inspector.graphql_policy { + return inspect_graphql_websocket_message(host, port, policy_name, inspector, text); + } + + let request_info = L7RequestInfo { + action: "WEBSOCKET_TEXT".to_string(), + target: inspector.target.clone(), + query_params: inspector.query_params.clone(), + graphql: None, + }; + let (allowed, reason) = evaluate_l7_request(inspector.engine, inspector.ctx, &request_info)?; + let decision = match (allowed, inspector.enforcement) { + (true, _) => "allow", + (false, EnforcementMode::Audit) => "audit", + (false, EnforcementMode::Enforce) => "deny", + }; + emit_websocket_l7_event( + host, + port, + policy_name, + &request_info, + decision, + &reason, + None, + ); + if !allowed && inspector.enforcement == EnforcementMode::Enforce { + return Err(miette!("websocket text message denied by policy")); + } + Ok(()) +} + +fn inspect_graphql_websocket_message( + host: &str, + port: u16, + policy_name: &str, + inspector: &InspectionOptions<'_>, + text: &str, +) -> Result<()> { + match classify_graphql_websocket_message(text) { + GraphqlWebSocketMessage::Control { message_type } => { + let request_info = L7RequestInfo { + action: "WEBSOCKET_CONTROL".to_string(), + target: inspector.target.clone(), + query_params: inspector.query_params.clone(), + graphql: None, + }; + emit_websocket_l7_event( + host, + port, + policy_name, + &request_info, + "allow", + &format!("GraphQL WebSocket control message {message_type}"), + None, + ); + Ok(()) + } + GraphqlWebSocketMessage::Operation { + message_type, + graphql, + } => { + let request_info = L7RequestInfo { + action: "WEBSOCKET_TEXT".to_string(), + target: inspector.target.clone(), + query_params: inspector.query_params.clone(), + graphql: Some(graphql.clone()), + }; + let parse_error_reason = graphql + .error + .as_deref() + .map(|error| format!("GraphQL WebSocket message rejected: {error}")); + let force_deny = parse_error_reason.is_some(); + let (allowed, reason) = if let Some(reason) = parse_error_reason { + (false, reason) + } else { + evaluate_l7_request(inspector.engine, inspector.ctx, &request_info)? + }; + let decision = match (allowed, inspector.enforcement) { + (_, _) if force_deny => "deny", + (true, _) => "allow", + (false, EnforcementMode::Audit) => "audit", + (false, EnforcementMode::Enforce) => "deny", + }; + let reason = format!("graphql_ws_type={message_type} {reason}"); + emit_websocket_l7_event( + host, + port, + policy_name, + &request_info, + decision, + &reason, + Some(&graphql), + ); + if (!allowed && inspector.enforcement == EnforcementMode::Enforce) || force_deny { + return Err(miette!("websocket GraphQL message denied by policy")); + } + Ok(()) + } + } +} + +#[derive(Debug)] +enum GraphqlWebSocketMessage { + Control { + message_type: String, + }, + Operation { + message_type: String, + graphql: crate::l7::graphql::GraphqlRequestInfo, + }, +} + +fn classify_graphql_websocket_message(text: &str) -> GraphqlWebSocketMessage { + let value = match serde_json::from_str::(text) { + Ok(value) => value, + Err(err) => { + return GraphqlWebSocketMessage::Operation { + message_type: "unknown".to_string(), + graphql: graphql_error(format!( + "GraphQL WebSocket message is not valid JSON: {err}" + )), + }; + } + }; + let Some(obj) = value.as_object() else { + return GraphqlWebSocketMessage::Operation { + message_type: "unknown".to_string(), + graphql: graphql_error("GraphQL WebSocket message must be a JSON object"), + }; + }; + let Some(message_type) = obj.get("type").and_then(serde_json::Value::as_str) else { + return GraphqlWebSocketMessage::Operation { + message_type: "unknown".to_string(), + graphql: graphql_error("GraphQL WebSocket message missing string type"), + }; + }; + + match message_type { + "subscribe" | "start" => { + if obj + .get("id") + .and_then(serde_json::Value::as_str) + .is_none_or(str::is_empty) + { + return GraphqlWebSocketMessage::Operation { + message_type: message_type.to_string(), + graphql: graphql_error( + "GraphQL WebSocket operation message missing non-empty id", + ), + }; + } + let Some(payload) = obj.get("payload").filter(|value| value.is_object()) else { + return GraphqlWebSocketMessage::Operation { + message_type: message_type.to_string(), + graphql: graphql_error( + "GraphQL WebSocket operation message missing object payload", + ), + }; + }; + GraphqlWebSocketMessage::Operation { + message_type: message_type.to_string(), + graphql: crate::l7::graphql::classify_json_envelope_value(payload), + } + } + "connection_init" | "connection_terminate" | "ping" | "pong" | "complete" | "stop" => { + GraphqlWebSocketMessage::Control { + message_type: message_type.to_string(), + } + } + _ => GraphqlWebSocketMessage::Operation { + message_type: message_type.to_string(), + graphql: graphql_error(format!( + "unsupported GraphQL WebSocket client message type {message_type:?}" + )), + }, + } +} + +fn graphql_error(message: impl Into) -> crate::l7::graphql::GraphqlRequestInfo { + crate::l7::graphql::GraphqlRequestInfo { + operations: Vec::new(), + error: Some(message.into()), + } +} + +async fn relay_control_frame( + reader: &mut R, + writer: &mut W, + frame: &FrameHeader, +) -> Result<()> +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + let raw_payload_len = usize::try_from(frame.payload_len) + .map_err(|_| miette!("websocket control frame payload length overflow"))?; + let mut raw_payload = vec![0u8; raw_payload_len]; + reader + .read_exact(&mut raw_payload) + .await + .map_err(|e| miette!("malformed websocket control payload: {e}"))?; + + if frame.opcode == OPCODE_CLOSE { + let mut payload = raw_payload.clone(); + let mask_key = frame + .mask_key + .ok_or_else(|| miette!("websocket client frame is not masked"))?; + apply_mask(&mut payload, mask_key); + validate_close_payload(&payload)?; + } + + writer + .write_all(&frame.raw_header) + .await + .into_diagnostic()?; + writer.write_all(&raw_payload).await.into_diagnostic()?; + writer.flush().await.into_diagnostic()?; + Ok(()) +} + +fn validate_close_payload(payload: &[u8]) -> Result<()> { + if payload.len() == 1 { + return Err(miette!( + "websocket close frame payload cannot be exactly one byte" + )); + } + if payload.len() < 2 { + return Ok(()); + } + + let code = u16::from_be_bytes([payload[0], payload[1]]); + if !valid_close_code(code) { + return Err(miette!("websocket close frame uses invalid close code")); + } + if std::str::from_utf8(&payload[2..]).is_err() { + return Err(miette!("websocket close frame reason is not valid UTF-8")); + } + Ok(()) +} + +fn valid_close_code(code: u16) -> bool { + (matches!(code, 1000..=1014) && !matches!(code, 1004..=1006)) || (3000..=4999).contains(&code) +} + +async fn copy_raw_frame_payload( + reader: &mut R, + writer: &mut W, + frame: &FrameHeader, +) -> Result<()> +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + writer + .write_all(&frame.raw_header) + .await + .into_diagnostic()?; + let mut remaining = frame.payload_len; + let mut buf = [0u8; COPY_BUF_SIZE]; + while remaining > 0 { + let to_read = usize::try_from(remaining) + .unwrap_or(buf.len()) + .min(buf.len()); + let n = reader.read(&mut buf[..to_read]).await.into_diagnostic()?; + if n == 0 { + return Err(miette!("websocket payload ended before declared length")); + } + writer.write_all(&buf[..n]).await.into_diagnostic()?; + remaining -= n as u64; + } + writer.flush().await.into_diagnostic()?; + Ok(()) +} + +async fn write_masked_frame( + writer: &mut W, + opcode: u8, + payload: &[u8], +) -> Result<()> { + write_masked_frame_with_rsv(writer, opcode, 0, payload).await +} + +async fn write_masked_frame_with_rsv( + writer: &mut W, + opcode: u8, + rsv: u8, + payload: &[u8], +) -> Result<()> { + let mut header = Vec::with_capacity(14); + header.push(0x80 | rsv | opcode); + match payload.len() { + 0..=125 => header.push(0x80 | u8::try_from(payload.len()).expect("payload <= 125")), + 126..=65_535 => { + header.push(0x80 | 0x7e); + header.extend_from_slice( + &u16::try_from(payload.len()) + .expect("payload <= 65535") + .to_be_bytes(), + ); + } + _ => { + header.push(0x80 | 127); + header.extend_from_slice(&(payload.len() as u64).to_be_bytes()); + } + } + let mask_key = new_mask_key(); + header.extend_from_slice(&mask_key); + + let mut masked = payload.to_vec(); + apply_mask(&mut masked, mask_key); + writer.write_all(&header).await.into_diagnostic()?; + writer.write_all(&masked).await.into_diagnostic()?; + writer.flush().await.into_diagnostic()?; + Ok(()) +} + +fn decompress_permessage_deflate(payload: &[u8]) -> Result> { + let mut decoder = Decompress::new(false); + let mut input = Vec::with_capacity(payload.len() + 4); + input.extend_from_slice(payload); + input.extend_from_slice(&[0x00, 0x00, 0xff, 0xff]); + let mut out = Vec::with_capacity(payload.len().saturating_mul(2).min(MAX_TEXT_MESSAGE_BYTES)); + let mut input_pos = 0usize; + let mut scratch = [0u8; COPY_BUF_SIZE]; + loop { + let before_in = decoder.total_in(); + let before_out = decoder.total_out(); + let status = decoder + .decompress(&input[input_pos..], &mut scratch, FlushDecompress::Sync) + .map_err(|e| miette!("websocket permessage-deflate decompression failed: {e}"))?; + let read = usize::try_from(decoder.total_in() - before_in) + .map_err(|_| miette!("websocket permessage-deflate input length overflow"))?; + let written = usize::try_from(decoder.total_out() - before_out) + .map_err(|_| miette!("websocket permessage-deflate output length overflow"))?; + input_pos = input_pos + .checked_add(read) + .ok_or_else(|| miette!("websocket permessage-deflate input length overflow"))?; + if out.len().saturating_add(written) > MAX_TEXT_MESSAGE_BYTES { + return Err(miette!( + "websocket text message exceeds {MAX_TEXT_MESSAGE_BYTES} byte limit" + )); + } + out.extend_from_slice(&scratch[..written]); + if matches!(status, Status::StreamEnd) { + break; + } + if input_pos >= input.len() && written < scratch.len() { + break; + } + if read == 0 && written == 0 { + return Err(miette!( + "websocket permessage-deflate decompression did not make progress" + )); + } + } + Ok(out) +} + +fn compress_permessage_deflate(payload: &[u8]) -> Result> { + let mut compressor = Compress::new(Compression::fast(), false); + let expansion = payload.len() / 16; + let mut out = Vec::with_capacity(payload.len().saturating_add(expansion).saturating_add(128)); + loop { + let consumed = usize::try_from(compressor.total_in()) + .map_err(|_| miette!("websocket permessage-deflate input length overflow"))?; + if consumed >= payload.len() { + break; + } + let before_in = compressor.total_in(); + let before_out = compressor.total_out(); + let status = compressor + .compress_vec(&payload[consumed..], &mut out, FlushCompress::None) + .map_err(|e| miette!("websocket permessage-deflate compression failed: {e}"))?; + if matches!(status, Status::BufError) + || (compressor.total_in() == before_in && compressor.total_out() == before_out) + { + out.reserve(out.capacity().max(1024)); + } + } + loop { + out.reserve(64); + let before_out = compressor.total_out(); + compressor + .compress_vec(&[], &mut out, FlushCompress::Sync) + .map_err(|e| miette!("websocket permessage-deflate compression failed: {e}"))?; + if out.ends_with(&[0x00, 0x00, 0xff, 0xff]) { + break; + } + if compressor.total_out() == before_out { + out.reserve(out.capacity().max(1024)); + } + } + if !out.ends_with(&[0x00, 0x00, 0xff, 0xff]) { + return Err(miette!( + "websocket permessage-deflate compression missing sync marker" + )); + } + out.truncate(out.len() - 4); + Ok(out) +} + +fn new_mask_key() -> [u8; 4] { + let bytes = uuid::Uuid::new_v4().into_bytes(); + [bytes[0], bytes[1], bytes[2], bytes[3]] +} + +fn apply_mask(payload: &mut [u8], mask_key: [u8; 4]) { + for (i, byte) in payload.iter_mut().enumerate() { + *byte ^= mask_key[i % 4]; + } +} + +fn emit_rewrite_event(host: &str, port: u16, policy_name: &str, replacements: usize) { + let policy_name = if policy_name.is_empty() { + "-" + } else { + policy_name + }; + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(ActionId::Allowed) + .disposition(DispositionId::Allowed) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .dst_endpoint(Endpoint::from_domain(host, port)) + .firewall_rule(policy_name, "l7-websocket") + .message(rewrite_event_message(host, port, replacements)) + .build(); + ocsf_emit!(event); +} + +fn rewrite_event_message(host: &str, port: u16, replacements: usize) -> String { + format!( + "WEBSOCKET_CREDENTIAL_REWRITE rewrote client text message [host:{host} port:{port} replacements:{replacements}]" + ) +} + +fn emit_websocket_l7_event( + host: &str, + port: u16, + policy_name: &str, + request_info: &L7RequestInfo, + decision: &str, + reason: &str, + graphql: Option<&crate::l7::graphql::GraphqlRequestInfo>, +) { + let policy_name = if policy_name.is_empty() { + "-" + } else { + policy_name + }; + let (action_id, disposition_id, severity) = match decision { + "deny" => (ActionId::Denied, DispositionId::Blocked, SeverityId::Medium), + "allow" | "audit" => ( + ActionId::Allowed, + DispositionId::Allowed, + SeverityId::Informational, + ), + _ => ( + ActionId::Other, + DispositionId::Other, + SeverityId::Informational, + ), + }; + let summary = graphql.map(graphql_log_summary).unwrap_or_default(); + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(action_id) + .disposition(disposition_id) + .severity(severity) + .status(StatusId::Success) + .dst_endpoint(Endpoint::from_domain(host, port)) + .firewall_rule(policy_name, "l7-websocket") + .message(format!( + "WEBSOCKET_L7_REQUEST {decision} {} {host}:{port}{}{} reason={reason}", + request_info.action, request_info.target, summary + )) + .build(); + ocsf_emit!(event); +} + +fn graphql_log_summary(info: &crate::l7::graphql::GraphqlRequestInfo) -> String { + if let Some(error) = info.error.as_deref() { + return format!(" graphql_error={error:?}"); + } + let ops: Vec = info + .operations + .iter() + .map(|op| { + let name = op.operation_name.as_deref().unwrap_or("-"); + let fields = if op.fields.is_empty() { + "-".to_string() + } else { + op.fields.join(",") + }; + let persisted = op + .persisted_query_hash + .as_deref() + .or(op.persisted_query_id.as_deref()) + .unwrap_or("-"); + format!( + "type={} name={} fields={} persisted={}", + op.operation_type, name, fields, persisted + ) + }) + .collect(); + format!(" graphql_ops={}", ops.join(";")) +} + +fn protocol_failure_class(error: &miette::Report) -> &'static str { + let msg = error.to_string().to_ascii_lowercase(); + if msg.contains("credential") { + "credential_resolution_failed" + } else if msg.contains("utf-8") { + "invalid_utf8" + } else if msg.contains("close frame") || msg.contains("after close") { + "invalid_close_frame" + } else if msg.contains("control frame") { + "invalid_control_frame" + } else if msg.contains("length") + || msg.contains("too large") + || msg.contains("exceeds") + || msg.contains("overflow") + { + "invalid_length" + } else if msg.contains("continuation") || msg.contains("fragmented") { + "invalid_fragmentation" + } else if msg.contains("reserved opcode") { + "reserved_opcode" + } else if msg.contains("not masked") { + "unmasked_client_frame" + } else if msg.contains("rsv") { + "rsv_bits" + } else if msg.contains("malformed") { + "malformed_frame" + } else { + "protocol_error" + } +} + +fn emit_protocol_failure(host: &str, port: u16, policy_name: &str, failure_class: &str) { + let policy_name = if policy_name.is_empty() { + "-" + } else { + policy_name + }; + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(host, port)) + .firewall_rule(policy_name, "l7-websocket") + .message(protocol_failure_message(host, port)) + .status_detail(failure_class) + .build(); + ocsf_emit!(event); +} + +fn protocol_failure_message(host: &str, port: u16) -> String { + format!("WEBSOCKET_CREDENTIAL_REWRITE closed ambiguous client frame [host:{host} port:{port}]") +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::l7::relay::L7EvalContext; + use crate::opa::{NetworkInput, OpaEngine}; + use crate::secrets::SecretResolver; + use std::path::PathBuf; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + const TEST_POLICY: &str = include_str!("../../data/sandbox-policy.rego"); + const GRAPHQL_WS_POLICY: &str = r#" +network_policies: + graphql_ws: + name: graphql_ws + endpoints: + - host: realtime.graphql.test + port: 443 + path: "/graphql" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/graphql" + - allow: + operation_type: query + fields: [viewer] + - allow: + operation_type: subscription + fields: [messageAdded] + binaries: + - { path: /usr/bin/node } +"#; + + fn resolver() -> (HashMap, SecretResolver) { + let (child_env, resolver) = SecretResolver::from_provider_env( + std::iter::once(("DISCORD_BOT_TOKEN".to_string(), "real-token".to_string())).collect(), + ); + (child_env, resolver.expect("resolver")) + } + + fn masked_frame(fin: bool, opcode: u8, payload: &[u8]) -> Vec { + masked_frame_with_rsv(fin, opcode, 0, payload) + } + + fn masked_frame_with_rsv(fin: bool, opcode: u8, rsv: u8, payload: &[u8]) -> Vec { + let mask_key = [0x37, 0xfa, 0x21, 0x3d]; + let mut frame = Vec::new(); + frame.push((if fin { 0x80 } else { 0 }) | rsv | opcode); + match payload.len() { + 0..=125 => frame.push(0x80 | u8::try_from(payload.len()).expect("payload <= 125")), + 126..=65_535 => { + frame.push(0x80 | 0x7e); + frame.extend_from_slice( + &u16::try_from(payload.len()) + .expect("payload <= 65535") + .to_be_bytes(), + ); + } + _ => { + frame.push(0x80 | 127); + frame.extend_from_slice(&(payload.len() as u64).to_be_bytes()); + } + } + frame.extend_from_slice(&mask_key); + for (i, byte) in payload.iter().enumerate() { + frame.push(byte ^ mask_key[i % 4]); + } + frame + } + + fn unmasked_frame(opcode: u8, payload: &[u8]) -> Vec { + let mut frame = Vec::new(); + frame.push(0x80 | opcode); + frame.push(u8::try_from(payload.len()).expect("test payload fits in one byte")); + frame.extend_from_slice(payload); + frame + } + + fn masked_frame_with_declared_len(opcode: u8, declared_len: u64) -> Vec { + let mut frame = Vec::new(); + frame.push(0x80 | opcode); + frame.push(0x80 | 127); + frame.extend_from_slice(&declared_len.to_be_bytes()); + frame.extend_from_slice(&[0x37, 0xfa, 0x21, 0x3d]); + frame + } + + fn masked_frame_with_non_minimal_16_bit_len(opcode: u8, payload: &[u8]) -> Vec { + let mask_key = [0x37, 0xfa, 0x21, 0x3d]; + let mut frame = Vec::new(); + frame.push(0x80 | opcode); + frame.push(0x80 | 0x7e); + frame.extend_from_slice( + &u16::try_from(payload.len()) + .expect("test payload fits u16") + .to_be_bytes(), + ); + frame.extend_from_slice(&mask_key); + for (i, byte) in payload.iter().enumerate() { + frame.push(byte ^ mask_key[i % 4]); + } + frame + } + + fn close_payload(code: u16, reason: &[u8]) -> Vec { + let mut payload = Vec::with_capacity(2 + reason.len()); + payload.extend_from_slice(&code.to_be_bytes()); + payload.extend_from_slice(reason); + payload + } + + async fn run_client_to_server(input: Vec) -> Result> { + let (_, resolver) = resolver(); + let (mut client_write, mut relay_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + let (mut relay_write, mut upstream_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + + client_write.write_all(&input).await.unwrap(); + drop(client_write); + + let options = RelayOptions { + policy_name: "test-policy", + resolver: Some(&resolver), + inspector: None, + compression: WebSocketCompression::None, + }; + let result = relay_client_to_server( + &mut relay_read, + &mut relay_write, + "gateway.example.test", + 443, + &options, + ) + .await; + drop(relay_write); + + let mut output = Vec::new(); + upstream_read.read_to_end(&mut output).await.unwrap(); + result.map(|()| output) + } + + async fn run_client_to_server_with_graphql_policy( + input: Vec, + resolver: Option<&SecretResolver>, + ) -> Result> { + let engine = OpaEngine::from_strings(TEST_POLICY, GRAPHQL_WS_POLICY) + .expect("GraphQL WebSocket policy should load"); + let network_input = NetworkInput { + host: "realtime.graphql.test".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let generation = engine + .evaluate_network_action_with_generation(&network_input) + .expect("network action should evaluate") + .1; + let tunnel_engine = engine + .clone_engine_for_tunnel(generation) + .expect("tunnel engine"); + let ctx = L7EvalContext { + host: "realtime.graphql.test".into(), + port: 443, + policy_name: "graphql_ws".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + }; + let (mut client_write, mut relay_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + let (mut relay_write, mut upstream_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + + client_write.write_all(&input).await.unwrap(); + drop(client_write); + + let options = RelayOptions { + policy_name: "graphql_ws", + resolver, + inspector: Some(InspectionOptions { + engine: &tunnel_engine, + ctx: &ctx, + enforcement: EnforcementMode::Enforce, + target: "/graphql".to_string(), + query_params: HashMap::new(), + graphql_policy: true, + }), + compression: WebSocketCompression::None, + }; + let result = relay_client_to_server( + &mut relay_read, + &mut relay_write, + "realtime.graphql.test", + 443, + &options, + ) + .await; + drop(relay_write); + + let mut output = Vec::new(); + upstream_read.read_to_end(&mut output).await.unwrap(); + result.map(|()| output) + } + + async fn run_client_to_server_compressed(input: Vec) -> Result> { + let (_, resolver) = resolver(); + let (mut client_write, mut relay_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + let (mut relay_write, mut upstream_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + + client_write.write_all(&input).await.unwrap(); + drop(client_write); + + let options = RelayOptions { + policy_name: "test-policy", + resolver: Some(&resolver), + inspector: None, + compression: WebSocketCompression::PermessageDeflate, + }; + let result = relay_client_to_server( + &mut relay_read, + &mut relay_write, + "gateway.example.test", + 443, + &options, + ) + .await; + drop(relay_write); + + let mut output = Vec::new(); + upstream_read.read_to_end(&mut output).await.unwrap(); + result.map(|()| output) + } + + fn decode_masked_text_frame(frame: &[u8]) -> String { + assert_eq!(frame[0] & 0x0F, OPCODE_TEXT); + assert_ne!(frame[1] & 0x80, 0); + String::from_utf8(decode_masked_payload(frame)).unwrap() + } + + fn decode_masked_payload(frame: &[u8]) -> Vec { + assert_ne!(frame[1] & 0x80, 0); + let len_code = frame[1] & 0x7F; + let (payload_len, mask_offset) = match len_code { + 0..=125 => (usize::from(len_code), 2), + 126 => (usize::from(u16::from_be_bytes([frame[2], frame[3]])), 4), + 127 => { + let len = u64::from_be_bytes(frame[2..10].try_into().unwrap()); + (usize::try_from(len).unwrap(), 10) + } + _ => unreachable!(), + }; + let mask_key: [u8; 4] = frame[mask_offset..mask_offset + 4].try_into().unwrap(); + let mut payload = frame[mask_offset + 4..mask_offset + 4 + payload_len].to_vec(); + apply_mask(&mut payload, mask_key); + payload + } + + fn decode_compressed_masked_text_frame(frame: &[u8]) -> String { + assert_eq!(frame[0] & 0x0F, OPCODE_TEXT); + assert_eq!(frame[0] & 0x40, 0x40); + let payload = decode_masked_payload(frame); + String::from_utf8(decompress_permessage_deflate(&payload).unwrap()).unwrap() + } + + async fn read_one_frame(reader: &mut R) -> Vec { + let mut header = [0u8; 2]; + reader.read_exact(&mut header).await.unwrap(); + let len_code = header[1] & 0x7F; + let extended_len = match len_code { + 0..=125 => Vec::new(), + 126 => { + let mut bytes = vec![0u8; 2]; + reader.read_exact(&mut bytes).await.unwrap(); + bytes + } + 127 => { + let mut bytes = vec![0u8; 8]; + reader.read_exact(&mut bytes).await.unwrap(); + bytes + } + _ => unreachable!(), + }; + let payload_len = match len_code { + 0..=125 => usize::from(len_code), + 126 => usize::from(u16::from_be_bytes( + extended_len.as_slice().try_into().unwrap(), + )), + 127 => usize::try_from(u64::from_be_bytes( + extended_len.as_slice().try_into().unwrap(), + )) + .unwrap(), + _ => unreachable!(), + }; + let mask_len = if header[1] & 0x80 != 0 { 4 } else { 0 }; + let mut rest = vec![0u8; extended_len.len() + mask_len + payload_len]; + rest[..extended_len.len()].copy_from_slice(&extended_len); + reader + .read_exact(&mut rest[extended_len.len()..]) + .await + .unwrap(); + + let mut frame = header.to_vec(); + frame.extend_from_slice(&rest); + frame + } + + #[test] + fn classifies_graphql_transport_ws_subscribe_operation() { + let message = r#"{"type":"subscribe","id":"1","payload":{"query":"subscription NewMessages { messageAdded }"}}"#; + + match classify_graphql_websocket_message(message) { + GraphqlWebSocketMessage::Operation { + message_type, + graphql, + } => { + assert_eq!(message_type, "subscribe"); + assert!( + graphql.error.is_none(), + "unexpected error: {:?}", + graphql.error + ); + assert_eq!(graphql.operations.len(), 1); + assert_eq!(graphql.operations[0].operation_type, "subscription"); + assert_eq!( + graphql.operations[0].operation_name.as_deref(), + Some("NewMessages") + ); + assert_eq!(graphql.operations[0].fields, vec!["messageAdded"]); + } + other @ GraphqlWebSocketMessage::Control { .. } => { + panic!("expected operation, got {other:?}") + } + } + } + + #[test] + fn classifies_legacy_graphql_ws_start_operation() { + let message = r#"{"type":"start","id":"1","payload":{"query":"query Viewer { viewer }"}}"#; + + match classify_graphql_websocket_message(message) { + GraphqlWebSocketMessage::Operation { + message_type, + graphql, + } => { + assert_eq!(message_type, "start"); + assert!( + graphql.error.is_none(), + "unexpected error: {:?}", + graphql.error + ); + assert_eq!(graphql.operations[0].operation_type, "query"); + assert_eq!(graphql.operations[0].fields, vec!["viewer"]); + } + other @ GraphqlWebSocketMessage::Control { .. } => { + panic!("expected operation, got {other:?}") + } + } + } + + #[test] + fn classifies_graphql_websocket_control_message_without_payload_logging() { + match classify_graphql_websocket_message( + r#"{"type":"connection_init","payload":{"authorization":"secret"}}"#, + ) { + GraphqlWebSocketMessage::Control { message_type } => { + assert_eq!(message_type, "connection_init"); + } + other @ GraphqlWebSocketMessage::Operation { .. } => { + panic!("expected control message, got {other:?}") + } + } + } + + #[test] + fn unsupported_graphql_websocket_message_type_fails_closed() { + match classify_graphql_websocket_message(r#"{"type":"next","id":"1"}"#) { + GraphqlWebSocketMessage::Operation { graphql, .. } => { + assert!( + graphql + .error + .as_deref() + .is_some_and(|error| error.contains("unsupported")) + ); + } + other @ GraphqlWebSocketMessage::Control { .. } => { + panic!("expected operation error, got {other:?}") + } + } + } + + #[test] + fn graphql_websocket_log_summary_excludes_payload_variables_and_secrets() { + let placeholder = "openshell:resolve:env:T"; + let message = format!( + r#"{{"type":"subscribe","id":"1","payload":{{"query":"query Viewer {{ viewer }}","variables":{{"token":"{placeholder}"}}}}}}"# + ); + let graphql = match classify_graphql_websocket_message(&message) { + GraphqlWebSocketMessage::Operation { graphql, .. } => graphql, + other @ GraphqlWebSocketMessage::Control { .. } => { + panic!("expected operation, got {other:?}") + } + }; + let summary = graphql_log_summary(&graphql); + + assert!(summary.contains("type=query")); + assert!(summary.contains("fields=viewer")); + assert!(!summary.contains(placeholder)); + assert!(!summary.contains("real-token")); + assert!(!summary.contains("variables")); + assert!(!summary.contains("token")); + assert!(!summary.contains("secret_len")); + } + + #[tokio::test] + async fn rewrites_discord_like_identify_text_payload() { + let (child_env, _) = resolver(); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); + + let output = run_client_to_server(masked_frame(true, OPCODE_TEXT, payload.as_bytes())) + .await + .expect("relay should succeed"); + + assert_eq!( + decode_masked_text_frame(&output), + r#"{"op":2,"d":{"token":"real-token"}}"# + ); + } + + #[tokio::test] + async fn upgraded_relay_rewrites_client_text_before_upstream_receives_it() { + let (child_env, resolver) = resolver(); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); + let client_frame = masked_frame(true, OPCODE_TEXT, payload.as_bytes()); + assert!( + !String::from_utf8_lossy(&client_frame).contains("real-token"), + "client-side fixture must not contain the real token" + ); + + let (mut client_app, mut relay_client) = tokio::io::duplex(4096); + let (mut relay_upstream, mut upstream_app) = tokio::io::duplex(4096); + let relay = tokio::spawn(async move { + relay_with_options( + &mut relay_client, + &mut relay_upstream, + Vec::new(), + "gateway.example.test", + 443, + RelayOptions { + policy_name: "test-policy", + resolver: Some(&resolver), + inspector: None, + compression: WebSocketCompression::None, + }, + ) + .await + }); + + client_app.write_all(&client_frame).await.unwrap(); + client_app.flush().await.unwrap(); + + let upstream_frame = tokio::time::timeout( + std::time::Duration::from_secs(2), + read_one_frame(&mut upstream_app), + ) + .await + .expect("upstream should receive rewritten frame"); + assert_eq!( + decode_masked_text_frame(&upstream_frame), + r#"{"op":2,"d":{"token":"real-token"}}"# + ); + + drop(client_app); + drop(upstream_app); + let _ = tokio::time::timeout(std::time::Duration::from_secs(2), relay).await; + } + + #[tokio::test] + async fn graphql_websocket_policy_allows_subscription_operation() { + let payload = r#"{"type":"subscribe","id":"1","payload":{"query":"subscription NewMessages { messageAdded }"}}"#; + let frame = masked_frame(true, OPCODE_TEXT, payload.as_bytes()); + + let output = run_client_to_server_with_graphql_policy(frame.clone(), None) + .await + .expect("allowed subscription should relay"); + + assert_eq!(output, frame); + assert_eq!(decode_masked_text_frame(&output), payload); + } + + #[tokio::test] + async fn graphql_websocket_policy_denies_unlisted_operation_field() { + let payload = + r#"{"type":"subscribe","id":"1","payload":{"query":"query Admin { adminAuditLog }"}}"#; + let frame = masked_frame(true, OPCODE_TEXT, payload.as_bytes()); + + let err = run_client_to_server_with_graphql_policy(frame, None) + .await + .expect_err("unlisted field should be denied"); + + assert!(err.to_string().contains("websocket GraphQL message denied")); + } + + #[tokio::test] + async fn graphql_websocket_control_message_rewrites_credentials_before_relay() { + let (child_env, resolver) = SecretResolver::from_provider_env( + std::iter::once(("T".to_string(), "real-token".to_string())).collect(), + ); + let resolver = resolver.expect("resolver"); + let placeholder = child_env.get("T").expect("placeholder env"); + let payload = format!( + r#"{{"type":"connection_init","payload":{{"authorization":"{placeholder}"}}}}"# + ); + let frame = masked_frame(true, OPCODE_TEXT, payload.as_bytes()); + + let output = run_client_to_server_with_graphql_policy(frame, Some(&resolver)) + .await + .expect("control message should relay after credential rewrite"); + + let rewritten = decode_masked_text_frame(&output); + assert_eq!( + rewritten, + r#"{"type":"connection_init","payload":{"authorization":"real-token"}}"# + ); + assert!(!rewritten.contains(placeholder)); + } + + #[tokio::test] + async fn text_without_placeholder_passes_semantically_unchanged() { + let frame = masked_frame(true, OPCODE_TEXT, br#"{"op":1,"d":42}"#); + let output = run_client_to_server(frame.clone()) + .await + .expect("relay should succeed"); + + assert_eq!(output, frame); + assert_eq!(decode_masked_text_frame(&output), r#"{"op":1,"d":42}"#); + } + + #[tokio::test] + async fn unknown_placeholder_fails_closed() { + let frame = masked_frame( + true, + OPCODE_TEXT, + br#"{"token":"openshell:resolve:env:UNKNOWN"}"#, + ); + + let err = run_client_to_server(frame) + .await + .expect_err("unknown placeholder should fail"); + + assert!( + err.to_string() + .contains("credential placeholder resolution") + ); + } + + #[tokio::test] + async fn fragmented_text_rewrites_after_final_continuation() { + let (child_env, _) = resolver(); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let first = format!(r#"{{"token":"{placeholder}"#); + let second = r#""}"#; + let mut input = masked_frame(false, OPCODE_TEXT, first.as_bytes()); + input.extend(masked_frame(true, OPCODE_CONTINUATION, second.as_bytes())); + + let output = run_client_to_server(input) + .await + .expect("relay should succeed"); + + assert_eq!( + decode_masked_text_frame(&output), + r#"{"token":"real-token"}"# + ); + } + + #[tokio::test] + async fn rejects_rsv_bits() { + let mut frame = masked_frame(true, OPCODE_TEXT, b"hello"); + frame[0] |= 0x40; + + let err = run_client_to_server(frame) + .await + .expect_err("RSV frame should fail"); + + assert!(err.to_string().contains("RSV bits")); + } + + #[tokio::test] + async fn rejects_unmasked_client_frame() { + let err = run_client_to_server(unmasked_frame(OPCODE_TEXT, b"hello")) + .await + .expect_err("unmasked frame should fail"); + + assert!(err.to_string().contains("not masked")); + } + + #[tokio::test] + async fn rejects_invalid_utf8_text() { + let err = run_client_to_server(masked_frame(true, OPCODE_TEXT, &[0xff])) + .await + .expect_err("invalid UTF-8 should fail"); + + assert!(err.to_string().contains("valid UTF-8")); + } + + #[tokio::test] + async fn rejects_oversize_text_message() { + let payload = vec![b'a'; MAX_TEXT_MESSAGE_BYTES + 1]; + let err = run_client_to_server(masked_frame(true, OPCODE_TEXT, &payload)) + .await + .expect_err("oversize text should fail"); + + assert!(err.to_string().contains("exceeds")); + } + + #[tokio::test] + async fn fragmented_text_allows_interleaved_ping_pong_and_rewrites_at_completion() { + let (child_env, _) = resolver(); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let first = format!(r#"{{"token":"{placeholder}"#); + let first_control_frame = masked_frame(true, OPCODE_PING, b"p"); + let second_control_frame = masked_frame(true, OPCODE_PONG, b"q"); + let mut input = masked_frame(false, OPCODE_TEXT, first.as_bytes()); + input.extend_from_slice(&first_control_frame); + input.extend_from_slice(&second_control_frame); + input.extend(masked_frame(true, OPCODE_CONTINUATION, br#""}"#)); + + let output = run_client_to_server(input) + .await + .expect("relay should allow interleaved control frames"); + + assert!(output.starts_with(&first_control_frame)); + assert_eq!( + &output + [first_control_frame.len()..first_control_frame.len() + second_control_frame.len()], + second_control_frame.as_slice() + ); + assert_eq!( + decode_masked_text_frame( + &output[first_control_frame.len() + second_control_frame.len()..] + ), + r#"{"token":"real-token"}"# + ); + } + + #[tokio::test] + async fn compressed_text_rewrites_with_permessage_deflate() { + let (child_env, _) = resolver(); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let payload = format!(r#"{{"token":"{placeholder}"}}"#); + let compressed = compress_permessage_deflate(payload.as_bytes()).unwrap(); + let input = masked_frame_with_rsv(true, OPCODE_TEXT, 0x40, &compressed); + + let output = run_client_to_server_compressed(input) + .await + .expect("compressed text should relay"); + + assert_eq!( + decode_compressed_masked_text_frame(&output), + r#"{"token":"real-token"}"# + ); + } + + #[tokio::test] + async fn compressed_text_rejects_decompressed_oversize_message() { + let payload = vec![b'a'; MAX_TEXT_MESSAGE_BYTES + 1]; + let compressed = compress_permessage_deflate(&payload).unwrap(); + let input = masked_frame_with_rsv(true, OPCODE_TEXT, 0x40, &compressed); + + let err = run_client_to_server_compressed(input) + .await + .expect_err("oversize decompressed text should fail"); + + assert!(err.to_string().contains("exceeds")); + } + + #[tokio::test] + async fn binary_frame_passes_through_unchanged() { + let frame = masked_frame(true, OPCODE_BINARY, &[0, 1, 2, 3, 255]); + + let output = run_client_to_server(frame.clone()) + .await + .expect("binary frame should pass through"); + + assert_eq!(output, frame); + } + + #[tokio::test] + async fn rejects_reserved_opcode() { + let err = run_client_to_server(masked_frame(true, 0x3, b"reserved")) + .await + .expect_err("reserved opcode should fail"); + + assert!(err.to_string().contains("reserved opcode")); + } + + #[tokio::test] + async fn rejects_continuation_without_active_message() { + let err = run_client_to_server(masked_frame(true, OPCODE_CONTINUATION, b"orphan")) + .await + .expect_err("orphan continuation should fail"); + + assert!(err.to_string().contains("continuation")); + } + + #[tokio::test] + async fn rejects_new_data_frame_before_fragment_completion() { + let mut input = masked_frame(false, OPCODE_TEXT, b"partial"); + input.extend(masked_frame(true, OPCODE_TEXT, b"second")); + + let err = run_client_to_server(input) + .await + .expect_err("new data frame during fragmentation should fail"); + + assert!(err.to_string().contains("previous fragmented message")); + } + + #[tokio::test] + async fn rejects_fragmented_control_frame() { + let err = run_client_to_server(masked_frame(false, OPCODE_PING, b"ping")) + .await + .expect_err("fragmented control frame should fail"); + + assert!(err.to_string().contains("control frame is fragmented")); + } + + #[tokio::test] + async fn rejects_control_frame_over_125_bytes() { + let payload = vec![b'a'; 126]; + let err = run_client_to_server(masked_frame(true, OPCODE_PING, &payload)) + .await + .expect_err("oversize control frame should fail"); + + assert!(err.to_string().contains("control frame exceeds")); + } + + #[tokio::test] + async fn rejects_non_minimal_extended_length() { + let err = run_client_to_server(masked_frame_with_non_minimal_16_bit_len( + OPCODE_TEXT, + b"hello", + )) + .await + .expect_err("non-minimal length should fail"); + + assert!(err.to_string().contains("non-minimal")); + } + + #[tokio::test] + async fn rejects_oversize_binary_frame_before_payload_buffering() { + let err = run_client_to_server(masked_frame_with_declared_len( + OPCODE_BINARY, + MAX_RAW_FRAME_PAYLOAD_BYTES + 1, + )) + .await + .expect_err("oversize binary frame should fail"); + + assert!(err.to_string().contains("binary frame exceeds")); + } + + #[tokio::test] + async fn validates_close_frame_payloads() { + let frame = masked_frame(true, OPCODE_CLOSE, &close_payload(1000, b"done")); + + let output = run_client_to_server(frame.clone()) + .await + .expect("valid close frame should pass through"); + + assert_eq!(output, frame); + } + + #[tokio::test] + async fn rejects_close_frame_with_one_byte_payload() { + let err = run_client_to_server(masked_frame(true, OPCODE_CLOSE, &[0x03])) + .await + .expect_err("one-byte close frame should fail"); + + assert!(err.to_string().contains("exactly one byte")); + } + + #[tokio::test] + async fn rejects_reserved_close_code() { + let err = run_client_to_server(masked_frame(true, OPCODE_CLOSE, &close_payload(1005, b""))) + .await + .expect_err("reserved close code should fail"); + + assert!(err.to_string().contains("invalid close code")); + } + + #[tokio::test] + async fn rejects_close_reason_with_invalid_utf8() { + let err = run_client_to_server(masked_frame( + true, + OPCODE_CLOSE, + &close_payload(1000, &[0xff]), + )) + .await + .expect_err("invalid close reason should fail"); + + assert!(err.to_string().contains("valid UTF-8")); + } + + #[tokio::test] + async fn rejects_frames_after_client_close_frame() { + let mut input = masked_frame(true, OPCODE_CLOSE, &close_payload(1000, b"done")); + input.extend(masked_frame(true, OPCODE_TEXT, b"late")); + + let err = run_client_to_server(input) + .await + .expect_err("frames after close should fail"); + + assert!(err.to_string().contains("after close")); + } + + #[test] + fn websocket_ocsf_messages_do_not_include_payload_or_secret_material() { + let placeholder = "openshell:resolve:env:DISCORD_BOT_TOKEN"; + let secret = "real-token"; + let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); + + let rewrite = rewrite_event_message("gateway.example.test", 443, 1); + let failure = protocol_failure_message("gateway.example.test", 443); + let messages = [rewrite, failure]; + + for message in messages { + assert!(!message.contains(placeholder)); + assert!(!message.contains(secret)); + assert!(!message.contains(&payload)); + assert!(!message.contains("secret_len")); + assert!(!message.contains("payload_len")); + } + } +} diff --git a/crates/openshell-sandbox/src/opa.rs b/crates/openshell-sandbox/src/opa.rs index 5897679a0..a9ab94a2b 100644 --- a/crates/openshell-sandbox/src/opa.rs +++ b/crates/openshell-sandbox/src/opa.rs @@ -1061,6 +1061,12 @@ fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> St if e.allow_encoded_slash { ep["allow_encoded_slash"] = true.into(); } + if e.websocket_credential_rewrite { + ep["websocket_credential_rewrite"] = true.into(); + } + if e.request_body_credential_rewrite { + ep["request_body_credential_rewrite"] = true.into(); + } if !e.persisted_queries.is_empty() { ep["persisted_queries"] = e.persisted_queries.clone().into(); } @@ -1811,6 +1817,28 @@ network_policies: access: read-only binaries: - { path: /usr/bin/curl } + graphql_ws: + name: graphql_ws + endpoints: + - host: realtime.graphql.com + ports: [443] + path: "/graphql" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/graphql" + - allow: + operation_type: query + fields: [viewer] + - allow: + operation_type: subscription + fields: [messageAdded] + deny_rules: + - operation_type: mutation + binaries: + - { path: /usr/bin/curl } l4_only: name: l4_only endpoints: @@ -1897,6 +1925,25 @@ process: }) } + fn l7_websocket_graphql_input(host: &str, operations: serde_json::Value) -> serde_json::Value { + serde_json::json!({ + "network": { "host": host, "port": 443 }, + "exec": { + "path": "/usr/bin/curl", + "ancestors": [], + "cmdline_paths": [] + }, + "request": { + "method": "WEBSOCKET_TEXT", + "path": "/graphql", + "query_params": {}, + "graphql": { + "operations": operations + } + } + }) + } + fn eval_l7(engine: &OpaEngine, input: &serde_json::Value) -> bool { let mut eng = engine.engine.lock().unwrap(); eng.set_input_json(&input.to_string()).unwrap(); @@ -2134,6 +2181,97 @@ process: assert!(!eval_l7(&engine, &mutation)); } + #[test] + fn l7_websocket_graphql_subscription_allowed_by_field_rule() { + let engine = l7_engine(); + let input = l7_websocket_graphql_input( + "realtime.graphql.com", + serde_json::json!([{ + "operation_type": "subscription", + "operation_name": "NewMessages", + "fields": ["messageAdded"], + "persisted_query": false + }]), + ); + assert!(eval_l7(&engine, &input)); + } + + #[test] + fn l7_websocket_graphql_unlisted_field_denied() { + let engine = l7_engine(); + let input = l7_websocket_graphql_input( + "realtime.graphql.com", + serde_json::json!([{ + "operation_type": "query", + "fields": ["adminAuditLog"], + "persisted_query": false + }]), + ); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_websocket_graphql_deny_rule_takes_precedence() { + let engine = l7_engine(); + let input = l7_websocket_graphql_input( + "realtime.graphql.com", + serde_json::json!([{ + "operation_type": "mutation", + "operation_name": "DeleteRepo", + "fields": ["deleteRepository"], + "persisted_query": false + }]), + ); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_websocket_graphql_not_bypassed_by_generic_text_rule() { + let data = r#" +network_policies: + graphql_ws: + name: graphql_ws + endpoints: + - host: realtime.graphql.com + ports: [443] + path: "/graphql" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/graphql" + - allow: + method: WEBSOCKET_TEXT + path: "/graphql" + - allow: + operation_type: query + fields: [viewer] + binaries: + - { path: /usr/bin/curl } +"#; + let data_json: serde_json::Value = + serde_yml::from_str(data).expect("fixture should parse as YAML"); + let mut rego = regorus::Engine::new(); + rego.add_policy("policy.rego".into(), TEST_POLICY.into()) + .expect("policy should load"); + rego.add_data_json(&data_json.to_string()) + .expect("data should load"); + let engine = OpaEngine { + engine: Mutex::new(rego), + generation: Arc::new(AtomicU64::new(0)), + }; + let input = l7_websocket_graphql_input( + "realtime.graphql.com", + serde_json::json!([{ + "operation_type": "query", + "fields": ["adminAuditLog"], + "persisted_query": false + }]), + ); + assert!(!eval_l7(&engine, &input)); + } + #[test] fn l7_endpoint_path_scopes_rest_and_graphql_on_same_host() { let data = r#" @@ -2463,6 +2601,120 @@ network_policies: assert!(l7.allow_encoded_slash); } + #[test] + fn l7_endpoint_config_preserves_proto_websocket_credential_rewrite() { + let mut network_policies = std::collections::HashMap::new(); + network_policies.insert( + "gateway".to_string(), + NetworkPolicyRule { + name: "gateway".to_string(), + endpoints: vec![NetworkEndpoint { + host: "gateway.example.com".to_string(), + port: 443, + protocol: "rest".to_string(), + enforcement: "enforce".to_string(), + access: "full".to_string(), + websocket_credential_rewrite: true, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/node".to_string(), + ..Default::default() + }], + }, + ); + let proto = ProtoSandboxPolicy { + version: 1, + filesystem: Some(ProtoFs { + include_workdir: true, + read_only: vec![], + read_write: vec![], + }), + landlock: Some(openshell_core::proto::LandlockPolicy { + compatibility: "best_effort".to_string(), + }), + process: Some(ProtoProc { + run_as_user: "sandbox".to_string(), + run_as_group: "sandbox".to_string(), + }), + network_policies, + }; + + let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); + let input = NetworkInput { + host: "gateway.example.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + + let config = engine + .query_endpoint_config(&input) + .unwrap() + .expect("endpoint config"); + let l7 = crate::l7::parse_l7_config(&config).unwrap(); + assert!(l7.websocket_credential_rewrite); + } + + #[test] + fn l7_endpoint_config_preserves_proto_request_body_credential_rewrite() { + let mut network_policies = std::collections::HashMap::new(); + network_policies.insert( + "slack".to_string(), + NetworkPolicyRule { + name: "slack".to_string(), + endpoints: vec![NetworkEndpoint { + host: "slack.com".to_string(), + port: 443, + protocol: "rest".to_string(), + enforcement: "enforce".to_string(), + access: "read-write".to_string(), + request_body_credential_rewrite: true, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/node".to_string(), + ..Default::default() + }], + }, + ); + let proto = ProtoSandboxPolicy { + version: 1, + filesystem: Some(ProtoFs { + include_workdir: true, + read_only: vec![], + read_write: vec![], + }), + landlock: Some(openshell_core::proto::LandlockPolicy { + compatibility: "best_effort".to_string(), + }), + process: Some(ProtoProc { + run_as_user: "sandbox".to_string(), + run_as_group: "sandbox".to_string(), + }), + network_policies, + }; + + let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); + let input = NetworkInput { + host: "slack.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + + let config = engine + .query_endpoint_config(&input) + .unwrap() + .expect("endpoint config"); + let l7 = crate::l7::parse_l7_config(&config).unwrap(); + assert!(l7.request_body_credential_rewrite); + } + #[test] fn l7_endpoint_config_none_for_l4_only() { let engine = l7_engine(); diff --git a/crates/openshell-sandbox/src/policy_local.rs b/crates/openshell-sandbox/src/policy_local.rs index 21556ec6a..165b0c1bd 100644 --- a/crates/openshell-sandbox/src/policy_local.rs +++ b/crates/openshell-sandbox/src/policy_local.rs @@ -619,6 +619,8 @@ fn network_endpoint_from_json( ports, deny_rules, allow_encoded_slash: endpoint.allow_encoded_slash, + websocket_credential_rewrite: false, + request_body_credential_rewrite: false, // GraphQL persisted-query knobs and path scoping default empty — // agent proposals don't author them today. persisted_queries: String::new(), diff --git a/crates/openshell-sandbox/src/provider_credentials.rs b/crates/openshell-sandbox/src/provider_credentials.rs index bd28824ae..ffe0148a4 100644 --- a/crates/openshell-sandbox/src/provider_credentials.rs +++ b/crates/openshell-sandbox/src/provider_credentials.rs @@ -122,6 +122,10 @@ mod tests { resolver.resolve_placeholder("openshell:resolve:env:v11_GITHUB_TOKEN"), Some("new") ); + assert_eq!( + resolver.resolve_placeholder("openshell:resolve:env:GITHUB_TOKEN"), + Some("new") + ); } #[test] diff --git a/crates/openshell-sandbox/src/proxy.rs b/crates/openshell-sandbox/src/proxy.rs index f20e51655..dca522c12 100644 --- a/crates/openshell-sandbox/src/proxy.rs +++ b/crates/openshell-sandbox/src/proxy.rs @@ -10,7 +10,7 @@ use crate::opa::{NetworkAction, OpaEngine, PolicyGenerationGuard}; use crate::policy::ProxyPolicy; use crate::policy_local::{POLICY_LOCAL_HOST, PolicyLocalContext}; use crate::provider_credentials::ProviderCredentialState; -use crate::secrets::{SecretResolver, rewrite_header_line}; +use crate::secrets::{SecretResolver, rewrite_header_line_checked}; use miette::{IntoDiagnostic, Result}; use openshell_core::net::{is_always_blocked_ip, is_internal_ip}; use openshell_ocsf::{ @@ -2282,6 +2282,11 @@ fn rewrite_forward_request( .windows(4) .position(|w| w == b"\r\n\r\n") .map_or(used, |p| p + 4); + let websocket_upgrade = crate::l7::rest::request_is_websocket_upgrade(&raw[..header_end]); + let upstream_path = match secret_resolver { + Some(resolver) => crate::secrets::rewrite_target_for_eval(path, resolver)?.resolved, + None => path.to_string(), + }; let header_str = String::from_utf8_lossy(&raw[..header_end]); let lines = header_str.split("\r\n").collect::>(); @@ -2298,7 +2303,7 @@ fn rewrite_forward_request( if parts.len() == 3 { output.extend_from_slice(parts[0].as_bytes()); output.push(b' '); - output.extend_from_slice(path.as_bytes()); + output.extend_from_slice(upstream_path.as_bytes()); output.push(b' '); output.extend_from_slice(parts[2].as_bytes()); } else { @@ -2325,14 +2330,19 @@ fn rewrite_forward_request( // Replace Connection header if lower.starts_with("connection:") { has_connection = true; + if websocket_upgrade { + output.extend_from_slice(line.as_bytes()); + output.extend_from_slice(b"\r\n"); + continue; + } output.extend_from_slice(b"Connection: close\r\n"); continue; } - let rewritten_line = secret_resolver.map_or_else( - || line.to_string(), - |resolver| rewrite_header_line(line, resolver), - ); + let rewritten_line = match secret_resolver { + Some(resolver) => rewrite_header_line_checked(line, resolver)?, + None => line.to_string(), + }; output.extend_from_slice(rewritten_line.as_bytes()); output.extend_from_slice(b"\r\n"); @@ -2343,7 +2353,7 @@ fn rewrite_forward_request( } // Inject missing headers - if !has_connection { + if !has_connection && !websocket_upgrade { output.extend_from_slice(b"Connection: close\r\n"); } if !has_via { @@ -2361,7 +2371,9 @@ fn rewrite_forward_request( // Fail-closed: scan for any remaining unresolved placeholders if secret_resolver.is_some() { let output_str = String::from_utf8_lossy(&output); - if output_str.contains(crate::secrets::PLACEHOLDER_PREFIX_PUBLIC) { + if output_str.contains(crate::secrets::PLACEHOLDER_PREFIX_PUBLIC) + || output_str.contains(crate::secrets::PROVIDER_ALIAS_MARKER_PUBLIC) + { return Err(crate::secrets::UnresolvedPlaceholderError { location: "header" }); } } @@ -2369,13 +2381,20 @@ fn rewrite_forward_request( Ok(output) } +struct ForwardRelayOptions<'a> { + generation_guard: &'a PolicyGenerationGuard, + websocket_extensions: crate::l7::rest::WebSocketExtensionMode, + secret_resolver: Option<&'a SecretResolver>, + request_body_credential_rewrite: bool, +} + async fn relay_rewritten_forward_request( method: &str, path: &str, rewritten: Vec, client: &mut C, upstream: &mut U, - generation_guard: &PolicyGenerationGuard, + options: ForwardRelayOptions<'_>, ) -> Result where C: TokioAsyncRead + TokioAsyncWrite + Unpin, @@ -2396,16 +2415,58 @@ where body_length, }; - crate::l7::rest::relay_http_request_with_resolver_guarded( + crate::l7::rest::relay_http_request_with_options_guarded( &req, client, upstream, - None, - Some(generation_guard), + crate::l7::rest::RelayRequestOptions { + resolver: options.secret_resolver, + generation_guard: Some(options.generation_guard), + websocket_extensions: options.websocket_extensions, + request_body_credential_rewrite: options.request_body_credential_rewrite, + }, ) .await } +fn forward_websocket_upgrade_settings( + config: &crate::l7::L7EndpointConfig, + websocket_request: bool, + secret_resolver: Option>, +) -> ( + crate::l7::rest::WebSocketExtensionMode, + crate::l7::relay::UpgradeRelayOptions<'static>, +) { + let websocket_credential_rewrite = matches!( + config.protocol, + crate::l7::L7Protocol::Rest | crate::l7::L7Protocol::Websocket + ) && config.websocket_credential_rewrite; + let websocket_extensions = if config.protocol == crate::l7::L7Protocol::Websocket + || (config.protocol == crate::l7::L7Protocol::Rest && websocket_credential_rewrite) + { + crate::l7::rest::WebSocketExtensionMode::PermessageDeflate + } else { + crate::l7::rest::WebSocketExtensionMode::Preserve + }; + + let upgrade_options = crate::l7::relay::UpgradeRelayOptions { + websocket_request, + websocket: crate::l7::relay::WebSocketUpgradeBehavior { + credential_rewrite: websocket_credential_rewrite, + ..Default::default() + }, + secret_resolver: if websocket_credential_rewrite { + secret_resolver + } else { + None + }, + enforcement: config.enforcement, + ..Default::default() + }; + + (websocket_extensions, upgrade_options) +} + /// Handle a plain HTTP forward proxy request (non-CONNECT). /// /// Public IPs are allowed through when the endpoint passes OPA evaluation. @@ -2623,6 +2684,9 @@ async fn handle_forward_proxy( }; let mut forward_request_bytes = buf[..used].to_vec(); let mut upstream_target = path.clone(); + let mut websocket_extensions = crate::l7::rest::WebSocketExtensionMode::Preserve; + let mut upgrade_options = crate::l7::relay::UpgradeRelayOptions::default(); + let mut request_body_credential_rewrite = false; // 4b. If the endpoint has L7 config, evaluate the request against // L7 policy. The forward proxy handles exactly one request per @@ -2760,6 +2824,16 @@ async fn handle_forward_proxy( .await?; return Ok(()); }; + let websocket_request = + crate::l7::rest::request_is_websocket_upgrade(&forward_request_bytes); + (websocket_extensions, upgrade_options) = forward_websocket_upgrade_settings( + &l7_config.config, + websocket_request, + secret_resolver.clone(), + ); + request_body_credential_rewrite = l7_config.config.protocol == crate::l7::L7Protocol::Rest + && l7_config.config.request_body_credential_rewrite; + upgrade_options.policy_name = matched_policy.clone().unwrap_or_default(); let graphql = if l7_config.config.protocol == crate::l7::L7Protocol::Graphql { let header_end = forward_request_bytes .windows(4) @@ -3222,11 +3296,29 @@ async fn handle_forward_proxy( rewritten, client, &mut upstream, - &forward_generation_guard, + ForwardRelayOptions { + generation_guard: &forward_generation_guard, + websocket_extensions, + secret_resolver: secret_resolver.as_deref(), + request_body_credential_rewrite, + }, ) .await?; - if let crate::l7::provider::RelayOutcome::Upgraded { overflow } = outcome { - crate::l7::relay::handle_upgrade(client, &mut upstream, overflow, &host_lc, port).await?; + if let crate::l7::provider::RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + } = outcome + { + upgrade_options.websocket.permessage_deflate = websocket_permessage_deflate; + crate::l7::relay::handle_upgrade( + client, + &mut upstream, + overflow, + &host_lc, + port, + upgrade_options, + ) + .await?; } Ok(()) @@ -3298,6 +3390,62 @@ fn is_benign_relay_error(err: &miette::Report) -> bool { mod tests { use super::*; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; + use std::sync::Arc; + + fn websocket_l7_config( + protocol: crate::l7::L7Protocol, + websocket_credential_rewrite: bool, + ) -> crate::l7::L7EndpointConfig { + crate::l7::L7EndpointConfig { + protocol, + path: "/**".to_string(), + tls: crate::l7::TlsMode::Auto, + enforcement: crate::l7::EnforcementMode::Enforce, + graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, + allow_encoded_slash: false, + websocket_credential_rewrite, + request_body_credential_rewrite: false, + websocket_graphql_policy: false, + } + } + + #[test] + fn forward_websocket_upgrade_enables_rewrite_for_native_websocket_endpoint() { + let (_, resolver) = SecretResolver::from_provider_env( + [("DISCORD_BOT_TOKEN".to_string(), "discord-real".to_string())] + .into_iter() + .collect(), + ); + + let (extensions, options) = forward_websocket_upgrade_settings( + &websocket_l7_config(crate::l7::L7Protocol::Websocket, true), + true, + resolver.map(Arc::new), + ); + + assert_eq!( + extensions, + crate::l7::rest::WebSocketExtensionMode::PermessageDeflate + ); + assert!(options.websocket.credential_rewrite); + assert!(options.secret_resolver.is_some()); + } + + #[test] + fn forward_websocket_upgrade_preserves_rest_without_rewrite() { + let (extensions, options) = forward_websocket_upgrade_settings( + &websocket_l7_config(crate::l7::L7Protocol::Rest, false), + true, + None, + ); + + assert_eq!( + extensions, + crate::l7::rest::WebSocketExtensionMode::Preserve + ); + assert!(!options.websocket.credential_rewrite); + assert!(options.secret_resolver.is_none()); + } #[test] fn l7_route_selection_prefers_path_specific_graphql_endpoint() { @@ -3310,6 +3458,9 @@ mod tests { enforcement: crate::l7::EnforcementMode::Enforce, graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, allow_encoded_slash: false, + websocket_credential_rewrite: false, + request_body_credential_rewrite: false, + websocket_graphql_policy: false, }, }, L7ConfigSnapshot { @@ -3320,6 +3471,9 @@ mod tests { enforcement: crate::l7::EnforcementMode::Enforce, graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, allow_encoded_slash: false, + websocket_credential_rewrite: false, + request_body_credential_rewrite: false, + websocket_graphql_policy: false, }, }, ]; @@ -4364,6 +4518,28 @@ mod tests { assert!(!result_str.contains("openshell:resolve:env:ANTHROPIC_API_KEY")); } + #[test] + fn test_forward_rewrite_preserves_websocket_upgrade_connection_header() { + let raw = "GET http://gateway.example.test/ws HTTP/1.1\r\n\ + Host: gateway.example.test\r\n\ + Upgrade: websocket\r\n\ + Connection: keep-alive, Upgrade\r\n\ + Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ + Sec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover\r\n\ + Sec-WebSocket-Version: 13\r\n\r\n"; + + let result = rewrite_forward_request(raw.as_bytes(), raw.len(), "/ws", None) + .expect("websocket forward rewrite should succeed"); + let result_str = String::from_utf8_lossy(&result); + + assert!(result_str.starts_with("GET /ws HTTP/1.1\r\n")); + assert!(result_str.contains("Connection: keep-alive, Upgrade\r\n")); + assert!( + !result_str.contains("Connection: close\r\n"), + "websocket forward proxy must not strip the upgrade token" + ); + } + #[tokio::test] async fn test_forward_relay_guard_blocks_stale_generation_before_upstream_write() { let policy = include_str!("../data/sandbox-policy.rego"); @@ -4386,7 +4562,12 @@ mod tests { rewritten, &mut proxy_to_client, &mut proxy_to_upstream, - &guard, + ForwardRelayOptions { + generation_guard: &guard, + websocket_extensions: crate::l7::rest::WebSocketExtensionMode::Preserve, + secret_resolver: None, + request_body_credential_rewrite: false, + }, ) .await; assert!( @@ -4424,7 +4605,12 @@ mod tests { rewritten, &mut proxy_to_client, &mut proxy_to_upstream, - &guard, + ForwardRelayOptions { + generation_guard: &guard, + websocket_extensions: crate::l7::rest::WebSocketExtensionMode::Preserve, + secret_resolver: None, + request_body_credential_rewrite: false, + }, ) .await; assert!(result.is_err(), "forward relay must reject CL/TE ambiguity"); diff --git a/crates/openshell-sandbox/src/secrets.rs b/crates/openshell-sandbox/src/secrets.rs index d645e1482..54c43d07a 100644 --- a/crates/openshell-sandbox/src/secrets.rs +++ b/crates/openshell-sandbox/src/secrets.rs @@ -6,9 +6,11 @@ use std::collections::HashMap; use std::fmt; const PLACEHOLDER_PREFIX: &str = "openshell:resolve:env:"; +const PROVIDER_ALIAS_MARKER: &str = "OPENSHELL-RESOLVE-ENV-"; /// Public access to the placeholder prefix for fail-closed scanning in other modules. pub const PLACEHOLDER_PREFIX_PUBLIC: &str = PLACEHOLDER_PREFIX; +pub const PROVIDER_ALIAS_MARKER_PUBLIC: &str = PROVIDER_ALIAS_MARKER; /// Characters that are valid in an env var key name (used to extract /// placeholder boundaries within concatenated strings like path segments). @@ -16,6 +18,22 @@ fn is_env_key_char(b: u8) -> bool { b.is_ascii_alphanumeric() || b == b'_' } +fn is_alias_token_char(b: u8) -> bool { + b.is_ascii_alphanumeric() || matches!(b, b'_' | b'-' | b'.' | b'~') +} + +fn contains_raw_reserved_marker(value: &str) -> bool { + value.contains(PLACEHOLDER_PREFIX) || value.contains(PROVIDER_ALIAS_MARKER) +} + +pub fn contains_reserved_credential_marker(value: &str) -> bool { + if contains_raw_reserved_marker(value) { + return true; + } + let decoded = percent_decode(value); + contains_raw_reserved_marker(&decoded) +} + // --------------------------------------------------------------------------- // Error and result types // --------------------------------------------------------------------------- @@ -31,7 +49,7 @@ impl fmt::Display for UnresolvedPlaceholderError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "unresolved credential placeholder in {}: detected openshell:resolve:env:* token that could not be resolved", + "unresolved credential placeholder in {}: detected reserved credential token that could not be resolved", self.location ) } @@ -90,8 +108,12 @@ impl SecretResolver { for (key, value) in provider_env { let placeholder = placeholder_for_env_key_for_revision(&key, revision); - child_env.insert(key, placeholder.clone()); - by_placeholder.insert(placeholder, value); + let canonical_placeholder = (revision != 0).then(|| placeholder_for_env_key(&key)); + child_env.insert(key.clone(), placeholder.clone()); + by_placeholder.insert(placeholder, value.clone()); + if let Some(canonical_placeholder) = canonical_placeholder { + by_placeholder.insert(canonical_placeholder, value.clone()); + } } (child_env, Some(Self { by_placeholder })) @@ -114,7 +136,13 @@ impl SecretResolver { /// Returns `None` if the placeholder is unknown or the resolved value /// contains prohibited control characters (CRLF, null byte). pub(crate) fn resolve_placeholder(&self, value: &str) -> Option<&str> { - let secret = self.by_placeholder.get(value).map(String::as_str)?; + let secret = if let Some(secret) = self.by_placeholder.get(value) { + secret.as_str() + } else { + let key = alias_env_key(value)?; + let canonical = placeholder_for_env_key(key); + self.by_placeholder.get(&canonical).map(String::as_str)? + }; match validate_resolved_secret(secret) { Ok(s) => Some(s), Err(reason) => { @@ -128,10 +156,13 @@ impl SecretResolver { } } - pub(crate) fn rewrite_header_value(&self, value: &str) -> Option { + pub(crate) fn rewrite_header_value( + &self, + value: &str, + ) -> Result, UnresolvedPlaceholderError> { // Direct placeholder match: `x-api-key: openshell:resolve:env:KEY` if let Some(secret) = self.resolve_placeholder(value.trim()) { - return Some(secret.to_string()); + return Ok(Some(secret.to_string())); } let trimmed = value.trim(); @@ -142,56 +173,228 @@ impl SecretResolver { .strip_prefix("Basic ") .or_else(|| trimmed.strip_prefix("basic ")) .map(str::trim) - && let Some(rewritten) = self.rewrite_basic_auth_token(encoded) + && let Some(rewritten) = self.rewrite_basic_auth_token(encoded)? { - return Some(format!("Basic {rewritten}")); + return Ok(Some(format!("Basic {rewritten}"))); } // Prefixed placeholder: `Bearer openshell:resolve:env:KEY` - let split_at = trimmed.find(char::is_whitespace)?; + let Some(split_at) = trimmed.find(char::is_whitespace) else { + if contains_reserved_credential_marker(trimmed) { + return Err(UnresolvedPlaceholderError { location: "header" }); + } + return Ok(None); + }; let prefix = &trimmed[..split_at]; let candidate = trimmed[split_at..].trim(); - let secret = self.resolve_placeholder(candidate)?; - Some(format!("{prefix} {secret}")) + if let Some(secret) = self.resolve_placeholder(candidate) { + return Ok(Some(format!("{prefix} {secret}"))); + } + + if contains_reserved_credential_marker(candidate) { + return Err(UnresolvedPlaceholderError { location: "header" }); + } + + Ok(None) + } + + pub(crate) fn rewrite_text_placeholders( + &self, + text: &mut String, + location: &'static str, + ) -> Result { + if !contains_raw_reserved_marker(text) { + return Ok(0); + } + + let mut rewritten = String::with_capacity(text.len()); + let mut pos = 0; + let mut replacements = 0; + + while pos < text.len() { + let next_canonical = text[pos..].find(PLACEHOLDER_PREFIX).map(|p| pos + p); + let next_alias = text[pos..].find(PROVIDER_ALIAS_MARKER).map(|marker_pos| { + let marker_abs = pos + marker_pos; + alias_start_for_marker(text, marker_abs) + }); + let Some(abs_start) = [next_canonical, next_alias].into_iter().flatten().min() else { + rewritten.push_str(&text[pos..]); + break; + }; + + rewritten.push_str(&text[pos..abs_start]); + + if text[abs_start..].starts_with(PLACEHOLDER_PREFIX) { + let Some((token_end, token)) = self.credential_token_at(text, abs_start) else { + return Err(UnresolvedPlaceholderError { location }); + }; + let Some(secret) = self.resolve_placeholder(token) else { + return Err(UnresolvedPlaceholderError { location }); + }; + rewritten.push_str(secret); + replacements += 1; + pos = token_end; + continue; + } + + if let Some((token_end, token)) = alias_token_at(text, abs_start) { + let Some(secret) = self.resolve_placeholder(token) else { + return Err(UnresolvedPlaceholderError { location }); + }; + rewritten.push_str(secret); + replacements += 1; + pos = token_end; + continue; + } + + return Err(UnresolvedPlaceholderError { location }); + } + + if contains_raw_reserved_marker(&rewritten) { + return Err(UnresolvedPlaceholderError { location }); + } + + *text = rewritten; + Ok(replacements) + } + + /// Rewrite credential placeholders inside a WebSocket text message. + /// + /// The message is mutated only after all placeholders resolve + /// successfully. The return value is the number of replacements; callers + /// must not log the rewritten text. + pub(crate) fn rewrite_websocket_text_placeholders( + &self, + text: &mut String, + ) -> Result { + self.rewrite_text_placeholders(text, "websocket") + } + + fn credential_token_at<'a>( + &'a self, + text: &'a str, + abs_start: usize, + ) -> Option<(usize, &'a str)> { + self.longest_known_token_match(text, abs_start) + .or_else(|| canonical_token_at(text, abs_start)) + .or_else(|| alias_token_at(text, abs_start)) + } + + fn longest_known_token_match<'a>( + &'a self, + text: &str, + abs_start: usize, + ) -> Option<(usize, &'a str)> { + let suffix = &text[abs_start..]; + self.by_placeholder + .keys() + .filter_map(|placeholder| { + if !suffix.starts_with(placeholder) { + return None; + } + let key_end = abs_start + placeholder.len(); + let boundary_ok = token_boundary_ok(text, abs_start, key_end, placeholder); + boundary_ok.then_some((key_end, placeholder.as_str())) + }) + .max_by_key(|(_, placeholder)| placeholder.len()) } /// Decode a Base64-encoded Basic auth token, resolve any placeholders in /// the decoded `username:password` string, and re-encode. /// /// Returns `None` if decoding fails or no placeholders are found. - fn rewrite_basic_auth_token(&self, encoded: &str) -> Option { + fn rewrite_basic_auth_token( + &self, + encoded: &str, + ) -> Result, UnresolvedPlaceholderError> { let b64 = base64::engine::general_purpose::STANDARD; - let decoded_bytes = b64.decode(encoded.trim()).ok()?; - let decoded = std::str::from_utf8(&decoded_bytes).ok()?; - - // Check if the decoded string contains any placeholder - if !decoded.contains(PLACEHOLDER_PREFIX) { - return None; + let Some(decoded_bytes) = b64.decode(encoded.trim()).ok() else { + return Ok(None); + }; + let Some(decoded) = std::str::from_utf8(&decoded_bytes).ok() else { + return Ok(None); + }; + + if !contains_raw_reserved_marker(decoded) { + return Ok(None); } - // Rewrite all placeholder occurrences in the decoded string let mut rewritten = decoded.to_string(); - for (placeholder, secret) in &self.by_placeholder { - if rewritten.contains(placeholder.as_str()) { - // Validate the resolved secret for control characters - if validate_resolved_secret(secret).is_err() { - tracing::warn!( - location = "basic_auth", - "credential resolution rejected: resolved value contains prohibited characters" - ); - return None; - } - rewritten = rewritten.replace(placeholder.as_str(), secret); - } - } + let replacements = self.rewrite_text_placeholders(&mut rewritten, "header")?; - // Only return if we actually changed something - if rewritten == decoded { - return None; + if replacements == 0 { + return Ok(None); } - Some(b64.encode(rewritten.as_bytes())) + Ok(Some(b64.encode(rewritten.as_bytes()))) + } +} + +fn alias_start_for_marker(text: &str, marker_abs: usize) -> usize { + let mut start = marker_abs; + let bytes = text.as_bytes(); + while start > 0 && is_alias_token_char(bytes[start - 1]) { + start -= 1; + } + start +} + +fn canonical_token_at(text: &str, abs_start: usize) -> Option<(usize, &str)> { + if !text[abs_start..].starts_with(PLACEHOLDER_PREFIX) { + return None; + } + let key_start = abs_start + PLACEHOLDER_PREFIX.len(); + let key_end = text[key_start..] + .bytes() + .position(|b| !is_env_key_char(b)) + .map_or(text.len(), |p| key_start + p); + (key_end > key_start).then_some((key_end, &text[abs_start..key_end])) +} + +fn alias_token_at(text: &str, abs_start: usize) -> Option<(usize, &str)> { + let suffix = &text[abs_start..]; + let marker_rel = suffix.find(PROVIDER_ALIAS_MARKER)?; + if marker_rel == 0 { + return None; + } + let key_start = abs_start + marker_rel + PROVIDER_ALIAS_MARKER.len(); + let key_end = text[key_start..] + .bytes() + .position(|b| !is_env_key_char(b)) + .map_or(text.len(), |p| key_start + p); + if key_end == key_start { + return None; + } + let before_ok = abs_start == 0 || !is_alias_token_char(text.as_bytes()[abs_start - 1]); + let after_ok = key_end == text.len() || !is_alias_token_char(text.as_bytes()[key_end]); + (before_ok && after_ok).then_some((key_end, &text[abs_start..key_end])) +} + +fn alias_env_key(token: &str) -> Option<&str> { + let marker_start = token.find(PROVIDER_ALIAS_MARKER)?; + if marker_start == 0 { + return None; + } + if !token[..marker_start].bytes().all(is_alias_token_char) { + return None; + } + let key_start = marker_start + PROVIDER_ALIAS_MARKER.len(); + let key_end = token[key_start..] + .bytes() + .position(|b| !is_env_key_char(b)) + .map_or(token.len(), |p| key_start + p); + (key_end == token.len() && key_end > key_start).then_some(&token[key_start..key_end]) +} + +fn token_boundary_ok(text: &str, abs_start: usize, token_end: usize, token: &str) -> bool { + if token.starts_with(PLACEHOLDER_PREFIX) { + return token_end == text.len() + || !is_env_key_char(text.as_bytes()[token_end]) + || text[token_end..].starts_with(PLACEHOLDER_PREFIX); } + let before_ok = abs_start == 0 || !is_alias_token_char(text.as_bytes()[abs_start - 1]); + let after_ok = token_end == text.len() || !is_alias_token_char(text.as_bytes()[token_end]); + before_ok && after_ok } pub fn placeholder_for_env_key(key: &str) -> String { @@ -387,8 +590,9 @@ fn rewrite_request_line( return unchanged(); }; - // Only rewrite if the URI contains a placeholder - if !uri.contains(PLACEHOLDER_PREFIX) { + // Only rewrite if the URI contains a placeholder or a provider-shaped + // credential alias, including percent-encoded canonical placeholders. + if !contains_reserved_credential_marker(uri) { return unchanged(); } @@ -444,10 +648,6 @@ fn rewrite_uri_path( path: &str, resolver: &SecretResolver, ) -> Result, UnresolvedPlaceholderError> { - if !path.contains(PLACEHOLDER_PREFIX) { - return Ok(None); - } - let segments: Vec<&str> = path.split('/').collect(); let mut resolved_segments = Vec::with_capacity(segments.len()); let mut redacted_segments = Vec::with_capacity(segments.len()); @@ -455,7 +655,7 @@ fn rewrite_uri_path( for segment in &segments { let decoded = percent_decode(segment); - if !decoded.contains(PLACEHOLDER_PREFIX) { + if !contains_raw_reserved_marker(&decoded) { resolved_segments.push(segment.to_string()); redacted_segments.push(segment.to_string()); continue; @@ -495,28 +695,23 @@ fn rewrite_path_segment( let bytes = segment.as_bytes(); while pos < bytes.len() { - if let Some(start) = segment[pos..].find(PLACEHOLDER_PREFIX) { - let abs_start = pos + start; + let next_canonical = segment[pos..].find(PLACEHOLDER_PREFIX).map(|p| pos + p); + let next_alias = segment[pos..] + .find(PROVIDER_ALIAS_MARKER) + .map(|marker_pos| { + let marker_abs = pos + marker_pos; + alias_start_for_marker(segment, marker_abs) + }); + if let Some(abs_start) = [next_canonical, next_alias].into_iter().flatten().min() { // Copy literal prefix before the placeholder resolved.push_str(&segment[pos..abs_start]); redacted.push_str(&segment[pos..abs_start]); - // Extract the key name using the env var grammar: [A-Za-z_][A-Za-z0-9_]* - let key_start = abs_start + PLACEHOLDER_PREFIX.len(); - let key_end = segment[key_start..] - .bytes() - .position(|b| !is_env_key_char(b)) - .map_or(segment.len(), |p| key_start + p); - - if key_end == key_start { - // Empty key — not a valid placeholder, copy literally - resolved.push_str(&segment[abs_start..abs_start + PLACEHOLDER_PREFIX.len()]); - redacted.push_str(&segment[abs_start..abs_start + PLACEHOLDER_PREFIX.len()]); - pos = abs_start + PLACEHOLDER_PREFIX.len(); - continue; - } - - let full_placeholder = &segment[abs_start..key_end]; + let Some((token_end, full_placeholder)) = canonical_token_at(segment, abs_start) + .or_else(|| alias_token_at(segment, abs_start)) + else { + return Err(UnresolvedPlaceholderError { location: "path" }); + }; if let Some(secret) = resolver.resolve_placeholder(full_placeholder) { validate_credential_for_path(secret).map_err(|reason| { tracing::warn!( @@ -531,7 +726,7 @@ fn rewrite_path_segment( } else { return Err(UnresolvedPlaceholderError { location: "path" }); } - pos = key_end; + pos = token_end; } else { // No more placeholders in remainder resolved.push_str(&segment[pos..]); @@ -550,7 +745,7 @@ fn rewrite_uri_query_params( query: &str, resolver: &SecretResolver, ) -> Result, UnresolvedPlaceholderError> { - if !query.contains(PLACEHOLDER_PREFIX) { + if !contains_reserved_credential_marker(query) { return Ok(None); } @@ -561,15 +756,18 @@ fn rewrite_uri_query_params( for param in query.split('&') { if let Some((key, value)) = param.split_once('=') { let decoded_value = percent_decode(value); - if let Some(secret) = resolver.resolve_placeholder(&decoded_value) { - resolved_params.push(format!("{key}={}", percent_encode_query(secret))); + if contains_raw_reserved_marker(&decoded_value) { + let mut rewritten = decoded_value.clone(); + let replacements = + resolver.rewrite_text_placeholders(&mut rewritten, "query_param")?; + if replacements == 0 || contains_raw_reserved_marker(&rewritten) { + return Err(UnresolvedPlaceholderError { + location: "query_param", + }); + } + resolved_params.push(format!("{key}={}", percent_encode_query(&rewritten))); redacted_params.push(format!("{key}=[CREDENTIAL]")); any_rewritten = true; - } else if decoded_value.contains(PLACEHOLDER_PREFIX) { - // Placeholder detected but not resolved - return Err(UnresolvedPlaceholderError { - location: "query_param", - }); } else { resolved_params.push(param.to_string()); redacted_params.push(param.to_string()); @@ -639,41 +837,42 @@ pub fn rewrite_http_header_block( break; } - output.extend_from_slice(rewrite_header_line(line, resolver).as_bytes()); + output.extend_from_slice(rewrite_header_line_checked(line, resolver)?.as_bytes()); output.extend_from_slice(b"\r\n"); } output.extend_from_slice(b"\r\n"); output.extend_from_slice(&raw[header_end..]); - // Fail-closed scan: check for any remaining unresolved placeholders - // in both raw form and percent-decoded form of the output header block. + // Fail-closed scan: check for any remaining unresolved placeholders or + // provider-shaped aliases in both raw and percent-decoded header bytes. let output_header = String::from_utf8_lossy(&output[..output.len().min(header_end + 256)]); - if output_header.contains(PLACEHOLDER_PREFIX) { + if contains_reserved_credential_marker(&output_header) { return Err(UnresolvedPlaceholderError { location: "header" }); } - // Also check percent-decoded form of the request line (F5 — encoded placeholder bypass) - let rewritten_rl = output_header.split("\r\n").next().unwrap_or(""); - let decoded_rl = percent_decode(rewritten_rl); - if decoded_rl.contains(PLACEHOLDER_PREFIX) { - return Err(UnresolvedPlaceholderError { location: "path" }); - } - Ok(RewriteResult { rewritten: output, redacted_target: rl_result.redacted_target, }) } +#[cfg_attr(not(test), allow(dead_code))] pub fn rewrite_header_line(line: &str, resolver: &SecretResolver) -> String { + rewrite_header_line_checked(line, resolver).unwrap_or_else(|_| line.to_string()) +} + +pub fn rewrite_header_line_checked( + line: &str, + resolver: &SecretResolver, +) -> Result { let Some((name, value)) = line.split_once(':') else { - return line.to_string(); + return Ok(line.to_string()); }; - resolver.rewrite_header_value(value.trim()).map_or_else( - || line.to_string(), - |rewritten| format!("{name}: {rewritten}"), + resolver.rewrite_header_value(value.trim())?.map_or_else( + || Ok(line.to_string()), + |rewritten| Ok(format!("{name}: {rewritten}")), ) } @@ -688,12 +887,7 @@ pub fn rewrite_target_for_eval( target: &str, resolver: &SecretResolver, ) -> Result { - if !target.contains(PLACEHOLDER_PREFIX) { - // Also check percent-decoded form - let decoded = percent_decode(target); - if decoded.contains(PLACEHOLDER_PREFIX) { - return Err(UnresolvedPlaceholderError { location: "path" }); - } + if !contains_reserved_credential_marker(target) { return Ok(RewriteTargetResult { resolved: target.to_string(), redacted: target.to_string(), @@ -800,6 +994,50 @@ mod tests { ); } + #[test] + fn rewrites_provider_shaped_alias_header_values() { + let (_, resolver) = SecretResolver::from_provider_env( + [ + ("API_TOKEN".to_string(), "provider-real-token".to_string()), + ("CHAT_APP_TOKEN".to_string(), "app-real-token".to_string()), + ] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + + assert_eq!( + rewrite_header_line( + "Authorization: Bearer vendor-OPENSHELL-RESOLVE-ENV-API_TOKEN", + &resolver, + ), + "Authorization: Bearer provider-real-token" + ); + assert_eq!( + rewrite_header_line( + "x-app-token: token.v1-OPENSHELL-RESOLVE-ENV-CHAT_APP_TOKEN", + &resolver, + ), + "x-app-token: app-real-token" + ); + } + + #[test] + fn unresolved_provider_shaped_alias_fails_closed() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let raw = b"GET / HTTP/1.1\r\nAuthorization: Bearer vendor-OPENSHELL-RESOLVE-ENV-UNKNOWN_TOKEN\r\n\r\n"; + + let err = rewrite_http_header_block(raw, Some(&resolver)) + .expect_err("unknown alias should fail closed"); + + assert_eq!(err.location, "header"); + } + #[test] fn rewrites_http_header_blocks_and_preserves_body() { let (_, resolver) = SecretResolver::from_provider_env( @@ -1410,6 +1648,29 @@ mod tests { ); } + #[test] + fn percent_encoded_canonical_placeholder_in_query_rewrites() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let encoded = "openshell%3Aresolve%3Aenv%3AAPI_TOKEN"; + let raw = format!("GET /api?token={encoded} HTTP/1.1\r\nHost: x\r\n\r\n"); + + let result = + rewrite_http_header_block(raw.as_bytes(), Some(&resolver)).expect("should rewrite"); + let rewritten = String::from_utf8(result.rewritten).expect("utf8"); + + assert!(rewritten.starts_with("GET /api?token=provider-real-token HTTP/1.1")); + assert!(!rewritten.contains("openshell")); + assert_eq!( + result.redacted_target.as_deref(), + Some("/api?token=[CREDENTIAL]") + ); + } + #[test] fn all_resolved_succeeds() { let (child_env, resolver) = SecretResolver::from_provider_env( @@ -1444,6 +1705,129 @@ mod tests { assert_eq!(raw.as_slice(), result.rewritten.as_slice()); } + #[test] + fn rewrite_websocket_text_replaces_placeholders_and_returns_count() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [ + ("DISCORD_BOT_TOKEN".to_string(), "real-token".to_string()), + ("APP_ID".to_string(), "app-123".to_string()), + ] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let token = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let app_id = child_env.get("APP_ID").unwrap(); + let mut payload = + format!(r#"{{"op":2,"d":{{"token":"{token}","properties":{{"app":"{app_id}"}}}}}}"#); + + let count = resolver + .rewrite_websocket_text_placeholders(&mut payload) + .expect("rewrite should succeed"); + + assert_eq!(count, 2); + assert!(payload.contains(r#""token":"real-token""#)); + assert!(payload.contains(r#""app":"app-123""#)); + assert!(!payload.contains(PLACEHOLDER_PREFIX)); + } + + #[test] + fn rewrite_websocket_text_replaces_provider_shaped_alias() { + let (_, resolver) = SecretResolver::from_provider_env( + [("APP_TOKEN".to_string(), "app-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let mut payload = r#"{"token":"provider-OPENSHELL-RESOLVE-ENV-APP_TOKEN"}"#.to_string(); + + let count = resolver + .rewrite_websocket_text_placeholders(&mut payload) + .expect("alias should rewrite"); + + assert_eq!(count, 1); + assert_eq!(payload, r#"{"token":"app-real-token"}"#); + } + + #[test] + fn rewrite_websocket_text_without_placeholder_is_unchanged() { + let (_, resolver) = SecretResolver::from_provider_env( + [("KEY".to_string(), "secret".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let mut payload = r#"{"op":1,"d":42}"#.to_string(); + + let count = resolver + .rewrite_websocket_text_placeholders(&mut payload) + .expect("rewrite should succeed"); + + assert_eq!(count, 0); + assert_eq!(payload, r#"{"op":1,"d":42}"#); + } + + #[test] + fn rewrite_websocket_text_unknown_placeholder_fails_closed_without_mutating() { + let (_, resolver) = SecretResolver::from_provider_env( + [("KEY".to_string(), "secret".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let original = r#"{"token":"openshell:resolve:env:UNKNOWN"}"#.to_string(); + let mut payload = original.clone(); + + let err = resolver + .rewrite_websocket_text_placeholders(&mut payload) + .expect_err("unknown placeholder should fail"); + + assert_eq!(err.location, "websocket"); + assert_eq!(payload, original); + } + + #[test] + fn rewrite_websocket_text_handles_repeated_adjacent_and_unicode_placeholders() { + let (child_env, resolver) = SecretResolver::from_provider_env( + [ + ("TOKEN".to_string(), "tok".to_string()), + ("APP".to_string(), "app".to_string()), + ] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let token = child_env.get("TOKEN").unwrap(); + let app = child_env.get("APP").unwrap(); + let mut payload = format!("prefix-☃-{token}{app}-{token}-suffix"); + + let count = resolver + .rewrite_websocket_text_placeholders(&mut payload) + .expect("rewrite should succeed"); + + assert_eq!(count, 3); + assert_eq!(payload, "prefix-☃-tokapp-tok-suffix"); + } + + #[test] + fn rewrite_websocket_text_placeholder_like_prefix_fails_without_mutating() { + let (_, resolver) = SecretResolver::from_provider_env( + [("KEY".to_string(), "secret".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let original = "openshell:resolve:env:-not-a-key".to_string(); + let mut payload = original.clone(); + + let err = resolver + .rewrite_websocket_text_placeholders(&mut payload) + .expect_err("placeholder-like prefix should fail closed"); + + assert_eq!(err.location, "websocket"); + assert_eq!(payload, original); + } + // === Redaction tests === #[test] diff --git a/crates/openshell-sandbox/tests/websocket_upgrade.rs b/crates/openshell-sandbox/tests/websocket_upgrade.rs index e4cd232ce..b35076a9a 100644 --- a/crates/openshell-sandbox/tests/websocket_upgrade.rs +++ b/crates/openshell-sandbox/tests/websocket_upgrade.rs @@ -124,7 +124,7 @@ async fn websocket_upgrade_through_l7_relay_exchanges_message() { .expect("relay should succeed"); match outcome { - RelayOutcome::Upgraded { overflow } => { + RelayOutcome::Upgraded { overflow, .. } => { // This is what handle_upgrade() does in relay.rs if !overflow.is_empty() { client_proxy.write_all(&overflow).await.unwrap(); diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index d5a47bcba..885dbc9ad 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -216,6 +216,12 @@ fn summarize_endpoint(endpoint: &NetworkEndpoint) -> String { if !endpoint.tls.is_empty() { parts.push(format!("tls={}", endpoint.tls)); } + if endpoint.websocket_credential_rewrite { + parts.push("websocket_credential_rewrite=true".to_string()); + } + if endpoint.request_body_credential_rewrite { + parts.push("request_body_credential_rewrite=true".to_string()); + } if !endpoint.allowed_ips.is_empty() { parts.push(format!("allowed_ips={}", endpoint.allowed_ips.len())); } @@ -4318,6 +4324,62 @@ mod tests { ); } + #[test] + fn summarize_cli_policy_merge_op_formats_websocket_credential_rewrite() { + let operation = PolicyMergeOp::AddRule { + rule_name: "realtime_api".to_string(), + rule: NetworkPolicyRule { + name: "realtime_api".to_string(), + endpoints: vec![NetworkEndpoint { + host: "realtime.example.com".to_string(), + port: 443, + protocol: "websocket".to_string(), + access: "read-write".to_string(), + enforcement: "enforce".to_string(), + websocket_credential_rewrite: true, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/node".to_string(), + ..Default::default() + }], + }, + }; + + assert_eq!( + summarize_cli_policy_merge_op(&operation), + "add-endpoint realtime_api endpoints=[realtime.example.com:443 protocol=websocket access=read-write enforcement=enforce websocket_credential_rewrite=true] binaries=[/usr/bin/node]" + ); + } + + #[test] + fn summarize_cli_policy_merge_op_formats_request_body_credential_rewrite() { + let operation = PolicyMergeOp::AddRule { + rule_name: "slack_api".to_string(), + rule: NetworkPolicyRule { + name: "slack_api".to_string(), + endpoints: vec![NetworkEndpoint { + host: "slack.com".to_string(), + port: 443, + protocol: "rest".to_string(), + access: "read-write".to_string(), + enforcement: "enforce".to_string(), + request_body_credential_rewrite: true, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/node".to_string(), + ..Default::default() + }], + }, + }; + + assert_eq!( + summarize_cli_policy_merge_op(&operation), + "add-endpoint slack_api endpoints=[slack.com:443 protocol=rest access=read-write enforcement=enforce request_body_credential_rewrite=true] binaries=[/usr/bin/node]" + ); + } + // ---- merge_chunk_into_policy ---- #[tokio::test] diff --git a/docs/reference/policy-schema.mdx b/docs/reference/policy-schema.mdx index a98e8087c..295f850df 100644 --- a/docs/reference/policy-schema.mdx +++ b/docs/reference/policy-schema.mdx @@ -155,7 +155,7 @@ Each endpoint defines a reachable destination and optional inspection rules. | `host` | string | Yes | Hostname or IP address. Supports wildcards: `*.example.com` matches any subdomain. | | `port` | integer | Yes | TCP port number. | | `path` | string | No | Optional HTTP path glob used to select between L7 endpoints that share the same host and port. Empty means all paths. Use this when REST and GraphQL live under the same host, such as `/repos/**` and `/graphql`. | -| `protocol` | string | No | Set to `rest` for HTTP method/path inspection or `graphql` for GraphQL operation inspection. Omit for TCP passthrough. | +| `protocol` | string | No | Set to `rest` for HTTP method/path inspection, `websocket` for RFC 6455 upgrade and client text-message inspection, or `graphql` for GraphQL-over-HTTP operation inspection. WebSocket endpoints can also use GraphQL operation rules for GraphQL-over-WebSocket traffic. Omit for TCP passthrough. | | `tls` | string | No | TLS handling mode. The proxy auto-detects TLS by peeking the first bytes of each connection and terminates it for inspected HTTPS traffic, so this field is optional in most cases. Set to `skip` to disable auto-detection for edge cases such as client-certificate mTLS or non-standard protocols. The values `terminate` and `passthrough` are deprecated and log a warning; they are still accepted for backward compatibility but have no effect on behavior. | | `enforcement` | string | No | `enforce` actively blocks disallowed requests. `audit` logs violations but allows traffic through. | | `access` | string | No | Access preset. One of `read-only`, `read-write`, or `full`. Mutually exclusive with `rules`. | @@ -163,19 +163,23 @@ Each endpoint defines a reachable destination and optional inspection rules. | `deny_rules` | list of deny rule objects | No | L7 deny rules that block specific requests even when allowed by `access` or `rules`. Deny rules take precedence over allow rules. | | `allowed_ips` | list of string | No | CIDR or IP allowlist for SSRF override. Entries overlapping loopback (`127.0.0.0/8`), link-local (`169.254.0.0/16`), or unspecified (`0.0.0.0`) are rejected at load time. | | `allow_encoded_slash` | bool | No | When `true`, L7 request parsing preserves `%2F` inside path segments instead of rejecting it. Use this for registries and APIs such as npm scoped packages (`/@scope%2Fname`). Defaults to `false`. | -| `persisted_queries` | string | No | GraphQL hash-only behavior. Default is `deny`; use `allow_registered` only with `graphql_persisted_queries`. | +| `websocket_credential_rewrite` | bool | No | When `true` on a `protocol: rest` or `protocol: websocket` endpoint, OpenShell rewrites credential placeholders in client-to-server WebSocket text messages after an allowed HTTP `101` upgrade. Binary frames are relayed but not rewritten. Defaults to `false`. | +| `request_body_credential_rewrite` | bool | No | When `true` on a `protocol: rest` endpoint, OpenShell rewrites credential placeholders in UTF-8 `application/json`, `application/x-www-form-urlencoded`, and `text/*` request bodies before forwarding upstream. The proxy buffers at most 256 KiB and updates `Content-Length` after rewriting. Defaults to `false`. | +| `persisted_queries` | string | No | GraphQL hash-only behavior for `protocol: graphql` and GraphQL-over-WebSocket operation policy. Default is `deny`; use `allow_registered` only with `graphql_persisted_queries`. | | `graphql_persisted_queries` | map | No | Trusted GraphQL persisted-query registry keyed by hash or saved-query ID. Values contain `operation_type`, optional `operation_name`, and optional root `fields`. | -| `graphql_max_body_bytes` | integer | No | Maximum GraphQL request body bytes buffered for inspection. Defaults to `65536`. | +| `graphql_max_body_bytes` | integer | No | Maximum GraphQL-over-HTTP request body bytes buffered for inspection. Defaults to `65536`. | + +Credential rewrite recognizes the canonical `openshell:resolve:env:KEY` placeholder form and whole-token provider-shaped aliases such as `provider-OPENSHELL-RESOLVE-ENV-API_TOKEN` when the referenced environment key exists in the configured provider credentials. #### Access Levels The `access` field accepts one of the following values: -| Value | REST expansion | GraphQL expansion | -|---|---|---| -| `full` | All methods and paths. | All operation types. | -| `read-only` | `GET`, `HEAD`, `OPTIONS`. | `query` operations. | -| `read-write` | `GET`, `HEAD`, `OPTIONS`, `POST`, `PUT`, `PATCH`. | `query` and `mutation` operations. | +| Value | REST expansion | WebSocket expansion | GraphQL expansion | +|---|---|---|---| +| `full` | All methods and paths. | WebSocket upgrade and all inspected client text-message paths. | All operation types. | +| `read-only` | `GET`, `HEAD`, `OPTIONS`. | WebSocket upgrade handshake only. | `query` operations. | +| `read-write` | `GET`, `HEAD`, `OPTIONS`, `POST`, `PUT`, `PATCH`. | WebSocket upgrade handshake and client text messages. | `query` and `mutation` operations. | #### Allow Rule Objects @@ -208,9 +212,31 @@ rules: any: ["v1.*", "v2.*"] ``` -##### GraphQL Allow Rule (`protocol: graphql`) +##### WebSocket Allow Rule (`protocol: websocket`) + +WebSocket allow rules match the RFC 6455 HTTP upgrade by path and match client-to-server text messages on the same upgraded connection with the synthetic `WEBSOCKET_TEXT` method. Binary frames are relayed but are not rewritten. + +| Field | Type | Required | Description | +|---|---|---|---| +| `method` | string | Yes | `GET` allows the upgrade handshake, `WEBSOCKET_TEXT` allows client text messages after upgrade, and `*` matches both inspected actions. | +| `path` | string | Yes | URL path pattern from the original upgrade request. Supports `*` and `**` glob syntax. | +| `query` | map | No | Query parameter matchers from the original upgrade request. Matcher syntax is the same as REST allow rules. | + +Example WebSocket allow rules: + +```yaml showLineNumbers={false} +rules: + - allow: + method: GET + path: /v1/realtime/** + - allow: + method: WEBSOCKET_TEXT + path: /v1/realtime/** +``` + +##### GraphQL Allow Rule (`protocol: graphql` or GraphQL-over-WebSocket) -GraphQL allow rules match parsed GraphQL operations by operation type, optional operation name, and optional root fields. +GraphQL allow rules match parsed GraphQL operations by operation type, optional operation name, and optional root fields. On `protocol: graphql`, they apply to GraphQL-over-HTTP `GET` and `POST` requests. On `protocol: websocket`, include a separate `GET` allow rule for the RFC 6455 upgrade, then use GraphQL allow rules for client operation messages using the `graphql-transport-ws` `subscribe` message type or the legacy `graphql-ws` `start` message type. | Field | Type | Required | Description | |---|---|---|---| @@ -231,6 +257,23 @@ rules: fields: [createIssue] ``` +Example GraphQL-over-WebSocket allow rules: + +```yaml showLineNumbers={false} +rules: + - allow: + method: GET + path: /graphql + - allow: + operation_type: subscription + fields: [messageAdded] + - allow: + operation_type: query + fields: [viewer] +``` + +Do not combine `method`, `path`, or `query` with `operation_type`, `operation_name`, or `fields` inside the same WebSocket rule. When a WebSocket endpoint has GraphQL operation policy, use GraphQL rules for client messages instead of a raw `WEBSOCKET_TEXT` allow rule. + #### Deny Rule Objects Blocks specific operations on endpoints that otherwise have broad access. Deny rules are evaluated after allow rules and take precedence: if a request matches any deny rule, it is blocked regardless of what the allow rules or access preset permit. @@ -263,9 +306,33 @@ endpoints: path: "/repos/*/rulesets" ``` -##### GraphQL Deny Rule (`protocol: graphql`) +##### WebSocket Deny Rule (`protocol: websocket`) + +WebSocket deny rules use the same field names as WebSocket allow rules, but they appear directly under each `deny_rules` entry instead of under an `allow` wrapper. + +| Field | Type | Required | Description | +|---|---|---|---| +| `method` | string | Yes | `GET` denies matching upgrade handshakes, `WEBSOCKET_TEXT` denies matching client text messages after upgrade, and `*` matches both inspected actions. | +| `path` | string | Yes | URL path pattern from the original upgrade request. Same glob syntax as allow rules. | +| `query` | map | No | Query parameter matchers from the original upgrade request. Same syntax as allow rule `query`. | + +Example WebSocket deny rules: + +```yaml showLineNumbers={false} +endpoints: + - host: realtime.example.com + port: 443 + protocol: websocket + enforcement: enforce + access: read-write + deny_rules: + - method: WEBSOCKET_TEXT + path: "/v1/admin/**" +``` + +##### GraphQL Deny Rule (`protocol: graphql` or GraphQL-over-WebSocket) -GraphQL deny rules use the same field names as GraphQL allow rules, but they appear directly under each `deny_rules` entry instead of under an `allow` wrapper. +GraphQL deny rules use the same field names as GraphQL allow rules, but they appear directly under each `deny_rules` entry instead of under an `allow` wrapper. On WebSocket GraphQL endpoints, they apply only to classified GraphQL operation messages; protocol lifecycle messages such as `connection_init`, `ping`, `pong`, and `complete` are allowed as WebSocket control-plane messages and are not payload-logged. | Field | Type | Required | Description | |---|---|---|---| diff --git a/docs/sandboxes/policies.mdx b/docs/sandboxes/policies.mdx index 4e0aa4357..fb0b04cfe 100644 --- a/docs/sandboxes/policies.mdx +++ b/docs/sandboxes/policies.mdx @@ -49,14 +49,14 @@ network_policies: Static sections are locked at sandbox creation. Changing them requires destroying and recreating the sandbox. Dynamic sections can be updated on a running sandbox with `openshell policy update` for incremental merges or `openshell policy set` for full replacement, and take effect without restarting. When a hot reload changes rules on an active HTTP L7 endpoint, existing keep-alive tunnels are closed before forwarding another parsed request. Credential-injection-only HTTP passthrough tunnels use the same reload boundary. Most HTTP clients reconnect automatically, and the next request is evaluated against the current policy. -Raw streams are connection-scoped and outside L7 live-reload guarantees. This includes `tls: skip`, non-HTTP TCP payloads, HTTP upgrades such as WebSocket, and long-lived response streams such as SSE. A reload applies to the next connection or next parsed HTTP request; it does not interrupt an already-forwarded raw stream. +Raw streams are connection-scoped and outside L7 live-reload guarantees. This includes `tls: skip`, non-HTTP TCP payloads, HTTP upgrades such as WebSocket, and long-lived response streams such as SSE. A reload applies to the next connection or next parsed HTTP request; it does not interrupt an already-forwarded raw stream. Use `protocol: websocket` when policy should stay attached to the RFC 6455 upgrade and client text messages after the allowed upgrade. Add `websocket_credential_rewrite: true` only when the relay should rewrite credential placeholders in client-to-server WebSocket text messages. Add `request_body_credential_rewrite: true` only on inspected REST endpoints that need OpenShell to rewrite placeholders in supported text request bodies. | Section | Type | Description | |---|---|---| | `filesystem_policy` | Static | Controls which directories the agent can access on disk. Paths are split into `read_only` and `read_write` lists. Any path not listed in either list is inaccessible. Set `include_workdir: true` to automatically add the agent's working directory to `read_write`. [Landlock LSM](https://docs.kernel.org/security/landlock.html) enforces these restrictions at the kernel level. | | `landlock` | Static | Configures Landlock LSM enforcement behavior. Set `compatibility` to `best_effort` (skip individual inaccessible paths while applying remaining rules) or `hard_requirement` (fail if any path is inaccessible or the required kernel ABI is unavailable). See the [Policy Schema Reference](/reference/policy-schema#landlock) for the full behavior table. | | `process` | Static | Sets the OS-level identity for the agent process. `run_as_user` and `run_as_group` default to `sandbox`. Root (`root` or `0`) is rejected. The agent also runs with seccomp filters that block dangerous system calls. | -| `network_policies` | Dynamic | Controls network access for ordinary outbound traffic from the sandbox. Each block has a name, a list of endpoints (host, port, protocol, and optional rules), and a list of binaries allowed to use those endpoints.
Every outbound connection except `https://inference.local` goes through the proxy, which queries the [policy engine](/about/how-it-works#core-components) with the destination and calling binary. A connection is allowed only when both match an entry in the same policy block.
For endpoints with `protocol: rest`, the proxy auto-detects TLS and terminates it so each HTTP request can be checked against that endpoint's `rules` (method and path).
Endpoints without `protocol` allow the TCP stream through without inspecting payloads.
If no endpoint matches, the connection is denied. Configure managed inference separately through [Inference Routing](/sandboxes/inference-routing). | +| `network_policies` | Dynamic | Controls network access for ordinary outbound traffic from the sandbox. Each block has a name, a list of endpoints (host, port, protocol, and optional rules), and a list of binaries allowed to use those endpoints.
Every outbound connection except `https://inference.local` goes through the proxy, which queries the [policy engine](/about/how-it-works#core-components) with the destination and calling binary. A connection is allowed only when both match an entry in the same policy block.
For endpoints with `protocol: rest`, the proxy auto-detects TLS and terminates it so each HTTP request can be checked against that endpoint's `rules` (method and path). For endpoints with `protocol: websocket`, the proxy validates the RFC 6455 upgrade and evaluates `GET` rules for the handshake plus either `WEBSOCKET_TEXT` rules for raw client text messages or GraphQL operation rules for GraphQL-over-WebSocket messages. Set `websocket_credential_rewrite: true` only when a WebSocket or REST compatibility endpoint must keep placeholder credentials in sandbox-owned text frames and resolve them at the OpenShell relay boundary.
Endpoints without `protocol` allow the TCP stream through without inspecting payloads.
If no endpoint matches, the connection is denied. Configure managed inference separately through [Inference Routing](/sandboxes/inference-routing). | ## Baseline Filesystem Paths @@ -123,7 +123,7 @@ The following steps outline the hot-reload policy update workflow. openshell logs --tail --source sandbox ``` -3. For additive network changes, use `openshell policy update`. This is the fastest path for adding endpoints, binaries, or REST allow/deny rules without replacing the full policy. The full option and format reference is in [Incremental Policy Updates](#incremental-policy-updates). +3. For additive network changes, use `openshell policy update`. This is the fastest path for adding endpoints, binaries, or REST and WebSocket allow/deny rules without replacing the full policy. The full option and format reference is in [Incremental Policy Updates](#incremental-policy-updates). ```shell openshell policy update \ @@ -136,7 +136,7 @@ The following steps outline the hot-reload policy update workflow. --wait ``` - `--add-allow` and `--add-deny` currently target existing `protocol: rest` endpoints only. If you pass multiple update flags in one command, OpenShell applies them as one atomic merge batch and persists at most one new revision. + `--add-allow` and `--add-deny` target existing `protocol: rest` or `protocol: websocket` endpoints. If you pass multiple update flags in one command, OpenShell applies them as one atomic merge batch and persists at most one new revision. 4. For larger edits, pull the current policy and edit the YAML directly. Strip the metadata header (Version, Hash, Status) before reusing the file. @@ -165,7 +165,7 @@ Use `openshell policy update` when you want to merge network policy changes into `openshell policy update` is useful when you want to: - add a new endpoint for an existing binary without touching other policy sections. -- add a few REST allow or deny rules after you see a blocked request in the logs. +- add a few REST or WebSocket allow/deny rules after you see a blocked request in the logs. - remove one endpoint or one named rule without rewriting the rest of the file. - preview a merged result locally with `--dry-run` before you send it to the gateway. @@ -173,15 +173,15 @@ Use `openshell policy set` instead when you want to replace the full policy, upd ### Update Commands -The incremental update surface is split into endpoint-level operations and REST rule-level operations. +The incremental update surface is split into endpoint-level operations and method/path rule-level operations for REST and WebSocket endpoints. | Flag | What it changes | Typical use | |---|---|---| -| `--add-endpoint ` | Creates or merges a network rule and endpoint. | Allow a new host and port, optionally with `access`, `protocol`, `enforcement`, and binaries. | +| `--add-endpoint ` | Creates or merges a network rule and endpoint. | Allow a new host and port, optionally with `access`, `protocol`, `enforcement`, endpoint options, and binaries. | | `--remove-endpoint ` | Removes one host and port match from the current policy. | Drop a stale endpoint or remove one port from a multi-port endpoint. | | `--remove-rule ` | Deletes a named `network_policies` entry. | Remove a whole rule by name when you no longer need it. | -| `--add-allow ` | Appends REST allow rules to an existing endpoint. | Permit one additional method and path on a REST API that is already configured. | -| `--add-deny ` | Appends REST deny rules to an existing endpoint. | Block a sensitive REST path under an endpoint that is otherwise allowed. | +| `--add-allow ` | Appends method/path allow rules to an existing REST or WebSocket endpoint. | Permit one additional REST method/path or WebSocket `WEBSOCKET_TEXT` path on an API that is already configured. | +| `--add-deny ` | Appends method/path deny rules to an existing REST or WebSocket endpoint. | Block a sensitive REST path or WebSocket text-message path under an endpoint that is otherwise allowed. | | `--binary ` | Adds binaries to every `--add-endpoint` rule in the same command. | Bind a new endpoint to one or more executables. | | `--rule-name ` | Overrides the generated rule name. | Keep a stable human-chosen rule name when adding exactly one endpoint. | | `--dry-run` | Shows the merged policy locally and does not call the gateway. | Review the result before persisting it. | @@ -194,17 +194,18 @@ The incremental update surface is split into endpoint-level operations and REST `--add-endpoint` works at the endpoint and rule level. It creates a new `network_policies` entry when needed, or merges into an existing rule that already covers the same host and port. Use it when you are defining where traffic may go and which binaries may send it. -`--add-allow` and `--add-deny` work at the REST request level. They do not create binaries, and they do not create a new endpoint. They modify an existing endpoint that already has `protocol: rest`. +`--add-allow` and `--add-deny` work at the method/path rule level. They do not create binaries, and they do not create a new endpoint. They modify an existing endpoint that already has `protocol: rest` or `protocol: websocket`. This is the practical difference: - Use `--add-endpoint` to say "allow this binary to reach `api.github.com:443`." - Use `--add-allow` to say "for that existing REST endpoint, also allow `POST /repos/*/issues`." - Use `--add-deny` to say "for that existing REST endpoint, explicitly deny `POST /admin/**`." +- Use `--add-allow` to say "for that existing WebSocket endpoint, also allow client text messages on `/v1/realtime/**`." -In the first pass of this feature: +Current constraints: -- `--add-allow` and `--add-deny` only work on `protocol: rest` endpoints. +- `--add-allow` and `--add-deny` work on `protocol: rest` and `protocol: websocket` endpoints. - `--add-deny` requires the endpoint to already have an allow base, either an `access` preset or explicit allow `rules`. - `protocol: sql` is not a practical incremental workflow today. OpenShell does not do full SQL parsing, and SQL enforcement is not meaningfully supported yet. @@ -213,7 +214,7 @@ In the first pass of this feature: `--add-endpoint` uses this format: ```text -host:port[:access[:protocol[:enforcement]]] +host:port[:access[:protocol[:enforcement[:options]]]] ``` Each segment has a fixed meaning: @@ -222,9 +223,10 @@ Each segment has a fixed meaning: |---|---|---| | `host` | Yes | Destination hostname. | | `port` | Yes | Destination port, `1` through `65535`. | -| `access` | No | Access preset for L7 endpoints: `read-only`, `read-write`, or `full`. Incremental updates currently expand presets for REST-shaped access. | -| `protocol` | No | L7 inspection mode: `rest` or `sql`. `sql` is audit-only and not a recommended workflow today. | +| `access` | No | Access preset for L7 endpoints: `read-only`, `read-write`, or `full`. Incremental updates expand presets into protocol-specific method/path rules for REST and WebSocket endpoints. | +| `protocol` | No | L7 inspection mode: `rest`, `websocket`, or `sql`. `sql` is audit-only and not a recommended workflow today. | | `enforcement` | No | Enforcement mode for inspected traffic: `enforce` or `audit`. | +| `options` | No | Comma-separated endpoint options. Use `websocket-credential-rewrite` with `protocol: websocket` or REST compatibility endpoints that perform a WebSocket upgrade. Use `request-body-credential-rewrite` only with `protocol: rest`. | Examples: @@ -232,19 +234,29 @@ Examples: |---|---| | `pypi.org:443` | Add a plain L4 endpoint. The proxy allows the TCP stream and does not inspect HTTP requests. | | `api.github.com:443:read-only:rest:enforce` | Add a REST endpoint with the `read-only` preset expanded by the policy engine into GET, HEAD, and OPTIONS access. | +| `api.example.com:443:read-write:rest:enforce:request-body-credential-rewrite` | Add a REST endpoint that rewrites credential placeholders in supported text request bodies. | +| `realtime.example.com:443:read-write:websocket:enforce` | Add a WebSocket endpoint with the `read-write` preset expanded by the policy engine into the upgrade `GET` and client `WEBSOCKET_TEXT` access. | +| `realtime.example.com:443:read-write:websocket:enforce:websocket-credential-rewrite` | Add a WebSocket endpoint that rewrites `openshell:resolve:env:*` placeholders in client text frames after an allowed upgrade. | -If you set `protocol: rest`, you also need an allow shape. With incremental updates, that means you should provide an `access` preset on `--add-endpoint`, then use `--add-allow` or `--add-deny` to refine REST endpoints later. +If you set `protocol: rest` or `protocol: websocket`, you also need an allow shape. With incremental updates, that means you should provide an `access` preset on `--add-endpoint`, then use `--add-allow` or `--add-deny` to refine method/path rules later. + +Use the `websocket-credential-rewrite` endpoint option with `protocol: websocket` when the sandbox should send credential placeholders in client text frames and have OpenShell resolve them after the allowed upgrade. The option can also be used with `protocol: rest` compatibility endpoints that perform a WebSocket upgrade. It is rejected for plain L4 or `protocol: sql` endpoints. + +Use the `request-body-credential-rewrite` endpoint option with `protocol: rest` when an API expects OpenShell-managed credentials in UTF-8 JSON, form, or text request bodies. OpenShell buffers up to 256 KiB, rewrites recognized credential placeholders, updates `Content-Length`, and rejects unresolved placeholders instead of forwarding them. The option is rejected for WebSocket, GraphQL, SQL, and plain L4 endpoints. + +Credential rewrite recognizes the canonical `openshell:resolve:env:KEY` placeholder form and whole-token provider-shaped aliases such as `provider-OPENSHELL-RESOLVE-ENV-API_TOKEN` when the referenced environment key exists in the configured provider credentials. For example: - `api.github.com:443:read-only:rest` is valid. +- `realtime.example.com:443:read-write:websocket` is valid. - `api.github.com:443::rest` is invalid. It does not mean "allow all traffic." An L7 endpoint with `protocol` but no `access` or `rules` is rejected when the policy loads. -When you pass multiple `--add-endpoint` flags in one command, every `--binary` value applies to every added endpoint in that command. If different endpoints need different binaries, use separate `policy update` commands. +Endpoint options belong to the individual `--add-endpoint` spec. When you pass multiple `--add-endpoint` flags in one command, every `--binary` value applies to every added endpoint in that command. If different endpoints need different binaries, use separate `policy update` commands. If you do not pass `--rule-name`, OpenShell generates one from the host and port, such as `allow_api_github_com_443`. -### REST Rule Specs +### Method/Path Rule Specs `--add-allow` and `--add-deny` use this format: @@ -252,7 +264,7 @@ If you do not pass `--rule-name`, OpenShell generates one from the host and port host:port:METHOD:path_glob ``` -This string identifies an existing REST endpoint and the request pattern you want to add. +This string identifies an existing REST or WebSocket endpoint and the request pattern you want to add. In shell commands, quote the full `SPEC` when it contains `*` or `**` so your shell passes it literally instead of expanding it as a local file glob. @@ -260,8 +272,8 @@ In shell commands, quote the full `SPEC` when it contains `*` or `**` so your sh |---|---| | `host` | Existing endpoint host. | | `port` | Existing endpoint port. | -| `METHOD` | HTTP method. The CLI normalizes it to uppercase. | -| `path_glob` | URL path glob. It must start with `/`, or be `**`, or start with `**/`. | +| `METHOD` | HTTP method for REST endpoints, or `GET` / `WEBSOCKET_TEXT` for WebSocket endpoints. The CLI normalizes it to uppercase. | +| `path_glob` | URL path glob. For WebSocket text messages, this still matches the upgraded request path, not message payload content. It must start with `/`, or be `**`, or start with `**/`. | This example: @@ -283,11 +295,11 @@ Path globs follow the same semantics as YAML allow and deny rules: - `/repos/*/issues` matches one repository owner or name segment in the middle. - `/repos/**` matches everything under `/repos/`. -The rule-level commands only modify method and path constraints. They do not change binaries, hostnames, ports, or protocol settings. +The rule-level commands only modify method and path constraints. They do not change binaries, hostnames, ports, protocol settings, or WebSocket message payload matching. ### Common Workflows -Use these patterns as starting points when you decide whether to update an endpoint or append REST rules. +Use these patterns as starting points when you decide whether to update an endpoint or append REST/WebSocket rules. #### Add a new L4 endpoint @@ -302,7 +314,7 @@ openshell policy update demo \ --wait ``` -This creates or merges endpoint entries and binds them to the listed binaries. It does not create per-path REST rules. +This creates or merges endpoint entries and binds them to the listed binaries. It does not create inspected method/path rules. #### Create a REST endpoint with a base allow set @@ -341,6 +353,31 @@ openshell policy update demo \ This adds a deny rule to the existing REST endpoint. The endpoint must already have an allow base. +#### Create a WebSocket endpoint with a base allow set + +Use `--add-endpoint` with `protocol: websocket` when the destination is an RFC 6455 WebSocket API. + +```shell +openshell policy update demo \ + --add-endpoint realtime.example.com:443:read-write:websocket:enforce:websocket-credential-rewrite \ + --binary /usr/bin/node \ + --wait +``` + +This creates a WebSocket endpoint and sets its base allow behavior through the `read-write` access preset. For WebSocket endpoints, `read-write` expands to the upgrade `GET` and client `WEBSOCKET_TEXT` messages on the upgraded request path. The rewrite option lets the sandbox send `openshell:resolve:env:*` placeholders in client text frames; OpenShell resolves them before forwarding to the upstream service. + +#### Add a WebSocket text-message deny rule + +Use `WEBSOCKET_TEXT` when you want to refine client-to-server text-frame policy without matching message payload content. + +```shell +openshell policy update demo \ + --add-deny 'realtime.example.com:443:WEBSOCKET_TEXT:/v1/admin/**' \ + --wait +``` + +This adds a deny rule to the existing WebSocket endpoint. The path glob matches the WebSocket upgrade path. + #### Remove one endpoint or rule Use `--remove-endpoint` to remove one host and port pair, or `--remove-rule` to delete the whole named rule. @@ -379,7 +416,7 @@ The CLI validates the argument shapes before it sends the request. The gateway t - a required segment is missing. - a port is outside `1` through `65535`. - `--add-allow` or `--add-deny` points at an endpoint that does not exist. -- `--add-allow` or `--add-deny` targets a non-REST endpoint. +- `--add-allow` or `--add-deny` targets an endpoint that is neither REST nor WebSocket. - `--add-deny` targets an endpoint that has no base allow set. ## Global Policy Override @@ -415,7 +452,7 @@ When triaging denied requests, check: - Destination host and port to confirm which endpoint is missing. - Calling binary path to confirm which `binaries` entry needs to be added or adjusted. -- HTTP method and path (for REST endpoints) to confirm which `rules` entry needs to be added or adjusted. +- HTTP method and path for REST endpoints, or `GET` / `WEBSOCKET_TEXT` and the upgraded request path for WebSocket endpoints, to confirm which `rules` entry needs to be added or adjusted. Then push the updated policy as described above. @@ -427,7 +464,7 @@ openshell policy update --add-allow 'api.github.com:443:GET:/repos/**' -- ## Examples -Add these blocks to the `network_policies` section of your sandbox policy. Apply simple endpoints and REST rule additions with `openshell policy update`, or apply any complete YAML block with `openshell policy set --policy --wait`. +Add these blocks to the `network_policies` section of your sandbox policy. Apply simple endpoints and REST/WebSocket rule additions with `openshell policy update`, or apply any complete YAML block with `openshell policy set --policy --wait`. Use **Simple endpoint** for host-level allowlists and **Granular rules** for method/path control. @@ -447,7 +484,7 @@ Allow `pip install` and `uv pip install` to reach PyPI: - { path: /usr/local/bin/uv } ``` -Endpoints without `protocol` use TCP passthrough, where the proxy allows the stream without inspecting payloads. If the stream is HTTP and TLS is auto-terminated, the proxy can still rewrite configured credential placeholders and closes keep-alive passthrough tunnels on policy reload before forwarding another request. +Endpoints without `protocol` use TCP passthrough, where the proxy allows the stream without inspecting payloads. If the stream is HTTP and TLS is auto-terminated, the proxy can still rewrite configured credential placeholders and closes keep-alive passthrough tunnels on policy reload before forwarding another request. WebSocket text-frame policy requires an explicit `protocol: websocket` endpoint. WebSocket payload credential rewrite can also be enabled on a `protocol: rest` compatibility endpoint with `websocket_credential_rewrite: true`. REST request body credential rewrite requires an inspected `protocol: rest` endpoint with `request_body_credential_rewrite: true`. @@ -505,7 +542,7 @@ For an end-to-end walkthrough that combines this policy with a GitHub credential - { path: /usr/bin/gh } ``` -Endpoints with `protocol: rest` enable HTTP request inspection. Endpoints with `protocol: graphql` parse GraphQL-over-HTTP payloads before evaluating rules. The endpoint-level `path` field lets both protocols share `api.github.com:443` without treating GraphQL payloads as plain REST `POST /graphql` requests. +Endpoints with `protocol: rest` enable HTTP request inspection and can opt in to supported text request body credential rewrite. Endpoints with `protocol: websocket` validate WebSocket upgrades and inspect client text messages on the upgraded request path. WebSocket endpoints can also classify GraphQL-over-WebSocket operation messages with the same operation rules used by GraphQL-over-HTTP. Endpoints with `protocol: graphql` parse GraphQL-over-HTTP payloads before evaluating rules. The endpoint-level `path` field lets these protocols share `api.github.com:443` without treating GraphQL payloads as plain REST `POST /graphql` requests. @@ -570,6 +607,36 @@ For allow rules, every selected root field in an operation must match one of the Hash-only persisted queries cannot be classified from the request alone. OpenShell denies them unless the endpoint uses `persisted_queries: allow_registered` and provides a trusted `graphql_persisted_queries` entry keyed by hash or saved-query ID. +### GraphQL-over-WebSocket matching + +Some APIs carry GraphQL operations over RFC 6455 WebSockets, commonly for subscriptions and realtime updates. Configure these as `protocol: websocket`, allow the upgrade with a normal `GET` rule, then add GraphQL operation rules for client operation messages. OpenShell recognizes modern `graphql-transport-ws` `subscribe` messages and legacy `graphql-ws` `start` messages. + +```yaml showLineNumbers={false} + realtime_graphql: + name: realtime_graphql + endpoints: + - host: realtime.example.com + port: 443 + path: "/graphql" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/graphql" + - allow: + operation_type: subscription + fields: [messageAdded] + - allow: + operation_type: query + fields: [viewer] + websocket_credential_rewrite: true + binaries: + - { path: /usr/bin/node } +``` + +When a WebSocket endpoint has GraphQL operation policy, client operation messages are fail-closed on malformed JSON, unsupported message types, parse errors, unregistered hash-only persisted queries, or unallowed operations. Use GraphQL operation rules for client messages rather than a raw `WEBSOCKET_TEXT` allow rule. Protocol lifecycle messages such as `connection_init`, `ping`, `pong`, and `complete` are allowed without payload logging; if `websocket_credential_rewrite: true` is set, placeholders inside those text messages are resolved before forwarding. + ### GraphQL service policy shapes GraphQL field names are application-specific, so treat these as starting shapes to review against the actual app schema: diff --git a/docs/security/best-practices.mdx b/docs/security/best-practices.mdx index 227bb4567..f51781144 100644 --- a/docs/security/best-practices.mdx +++ b/docs/security/best-practices.mdx @@ -97,9 +97,9 @@ The `protocol` field on an endpoint controls whether the proxy inspects individu | Aspect | Detail | |---|---| | Default | Endpoints without a `protocol` field use L4-only enforcement: the proxy checks host, port, and binary, then relays the TCP stream without inspecting payloads. | -| What you can change | Add `protocol: rest` to enable per-request HTTP method/path inspection, or `protocol: graphql` to inspect GraphQL operation type, operation name, and root fields. Pair either protocol with `rules` or access presets (`full`, `read-only`, `read-write`). | +| What you can change | Add `protocol: rest` to enable per-request HTTP method/path inspection, `protocol: websocket` to inspect RFC 6455 upgrade handshakes and client text messages, or `protocol: graphql` to inspect GraphQL-over-HTTP operation type, operation name, and root fields. WebSocket endpoints can also use GraphQL operation rules for GraphQL-over-WebSocket messages. Pair inspected protocols with `rules` or access presets (`full`, `read-only`, `read-write`). REST endpoints that need credential placeholders in supported text request bodies can set `request_body_credential_rewrite: true`. | | Risk if relaxed | L4-only endpoints allow the agent to send any data through the tunnel after the initial connection is permitted. The proxy cannot see HTTP methods, paths, or GraphQL operations. Adding `access: full` with L7 inspection enables observability but permits all inspected actions. | -| Recommendation | Use `protocol: rest` with specific `rules` for APIs where intent is encoded in method and path. Use `protocol: graphql` for GraphQL APIs where destructive operations are body-encoded. Prefer `access: read-only` or explicit allowlists, and deny hash-only persisted queries unless you maintain a trusted registry. Omit `protocol` for non-HTTP protocols (WebSocket, gRPC streaming). | +| Recommendation | Use `protocol: rest` with specific `rules` for APIs where intent is encoded in method and path. Add `request_body_credential_rewrite: true` only for REST APIs that require OpenShell-managed credentials in UTF-8 JSON, form, or text request bodies. Use `protocol: graphql` for GraphQL-over-HTTP APIs where destructive operations are body-encoded. Use `protocol: websocket` for RFC 6455 endpoints, with explicit `GET` and `WEBSOCKET_TEXT` rules for raw text protocols or explicit GraphQL operation rules for GraphQL-over-WebSocket. Prefer `access: read-only` or explicit allowlists, and deny hash-only persisted queries unless you maintain a trusted registry. Omit `protocol` for non-HTTP protocols. For WebSocket endpoints that must carry placeholder credentials in client text frames, add `websocket_credential_rewrite: true`. | ### Enforcement Mode (`audit` vs `enforce`) @@ -283,8 +283,8 @@ The following patterns weaken security without providing meaningful benefit. | Mistake | Why it matters | What to do instead | |---------|---------------|-------------------| -| Omitting `protocol: rest` on REST API endpoints | Without `protocol: rest`, the proxy uses L4-only enforcement. It allows the TCP stream through after checking host, port, and binary, but cannot inspect individual HTTP requests. | Add `protocol: rest` with specific `rules` to enable per-request method and path control. | -| Using `access: full` when finer rules would suffice | `access: full` with `protocol: rest` enables inspection but allows all HTTP methods and paths. | Use `access: read-only` or explicit `rules` to restrict what the agent can do at the HTTP level. | +| Omitting an inspected protocol on REST or WebSocket API endpoints | Without `protocol: rest` or `protocol: websocket`, the proxy uses L4-only enforcement. It allows the TCP stream through after checking host, port, and binary, but cannot inspect individual HTTP requests or WebSocket text messages. | Add `protocol: rest` or `protocol: websocket` with specific `rules` to enable method and path control. | +| Using `access: full` when finer rules would suffice | `access: full` with `protocol: rest` or `protocol: websocket` enables inspection but allows all methods and paths for that protocol. | Use `access: read-only` or explicit `rules` to restrict what the agent can do at the L7 level. | | Adding endpoints permanently when operator approval would suffice | Adding endpoints to the policy YAML makes them permanently reachable across all instances. | Use operator approval. Approved endpoints persist within the sandbox instance but reset on re-creation. | | Using broad binary globs | A glob like `/**` allows any binary to reach the endpoint, defeating binary-scoped enforcement. | Scope globs to specific directories (for example, `/sandbox/.vscode-server/**`). | | Skipping TLS termination on HTTPS APIs | Setting `tls: skip` disables credential injection and L7 inspection. | Use the default auto-detect behavior unless the upstream requires client-certificate mTLS. | diff --git a/e2e/rust/Cargo.toml b/e2e/rust/Cargo.toml index 0da2e417b..73b723718 100644 --- a/e2e/rust/Cargo.toml +++ b/e2e/rust/Cargo.toml @@ -41,6 +41,11 @@ name = "gateway_resume" path = "tests/gateway_resume.rs" required-features = ["e2e-docker"] +[[test]] +name = "websocket_conformance" +path = "tests/websocket_conformance.rs" +required-features = ["e2e-docker"] + [[test]] name = "user_namespaces" path = "tests/user_namespaces.rs" diff --git a/e2e/rust/tests/websocket_conformance.rs b/e2e/rust/tests/websocket_conformance.rs new file mode 100644 index 000000000..16606448c --- /dev/null +++ b/e2e/rust/tests/websocket_conformance.rs @@ -0,0 +1,376 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#![cfg(feature = "e2e")] + +//! E2E regression: WebSocket credential placeholders are resolved on the real +//! Docker-backed sandbox path after an RFC 6455 upgrade. +//! +//! The sandbox process sends its provider-managed placeholder in a masked text +//! frame. The local upstream only reports whether it saw the real secret and +//! whether any placeholder survived; it never echoes payload bytes, placeholder +//! text, or secret material back into test output. + +use std::io::Write; +use std::process::Stdio; +use std::sync::Mutex; + +use openshell_e2e::harness::binary::openshell_cmd; +use openshell_e2e::harness::container::ContainerHttpServer; +use openshell_e2e::harness::sandbox::SandboxGuard; +use tempfile::NamedTempFile; + +const PROVIDER_NAME: &str = "e2e-websocket-conformance"; +const TEST_SERVER_ALIAS: &str = "websocket-conformance.openshell.test"; +const TEST_SECRET: &str = "sk-e2e-websocket-conformance-secret"; +const TOKEN_ENV: &str = "WS_E2E_TOKEN"; +const PLACEHOLDER_PREFIX: &str = "openshell:resolve:env:"; +static PROVIDER_LOCK: Mutex<()> = Mutex::new(()); + +async fn run_cli(args: &[&str]) -> Result { + let mut cmd = openshell_cmd(); + cmd.args(args).stdout(Stdio::piped()).stderr(Stdio::piped()); + + let output = cmd + .output() + .await + .map_err(|e| format!("failed to spawn openshell {}: {e}", args.join(" ")))?; + + let stdout = String::from_utf8_lossy(&output.stdout).to_string(); + let stderr = String::from_utf8_lossy(&output.stderr).to_string(); + let combined = format!("{stdout}{stderr}"); + + if !output.status.success() { + return Err(format!( + "openshell {} failed (exit {:?}):\n{combined}", + args.join(" "), + output.status.code() + )); + } + + Ok(combined) +} + +async fn delete_provider(name: &str) { + let mut cmd = openshell_cmd(); + cmd.arg("provider") + .arg("delete") + .arg(name) + .stdout(Stdio::null()) + .stderr(Stdio::null()); + let _ = cmd.status().await; +} + +async fn create_generic_provider(name: &str) -> Result { + let credential = format!("{TOKEN_ENV}={TEST_SECRET}"); + run_cli(&[ + "provider", + "create", + "--name", + name, + "--type", + "generic", + "--credential", + &credential, + ]) + .await +} + +async fn start_websocket_probe_server() -> Result { + let script = format!( + r#" +import base64 +import hashlib +import json +import socketserver +import struct + +GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" +SECRET = {secret:?} +PLACEHOLDER_PREFIX = {placeholder_prefix:?} + +def recv_until(sock, marker): + data = b"" + while marker not in data: + chunk = sock.recv(4096) + if not chunk: + break + data += chunk + return data + +def read_exact(sock, size): + data = b"" + while len(data) < size: + chunk = sock.recv(size - len(data)) + if not chunk: + raise EOFError("unexpected end of websocket frame") + data += chunk + return data + +def read_frame(sock): + header = read_exact(sock, 2) + first, second = header[0], header[1] + length = second & 0x7F + if length == 126: + length = struct.unpack("!H", read_exact(sock, 2))[0] + elif length == 127: + length = struct.unpack("!Q", read_exact(sock, 8))[0] + mask = read_exact(sock, 4) if second & 0x80 else b"" + payload = read_exact(sock, length) + if mask: + payload = bytes(byte ^ mask[index % 4] for index, byte in enumerate(payload)) + return first, payload + +def send_text(sock, payload): + data = payload.encode("utf-8") + if len(data) < 126: + header = bytes([0x81, len(data)]) + elif len(data) <= 0xFFFF: + header = bytes([0x81, 126]) + struct.pack("!H", len(data)) + else: + header = bytes([0x81, 127]) + struct.pack("!Q", len(data)) + sock.sendall(header + data) + +def header_value(request, name): + prefix = name.lower() + ":" + for line in request.split("\r\n"): + if line.lower().startswith(prefix): + return line.split(":", 1)[1].strip() + return "" + +class Handler(socketserver.BaseRequestHandler): + def handle(self): + request_bytes = recv_until(self.request, b"\r\n\r\n") + request = request_bytes.decode("iso-8859-1", "replace") + if "upgrade: websocket" not in request.lower(): + self.request.sendall( + b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: close\r\n\r\nok" + ) + return + + key = header_value(request, "Sec-WebSocket-Key") + accept = base64.b64encode(hashlib.sha1((key + GUID).encode("ascii")).digest()).decode("ascii") + response = ( + "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + f"Sec-WebSocket-Accept: {{accept}}\r\n" + "\r\n" + ) + self.request.sendall(response.encode("ascii")) + + _, payload = read_frame(self.request) + text = payload.decode("utf-8", "replace") + result = {{ + "saw_placeholder": PLACEHOLDER_PREFIX in text, + "saw_secret": SECRET in text, + }} + send_text(self.request, json.dumps(result, sort_keys=True)) + +class Server(socketserver.ThreadingTCPServer): + allow_reuse_address = True + +Server(("0.0.0.0", 8000), Handler).serve_forever() +"#, + secret = TEST_SECRET, + placeholder_prefix = PLACEHOLDER_PREFIX, + ); + + ContainerHttpServer::start_python(TEST_SERVER_ALIAS, &script).await +} + +fn write_websocket_policy(host: &str, port: u16) -> Result { + let mut file = NamedTempFile::new().map_err(|e| format!("create temp policy file: {e}"))?; + let policy = format!( + r#"version: 1 + +filesystem_policy: + include_workdir: true + read_only: + - /usr + - /lib + - /proc + - /dev/urandom + - /app + - /etc + - /var/log + read_write: + - /sandbox + - /tmp + - /dev/null + +landlock: + compatibility: best_effort + +process: + run_as_user: sandbox + run_as_group: sandbox + +network_policies: + websocket_conformance: + name: websocket_conformance + endpoints: + - host: {host} + port: {port} + protocol: websocket + enforcement: enforce + access: read-write + websocket_credential_rewrite: true + allowed_ips: + - "10.0.0.0/8" + - "172.0.0.0/8" + - "192.168.0.0/16" + - "fc00::/7" + binaries: + - path: /usr/bin/python* + - path: /usr/local/bin/python* + - path: /sandbox/.uv/python/*/bin/python* +"# + ); + file.write_all(policy.as_bytes()) + .map_err(|e| format!("write temp policy file: {e}"))?; + file.flush() + .map_err(|e| format!("flush temp policy file: {e}"))?; + Ok(file) +} + +fn websocket_client_script(host: &str, port: u16) -> String { + format!( + r#" +import base64 +import json +import os +import socket +import struct + +HOST = {host:?} +PORT = {port} +TOKEN_ENV = {token_env:?} + +def recv_until(sock, marker): + data = b"" + while marker not in data: + chunk = sock.recv(4096) + if not chunk: + break + data += chunk + return data + +def read_exact(sock, size): + data = b"" + while len(data) < size: + chunk = sock.recv(size - len(data)) + if not chunk: + raise EOFError("unexpected end of websocket frame") + data += chunk + return data + +def masked_text_frame(payload): + data = payload.encode("utf-8") + mask = os.urandom(4) + if len(data) < 126: + header = bytes([0x81, 0x80 | len(data)]) + elif len(data) <= 0xFFFF: + header = bytes([0x81, 0x80 | 126]) + struct.pack("!H", len(data)) + else: + header = bytes([0x81, 0x80 | 127]) + struct.pack("!Q", len(data)) + masked = bytes(byte ^ mask[index % 4] for index, byte in enumerate(data)) + return header + mask + masked + +def read_frame(sock): + first, second = read_exact(sock, 2) + length = second & 0x7F + if length == 126: + length = struct.unpack("!H", read_exact(sock, 2))[0] + elif length == 127: + length = struct.unpack("!Q", read_exact(sock, 8))[0] + mask = read_exact(sock, 4) if second & 0x80 else b"" + payload = read_exact(sock, length) + if mask: + payload = bytes(byte ^ mask[index % 4] for index, byte in enumerate(payload)) + return first, payload + +token = os.environ[TOKEN_ENV] +payload = json.dumps({{"authorization": "Bearer " + token}}, sort_keys=True) +key = base64.b64encode(os.urandom(16)).decode("ascii") + +with socket.create_connection((HOST, PORT), timeout=20) as sock: + request = ( + f"GET /ws HTTP/1.1\r\n" + f"Host: {{HOST}}:{{PORT}}\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + f"Sec-WebSocket-Key: {{key}}\r\n" + "Sec-WebSocket-Version: 13\r\n" + "\r\n" + ) + sock.sendall(request.encode("ascii")) + response = recv_until(sock, b"\r\n\r\n").decode("iso-8859-1", "replace") + if not response.startswith("HTTP/1.1 101"): + raise RuntimeError("websocket upgrade failed") + sock.sendall(masked_text_frame(payload)) + _, response_payload = read_frame(sock) + print(response_payload.decode("utf-8")) +"#, + host = host, + port = port, + token_env = TOKEN_ENV, + ) +} + +#[tokio::test] +async fn websocket_text_placeholder_is_rewritten_in_docker_sandbox() { + let _provider_lock = PROVIDER_LOCK + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + + delete_provider(PROVIDER_NAME).await; + create_generic_provider(PROVIDER_NAME) + .await + .expect("create generic provider"); + + let result = async { + let server = start_websocket_probe_server().await?; + let policy = write_websocket_policy(&server.host, server.port)?; + let policy_path = policy + .path() + .to_str() + .ok_or_else(|| "temp policy path should be utf-8".to_string())? + .to_string(); + let script = websocket_client_script(&server.host, server.port); + + SandboxGuard::create(&[ + "--policy", + &policy_path, + "--provider", + PROVIDER_NAME, + "--", + "python3", + "-c", + &script, + ]) + .await + } + .await; + + delete_provider(PROVIDER_NAME).await; + + let guard = result.expect("sandbox create"); + assert!( + guard + .create_output + .contains(r#"{"saw_placeholder": false, "saw_secret": true}"#), + "expected upstream to see only the resolved secret marker:\n{}", + guard.create_output + ); + assert!( + !guard.create_output.contains(TEST_SECRET), + "test output should not expose the raw WebSocket credential:\n{}", + guard.create_output + ); + assert!( + !guard.create_output.contains(PLACEHOLDER_PREFIX), + "test output should not expose unresolved credential placeholders:\n{}", + guard.create_output + ); +} diff --git a/proto/sandbox.proto b/proto/sandbox.proto index f7df5945e..b40d95cb1 100644 --- a/proto/sandbox.proto +++ b/proto/sandbox.proto @@ -70,7 +70,7 @@ message NetworkEndpoint { // Single port (backwards compat). Use `ports` for multiple ports. // Mutually exclusive with `ports` — if both are set, `ports` takes precedence. uint32 port = 2; - // Application protocol for L7 inspection: "rest", "graphql", "sql", or "" (L4-only). + // Application protocol for L7 inspection: "rest", "websocket", "graphql", "sql", or "" (L4-only). string protocol = 3; // TLS handling: "terminate" or "passthrough" (default). string tls = 4; @@ -116,6 +116,14 @@ message NetworkEndpoint { // protocol "rest" when both surfaces live under api.example.com:443. // Empty means all paths. string path = 15; + // When true on a "rest" endpoint, OpenShell rewrites credential placeholders + // inside client-to-server WebSocket text messages after an allowed HTTP 101 + // upgrade. Defaults to false. + bool websocket_credential_rewrite = 16; + // When true on a "rest" endpoint, OpenShell rewrites credential placeholders + // inside supported textual HTTP request bodies before forwarding upstream. + // Defaults to false. + bool request_body_credential_rewrite = 17; } // Trusted GraphQL operation classification. diff --git a/tasks/test.toml b/tasks/test.toml index bf5741c72..a1f4c6429 100644 --- a/tasks/test.toml +++ b/tasks/test.toml @@ -34,6 +34,13 @@ run = [ "e2e/with-docker-gateway.sh cargo test --manifest-path e2e/rust/Cargo.toml --features e2e-docker", ] +["e2e:websocket-conformance"] +description = "Run focused WebSocket conformance e2e tests against a Docker-backed gateway" +run = [ + "cargo build -p openshell-cli --features openshell-core/dev-settings", + "e2e/with-docker-gateway.sh cargo test --manifest-path e2e/rust/Cargo.toml --features e2e-docker --test websocket_conformance", +] + ["e2e:python"] description = "Run Python e2e tests against a Docker-backed gateway (E2E_PARALLEL=N or 'auto'; default 5)" depends = ["python:proto"]