diff --git a/api/src/auth.rs b/api/src/auth.rs new file mode 100644 index 00000000..d3445e79 --- /dev/null +++ b/api/src/auth.rs @@ -0,0 +1,119 @@ +/// Authentication credential for Trunk API requests. +/// +/// Two flavors: +/// * `Token` — an org API token. Sent on every endpoint via the `x-api-token` +/// header. Used by regular CI on the upstream repo where repo secrets are +/// available. +/// * `PublicRepoId` — a non-secret per-repo identifier (the 8-character value +/// from the Trunk settings UI). Sent via the `X-Trunk-Public-Repo-Id` header +/// on the two endpoints that accept it (`createBundleUpload` and +/// `getQuarantineConfig`). Used on fork-PR runs where secrets are +/// unavailable. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum TrunkApiCredential { + Token(String), + PublicRepoId(String), +} + +pub const PUBLIC_REPO_ID_HEADER: &str = "x-trunk-public-repo-id"; +pub const API_TOKEN_HEADER: &str = "x-api-token"; + +impl TrunkApiCredential { + /// Resolve a credential from a token and a public-repo-id, with token-first + /// ordering. Empty / whitespace-only strings count as absent. + /// + /// Token-first matters: it preserves existing behaviour for non-fork CI + /// runs that have an org token configured, and only falls back to the + /// public-id when the token is genuinely absent. + pub fn resolve(token: Option<&str>, public_repo_id: Option<&str>) -> Option { + let cleaned_token = token.map(str::trim).filter(|s| !s.is_empty()); + if let Some(token) = cleaned_token { + return Some(Self::Token(token.to_string())); + } + let cleaned_public_id = public_repo_id.map(str::trim).filter(|s| !s.is_empty()); + if let Some(public_id) = cleaned_public_id { + return Some(Self::PublicRepoId(public_id.to_string())); + } + None + } + + pub fn header_name(&self) -> &'static str { + match self { + Self::Token(_) => API_TOKEN_HEADER, + Self::PublicRepoId(_) => PUBLIC_REPO_ID_HEADER, + } + } + + pub fn header_value(&self) -> &str { + match self { + Self::Token(value) | Self::PublicRepoId(value) => value, + } + } + + pub fn is_token(&self) -> bool { + matches!(self, Self::Token(_)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn token_only_returns_token() { + let cred = TrunkApiCredential::resolve(Some("abc"), None).unwrap(); + assert_eq!(cred, TrunkApiCredential::Token("abc".to_string())); + } + + #[test] + fn public_id_only_returns_public_id() { + let cred = TrunkApiCredential::resolve(None, Some("abcd1234")).unwrap(); + assert_eq!( + cred, + TrunkApiCredential::PublicRepoId("abcd1234".to_string()) + ); + } + + #[test] + fn token_wins_when_both_present() { + let cred = TrunkApiCredential::resolve(Some("the-token"), Some("abcd1234")).unwrap(); + assert_eq!(cred, TrunkApiCredential::Token("the-token".to_string())); + } + + #[test] + fn empty_token_falls_through_to_public_id() { + let cred = TrunkApiCredential::resolve(Some(""), Some("abcd1234")).unwrap(); + assert_eq!( + cred, + TrunkApiCredential::PublicRepoId("abcd1234".to_string()) + ); + } + + #[test] + fn whitespace_token_falls_through_to_public_id() { + let cred = TrunkApiCredential::resolve(Some(" "), Some("abcd1234")).unwrap(); + assert_eq!( + cred, + TrunkApiCredential::PublicRepoId("abcd1234".to_string()) + ); + } + + #[test] + fn neither_returns_none() { + assert!(TrunkApiCredential::resolve(None, None).is_none()); + assert!(TrunkApiCredential::resolve(Some(""), Some("")).is_none()); + assert!(TrunkApiCredential::resolve(Some(" "), Some(" ")).is_none()); + } + + #[test] + fn header_name_matches_credential_kind() { + assert_eq!( + TrunkApiCredential::Token("x".into()).header_name(), + "x-api-token" + ); + assert_eq!( + TrunkApiCredential::PublicRepoId("x".into()).header_name(), + "x-trunk-public-repo-id" + ); + } +} diff --git a/api/src/client.rs b/api/src/client.rs index a09b7692..d69863d8 100644 --- a/api/src/client.rs +++ b/api/src/client.rs @@ -9,6 +9,7 @@ use reqwest::{Client, Response, StatusCode, header}; use serde::de::DeserializeOwned; use tokio::fs; +use crate::auth::TrunkApiCredential; use crate::call_api::CallApi; use crate::message; @@ -49,7 +50,10 @@ pub struct ApiClient { pub telemetry_host: String, s3_client: Client, trunk_api_client: Client, - telemetry_client: Client, + /// `None` when the CLI is running with only a public-repo-id (fork PR + /// path) — the telemetry endpoint does not accept that header, so we + /// short-circuit telemetry sends rather than fire 401s. + telemetry_client: Option, version_path_prefix: String, org_url_slug: String, render_sender: Option>, @@ -66,20 +70,27 @@ impl ApiClient { const TRUNK_API_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30); // This should always be fast const TRUNK_TELEMETRY_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(1); - const TRUNK_API_TOKEN_HEADER: &'static str = "x-api-token"; - pub fn new>( - api_token: T, - org_url_slug: T, + pub fn new>( + auth: TrunkApiCredential, + org_url_slug: S, render_sender: Option>, ) -> anyhow::Result { let org_url_slug = String::from(org_url_slug.as_ref()); - let api_token = api_token.as_ref(); - if api_token.trim().is_empty() { - return Err(anyhow::anyhow!("Trunk API token is required.")); + + let auth_header_value = HeaderValue::from_str(auth.header_value()).map_err(|_| { + anyhow::Error::msg(if auth.is_token() { + "Trunk API token is not ASCII" + } else { + "Trunk public repo id is not ASCII" + }) + })?; + + if !auth.is_token() { + tracing::info!( + "Using X-Trunk-Public-Repo-Id auth (TRUNK_API_TOKEN not set; assuming fork PR)" + ); } - let api_token_header_value = HeaderValue::from_str(api_token) - .map_err(|_| anyhow::Error::msg("Trunk API token is not ASCII"))?; let api_host = get_api_host(); tracing::debug!("Using public api address {}", api_host); @@ -101,26 +112,33 @@ impl ApiClient { header::CONTENT_TYPE, HeaderValue::from_static("application/json"), ); - trunk_api_client_default_headers - .append(Self::TRUNK_API_TOKEN_HEADER, api_token_header_value.clone()); + trunk_api_client_default_headers.append(auth.header_name(), auth_header_value.clone()); let trunk_api_client = Client::builder() .timeout(Self::TRUNK_API_TIMEOUT) .default_headers(trunk_api_client_default_headers) .build()?; - let mut telemetry_client_default_headers = HeaderMap::new(); - telemetry_client_default_headers.append( - header::CONTENT_TYPE, - HeaderValue::from_static("application/x-protobuf"), - ); - telemetry_client_default_headers - .append(Self::TRUNK_API_TOKEN_HEADER, api_token_header_value); + // The telemetry endpoint does not accept the public-repo-id auth path, + // so when we only have a public-repo-id we skip building the telemetry + // client and short-circuit telemetry sends in `telemetry_upload_metrics`. + let telemetry_client = if auth.is_token() { + let mut telemetry_client_default_headers = HeaderMap::new(); + telemetry_client_default_headers.append( + header::CONTENT_TYPE, + HeaderValue::from_static("application/x-protobuf"), + ); + telemetry_client_default_headers.append(auth.header_name(), auth_header_value); - let telemetry_client = Client::builder() - .timeout(Self::TRUNK_TELEMETRY_TIMEOUT) - .default_headers(telemetry_client_default_headers) - .build()?; + Some( + Client::builder() + .timeout(Self::TRUNK_TELEMETRY_TIMEOUT) + .default_headers(telemetry_client_default_headers) + .build()?, + ) + } else { + None + }; let mut s3_client_default_headers = HeaderMap::new(); s3_client_default_headers.append( @@ -273,13 +291,16 @@ impl ApiClient { &self, request: &message::TelemetryUploadMetricsRequest, ) -> anyhow::Result<()> { + let Some(telemetry_client) = self.telemetry_client.clone() else { + tracing::debug!("Skipping telemetry upload: no API token configured"); + return Ok(()); + }; CallApi { action: || async { if std::env::var("DISABLE_TELEMETRY").is_ok() { return Ok(()); } - let response = self - .telemetry_client + let response = telemetry_client .post(format!( "{}{}/flakytests-cli/upload-metrics", self.telemetry_host, self.version_path_prefix @@ -461,8 +482,13 @@ mod tests { use test_utils::mock_server::MockServerBuilder; use super::ApiClient; + use crate::auth::TrunkApiCredential; use crate::message; + fn mock_token_auth() -> TrunkApiCredential { + TrunkApiCredential::Token(String::from("mock-token")) + } + #[tokio::test(start_paused = true)] async fn does_not_retry_on_ok_501() { let mut mock_server_builder = MockServerBuilder::new(); @@ -486,7 +512,7 @@ mod tests { let state = mock_server_builder.spawn_mock_server().await; let mut api_client = - ApiClient::new(String::from("mock-token"), String::from("mock-org"), None).unwrap(); + ApiClient::new(mock_token_auth(), String::from("mock-org"), None).unwrap(); api_client.api_host.clone_from(&state.host); assert!( @@ -534,7 +560,7 @@ mod tests { let state = mock_server_builder.spawn_mock_server().await; let mut api_client = - ApiClient::new(String::from("mock-token"), String::from("mock-org"), None).unwrap(); + ApiClient::new(mock_token_auth(), String::from("mock-org"), None).unwrap(); api_client.api_host.clone_from(&state.host); assert!( @@ -577,7 +603,7 @@ mod tests { let state = mock_server_builder.spawn_mock_server().await; let mut api_client = - ApiClient::new(String::from("mock-token"), String::from("mock-org"), None).unwrap(); + ApiClient::new(mock_token_auth(), String::from("mock-org"), None).unwrap(); api_client.api_host.clone_from(&state.host); assert!( diff --git a/api/src/lib.rs b/api/src/lib.rs index ee4fdaf2..b59393e2 100644 --- a/api/src/lib.rs +++ b/api/src/lib.rs @@ -1,3 +1,4 @@ +pub mod auth; mod call_api; pub mod client; pub mod message; diff --git a/cli/src/upload_command.rs b/cli/src/upload_command.rs index 41b2805b..57ae8376 100644 --- a/cli/src/upload_command.rs +++ b/cli/src/upload_command.rs @@ -3,6 +3,7 @@ use std::env; use std::path::PathBuf; use std::sync::mpsc::Sender; +use api::auth::TrunkApiCredential; use api::client::{ApiClient, ApiErrorEndpoint}; use api::{client::get_api_host, urls::url_for_test_case}; use bundle::{BundleMeta, BundlerUtil, Test, unzip_tarball}; @@ -81,11 +82,17 @@ pub struct UploadArgs { pub test_collection_short_id: Option, #[arg( long, - required = true, env = constants::TRUNK_API_TOKEN_ENV, - help = "Organization token. Defaults to TRUNK_API_TOKEN env var." + help = "Organization token. Defaults to TRUNK_API_TOKEN env var. Required unless --public-repo-id is set (e.g. for fork PRs).", + default_value = "" )] pub token: String, + #[arg( + long, + env = constants::TRUNK_PUBLIC_REPO_ID_ENV, + help = "Non-secret per-repo identifier from the Trunk settings page. Used in place of --token on workflows triggered from forked pull requests where the org API token is unavailable." + )] + pub public_repo_id: Option, #[arg( long, env = constants::TRUNK_REPO_ROOT_ENV, @@ -396,7 +403,16 @@ pub async fn run_upload( ); } - let api_client = ApiClient::new(&upload_args.token, &upload_args.org_url_slug, render_sender)?; + let auth = TrunkApiCredential::resolve( + Some(upload_args.token.as_str()), + upload_args.public_repo_id.as_deref(), + ) + .ok_or_else(|| { + anyhow::anyhow!( + "Authentication required: set --token / TRUNK_API_TOKEN, or --public-repo-id / TRUNK_PUBLIC_REPO_ID for fork-PR runs." + ) + })?; + let api_client = ApiClient::new(auth, &upload_args.org_url_slug, render_sender)?; let PreTestContext { mut meta, diff --git a/constants/src/lib.rs b/constants/src/lib.rs index 5d9f7661..ddd68140 100644 --- a/constants/src/lib.rs +++ b/constants/src/lib.rs @@ -26,6 +26,7 @@ pub const TRUNK_API_CLIENT_RETRY_COUNT_ENV: &str = "TRUNK_API_CLIENT_RETRY_COUNT // Trunk CLI environment variable names for configuration overrides pub const TRUNK_API_TOKEN_ENV: &str = "TRUNK_API_TOKEN"; +pub const TRUNK_PUBLIC_REPO_ID_ENV: &str = "TRUNK_PUBLIC_REPO_ID"; pub const TRUNK_ORG_URL_SLUG_ENV: &str = "TRUNK_ORG_URL_SLUG"; pub const TRUNK_TEST_COLLECTION_SHORT_ID_ENV: &str = "TRUNK_TEST_COLLECTION_SHORT_ID"; pub const TRUNK_REPO_ROOT_ENV: &str = "TRUNK_REPO_ROOT"; diff --git a/test_report/src/report.rs b/test_report/src/report.rs index a3254d4b..8c3797b1 100644 --- a/test_report/src/report.rs +++ b/test_report/src/report.rs @@ -6,7 +6,7 @@ use std::{ time::{Duration, SystemTime}, }; -use api::{client::ApiClient, message}; +use api::{auth::TrunkApiCredential, client::ApiClient, message}; use bundle::BundleMetaDebugProps; use bundle::Test; use chrono::prelude::*; @@ -304,16 +304,21 @@ impl MutTestReport { file: Option, ) -> IsQuarantinedResult { let token = self.get_token(); + let public_repo_id = self.get_public_repo_id(); let org_url_slug = self.get_org_url_slug(); - if token.is_empty() { - tracing::warn!("Not checking quarantine status because TRUNK_API_TOKEN is empty"); + let Some(auth) = + TrunkApiCredential::resolve(Some(token.as_str()), Some(public_repo_id.as_str())) + else { + tracing::warn!( + "Not checking quarantine status because neither TRUNK_API_TOKEN nor TRUNK_PUBLIC_REPO_ID is set" + ); return IsQuarantinedResult::default(); - } + }; if org_url_slug.is_empty() { tracing::warn!("Not checking quarantine status because TRUNK_ORG_URL_SLUG is empty"); return IsQuarantinedResult::default(); } - let api_client = ApiClient::new(token, org_url_slug.clone(), None); + let api_client = ApiClient::new(auth, org_url_slug.clone(), None); let use_uncloned_repo = env::var(constants::TRUNK_USE_UNCLONED_REPO_ENV) .ok() .and_then(|v| v.parse().ok()) @@ -526,6 +531,10 @@ impl MutTestReport { env::var(constants::TRUNK_API_TOKEN_ENV).unwrap_or_default() } + fn get_public_repo_id(&self) -> String { + env::var(constants::TRUNK_PUBLIC_REPO_ID_ENV).unwrap_or_default() + } + // sends out to the trunk api pub fn publish(&self) -> bool { let release_name = format!("rspec-flaky-tests@{}", env!("CARGO_PKG_VERSION")); @@ -538,9 +547,14 @@ impl MutTestReport { } let token = self.get_token(); + let public_repo_id = self.get_public_repo_id(); let org_url_slug = self.get_org_url_slug(); - if token.is_empty() { - tracing::warn!("Not publishing results because TRUNK_API_TOKEN is empty"); + if TrunkApiCredential::resolve(Some(token.as_str()), Some(public_repo_id.as_str())) + .is_none() + { + tracing::warn!( + "Not publishing results because neither TRUNK_API_TOKEN nor TRUNK_PUBLIC_REPO_ID is set" + ); return false; } if org_url_slug.is_empty() { @@ -603,6 +617,9 @@ impl MutTestReport { env::var(constants::TRUNK_REPO_ROOT_ENV).ok(), true, ); + if !public_repo_id.is_empty() { + upload_args.public_repo_id = Some(public_repo_id); + } // Read additional environment variables using constants upload_args.repo_url = env::var(constants::TRUNK_REPO_URL_ENV).ok(); diff --git a/test_report/tests/report.rs b/test_report/tests/report.rs index 1622dd56..22d07036 100644 --- a/test_report/tests/report.rs +++ b/test_report/tests/report.rs @@ -6,10 +6,10 @@ use bundle::{BundleMeta, FileSetType, Test}; use constants::{ TRUNK_ALLOW_EMPTY_TEST_RESULTS_ENV, TRUNK_API_TOKEN_ENV, TRUNK_CODEOWNERS_PATH_ENV, TRUNK_DISABLE_QUARANTINING_ENV, TRUNK_DRY_RUN_ENV, TRUNK_ORG_URL_SLUG_ENV, TRUNK_PR_NUMBER_ENV, - TRUNK_PUBLIC_API_ADDRESS_ENV, TRUNK_QUARANTINED_TESTS_DISK_CACHE_TTL_SECS_ENV, - TRUNK_REPO_HEAD_AUTHOR_NAME_ENV, TRUNK_REPO_HEAD_BRANCH_ENV, TRUNK_REPO_HEAD_COMMIT_EPOCH_ENV, - TRUNK_REPO_HEAD_SHA_ENV, TRUNK_REPO_ROOT_ENV, TRUNK_REPO_URL_ENV, TRUNK_USE_UNCLONED_REPO_ENV, - TRUNK_VARIANT_ENV, + TRUNK_PUBLIC_API_ADDRESS_ENV, TRUNK_PUBLIC_REPO_ID_ENV, + TRUNK_QUARANTINED_TESTS_DISK_CACHE_TTL_SECS_ENV, TRUNK_REPO_HEAD_AUTHOR_NAME_ENV, + TRUNK_REPO_HEAD_BRANCH_ENV, TRUNK_REPO_HEAD_COMMIT_EPOCH_ENV, TRUNK_REPO_HEAD_SHA_ENV, + TRUNK_REPO_ROOT_ENV, TRUNK_REPO_URL_ENV, TRUNK_USE_UNCLONED_REPO_ENV, TRUNK_VARIANT_ENV, }; use context::repo::RepoUrlParts; use prost::Message; @@ -34,6 +34,7 @@ fn cleanup_env_vars() { unsafe { env::remove_var(TRUNK_PUBLIC_API_ADDRESS_ENV); env::remove_var(TRUNK_API_TOKEN_ENV); + env::remove_var(TRUNK_PUBLIC_REPO_ID_ENV); env::remove_var(TRUNK_ORG_URL_SLUG_ENV); env::remove_var(TRUNK_REPO_ROOT_ENV); env::remove_var(TRUNK_REPO_URL_ENV);