diff --git a/Cargo.lock b/Cargo.lock index 103499498..7e8b1c283 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1245,6 +1245,7 @@ dependencies = [ "etherparse", "expect-test", "futures", + "hex", "hostname 0.4.1", "http-body-util", "hyper 1.6.0", @@ -1273,10 +1274,12 @@ dependencies = [ "reqwest", "rstest", "rustls-cng", + "rustls-native-certs", "serde", "serde-querystring", "serde_json", "serde_urlencoded", + "sha2", "smol_str", "sysevent", "sysevent-codes", @@ -6215,6 +6218,8 @@ dependencies = [ "tempfile", "test-utils", "tokio 1.46.1", + "tokio-tungstenite 0.26.2", + "typed-builder", ] [[package]] diff --git a/devolutions-gateway/Cargo.toml b/devolutions-gateway/Cargo.toml index 1c273719e..280442959 100644 --- a/devolutions-gateway/Cargo.toml +++ b/devolutions-gateway/Cargo.toml @@ -75,6 +75,9 @@ zeroize = { version = "1.8", features = ["derive"] } multibase = "0.9" argon2 = { version = "0.5", features = ["std"] } x509-cert = { version = "0.2", default-features = false, features = ["std"] } +sha2 = "0.10" +hex = "0.4" +rustls-native-certs = "0.8" # Logging tracing = "0.1" diff --git a/devolutions-gateway/src/api/fwd.rs b/devolutions-gateway/src/api/fwd.rs index 9f285cec3..ac3927190 100644 --- a/devolutions-gateway/src/api/fwd.rs +++ b/devolutions-gateway/src/api/fwd.rs @@ -243,12 +243,12 @@ where if with_tls { trace!("Establishing TLS connection with server"); - // Establish TLS connection with server - - let server_stream = crate::tls::connect(selected_target.host().to_owned(), server_stream) - .await - .context("TLS connect") - .map_err(ForwardError::BadGateway)?; + // Establish TLS connection with server. + let server_stream = + crate::tls::safe_connect(selected_target.host().to_owned(), server_stream, claims.cert_thumb256) + .await + .context("TLS connect") + .map_err(ForwardError::BadGateway)?; info!("WebSocket-TLS forwarding"); diff --git a/devolutions-gateway/src/api/webapp.rs b/devolutions-gateway/src/api/webapp.rs index 2ef980cfb..efe830a38 100644 --- a/devolutions-gateway/src/api/webapp.rs +++ b/devolutions-gateway/src/api/webapp.rs @@ -339,6 +339,7 @@ pub(crate) async fn sign_session_token( jet_reuse: ReconnectionPolicy::Disallowed, exp, jti, + cert_thumb256: None, } .pipe(serde_json::to_value) .map(|mut claims| { diff --git a/devolutions-gateway/src/rd_clean_path.rs b/devolutions-gateway/src/rd_clean_path.rs index 056166021..4be8ae37e 100644 --- a/devolutions-gateway/src/rd_clean_path.rs +++ b/devolutions-gateway/src/rd_clean_path.rs @@ -256,7 +256,7 @@ async fn process_cleanpath( // Establish TLS connection with server - let server_stream = crate::tls::connect(selected_target.host().to_owned(), server_stream) + let server_stream = crate::tls::dangerous_connect(selected_target.host().to_owned(), server_stream) .await .map_err(|source| CleanPathError::TlsHandshake { source, diff --git a/devolutions-gateway/src/rdp_proxy.rs b/devolutions-gateway/src/rdp_proxy.rs index 8d5a40c8a..fce931d80 100644 --- a/devolutions-gateway/src/rdp_proxy.rs +++ b/devolutions-gateway/src/rdp_proxy.rs @@ -87,7 +87,7 @@ where // -- Perform the TLS upgrading for both the client and the server, effectively acting as a man-in-the-middle -- // let client_tls_upgrade_fut = tls_conf.acceptor.accept(client_stream); - let server_tls_upgrade_fut = crate::tls::connect(server_dns_name.clone(), server_stream); + let server_tls_upgrade_fut = crate::tls::dangerous_connect(server_dns_name.clone(), server_stream); let (client_stream, server_stream) = tokio::join!(client_tls_upgrade_fut, server_tls_upgrade_fut); @@ -510,7 +510,7 @@ async fn get_cached_gateway_public_key( async fn retrieve_gateway_public_key(hostname: String, acceptor: tokio_rustls::TlsAcceptor) -> anyhow::Result> { let (client_side, server_side) = tokio::io::duplex(4096); - let connect_fut = crate::tls::connect(hostname, client_side); + let connect_fut = crate::tls::dangerous_connect(hostname, client_side); let accept_fut = acceptor.accept(server_side); let (connect_res, _) = tokio::join!(connect_fut, accept_fut); diff --git a/devolutions-gateway/src/tls.rs b/devolutions-gateway/src/tls.rs index 7288859dc..94d6b41b0 100644 --- a/devolutions-gateway/src/tls.rs +++ b/devolutions-gateway/src/tls.rs @@ -1,9 +1,12 @@ +use std::collections::HashMap; use std::io; use std::sync::{Arc, LazyLock}; use anyhow::Context as _; +use parking_lot::Mutex; use tokio_rustls::client::TlsStream; use tokio_rustls::rustls::{self, pki_types}; +use x509_cert::der::Decode as _; static DEFAULT_CIPHER_SUITES: &[rustls::SupportedCipherSuite] = rustls::crypto::ring::DEFAULT_CIPHER_SUITES; @@ -15,7 +18,7 @@ static DEFAULT_CIPHER_SUITES: &[rustls::SupportedCipherSuite] = rustls::crypto:: // // We’ll reuse the same TLS client config for all proxy-based TLS connections. // (TlsConnector is just a wrapper around the config providing the `connect` method.) -static TLS_CONNECTOR: LazyLock = LazyLock::new(|| { +static DANGEROUS_TLS_CONNECTOR: LazyLock = LazyLock::new(|| { let mut config = rustls::client::ClientConfig::builder() .dangerous() .with_custom_certificate_verifier(Arc::new(danger::NoCertificateVerification)) @@ -31,7 +34,27 @@ static TLS_CONNECTOR: LazyLock = LazyLock::new(|| { tokio_rustls::TlsConnector::from(Arc::new(config)) }); -pub async fn connect(dns_name: String, stream: IO) -> io::Result> +static NATIVE_ROOTS_VERIFIER: LazyLock> = + LazyLock::new(|| Arc::new(NativeRootsVerifier::new())); + +static SAFE_TLS_CONNECTOR: LazyLock = LazyLock::new(|| { + let mut config = rustls::client::ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(NATIVE_ROOTS_VERIFIER.clone()) + .with_no_client_auth(); + + config.resumption = rustls::client::Resumption::disabled(); + + tokio_rustls::TlsConnector::from(Arc::new(config)) +}); + +// Cache for thumbprint-anchored TLS connectors to avoid recreating them for each connection. +// The rustls documentation recommends creating ClientConfig once per process rather than per connection. +static THUMBPRINT_ANCHORED_CONNECTORS: LazyLock< + Mutex>, +> = LazyLock::new(|| Mutex::new(HashMap::new())); + +pub async fn dangerous_connect(dns_name: String, stream: IO) -> io::Result> where IO: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, { @@ -39,7 +62,68 @@ where let dns_name = pki_types::ServerName::try_from(dns_name).map_err(io::Error::other)?; - let mut tls_stream = TLS_CONNECTOR.connect(dns_name, stream).await?; + let mut tls_stream = DANGEROUS_TLS_CONNECTOR.connect(dns_name, stream).await?; + + // > To keep it simple and correct, [TlsStream] will behave like `BufWriter`. + // > For `TlsStream`, this means that data written by `poll_write` + // > is not guaranteed to be written to `TcpStream`. + // > You must call `poll_flush` to ensure that it is written to `TcpStream`. + // + // source: https://docs.rs/tokio-rustls/latest/tokio_rustls/#why-do-i-need-to-call-poll_flush + tls_stream.flush().await?; + + Ok(tls_stream) +} + +/// Connect to a TLS server with optional certificate thumbprint anchoring. +/// +/// # Thumbprint Anchoring Behavior +/// +/// When `cert_thumb256` is provided: +/// - If thumbprint matches: Accept immediately, bypassing ALL certificate checks (expiration, key usage, trust chain) +/// - If thumbprint doesn't match: Fall back to standard TLS verification +/// +/// This is an escape hatch for certificate issues, NOT for long-term use. +pub async fn safe_connect( + dns_name: String, + stream: IO, + cert_thumb256: Option, +) -> io::Result> +where + IO: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, +{ + use tokio::io::AsyncWriteExt as _; + + let server_name = pki_types::ServerName::try_from(dns_name.clone()).map_err(io::Error::other)?; + + // Get the appropriate connector, using cache for thumbprint-anchored ones. + let connector = if let Some(thumbprint) = cert_thumb256 { + // Check the cache first. + let mut cache = THUMBPRINT_ANCHORED_CONNECTORS.lock(); + + // Clone existing connector or create and cache a new one. + cache + .entry(thumbprint.clone()) + .or_insert_with(|| { + debug!(%thumbprint, "Creating new thumbprint-anchored TLS connector"); + + let verifier = Arc::new(ThumbprintAnchoredVerifier::new(thumbprint)); + + let mut config = rustls::client::ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(verifier) + .with_no_client_auth(); + + config.resumption = rustls::client::Resumption::disabled(); + + tokio_rustls::TlsConnector::from(Arc::new(config)) + }) + .clone() + } else { + SAFE_TLS_CONNECTOR.clone() + }; + + let mut tls_stream = connector.connect(server_name, stream).await?; // > To keep it simple and correct, [TlsStream] will behave like `BufWriter`. // > For `TlsStream`, this means that data written by `poll_write` @@ -480,6 +564,284 @@ pub fn check_certificate(cert: &[u8], at: time::OffsetDateTime) -> anyhow::Resul }) } +/// Standard certificate verifier using native roots based on the [`WebPkiServerVerifier`]. +/// +/// This verifier attempts normal TLS verification using system roots. +/// If verification fails, certificate details are logged. +#[derive(Debug)] +pub struct NativeRootsVerifier { + inner: rustls::client::WebPkiServerVerifier, +} + +impl NativeRootsVerifier { + pub fn new() -> Self { + // Create a standard verifier using platform native certificate store. + let mut root_store = rustls::RootCertStore::empty(); + + // Load certificates from the platform native certificate store. + let result = rustls_native_certs::load_native_certs(); + + for error in result.errors { + warn!(error = %error, "Error when loading native certificate"); + } + + let mut added_count = 0; + + for cert in result.certs { + if root_store.add(cert).is_ok() { + added_count += 1; + } + } + + if added_count == 0 { + warn!("No valid certificates found in platform native certificate store"); + } else { + debug!(count = added_count, "Loaded native certificates"); + } + + let webpki_server_verifier = rustls::client::WebPkiServerVerifier::builder(Arc::new(root_store)) + .build() + .expect("failed to build WebPkiServerVerifier; this should not fail"); + + Self { + inner: Arc::into_inner(webpki_server_verifier).expect("exactly one strong reference at this point"), + } + } +} + +impl rustls::client::danger::ServerCertVerifier for NativeRootsVerifier { + fn verify_server_cert( + &self, + end_entity: &pki_types::CertificateDer<'_>, + intermediates: &[pki_types::CertificateDer<'_>], + server_name: &pki_types::ServerName<'_>, + ocsp_response: &[u8], + now: pki_types::UnixTime, + ) -> Result { + match self + .inner + .verify_server_cert(end_entity, intermediates, server_name, ocsp_response, now) + { + Ok(verified) => Ok(verified), + Err(verification_error) => { + // Compute SHA-256 thumbprint of the certificate. + let thumbprint = thumbprint::compute_sha256_thumbprint(end_entity); + + // Extract certificate details. + let cert_info = extract_cert_info(end_entity); + + error!( + cert_subject = %cert_info.subject, + cert_issuer = %cert_info.issuer, + not_before = %cert_info.not_before, + not_after = %cert_info.not_after, + san = %cert_info.sans, + reason = %verification_error, + sha256_thumb = %thumbprint, + hint = "PASTE_THIS_THUMBPRINT_IN_RDM_CONNECTION", + "Invalid peer certificate" + ); + + Err(verification_error) + } + } + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &pki_types::CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + self.inner.verify_tls12_signature(message, cert, dss) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &pki_types::CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + self.inner.verify_tls13_signature(message, cert, dss) + } + + fn supported_verify_schemes(&self) -> Vec { + self.inner.supported_verify_schemes() + } +} + +impl Default for NativeRootsVerifier { + fn default() -> Self { + Self::new() + } +} + +/// Certificate verifier that supports thumbprint anchoring. +/// +/// This verifier accepts the certificate if the provided thumbprint matches the leaf certificate, +/// otherwise normal TLS verification using system roots is performed. +/// +/// ## Security Warning +/// +/// When thumbprint matches, this bypasses ALL standard TLS verification: +/// - Certificate expiration dates are NOT checked +/// - Key usage extensions are NOT validated +/// - Certificate chain trust is NOT verified +/// - Hostname matching is NOT performed +/// +/// This is an **escape hatch** for users with certificate issues, NOT a long-term solution. +/// Users should resolve certificate problems and remove thumbprint configuration ASAP. +#[derive(Debug)] +pub struct ThumbprintAnchoredVerifier { + expected_thumbprint: thumbprint::Sha256Thumbprint, +} + +impl ThumbprintAnchoredVerifier { + pub fn new(expected_thumbprint: thumbprint::Sha256Thumbprint) -> Self { + Self { expected_thumbprint } + } +} + +impl rustls::client::danger::ServerCertVerifier for ThumbprintAnchoredVerifier { + fn verify_server_cert( + &self, + end_entity: &pki_types::CertificateDer<'_>, + intermediates: &[pki_types::CertificateDer<'_>], + server_name: &pki_types::ServerName<'_>, + ocsp_response: &[u8], + now: pki_types::UnixTime, + ) -> Result { + // Compute SHA-256 thumbprint of the certificate. + let actual_thumbprint = thumbprint::compute_sha256_thumbprint(end_entity); + + // Thumbprint matches, accept immediately. + // SECURITY: This bypasses ALL certificate validation checks when thumbprint matches. + // No validation of: expiration, key usage, hostname, or trust chain. + if actual_thumbprint == self.expected_thumbprint { + info!( + sha256_thumb = %actual_thumbprint, + "Accepting TLS connection via certificate thumbprint anchor (bypassing standard validation)" + ); + + return Ok(rustls::client::danger::ServerCertVerified::assertion()); + } + + // Otherwise, try the normal verification. + NATIVE_ROOTS_VERIFIER.verify_server_cert(end_entity, intermediates, server_name, ocsp_response, now) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &pki_types::CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + NATIVE_ROOTS_VERIFIER.verify_tls12_signature(message, cert, dss) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &pki_types::CertificateDer<'_>, + dss: &rustls::DigitallySignedStruct, + ) -> Result { + NATIVE_ROOTS_VERIFIER.verify_tls13_signature(message, cert, dss) + } + + fn supported_verify_schemes(&self) -> Vec { + NATIVE_ROOTS_VERIFIER.supported_verify_schemes() + } +} + +struct CertInfo { + subject: String, + issuer: String, + not_before: x509_cert::time::Time, + not_after: x509_cert::time::Time, + sans: String, +} + +fn extract_cert_info(cert_der: &[u8]) -> CertInfo { + use bytes::Buf as _; + use std::fmt::Write as _; + + match x509_cert::Certificate::from_der(cert_der) { + Ok(cert) => { + let subject = cert.tbs_certificate.subject.to_string(); + let issuer = cert.tbs_certificate.issuer.to_string(); + let not_before = cert.tbs_certificate.validity.not_before; + let not_after = cert.tbs_certificate.validity.not_after; + + let mut sans = String::new(); + let mut first = true; + + if let Some(extensions) = cert.tbs_certificate.extensions { + for ext in extensions { + if let Ok(san) = x509_cert::ext::pkix::SubjectAltName::from_der(ext.extn_value.as_bytes()) { + for name in san.0 { + if first { + first = false; + } else { + let _ = write!(sans, ","); + } + + match name { + x509_cert::ext::pkix::name::GeneralName::OtherName(other_name) => { + let _ = write!(sans, "{}", other_name.type_id); + } + x509_cert::ext::pkix::name::GeneralName::Rfc822Name(name) => { + let _ = write!(sans, "{}", name.as_str()); + } + x509_cert::ext::pkix::name::GeneralName::DnsName(name) => { + let _ = write!(sans, "{}", name.as_str()); + } + x509_cert::ext::pkix::name::GeneralName::DirectoryName(rdn_sequence) => { + let _ = write!(sans, "{rdn_sequence}"); + } + x509_cert::ext::pkix::name::GeneralName::EdiPartyName(_) => { + let _ = write!(sans, ""); + } + x509_cert::ext::pkix::name::GeneralName::UniformResourceIdentifier(uri) => { + let _ = write!(sans, "{}", uri.as_str()); + } + x509_cert::ext::pkix::name::GeneralName::IpAddress(octet_string) => { + if let Ok(ip) = octet_string.as_bytes().try_get_u128() { + let ip = std::net::Ipv6Addr::from_bits(ip); + let _ = write!(sans, "{ip}"); + } else if let Ok(ip) = octet_string.as_bytes().try_get_u32() { + let ip = std::net::Ipv4Addr::from_bits(ip); + let _ = write!(sans, "{ip}"); + } else { + let _ = write!(sans, ""); + } + } + x509_cert::ext::pkix::name::GeneralName::RegisteredId(object_identifier) => { + let _ = write!(sans, "{object_identifier}"); + } + } + } + } + } + } + + CertInfo { + subject, + issuer, + not_before, + not_after, + sans, + } + } + Err(_) => CertInfo { + subject: "".to_owned(), + issuer: "".to_owned(), + not_before: x509_cert::time::Time::INFINITY, + not_after: x509_cert::time::Time::INFINITY, + sans: "".to_owned(), + }, + } +} + pub mod sanity { use tokio_rustls::rustls; @@ -585,3 +947,119 @@ pub mod danger { } } } + +pub mod thumbprint { + use core::fmt; + + // SHA-256 thumbprint should be exactly 64 hex characters (32 bytes). + const EXPECTED_SHA256_LENGTH: usize = 64; + + /// Normalized SHA-256 Thumbprint. + #[derive(Debug, Clone, PartialEq, Eq, Hash)] + pub struct Sha256Thumbprint( + /// INVARIANT: 64-character, lowercased hex with no separator. + String, + ); + + impl fmt::Display for Sha256Thumbprint { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } + } + + impl Sha256Thumbprint { + pub fn as_str(&self) -> &str { + &self.0 + } + } + + #[derive(Debug, thiserror::Error)] + #[error( + "certificate thumbprint has unexpected length: expected {EXPECTED_SHA256_LENGTH} hex characters (SHA-256), got {actual_length}; \ + this may indicate a SHA-1 thumbprint (40 chars) or incorrect format" + )] + pub struct ThumbprintLengthError { + actual_length: usize, + } + + /// Normalize thumbprint to lowercase hex with no separators. + /// + /// Validates that the resulting thumbprint has the expected length for SHA-256 (64 hex chars). + pub fn normalize_sha256_thumbprint(thumb: &str) -> Result { + let normalized = thumb + .chars() + .filter(|c| c.is_ascii_hexdigit()) + .map(|mut c| { + c.make_ascii_lowercase(); + c + }) + .collect::(); + + if normalized.len() != EXPECTED_SHA256_LENGTH { + return Err(ThumbprintLengthError { + actual_length: normalized.len(), + }); + } + + Ok(Sha256Thumbprint(normalized)) + } + + /// Compute SHA-256 thumbprint of certificate DER bytes. + pub fn compute_sha256_thumbprint(cert_der: &[u8]) -> Sha256Thumbprint { + use sha2::{Digest, Sha256}; + let hash = Sha256::digest(cert_der); + Sha256Thumbprint(hex::encode(hash)) + } + + #[cfg(test)] + mod tests { + #![allow(clippy::unwrap_used, reason = "allowed in tests")] + + use rstest::rstest; + + use super::*; + + #[rstest] + #[case("3a7fb2c45e8d9f1a2b3c4d5e6f7a8b9cadbecfd0e1f2031425364758697a8b9c")] + #[case("3A7FB2C45E8D9F1A2B3C4D5E6F7A8B9CADBECFD0E1F2031425364758697A8B9C")] + #[case("3a7Fb2C45E8d9f1a2b3c4D5E6f7a8b9CAdbecfd0E1f2031425364758697A8b9c")] + #[case("3A 7F B2 C4 5E 8D 9F 1A 2B 3C 4D 5E 6F 7A 8B 9C AD BE CF D0 E1 F2 03 14 25 36 47 58 69 7A 8B 9C")] + #[case("3A:7F:B2:C4:5E:8D:9F:1A:2B:3C:4D:5E:6F:7A:8B:9C:AD:BE:CF:D0:E1:F2:03:14:25:36:47:58:69:7A:8B:9C")] + #[case("3a:7F-B2.C4_5E:8d:9f_1a-2b:3c-4d.5e:6f:7a:8b:9c.ad:be:cf:d0.e1:f2:03:14:25-36-47-58_69-7A:8B:9C")] + fn test_normalize_thumbprint(#[case] input: &str) { + assert_eq!( + normalize_sha256_thumbprint(input).unwrap().as_str(), + "3a7fb2c45e8d9f1a2b3c4d5e6f7a8b9cadbecfd0e1f2031425364758697a8b9c" + ); + } + + #[test] + fn test_compute_sha256_thumbprint() { + // Test with known input. + let test_data = b"Hello, World!"; + let thumbprint = compute_sha256_thumbprint(test_data); + + // Expected SHA-256 of "Hello, World!". + let expected = "dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f"; + assert_eq!(thumbprint.as_str(), expected); + + // Test output format (lowercase hex, no separators). + assert!( + thumbprint + .as_str() + .chars() + .all(|c| c.is_ascii_hexdigit() && !c.is_uppercase()) + ); + assert_eq!(thumbprint.as_str().len(), 64); // SHA-256 is 32 bytes = 64 hex chars. + } + + #[test] + fn test_compute_sha256_thumbprint_deterministic() { + // Same input should always produce same thumbprint. + let test_data = b"test certificate data"; + let thumbprint1 = compute_sha256_thumbprint(test_data); + let thumbprint2 = compute_sha256_thumbprint(test_data); + assert_eq!(thumbprint1, thumbprint2); + } + } +} diff --git a/devolutions-gateway/src/token.rs b/devolutions-gateway/src/token.rs index fb234bec3..534667da9 100644 --- a/devolutions-gateway/src/token.rs +++ b/devolutions-gateway/src/token.rs @@ -18,6 +18,7 @@ use uuid::Uuid; use crate::recording::ActiveRecordings; use crate::session::DisconnectedInfo; use crate::target_addr::TargetAddr; +use crate::tls::thumbprint::Sha256Thumbprint; pub const MAX_SUBKEY_TOKEN_VALIDITY_DURATION_SECS: i64 = 60 * 60 * 2; // 2 hours @@ -420,6 +421,9 @@ pub struct AssociationTokenClaims { /// Unique ID for this token pub jti: Uuid, + + /// Optional SHA-256 thumbprint of target server certificate (for anchored TLS validation) + pub cert_thumb256: Option, } // ----- scope claims ----- // @@ -1304,6 +1308,8 @@ mod serde_impl { jet_reuse: ReconnectionPolicy, exp: i64, jti: Uuid, + #[serde(default)] + cert_thumb256: Option, } #[derive(Deserialize)] @@ -1411,6 +1417,7 @@ mod serde_impl { jet_reuse: self.jet_reuse, exp: self.exp, jti: self.jti, + cert_thumb256: self.cert_thumb256.as_ref().map(|thumb| SmolStr::new(thumb.as_str())), } .serialize(serializer) } @@ -1454,6 +1461,12 @@ mod serde_impl { jet_reuse: claims.jet_reuse, exp: claims.exp, jti: claims.jti, + cert_thumb256: claims + .cert_thumb256 + .as_deref() + .map(crate::tls::thumbprint::normalize_sha256_thumbprint) + .transpose() + .map_err(de::Error::custom)?, }) } } diff --git a/testsuite/Cargo.toml b/testsuite/Cargo.toml index b87bd3d2d..012c1c138 100644 --- a/testsuite/Cargo.toml +++ b/testsuite/Cargo.toml @@ -26,6 +26,8 @@ serde_json = "1" serde = { version = "1", features = ["derive"] } tempfile = "3" tokio = { version = "1", features = ["rt-multi-thread", "macros", "time", "net", "process"] } +typed-builder = "0.21" +tokio-tungstenite = { version = "0.26", features = ["rustls-tls-native-roots"] } [dev-dependencies] mcp-proxy.path = "../crates/mcp-proxy" diff --git a/testsuite/src/cli.rs b/testsuite/src/cli.rs index 36dc86c2a..526107648 100644 --- a/testsuite/src/cli.rs +++ b/testsuite/src/cli.rs @@ -32,6 +32,36 @@ pub fn jetsocat_tokio_cmd() -> tokio::process::Command { cmd } +static DGW_BIN_PATH: LazyLock = LazyLock::new(|| { + escargot::CargoBuild::new() + .manifest_path("../devolutions-gateway/Cargo.toml") + .bin("devolutions-gateway") + .current_release() + .current_target() + .run() + .expect("build Devolutions Gateway") + .path() + .to_path_buf() +}); + +pub fn dgw_assert_cmd() -> assert_cmd::Command { + let mut cmd = assert_cmd::Command::new(&*DGW_BIN_PATH); + cmd.env("RUST_BACKTRACE", "0"); + cmd +} + +pub fn dgw_cmd() -> std::process::Command { + let mut cmd = std::process::Command::new(&*DGW_BIN_PATH); + cmd.env("RUST_BACKTRACE", "0"); + cmd +} + +pub fn dgw_tokio_cmd() -> tokio::process::Command { + let mut cmd = tokio::process::Command::new(&*DGW_BIN_PATH); + cmd.env("RUST_BACKTRACE", "0"); + cmd +} + pub fn assert_stderr_eq(output: &assert_cmd::assert::Assert, expected: expect_test::Expect) { let stderr = std::str::from_utf8(&output.get_output().stderr).unwrap(); expected.assert_eq(stderr); diff --git a/testsuite/src/dgw_config.rs b/testsuite/src/dgw_config.rs new file mode 100644 index 000000000..840125dce --- /dev/null +++ b/testsuite/src/dgw_config.rs @@ -0,0 +1,113 @@ +use core::fmt; +use std::path::Path; + +use anyhow::Context as _; +use tempfile::TempDir; +use typed_builder::TypedBuilder; + +pub struct VerbosityProfile(&'static str); + +impl VerbosityProfile { + pub const DEFAULT: Self = Self("Default"); + pub const DEBUG: Self = Self("Debug"); + pub const TLS: Self = Self("Tls"); + pub const ALL: Self = Self("All"); + pub const QUIET: Self = Self("Quiet"); +} + +impl fmt::Display for VerbosityProfile { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +#[derive(TypedBuilder)] +pub struct DgwConfig { + #[builder(default, setter(into))] + tcp_port: Option, + #[builder(default, setter(into))] + http_port: Option, + #[builder(default = false)] + disable_token_validation: bool, + #[builder(default = VerbosityProfile::DEFAULT)] + verbosity_profile: VerbosityProfile, +} + +fn find_unused_port() -> u16 { + std::net::TcpListener::bind("127.0.0.1:0") + .unwrap() + .local_addr() + .unwrap() + .port() +} + +impl DgwConfig { + pub fn init(self) -> anyhow::Result { + DgwConfigHandle::init(self) + } +} + +pub struct DgwConfigHandle { + tempdir: TempDir, + tcp_port: u16, + http_port: u16, +} + +impl DgwConfigHandle { + pub fn init(config: DgwConfig) -> anyhow::Result { + let DgwConfig { + tcp_port, + http_port, + disable_token_validation, + verbosity_profile, + } = config; + + let tempdir = tempfile::tempdir().context("create tempdir")?; + let config_path = tempdir.path().join("gateway.json"); + + let tcp_port = tcp_port.unwrap_or_else(find_unused_port); + let http_port = http_port.unwrap_or_else(find_unused_port); + + let config = format!( + "{{ + \"ProvisionerPublicKeyData\": {{ + \"Value\": \"mMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA4vuqLOkl1pWobt6su1XO9VskgCAwevEGs6kkNjJQBwkGnPKYLmNF1E/af1yCocfVn/OnPf9e4x+lXVyZ6LMDJxFxu+axdgOq3Ld392J1iAEbfvwlyRFnEXFOJNyylqg3bY6LvnWHL/XZczVdMD9xYfq2sO9bg3xjRW4s7r9EEYOFjqVT3VFznH9iWJVtcSEKukmS/3uKoO6lGhacvu0HgjXXdgq0R8zvR4XRJ9Fcnf0f9Ypoc+i6L80NVjrRCeVOH+Ld/2fA9bocpfLarcVqG3RjS+qgOtpyCc0jWVFF4zaGQ7LUDFkEIYILkICeMMn2ll29hmZNzsJzZJ9s6NocgQIDAQAB\" + }}, + \"Listeners\": [ + {{ + \"InternalUrl\": \"tcp://127.0.0.1:{tcp_port}\", + \"ExternalUrl\": \"tcp://127.0.0.1:{tcp_port}\" + }}, + {{ + \"InternalUrl\": \"http://127.0.0.1:{http_port}\", + \"ExternalUrl\": \"http://127.0.0.1:{http_port}\" + }} + ], + \"VerbosityProfile\": \"{verbosity_profile}\", + \"__debug__\": {{ + \"disable_token_validation\": {disable_token_validation} + }} +}}" + ); + + std::fs::write(&config_path, config).with_context(|| format!("write config into {}", config_path.display()))?; + + Ok(Self { + tempdir, + tcp_port, + http_port, + }) + } + + pub fn config_dir(&self) -> &Path { + self.tempdir.path() + } + + pub fn tcp_port(&self) -> u16 { + self.tcp_port + } + + pub fn http_port(&self) -> u16 { + self.http_port + } +} diff --git a/testsuite/src/lib.rs b/testsuite/src/lib.rs index a22da9370..56ae206ba 100644 --- a/testsuite/src/lib.rs +++ b/testsuite/src/lib.rs @@ -5,5 +5,6 @@ #![allow(clippy::unwrap_used, reason = "test infrastructure can panic on errors")] pub mod cli; +pub mod dgw_config; pub mod mcp_client; pub mod mcp_server; diff --git a/testsuite/tests/cli/dgw/mod.rs b/testsuite/tests/cli/dgw/mod.rs new file mode 100644 index 000000000..78a7b4096 --- /dev/null +++ b/testsuite/tests/cli/dgw/mod.rs @@ -0,0 +1 @@ +mod tls_anchoring; diff --git a/testsuite/tests/cli/dgw/tls_anchoring.rs b/testsuite/tests/cli/dgw/tls_anchoring.rs new file mode 100644 index 000000000..1c53ce4b1 --- /dev/null +++ b/testsuite/tests/cli/dgw/tls_anchoring.rs @@ -0,0 +1,117 @@ +use anyhow::Context as _; +use rstest::rstest; +use testsuite::cli::dgw_tokio_cmd; +use testsuite::dgw_config::{DgwConfig, DgwConfigHandle}; +use tokio::process::Child; + +async fn start_gateway() -> anyhow::Result<(DgwConfigHandle, Child)> { + let config_handle = DgwConfig::builder() + .disable_token_validation(true) + .build() + .init() + .context("init config")?; + + // Start a Devolutions Gateway instance. + let process = dgw_tokio_cmd() + .env("DGATEWAY_CONFIG_PATH", config_handle.config_dir()) + .kill_on_drop(true) + .stdout(std::process::Stdio::piped()) + .spawn() + .context("failed to start Devolutions Gateway")?; + + // Give the server a moment to start. + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + Ok((config_handle, process)) +} + +/// Perform a WebSocket connection on the /jet/fwd/tcp endpoint. +async fn websocket_connect(port: u16, token: &str, session_id: &str) -> anyhow::Result<()> { + let url = format!("ws://127.0.0.1:{port}/jet/fwd/tls/{session_id}?token={token}"); + + // Try to connect with a timeout + let (_ws_stream, response) = + tokio::time::timeout(std::time::Duration::from_secs(5), tokio_tungstenite::connect_async(url)) + .await + .context("timeout")? + .context("websocket connection")?; + + println!("WebSocket connected successfully: {response:?}"); + + // Give the server a moment to perform the connection with the remote server. + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + + Ok(()) +} + +#[derive(Debug, PartialEq)] +enum TlsOutcome { + Failed, + Succeeded, +} + +async fn read_until_tls_done(mut logs: impl tokio::io::AsyncRead + Unpin) -> anyhow::Result { + use tokio::io::AsyncReadExt as _; + + let mut buf = Vec::new(); + + loop { + let n = logs.read_buf(&mut buf).await.context("read_buf")?; + + if n == 0 { + anyhow::bail!("eof"); + } + + let logs = String::from_utf8_lossy(&buf); + + if logs.contains("PASTE_THIS_THUMBPRINT_IN_RDM_CONNECTION") { + return Ok(TlsOutcome::Failed); + } else if logs.contains("WebSocket-TLS forwarding") { + return Ok(TlsOutcome::Succeeded); + } + } +} + +#[rstest] +#[case::self_signed_correct_thumb(token::SELF_SIGNED_WITH_CORRECT_THUMB, TlsOutcome::Succeeded)] +#[case::self_signed_wrong_thumb(token::SELF_SIGNED_WITH_WRONG_THUMB, TlsOutcome::Failed)] +#[case::self_signed_no_thumb(token::SELF_SIGNED_NO_THUMB, TlsOutcome::Failed)] +#[case::valid_cert_no_thumb(token::VALID_CERT_NO_THUMB, TlsOutcome::Succeeded)] +#[tokio::test] +async fn test(#[case] token: &str, #[case] expected_outcome: TlsOutcome) -> anyhow::Result<()> { + let (config_handle, mut process) = start_gateway().await?; + + let stdout = process.stdout.take().unwrap(); + + let connect_fut = websocket_connect(config_handle.http_port(), token, token::SESSION_ID); + let read_fut = read_until_tls_done(stdout); + + tokio::select! { + res = connect_fut => { + res.context("websocket connect")?; + anyhow::bail!("expected read future to terminate before connect future"); + } + res = read_fut => { + let outcome = res.context("read")?; + assert_eq!(outcome, expected_outcome); + } + } + + Ok(()) +} + +mod token { + pub(super) const SESSION_ID: &str = "897fd399-540c-4be3-84a1-47c73f68c7a4"; + + /// Token with correct thumbprint for self-signed.badssl.com + pub(super) const SELF_SIGNED_WITH_CORRECT_THUMB: &str = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImN0eSI6IkFTU09DSUFUSU9OIn0.eyJjZXJ0X3RodW1iMjU2IjoiMzkxYTIyOGUyZjQ4NjA2NDQwNTkyNjU1ODEzNTAxNThmNTUyMTNkODc0YzVmYmY1NzFjZThiZTYyYmZlY2Y1NCIsImRzdF9oc3QiOiJzZWxmLXNpZ25lZC5iYWRzc2wuY29tOjQ0MyIsImV4cCI6MTc2MjkzNzI5OCwiamV0X2FpZCI6Ijg5N2ZkMzk5LTU0MGMtNGJlMy04NGExLTQ3YzczZjY4YzdhNCIsImpldF9hcCI6InVua25vd24iLCJqZXRfY20iOiJmd2QiLCJqZXRfcmVjIjoibm9uZSIsImp0aSI6IjgwYTcxN2JmLTZlMzItNGEyMi05Yjk3LTVlYzFkNzk1YjVlMSIsIm5iZiI6MTc2MjkzNjM5OH0.ZHVtbXlfc2lnbmF0dXJl"; + + /// Token with wrong thumbprint for self-signed.badssl.com + pub(super) const SELF_SIGNED_WITH_WRONG_THUMB: &str = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImN0eSI6IkFTU09DSUFUSU9OIn0.eyJjZXJ0X3RodW1iMjU2IjoiYTkxYTIyODIyZjQ4NjA2NDQwNTkyNjU1ODExMTExNThmNTUyMTNkODc0YzVmYmY1NzFjZThiZTYzYmZlY2Y1NCIsImRzdF9oc3QiOiJzZWxmLXNpZ25lZC5iYWRzc2wuY29tOjQ0MyIsImV4cCI6MTc2MjkzODI5MywiamV0X2FpZCI6Ijg5N2ZkMzk5LTU0MGMtNGJlMy04NGExLTQ3YzczZjY4YzdhNCIsImpldF9hcCI6InVua25vd24iLCJqZXRfY20iOiJmd2QiLCJqZXRfcmVjIjoibm9uZSIsImp0aSI6IjRlMjZhNjM2LTA0MjUtNDNlMy1iMGZmLWYzZDk1ODhjZWY4YSIsIm5iZiI6MTc2MjkzNzM5M30.ZHVtbXlfc2lnbmF0dXJl"; + + /// Token without thumbprint for self-signed.badssl.com + pub(super) const SELF_SIGNED_NO_THUMB: &str = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImN0eSI6IkFTU09DSUFUSU9OIn0.eyJkc3RfaHN0Ijoic2VsZi1zaWduZWQuYmFkc3NsLmNvbTo0NDMiLCJleHAiOjE3NjI5Mzc0ODAsImpldF9haWQiOiI4OTdmZDM5OS01NDBjLTRiZTMtODRhMS00N2M3M2Y2OGM3YTQiLCJqZXRfYXAiOiJ1bmtub3duIiwiamV0X2NtIjoiZndkIiwiamV0X3JlYyI6Im5vbmUiLCJqdGkiOiI0ODdjZThiNS1lY2ZmLTRlY2QtYWE3ZC0wNTJkNThlM2U2YjEiLCJuYmYiOjE3NjI5MzY1ODB9.ZHVtbXlfc2lnbmF0dXJl"; + + /// Token without thumbprint for badssl.com (valid cert) + pub(super) const VALID_CERT_NO_THUMB: &str = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImN0eSI6IkFTU09DSUFUSU9OIn0.eyJkc3RfaHN0IjoiYmFkc3NsLmNvbTo0NDMiLCJleHAiOjE3NjI5Mzc1MjEsImpldF9haWQiOiI4OTdmZDM5OS01NDBjLTRiZTMtODRhMS00N2M3M2Y2OGM3YTQiLCJqZXRfYXAiOiJ1bmtub3duIiwiamV0X2NtIjoiZndkIiwiamV0X3JlYyI6Im5vbmUiLCJqdGkiOiI4YWUzMzkxNS00ZDNlLTQyYmItODBkNi0yYjQzYjIyN2QzYTQiLCJuYmYiOjE3NjI5MzY2MjF9.ZHVtbXlfc2lnbmF0dXJl"; +} diff --git a/testsuite/tests/cli/mod.rs b/testsuite/tests/cli/mod.rs index c3448dcd3..c895a483c 100644 --- a/testsuite/tests/cli/mod.rs +++ b/testsuite/tests/cli/mod.rs @@ -1 +1,2 @@ +mod dgw; mod jetsocat; diff --git a/tools/tokengen/src/lib.rs b/tools/tokengen/src/lib.rs index efc0ffbbb..3fdb532ee 100644 --- a/tools/tokengen/src/lib.rs +++ b/tools/tokengen/src/lib.rs @@ -29,7 +29,10 @@ pub struct AssociationClaims<'a> { pub jet_gw_id: Option, #[serde(skip_serializing_if = "Option::is_none")] pub jet_reuse: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub dst_hst: Option<&'a str>, + #[serde(skip_serializing_if = "Option::is_none")] + pub cert_thumb256: Option<&'a str>, #[serde(flatten)] pub creds: Option>, } @@ -215,6 +218,7 @@ pub enum SubCommandArgs { jet_aid: Option, jet_rec: bool, jet_reuse: Option, + cert_thumb256: Option, }, Rendezvous { jet_ap: Option, @@ -288,6 +292,7 @@ pub fn generate_token( jet_aid, jet_rec, jet_reuse, + cert_thumb256, } => { let claims = AssociationClaims { exp, @@ -305,6 +310,7 @@ pub fn generate_token( jet_ttl, jet_gw_id, jet_reuse, + cert_thumb256: cert_thumb256.as_deref(), creds: None, }; ("ASSOCIATION", serde_json::to_value(claims)?) @@ -329,6 +335,7 @@ pub fn generate_token( jet_ttl: None, jet_gw_id, jet_reuse: None, + cert_thumb256: None, creds: Some(CredsClaims { prx_usr: &prx_usr, prx_pwd: &prx_pwd, @@ -359,6 +366,7 @@ pub fn generate_token( jet_ttl: None, jet_gw_id, jet_reuse: None, + cert_thumb256: None, creds: None, }; ("ASSOCIATION", serde_json::to_value(claims)?) diff --git a/tools/tokengen/src/main.rs b/tools/tokengen/src/main.rs index cc15bb95b..c080b8ba1 100644 --- a/tools/tokengen/src/main.rs +++ b/tools/tokengen/src/main.rs @@ -50,6 +50,7 @@ fn sign( jet_aid, jet_rec, jet_reuse, + cert_thumb256, } => SubCommandArgs::Forward { dst_hst, jet_ap, @@ -57,6 +58,7 @@ fn sign( jet_aid, jet_rec, jet_reuse, + cert_thumb256, }, SignSubCommand::Rendezvous { jet_ap, @@ -189,6 +191,8 @@ enum SignSubCommand { jet_rec: bool, #[clap(long)] jet_reuse: Option, + #[clap(long)] + cert_thumb256: Option, }, Rendezvous { #[clap(long)] diff --git a/tools/tokengen/src/server/server_impl.rs b/tools/tokengen/src/server/server_impl.rs index 04742e3c3..400ba8bf9 100644 --- a/tools/tokengen/src/server/server_impl.rs +++ b/tools/tokengen/src/server/server_impl.rs @@ -102,6 +102,7 @@ pub(crate) async fn forward_handler( jet_aid: request.jet_aid, jet_rec: request.jet_rec, jet_reuse: request.jet_reuse, + cert_thumb256: request.cert_thumb256, }, ) .await @@ -282,6 +283,7 @@ pub(crate) struct ForwardRequest { jet_aid: Option, jet_rec: bool, jet_reuse: Option, + cert_thumb256: Option, } #[derive(Deserialize)]