Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions Dockerfile.connector
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@ ARG VERSION

RUN apt-get update && apt-get install build-essential -y \
&& pip install uv

RUN uv pip install --system --no-cache-dir -U flytekit[connector]==$VERSION \
# Pin pendulum<3.0: Apache Airflow (via flytekitplugins-airflow) imports
# pendulum.tz.timezone() at module load time (airflow/settings.py).
# Pendulum 3.x changed the tz API, causing the connector to crash on startup:
# airflow/settings.py → TIMEZONE = pendulum.tz.timezone("UTC") → AttributeError
# Without this pin, uv resolves to pendulum 3.x which breaks the import chain:
# pyflyte serve connector → load_implicit_plugins → airflow → pendulum → crash
RUN uv pip install --system --no-cache-dir -U \
"pendulum>=2.0.0,<3.0" \
flytekit[connector]==$VERSION \
flytekitplugins-airflow==$VERSION \
flytekitplugins-bigquery==$VERSION \
flytekitplugins-k8sdataservice==$VERSION \
Expand All @@ -28,4 +35,4 @@ ARG VERSION

RUN uv pip install --system --no-cache-dir -U \
flytekitplugins-mmcloud==$VERSION \
flytekitplugins-spark==$VERSION
flytekitplugins-spark==$VERSION
224 changes: 205 additions & 19 deletions plugins/flytekit-spark/flytekitplugins/spark/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

DATABRICKS_API_ENDPOINT = "/api/2.1/jobs"
DEFAULT_DATABRICKS_INSTANCE_ENV_KEY = "FLYTE_DATABRICKS_INSTANCE"
DEFAULT_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER_ENV_KEY = "FLYTE_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER"


@dataclass
Expand All @@ -27,36 +28,221 @@ class DatabricksJobMetadata(ResourceMeta):
run_id: str


def _is_serverless_config(databricks_job: dict) -> bool:
"""
Detect if the configuration is for serverless compute.
Serverless is indicated by having environment_key or environments without cluster config.
"""
# Check if cluster config keys exist (even empty dict counts as cluster config)
has_cluster_config = "existing_cluster_id" in databricks_job or "new_cluster" in databricks_job
has_serverless_config = bool(databricks_job.get("environment_key") or databricks_job.get("environments"))
return not has_cluster_config and has_serverless_config


def _configure_serverless(databricks_job: dict, envs: dict) -> str:
"""
Configure serverless compute settings and return the environment_key to use.

Databricks serverless requires the 'environments' array to be defined in the job
submission. This function ensures the environments array exists and injects
Flyte environment variables.

Args:
databricks_job: The databricks job configuration dict
envs: Environment variables to inject

Returns:
The environment_key to use for the task
"""
environment_key = databricks_job.get("environment_key", "default")
environments = databricks_job.get("environments", [])

# Check if environment already exists in the array
env_exists = any(env.get("environment_key") == environment_key for env in environments)

if not env_exists:
# Create the environment entry - Databricks serverless requires environments
# to be defined in the job submission (not externally pre-configured)
new_env = {
"environment_key": environment_key,
"spec": {
"client": "1", # Required: Databricks serverless client version
}
}
environments.append(new_env)
databricks_job["environments"] = environments

# Inject Flyte environment variables into the environment spec
for env in environments:
if env.get("environment_key") == environment_key:
spec = env.setdefault("spec", {})
existing_env_vars = spec.get("environment_vars", {})
# Merge Flyte env vars with any existing ones (Flyte vars take precedence)
merged_env_vars = {**existing_env_vars, **{k: v for k, v in envs.items()}}
spec["environment_vars"] = merged_env_vars
break

# Remove environment_key from top level (it's now in the task definition)
databricks_job.pop("environment_key", None)

return environment_key


def _configure_classic_cluster(databricks_job: dict, custom: dict, container, envs: dict) -> None:
"""
Configure classic compute (existing cluster or new cluster).

Args:
databricks_job: The databricks job configuration dict
custom: The custom config from task template
container: The container config from task template
envs: Environment variables to inject
"""
if databricks_job.get("existing_cluster_id") is not None:
# Using an existing cluster, no additional configuration needed
return

new_cluster = databricks_job.get("new_cluster")
if new_cluster is None:
raise ValueError(
"Either existing_cluster_id, new_cluster, environment_key, or environments must be specified"
)

if not new_cluster.get("docker_image"):
new_cluster["docker_image"] = {"url": container.image}
if not new_cluster.get("spark_conf"):
new_cluster["spark_conf"] = custom.get("sparkConf", {})
if not new_cluster.get("spark_env_vars"):
new_cluster["spark_env_vars"] = {k: v for k, v in envs.items()}
else:
new_cluster["spark_env_vars"].update({k: v for k, v in envs.items()})


def _get_databricks_job_spec(task_template: TaskTemplate) -> dict:
custom = task_template.custom
container = task_template.container
envs = task_template.container.env
envs[FLYTE_FAIL_ON_ERROR] = "true"
databricks_job = custom["databricksConf"]
if databricks_job.get("existing_cluster_id") is None:
new_cluster = databricks_job.get("new_cluster")
if new_cluster is None:
raise ValueError("Either existing_cluster_id or new_cluster must be specified")
if not new_cluster.get("docker_image"):
new_cluster["docker_image"] = {"url": container.image}
if not new_cluster.get("spark_conf"):
new_cluster["spark_conf"] = custom.get("sparkConf", {})
if not new_cluster.get("spark_env_vars"):
new_cluster["spark_env_vars"] = {k: v for k, v in envs.items()}

# Check if this is a notebook task
notebook_path = custom.get("notebookPath")
notebook_base_parameters = custom.get("notebookBaseParameters", {})

# Determine compute mode and configure accordingly
is_serverless = _is_serverless_config(databricks_job)

if notebook_path:
# Notebook task - runs a Databricks notebook
notebook_task = {
"notebook_path": notebook_path,
}
if notebook_base_parameters:
notebook_task["base_parameters"] = notebook_base_parameters

# Check if notebook should be sourced from git
user_git_source = databricks_job.get("git_source")
if user_git_source:
notebook_task["source"] = "GIT"
# Set git_source at job level
databricks_job["git_source"] = user_git_source

if is_serverless:
# Serverless notebook task
environment_key = _configure_serverless(databricks_job, envs)

task_def = {
"task_key": "flyte_notebook_task",
"notebook_task": notebook_task,
"environment_key": environment_key,
}
databricks_job["tasks"] = [task_def]
databricks_job.pop("environment_key", None)
else:
new_cluster["spark_env_vars"].update({k: v for k, v in envs.items()})
# https://docs.databricks.com/api/workspace/jobs/submit
databricks_job["spark_python_task"] = {
"python_file": "flytekitplugins/databricks/entrypoint.py",
"source": "GIT",
"parameters": container.args,
}
databricks_job["git_source"] = {
# Classic compute notebook task
_configure_classic_cluster(databricks_job, custom, container, envs)
databricks_job["notebook_task"] = notebook_task

# Clean up git_source from databricks_job if it was there (already set at job level)
databricks_job.pop("git_source", None)
if user_git_source:
databricks_job["git_source"] = user_git_source

return databricks_job

# Python file task (original behavior)
# Allow custom git_source and python_file override from user config
user_git_source = databricks_job.get("git_source")
user_python_file = databricks_job.get("python_file")

# Default entrypoints from the flytetools repo.
# Both classic and serverless use the same repo; only the python_file differs.
default_git_source = {
"git_url": "https://github.com/flyteorg/flytetools",
"git_provider": "gitHub",
# https://github.com/flyteorg/flytetools/commit/572298df1f971fb58c258398bd70a6372f811c96
"git_commit": "572298df1f971fb58c258398bd70a6372f811c96",
}
default_classic_python_file = "flytekitplugins/databricks/entrypoint.py"
default_serverless_python_file = "flytekitplugins/databricks/entrypoint_serverless.py"

if is_serverless:
# Serverless compute - use flytetools serverless entrypoint by default
git_source = user_git_source if user_git_source else default_git_source
python_file = user_python_file if user_python_file else default_serverless_python_file

# Serverless requires multi-task format with tasks array
environment_key = _configure_serverless(databricks_job, envs)

# Build parameters list - append credential provider if specified
# This allows the entrypoint to receive the credential provider via command line
# We append at the END to avoid breaking pyflyte-fast-execute which must be first
parameters = list(container.args) if container.args else []

# Resolve service credential provider: task config > env var
service_credential_provider = custom.get(
"databricksServiceCredentialProvider",
os.getenv(DEFAULT_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER_ENV_KEY)
)
if service_credential_provider:
# Append as a special argument that the entrypoint will parse and remove
parameters.append(f"--flyte-credential-provider={service_credential_provider}")

spark_python_task = {
"python_file": python_file,
"source": "GIT",
"parameters": parameters,
}

# Build the task definition for serverless
task_def = {
"task_key": "flyte_task",
"spark_python_task": spark_python_task,
"environment_key": environment_key,
}

# Add tasks array for serverless (required by Databricks API)
databricks_job["tasks"] = [task_def]

# Remove environment_key from top level (it's now in the task)
databricks_job.pop("environment_key", None)
else:
# Classic compute - use flytetools entrypoint by default
git_source = user_git_source if user_git_source else default_git_source
python_file = user_python_file if user_python_file else default_classic_python_file

spark_python_task = {
"python_file": python_file,
"source": "GIT",
"parameters": container.args,
}

_configure_classic_cluster(databricks_job, custom, container, envs)
databricks_job["spark_python_task"] = spark_python_task

# Set git_source (remove from user config if it was there to avoid duplication)
databricks_job.pop("git_source", None)
databricks_job.pop("python_file", None)
databricks_job["git_source"] = git_source

return databricks_job

Expand Down
Loading