Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions api/src/auth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/// Authentication credential for Trunk API requests.
///
/// Two flavors:
/// * `Token` — an org API token. Sent on every endpoint via the `x-api-token`
/// header. Used by regular CI on the upstream repo where repo secrets are
/// available.
/// * `PublicRepoId` — a non-secret per-repo identifier (the 8-character value
/// from the Trunk settings UI). Sent via the `X-Trunk-Public-Repo-Id` header
/// on the two endpoints that accept it (`createBundleUpload` and
/// `getQuarantineConfig`). Used on fork-PR runs where secrets are
/// unavailable.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TrunkApiCredential {
Token(String),
PublicRepoId(String),
}

pub const PUBLIC_REPO_ID_HEADER: &str = "x-trunk-public-repo-id";
pub const API_TOKEN_HEADER: &str = "x-api-token";

impl TrunkApiCredential {
/// Resolve a credential from a token and a public-repo-id, with token-first
/// ordering. Empty / whitespace-only strings count as absent.
///
/// Token-first matters: it preserves existing behaviour for non-fork CI
/// runs that have an org token configured, and only falls back to the
/// public-id when the token is genuinely absent.
pub fn resolve(token: Option<&str>, public_repo_id: Option<&str>) -> Option<Self> {
let cleaned_token = token.map(str::trim).filter(|s| !s.is_empty());
if let Some(token) = cleaned_token {
return Some(Self::Token(token.to_string()));
}
let cleaned_public_id = public_repo_id.map(str::trim).filter(|s| !s.is_empty());
if let Some(public_id) = cleaned_public_id {
return Some(Self::PublicRepoId(public_id.to_string()));
}
None
}

pub fn header_name(&self) -> &'static str {
match self {
Self::Token(_) => API_TOKEN_HEADER,
Self::PublicRepoId(_) => PUBLIC_REPO_ID_HEADER,
}
}

pub fn header_value(&self) -> &str {
match self {
Self::Token(value) | Self::PublicRepoId(value) => value,
}
}

pub fn is_token(&self) -> bool {
matches!(self, Self::Token(_))
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn token_only_returns_token() {
let cred = TrunkApiCredential::resolve(Some("abc"), None).unwrap();
assert_eq!(cred, TrunkApiCredential::Token("abc".to_string()));
}

#[test]
fn public_id_only_returns_public_id() {
let cred = TrunkApiCredential::resolve(None, Some("abcd1234")).unwrap();
assert_eq!(
cred,
TrunkApiCredential::PublicRepoId("abcd1234".to_string())
);
}

#[test]
fn token_wins_when_both_present() {
let cred = TrunkApiCredential::resolve(Some("the-token"), Some("abcd1234")).unwrap();
assert_eq!(cred, TrunkApiCredential::Token("the-token".to_string()));
}

#[test]
fn empty_token_falls_through_to_public_id() {
let cred = TrunkApiCredential::resolve(Some(""), Some("abcd1234")).unwrap();
assert_eq!(
cred,
TrunkApiCredential::PublicRepoId("abcd1234".to_string())
);
}

#[test]
fn whitespace_token_falls_through_to_public_id() {
let cred = TrunkApiCredential::resolve(Some(" "), Some("abcd1234")).unwrap();
assert_eq!(
cred,
TrunkApiCredential::PublicRepoId("abcd1234".to_string())
);
}

#[test]
fn neither_returns_none() {
assert!(TrunkApiCredential::resolve(None, None).is_none());
assert!(TrunkApiCredential::resolve(Some(""), Some("")).is_none());
assert!(TrunkApiCredential::resolve(Some(" "), Some(" ")).is_none());
}

#[test]
fn header_name_matches_credential_kind() {
assert_eq!(
TrunkApiCredential::Token("x".into()).header_name(),
"x-api-token"
);
assert_eq!(
TrunkApiCredential::PublicRepoId("x".into()).header_name(),
"x-trunk-public-repo-id"
);
}
}
82 changes: 54 additions & 28 deletions api/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use reqwest::{Client, Response, StatusCode, header};
use serde::de::DeserializeOwned;
use tokio::fs;

use crate::auth::TrunkApiCredential;
use crate::call_api::CallApi;
use crate::message;

Expand Down Expand Up @@ -49,7 +50,10 @@ pub struct ApiClient {
pub telemetry_host: String,
s3_client: Client,
trunk_api_client: Client,
telemetry_client: Client,
/// `None` when the CLI is running with only a public-repo-id (fork PR
/// path) — the telemetry endpoint does not accept that header, so we
/// short-circuit telemetry sends rather than fire 401s.
telemetry_client: Option<Client>,
version_path_prefix: String,
org_url_slug: String,
render_sender: Option<Sender<DisplayMessage>>,
Expand All @@ -66,20 +70,27 @@ impl ApiClient {
const TRUNK_API_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
// This should always be fast
const TRUNK_TELEMETRY_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(1);
const TRUNK_API_TOKEN_HEADER: &'static str = "x-api-token";

pub fn new<T: AsRef<str>>(
api_token: T,
org_url_slug: T,
pub fn new<S: AsRef<str>>(
auth: TrunkApiCredential,
org_url_slug: S,
render_sender: Option<Sender<DisplayMessage>>,
) -> anyhow::Result<ApiClient> {
let org_url_slug = String::from(org_url_slug.as_ref());
let api_token = api_token.as_ref();
if api_token.trim().is_empty() {
return Err(anyhow::anyhow!("Trunk API token is required."));

let auth_header_value = HeaderValue::from_str(auth.header_value()).map_err(|_| {
anyhow::Error::msg(if auth.is_token() {
"Trunk API token is not ASCII"
} else {
"Trunk public repo id is not ASCII"
})
})?;

if !auth.is_token() {
tracing::info!(
"Using X-Trunk-Public-Repo-Id auth (TRUNK_API_TOKEN not set; assuming fork PR)"
);
}
let api_token_header_value = HeaderValue::from_str(api_token)
.map_err(|_| anyhow::Error::msg("Trunk API token is not ASCII"))?;

let api_host = get_api_host();
tracing::debug!("Using public api address {}", api_host);
Expand All @@ -101,26 +112,33 @@ impl ApiClient {
header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
trunk_api_client_default_headers
.append(Self::TRUNK_API_TOKEN_HEADER, api_token_header_value.clone());
trunk_api_client_default_headers.append(auth.header_name(), auth_header_value.clone());

let trunk_api_client = Client::builder()
.timeout(Self::TRUNK_API_TIMEOUT)
.default_headers(trunk_api_client_default_headers)
.build()?;

let mut telemetry_client_default_headers = HeaderMap::new();
telemetry_client_default_headers.append(
header::CONTENT_TYPE,
HeaderValue::from_static("application/x-protobuf"),
);
telemetry_client_default_headers
.append(Self::TRUNK_API_TOKEN_HEADER, api_token_header_value);
// The telemetry endpoint does not accept the public-repo-id auth path,
// so when we only have a public-repo-id we skip building the telemetry
// client and short-circuit telemetry sends in `telemetry_upload_metrics`.
let telemetry_client = if auth.is_token() {
let mut telemetry_client_default_headers = HeaderMap::new();
telemetry_client_default_headers.append(
header::CONTENT_TYPE,
HeaderValue::from_static("application/x-protobuf"),
);
telemetry_client_default_headers.append(auth.header_name(), auth_header_value);

let telemetry_client = Client::builder()
.timeout(Self::TRUNK_TELEMETRY_TIMEOUT)
.default_headers(telemetry_client_default_headers)
.build()?;
Some(
Client::builder()
.timeout(Self::TRUNK_TELEMETRY_TIMEOUT)
.default_headers(telemetry_client_default_headers)
.build()?,
)
} else {
None
};

let mut s3_client_default_headers = HeaderMap::new();
s3_client_default_headers.append(
Expand Down Expand Up @@ -273,13 +291,16 @@ impl ApiClient {
&self,
request: &message::TelemetryUploadMetricsRequest,
) -> anyhow::Result<()> {
let Some(telemetry_client) = self.telemetry_client.clone() else {
tracing::debug!("Skipping telemetry upload: no API token configured");
return Ok(());
};
CallApi {
action: || async {
if std::env::var("DISABLE_TELEMETRY").is_ok() {
return Ok(());
}
let response = self
.telemetry_client
let response = telemetry_client
.post(format!(
"{}{}/flakytests-cli/upload-metrics",
self.telemetry_host, self.version_path_prefix
Expand Down Expand Up @@ -461,8 +482,13 @@ mod tests {
use test_utils::mock_server::MockServerBuilder;

use super::ApiClient;
use crate::auth::TrunkApiCredential;
use crate::message;

fn mock_token_auth() -> TrunkApiCredential {
TrunkApiCredential::Token(String::from("mock-token"))
}

#[tokio::test(start_paused = true)]
async fn does_not_retry_on_ok_501() {
let mut mock_server_builder = MockServerBuilder::new();
Expand All @@ -486,7 +512,7 @@ mod tests {
let state = mock_server_builder.spawn_mock_server().await;

let mut api_client =
ApiClient::new(String::from("mock-token"), String::from("mock-org"), None).unwrap();
ApiClient::new(mock_token_auth(), String::from("mock-org"), None).unwrap();
api_client.api_host.clone_from(&state.host);

assert!(
Expand Down Expand Up @@ -534,7 +560,7 @@ mod tests {
let state = mock_server_builder.spawn_mock_server().await;

let mut api_client =
ApiClient::new(String::from("mock-token"), String::from("mock-org"), None).unwrap();
ApiClient::new(mock_token_auth(), String::from("mock-org"), None).unwrap();
api_client.api_host.clone_from(&state.host);

assert!(
Expand Down Expand Up @@ -577,7 +603,7 @@ mod tests {
let state = mock_server_builder.spawn_mock_server().await;

let mut api_client =
ApiClient::new(String::from("mock-token"), String::from("mock-org"), None).unwrap();
ApiClient::new(mock_token_auth(), String::from("mock-org"), None).unwrap();
api_client.api_host.clone_from(&state.host);

assert!(
Expand Down
1 change: 1 addition & 0 deletions api/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod auth;
mod call_api;
pub mod client;
pub mod message;
Expand Down
22 changes: 19 additions & 3 deletions cli/src/upload_command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::env;
use std::path::PathBuf;
use std::sync::mpsc::Sender;

use api::auth::TrunkApiCredential;
use api::client::{ApiClient, ApiErrorEndpoint};
use api::{client::get_api_host, urls::url_for_test_case};
use bundle::{BundleMeta, BundlerUtil, Test, unzip_tarball};
Expand Down Expand Up @@ -81,11 +82,17 @@ pub struct UploadArgs {
pub test_collection_short_id: Option<String>,
#[arg(
long,
required = true,
env = constants::TRUNK_API_TOKEN_ENV,
help = "Organization token. Defaults to TRUNK_API_TOKEN env var."
help = "Organization token. Defaults to TRUNK_API_TOKEN env var. Required unless --public-repo-id is set (e.g. for fork PRs).",
default_value = ""
)]
pub token: String,
#[arg(
long,
env = constants::TRUNK_PUBLIC_REPO_ID_ENV,
help = "Non-secret per-repo identifier from the Trunk settings page. Used in place of --token on workflows triggered from forked pull requests where the org API token is unavailable."
)]
pub public_repo_id: Option<String>,
#[arg(
long,
env = constants::TRUNK_REPO_ROOT_ENV,
Expand Down Expand Up @@ -396,7 +403,16 @@ pub async fn run_upload(
);
}

let api_client = ApiClient::new(&upload_args.token, &upload_args.org_url_slug, render_sender)?;
let auth = TrunkApiCredential::resolve(
Some(upload_args.token.as_str()),
upload_args.public_repo_id.as_deref(),
)
.ok_or_else(|| {
anyhow::anyhow!(
"Authentication required: set --token / TRUNK_API_TOKEN, or --public-repo-id / TRUNK_PUBLIC_REPO_ID for fork-PR runs."
)
})?;
let api_client = ApiClient::new(auth, &upload_args.org_url_slug, render_sender)?;

let PreTestContext {
mut meta,
Expand Down
1 change: 1 addition & 0 deletions constants/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub const TRUNK_API_CLIENT_RETRY_COUNT_ENV: &str = "TRUNK_API_CLIENT_RETRY_COUNT

// Trunk CLI environment variable names for configuration overrides
pub const TRUNK_API_TOKEN_ENV: &str = "TRUNK_API_TOKEN";
pub const TRUNK_PUBLIC_REPO_ID_ENV: &str = "TRUNK_PUBLIC_REPO_ID";
pub const TRUNK_ORG_URL_SLUG_ENV: &str = "TRUNK_ORG_URL_SLUG";
pub const TRUNK_TEST_COLLECTION_SHORT_ID_ENV: &str = "TRUNK_TEST_COLLECTION_SHORT_ID";
pub const TRUNK_REPO_ROOT_ENV: &str = "TRUNK_REPO_ROOT";
Expand Down
Loading
Loading