diff --git a/src/ucode/databricks.py b/src/ucode/databricks.py index b346cef..5ed537c 100644 --- a/src/ucode/databricks.py +++ b/src/ucode/databricks.py @@ -278,9 +278,11 @@ def run( ) -def build_databricks_cli_env(workspace: str) -> dict[str, str]: +def build_databricks_cli_env(workspace: str, profile: str | None = None) -> dict[str, str]: env = os.environ.copy() env["DATABRICKS_HOST"] = workspace + if profile is None: + env.pop("DATABRICKS_CONFIG_PROFILE", None) return env @@ -385,7 +387,7 @@ def has_valid_databricks_auth(workspace: str, profile: str | None = None) -> boo # to disambiguate without --profile, so resolve it from the host here. profile = profile or find_profile_name_for_host(workspace) try: - env = build_databricks_cli_env(workspace) + env = build_databricks_cli_env(workspace, profile) result = run( [ "databricks", @@ -492,7 +494,7 @@ def run_databricks_login(workspace: str, profile: str | None = None) -> None: workspace, *_profile_args(profile_name), ] - run(cmd, env=build_databricks_cli_env(workspace), timeout=300) + run(cmd, env=build_databricks_cli_env(workspace, profile_name), timeout=300) except subprocess.CalledProcessError as exc: raise RuntimeError("`databricks auth login` failed.") from exc except subprocess.TimeoutExpired as exc: @@ -531,7 +533,7 @@ def get_databricks_token( # See has_valid_databricks_auth: resolve the profile from the host when # the caller didn't supply one, so duplicate-host cfgs don't break us. profile = profile or find_profile_name_for_host(workspace) - env = build_databricks_cli_env(workspace) + env = build_databricks_cli_env(workspace, profile) cmd = [ "databricks", "auth", @@ -595,12 +597,13 @@ def _fetch() -> str: token = _fetch() if not token: + profile_name = profile or find_profile_name_for_host(workspace) stale_profile_hint = "" - if profile: + if profile_name: stale_profile_hint = ( " The saved Databricks CLI profile may be stale or invalid. Try:\n" - f" databricks auth logout --profile {profile}\n" - f" databricks auth login --host {workspace} --profile {profile}" + f" databricks auth logout --profile {profile_name}\n" + f" databricks auth login --host {workspace} --profile {profile_name}" ) raise RuntimeError( f"Databricks CLI returned no access token for {workspace}. " diff --git a/tests/test_databricks.py b/tests/test_databricks.py index 806026f..49306cd 100644 --- a/tests/test_databricks.py +++ b/tests/test_databricks.py @@ -51,6 +51,22 @@ def test_sets_databricks_host(self): env = build_databricks_cli_env(WS) assert env["DATABRICKS_HOST"] == WS + def test_strips_ambient_profile_without_explicit_profile(self, monkeypatch): + monkeypatch.setenv("DATABRICKS_CONFIG_PROFILE", "other-workspace") + + env = build_databricks_cli_env(WS) + + assert env["DATABRICKS_HOST"] == WS + assert "DATABRICKS_CONFIG_PROFILE" not in env + + def test_preserves_ambient_profile_with_explicit_profile(self, monkeypatch): + monkeypatch.setenv("DATABRICKS_CONFIG_PROFILE", "other-workspace") + + env = build_databricks_cli_env(WS, profile="stablebox") + + assert env["DATABRICKS_HOST"] == WS + assert env["DATABRICKS_CONFIG_PROFILE"] == "other-workspace" + class TestBuildToolBaseUrl: def test_codex(self): @@ -275,6 +291,36 @@ def test_returns_token_on_success(self, tmp_path, monkeypatch): token = get_databricks_token(WS) assert token == "good-token" + def test_strips_ambient_profile_when_profile_not_provided(self, tmp_path, monkeypatch): + profile_log = tmp_path / "profile" + env = self._fake_databricks( + tmp_path, + f'printf "%s" "${{DATABRICKS_CONFIG_PROFILE:-}}" > {profile_log}\n' + 'echo \'{"access_token": "good-token", "token_type": "Bearer"}\'', + ) + env["DATABRICKS_CONFIG_PROFILE"] = "other-workspace" + monkeypatch.setattr("os.environ", env) + + token = get_databricks_token(WS) + + assert token == "good-token" + assert profile_log.read_text() == "" + + def test_has_valid_auth_strips_ambient_profile_without_explicit_profile( + self, tmp_path, monkeypatch + ): + profile_log = tmp_path / "profile" + env = self._fake_databricks( + tmp_path, + f'printf "%s" "${{DATABRICKS_CONFIG_PROFILE:-}}" > {profile_log}\n' + 'echo \'{"access_token": "good-token", "token_type": "Bearer"}\'', + ) + env["DATABRICKS_CONFIG_PROFILE"] = "other-workspace" + monkeypatch.setattr("os.environ", env) + + assert db_mod.has_valid_databricks_auth(WS) + assert profile_log.read_text() == "" + def test_reauths_and_retries_when_token_empty(self, tmp_path, monkeypatch): call_count = tmp_path / "calls" call_count.write_text("0")