Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions src/ucode/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,11 @@ def run(
)


def build_databricks_cli_env(workspace: str) -> dict[str, str]:
def build_databricks_cli_env(workspace: str, profile: str | None = None) -> dict[str, str]:
env = os.environ.copy()
env["DATABRICKS_HOST"] = workspace
if profile is None:
env.pop("DATABRICKS_CONFIG_PROFILE", None)
return env


Expand Down Expand Up @@ -385,7 +387,7 @@ def has_valid_databricks_auth(workspace: str, profile: str | None = None) -> boo
# to disambiguate without --profile, so resolve it from the host here.
profile = profile or find_profile_name_for_host(workspace)
try:
env = build_databricks_cli_env(workspace)
env = build_databricks_cli_env(workspace, profile)
result = run(
[
"databricks",
Expand Down Expand Up @@ -492,7 +494,7 @@ def run_databricks_login(workspace: str, profile: str | None = None) -> None:
workspace,
*_profile_args(profile_name),
]
run(cmd, env=build_databricks_cli_env(workspace), timeout=300)
run(cmd, env=build_databricks_cli_env(workspace, profile_name), timeout=300)
except subprocess.CalledProcessError as exc:
raise RuntimeError("`databricks auth login` failed.") from exc
except subprocess.TimeoutExpired as exc:
Expand Down Expand Up @@ -531,7 +533,7 @@ def get_databricks_token(
# See has_valid_databricks_auth: resolve the profile from the host when
# the caller didn't supply one, so duplicate-host cfgs don't break us.
profile = profile or find_profile_name_for_host(workspace)
env = build_databricks_cli_env(workspace)
env = build_databricks_cli_env(workspace, profile)
cmd = [
"databricks",
"auth",
Expand Down Expand Up @@ -595,12 +597,13 @@ def _fetch() -> str:
token = _fetch()

if not token:
profile_name = profile or find_profile_name_for_host(workspace)
stale_profile_hint = ""
if profile:
if profile_name:
stale_profile_hint = (
" The saved Databricks CLI profile may be stale or invalid. Try:\n"
f" databricks auth logout --profile {profile}\n"
f" databricks auth login --host {workspace} --profile {profile}"
f" databricks auth logout --profile {profile_name}\n"
f" databricks auth login --host {workspace} --profile {profile_name}"
)
raise RuntimeError(
f"Databricks CLI returned no access token for {workspace}. "
Expand Down
46 changes: 46 additions & 0 deletions tests/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,22 @@ def test_sets_databricks_host(self):
env = build_databricks_cli_env(WS)
assert env["DATABRICKS_HOST"] == WS

def test_strips_ambient_profile_without_explicit_profile(self, monkeypatch):
monkeypatch.setenv("DATABRICKS_CONFIG_PROFILE", "other-workspace")

env = build_databricks_cli_env(WS)

assert env["DATABRICKS_HOST"] == WS
assert "DATABRICKS_CONFIG_PROFILE" not in env

def test_preserves_ambient_profile_with_explicit_profile(self, monkeypatch):
monkeypatch.setenv("DATABRICKS_CONFIG_PROFILE", "other-workspace")

env = build_databricks_cli_env(WS, profile="stablebox")

assert env["DATABRICKS_HOST"] == WS
assert env["DATABRICKS_CONFIG_PROFILE"] == "other-workspace"


class TestBuildToolBaseUrl:
def test_codex(self):
Expand Down Expand Up @@ -275,6 +291,36 @@ def test_returns_token_on_success(self, tmp_path, monkeypatch):
token = get_databricks_token(WS)
assert token == "good-token"

def test_strips_ambient_profile_when_profile_not_provided(self, tmp_path, monkeypatch):
profile_log = tmp_path / "profile"
env = self._fake_databricks(
tmp_path,
f'printf "%s" "${{DATABRICKS_CONFIG_PROFILE:-}}" > {profile_log}\n'
'echo \'{"access_token": "good-token", "token_type": "Bearer"}\'',
)
env["DATABRICKS_CONFIG_PROFILE"] = "other-workspace"
monkeypatch.setattr("os.environ", env)

token = get_databricks_token(WS)

assert token == "good-token"
assert profile_log.read_text() == ""

def test_has_valid_auth_strips_ambient_profile_without_explicit_profile(
self, tmp_path, monkeypatch
):
profile_log = tmp_path / "profile"
env = self._fake_databricks(
tmp_path,
f'printf "%s" "${{DATABRICKS_CONFIG_PROFILE:-}}" > {profile_log}\n'
'echo \'{"access_token": "good-token", "token_type": "Bearer"}\'',
)
env["DATABRICKS_CONFIG_PROFILE"] = "other-workspace"
monkeypatch.setattr("os.environ", env)

assert db_mod.has_valid_databricks_auth(WS)
assert profile_log.read_text() == ""

def test_reauths_and_retries_when_token_empty(self, tmp_path, monkeypatch):
call_count = tmp_path / "calls"
call_count.write_text("0")
Expand Down
Loading