diff --git a/Dockerfile.connector b/Dockerfile.connector index bed0fa0160..7eea1d42b9 100644 --- a/Dockerfile.connector +++ b/Dockerfile.connector @@ -7,8 +7,15 @@ ARG VERSION RUN apt-get update && apt-get install build-essential -y \ && pip install uv - -RUN uv pip install --system --no-cache-dir -U flytekit[connector]==$VERSION \ +# Pin pendulum<3.0: Apache Airflow (via flytekitplugins-airflow) imports +# pendulum.tz.timezone() at module load time (airflow/settings.py). +# Pendulum 3.x changed the tz API, causing the connector to crash on startup: +# airflow/settings.py → TIMEZONE = pendulum.tz.timezone("UTC") → AttributeError +# Without this pin, uv resolves to pendulum 3.x which breaks the import chain: +# pyflyte serve connector → load_implicit_plugins → airflow → pendulum → crash +RUN uv pip install --system --no-cache-dir -U \ + "pendulum>=2.0.0,<3.0" \ + flytekit[connector]==$VERSION \ flytekitplugins-airflow==$VERSION \ flytekitplugins-bigquery==$VERSION \ flytekitplugins-k8sdataservice==$VERSION \ @@ -28,4 +35,4 @@ ARG VERSION RUN uv pip install --system --no-cache-dir -U \ flytekitplugins-mmcloud==$VERSION \ - flytekitplugins-spark==$VERSION + flytekitplugins-spark==$VERSION \ No newline at end of file diff --git a/plugins/flytekit-spark/flytekitplugins/spark/connector.py b/plugins/flytekit-spark/flytekitplugins/spark/connector.py index 895c7d153d..628ef057f8 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/connector.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/connector.py @@ -19,6 +19,7 @@ DATABRICKS_API_ENDPOINT = "/api/2.1/jobs" DEFAULT_DATABRICKS_INSTANCE_ENV_KEY = "FLYTE_DATABRICKS_INSTANCE" +DEFAULT_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER_ENV_KEY = "FLYTE_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER" @dataclass @@ -27,36 +28,221 @@ class DatabricksJobMetadata(ResourceMeta): run_id: str +def _is_serverless_config(databricks_job: dict) -> bool: + """ + Detect if the configuration is for serverless compute. + Serverless is indicated by having environment_key or environments without cluster config. + """ + # Check if cluster config keys exist (even empty dict counts as cluster config) + has_cluster_config = "existing_cluster_id" in databricks_job or "new_cluster" in databricks_job + has_serverless_config = bool(databricks_job.get("environment_key") or databricks_job.get("environments")) + return not has_cluster_config and has_serverless_config + + +def _configure_serverless(databricks_job: dict, envs: dict) -> str: + """ + Configure serverless compute settings and return the environment_key to use. + + Databricks serverless requires the 'environments' array to be defined in the job + submission. This function ensures the environments array exists and injects + Flyte environment variables. + + Args: + databricks_job: The databricks job configuration dict + envs: Environment variables to inject + + Returns: + The environment_key to use for the task + """ + environment_key = databricks_job.get("environment_key", "default") + environments = databricks_job.get("environments", []) + + # Check if environment already exists in the array + env_exists = any(env.get("environment_key") == environment_key for env in environments) + + if not env_exists: + # Create the environment entry - Databricks serverless requires environments + # to be defined in the job submission (not externally pre-configured) + new_env = { + "environment_key": environment_key, + "spec": { + "client": "1", # Required: Databricks serverless client version + } + } + environments.append(new_env) + databricks_job["environments"] = environments + + # Inject Flyte environment variables into the environment spec + for env in environments: + if env.get("environment_key") == environment_key: + spec = env.setdefault("spec", {}) + existing_env_vars = spec.get("environment_vars", {}) + # Merge Flyte env vars with any existing ones (Flyte vars take precedence) + merged_env_vars = {**existing_env_vars, **{k: v for k, v in envs.items()}} + spec["environment_vars"] = merged_env_vars + break + + # Remove environment_key from top level (it's now in the task definition) + databricks_job.pop("environment_key", None) + + return environment_key + + +def _configure_classic_cluster(databricks_job: dict, custom: dict, container, envs: dict) -> None: + """ + Configure classic compute (existing cluster or new cluster). + + Args: + databricks_job: The databricks job configuration dict + custom: The custom config from task template + container: The container config from task template + envs: Environment variables to inject + """ + if databricks_job.get("existing_cluster_id") is not None: + # Using an existing cluster, no additional configuration needed + return + + new_cluster = databricks_job.get("new_cluster") + if new_cluster is None: + raise ValueError( + "Either existing_cluster_id, new_cluster, environment_key, or environments must be specified" + ) + + if not new_cluster.get("docker_image"): + new_cluster["docker_image"] = {"url": container.image} + if not new_cluster.get("spark_conf"): + new_cluster["spark_conf"] = custom.get("sparkConf", {}) + if not new_cluster.get("spark_env_vars"): + new_cluster["spark_env_vars"] = {k: v for k, v in envs.items()} + else: + new_cluster["spark_env_vars"].update({k: v for k, v in envs.items()}) + + def _get_databricks_job_spec(task_template: TaskTemplate) -> dict: custom = task_template.custom container = task_template.container envs = task_template.container.env envs[FLYTE_FAIL_ON_ERROR] = "true" databricks_job = custom["databricksConf"] - if databricks_job.get("existing_cluster_id") is None: - new_cluster = databricks_job.get("new_cluster") - if new_cluster is None: - raise ValueError("Either existing_cluster_id or new_cluster must be specified") - if not new_cluster.get("docker_image"): - new_cluster["docker_image"] = {"url": container.image} - if not new_cluster.get("spark_conf"): - new_cluster["spark_conf"] = custom.get("sparkConf", {}) - if not new_cluster.get("spark_env_vars"): - new_cluster["spark_env_vars"] = {k: v for k, v in envs.items()} + + # Check if this is a notebook task + notebook_path = custom.get("notebookPath") + notebook_base_parameters = custom.get("notebookBaseParameters", {}) + + # Determine compute mode and configure accordingly + is_serverless = _is_serverless_config(databricks_job) + + if notebook_path: + # Notebook task - runs a Databricks notebook + notebook_task = { + "notebook_path": notebook_path, + } + if notebook_base_parameters: + notebook_task["base_parameters"] = notebook_base_parameters + + # Check if notebook should be sourced from git + user_git_source = databricks_job.get("git_source") + if user_git_source: + notebook_task["source"] = "GIT" + # Set git_source at job level + databricks_job["git_source"] = user_git_source + + if is_serverless: + # Serverless notebook task + environment_key = _configure_serverless(databricks_job, envs) + + task_def = { + "task_key": "flyte_notebook_task", + "notebook_task": notebook_task, + "environment_key": environment_key, + } + databricks_job["tasks"] = [task_def] + databricks_job.pop("environment_key", None) else: - new_cluster["spark_env_vars"].update({k: v for k, v in envs.items()}) - # https://docs.databricks.com/api/workspace/jobs/submit - databricks_job["spark_python_task"] = { - "python_file": "flytekitplugins/databricks/entrypoint.py", - "source": "GIT", - "parameters": container.args, - } - databricks_job["git_source"] = { + # Classic compute notebook task + _configure_classic_cluster(databricks_job, custom, container, envs) + databricks_job["notebook_task"] = notebook_task + + # Clean up git_source from databricks_job if it was there (already set at job level) + databricks_job.pop("git_source", None) + if user_git_source: + databricks_job["git_source"] = user_git_source + + return databricks_job + + # Python file task (original behavior) + # Allow custom git_source and python_file override from user config + user_git_source = databricks_job.get("git_source") + user_python_file = databricks_job.get("python_file") + + # Default entrypoints from the flytetools repo. + # Both classic and serverless use the same repo; only the python_file differs. + default_git_source = { "git_url": "https://github.com/flyteorg/flytetools", "git_provider": "gitHub", - # https://github.com/flyteorg/flytetools/commit/572298df1f971fb58c258398bd70a6372f811c96 "git_commit": "572298df1f971fb58c258398bd70a6372f811c96", } + default_classic_python_file = "flytekitplugins/databricks/entrypoint.py" + default_serverless_python_file = "flytekitplugins/databricks/entrypoint_serverless.py" + + if is_serverless: + # Serverless compute - use flytetools serverless entrypoint by default + git_source = user_git_source if user_git_source else default_git_source + python_file = user_python_file if user_python_file else default_serverless_python_file + + # Serverless requires multi-task format with tasks array + environment_key = _configure_serverless(databricks_job, envs) + + # Build parameters list - append credential provider if specified + # This allows the entrypoint to receive the credential provider via command line + # We append at the END to avoid breaking pyflyte-fast-execute which must be first + parameters = list(container.args) if container.args else [] + + # Resolve service credential provider: task config > env var + service_credential_provider = custom.get( + "databricksServiceCredentialProvider", + os.getenv(DEFAULT_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER_ENV_KEY) + ) + if service_credential_provider: + # Append as a special argument that the entrypoint will parse and remove + parameters.append(f"--flyte-credential-provider={service_credential_provider}") + + spark_python_task = { + "python_file": python_file, + "source": "GIT", + "parameters": parameters, + } + + # Build the task definition for serverless + task_def = { + "task_key": "flyte_task", + "spark_python_task": spark_python_task, + "environment_key": environment_key, + } + + # Add tasks array for serverless (required by Databricks API) + databricks_job["tasks"] = [task_def] + + # Remove environment_key from top level (it's now in the task) + databricks_job.pop("environment_key", None) + else: + # Classic compute - use flytetools entrypoint by default + git_source = user_git_source if user_git_source else default_git_source + python_file = user_python_file if user_python_file else default_classic_python_file + + spark_python_task = { + "python_file": python_file, + "source": "GIT", + "parameters": container.args, + } + + _configure_classic_cluster(databricks_job, custom, container, envs) + databricks_job["spark_python_task"] = spark_python_task + + # Set git_source (remove from user config if it was there to avoid duplication) + databricks_job.pop("git_source", None) + databricks_job.pop("python_file", None) + databricks_job["git_source"] = git_source return databricks_job diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 5801d24fde..c8b09447dd 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -73,17 +73,123 @@ def __post_init__(self): class DatabricksV2(Spark): """ Use this to configure a Databricks task. Task's marked with this will automatically execute - natively onto databricks platform as a distributed execution of spark + natively onto databricks platform as a distributed execution of spark. + + Supports both classic compute (clusters) and serverless compute. Args: databricks_conf: Databricks job configuration compliant with API version 2.1, supporting 2.0 use cases. - For the configuration structure, visit here.https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure - For updates in API 2.1, refer to: https://docs.databricks.com/en/workflows/jobs/jobs-api-updates.html + For the configuration structure, visit: https://docs.databricks.com/dev-tools/api/2.0/jobs.html#request-structure + For updates in API 2.1, refer to: https://docs.databricks.com/en/workflows/jobs/jobs-api-updates.html databricks_instance: Domain name of your deployment. Use the form .cloud.databricks.com. + + Compute Modes: + The connector auto-detects the compute mode based on the databricks_conf contents: + + 1. Classic Compute (existing cluster): + Provide `existing_cluster_id` in databricks_conf. + + 2. Classic Compute (new cluster): + Provide `new_cluster` configuration in databricks_conf. + + 3. Serverless Compute (pre-configured environment): + Provide `environment_key` referencing a pre-configured environment in Databricks. + Do not include `existing_cluster_id` or `new_cluster`. + + 4. Serverless Compute (inline environment spec): + Provide `environments` array with environment specifications. + Optionally include `environment_key` to specify which environment to use. + Do not include `existing_cluster_id` or `new_cluster`. + + Example - Classic Compute with new cluster:: + + DatabricksV2( + databricks_conf={ + "run_name": "my-spark-job", + "new_cluster": { + "spark_version": "13.3.x-scala2.12", + "node_type_id": "m5.xlarge", + "num_workers": 2, + }, + }, + databricks_instance="my-workspace.cloud.databricks.com", + ) + + Example - Serverless Compute with pre-configured environment:: + + DatabricksV2( + databricks_conf={ + "run_name": "my-serverless-job", + "environment_key": "my-preconfigured-env", + }, + databricks_instance="my-workspace.cloud.databricks.com", + ) + + Example - Serverless Compute with inline environment spec:: + + DatabricksV2( + databricks_conf={ + "run_name": "my-serverless-job", + "environment_key": "default", + "environments": [{ + "environment_key": "default", + "spec": { + "client": "1", + "dependencies": ["pandas==2.0.0", "numpy==1.24.0"], + } + }], + }, + databricks_instance="my-workspace.cloud.databricks.com", + ) + + Note: + Serverless compute has certain limitations compared to classic compute: + - Only Python and SQL are supported (no Scala or R) + - Only Spark Connect APIs are supported (no RDD APIs) + - Must use Unity Catalog for external data sources + - No support for compute-scoped init scripts or libraries + For full details, see: https://docs.databricks.com/en/compute/serverless/limitations.html + + Serverless Entrypoint: + Both classic and serverless use the same ``flytetools`` repo for their entrypoints. + Classic uses ``flytekitplugins/databricks/entrypoint.py`` and serverless uses + ``flytekitplugins/databricks/entrypoint_serverless.py``. No additional configuration needed. + + To override the default, provide ``git_source`` and ``python_file`` in ``databricks_conf``. + + AWS Credentials for Serverless: + Databricks serverless does not provide AWS credentials via instance metadata. + To access S3 (for Flyte data), configure a Databricks Service Credential. + + The provider name is resolved in this order: + 1. ``databricks_service_credential_provider`` in the task config (per-task override) + 2. ``FLYTE_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER`` environment variable on the connector (default for all tasks) + + The entrypoint will use this to obtain AWS credentials via: + dbutils.credentials.getServiceCredentialsProvider(provider_name) + + Notebook Support: + To run a Databricks notebook instead of a Python file, set `notebook_path`. + Parameters can be passed via `notebook_base_parameters`. + + Example - Running a notebook:: + + DatabricksV2( + databricks_conf={ + "run_name": "my-notebook-job", + "new_cluster": {...}, + }, + databricks_instance="my-workspace.cloud.databricks.com", + notebook_path="/Users/user@example.com/my-notebook", + notebook_base_parameters={"param1": "value1"}, + ) """ databricks_conf: Optional[Dict[str, Union[str, dict]]] = None - databricks_instance: Optional[str] = None + databricks_instance: Optional[str] = None # Falls back to FLYTE_DATABRICKS_INSTANCE env var + databricks_service_credential_provider: Optional[str] = None # Falls back to FLYTE_DATABRICKS_SERVICE_CREDENTIAL_PROVIDER env var + notebook_path: Optional[str] = None # Path to Databricks notebook (e.g., "/Users/user@example.com/notebook") + notebook_base_parameters: Optional[Dict[str, str]] = None # Parameters to pass to the notebook # This method does not reset the SparkSession since it's a bit hard to handle multiple @@ -187,7 +293,20 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: job._databricks_conf = cfg.databricks_conf job._databricks_instance = cfg.databricks_instance - return MessageToDict(job.to_flyte_idl()) + # Serialize to dict + custom_dict = MessageToDict(job.to_flyte_idl()) + + # Add DatabricksV2-specific fields (not part of protobuf) + if isinstance(self.task_config, DatabricksV2): + cfg = cast(DatabricksV2, self.task_config) + if cfg.databricks_service_credential_provider: + custom_dict['databricksServiceCredentialProvider'] = cfg.databricks_service_credential_provider + if cfg.notebook_path: + custom_dict['notebookPath'] = cfg.notebook_path + if cfg.notebook_base_parameters: + custom_dict['notebookBaseParameters'] = cfg.notebook_base_parameters + + return custom_dict def to_k8s_pod(self, pod_template: Optional[PodTemplate] = None) -> Optional[K8sPod]: """ @@ -210,10 +329,105 @@ def to_k8s_pod(self, pod_template: Optional[PodTemplate] = None) -> Optional[K8s return K8sPod.from_pod_template(pod_template) + def _is_databricks_serverless(self) -> bool: + """ + Detect if we're running in Databricks serverless environment. + + Serverless uses Spark Connect and requires different SparkSession handling. + """ + # Check for explicit serverless markers set by our entrypoint + if os.environ.get("DATABRICKS_SERVERLESS") == "true": + return True + if os.environ.get("SPARK_CONNECT_MODE") == "true": + return True + + # Check for Databricks serverless indicators + # 1. DATABRICKS_RUNTIME_VERSION exists (Databricks environment) + # 2. No SPARK_HOME (serverless doesn't have traditional Spark) + is_databricks = "DATABRICKS_RUNTIME_VERSION" in os.environ + + # Additional check: if using DatabricksV2 with serverless config + if isinstance(self.task_config, DatabricksV2): + conf = self.task_config.databricks_conf or {} + has_serverless_config = ( + "environment_key" in conf or + "environments" in conf + ) and "new_cluster" not in conf and "existing_cluster_id" not in conf + if has_serverless_config: + return True + + return is_databricks and "SPARK_HOME" not in os.environ + + def _get_databricks_serverless_spark_session(self): + """ + Get SparkSession in Databricks serverless environment. + + The entrypoint injects the SparkSession into: + 1. Custom module '_flyte_spark_session' in sys.modules (most reliable) + 2. builtins.spark (backup) + + Returns: + SparkSession or None if not available + """ + import sys + + # Method 1: Try custom module (most reliable - survives module reloads) + try: + if '_flyte_spark_session' in sys.modules: + spark_module = sys.modules['_flyte_spark_session'] + if hasattr(spark_module, 'spark') and spark_module.spark is not None: + logger.info(f"Got SparkSession from _flyte_spark_session module") + return spark_module.spark + except Exception as e: + logger.debug(f"Could not get spark from _flyte_spark_session: {e}") + + # Method 2: Try builtins (backup location) + try: + import builtins + if hasattr(builtins, 'spark') and builtins.spark is not None: + logger.info(f"Got SparkSession from builtins") + return builtins.spark + except Exception as e: + logger.debug(f"Could not get spark from builtins: {e}") + + # Method 3: Try __main__ module + try: + import __main__ + if hasattr(__main__, 'spark') and __main__.spark is not None: + logger.info(f"Got SparkSession from __main__") + return __main__.spark + except Exception as e: + logger.debug(f"Could not get spark from __main__: {e}") + + # Method 4: Try active session + try: + from pyspark.sql import SparkSession + active = SparkSession.getActiveSession() + if active: + logger.info(f"Got active SparkSession") + return active + except Exception as e: + logger.debug(f"Could not get active SparkSession: {e}") + + logger.warning("Could not obtain SparkSession in serverless environment") + return None + def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: import pyspark as _pyspark ctx = FlyteContextManager.current_context() + + # Databricks serverless uses Spark Connect - SparkSession is pre-configured + if self._is_databricks_serverless(): + logger.info("Detected Databricks serverless environment - using pre-configured SparkSession") + self.sess = self._get_databricks_serverless_spark_session() + + if self.sess is None: + logger.warning("No SparkSession available - task will run without Spark") + + return user_params.builder().add_attr("SPARK_SESSION", self.sess).build() + + # Standard Spark session creation for non-serverless environments sess_builder = _pyspark.sql.SparkSession.builder.appName(f"FlyteSpark: {user_params.execution_id}") if not (ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION): # If either of above cases is not true, then we are in local execution of this task diff --git a/plugins/flytekit-spark/tests/test_connector.py b/plugins/flytekit-spark/tests/test_connector.py index 5136d39ce8..b2828de9fa 100644 --- a/plugins/flytekit-spark/tests/test_connector.py +++ b/plugins/flytekit-spark/tests/test_connector.py @@ -8,8 +8,15 @@ from flyteidl.core.execution_pb2 import TaskExecution from flytekit.core.constants import FLYTE_FAIL_ON_ERROR -from flytekitplugins.spark.connector import DATABRICKS_API_ENDPOINT, DatabricksJobMetadata, get_header, \ - _get_databricks_job_spec, DEFAULT_DATABRICKS_INSTANCE_ENV_KEY +from flytekitplugins.spark.connector import ( + DATABRICKS_API_ENDPOINT, + DatabricksJobMetadata, + get_header, + _get_databricks_job_spec, + _is_serverless_config, + _configure_serverless, + DEFAULT_DATABRICKS_INSTANCE_ENV_KEY, +) from flytekit.extend.backend.base_agent import AgentRegistry from flytekit.interfaces.cli_identifiers import Identifier @@ -192,3 +199,475 @@ async def test_agent_create_with_default_instance(task_template: TaskTemplate): assert res == databricks_metadata mock.patch.stopall() + + +# ==================== Serverless Compute Tests ==================== + + +@pytest.fixture(scope="function") +def serverless_task_template_with_env_key() -> TaskTemplate: + """Task template configured for serverless with pre-configured environment_key.""" + task_id = Identifier( + resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version" + ) + task_metadata = task.TaskMetadata( + True, + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + "A", + (), + ) + task_config = { + "sparkConf": {}, + "mainApplicationFile": "dbfs:/entrypoint.py", + "databricksConf": { + "run_name": "flytekit serverless job", + "environment_key": "my-preconfigured-env", + "timeout_seconds": 3600, + "git_source": { + "git_url": "https://github.com/test-org/test-repo", + "git_provider": "gitHub", + "git_branch": "main", + }, + "python_file": "entrypoint_serverless.py", + } + } + container = Container( + image="flyteorg/flytekit:databricks-0.18.0-py3.7", + command=[], + args=["pyflyte-execute", "--inputs", "s3://my-s3-bucket"], + resources=Resources(requests=[], limits=[]), + env={"foo": "bar"}, + config={}, + ) + + return TaskTemplate( + id=task_id, + custom=task_config, + metadata=task_metadata, + container=container, + interface=None, + type="spark", + ) + + +@pytest.fixture(scope="function") +def serverless_task_template_with_inline_env() -> TaskTemplate: + """Task template configured for serverless with inline environments spec.""" + task_id = Identifier( + resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version" + ) + task_metadata = task.TaskMetadata( + True, + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + "A", + (), + ) + task_config = { + "sparkConf": {}, + "mainApplicationFile": "dbfs:/entrypoint.py", + "databricksConf": { + "run_name": "flytekit serverless job with inline env", + "environment_key": "default", + "environments": [{ + "environment_key": "default", + "spec": { + "client": "1", + "dependencies": ["pandas==2.0.0"], + } + }], + "timeout_seconds": 3600, + "git_source": { + "git_url": "https://github.com/test-org/test-repo", + "git_provider": "gitHub", + "git_branch": "main", + }, + "python_file": "entrypoint_serverless.py", + } + } + container = Container( + image="flyteorg/flytekit:databricks-0.18.0-py3.7", + command=[], + args=["pyflyte-execute", "--inputs", "s3://my-s3-bucket"], + resources=Resources(requests=[], limits=[]), + env={"foo": "bar", "MY_VAR": "my_value"}, + config={}, + ) + + return TaskTemplate( + id=task_id, + custom=task_config, + metadata=task_metadata, + container=container, + interface=None, + type="spark", + ) + + +@pytest.fixture(scope="function") +def serverless_task_template_no_git_source() -> TaskTemplate: + """Task template for serverless without git_source - relies on connector env vars.""" + task_id = Identifier( + resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version" + ) + task_metadata = task.TaskMetadata( + True, + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + "A", + (), + ) + task_config = { + "sparkConf": {}, + "mainApplicationFile": "dbfs:/entrypoint.py", + "databricksConf": { + "run_name": "flytekit serverless job - no git source", + "environment_key": "default", + "environments": [{ + "environment_key": "default", + "spec": { + "client": "4", + } + }], + "timeout_seconds": 3600, + } + } + container = Container( + image="flyteorg/flytekit:databricks-0.18.0-py3.7", + command=[], + args=["pyflyte-execute", "--inputs", "s3://my-s3-bucket"], + resources=Resources(requests=[], limits=[]), + env={"foo": "bar"}, + config={}, + ) + + return TaskTemplate( + id=task_id, + custom=task_config, + metadata=task_metadata, + container=container, + interface=None, + type="spark", + ) + + +@pytest.fixture(scope="function") +def invalid_task_template_no_compute() -> TaskTemplate: + """Task template with no cluster or environment config - should fail.""" + task_id = Identifier( + resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version" + ) + task_metadata = task.TaskMetadata( + True, + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + "A", + (), + ) + task_config = { + "sparkConf": {}, + "mainApplicationFile": "dbfs:/entrypoint.py", + "databricksConf": { + "run_name": "invalid job - no compute config", + "timeout_seconds": 3600, + } + } + container = Container( + image="flyteorg/flytekit:databricks-0.18.0-py3.7", + command=[], + args=["pyflyte-execute", "--inputs", "s3://my-s3-bucket"], + resources=Resources(requests=[], limits=[]), + env={"foo": "bar"}, + config={}, + ) + + return TaskTemplate( + id=task_id, + custom=task_config, + metadata=task_metadata, + container=container, + interface=None, + type="spark", + ) + + +def test_is_serverless_config_detection(): + """Test the serverless configuration detection logic.""" + # Classic compute with existing_cluster_id + assert _is_serverless_config({"existing_cluster_id": "abc123"}) is False + + # Classic compute with new_cluster + assert _is_serverless_config({"new_cluster": {"spark_version": "13.3"}}) is False + + # Serverless with environment_key only + assert _is_serverless_config({"environment_key": "my-env"}) is True + + # Serverless with environments array + assert _is_serverless_config({"environments": [{"environment_key": "default"}]}) is True + + # Serverless with both environment_key and environments + assert _is_serverless_config({ + "environment_key": "default", + "environments": [{"environment_key": "default"}] + }) is True + + # No compute config at all + assert _is_serverless_config({"run_name": "test"}) is False + + # Has cluster AND environment (cluster takes precedence, not serverless) + assert _is_serverless_config({ + "new_cluster": {"spark_version": "13.3"}, + "environment_key": "my-env" + }) is False + + +def test_configure_serverless_with_env_key_only(): + """Test serverless configuration with environment_key only (no environments array).""" + databricks_job = {"environment_key": "my-env"} + envs = {"FOO": "bar", FLYTE_FAIL_ON_ERROR: "true"} + + result_key = _configure_serverless(databricks_job, envs) + + assert result_key == "my-env" + # Databricks serverless requires environments array - it should be auto-created + assert "environments" in databricks_job + assert len(databricks_job["environments"]) == 1 + assert databricks_job["environments"][0]["environment_key"] == "my-env" + # Environment variables should be injected + env_vars = databricks_job["environments"][0]["spec"]["environment_vars"] + assert env_vars["FOO"] == "bar" + assert env_vars[FLYTE_FAIL_ON_ERROR] == "true" + # environment_key should be removed from top level + assert "environment_key" not in databricks_job + + +def test_configure_serverless_with_inline_env(): + """Test serverless configuration with inline environment spec.""" + databricks_job = { + "environment_key": "default", + "environments": [{ + "environment_key": "default", + "spec": { + "client": "1", + "dependencies": ["pandas==2.0.0"], + } + }] + } + envs = {"FOO": "bar", FLYTE_FAIL_ON_ERROR: "true"} + + result_key = _configure_serverless(databricks_job, envs) + + assert result_key == "default" + # Environment variables should be injected + env_vars = databricks_job["environments"][0]["spec"]["environment_vars"] + assert env_vars["FOO"] == "bar" + assert env_vars[FLYTE_FAIL_ON_ERROR] == "true" + # environment_key should be removed from top level + assert "environment_key" not in databricks_job + + +def test_configure_serverless_creates_default_env(): + """Test that serverless creates a default environment when no environment specified.""" + databricks_job = {} # No environment_key or environments + envs = {"FOO": "bar"} + + result_key = _configure_serverless(databricks_job, envs) + + assert result_key == "default" + assert len(databricks_job["environments"]) == 1 + assert databricks_job["environments"][0]["environment_key"] == "default" + # Should have env vars injected + assert databricks_job["environments"][0]["spec"]["environment_vars"]["FOO"] == "bar" + + +def test_get_databricks_job_spec_serverless_with_env_key(serverless_task_template_with_env_key: TaskTemplate): + """Test job spec generation for serverless with environment_key only.""" + serverless_task_template_with_env_key.custom["databricksInstance"] = "test-account.cloud.databricks.com" + + spec = _get_databricks_job_spec(serverless_task_template_with_env_key) + + # Serverless uses multi-task format with tasks array + assert "tasks" in spec + assert len(spec["tasks"]) == 1 + + task_def = spec["tasks"][0] + assert task_def["task_key"] == "flyte_task" + assert task_def["environment_key"] == "my-preconfigured-env" + assert "spark_python_task" in task_def + + # Databricks serverless requires environments array - should be auto-created + assert "environments" in spec + assert len(spec["environments"]) == 1 + assert spec["environments"][0]["environment_key"] == "my-preconfigured-env" + + # Should NOT have spark_python_task at top level for serverless + assert "spark_python_task" not in spec + + # Should NOT have environment_key at top level (moved to task) + assert "environment_key" not in spec + + # Should NOT have cluster config + assert "new_cluster" not in spec + assert "existing_cluster_id" not in spec + + # Should have git_source + assert "git_source" in spec + + +def test_get_databricks_job_spec_serverless_with_inline_env(serverless_task_template_with_inline_env: TaskTemplate): + """Test job spec generation for serverless with inline environment spec.""" + serverless_task_template_with_inline_env.custom["databricksInstance"] = "test-account.cloud.databricks.com" + + spec = _get_databricks_job_spec(serverless_task_template_with_inline_env) + + # Serverless uses multi-task format with tasks array + assert "tasks" in spec + assert len(spec["tasks"]) == 1 + + task_def = spec["tasks"][0] + assert task_def["task_key"] == "flyte_task" + assert task_def["environment_key"] == "default" + assert "spark_python_task" in task_def + + # Should have environments array with injected env vars + assert "environments" in spec + env_vars = spec["environments"][0]["spec"]["environment_vars"] + assert env_vars["foo"] == "bar" + assert env_vars["MY_VAR"] == "my_value" + assert env_vars[FLYTE_FAIL_ON_ERROR] == "true" + + # Should NOT have cluster config + assert "new_cluster" not in spec + assert "existing_cluster_id" not in spec + + +def test_get_databricks_job_spec_error_no_compute(invalid_task_template_no_compute: TaskTemplate): + """Test that job spec generation fails when no compute config is provided.""" + with pytest.raises(ValueError) as exc_info: + _get_databricks_job_spec(invalid_task_template_no_compute) + + assert "existing_cluster_id" in str(exc_info.value) + assert "new_cluster" in str(exc_info.value) + assert "environment_key" in str(exc_info.value) + assert "environments" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_databricks_agent_serverless(serverless_task_template_with_env_key: TaskTemplate): + """Test the full agent flow with serverless compute.""" + import copy + agent = AgentRegistry.get_agent("spark") + + serverless_task_template_with_env_key.custom["databricksInstance"] = "test-account.cloud.databricks.com" + + # Generate spec BEFORE agent.create() mutates the template in-place + spec_copy = copy.deepcopy(serverless_task_template_with_env_key) + spec = _get_databricks_job_spec(spec_copy) + + # Verify serverless config uses multi-task format with environments + assert "tasks" in spec + task_def = spec["tasks"][0] + assert task_def["task_key"] == "flyte_task" + assert task_def["environment_key"] == "my-preconfigured-env" + assert "spark_python_task" in task_def + assert "environments" in spec # Required for serverless + assert "new_cluster" not in spec + + mocked_token = "mocked_databricks_token" + mocked_context = mock.patch("flytekit.current_context", autospec=True).start() + mocked_context.return_value.secrets.get.return_value = mocked_token + + databricks_metadata = DatabricksJobMetadata( + databricks_instance="test-account.cloud.databricks.com", + run_id="456", + ) + + mock_create_response = {"run_id": "456"} + mock_get_response = { + "job_id": "2", + "run_id": "456", + "state": {"life_cycle_state": "TERMINATED", "result_state": "SUCCESS", "state_message": "OK"}, + } + + create_url = f"https://test-account.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/submit" + get_url = f"https://test-account.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/get?run_id=456" + + with aioresponses() as mocked: + mocked.post(create_url, status=http.HTTPStatus.OK, payload=mock_create_response) + res = await agent.create(serverless_task_template_with_env_key, None) + assert res == databricks_metadata + + mocked.get(get_url, status=http.HTTPStatus.OK, payload=mock_get_response) + resource = await agent.get(databricks_metadata) + assert resource.phase == TaskExecution.SUCCEEDED + + mock.patch.stopall() + + +# ==================== Default Serverless Entrypoint Tests ==================== + + +def test_serverless_default_entrypoint_from_flytetools(serverless_task_template_no_git_source: TaskTemplate): + """Test that serverless uses the default flytetools entrypoint when no git_source in task config.""" + spec = _get_databricks_job_spec(serverless_task_template_no_git_source) + + # Should use the same flytetools repo as classic + assert spec["git_source"]["git_url"] == "https://github.com/flyteorg/flytetools" + assert spec["git_source"]["git_provider"] == "gitHub" + assert "git_commit" in spec["git_source"] + + # Should use the serverless-specific python_file + task_def = spec["tasks"][0] + assert task_def["spark_python_task"]["python_file"] == "flytekitplugins/databricks/entrypoint_serverless.py" + + # Should still be valid serverless format + assert "environments" in spec + assert "new_cluster" not in spec + + +def test_serverless_task_git_source_overrides_default(serverless_task_template_with_env_key: TaskTemplate): + """Test that task-level git_source takes precedence over the flytetools default.""" + spec = _get_databricks_job_spec(serverless_task_template_with_env_key) + + # Should use the task-level git_source, NOT the flytetools default + assert spec["git_source"]["git_url"] == "https://github.com/test-org/test-repo" + assert spec["git_source"]["git_branch"] == "main" + + # Should use the task-level python_file + task_def = spec["tasks"][0] + assert task_def["spark_python_task"]["python_file"] == "entrypoint_serverless.py" + + +def test_classic_and_serverless_use_same_repo(task_template: TaskTemplate, serverless_task_template_no_git_source: TaskTemplate): + """Test that both classic and serverless default to the same flytetools repo.""" + classic_spec = _get_databricks_job_spec(task_template) + serverless_spec = _get_databricks_job_spec(serverless_task_template_no_git_source) + + # Same repo + assert classic_spec["git_source"]["git_url"] == serverless_spec["git_source"]["git_url"] + # Same commit + assert classic_spec["git_source"]["git_commit"] == serverless_spec["git_source"]["git_commit"] + # Different python_file + assert classic_spec["spark_python_task"]["python_file"] == "flytekitplugins/databricks/entrypoint.py" + assert serverless_spec["tasks"][0]["spark_python_task"]["python_file"] == "flytekitplugins/databricks/entrypoint_serverless.py" diff --git a/plugins/flytekit-spark/tests/test_spark_task.py b/plugins/flytekit-spark/tests/test_spark_task.py index 7198a4dec0..76c229766c 100644 --- a/plugins/flytekit-spark/tests/test_spark_task.py +++ b/plugins/flytekit-spark/tests/test_spark_task.py @@ -1,3 +1,4 @@ +import os import os.path from unittest import mock @@ -493,3 +494,216 @@ def my_spark(a: str) -> int: configs = my_spark.sess.sparkContext.getConf().getAll() assert ("spark.driver.memory", "1000M") in configs assert ("spark.app.name", "FlyteSpark: ex:local:local:local") in configs + + +# ==================== Serverless Detection Tests ==================== + + +def test_databricks_v2_serverless_detection_with_env_var(reset_spark_session): + """Test that serverless is detected when DATABRICKS_SERVERLESS env var is set.""" + databricks_conf = { + "run_name": "test", + "new_cluster": {"spark_version": "13.3.x-scala2.12"}, # Has cluster config + } + + @task( + task_config=DatabricksV2( + databricks_conf=databricks_conf, + databricks_instance="test.cloud.databricks.com", + ) + ) + def my_task(a: int) -> int: + return a + + # Without env var, should NOT be serverless (has new_cluster) + assert my_task._is_databricks_serverless() is False + + # With env var set, should BE serverless + os.environ["DATABRICKS_SERVERLESS"] = "true" + try: + assert my_task._is_databricks_serverless() is True + finally: + del os.environ["DATABRICKS_SERVERLESS"] + + +def test_databricks_v2_serverless_detection_with_config(reset_spark_session): + """Test that serverless is detected based on DatabricksV2 config.""" + # Serverless config: has environment_key, no cluster config + serverless_conf = { + "run_name": "serverless-test", + "environment_key": "my-env", + } + + @task( + task_config=DatabricksV2( + databricks_conf=serverless_conf, + databricks_instance="test.cloud.databricks.com", + ) + ) + def serverless_task(a: int) -> int: + return a + + # Should detect serverless from config + assert serverless_task._is_databricks_serverless() is True + + # Classic config: has new_cluster + classic_conf = { + "run_name": "classic-test", + "new_cluster": {"spark_version": "13.3.x-scala2.12"}, + } + + @task( + task_config=DatabricksV2( + databricks_conf=classic_conf, + databricks_instance="test.cloud.databricks.com", + ) + ) + def classic_task(a: int) -> int: + return a + + # Should NOT detect serverless + assert classic_task._is_databricks_serverless() is False + + +def test_databricks_v2_serverless_detection_with_environments_array(reset_spark_session): + """Test serverless detection with inline environments array.""" + serverless_conf = { + "run_name": "serverless-inline", + "environments": [{ + "environment_key": "default", + "spec": {"client": "1", "dependencies": ["pandas"]} + }], + } + + @task( + task_config=DatabricksV2( + databricks_conf=serverless_conf, + databricks_instance="test.cloud.databricks.com", + ) + ) + def serverless_task(a: int) -> int: + return a + + assert serverless_task._is_databricks_serverless() is True + + +def test_databricks_v2_classic_not_detected_as_serverless(reset_spark_session): + """Test that classic compute is not incorrectly detected as serverless.""" + # Classic with existing_cluster_id + existing_cluster_conf = { + "run_name": "existing-cluster", + "existing_cluster_id": "abc-123", + } + + @task( + task_config=DatabricksV2( + databricks_conf=existing_cluster_conf, + databricks_instance="test.cloud.databricks.com", + ) + ) + def existing_cluster_task(a: int) -> int: + return a + + assert existing_cluster_task._is_databricks_serverless() is False + + # Classic with new_cluster AND environment_key (cluster takes precedence) + mixed_conf = { + "run_name": "mixed", + "new_cluster": {"spark_version": "13.3.x-scala2.12"}, + "environment_key": "my-env", # Should be ignored + } + + @task( + task_config=DatabricksV2( + databricks_conf=mixed_conf, + databricks_instance="test.cloud.databricks.com", + ) + ) + def mixed_task(a: int) -> int: + return a + + assert mixed_task._is_databricks_serverless() is False + + +def test_databricks_v2_service_credential_provider(): + """Test that service credential provider is properly serialized.""" + serverless_conf = { + "run_name": "serverless-with-creds", + "environment_key": "my-env", + } + + @task( + task_config=DatabricksV2( + databricks_conf=serverless_conf, + databricks_instance="test.cloud.databricks.com", + databricks_service_credential_provider="my-credential-provider", + ) + ) + def task_with_creds(a: int) -> int: + return a + + default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + + custom = task_with_creds.get_custom(settings) + assert custom.get("databricksServiceCredentialProvider") == "my-credential-provider" + + +def test_databricks_v2_no_service_credential_provider(): + """Test that custom dict doesn't have credential provider when not set.""" + serverless_conf = { + "run_name": "serverless-no-creds", + "environment_key": "my-env", + } + + @task( + task_config=DatabricksV2( + databricks_conf=serverless_conf, + databricks_instance="test.cloud.databricks.com", + ) + ) + def task_no_creds(a: int) -> int: + return a + + default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + + custom = task_no_creds.get_custom(settings) + assert "databricksServiceCredentialProvider" not in custom + + +def test_spark_classic_not_affected_by_serverless_code(reset_spark_session): + """Test that regular Spark tasks (non-Databricks) are not affected by serverless code.""" + @task( + task_config=Spark( + spark_conf={"spark.driver.memory": "512M"}, + ) + ) + def spark_task(a: int) -> int: + return a + + # Regular Spark task should NOT be detected as serverless + assert spark_task._is_databricks_serverless() is False + + # pre_execute should work normally + pb = ExecutionParameters.new_builder() + pb.working_dir = "/tmp" + pb.execution_id = "ex:local:local:local" + p = pb.build() + new_p = spark_task.pre_execute(p) + + assert new_p is not None + assert new_p.has_attr("SPARK_SESSION") + assert spark_task.sess is not None