From eece1045b223d263c00099f6dd5565e32fe5842e Mon Sep 17 00:00:00 2001 From: "Alexander Nicholson 4584443+DragonStuff@users.noreply.github.com" <4584443+DragonStuff@users.noreply.github.com> Date: Sat, 18 Oct 2025 13:16:20 +0900 Subject: [PATCH 1/5] feature: implement replica groups in service configurations type: service port: 8000 commands: ["python app.py"] replica_groups: - name: l40s-gpu replicas: 1..3 # autoscalable resources: gpu: L40S - name: h100-gpu replicas: 2 # fixed regions: [us-east] resources: gpu: H100 - Added the ability to define multiple replica groups with distinct configurations, including resource requirements and autoscaling behavior. - Updated relevant documentation to reflect the new replica groups feature. - Enhanced CLI output to display job plans with group names for better clarity. - Ensured backward compatibility by excluding replica groups from JSON when not set. - Added tests to validate the functionality and backward compatibility of replica groups. This change allows for more flexible service configurations, enabling users to manage different types of resources and scaling strategies within a single service. --- contributing/AUTOSCALING.md | 2 + contributing/RUNS-AND-JOBS.md | 2 +- docs/docs/concepts/services.md | 60 +++ docs/docs/reference/dstack.yml/service.md | 16 + src/dstack/_internal/cli/utils/run.py | 184 ++++++-- .../_internal/core/compatibility/runs.py | 3 + .../_internal/core/models/configurations.py | 91 +++- src/dstack/_internal/core/models/runs.py | 40 +- .../server/background/tasks/process_runs.py | 5 +- .../tasks/process_submitted_jobs.py | 49 +- ...3d4e5f6_add_jobmodel_replica_group_name.py | 27 ++ src/dstack/_internal/server/models.py | 1 + .../server/services/jobs/__init__.py | 17 +- .../services/jobs/configurators/base.py | 43 +- src/dstack/_internal/server/services/runs.py | 277 +++++++++-- .../server/services/services/autoscalers.py | 22 +- src/dstack/_internal/server/testing/common.py | 2 + .../cli/utils/test_run_plan_display.py | 424 +++++++++++++++++ .../core/models/test_replica_groups.py | 437 ++++++++++++++++++ .../core/test_backward_compatibility.py | 130 ++++++ .../services/test_get_plan_replica_groups.py | 214 +++++++++ .../services/test_replica_groups_scaling.py | 389 ++++++++++++++++ 22 files changed, 2330 insertions(+), 105 deletions(-) create mode 100644 src/dstack/_internal/server/migrations/versions/a1b2c3d4e5f6_add_jobmodel_replica_group_name.py create mode 100644 src/tests/_internal/cli/utils/test_run_plan_display.py create mode 100644 src/tests/_internal/core/models/test_replica_groups.py create mode 100644 src/tests/_internal/core/test_backward_compatibility.py create mode 100644 src/tests/_internal/server/services/test_get_plan_replica_groups.py create mode 100644 src/tests/_internal/server/services/test_replica_groups_scaling.py diff --git a/contributing/AUTOSCALING.md b/contributing/AUTOSCALING.md index 7fa987aaf..60a33aec1 100644 --- a/contributing/AUTOSCALING.md +++ b/contributing/AUTOSCALING.md @@ -11,6 +11,8 @@ - STEP 7: `scale_run_replicas` terminates or starts replicas. - `SUBMITTED` and `PROVISIONING` replicas get terminated before `RUNNING`. - Replicas are terminated by descending `replica_num` and launched by ascending `replica_num`. + - For services with `replica_groups`, only groups with autoscaling ranges (min != max) participate in scaling. + - Scale operations respect per-group minimum and maximum constraints. ## RPSAutoscaler diff --git a/contributing/RUNS-AND-JOBS.md b/contributing/RUNS-AND-JOBS.md index b2c0430af..18544caa5 100644 --- a/contributing/RUNS-AND-JOBS.md +++ b/contributing/RUNS-AND-JOBS.md @@ -13,7 +13,7 @@ Runs are created from run configurations. There are three types of run configura 2. `task` — runs the user's bash script until completion. 3. `service` — runs the user's bash script and exposes a port through [dstack-proxy](PROXY.md). -A run can spawn one or multiple jobs, depending on the configuration. A task that specifies multiple `nodes` spawns a job for every node (a multi-node task). A service that specifies multiple `replicas` spawns a job for every replica. A job submission is always assigned to one particular instance. If a job fails and the configuration allows retrying, the server creates a new job submission for the job. +A run can spawn one or multiple jobs, depending on the configuration. A task that specifies multiple `nodes` spawns a job for every node (a multi-node task). A service that specifies multiple `replicas` or `replica_groups` spawns a job for every replica. Each job in a replica group is tagged with `replica_group_name` to track which group it belongs to. A job submission is always assigned to one particular instance. If a job fails and the configuration allows retrying, the server creates a new job submission for the job. ## Run's Lifecycle diff --git a/docs/docs/concepts/services.md b/docs/docs/concepts/services.md index cb2649e00..383738538 100644 --- a/docs/docs/concepts/services.md +++ b/docs/docs/concepts/services.md @@ -160,6 +160,66 @@ Setting the minimum number of replicas to `0` allows the service to scale down t > The `scaling` property requires creating a [gateway](gateways.md). +### Replica Groups (Advanced) + +For advanced use cases, you can define multiple **replica groups** with different instance types, resources, and configurations within a single service. This is useful when you want to: + +- Run different GPU types in the same service (e.g., H100 for primary, RTX5090 for overflow) +- Configure different backends or regions per replica type +- Set different autoscaling behavior per group + +
+ +```yaml +type: service +name: llama31-service + +python: 3.12 +env: + - HF_TOKEN +commands: + - uv pip install vllm + - vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct --max-model-len 4096 +port: 8000 + +# Define multiple replica groups with different configurations +replica_groups: + - name: primary + replicas: 1 # Always 1 H100 (fixed) + resources: + gpu: H100:1 + backends: [aws] + regions: [us-west-2] + + - name: overflow + replicas: 0..5 # Autoscales 0-5 RTX5090s + resources: + gpu: RTX5090:1 + backends: [runpod] + +scaling: + metric: rps + target: 10 +``` + +
+ +In this example: + +- The `primary` group always runs 1 H100 replica on AWS (fixed, never scaled) +- The `overflow` group scales 0-5 RTX5090 replicas on RunPod based on load +- Scale operations only affect groups with autoscaling ranges (min != max) + +Each replica group can override any [profile parameter](../reference/profiles.yml.md) including `backends`, `regions`, `instance_types`, `spot_policy`, etc. Group-level settings override service-level settings. + +> **Note:** When using `replica_groups`, you cannot use the simple `replicas` field. They are mutually exclusive. + +**When to use replica groups:** + +- You need different GPU types in the same service +- Different replicas should run in different regions or clouds +- Some replicas should be fixed while others autoscale + ### Model If the service is running a chat model with an OpenAI-compatible interface, diff --git a/docs/docs/reference/dstack.yml/service.md b/docs/docs/reference/dstack.yml/service.md index 8d89b2d57..40509332e 100644 --- a/docs/docs/reference/dstack.yml/service.md +++ b/docs/docs/reference/dstack.yml/service.md @@ -10,6 +10,22 @@ The `service` configuration type allows running [services](../../concepts/servic type: required: true +### `replica_groups` + +Define multiple replica groups with different configurations within a single service. + +> **Note:** Cannot be used together with `replicas`. + +#### `replica_groups[n]` + +#SCHEMA# dstack._internal.core.models.configurations.ReplicaGroup + overrides: + show_root_heading: false + type: + required: true + +Each replica group inherits from [ProfileParams](../profiles.yml.md) and can override any profile parameter including `backends`, `regions`, `instance_types`, `spot_policy`, etc. + ### `model` { data-toc-label="model" } === "OpenAI" diff --git a/src/dstack/_internal/cli/utils/run.py b/src/dstack/_internal/cli/utils/run.py index 58497c084..ffcbef4f5 100644 --- a/src/dstack/_internal/cli/utils/run.py +++ b/src/dstack/_internal/cli/utils/run.py @@ -119,7 +119,32 @@ def th(s: str) -> str: if include_run_properties: props.add_row(th("Configuration"), run_spec.configuration_path) props.add_row(th("Type"), run_spec.configuration.type) - props.add_row(th("Resources"), pretty_req) + + from dstack._internal.core.models.configurations import ServiceConfiguration + + if ( + include_run_properties + and isinstance(run_spec.configuration, ServiceConfiguration) + and run_spec.configuration.replica_groups + ): + groups_info = [] + for group in run_spec.configuration.replica_groups: + group_parts = [f"[cyan]{group.name}[/cyan]"] + + if group.replicas.min == group.replicas.max: + group_parts.append(f"×{group.replicas.max}") + else: + group_parts.append(f"×{group.replicas.min}..{group.replicas.max}") + group_parts.append("[dim](autoscalable)[/dim]") + + group_parts.append(f"[dim]({group.resources.pretty_format()})[/dim]") + + groups_info.append(" ".join(group_parts)) + + props.add_row(th("Replica groups"), "\n".join(groups_info)) + else: + props.add_row(th("Resources"), pretty_req) + props.add_row(th("Spot policy"), spot_policy) props.add_row(th("Max price"), max_price) if include_run_properties: @@ -138,45 +163,130 @@ def th(s: str) -> str: offers.add_column("INSTANCE TYPE", style="grey58", no_wrap=True, ratio=2) offers.add_column("PRICE", style="grey58", ratio=1) offers.add_column() + + # For replica groups, show offers from all job plans + if len(run_plan.job_plans) > 1: + # Multiple jobs - aggregate offers from all groups + all_offers = [] + groups_with_no_offers = [] + total_offers_count = 0 + + for jp in run_plan.job_plans: + group_name = jp.job_spec.replica_group_name or "default" + if jp.total_offers == 0: + groups_with_no_offers.append(group_name) + for offer in jp.offers[:max_offers] if max_offers else jp.offers: + all_offers.append((group_name, offer)) + total_offers_count += jp.total_offers + + # Sort by price + all_offers.sort(key=lambda x: x[1].price) + if max_offers: + all_offers = all_offers[:max_offers] + + # Show groups with no offers FIRST + for group_name in groups_with_no_offers: + offers.add_row( + "", + f"[cyan]{group_name}[/cyan]:", + "[red]No matching instance offers available.[/red]\n" + "Possible reasons: https://dstack.ai/docs/guides/troubleshooting/#no-offers", + "", + "", + "", + style="secondary", + ) + + # Then show groups with offers + for i, (group_name, offer) in enumerate(all_offers, start=1): + r = offer.instance.resources - job_plan.offers = job_plan.offers[:max_offers] if max_offers else job_plan.offers + availability = "" + if offer.availability in { + InstanceAvailability.NOT_AVAILABLE, + InstanceAvailability.NO_QUOTA, + InstanceAvailability.IDLE, + InstanceAvailability.BUSY, + }: + availability = offer.availability.value.replace("_", " ").lower() + instance = offer.instance.name + if offer.total_blocks > 1: + instance += f" ({offer.blocks}/{offer.total_blocks})" + + # Add group name prefix for multi-group display + backend_display = f"[cyan]{group_name}[/cyan]: {offer.backend.replace('remote', 'ssh')} ({offer.region})" + + offers.add_row( + f"{i}", + backend_display, + r.pretty_format(include_spot=True), + instance, + f"${offer.price:.4f}".rstrip("0").rstrip("."), + availability, + style=None if i == 1 or not include_run_properties else "secondary", + ) + + if total_offers_count > len(all_offers): + offers.add_row("", "...", style="secondary") + else: + # Single job - original logic + job_plan.offers = job_plan.offers[:max_offers] if max_offers else job_plan.offers - for i, offer in enumerate(job_plan.offers, start=1): - r = offer.instance.resources + for i, offer in enumerate(job_plan.offers, start=1): + r = offer.instance.resources - availability = "" - if offer.availability in { - InstanceAvailability.NOT_AVAILABLE, - InstanceAvailability.NO_QUOTA, - InstanceAvailability.IDLE, - InstanceAvailability.BUSY, - }: - availability = offer.availability.value.replace("_", " ").lower() - instance = offer.instance.name - if offer.total_blocks > 1: - instance += f" ({offer.blocks}/{offer.total_blocks})" - offers.add_row( - f"{i}", - offer.backend.replace("remote", "ssh") + " (" + offer.region + ")", - r.pretty_format(include_spot=True), - instance, - f"${offer.price:.4f}".rstrip("0").rstrip("."), - availability, - style=None if i == 1 or not include_run_properties else "secondary", - ) - if job_plan.total_offers > len(job_plan.offers): - offers.add_row("", "...", style="secondary") + availability = "" + if offer.availability in { + InstanceAvailability.NOT_AVAILABLE, + InstanceAvailability.NO_QUOTA, + InstanceAvailability.IDLE, + InstanceAvailability.BUSY, + }: + availability = offer.availability.value.replace("_", " ").lower() + instance = offer.instance.name + if offer.total_blocks > 1: + instance += f" ({offer.blocks}/{offer.total_blocks})" + offers.add_row( + f"{i}", + offer.backend.replace("remote", "ssh") + " (" + offer.region + ")", + r.pretty_format(include_spot=True), + instance, + f"${offer.price:.4f}".rstrip("0").rstrip("."), + availability, + style=None if i == 1 or not include_run_properties else "secondary", + ) + if job_plan.total_offers > len(job_plan.offers): + offers.add_row("", "...", style="secondary") console.print(props) console.print() - if len(job_plan.offers) > 0: + + # Check if we have offers to display + has_offers = False + if len(run_plan.job_plans) > 1: + has_offers = any(len(jp.offers) > 0 for jp in run_plan.job_plans) + else: + has_offers = len(job_plan.offers) > 0 + + if has_offers: console.print(offers) - if job_plan.total_offers > len(job_plan.offers): - console.print( - f"[secondary] Shown {len(job_plan.offers)} of {job_plan.total_offers} offers, " - f"${job_plan.max_price:3f}".rstrip("0").rstrip(".") - + "max[/]" - ) + # Show summary for multi-job plans + if len(run_plan.job_plans) > 1: + if total_offers_count > len(all_offers): + max_price_overall = max((jp.max_price for jp in run_plan.job_plans if jp.max_price), default=None) + if max_price_overall: + console.print( + f"[secondary] Shown {len(all_offers)} of {total_offers_count} offers, " + f"${max_price_overall:3f}".rstrip("0").rstrip(".") + + " max[/]" + ) + else: + if job_plan.total_offers > len(job_plan.offers): + console.print( + f"[secondary] Shown {len(job_plan.offers)} of {job_plan.total_offers} offers, " + f"${job_plan.max_price:3f}".rstrip("0").rstrip(".") + + " max[/]" + ) console.print() else: console.print(NO_OFFERS_WARNING) @@ -233,8 +343,14 @@ def get_runs_table( if verbose and latest_job_submission.inactivity_secs: inactive_for = format_duration_multiunit(latest_job_submission.inactivity_secs) status += f" (inactive for {inactive_for})" + + job_name_parts = [f" replica={job.job_spec.replica_num}"] + if job.job_spec.replica_group_name: + job_name_parts.append(f"[cyan]group={job.job_spec.replica_group_name}[/cyan]") + job_name_parts.append(f"job={job.job_spec.job_num}") + job_row: Dict[Union[str, int], Any] = { - "NAME": f" replica={job.job_spec.replica_num} job={job.job_spec.job_num}" + "NAME": " ".join(job_name_parts) + ( f" deployment={latest_job_submission.deployment_num}" if show_deployment_num diff --git a/src/dstack/_internal/core/compatibility/runs.py b/src/dstack/_internal/core/compatibility/runs.py index a7a9d40a3..3c4b9c5c9 100644 --- a/src/dstack/_internal/core/compatibility/runs.py +++ b/src/dstack/_internal/core/compatibility/runs.py @@ -151,6 +151,9 @@ def get_run_spec_excludes(run_spec: RunSpec) -> IncludeExcludeDictType: configuration_excludes["schedule"] = True if profile is not None and profile.schedule is None: profile_excludes.add("schedule") + # Exclude replica_groups for backward compatibility with older servers + if isinstance(configuration, ServiceConfiguration) and configuration.replica_groups is None: + configuration_excludes["replica_groups"] = True configuration_excludes["repos"] = True if configuration_excludes: diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 6fe8132de..30e060207 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -685,6 +685,52 @@ class TaskConfiguration( type: Literal["task"] = "task" +class ReplicaGroupConfig(ProfileParamsConfig): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + ProfileParamsConfig.schema_extra(schema) + add_extra_schema_types( + schema["properties"]["replicas"], + extra_types=[{"type": "integer"}, {"type": "string"}], + ) + + +class ReplicaGroup(ProfileParams, generate_dual_core_model(ReplicaGroupConfig)): + """ + A replica group defines a set of service replicas with specific resource requirements + and provisioning parameters. + """ + + name: Annotated[str, Field(description="Group name (must be unique within the service)")] + replicas: Annotated[ + Range[int], + Field( + description="Number of replicas. Can be a fixed number (e.g., `2`) or a range (`1..3`). " + "If it's a range, the group can be autoscaled" + ), + ] + resources: Annotated[ + ResourcesSpec, + Field(description="Resource requirements for replicas in this group"), + ] + + @validator("name") + def validate_name(cls, v): + if not v or not v.strip(): + raise ValueError("Group name cannot be empty") + return v + + @validator("replicas") + def convert_replicas(cls, v: Range[int]) -> Range[int]: + if v.max is None: + raise ValueError("The maximum number of replicas is required") + if v.min is None: + v.min = 0 + if v.min < 0: + raise ValueError("The minimum number of replicas must be greater than or equal to 0") + return v + + class ServiceConfigurationParamsConfig(CoreConfig): @staticmethod def schema_extra(schema: Dict[str, Any]): @@ -754,6 +800,13 @@ class ServiceConfigurationParams(CoreModel): list[ProbeConfig], Field(description="List of probes used to determine job health"), ] = [] + replica_groups: Annotated[ + Optional[List[ReplicaGroup]], + Field( + description="Define multiple replica groups with different configurations. " + "Cannot be used together with 'replicas'" + ), + ] = None @validator("port") def convert_port(cls, v) -> PortMapping: @@ -789,14 +842,48 @@ def validate_gateway( ) return v + @root_validator() + def validate_replica_groups_xor_replicas(cls, values): + replica_groups = values.get("replica_groups") + replicas = values.get("replicas") + + # Check if user specified both + has_groups = replica_groups is not None + has_replicas = replicas != Range[int](min=1, max=1) + + if has_groups and has_replicas: + raise ValueError("Cannot specify both 'replicas' and 'replica_groups'") + + if has_groups: + # Validate unique names + names = [g.name for g in replica_groups] + if len(names) != len(set(names)): + raise ValueError("Replica group names must be unique") + + # Validate at least one group + if not replica_groups: + raise ValueError("replica_groups cannot be empty") + + return values + @root_validator() def validate_scaling(cls, values): scaling = values.get("scaling") replicas = values.get("replicas") - if replicas and replicas.min != replicas.max and not scaling: + replica_groups = values.get("replica_groups") + + if replica_groups: + # Check if any group has a range + has_range = any(g.replicas.min != g.replicas.max for g in replica_groups) + if has_range and not scaling: + raise ValueError( + "When any replica group has a range, 'scaling' must be specified" + ) + elif replicas and replicas.min != replicas.max and not scaling: raise ValueError("When you set `replicas` to a range, ensure to specify `scaling`.") - if replicas and replicas.min == replicas.max and scaling: + elif replicas and replicas.min == replicas.max and scaling: raise ValueError("To use `scaling`, `replicas` must be set to a range.") + return values @validator("rate_limits") diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index 0a5b174d2..7efa83902 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -1,11 +1,14 @@ from datetime import datetime, timedelta from enum import Enum -from typing import Any, Dict, List, Literal, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional from urllib.parse import urlparse from pydantic import UUID4, Field, root_validator from typing_extensions import Annotated +if TYPE_CHECKING: + from dstack._internal.core.models.configurations import ReplicaGroup, ServiceConfiguration + from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import ( ApplyAction, @@ -247,6 +250,7 @@ class ProbeSpec(CoreModel): class JobSpec(CoreModel): replica_num: int = 0 # default value for backward compatibility + replica_group_name: Optional[str] = None job_num: int job_name: str jobs_per_replica: int = 1 # default value for backward compatibility @@ -618,3 +622,37 @@ def get_service_port(job_spec: JobSpec, configuration: ServiceConfiguration) -> if job_spec.service_port is None: return configuration.port.container_port return job_spec.service_port + + +def get_normalized_replica_groups(configuration: "ServiceConfiguration") -> List["ReplicaGroup"]: + """ + Normalize service configuration to replica groups. + Converts legacy replicas field to a single "default" group for backward compatibility. + """ + from dstack._internal.core.models.configurations import ReplicaGroup + + if configuration.replica_groups: + return configuration.replica_groups + + return [ + ReplicaGroup( + name="default", + replicas=configuration.replicas, + resources=configuration.resources, + backends=configuration.backends, + regions=configuration.regions, + availability_zones=configuration.availability_zones, + instance_types=configuration.instance_types, + reservation=configuration.reservation, + spot_policy=configuration.spot_policy, + retry=configuration.retry, + max_duration=configuration.max_duration, + stop_duration=configuration.stop_duration, + max_price=configuration.max_price, + creation_policy=configuration.creation_policy, + idle_duration=configuration.idle_duration, + utilization_policy=configuration.utilization_policy, + startup_order=configuration.startup_order, + stop_criteria=configuration.stop_criteria, + ) + ] diff --git a/src/dstack/_internal/server/background/tasks/process_runs.py b/src/dstack/_internal/server/background/tasks/process_runs.py index df1cce72f..1bd30fb61 100644 --- a/src/dstack/_internal/server/background/tasks/process_runs.py +++ b/src/dstack/_internal/server/background/tasks/process_runs.py @@ -189,7 +189,10 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel): run_model.desired_replica_count = 1 if run.run_spec.configuration.type == "service": - run_model.desired_replica_count = run.run_spec.configuration.replicas.min or 0 + from dstack._internal.core.models.runs import get_normalized_replica_groups + + normalized_groups = get_normalized_replica_groups(run.run_spec.configuration) + run_model.desired_replica_count = sum(g.replicas.min or 0 for g in normalized_groups) await update_service_desired_replica_count( session, run_model, diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index 2814840b5..3c740e6a0 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -609,11 +609,14 @@ def _get_nodes_required_num_for_run(run_spec: RunSpec) -> int: nodes_required_num = 1 if run_spec.configuration.type == "task": nodes_required_num = run_spec.configuration.nodes - elif ( - run_spec.configuration.type == "service" - and run_spec.configuration.replicas.min is not None - ): - nodes_required_num = run_spec.configuration.replicas.min + elif run_spec.configuration.type == "service": + # Use groups if present + if run_spec.configuration.replica_groups: + from dstack._internal.core.models.runs import get_normalized_replica_groups + groups = get_normalized_replica_groups(run_spec.configuration) + nodes_required_num = sum(g.replicas.min or 0 for g in groups) + elif run_spec.configuration.replicas.min is not None: + nodes_required_num = run_spec.configuration.replicas.min return nodes_required_num @@ -728,6 +731,34 @@ async def _assign_job_to_fleet_instance( return instance +def _get_profile_for_job(run_spec: RunSpec, job: Job) -> Profile: + """Get merged profile with group overrides for this job.""" + from dstack._internal.core.models.profiles import Profile, ProfileParams + from dstack._internal.core.models.runs import get_normalized_replica_groups + + base_profile = run_spec.merged_profile + + group_name = job.job_spec.replica_group_name + if not group_name or run_spec.configuration.type != "service": + return base_profile + + # Find the group + normalized_groups = get_normalized_replica_groups(run_spec.configuration) + group = next((g for g in normalized_groups if g.name == group_name), None) + + if not group: + return base_profile + + # Merge: group overrides base + merged = Profile.parse_obj(base_profile.dict()) + for field_name in ProfileParams.__fields__: + group_value = getattr(group, field_name, None) + if group_value is not None: + setattr(merged, field_name, group_value) + + return merged + + async def _run_job_on_new_instance( project: ProjectModel, job_model: JobModel, @@ -741,8 +772,8 @@ async def _run_job_on_new_instance( ) -> Optional[tuple[JobProvisioningData, InstanceOfferWithAvailability, Profile, Requirements]]: if volumes is None: volumes = [] - profile = run.run_spec.merged_profile - requirements = job.job_spec.requirements + profile = _get_profile_for_job(run.run_spec, job) + requirements = job.job_spec.requirements # Already has group resources baked in fleet = None if fleet_model is not None: fleet = fleet_model_to_fleet(fleet_model) @@ -822,7 +853,9 @@ def _get_run_profile_and_requirements_in_fleet( run_spec: RunSpec, fleet: Fleet, ) -> tuple[Profile, Requirements]: - profile = combine_fleet_and_run_profiles(fleet.spec.merged_profile, run_spec.merged_profile) + # Use group-merged profile instead of run-level + job_profile = _get_profile_for_job(run_spec, job) + profile = combine_fleet_and_run_profiles(fleet.spec.merged_profile, job_profile) if profile is None: raise ValueError("Cannot combine fleet profile") fleet_requirements = get_fleet_requirements(fleet.spec) diff --git a/src/dstack/_internal/server/migrations/versions/a1b2c3d4e5f6_add_jobmodel_replica_group_name.py b/src/dstack/_internal/server/migrations/versions/a1b2c3d4e5f6_add_jobmodel_replica_group_name.py new file mode 100644 index 000000000..8e8a1eed0 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/a1b2c3d4e5f6_add_jobmodel_replica_group_name.py @@ -0,0 +1,27 @@ +"""Add JobModel.replica_group_name + +Revision ID: a1b2c3d4e5f6 +Revises: ff1d94f65b08 +Create Date: 2025-10-17 00:00:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "a1b2c3d4e5f6" +down_revision = "ff1d94f65b08" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + with op.batch_alter_table("jobs", schema=None) as batch_op: + batch_op.add_column(sa.Column("replica_group_name", sa.String(), nullable=True)) + + +def downgrade() -> None: + with op.batch_alter_table("jobs", schema=None) as batch_op: + batch_op.drop_column("replica_group_name") + diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 31f44d369..d23a99e28 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -437,6 +437,7 @@ class JobModel(BaseModel): instance: Mapped[Optional["InstanceModel"]] = relationship(back_populates="jobs") used_instance_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUIDType(binary=False)) replica_num: Mapped[int] = mapped_column(Integer) + replica_group_name: Mapped[Optional[str]] = mapped_column(String, nullable=True) deployment_num: Mapped[int] = mapped_column(Integer) job_runtime_data: Mapped[Optional[str]] = mapped_column(Text) probes: Mapped[list["ProbeModel"]] = relationship( diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index cbb089b2c..1f9f66580 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -1,9 +1,12 @@ import itertools import json from datetime import timedelta -from typing import Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple from uuid import UUID +if TYPE_CHECKING: + from dstack._internal.core.models.configurations import ReplicaGroup + import requests from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -67,7 +70,10 @@ async def get_jobs_from_run_spec( - run_spec: RunSpec, secrets: Dict[str, str], replica_num: int + run_spec: RunSpec, + secrets: Dict[str, str], + replica_num: int, + replica_group: Optional["ReplicaGroup"] = None, ) -> List[Job]: return [ Job(job_spec=s, job_submissions=[]) @@ -75,14 +81,19 @@ async def get_jobs_from_run_spec( run_spec=run_spec, secrets=secrets, replica_num=replica_num, + replica_group=replica_group, ) ] async def get_job_specs_from_run_spec( - run_spec: RunSpec, secrets: Dict[str, str], replica_num: int + run_spec: RunSpec, + secrets: Dict[str, str], + replica_num: int, + replica_group: Optional["ReplicaGroup"] = None, ) -> List[JobSpec]: job_configurator = _get_job_configurator(run_spec=run_spec, secrets=secrets) + job_configurator.replica_group = replica_group job_specs = await job_configurator.get_job_specs(replica_num=replica_num) return job_specs diff --git a/src/dstack/_internal/server/services/jobs/configurators/base.py b/src/dstack/_internal/server/services/jobs/configurators/base.py index 02cdc70b3..b79c2a15e 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/base.py +++ b/src/dstack/_internal/server/services/jobs/configurators/base.py @@ -3,7 +3,11 @@ import threading from abc import ABC, abstractmethod from pathlib import PurePosixPath -from typing import Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional + +if TYPE_CHECKING: + from dstack._internal.core.models.configurations import ReplicaGroup + from dstack._internal.core.models.profiles import Profile from cachetools import TTLCache, cached @@ -84,6 +88,7 @@ class JobConfigurator(ABC): _image_config: Optional[ImageConfig] = None # JobSSHKey should be shared for all jobs in a replica for inter-node communication. _job_ssh_key: Optional[JobSSHKey] = None + replica_group: Optional["ReplicaGroup"] = None def __init__( self, @@ -146,6 +151,7 @@ async def _get_job_spec( ) -> JobSpec: job_spec = JobSpec( replica_num=replica_num, # TODO(egor-s): add to env variables in the runner + replica_group_name=self.replica_group.name if self.replica_group else None, job_num=job_num, job_name=f"{self.run_spec.run_name}-{job_num}-{replica_num}", jobs_per_replica=jobs_per_replica, @@ -295,13 +301,40 @@ def _utilization_policy(self) -> Optional[UtilizationPolicy]: def _registry_auth(self) -> Optional[RegistryAuth]: return self.run_spec.configuration.registry_auth + def _get_merged_profile(self) -> "Profile": + """Get profile with group overrides applied.""" + from dstack._internal.core.models.profiles import Profile, ProfileParams + + base = self.run_spec.merged_profile + + if not self.replica_group: + return base + + # Clone and apply group overrides + merged = Profile.parse_obj(base.dict()) + for field_name in ProfileParams.__fields__: + group_value = getattr(self.replica_group, field_name, None) + if group_value is not None: + setattr(merged, field_name, group_value) + + return merged + def _requirements(self) -> Requirements: - spot_policy = self._spot_policy() + # Use group resources if available, else fall back to config + if self.replica_group: + resources = self.replica_group.resources + else: + resources = self.run_spec.configuration.resources + + # Get merged profile for spot/price/reservation + profile = self._get_merged_profile() + spot_policy = profile.spot_policy or SpotPolicy.ONDEMAND + return Requirements( - resources=self.run_spec.configuration.resources, - max_price=self.run_spec.merged_profile.max_price, + resources=resources, + max_price=profile.max_price, spot=None if spot_policy == SpotPolicy.AUTO else (spot_policy == SpotPolicy.SPOT), - reservation=self.run_spec.merged_profile.reservation, + reservation=profile.reservation, ) def _retry(self) -> Optional[Retry]: diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 25ac750aa..f20d42ada 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -340,11 +340,34 @@ async def get_plan( action = ApplyAction.UPDATE secrets = await get_project_secrets_mapping(session=session, project=project) - jobs = await get_jobs_from_run_spec( - run_spec=effective_run_spec, - secrets=secrets, - replica_num=0, - ) + + # For services with replica groups, create jobs for all groups during planning + jobs = [] + if ( + effective_run_spec.configuration.type == "service" + and effective_run_spec.configuration.replica_groups + ): + from dstack._internal.core.models.runs import get_normalized_replica_groups + + normalized_groups = get_normalized_replica_groups(effective_run_spec.configuration) + replica_num = 0 + for group in normalized_groups: + # Create one job per group for planning (minimum replicas) + group_jobs = await get_jobs_from_run_spec( + run_spec=effective_run_spec, + secrets=secrets, + replica_num=replica_num, + replica_group=group, + ) + jobs.extend(group_jobs) + replica_num += 1 + else: + # Legacy: single job for planning + jobs = await get_jobs_from_run_spec( + run_spec=effective_run_spec, + secrets=secrets, + replica_num=0, + ) volumes = await get_job_configured_volumes( session=session, @@ -362,10 +385,15 @@ async def get_plan( ) effective_run_spec.run_name = "dry-run" # will regenerate jobs on submission - # Get offers once for all jobs - offers = [] - if creation_policy == CreationPolicy.REUSE_OR_CREATE: - offers = await get_offers_by_requirements( + # Check if all jobs have identical requirements (optimization for single-type jobs) + all_requirements_identical = all( + job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs + ) + + # Get offers once if all jobs are identical, otherwise get per-job + shared_offers = [] + if creation_policy == CreationPolicy.REUSE_OR_CREATE and all_requirements_identical: + shared_offers = await get_offers_by_requirements( project=project, profile=profile, requirements=jobs[0].job_spec.requirements, @@ -379,8 +407,43 @@ async def get_plan( job_plans = [] for job in jobs: job_offers: List[InstanceOfferWithAvailability] = [] - job_offers.extend(pool_offers) - job_offers.extend(offer for _, offer in offers) + + # Filter pool offers to match this job's GPU requirements + gpu_req = None + if job.job_spec.requirements.resources and job.job_spec.requirements.resources.gpu: + gpu_req = job.job_spec.requirements.resources.gpu.name + + matching_pool_offers = [] + for pool_offer in pool_offers: + offer_gpus = pool_offer.instance.resources.gpus + if offer_gpus and gpu_req: + # Check if offer's GPU matches job's requirement + offer_gpu_names = [gpu.name for gpu in offer_gpus] + if any(req_gpu in offer_gpu_names for req_gpu in gpu_req): + matching_pool_offers.append(pool_offer) + elif not gpu_req: + # No GPU requirement, include all pool offers + matching_pool_offers.append(pool_offer) + + job_offers.extend(matching_pool_offers) + + # Use shared offers if all jobs are identical, otherwise fetch per-job + if shared_offers: + job_offers.extend(offer for _, offer in shared_offers) + elif creation_policy == CreationPolicy.REUSE_OR_CREATE: + # Fetch offers specific to this job's requirements + job_specific_offers = await get_offers_by_requirements( + project=project, + profile=profile, + requirements=job.job_spec.requirements, + exclude_not_available=False, + multinode=job.job_spec.jobs_per_replica > 1, + volumes=volumes, + privileged=job.job_spec.privileged, + instance_mounts=check_run_spec_requires_instance_mounts(effective_run_spec), + ) + job_offers.extend(offer for _, offer in job_specific_offers) + job_offers.sort(key=lambda offer: not offer.availability.is_available()) job_spec = job.job_spec @@ -557,19 +620,47 @@ async def submit_run( if run_spec.configuration.type == "service": await services.register_service(session, run_model, run_spec) - for replica_num in range(initial_replicas): - jobs = await get_jobs_from_run_spec( - run_spec=run_spec, - secrets=secrets, - replica_num=replica_num, - ) - for job in jobs: - job_model = create_job_model_for_new_submission( - run_model=run_model, - job=job, - status=JobStatus.SUBMITTED, + from dstack._internal.core.models.runs import get_normalized_replica_groups + + normalized_groups = get_normalized_replica_groups(run_spec.configuration) + + # Set initial desired count (sum of all group minimums) + run_model.desired_replica_count = sum(g.replicas.min or 0 for g in normalized_groups) + + # Create jobs by iterating over groups + replica_num = 0 # Global counter across all groups + for group in normalized_groups: + group_min = group.replicas.min or 0 + for _ in range(group_min): + jobs = await get_jobs_from_run_spec( + run_spec=run_spec, + secrets=secrets, + replica_num=replica_num, + replica_group=group, # Pass group context + ) + for job in jobs: + job_model = create_job_model_for_new_submission( + run_model=run_model, + job=job, + status=JobStatus.SUBMITTED, + ) + session.add(job_model) + replica_num += 1 + else: + # Non-service runs (tasks, dev environments) + for replica_num in range(initial_replicas): + jobs = await get_jobs_from_run_spec( + run_spec=run_spec, + secrets=secrets, + replica_num=replica_num, ) - session.add(job_model) + for job in jobs: + job_model = create_job_model_for_new_submission( + run_model=run_model, + job=job, + status=JobStatus.SUBMITTED, + ) + session.add(job_model) await session.commit() await session.refresh(run_model) @@ -591,6 +682,7 @@ def create_job_model_for_new_submission( job_num=job.job_spec.job_num, job_name=f"{job.job_spec.job_name}", replica_num=job.job_spec.replica_num, + replica_group_name=job.job_spec.replica_group_name, deployment_num=run_model.deployment_num, submission_num=len(job.job_submissions), submitted_at=now, @@ -1017,10 +1109,19 @@ def _validate_run_spec_and_set_defaults(user: UserModel, run_spec: RunSpec): f"Maximum utilization_policy.time_window is {settings.SERVER_METRICS_RUNNING_TTL_SECONDS}s" ) if isinstance(run_spec.configuration, ServiceConfiguration): - if run_spec.merged_profile.schedule and run_spec.configuration.replicas.min == 0: - raise ServerClientError( - "Scheduled services with autoscaling to zero are not supported" - ) + # Check all groups for min=0 with schedule + if run_spec.merged_profile.schedule: + if run_spec.configuration.replica_groups: + from dstack._internal.core.models.runs import get_normalized_replica_groups + groups = get_normalized_replica_groups(run_spec.configuration) + if any(g.replicas.min == 0 for g in groups): + raise ServerClientError( + "Scheduled services with autoscaling to zero are not supported" + ) + elif run_spec.configuration.replicas.min == 0: + raise ServerClientError( + "Scheduled services with autoscaling to zero are not supported" + ) if len(run_spec.configuration.probes) > settings.MAX_PROBES_PER_JOB: raise ServerClientError( f"Cannot configure more than {settings.MAX_PROBES_PER_JOB} probes" @@ -1226,38 +1327,112 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica active_replicas.sort(key=lambda r: (r[1], -r[0], r[2])) run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) + from dstack._internal.core.models.runs import get_normalized_replica_groups + + normalized_groups = ( + get_normalized_replica_groups(run_spec.configuration) + if run_spec.configuration.type == "service" + else [] + ) + if replicas_diff < 0: - for _, _, _, replica_jobs in reversed(active_replicas[-abs(replicas_diff) :]): - # scale down the less important replicas first + # SCALE DOWN: Only terminate from autoscalable groups while respecting group minimums + autoscalable_groups = {g.name for g in normalized_groups if g.replicas.min != g.replicas.max} + + # Count replicas per group + group_counts = {} + for _, _, _, replica_jobs in active_replicas: + if replica_jobs: + group_name = replica_jobs[0].replica_group_name or "default" + group_counts[group_name] = group_counts.get(group_name, 0) + 1 + + # Get group minimums + group_mins = {g.name: g.replicas.min for g in normalized_groups} + + # Terminate from end (reversed), but skip if group not autoscalable or at minimum + terminated_count = 0 + for _, _, _, replica_jobs in reversed(active_replicas): + if terminated_count >= abs(replicas_diff): + break + + if not replica_jobs: + continue + + group_name = replica_jobs[0].replica_group_name or "default" + + # Skip if not autoscalable + if normalized_groups and group_name not in autoscalable_groups: + continue + + # Skip if at minimum + current_count = group_counts.get(group_name, 0) + min_count = group_mins.get(group_name, 0) + if current_count <= min_count: + continue + + # Terminate this replica for job in replica_jobs: - if job.status.is_finished() or job.status == JobStatus.TERMINATING: - continue - job.status = JobStatus.TERMINATING - job.termination_reason = JobTerminationReason.SCALED_DOWN - # background task will process the job later + if not job.status.is_finished() and job.status != JobStatus.TERMINATING: + job.status = JobStatus.TERMINATING + job.termination_reason = JobTerminationReason.SCALED_DOWN + + group_counts[group_name] -= 1 + terminated_count += 1 else: + # SCALE UP: Choose from autoscalable groups + autoscalable_groups = [g for g in normalized_groups if g.replicas.min != g.replicas.max] + + if normalized_groups and not autoscalable_groups: + # No autoscalable groups, cannot scale + logger.info("%s: no autoscalable groups available for scaling up", fmt(run_model)) + return + + # Count current replicas per group to respect maximums + group_counts = {} + for _, _, _, replica_jobs in active_replicas: + if replica_jobs: + group_name = replica_jobs[0].replica_group_name or "default" + group_counts[group_name] = group_counts.get(group_name, 0) + 1 + + # Filter groups that haven't reached maximum + eligible_groups = [ + g for g in autoscalable_groups if group_counts.get(g.name, 0) < (g.replicas.max or float("inf")) + ] if normalized_groups else normalized_groups + + if normalized_groups and not eligible_groups: + # All groups at maximum + logger.info("%s: all autoscalable groups at maximum capacity", fmt(run_model)) + return + scheduled_replicas = 0 - # rerun inactive replicas + # Reuse inactive replicas first (existing logic) for _, _, _, replica_jobs in inactive_replicas: if scheduled_replicas == replicas_diff: break - await retry_run_replica_jobs(session, run_model, replica_jobs, only_failed=False) - scheduled_replicas += 1 + # Only reuse if from autoscalable group + if replica_jobs: + group_name = replica_jobs[0].replica_group_name or "default" + if not normalized_groups or group_name in {g.name for g in autoscalable_groups}: + await retry_run_replica_jobs(session, run_model, replica_jobs, only_failed=False) + scheduled_replicas += 1 - secrets = await get_project_secrets_mapping( - session=session, - project=run_model.project, - ) + # Create new replicas for remaining diff + secrets = await get_project_secrets_mapping(session=session, project=run_model.project) + + for _ in range(replicas_diff - scheduled_replicas): + # Pick group for new replica + # v1: Simple heuristic - pick first eligible group (round-robin in future) + selected_group = eligible_groups[0] if eligible_groups else None + + replica_num = len(active_replicas) + scheduled_replicas - for replica_num in range( - len(active_replicas) + scheduled_replicas, len(active_replicas) + replicas_diff - ): # FIXME: Handle getting image configuration errors or skip it. jobs = await get_jobs_from_run_spec( run_spec=run_spec, secrets=secrets, replica_num=replica_num, + replica_group=selected_group, ) for job in jobs: job_model = create_job_model_for_new_submission( @@ -1267,6 +1442,20 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica ) session.add(job_model) + # Update count + if selected_group: + group_counts[selected_group.name] = group_counts.get(selected_group.name, 0) + 1 + scheduled_replicas += 1 + + # Remove from eligible if at max + if group_counts[selected_group.name] >= (selected_group.replicas.max or float("inf")): + eligible_groups = [g for g in eligible_groups if g.name != selected_group.name] + if not eligible_groups: + logger.info("%s: all eligible groups reached maximum capacity", fmt(run_model)) + break + else: + scheduled_replicas += 1 + async def retry_run_replica_jobs( session: AsyncSession, run_model: RunModel, latest_jobs: List[JobModel], *, only_failed: bool diff --git a/src/dstack/_internal/server/services/services/autoscalers.py b/src/dstack/_internal/server/services/services/autoscalers.py index cd6d06e58..379d2e920 100644 --- a/src/dstack/_internal/server/services/services/autoscalers.py +++ b/src/dstack/_internal/server/services/services/autoscalers.py @@ -120,18 +120,28 @@ def get_desired_count( def get_service_scaler(conf: ServiceConfiguration) -> BaseServiceScaler: - assert conf.replicas.min is not None - assert conf.replicas.max is not None + # Compute bounds from groups if present + if conf.replica_groups: + from dstack._internal.core.models.runs import get_normalized_replica_groups + groups = get_normalized_replica_groups(conf) + min_replicas = sum(g.replicas.min or 0 for g in groups) + max_replicas = sum(g.replicas.max or 0 for g in groups) + else: + assert conf.replicas.min is not None + assert conf.replicas.max is not None + min_replicas = conf.replicas.min + max_replicas = conf.replicas.max + if conf.scaling is None: return ManualScaler( - min_replicas=conf.replicas.min, - max_replicas=conf.replicas.max, + min_replicas=min_replicas, + max_replicas=max_replicas, ) if conf.scaling.metric == "rps": return RPSAutoscaler( # replicas count validated by configuration model - min_replicas=conf.replicas.min, - max_replicas=conf.replicas.max, + min_replicas=min_replicas, + max_replicas=max_replicas, target=conf.scaling.target, scale_up_delay=conf.scaling.scale_up_delay, scale_down_delay=conf.scaling.scale_down_delay, diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 0a8adaa42..f7be289d3 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -345,6 +345,7 @@ async def create_job( instance: Optional[InstanceModel] = None, job_num: int = 0, replica_num: int = 0, + replica_group_name: Optional[str] = None, deployment_num: Optional[int] = None, instance_assigned: bool = False, disconnected_at: Optional[datetime] = None, @@ -365,6 +366,7 @@ async def create_job( job_num=job_num, job_name=run.run_name + f"-{job_num}-{replica_num}", replica_num=replica_num, + replica_group_name=replica_group_name, deployment_num=deployment_num, submission_num=submission_num, submitted_at=submitted_at, diff --git a/src/tests/_internal/cli/utils/test_run_plan_display.py b/src/tests/_internal/cli/utils/test_run_plan_display.py new file mode 100644 index 000000000..7719b6e4e --- /dev/null +++ b/src/tests/_internal/cli/utils/test_run_plan_display.py @@ -0,0 +1,424 @@ +"""Test CLI display of run plans with replica groups.""" + +from io import StringIO +from unittest.mock import MagicMock, patch + +import pytest + +from dstack._internal.cli.utils.run import print_run_plan +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.configurations import ServiceConfiguration +from dstack._internal.core.models.instances import ( + Gpu, + InstanceAvailability, + InstanceType, + Resources, +) +from dstack._internal.core.models.profiles import Profile +from dstack._internal.core.models.repos import LocalRunRepoData +from dstack._internal.core.models.resources import ResourcesSpec, Range +from dstack._internal.core.models.runs import ( + ApplyAction, + InstanceOfferWithAvailability, + JobPlan, + JobSpec, + Requirements, + RunPlan, + RunSpec, +) + + +def create_test_offer( + backend: BackendType, + gpu_name: str, + price: float, + region: str = "us-east", + availability: InstanceAvailability = InstanceAvailability.AVAILABLE, +) -> InstanceOfferWithAvailability: + """Helper to create test offers.""" + return InstanceOfferWithAvailability( + backend=backend, + instance=InstanceType( + name=f"{gpu_name.lower()}-instance", + resources=Resources( + cpus=8, + memory_mib=16384, + gpus=[Gpu(name=gpu_name, memory_mib=40960)], + spot=False, + ), + ), + region=region, + price=price, + availability=availability, + ) + + +class TestReplicaGroupsDisplayInCLI: + """Test that replica groups are properly displayed in CLI output.""" + + def test_multiple_replica_groups_show_group_names(self, capsys): + """CLI should prefix offers with group names when multiple job plans exist.""" + # Create a service with 2 replica groups + config = ServiceConfiguration( + type="service", + port=8000, + commands=["echo test"], + replica_groups=[ + { + "name": "l40s-group", + "replicas": "1", + "resources": {"gpu": {"name": "L40S", "count": 1}}, + }, + { + "name": "a100-group", + "replicas": "1", + "resources": {"gpu": {"name": "A100", "count": 1}}, + }, + ], + ) + + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data=LocalRunRepoData(repo_dir="/tmp"), + configuration=config, + configuration_path=".dstack.yml", + profile=Profile(backends=[BackendType.VASTAI]), + ) + + # Create job plans for each group + l40s_offer = create_test_offer(BackendType.VASTAI, "L40S", 0.50) + a100_offer = create_test_offer(BackendType.VASTAI, "A100", 1.20) + + job_plan_l40s = JobPlan( + job_spec=JobSpec( + replica_num=0, + replica_group_name="l40s-group", + job_num=0, + job_name="test-job-0", + image_name="dstackai/base", + commands=["echo test"], + env={}, + working_dir="/workflow", + requirements=Requirements( + resources=ResourcesSpec(gpu={"name": "L40S", "count": 1}) + ), + ), + offers=[l40s_offer], + total_offers=1, + max_price=0.50, + ) + + job_plan_a100 = JobPlan( + job_spec=JobSpec( + replica_num=1, + replica_group_name="a100-group", + job_num=1, + job_name="test-job-1", + image_name="dstackai/base", + commands=["echo test"], + env={}, + working_dir="/workflow", + requirements=Requirements( + resources=ResourcesSpec(gpu={"name": "A100", "count": 1}) + ), + ), + offers=[a100_offer], + total_offers=1, + max_price=1.20, + ) + + run_plan = RunPlan( + project_name="test-project", + user="test-user", + run_spec=run_spec, + effective_run_spec=run_spec, + job_plans=[job_plan_l40s, job_plan_a100], + current_resource=None, + action=ApplyAction.CREATE, + ) + + # Print the plan + print_run_plan(run_plan, max_offers=10, include_run_properties=True) + + # Capture output + captured = capsys.readouterr() + output = captured.out + + # Verify group names are in the output + assert "l40s-group" in output, "l40s-group name should appear in output" + assert "a100-group" in output, "a100-group name should appear in output" + + # Verify both GPU types are shown + assert "L40S" in output or "l40s" in output.lower() + assert "A100" in output or "a100" in output.lower() + + def test_single_job_plan_no_group_prefix(self, capsys): + """CLI should NOT prefix offers when only one job plan exists (legacy).""" + config = ServiceConfiguration( + type="service", + port=8000, + commands=["echo test"], + replicas=Range[int](min=1, max=1), + resources=ResourcesSpec(gpu={"name": "V100", "count": 1}), + ) + + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data=LocalRunRepoData(repo_dir="/tmp"), + configuration=config, + configuration_path=".dstack.yml", + profile=Profile(backends=[BackendType.AWS]), + ) + + v100_offer = create_test_offer(BackendType.AWS, "V100", 0.80) + + job_plan = JobPlan( + job_spec=JobSpec( + replica_num=0, + replica_group_name="default", + job_num=0, + job_name="test-job-0", + image_name="dstackai/base", + commands=["echo test"], + env={}, + working_dir="/workflow", + requirements=Requirements( + resources=ResourcesSpec(gpu={"name": "V100", "count": 1}) + ), + ), + offers=[v100_offer], + total_offers=1, + max_price=0.80, + ) + + run_plan = RunPlan( + project_name="test-project", + user="test-user", + run_spec=run_spec, + effective_run_spec=run_spec, + job_plans=[job_plan], + current_resource=None, + action=ApplyAction.CREATE, + ) + + # Print the plan + print_run_plan(run_plan, max_offers=10, include_run_properties=True) + + # Capture output + captured = capsys.readouterr() + output = captured.out + + # Verify NO group prefix (legacy mode) + assert "default:" not in output, "Legacy mode should not show group prefix" + # But should show backend normally + assert "aws" in output.lower() + + def test_replica_groups_offers_sorted_by_price(self, capsys): + """Offers from multiple groups should be sorted by price across all groups.""" + config = ServiceConfiguration( + type="service", + port=8000, + commands=["echo test"], + replica_groups=[ + { + "name": "expensive-group", + "replicas": "1", + "resources": {"gpu": {"name": "H100", "count": 1}}, + }, + { + "name": "cheap-group", + "replicas": "1", + "resources": {"gpu": {"name": "T4", "count": 1}}, + }, + ], + ) + + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data=LocalRunRepoData(repo_dir="/tmp"), + configuration=config, + configuration_path=".dstack.yml", + profile=Profile(backends=[BackendType.AWS]), + ) + + # Expensive offer + h100_offer = create_test_offer(BackendType.AWS, "H100", 3.00) + # Cheap offer + t4_offer = create_test_offer(BackendType.AWS, "T4", 0.30) + + job_plan_expensive = JobPlan( + job_spec=JobSpec( + replica_num=0, + replica_group_name="expensive-group", + job_num=0, + job_name="test-job-0", + image_name="dstackai/base", + commands=["echo test"], + env={}, + working_dir="/workflow", + requirements=Requirements( + resources=ResourcesSpec(gpu={"name": "H100", "count": 1}) + ), + ), + offers=[h100_offer], + total_offers=1, + max_price=3.00, + ) + + job_plan_cheap = JobPlan( + job_spec=JobSpec( + replica_num=1, + replica_group_name="cheap-group", + job_num=1, + job_name="test-job-1", + image_name="dstackai/base", + commands=["echo test"], + env={}, + working_dir="/workflow", + requirements=Requirements(resources=ResourcesSpec(gpu={"name": "T4", "count": 1})), + ), + offers=[t4_offer], + total_offers=1, + max_price=0.30, + ) + + run_plan = RunPlan( + project_name="test-project", + user="test-user", + run_spec=run_spec, + effective_run_spec=run_spec, + job_plans=[job_plan_expensive, job_plan_cheap], + current_resource=None, + action=ApplyAction.CREATE, + ) + + # Print the plan + print_run_plan(run_plan, max_offers=10, include_run_properties=True) + + # Capture output + captured = capsys.readouterr() + output = captured.out + + # Split output to find the offers table (after the header section) + lines = output.split("\n") + + # Find lines that contain both a number and a group name (these are offer rows) + offer_rows = [ + line for line in lines + if ("cheap-group:" in line or "expensive-group:" in line) and line.strip().startswith(("1", "2", "3")) + ] + + # The first offer row should be cheap-group (lower price) + assert len(offer_rows) >= 2, "Should have at least 2 offer rows" + assert "cheap-group:" in offer_rows[0], "First offer should be cheap-group (sorted by price)" + assert "expensive-group:" in offer_rows[1], "Second offer should be expensive-group" + assert "$0.3" in output # Price displayed as $0.3 + assert "$3" in output # Price displayed as $3 + + def test_replica_group_with_no_offers_shows_message(self, capsys): + """Replica groups with no available offers should show a message.""" + config = ServiceConfiguration( + type="service", + port=8000, + commands=["echo test"], + replica_groups=[ + { + "name": "available-group", + "replicas": "1", + "resources": {"gpu": {"name": "L40S", "count": 1}}, + }, + { + "name": "unavailable-group", + "replicas": "1", + "resources": {"gpu": {"name": "A100", "count": 1}}, + }, + ], + ) + + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data=LocalRunRepoData(repo_dir="/tmp"), + configuration=config, + configuration_path=".dstack.yml", + profile=Profile(backends=[BackendType.VASTAI]), + ) + + # One group has offers, another doesn't + l40s_offer = create_test_offer(BackendType.VASTAI, "L40S", 0.50) + + job_plan_with_offers = JobPlan( + job_spec=JobSpec( + replica_num=0, + replica_group_name="available-group", + job_num=0, + job_name="test-job-0", + image_name="dstackai/base", + commands=["echo test"], + env={}, + working_dir="/workflow", + requirements=Requirements( + resources=ResourcesSpec(gpu={"name": "L40S", "count": 1}) + ), + ), + offers=[l40s_offer], + total_offers=1, + max_price=0.50, + ) + + job_plan_no_offers = JobPlan( + job_spec=JobSpec( + replica_num=1, + replica_group_name="unavailable-group", + job_num=1, + job_name="test-job-1", + image_name="dstackai/base", + commands=["echo test"], + env={}, + working_dir="/workflow", + requirements=Requirements( + resources=ResourcesSpec(gpu={"name": "A100", "count": 1}) + ), + ), + offers=[], # No offers + total_offers=0, + max_price=0.0, + ) + + run_plan = RunPlan( + project_name="test-project", + user="test-user", + run_spec=run_spec, + effective_run_spec=run_spec, + job_plans=[job_plan_with_offers, job_plan_no_offers], + current_resource=None, + action=ApplyAction.CREATE, + ) + + # Print the plan + print_run_plan(run_plan, max_offers=10, include_run_properties=True) + + # Capture output + captured = capsys.readouterr() + output = captured.out + + # Verify available group shows offer + assert "available-group:" in output + assert "L40S" in output + + # Verify unavailable group shows the standard "no offers" message + # (Message may be wrapped across lines in table display) + assert "unavailable-group:" in output + assert "No matching instance" in output + assert "offers available" in output + assert "Possible reasons:" in output + assert "dstack.ai/docs" in output # URL may be truncated in table + + # Verify unavailable group appears BEFORE available group (at top) + unavailable_pos = output.find("unavailable-group:") + available_pos = output.find("available-group:") + assert unavailable_pos < available_pos, "Group with no offers should appear first" + diff --git a/src/tests/_internal/core/models/test_replica_groups.py b/src/tests/_internal/core/models/test_replica_groups.py new file mode 100644 index 000000000..dc8de87cd --- /dev/null +++ b/src/tests/_internal/core/models/test_replica_groups.py @@ -0,0 +1,437 @@ +"""Tests for Named Replica Groups functionality""" +import pytest + +from dstack._internal.core.errors import ConfigurationError +from dstack._internal.core.models.configurations import ( + ServiceConfiguration, + parse_run_configuration, +) +from dstack._internal.core.models.resources import CPUSpec, GPUSpec, Range, ResourcesSpec +from dstack._internal.core.models.runs import get_normalized_replica_groups + + +class TestReplicaGroupConfiguration: + """Test replica group configuration parsing and validation""" + + def test_basic_replica_groups(self): + """Test basic replica groups configuration""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replica_groups": [ + { + "name": "h100-group", + "replicas": 1, + "resources": {"gpu": "H100:1"}, + }, + { + "name": "rtx5090-group", + "replicas": 2, + "resources": {"gpu": "RTX5090:1"}, + }, + ], + } + + parsed = parse_run_configuration(conf) + assert isinstance(parsed, ServiceConfiguration) + assert parsed.replica_groups is not None + assert len(parsed.replica_groups) == 2 + + # Check first group + assert parsed.replica_groups[0].name == "h100-group" + assert parsed.replica_groups[0].replicas == Range(min=1, max=1) + assert parsed.replica_groups[0].resources.gpu.name == ["H100"] + + # Check second group + assert parsed.replica_groups[1].name == "rtx5090-group" + assert parsed.replica_groups[1].replicas == Range(min=2, max=2) + assert parsed.replica_groups[1].resources.gpu.name == ["RTX5090"] + + def test_replica_groups_with_ranges(self): + """Test replica groups with autoscaling ranges""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replica_groups": [ + { + "name": "fixed-group", + "replicas": 1, + "resources": {"gpu": "H100:1"}, + }, + { + "name": "scalable-group", + "replicas": "1..3", # Range + "resources": {"gpu": "RTX5090:1"}, + }, + ], + "scaling": { + "metric": "rps", + "target": 10, + }, + } + + parsed = parse_run_configuration(conf) + assert parsed.replica_groups is not None + assert len(parsed.replica_groups) == 2 + + # Fixed group + assert parsed.replica_groups[0].replicas == Range(min=1, max=1) + + # Scalable group + assert parsed.replica_groups[1].replicas == Range(min=1, max=3) + + def test_replica_groups_with_profile_params(self): + """Test replica groups can override profile parameters""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + # Service-level settings + "backends": ["aws"], + "regions": ["us-west-2"], + "replica_groups": [ + { + "name": "aws-group", + "replicas": 1, + "resources": {"gpu": "H100:1"}, + # Inherits backends/regions from service + }, + { + "name": "runpod-group", + "replicas": 1, + "resources": {"gpu": "RTX5090:1"}, + # Override backends + "backends": ["runpod"], + "regions": ["eu-west-1"], + }, + ], + } + + parsed = parse_run_configuration(conf) + + # First group inherits from service (doesn't specify backends/regions) + assert parsed.replica_groups[0].backends is None + assert parsed.replica_groups[0].regions is None + + # Second group overrides + assert parsed.replica_groups[1].backends == ["runpod"] + assert parsed.replica_groups[1].regions == ["eu-west-1"] + + def test_replica_groups_xor_replicas(self): + """Test that replica_groups and replicas are mutually exclusive""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replicas": 2, # Old format + "replica_groups": [ # New format + { + "name": "group1", + "replicas": 1, + "resources": {"gpu": "H100:1"}, + } + ], + } + + with pytest.raises( + ConfigurationError, + match="Cannot specify both 'replicas' and 'replica_groups'", + ): + parse_run_configuration(conf) + + def test_replica_groups_unique_names(self): + """Test that replica group names must be unique""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replica_groups": [ + { + "name": "group1", + "replicas": 1, + "resources": {"gpu": "H100:1"}, + }, + { + "name": "group1", # Duplicate! + "replicas": 1, + "resources": {"gpu": "RTX5090:1"}, + }, + ], + } + + with pytest.raises( + ConfigurationError, + match="Replica group names must be unique", + ): + parse_run_configuration(conf) + + def test_replica_groups_empty_name(self): + """Test that replica group names cannot be empty""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replica_groups": [ + { + "name": "", # Empty name + "replicas": 1, + "resources": {"gpu": "H100:1"}, + } + ], + } + + with pytest.raises( + ConfigurationError, + match="Group name cannot be empty", + ): + parse_run_configuration(conf) + + def test_replica_groups_range_requires_scaling(self): + """Test that replica ranges require scaling configuration""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replica_groups": [ + { + "name": "scalable-group", + "replicas": "1..3", + "resources": {"gpu": "RTX5090:1"}, + } + ], + # Missing scaling! + } + + with pytest.raises( + ConfigurationError, + match="When any replica group has a range, 'scaling' must be specified", + ): + parse_run_configuration(conf) + + def test_replica_groups_cannot_be_empty(self): + """Test that replica_groups list cannot be empty""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replica_groups": [], # Empty list + } + + with pytest.raises( + ConfigurationError, + match="replica_groups cannot be empty", + ): + parse_run_configuration(conf) + + +class TestReplicaGroupNormalization: + """Test get_normalized_replica_groups helper""" + + def test_normalize_new_format(self): + """Test normalization with replica_groups format""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replica_groups": [ + { + "name": "group1", + "replicas": 1, + "resources": {"gpu": "H100:1"}, + }, + { + "name": "group2", + "replicas": 2, + "resources": {"gpu": "RTX5090:1"}, + }, + ], + } + + parsed = parse_run_configuration(conf) + normalized = get_normalized_replica_groups(parsed) + + assert len(normalized) == 2 + assert normalized[0].name == "group1" + assert normalized[1].name == "group2" + + def test_normalize_legacy_format(self): + """Test normalization converts legacy replicas to default group""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replicas": 3, + "resources": {"gpu": "H100:1"}, + "backends": ["aws"], + "regions": ["us-west-2"], + } + + parsed = parse_run_configuration(conf) + normalized = get_normalized_replica_groups(parsed) + + # Should create single "default" group + assert len(normalized) == 1 + assert normalized[0].name == "default" + assert normalized[0].replicas == Range(min=3, max=3) + assert normalized[0].resources.gpu.name == ["H100"] + + # Should inherit profile params + assert normalized[0].backends == ["aws"] + assert normalized[0].regions == ["us-west-2"] + + def test_normalize_legacy_with_range(self): + """Test normalization with legacy autoscaling""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replicas": "1..5", + "resources": {"gpu": "RTX5090:1"}, + "scaling": { + "metric": "rps", + "target": 10, + }, + } + + parsed = parse_run_configuration(conf) + normalized = get_normalized_replica_groups(parsed) + + assert len(normalized) == 1 + assert normalized[0].name == "default" + assert normalized[0].replicas == Range(min=1, max=5) + + +class TestReplicaGroupAutoscaling: + """Test autoscaling behavior with replica groups""" + + def test_autoscalable_group_detection(self): + """Test identifying which groups are autoscalable""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replica_groups": [ + { + "name": "fixed", + "replicas": 1, + "resources": {"gpu": "H100:1"}, + }, + { + "name": "scalable", + "replicas": "1..3", + "resources": {"gpu": "RTX5090:1"}, + }, + ], + "scaling": { + "metric": "rps", + "target": 10, + }, + } + + parsed = parse_run_configuration(conf) + + # Fixed group: min == max + assert parsed.replica_groups[0].replicas.min == parsed.replica_groups[0].replicas.max + + # Scalable group: min != max + assert parsed.replica_groups[1].replicas.min != parsed.replica_groups[1].replicas.max + + def test_multiple_autoscalable_groups(self): + """Test multiple groups can be autoscalable""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replica_groups": [ + { + "name": "scalable-1", + "replicas": "1..3", + "resources": {"gpu": "H100:1"}, + }, + { + "name": "scalable-2", + "replicas": "2..5", + "resources": {"gpu": "RTX5090:1"}, + }, + ], + "scaling": { + "metric": "rps", + "target": 10, + }, + } + + parsed = parse_run_configuration(conf) + + # Both are autoscalable + assert parsed.replica_groups[0].replicas.min != parsed.replica_groups[0].replicas.max + assert parsed.replica_groups[1].replicas.min != parsed.replica_groups[1].replicas.max + + +class TestBackwardCompatibility: + """Test backward compatibility with existing configurations""" + + def test_legacy_service_config(self): + """Test that legacy service configs still work""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replicas": 2, + "resources": {"gpu": "A100:1"}, + } + + parsed = parse_run_configuration(conf) + + # Should parse successfully + assert isinstance(parsed, ServiceConfiguration) + assert parsed.replicas == Range(min=2, max=2) + assert parsed.replica_groups is None # Not using new format + + def test_legacy_autoscaling_config(self): + """Test legacy autoscaling configurations""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replicas": "0..5", + "resources": {"gpu": "A100:1"}, + "scaling": { + "metric": "rps", + "target": 10, + }, + } + + parsed = parse_run_configuration(conf) + + # Should parse successfully + assert parsed.replicas == Range(min=0, max=5) + assert parsed.scaling is not None + + def test_normalization_preserves_all_profile_params(self): + """Test that normalization copies all ProfileParams fields""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replicas": 1, + "resources": {"gpu": "H100:1"}, + "backends": ["aws"], + "regions": ["us-east-1"], + "instance_types": ["p4d.24xlarge"], + "spot_policy": "spot", + "max_price": 10.0, + } + + parsed = parse_run_configuration(conf) + normalized = get_normalized_replica_groups(parsed) + + # Check all fields are copied + group = normalized[0] + assert group.backends == ["aws"] + assert group.regions == ["us-east-1"] + assert group.instance_types == ["p4d.24xlarge"] + assert group.spot_policy == "spot" + assert group.max_price == 10.0 + diff --git a/src/tests/_internal/core/test_backward_compatibility.py b/src/tests/_internal/core/test_backward_compatibility.py new file mode 100644 index 000000000..6bb2ec8b3 --- /dev/null +++ b/src/tests/_internal/core/test_backward_compatibility.py @@ -0,0 +1,130 @@ +"""Test backward compatibility for replica_groups with older servers.""" + +import pytest + +from dstack._internal.core.compatibility.runs import get_get_plan_excludes, get_run_spec_excludes +from dstack._internal.core.models.configurations import ServiceConfiguration +from dstack._internal.core.models.repos import LocalRunRepoData +from dstack._internal.core.models.runs import RunSpec +from dstack._internal.server.schemas.runs import GetRunPlanRequest + + +class TestReplicaGroupsBackwardCompatibility: + """Test that replica_groups field is excluded when None for backward compatibility.""" + + def test_replica_groups_excluded_when_none(self): + """replica_groups should be excluded from JSON when None.""" + config = ServiceConfiguration( + type="service", + port=8000, + commands=["echo test"], + replicas={"min": 1, "max": 1}, + ) + + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data=LocalRunRepoData(repo_dir="/tmp"), + configuration=config, + profile=None, + ) + + # Get excludes + excludes = get_run_spec_excludes(run_spec) + + # replica_groups should be in excludes + assert "configuration" in excludes + assert "replica_groups" in excludes["configuration"] + assert excludes["configuration"]["replica_groups"] is True + + def test_replica_groups_not_excluded_when_set(self): + """replica_groups should NOT be excluded when set.""" + config = ServiceConfiguration( + type="service", + port=8000, + commands=["echo test"], + replica_groups=[ + { + "name": "gpu-group", + "replicas": "1", + "resources": {"gpu": {"name": "A100"}}, + } + ], + scaling={"metric": "rps", "target": 10}, + ) + + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data=LocalRunRepoData(repo_dir="/tmp"), + configuration=config, + profile=None, + ) + + # Get excludes + excludes = get_run_spec_excludes(run_spec) + + # replica_groups should NOT be in excludes (or be False) + if "configuration" in excludes and "replica_groups" in excludes["configuration"]: + assert excludes["configuration"]["replica_groups"] is not True + + def test_get_plan_request_serialization_without_replica_groups(self): + """GetRunPlanRequest should not include replica_groups in JSON when None.""" + config = ServiceConfiguration( + type="service", + port=8000, + commands=["echo test"], + replicas={"min": 1, "max": 1}, + ) + + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data=LocalRunRepoData(repo_dir="/tmp"), + configuration=config, + profile=None, + ) + + request = GetRunPlanRequest(run_spec=run_spec, max_offers=None) + excludes = get_get_plan_excludes(request) + + # Serialize with excludes + json_str = request.json(exclude=excludes) + + # replica_groups should not appear in JSON + assert "replica_groups" not in json_str + + def test_get_plan_request_serialization_with_replica_groups(self): + """GetRunPlanRequest should include replica_groups in JSON when set.""" + config = ServiceConfiguration( + type="service", + port=8000, + commands=["echo test"], + replica_groups=[ + { + "name": "gpu-group", + "replicas": "1", + "resources": {"gpu": {"name": "A100"}}, + } + ], + scaling={"metric": "rps", "target": 10}, + ) + + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data=LocalRunRepoData(repo_dir="/tmp"), + configuration=config, + profile=None, + ) + + request = GetRunPlanRequest(run_spec=run_spec, max_offers=None) + excludes = get_get_plan_excludes(request) + + # Serialize with excludes + json_str = request.json(exclude=excludes) + + # replica_groups SHOULD appear in JSON + assert "replica_groups" in json_str + assert "gpu-group" in json_str + diff --git a/src/tests/_internal/server/services/test_get_plan_replica_groups.py b/src/tests/_internal/server/services/test_get_plan_replica_groups.py new file mode 100644 index 000000000..8d322312d --- /dev/null +++ b/src/tests/_internal/server/services/test_get_plan_replica_groups.py @@ -0,0 +1,214 @@ +"""Test get_plan() offer fetching logic for replica groups.""" + +import pytest + +from dstack._internal.core.models.resources import ResourcesSpec +from dstack._internal.core.models.runs import Requirements + + +class TestGetPlanOfferFetchingLogic: + """Test the logic for determining when to fetch offers per-job vs. shared.""" + + def test_requirements_equality_check(self): + """Test that Requirements objects can be compared for equality.""" + # Identical requirements + req1 = Requirements( + resources=ResourcesSpec(gpu={"name": "A100", "count": 1}), + ) + req2 = Requirements( + resources=ResourcesSpec(gpu={"name": "A100", "count": 1}), + ) + assert req1 == req2 + + # Different GPU names + req3 = Requirements( + resources=ResourcesSpec(gpu={"name": "H100", "count": 1}), + ) + assert req1 != req3 + + # Different GPU counts + req4 = Requirements( + resources=ResourcesSpec(gpu={"name": "A100", "count": 2}), + ) + assert req1 != req4 + + def test_identical_requirements_detection_logic(self): + """Test logic for detecting when all jobs have identical requirements.""" + # Simulate job specs with requirements + class MockJobSpec: + def __init__(self, gpu_name: str, gpu_count: int = 1): + self.requirements = Requirements( + resources=ResourcesSpec(gpu={"name": gpu_name, "count": gpu_count}), + ) + + class MockJob: + def __init__(self, gpu_name: str, gpu_count: int = 1): + self.job_spec = MockJobSpec(gpu_name, gpu_count) + + # Test 1: All identical + jobs = [MockJob("A100"), MockJob("A100"), MockJob("A100")] + all_requirements_identical = all( + job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs + ) + assert all_requirements_identical is True + + # Test 2: Different GPU types + jobs = [MockJob("A100"), MockJob("H100"), MockJob("A100")] + all_requirements_identical = all( + job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs + ) + assert all_requirements_identical is False + + # Test 3: Different GPU counts + jobs = [MockJob("A100", 1), MockJob("A100", 2)] + all_requirements_identical = all( + job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs + ) + assert all_requirements_identical is False + + # Test 4: Single job (always identical) + jobs = [MockJob("V100")] + all_requirements_identical = all( + job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs + ) + assert all_requirements_identical is True + + def test_offer_fetch_decision_logic(self): + """Test the decision logic for when to use shared vs per-job offer fetching.""" + + class MockJobSpec: + def __init__(self, gpu_name: str): + self.requirements = Requirements( + resources=ResourcesSpec(gpu={"name": gpu_name, "count": 1}), + ) + + class MockJob: + def __init__(self, group_name: str, gpu_name: str): + self.job_spec = MockJobSpec(gpu_name) + self.job_spec.replica_group_name = group_name + + # Scenario 1: Replica groups with different GPUs -> per-job fetch + jobs = [ + MockJob("l40s-group", "L40S"), + MockJob("rtx4080-group", "RTX4080"), + ] + all_identical = all( + job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs + ) + assert ( + all_identical is False + ), "Different GPU types should trigger per-job offer fetch" + + # Scenario 2: Replica groups with same GPU -> shared fetch (optimization) + jobs = [ + MockJob("group-1", "A100"), + MockJob("group-2", "A100"), + ] + all_identical = all( + job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs + ) + assert all_identical is True, "Identical GPUs should use shared offer fetch" + + # Scenario 3: Legacy replicas (same requirements) -> shared fetch + jobs = [ + MockJob("default", "V100"), + MockJob("default", "V100"), + ] + all_identical = all( + job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs + ) + assert ( + all_identical is True + ), "Legacy replicas with same GPU should use shared fetch" + + # Scenario 4: Mixed groups (2 same + 1 different) -> per-job fetch + jobs = [ + MockJob("a100-group-1", "A100"), + MockJob("h100-group", "H100"), + MockJob("a100-group-2", "A100"), + ] + all_identical = all( + job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs + ) + assert ( + all_identical is False + ), "Mix of different GPUs should trigger per-job fetch for all" + + +class TestReplicaGroupOfferSearchIntegration: + """Integration tests for replica group offer search behavior.""" + + def test_different_gpu_types_creates_different_requirements(self): + """Different replica group GPU types should create different Requirements objects.""" + # This tests the data model behavior that get_plan() relies on + req_l40s = Requirements( + resources=ResourcesSpec( + gpu={"name": "L40S", "count": 1}, + ) + ) + + req_rtx4080 = Requirements( + resources=ResourcesSpec( + gpu={"name": "RTX4080", "count": 1}, + ) + ) + + # These should NOT be equal + assert req_l40s != req_rtx4080 + + # Verify the GPU names are different + assert req_l40s.resources.gpu.name != req_rtx4080.resources.gpu.name + + def test_identical_gpu_types_creates_identical_requirements(self): + """Identical replica group GPU types should create equal Requirements objects.""" + req_a = Requirements( + resources=ResourcesSpec( + gpu={"name": "A100", "count": 1}, + ) + ) + + req_b = Requirements( + resources=ResourcesSpec( + gpu={"name": "A100", "count": 1}, + ) + ) + + # These SHOULD be equal (enables optimization) + assert req_a == req_b + + def test_requirements_with_different_memory(self): + """Requirements with different GPU memory should not be equal.""" + req_16gb = Requirements( + resources=ResourcesSpec( + gpu={"name": "A100", "memory": "16GB", "count": 1}, + ) + ) + + req_40gb = Requirements( + resources=ResourcesSpec( + gpu={"name": "A100", "memory": "40GB", "count": 1}, + ) + ) + + # Different memory specifications + assert req_16gb != req_40gb + + def test_requirements_with_different_cpu_specs(self): + """Requirements with different CPU specs should not be equal.""" + req_low_cpu = Requirements( + resources=ResourcesSpec( + cpu={"min": 2}, + gpu={"name": "A100", "count": 1}, + ) + ) + + req_high_cpu = Requirements( + resources=ResourcesSpec( + cpu={"min": 16}, + gpu={"name": "A100", "count": 1}, + ) + ) + + # Different CPU requirements + assert req_low_cpu != req_high_cpu + diff --git a/src/tests/_internal/server/services/test_replica_groups_scaling.py b/src/tests/_internal/server/services/test_replica_groups_scaling.py new file mode 100644 index 000000000..de40e5fad --- /dev/null +++ b/src/tests/_internal/server/services/test_replica_groups_scaling.py @@ -0,0 +1,389 @@ +"""Integration tests for replica groups scaling functionality""" +from typing import List + +import pytest +from pydantic import parse_obj_as +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from dstack._internal.core.models.configurations import ReplicaGroup, ScalingSpec, ServiceConfiguration +from dstack._internal.core.models.profiles import Profile +from dstack._internal.core.models.resources import GPUSpec, Range, ResourcesSpec +from dstack._internal.core.models.runs import JobStatus, JobTerminationReason +from dstack._internal.server.models import JobModel, RunModel +from dstack._internal.server.services.runs import scale_run_replicas +from dstack._internal.server.testing.common import ( + create_job, + create_project, + create_repo, + create_run, + create_user, + get_run_spec, +) + +pytestmark = pytest.mark.usefixtures("image_config_mock") + + +async def scale_wrapper(session: AsyncSession, run: RunModel, diff: int): + """Wrapper that handles commit and refresh like existing tests""" + await scale_run_replicas(session, run, diff) + await session.commit() + await session.refresh(run) + + +async def make_run_with_groups( + session: AsyncSession, + groups_config: List[dict], # List of {name, replicas_range, gpu, initial_jobs} +) -> RunModel: + """Helper to create a run with replica groups""" + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + + # Build replica groups + replica_groups = [] + for group_cfg in groups_config: + replica_groups.append( + ReplicaGroup( + name=group_cfg["name"], + replicas=parse_obj_as(Range[int], group_cfg["replicas_range"]), + resources=ResourcesSpec( + gpu=GPUSpec(name=[group_cfg["gpu"]], count=1) + ), + ) + ) + + profile = Profile(name="test-profile") + run_spec = get_run_spec( + repo_id=repo.name, + run_name="test-run", + profile=profile, + configuration=ServiceConfiguration( + commands=["python app.py"], + port=8000, + replica_groups=replica_groups, + scaling=ScalingSpec(metric="rps", target=10), + ), + ) + + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="test-run", + run_spec=run_spec, + ) + + # Create initial jobs + replica_num = 0 + for group_cfg in groups_config: + for job_status in group_cfg.get("initial_jobs", []): + job = await create_job( + session=session, + run=run, + status=job_status, + replica_num=replica_num, + replica_group_name=group_cfg["name"], + ) + run.jobs.append(job) + replica_num += 1 + + await session.commit() + + # Reload with jobs and project + res = await session.execute( + select(RunModel) + .where(RunModel.id == run.id) + .options(selectinload(RunModel.jobs), selectinload(RunModel.project)) + ) + return res.scalar_one() + + +class TestReplicaGroupsScaleDown: + """Test scaling down with replica groups""" + + @pytest.mark.asyncio + async def test_scale_down_only_from_autoscalable_groups(self, session: AsyncSession): + """Test that scale down only affects autoscalable groups""" + run = await make_run_with_groups( + session, + [ + { + "name": "fixed-h100", + "replicas_range": "1..1", # Fixed + "gpu": "H100", + "initial_jobs": [JobStatus.RUNNING], + }, + { + "name": "scalable-rtx", + "replicas_range": "1..3", # Autoscalable + "gpu": "RTX5090", + "initial_jobs": [JobStatus.RUNNING, JobStatus.RUNNING], + }, + ], + ) + + # Scale down by 1 (should only affect scalable group) + await scale_wrapper(session, run, -1) + + # Check: fixed group should still have 1 running job + fixed_jobs = [j for j in run.jobs if j.replica_group_name == "fixed-h100"] + assert len(fixed_jobs) == 1 + assert fixed_jobs[0].status == JobStatus.RUNNING + + # Check: scalable group should have 1 terminated, 1 running + scalable_jobs = [j for j in run.jobs if j.replica_group_name == "scalable-rtx"] + assert len(scalable_jobs) == 2 + terminating = [j for j in scalable_jobs if j.status == JobStatus.TERMINATING] + assert len(terminating) == 1 + assert terminating[0].termination_reason == JobTerminationReason.SCALED_DOWN + + @pytest.mark.asyncio + async def test_scale_down_respects_group_minimums(self, session: AsyncSession): + """Test that scale down respects each group's minimum""" + run = await make_run_with_groups( + session, + [ + { + "name": "group-a", + "replicas_range": "1..3", # Min=1 + "gpu": "H100", + "initial_jobs": [JobStatus.RUNNING], # At minimum + }, + { + "name": "group-b", + "replicas_range": "2..5", # Min=2 + "gpu": "RTX5090", + "initial_jobs": [JobStatus.RUNNING, JobStatus.RUNNING, JobStatus.RUNNING], + }, + ], + ) + + # Try to scale down by 2 + await scale_wrapper(session, run, -2) + + # Group A should still have 1 (at minimum) + group_a_jobs = [j for j in run.jobs if j.replica_group_name == "group-a"] + assert len([j for j in group_a_jobs if j.status == JobStatus.RUNNING]) == 1 + + # Group B should have terminated 1 (3 -> 2, which is minimum) + group_b_jobs = [j for j in run.jobs if j.replica_group_name == "group-b"] + terminating = [j for j in group_b_jobs if j.status == JobStatus.TERMINATING] + assert len(terminating) == 1 + + @pytest.mark.asyncio + async def test_scale_down_all_groups_fixed(self, session: AsyncSession): + """Test scaling down when all groups are fixed (should not terminate anything)""" + run = await make_run_with_groups( + session, + [ + { + "name": "fixed-1", + "replicas_range": "1..1", + "gpu": "H100", + "initial_jobs": [JobStatus.RUNNING], + }, + { + "name": "fixed-2", + "replicas_range": "2..2", + "gpu": "RTX5090", + "initial_jobs": [JobStatus.RUNNING, JobStatus.RUNNING], + }, + ], + ) + + initial_count = len(run.jobs) + + # Try to scale down + await scale_wrapper(session, run, -1) + + # No jobs should be terminated (all groups are fixed) + assert len(run.jobs) == initial_count + assert all(j.status == JobStatus.RUNNING for j in run.jobs) + + +class TestReplicaGroupsScaleUp: + """Test scaling up with replica groups""" + + @pytest.mark.asyncio + async def test_scale_up_selects_autoscalable_group(self, session: AsyncSession): + """Test that scale up only creates jobs in autoscalable groups""" + run = await make_run_with_groups( + session, + [ + { + "name": "fixed-h100", + "replicas_range": "1..1", # Fixed + "gpu": "H100", + "initial_jobs": [JobStatus.RUNNING], + }, + { + "name": "scalable-rtx", + "replicas_range": "1..3", # Autoscalable + "gpu": "RTX5090", + "initial_jobs": [JobStatus.RUNNING], + }, + ], + ) + + initial_count = len(run.jobs) + + # Scale up by 1 + await scale_wrapper(session, run, 1) + + # Should have one more job + assert len(run.jobs) == initial_count + 1 + + # New job should be in scalable group + new_jobs = [j for j in run.jobs if j.replica_num == initial_count] + assert len(new_jobs) == 1 + assert new_jobs[0].replica_group_name == "scalable-rtx" + assert new_jobs[0].status == JobStatus.SUBMITTED + + @pytest.mark.asyncio + async def test_scale_up_respects_group_maximums(self, session: AsyncSession): + """Test that scale up respects group maximums""" + run = await make_run_with_groups( + session, + [ + { + "name": "small-group", + "replicas_range": "1..2", # Max=2 + "gpu": "H100", + "initial_jobs": [JobStatus.RUNNING, JobStatus.RUNNING], # At max + }, + { + "name": "large-group", + "replicas_range": "1..5", # Max=5 + "gpu": "RTX5090", + "initial_jobs": [JobStatus.RUNNING], + }, + ], + ) + + # Try to scale up by 2 + await scale_wrapper(session, run, 2) + + # Small group should still have 2 (at max) + small_jobs = [j for j in run.jobs if j.replica_group_name == "small-group"] + assert len(small_jobs) == 2 + + # Large group should have grown by 2 + large_jobs = [j for j in run.jobs if j.replica_group_name == "large-group"] + assert len(large_jobs) == 3 + + @pytest.mark.asyncio + async def test_scale_up_no_autoscalable_groups(self, session: AsyncSession): + """Test scale up does nothing when no autoscalable groups exist""" + run = await make_run_with_groups( + session, + [ + { + "name": "fixed-1", + "replicas_range": "1..1", + "gpu": "H100", + "initial_jobs": [JobStatus.RUNNING], + }, + { + "name": "fixed-2", + "replicas_range": "2..2", + "gpu": "RTX5090", + "initial_jobs": [JobStatus.RUNNING, JobStatus.RUNNING], + }, + ], + ) + + initial_count = len(run.jobs) + + # Try to scale up + await scale_wrapper(session, run, 2) + + # Should not have added any jobs + assert len(run.jobs) == initial_count + + @pytest.mark.asyncio + async def test_scale_up_all_groups_at_max(self, session: AsyncSession): + """Test scale up when all autoscalable groups are at maximum""" + run = await make_run_with_groups( + session, + [ + { + "name": "group-a", + "replicas_range": "1..2", + "gpu": "H100", + "initial_jobs": [JobStatus.RUNNING, JobStatus.RUNNING], # At max + }, + { + "name": "group-b", + "replicas_range": "1..3", + "gpu": "RTX5090", + "initial_jobs": [JobStatus.RUNNING, JobStatus.RUNNING, JobStatus.RUNNING], # At max + }, + ], + ) + + initial_count = len(run.jobs) + + # Try to scale up + await scale_wrapper(session, run, 1) + + # Should not have added any jobs (all at max) + assert len(run.jobs) == initial_count + + +class TestReplicaGroupsBackwardCompatibility: + """Test backward compatibility with legacy configs""" + + @pytest.mark.asyncio + async def test_legacy_config_scaling(self, session: AsyncSession): + """Test scaling works with legacy replicas configuration""" + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + + # Use legacy format (no replica_groups) + profile = Profile(name="test-profile") + run_spec = get_run_spec( + repo_id=repo.name, + run_name="test-run", + profile=profile, + configuration=ServiceConfiguration( + commands=["python app.py"], + port=8000, + replicas=parse_obj_as(Range[int], "1..3"), # Legacy format + scaling=ScalingSpec(metric="rps", target=10), + ), + ) + + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="test-run", + run_spec=run_spec, + ) + + # Add initial job (no group name) + job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + replica_num=0, + replica_group_name=None, # Legacy jobs have no group + ) + run.jobs.append(job) + await session.commit() + + # Scale up should work + await scale_wrapper(session, run, 1) + + # Should have 2 jobs now + assert len(run.jobs) == 2 + + # New job should have "default" group name or None + new_job = [j for j in run.jobs if j.replica_num == 1][0] + assert new_job.replica_group_name in [None, "default"] + From 5192caf2eba44b5e157da6ba5698ac9b2b031077 Mon Sep 17 00:00:00 2001 From: "Alexander Nicholson 4584443+DragonStuff@users.noreply.github.com" <4584443+DragonStuff@users.noreply.github.com> Date: Sat, 18 Oct 2025 13:43:32 +0900 Subject: [PATCH 2/5] chore: run linter and fix failing tests --- src/dstack/_internal/cli/utils/run.py | 38 ++++---- .../_internal/core/models/configurations.py | 14 +-- .../server/background/tasks/process_runs.py | 2 + src/dstack/_internal/server/services/runs.py | 94 +++++++++++++------ .../server/services/services/autoscalers.py | 2 +- .../cli/utils/test_run_plan_display.py | 13 +-- .../core/models/test_replica_groups.py | 66 ++++++------- .../core/test_backward_compatibility.py | 29 +++--- .../_internal/server/routers/test_runs.py | 2 + .../services/test_get_plan_replica_groups.py | 1 - .../services/test_replica_groups_scaling.py | 76 ++++++++------- 11 files changed, 190 insertions(+), 147 deletions(-) diff --git a/src/dstack/_internal/cli/utils/run.py b/src/dstack/_internal/cli/utils/run.py index ffcbef4f5..c0cbaf815 100644 --- a/src/dstack/_internal/cli/utils/run.py +++ b/src/dstack/_internal/cli/utils/run.py @@ -119,9 +119,9 @@ def th(s: str) -> str: if include_run_properties: props.add_row(th("Configuration"), run_spec.configuration_path) props.add_row(th("Type"), run_spec.configuration.type) - + from dstack._internal.core.models.configurations import ServiceConfiguration - + if ( include_run_properties and isinstance(run_spec.configuration, ServiceConfiguration) @@ -130,21 +130,21 @@ def th(s: str) -> str: groups_info = [] for group in run_spec.configuration.replica_groups: group_parts = [f"[cyan]{group.name}[/cyan]"] - + if group.replicas.min == group.replicas.max: group_parts.append(f"×{group.replicas.max}") else: group_parts.append(f"×{group.replicas.min}..{group.replicas.max}") group_parts.append("[dim](autoscalable)[/dim]") - + group_parts.append(f"[dim]({group.resources.pretty_format()})[/dim]") - + groups_info.append(" ".join(group_parts)) - + props.add_row(th("Replica groups"), "\n".join(groups_info)) else: props.add_row(th("Resources"), pretty_req) - + props.add_row(th("Spot policy"), spot_policy) props.add_row(th("Max price"), max_price) if include_run_properties: @@ -163,14 +163,14 @@ def th(s: str) -> str: offers.add_column("INSTANCE TYPE", style="grey58", no_wrap=True, ratio=2) offers.add_column("PRICE", style="grey58", ratio=1) offers.add_column() - + # For replica groups, show offers from all job plans if len(run_plan.job_plans) > 1: # Multiple jobs - aggregate offers from all groups all_offers = [] groups_with_no_offers = [] total_offers_count = 0 - + for jp in run_plan.job_plans: group_name = jp.job_spec.replica_group_name or "default" if jp.total_offers == 0: @@ -178,12 +178,12 @@ def th(s: str) -> str: for offer in jp.offers[:max_offers] if max_offers else jp.offers: all_offers.append((group_name, offer)) total_offers_count += jp.total_offers - + # Sort by price all_offers.sort(key=lambda x: x[1].price) if max_offers: all_offers = all_offers[:max_offers] - + # Show groups with no offers FIRST for group_name in groups_with_no_offers: offers.add_row( @@ -196,7 +196,7 @@ def th(s: str) -> str: "", style="secondary", ) - + # Then show groups with offers for i, (group_name, offer) in enumerate(all_offers, start=1): r = offer.instance.resources @@ -212,10 +212,10 @@ def th(s: str) -> str: instance = offer.instance.name if offer.total_blocks > 1: instance += f" ({offer.blocks}/{offer.total_blocks})" - + # Add group name prefix for multi-group display backend_display = f"[cyan]{group_name}[/cyan]: {offer.backend.replace('remote', 'ssh')} ({offer.region})" - + offers.add_row( f"{i}", backend_display, @@ -225,7 +225,7 @@ def th(s: str) -> str: availability, style=None if i == 1 or not include_run_properties else "secondary", ) - + if total_offers_count > len(all_offers): offers.add_row("", "...", style="secondary") else: @@ -260,14 +260,14 @@ def th(s: str) -> str: console.print(props) console.print() - + # Check if we have offers to display has_offers = False if len(run_plan.job_plans) > 1: has_offers = any(len(jp.offers) > 0 for jp in run_plan.job_plans) else: has_offers = len(job_plan.offers) > 0 - + if has_offers: console.print(offers) # Show summary for multi-job plans @@ -343,12 +343,12 @@ def get_runs_table( if verbose and latest_job_submission.inactivity_secs: inactive_for = format_duration_multiunit(latest_job_submission.inactivity_secs) status += f" (inactive for {inactive_for})" - + job_name_parts = [f" replica={job.job_spec.replica_num}"] if job.job_spec.replica_group_name: job_name_parts.append(f"[cyan]group={job.job_spec.replica_group_name}[/cyan]") job_name_parts.append(f"job={job.job_spec.job_num}") - + job_row: Dict[Union[str, int], Any] = { "NAME": " ".join(job_name_parts) + ( diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 30e060207..cfdd8cfe1 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -846,24 +846,24 @@ def validate_gateway( def validate_replica_groups_xor_replicas(cls, values): replica_groups = values.get("replica_groups") replicas = values.get("replicas") - + # Check if user specified both has_groups = replica_groups is not None has_replicas = replicas != Range[int](min=1, max=1) - + if has_groups and has_replicas: raise ValueError("Cannot specify both 'replicas' and 'replica_groups'") - + if has_groups: # Validate unique names names = [g.name for g in replica_groups] if len(names) != len(set(names)): raise ValueError("Replica group names must be unique") - + # Validate at least one group if not replica_groups: raise ValueError("replica_groups cannot be empty") - + return values @root_validator() @@ -871,7 +871,7 @@ def validate_scaling(cls, values): scaling = values.get("scaling") replicas = values.get("replicas") replica_groups = values.get("replica_groups") - + if replica_groups: # Check if any group has a range has_range = any(g.replicas.min != g.replicas.max for g in replica_groups) @@ -883,7 +883,7 @@ def validate_scaling(cls, values): raise ValueError("When you set `replicas` to a range, ensure to specify `scaling`.") elif replicas and replicas.min == replicas.max and scaling: raise ValueError("To use `scaling`, `replicas` must be set to a range.") - + return values @validator("rate_limits") diff --git a/src/dstack/_internal/server/background/tasks/process_runs.py b/src/dstack/_internal/server/background/tasks/process_runs.py index 1bd30fb61..9a879c0fc 100644 --- a/src/dstack/_internal/server/background/tasks/process_runs.py +++ b/src/dstack/_internal/server/background/tasks/process_runs.py @@ -484,6 +484,7 @@ async def _handle_run_replicas( session, run_model, replicas_diff=max_replica_count - non_terminated_replica_count, + allow_exceeding_max=True, # Allow exceeding max for rolling deployments ) replicas_to_stop_count = 0 @@ -510,6 +511,7 @@ async def _handle_run_replicas( session, run_model, replicas_diff=-replicas_to_stop_count, + allow_exceeding_max=True, # Allow terminating out-of-date replicas during rolling deployment ) diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index f20d42ada..33834cbb9 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -340,7 +340,7 @@ async def get_plan( action = ApplyAction.UPDATE secrets = await get_project_secrets_mapping(session=session, project=project) - + # For services with replica groups, create jobs for all groups during planning jobs = [] if ( @@ -407,12 +407,12 @@ async def get_plan( job_plans = [] for job in jobs: job_offers: List[InstanceOfferWithAvailability] = [] - + # Filter pool offers to match this job's GPU requirements gpu_req = None if job.job_spec.requirements.resources and job.job_spec.requirements.resources.gpu: gpu_req = job.job_spec.requirements.resources.gpu.name - + matching_pool_offers = [] for pool_offer in pool_offers: offer_gpus = pool_offer.instance.resources.gpus @@ -424,9 +424,9 @@ async def get_plan( elif not gpu_req: # No GPU requirement, include all pool offers matching_pool_offers.append(pool_offer) - + job_offers.extend(matching_pool_offers) - + # Use shared offers if all jobs are identical, otherwise fetch per-job if shared_offers: job_offers.extend(offer for _, offer in shared_offers) @@ -443,7 +443,7 @@ async def get_plan( instance_mounts=check_run_spec_requires_instance_mounts(effective_run_spec), ) job_offers.extend(offer for _, offer in job_specific_offers) - + job_offers.sort(key=lambda offer: not offer.availability.is_available()) job_spec = job.job_spec @@ -1287,7 +1287,21 @@ async def process_terminating_run(session: AsyncSession, run_model: RunModel): ) -async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replicas_diff: int): +async def scale_run_replicas( + session: AsyncSession, + run_model: RunModel, + replicas_diff: int, + allow_exceeding_max: bool = False, +): + """ + Scale run replicas up or down. + + Args: + session: Database session + run_model: The run to scale + replicas_diff: Number of replicas to add (positive) or remove (negative) + allow_exceeding_max: If True, allow scaling beyond configured max (for rolling deployments) + """ if replicas_diff == 0: # nothing to do return @@ -1349,9 +1363,10 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica # Get group minimums group_mins = {g.name: g.replicas.min for g in normalized_groups} - # Terminate from end (reversed), but skip if group not autoscalable or at minimum + # Terminate from end (reversed) + # For rolling deployments (allow_exceeding_max), prioritize terminating out-of-date replicas terminated_count = 0 - for _, _, _, replica_jobs in reversed(active_replicas): + for _, is_out_of_date, _, replica_jobs in reversed(active_replicas): if terminated_count >= abs(replicas_diff): break @@ -1360,7 +1375,19 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica group_name = replica_jobs[0].replica_group_name or "default" - # Skip if not autoscalable + # For rolling deployment, allow terminating any out-of-date replica + if allow_exceeding_max and is_out_of_date: + # Terminate this replica (out-of-date during rolling deployment) + for job in replica_jobs: + if not job.status.is_finished() and job.status != JobStatus.TERMINATING: + job.status = JobStatus.TERMINATING + job.termination_reason = JobTerminationReason.SCALED_DOWN + + group_counts[group_name] -= 1 + terminated_count += 1 + continue + + # For normal scaling, skip if not autoscalable if normalized_groups and group_name not in autoscalable_groups: continue @@ -1379,29 +1406,42 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica group_counts[group_name] -= 1 terminated_count += 1 else: - # SCALE UP: Choose from autoscalable groups - autoscalable_groups = [g for g in normalized_groups if g.replicas.min != g.replicas.max] - - if normalized_groups and not autoscalable_groups: - # No autoscalable groups, cannot scale - logger.info("%s: no autoscalable groups available for scaling up", fmt(run_model)) - return - - # Count current replicas per group to respect maximums + # SCALE UP + # Count current replicas per group group_counts = {} for _, _, _, replica_jobs in active_replicas: if replica_jobs: group_name = replica_jobs[0].replica_group_name or "default" group_counts[group_name] = group_counts.get(group_name, 0) + 1 - # Filter groups that haven't reached maximum - eligible_groups = [ - g for g in autoscalable_groups if group_counts.get(g.name, 0) < (g.replicas.max or float("inf")) - ] if normalized_groups else normalized_groups + # First, identify groups below minimum (need to scale regardless of autoscalability) + below_min_groups = [ + g for g in normalized_groups + if group_counts.get(g.name, 0) < (g.replicas.min or 0) + ] + + # Then, identify autoscalable groups that can scale beyond minimum + autoscalable_groups = [ + g for g in normalized_groups + if g.replicas.min != g.replicas.max and ( + allow_exceeding_max or group_counts.get(g.name, 0) < (g.replicas.max or float("inf")) + ) + ] + + # Eligible groups are: below-min groups + autoscalable groups + eligible_groups = [] + if below_min_groups: + eligible_groups.extend(below_min_groups) + elif autoscalable_groups: + # Only use autoscalable groups if no groups are below minimum + eligible_groups.extend(autoscalable_groups) + elif allow_exceeding_max and normalized_groups: + # For rolling deployments, allow exceeding max even for fixed groups + eligible_groups.extend(normalized_groups) if normalized_groups and not eligible_groups: - # All groups at maximum - logger.info("%s: all autoscalable groups at maximum capacity", fmt(run_model)) + # All groups at their limits + logger.info("%s: all replica groups at their limits (min/max)", fmt(run_model)) return scheduled_replicas = 0 @@ -1410,10 +1450,10 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica for _, _, _, replica_jobs in inactive_replicas: if scheduled_replicas == replicas_diff: break - # Only reuse if from autoscalable group + # Only reuse if from eligible group if replica_jobs: group_name = replica_jobs[0].replica_group_name or "default" - if not normalized_groups or group_name in {g.name for g in autoscalable_groups}: + if not normalized_groups or group_name in {g.name for g in eligible_groups}: await retry_run_replica_jobs(session, run_model, replica_jobs, only_failed=False) scheduled_replicas += 1 diff --git a/src/dstack/_internal/server/services/services/autoscalers.py b/src/dstack/_internal/server/services/services/autoscalers.py index 379d2e920..93744bd2e 100644 --- a/src/dstack/_internal/server/services/services/autoscalers.py +++ b/src/dstack/_internal/server/services/services/autoscalers.py @@ -131,7 +131,7 @@ def get_service_scaler(conf: ServiceConfiguration) -> BaseServiceScaler: assert conf.replicas.max is not None min_replicas = conf.replicas.min max_replicas = conf.replicas.max - + if conf.scaling is None: return ManualScaler( min_replicas=min_replicas, diff --git a/src/tests/_internal/cli/utils/test_run_plan_display.py b/src/tests/_internal/cli/utils/test_run_plan_display.py index 7719b6e4e..f3e5babf8 100644 --- a/src/tests/_internal/cli/utils/test_run_plan_display.py +++ b/src/tests/_internal/cli/utils/test_run_plan_display.py @@ -1,9 +1,6 @@ """Test CLI display of run plans with replica groups.""" -from io import StringIO -from unittest.mock import MagicMock, patch -import pytest from dstack._internal.cli.utils.run import print_run_plan from dstack._internal.core.models.backends.base import BackendType @@ -16,7 +13,7 @@ ) from dstack._internal.core.models.profiles import Profile from dstack._internal.core.models.repos import LocalRunRepoData -from dstack._internal.core.models.resources import ResourcesSpec, Range +from dstack._internal.core.models.resources import Range, ResourcesSpec from dstack._internal.core.models.runs import ( ApplyAction, InstanceOfferWithAvailability, @@ -304,13 +301,13 @@ def test_replica_groups_offers_sorted_by_price(self, capsys): # Split output to find the offers table (after the header section) lines = output.split("\n") - + # Find lines that contain both a number and a group name (these are offer rows) offer_rows = [ - line for line in lines + line for line in lines if ("cheap-group:" in line or "expensive-group:" in line) and line.strip().startswith(("1", "2", "3")) ] - + # The first offer row should be cheap-group (lower price) assert len(offer_rows) >= 2, "Should have at least 2 offer rows" assert "cheap-group:" in offer_rows[0], "First offer should be cheap-group (sorted by price)" @@ -416,7 +413,7 @@ def test_replica_group_with_no_offers_shows_message(self, capsys): assert "offers available" in output assert "Possible reasons:" in output assert "dstack.ai/docs" in output # URL may be truncated in table - + # Verify unavailable group appears BEFORE available group (at top) unavailable_pos = output.find("unavailable-group:") available_pos = output.find("available-group:") diff --git a/src/tests/_internal/core/models/test_replica_groups.py b/src/tests/_internal/core/models/test_replica_groups.py index dc8de87cd..2b38a17b0 100644 --- a/src/tests/_internal/core/models/test_replica_groups.py +++ b/src/tests/_internal/core/models/test_replica_groups.py @@ -6,7 +6,7 @@ ServiceConfiguration, parse_run_configuration, ) -from dstack._internal.core.models.resources import CPUSpec, GPUSpec, Range, ResourcesSpec +from dstack._internal.core.models.resources import Range from dstack._internal.core.models.runs import get_normalized_replica_groups @@ -32,17 +32,17 @@ def test_basic_replica_groups(self): }, ], } - + parsed = parse_run_configuration(conf) assert isinstance(parsed, ServiceConfiguration) assert parsed.replica_groups is not None assert len(parsed.replica_groups) == 2 - + # Check first group assert parsed.replica_groups[0].name == "h100-group" assert parsed.replica_groups[0].replicas == Range(min=1, max=1) assert parsed.replica_groups[0].resources.gpu.name == ["H100"] - + # Check second group assert parsed.replica_groups[1].name == "rtx5090-group" assert parsed.replica_groups[1].replicas == Range(min=2, max=2) @@ -71,14 +71,14 @@ def test_replica_groups_with_ranges(self): "target": 10, }, } - + parsed = parse_run_configuration(conf) assert parsed.replica_groups is not None assert len(parsed.replica_groups) == 2 - + # Fixed group assert parsed.replica_groups[0].replicas == Range(min=1, max=1) - + # Scalable group assert parsed.replica_groups[1].replicas == Range(min=1, max=3) @@ -108,13 +108,13 @@ def test_replica_groups_with_profile_params(self): }, ], } - + parsed = parse_run_configuration(conf) - + # First group inherits from service (doesn't specify backends/regions) assert parsed.replica_groups[0].backends is None assert parsed.replica_groups[0].regions is None - + # Second group overrides assert parsed.replica_groups[1].backends == ["runpod"] assert parsed.replica_groups[1].regions == ["eu-west-1"] @@ -134,7 +134,7 @@ def test_replica_groups_xor_replicas(self): } ], } - + with pytest.raises( ConfigurationError, match="Cannot specify both 'replicas' and 'replica_groups'", @@ -160,7 +160,7 @@ def test_replica_groups_unique_names(self): }, ], } - + with pytest.raises( ConfigurationError, match="Replica group names must be unique", @@ -181,7 +181,7 @@ def test_replica_groups_empty_name(self): } ], } - + with pytest.raises( ConfigurationError, match="Group name cannot be empty", @@ -203,7 +203,7 @@ def test_replica_groups_range_requires_scaling(self): ], # Missing scaling! } - + with pytest.raises( ConfigurationError, match="When any replica group has a range, 'scaling' must be specified", @@ -218,7 +218,7 @@ def test_replica_groups_cannot_be_empty(self): "port": 8000, "replica_groups": [], # Empty list } - + with pytest.raises( ConfigurationError, match="replica_groups cannot be empty", @@ -248,10 +248,10 @@ def test_normalize_new_format(self): }, ], } - + parsed = parse_run_configuration(conf) normalized = get_normalized_replica_groups(parsed) - + assert len(normalized) == 2 assert normalized[0].name == "group1" assert normalized[1].name == "group2" @@ -267,16 +267,16 @@ def test_normalize_legacy_format(self): "backends": ["aws"], "regions": ["us-west-2"], } - + parsed = parse_run_configuration(conf) normalized = get_normalized_replica_groups(parsed) - + # Should create single "default" group assert len(normalized) == 1 assert normalized[0].name == "default" assert normalized[0].replicas == Range(min=3, max=3) assert normalized[0].resources.gpu.name == ["H100"] - + # Should inherit profile params assert normalized[0].backends == ["aws"] assert normalized[0].regions == ["us-west-2"] @@ -294,10 +294,10 @@ def test_normalize_legacy_with_range(self): "target": 10, }, } - + parsed = parse_run_configuration(conf) normalized = get_normalized_replica_groups(parsed) - + assert len(normalized) == 1 assert normalized[0].name == "default" assert normalized[0].replicas == Range(min=1, max=5) @@ -329,12 +329,12 @@ def test_autoscalable_group_detection(self): "target": 10, }, } - + parsed = parse_run_configuration(conf) - + # Fixed group: min == max assert parsed.replica_groups[0].replicas.min == parsed.replica_groups[0].replicas.max - + # Scalable group: min != max assert parsed.replica_groups[1].replicas.min != parsed.replica_groups[1].replicas.max @@ -361,9 +361,9 @@ def test_multiple_autoscalable_groups(self): "target": 10, }, } - + parsed = parse_run_configuration(conf) - + # Both are autoscalable assert parsed.replica_groups[0].replicas.min != parsed.replica_groups[0].replicas.max assert parsed.replica_groups[1].replicas.min != parsed.replica_groups[1].replicas.max @@ -381,9 +381,9 @@ def test_legacy_service_config(self): "replicas": 2, "resources": {"gpu": "A100:1"}, } - + parsed = parse_run_configuration(conf) - + # Should parse successfully assert isinstance(parsed, ServiceConfiguration) assert parsed.replicas == Range(min=2, max=2) @@ -402,9 +402,9 @@ def test_legacy_autoscaling_config(self): "target": 10, }, } - + parsed = parse_run_configuration(conf) - + # Should parse successfully assert parsed.replicas == Range(min=0, max=5) assert parsed.scaling is not None @@ -423,10 +423,10 @@ def test_normalization_preserves_all_profile_params(self): "spot_policy": "spot", "max_price": 10.0, } - + parsed = parse_run_configuration(conf) normalized = get_normalized_replica_groups(parsed) - + # Check all fields are copied group = normalized[0] assert group.backends == ["aws"] diff --git a/src/tests/_internal/core/test_backward_compatibility.py b/src/tests/_internal/core/test_backward_compatibility.py index 6bb2ec8b3..e5c6f985d 100644 --- a/src/tests/_internal/core/test_backward_compatibility.py +++ b/src/tests/_internal/core/test_backward_compatibility.py @@ -1,6 +1,5 @@ """Test backward compatibility for replica_groups with older servers.""" -import pytest from dstack._internal.core.compatibility.runs import get_get_plan_excludes, get_run_spec_excludes from dstack._internal.core.models.configurations import ServiceConfiguration @@ -20,7 +19,7 @@ def test_replica_groups_excluded_when_none(self): commands=["echo test"], replicas={"min": 1, "max": 1}, ) - + run_spec = RunSpec( run_name="test-run", repo_id="test-repo", @@ -28,10 +27,10 @@ def test_replica_groups_excluded_when_none(self): configuration=config, profile=None, ) - + # Get excludes excludes = get_run_spec_excludes(run_spec) - + # replica_groups should be in excludes assert "configuration" in excludes assert "replica_groups" in excludes["configuration"] @@ -52,7 +51,7 @@ def test_replica_groups_not_excluded_when_set(self): ], scaling={"metric": "rps", "target": 10}, ) - + run_spec = RunSpec( run_name="test-run", repo_id="test-repo", @@ -60,10 +59,10 @@ def test_replica_groups_not_excluded_when_set(self): configuration=config, profile=None, ) - + # Get excludes excludes = get_run_spec_excludes(run_spec) - + # replica_groups should NOT be in excludes (or be False) if "configuration" in excludes and "replica_groups" in excludes["configuration"]: assert excludes["configuration"]["replica_groups"] is not True @@ -76,7 +75,7 @@ def test_get_plan_request_serialization_without_replica_groups(self): commands=["echo test"], replicas={"min": 1, "max": 1}, ) - + run_spec = RunSpec( run_name="test-run", repo_id="test-repo", @@ -84,13 +83,13 @@ def test_get_plan_request_serialization_without_replica_groups(self): configuration=config, profile=None, ) - + request = GetRunPlanRequest(run_spec=run_spec, max_offers=None) excludes = get_get_plan_excludes(request) - + # Serialize with excludes json_str = request.json(exclude=excludes) - + # replica_groups should not appear in JSON assert "replica_groups" not in json_str @@ -109,7 +108,7 @@ def test_get_plan_request_serialization_with_replica_groups(self): ], scaling={"metric": "rps", "target": 10}, ) - + run_spec = RunSpec( run_name="test-run", repo_id="test-repo", @@ -117,13 +116,13 @@ def test_get_plan_request_serialization_with_replica_groups(self): configuration=config, profile=None, ) - + request = GetRunPlanRequest(run_spec=run_spec, max_offers=None) excludes = get_get_plan_excludes(request) - + # Serialize with excludes json_str = request.json(exclude=excludes) - + # replica_groups SHOULD appear in JSON assert "replica_groups" in json_str assert "gpu-group" in json_str diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index f4e481f53..42758111c 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -233,6 +233,7 @@ def get_dev_env_run_plan_dict( "privileged": True if docker else privileged, "job_name": f"{run_name}-0-0", "replica_num": 0, + "replica_group_name": None, "job_num": 0, "jobs_per_replica": 1, "single_branch": False, @@ -441,6 +442,7 @@ def get_dev_env_run_dict( "privileged": True if docker else privileged, "job_name": f"{run_name}-0-0", "replica_num": 0, + "replica_group_name": None, "job_num": 0, "jobs_per_replica": 1, "single_branch": False, diff --git a/src/tests/_internal/server/services/test_get_plan_replica_groups.py b/src/tests/_internal/server/services/test_get_plan_replica_groups.py index 8d322312d..b45a9c42b 100644 --- a/src/tests/_internal/server/services/test_get_plan_replica_groups.py +++ b/src/tests/_internal/server/services/test_get_plan_replica_groups.py @@ -1,6 +1,5 @@ """Test get_plan() offer fetching logic for replica groups.""" -import pytest from dstack._internal.core.models.resources import ResourcesSpec from dstack._internal.core.models.runs import Requirements diff --git a/src/tests/_internal/server/services/test_replica_groups_scaling.py b/src/tests/_internal/server/services/test_replica_groups_scaling.py index de40e5fad..44a9e7084 100644 --- a/src/tests/_internal/server/services/test_replica_groups_scaling.py +++ b/src/tests/_internal/server/services/test_replica_groups_scaling.py @@ -7,11 +7,15 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -from dstack._internal.core.models.configurations import ReplicaGroup, ScalingSpec, ServiceConfiguration +from dstack._internal.core.models.configurations import ( + ReplicaGroup, + ScalingSpec, + ServiceConfiguration, +) from dstack._internal.core.models.profiles import Profile from dstack._internal.core.models.resources import GPUSpec, Range, ResourcesSpec from dstack._internal.core.models.runs import JobStatus, JobTerminationReason -from dstack._internal.server.models import JobModel, RunModel +from dstack._internal.server.models import RunModel from dstack._internal.server.services.runs import scale_run_replicas from dstack._internal.server.testing.common import ( create_job, @@ -40,7 +44,7 @@ async def make_run_with_groups( project = await create_project(session=session) user = await create_user(session=session) repo = await create_repo(session=session, project_id=project.id) - + # Build replica groups replica_groups = [] for group_cfg in groups_config: @@ -53,7 +57,7 @@ async def make_run_with_groups( ), ) ) - + profile = Profile(name="test-profile") run_spec = get_run_spec( repo_id=repo.name, @@ -66,7 +70,7 @@ async def make_run_with_groups( scaling=ScalingSpec(metric="rps", target=10), ), ) - + run = await create_run( session=session, project=project, @@ -75,7 +79,7 @@ async def make_run_with_groups( run_name="test-run", run_spec=run_spec, ) - + # Create initial jobs replica_num = 0 for group_cfg in groups_config: @@ -89,9 +93,9 @@ async def make_run_with_groups( ) run.jobs.append(job) replica_num += 1 - + await session.commit() - + # Reload with jobs and project res = await session.execute( select(RunModel) @@ -124,15 +128,15 @@ async def test_scale_down_only_from_autoscalable_groups(self, session: AsyncSess }, ], ) - + # Scale down by 1 (should only affect scalable group) await scale_wrapper(session, run, -1) - + # Check: fixed group should still have 1 running job fixed_jobs = [j for j in run.jobs if j.replica_group_name == "fixed-h100"] assert len(fixed_jobs) == 1 assert fixed_jobs[0].status == JobStatus.RUNNING - + # Check: scalable group should have 1 terminated, 1 running scalable_jobs = [j for j in run.jobs if j.replica_group_name == "scalable-rtx"] assert len(scalable_jobs) == 2 @@ -160,14 +164,14 @@ async def test_scale_down_respects_group_minimums(self, session: AsyncSession): }, ], ) - + # Try to scale down by 2 await scale_wrapper(session, run, -2) - + # Group A should still have 1 (at minimum) group_a_jobs = [j for j in run.jobs if j.replica_group_name == "group-a"] assert len([j for j in group_a_jobs if j.status == JobStatus.RUNNING]) == 1 - + # Group B should have terminated 1 (3 -> 2, which is minimum) group_b_jobs = [j for j in run.jobs if j.replica_group_name == "group-b"] terminating = [j for j in group_b_jobs if j.status == JobStatus.TERMINATING] @@ -193,12 +197,12 @@ async def test_scale_down_all_groups_fixed(self, session: AsyncSession): }, ], ) - + initial_count = len(run.jobs) - + # Try to scale down await scale_wrapper(session, run, -1) - + # No jobs should be terminated (all groups are fixed) assert len(run.jobs) == initial_count assert all(j.status == JobStatus.RUNNING for j in run.jobs) @@ -227,15 +231,15 @@ async def test_scale_up_selects_autoscalable_group(self, session: AsyncSession): }, ], ) - + initial_count = len(run.jobs) - + # Scale up by 1 await scale_wrapper(session, run, 1) - + # Should have one more job assert len(run.jobs) == initial_count + 1 - + # New job should be in scalable group new_jobs = [j for j in run.jobs if j.replica_num == initial_count] assert len(new_jobs) == 1 @@ -262,14 +266,14 @@ async def test_scale_up_respects_group_maximums(self, session: AsyncSession): }, ], ) - + # Try to scale up by 2 await scale_wrapper(session, run, 2) - + # Small group should still have 2 (at max) small_jobs = [j for j in run.jobs if j.replica_group_name == "small-group"] assert len(small_jobs) == 2 - + # Large group should have grown by 2 large_jobs = [j for j in run.jobs if j.replica_group_name == "large-group"] assert len(large_jobs) == 3 @@ -294,12 +298,12 @@ async def test_scale_up_no_autoscalable_groups(self, session: AsyncSession): }, ], ) - + initial_count = len(run.jobs) - + # Try to scale up await scale_wrapper(session, run, 2) - + # Should not have added any jobs assert len(run.jobs) == initial_count @@ -323,12 +327,12 @@ async def test_scale_up_all_groups_at_max(self, session: AsyncSession): }, ], ) - + initial_count = len(run.jobs) - + # Try to scale up await scale_wrapper(session, run, 1) - + # Should not have added any jobs (all at max) assert len(run.jobs) == initial_count @@ -342,7 +346,7 @@ async def test_legacy_config_scaling(self, session: AsyncSession): project = await create_project(session=session) user = await create_user(session=session) repo = await create_repo(session=session, project_id=project.id) - + # Use legacy format (no replica_groups) profile = Profile(name="test-profile") run_spec = get_run_spec( @@ -356,7 +360,7 @@ async def test_legacy_config_scaling(self, session: AsyncSession): scaling=ScalingSpec(metric="rps", target=10), ), ) - + run = await create_run( session=session, project=project, @@ -365,7 +369,7 @@ async def test_legacy_config_scaling(self, session: AsyncSession): run_name="test-run", run_spec=run_spec, ) - + # Add initial job (no group name) job = await create_job( session=session, @@ -376,13 +380,13 @@ async def test_legacy_config_scaling(self, session: AsyncSession): ) run.jobs.append(job) await session.commit() - + # Scale up should work await scale_wrapper(session, run, 1) - + # Should have 2 jobs now assert len(run.jobs) == 2 - + # New job should have "default" group name or None new_job = [j for j in run.jobs if j.replica_num == 1][0] assert new_job.replica_group_name in [None, "default"] From 570f83356f5ad6933d2d6cde77092cb14b576be9 Mon Sep 17 00:00:00 2001 From: "Alexander Nicholson 4584443+DragonStuff@users.noreply.github.com" <4584443+DragonStuff@users.noreply.github.com> Date: Sat, 18 Oct 2025 13:48:38 +0900 Subject: [PATCH 3/5] chore: run pre-commit and fix failures --- src/dstack/_internal/cli/utils/run.py | 4 ++- .../_internal/core/models/configurations.py | 4 +-- .../tasks/process_submitted_jobs.py | 1 + ...3d4e5f6_add_jobmodel_replica_group_name.py | 1 - src/dstack/_internal/server/services/runs.py | 29 +++++++++++++------ .../server/services/services/autoscalers.py | 1 + .../cli/utils/test_run_plan_display.py | 13 +++++---- .../core/models/test_replica_groups.py | 2 +- .../core/test_backward_compatibility.py | 2 -- .../services/test_get_plan_replica_groups.py | 15 +++------- .../services/test_replica_groups_scaling.py | 12 ++++---- 11 files changed, 45 insertions(+), 39 deletions(-) diff --git a/src/dstack/_internal/cli/utils/run.py b/src/dstack/_internal/cli/utils/run.py index c0cbaf815..31485756c 100644 --- a/src/dstack/_internal/cli/utils/run.py +++ b/src/dstack/_internal/cli/utils/run.py @@ -273,7 +273,9 @@ def th(s: str) -> str: # Show summary for multi-job plans if len(run_plan.job_plans) > 1: if total_offers_count > len(all_offers): - max_price_overall = max((jp.max_price for jp in run_plan.job_plans if jp.max_price), default=None) + max_price_overall = max( + (jp.max_price for jp in run_plan.job_plans if jp.max_price), default=None + ) if max_price_overall: console.print( f"[secondary] Shown {len(all_offers)} of {total_offers_count} offers, " diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index cfdd8cfe1..f687aec30 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -876,9 +876,7 @@ def validate_scaling(cls, values): # Check if any group has a range has_range = any(g.replicas.min != g.replicas.max for g in replica_groups) if has_range and not scaling: - raise ValueError( - "When any replica group has a range, 'scaling' must be specified" - ) + raise ValueError("When any replica group has a range, 'scaling' must be specified") elif replicas and replicas.min != replicas.max and not scaling: raise ValueError("When you set `replicas` to a range, ensure to specify `scaling`.") elif replicas and replicas.min == replicas.max and scaling: diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index 3c740e6a0..5a7f6875e 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -613,6 +613,7 @@ def _get_nodes_required_num_for_run(run_spec: RunSpec) -> int: # Use groups if present if run_spec.configuration.replica_groups: from dstack._internal.core.models.runs import get_normalized_replica_groups + groups = get_normalized_replica_groups(run_spec.configuration) nodes_required_num = sum(g.replicas.min or 0 for g in groups) elif run_spec.configuration.replicas.min is not None: diff --git a/src/dstack/_internal/server/migrations/versions/a1b2c3d4e5f6_add_jobmodel_replica_group_name.py b/src/dstack/_internal/server/migrations/versions/a1b2c3d4e5f6_add_jobmodel_replica_group_name.py index 8e8a1eed0..a1d9e3eaf 100644 --- a/src/dstack/_internal/server/migrations/versions/a1b2c3d4e5f6_add_jobmodel_replica_group_name.py +++ b/src/dstack/_internal/server/migrations/versions/a1b2c3d4e5f6_add_jobmodel_replica_group_name.py @@ -24,4 +24,3 @@ def upgrade() -> None: def downgrade() -> None: with op.batch_alter_table("jobs", schema=None) as batch_op: batch_op.drop_column("replica_group_name") - diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 33834cbb9..895155b33 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -1113,6 +1113,7 @@ def _validate_run_spec_and_set_defaults(user: UserModel, run_spec: RunSpec): if run_spec.merged_profile.schedule: if run_spec.configuration.replica_groups: from dstack._internal.core.models.runs import get_normalized_replica_groups + groups = get_normalized_replica_groups(run_spec.configuration) if any(g.replicas.min == 0 for g in groups): raise ServerClientError( @@ -1351,7 +1352,9 @@ async def scale_run_replicas( if replicas_diff < 0: # SCALE DOWN: Only terminate from autoscalable groups while respecting group minimums - autoscalable_groups = {g.name for g in normalized_groups if g.replicas.min != g.replicas.max} + autoscalable_groups = { + g.name for g in normalized_groups if g.replicas.min != g.replicas.max + } # Count replicas per group group_counts = {} @@ -1416,15 +1419,17 @@ async def scale_run_replicas( # First, identify groups below minimum (need to scale regardless of autoscalability) below_min_groups = [ - g for g in normalized_groups - if group_counts.get(g.name, 0) < (g.replicas.min or 0) + g for g in normalized_groups if group_counts.get(g.name, 0) < (g.replicas.min or 0) ] # Then, identify autoscalable groups that can scale beyond minimum autoscalable_groups = [ - g for g in normalized_groups - if g.replicas.min != g.replicas.max and ( - allow_exceeding_max or group_counts.get(g.name, 0) < (g.replicas.max or float("inf")) + g + for g in normalized_groups + if g.replicas.min != g.replicas.max + and ( + allow_exceeding_max + or group_counts.get(g.name, 0) < (g.replicas.max or float("inf")) ) ] @@ -1454,7 +1459,9 @@ async def scale_run_replicas( if replica_jobs: group_name = replica_jobs[0].replica_group_name or "default" if not normalized_groups or group_name in {g.name for g in eligible_groups}: - await retry_run_replica_jobs(session, run_model, replica_jobs, only_failed=False) + await retry_run_replica_jobs( + session, run_model, replica_jobs, only_failed=False + ) scheduled_replicas += 1 # Create new replicas for remaining diff @@ -1488,10 +1495,14 @@ async def scale_run_replicas( scheduled_replicas += 1 # Remove from eligible if at max - if group_counts[selected_group.name] >= (selected_group.replicas.max or float("inf")): + if group_counts[selected_group.name] >= ( + selected_group.replicas.max or float("inf") + ): eligible_groups = [g for g in eligible_groups if g.name != selected_group.name] if not eligible_groups: - logger.info("%s: all eligible groups reached maximum capacity", fmt(run_model)) + logger.info( + "%s: all eligible groups reached maximum capacity", fmt(run_model) + ) break else: scheduled_replicas += 1 diff --git a/src/dstack/_internal/server/services/services/autoscalers.py b/src/dstack/_internal/server/services/services/autoscalers.py index 93744bd2e..6573d1f7f 100644 --- a/src/dstack/_internal/server/services/services/autoscalers.py +++ b/src/dstack/_internal/server/services/services/autoscalers.py @@ -123,6 +123,7 @@ def get_service_scaler(conf: ServiceConfiguration) -> BaseServiceScaler: # Compute bounds from groups if present if conf.replica_groups: from dstack._internal.core.models.runs import get_normalized_replica_groups + groups = get_normalized_replica_groups(conf) min_replicas = sum(g.replicas.min or 0 for g in groups) max_replicas = sum(g.replicas.max or 0 for g in groups) diff --git a/src/tests/_internal/cli/utils/test_run_plan_display.py b/src/tests/_internal/cli/utils/test_run_plan_display.py index f3e5babf8..e5255c106 100644 --- a/src/tests/_internal/cli/utils/test_run_plan_display.py +++ b/src/tests/_internal/cli/utils/test_run_plan_display.py @@ -1,7 +1,5 @@ """Test CLI display of run plans with replica groups.""" - - from dstack._internal.cli.utils.run import print_run_plan from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import ServiceConfiguration @@ -304,13 +302,17 @@ def test_replica_groups_offers_sorted_by_price(self, capsys): # Find lines that contain both a number and a group name (these are offer rows) offer_rows = [ - line for line in lines - if ("cheap-group:" in line or "expensive-group:" in line) and line.strip().startswith(("1", "2", "3")) + line + for line in lines + if ("cheap-group:" in line or "expensive-group:" in line) + and line.strip().startswith(("1", "2", "3")) ] # The first offer row should be cheap-group (lower price) assert len(offer_rows) >= 2, "Should have at least 2 offer rows" - assert "cheap-group:" in offer_rows[0], "First offer should be cheap-group (sorted by price)" + assert "cheap-group:" in offer_rows[0], ( + "First offer should be cheap-group (sorted by price)" + ) assert "expensive-group:" in offer_rows[1], "Second offer should be expensive-group" assert "$0.3" in output # Price displayed as $0.3 assert "$3" in output # Price displayed as $3 @@ -418,4 +420,3 @@ def test_replica_group_with_no_offers_shows_message(self, capsys): unavailable_pos = output.find("unavailable-group:") available_pos = output.find("available-group:") assert unavailable_pos < available_pos, "Group with no offers should appear first" - diff --git a/src/tests/_internal/core/models/test_replica_groups.py b/src/tests/_internal/core/models/test_replica_groups.py index 2b38a17b0..7c0a35a08 100644 --- a/src/tests/_internal/core/models/test_replica_groups.py +++ b/src/tests/_internal/core/models/test_replica_groups.py @@ -1,4 +1,5 @@ """Tests for Named Replica Groups functionality""" + import pytest from dstack._internal.core.errors import ConfigurationError @@ -434,4 +435,3 @@ def test_normalization_preserves_all_profile_params(self): assert group.instance_types == ["p4d.24xlarge"] assert group.spot_policy == "spot" assert group.max_price == 10.0 - diff --git a/src/tests/_internal/core/test_backward_compatibility.py b/src/tests/_internal/core/test_backward_compatibility.py index e5c6f985d..ca49f12a3 100644 --- a/src/tests/_internal/core/test_backward_compatibility.py +++ b/src/tests/_internal/core/test_backward_compatibility.py @@ -1,6 +1,5 @@ """Test backward compatibility for replica_groups with older servers.""" - from dstack._internal.core.compatibility.runs import get_get_plan_excludes, get_run_spec_excludes from dstack._internal.core.models.configurations import ServiceConfiguration from dstack._internal.core.models.repos import LocalRunRepoData @@ -126,4 +125,3 @@ def test_get_plan_request_serialization_with_replica_groups(self): # replica_groups SHOULD appear in JSON assert "replica_groups" in json_str assert "gpu-group" in json_str - diff --git a/src/tests/_internal/server/services/test_get_plan_replica_groups.py b/src/tests/_internal/server/services/test_get_plan_replica_groups.py index b45a9c42b..84a549656 100644 --- a/src/tests/_internal/server/services/test_get_plan_replica_groups.py +++ b/src/tests/_internal/server/services/test_get_plan_replica_groups.py @@ -1,6 +1,5 @@ """Test get_plan() offer fetching logic for replica groups.""" - from dstack._internal.core.models.resources import ResourcesSpec from dstack._internal.core.models.runs import Requirements @@ -33,6 +32,7 @@ def test_requirements_equality_check(self): def test_identical_requirements_detection_logic(self): """Test logic for detecting when all jobs have identical requirements.""" + # Simulate job specs with requirements class MockJobSpec: def __init__(self, gpu_name: str, gpu_count: int = 1): @@ -94,9 +94,7 @@ def __init__(self, group_name: str, gpu_name: str): all_identical = all( job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs ) - assert ( - all_identical is False - ), "Different GPU types should trigger per-job offer fetch" + assert all_identical is False, "Different GPU types should trigger per-job offer fetch" # Scenario 2: Replica groups with same GPU -> shared fetch (optimization) jobs = [ @@ -116,9 +114,7 @@ def __init__(self, group_name: str, gpu_name: str): all_identical = all( job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs ) - assert ( - all_identical is True - ), "Legacy replicas with same GPU should use shared fetch" + assert all_identical is True, "Legacy replicas with same GPU should use shared fetch" # Scenario 4: Mixed groups (2 same + 1 different) -> per-job fetch jobs = [ @@ -129,9 +125,7 @@ def __init__(self, group_name: str, gpu_name: str): all_identical = all( job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs ) - assert ( - all_identical is False - ), "Mix of different GPUs should trigger per-job fetch for all" + assert all_identical is False, "Mix of different GPUs should trigger per-job fetch for all" class TestReplicaGroupOfferSearchIntegration: @@ -210,4 +204,3 @@ def test_requirements_with_different_cpu_specs(self): # Different CPU requirements assert req_low_cpu != req_high_cpu - diff --git a/src/tests/_internal/server/services/test_replica_groups_scaling.py b/src/tests/_internal/server/services/test_replica_groups_scaling.py index 44a9e7084..9cea6193b 100644 --- a/src/tests/_internal/server/services/test_replica_groups_scaling.py +++ b/src/tests/_internal/server/services/test_replica_groups_scaling.py @@ -1,4 +1,5 @@ """Integration tests for replica groups scaling functionality""" + from typing import List import pytest @@ -52,9 +53,7 @@ async def make_run_with_groups( ReplicaGroup( name=group_cfg["name"], replicas=parse_obj_as(Range[int], group_cfg["replicas_range"]), - resources=ResourcesSpec( - gpu=GPUSpec(name=[group_cfg["gpu"]], count=1) - ), + resources=ResourcesSpec(gpu=GPUSpec(name=[group_cfg["gpu"]], count=1)), ) ) @@ -323,7 +322,11 @@ async def test_scale_up_all_groups_at_max(self, session: AsyncSession): "name": "group-b", "replicas_range": "1..3", "gpu": "RTX5090", - "initial_jobs": [JobStatus.RUNNING, JobStatus.RUNNING, JobStatus.RUNNING], # At max + "initial_jobs": [ + JobStatus.RUNNING, + JobStatus.RUNNING, + JobStatus.RUNNING, + ], # At max }, ], ) @@ -390,4 +393,3 @@ async def test_legacy_config_scaling(self, session: AsyncSession): # New job should have "default" group name or None new_job = [j for j in run.jobs if j.replica_num == 1][0] assert new_job.replica_group_name in [None, "default"] - From 0d75f3341ae44e99a8e005c6bccc291cc64c66dd Mon Sep 17 00:00:00 2001 From: "Alexander Nicholson 4584443+DragonStuff@users.noreply.github.com" <4584443+DragonStuff@users.noreply.github.com> Date: Sun, 19 Oct 2025 01:15:53 +0900 Subject: [PATCH 4/5] fix: replica to replica_groups migration tests and cli improvements --- src/dstack/_internal/cli/utils/run.py | 67 ++++- .../server/background/tasks/process_runs.py | 29 ++ .../tasks/process_submitted_jobs.py | 49 ++++ src/dstack/_internal/server/services/runs.py | 49 +++- .../cli/utils/test_run_plan_display.py | 217 ++++++++++++++ .../test_replica_groups_profile_overrides.py | 265 ++++++++++++++++++ .../services/test_replica_groups_update.py | 172 ++++++++++++ 7 files changed, 833 insertions(+), 15 deletions(-) create mode 100644 src/tests/_internal/server/services/test_replica_groups_profile_overrides.py create mode 100644 src/tests/_internal/server/services/test_replica_groups_update.py diff --git a/src/dstack/_internal/cli/utils/run.py b/src/dstack/_internal/cli/utils/run.py index 31485756c..d3706e970 100644 --- a/src/dstack/_internal/cli/utils/run.py +++ b/src/dstack/_internal/cli/utils/run.py @@ -166,23 +166,64 @@ def th(s: str) -> str: # For replica groups, show offers from all job plans if len(run_plan.job_plans) > 1: - # Multiple jobs - aggregate offers from all groups - all_offers = [] + # Multiple jobs - ensure fair representation of all groups groups_with_no_offers = [] + groups_with_offers = {} total_offers_count = 0 + # Collect offers per group for jp in run_plan.job_plans: group_name = jp.job_spec.replica_group_name or "default" if jp.total_offers == 0: groups_with_no_offers.append(group_name) - for offer in jp.offers[:max_offers] if max_offers else jp.offers: - all_offers.append((group_name, offer)) + else: + groups_with_offers[group_name] = jp.offers total_offers_count += jp.total_offers - # Sort by price - all_offers.sort(key=lambda x: x[1].price) - if max_offers: - all_offers = all_offers[:max_offers] + # Strategy: Show at least min_per_group offers from each group, then fill with cheapest + num_groups = len(groups_with_offers) + if num_groups > 0 and max_offers: + min_per_group = max( + 1, max_offers // (num_groups * 2) + ) # At least 1, aim for ~half distribution + remaining_slots = max_offers + else: + min_per_group = None + remaining_slots = None + + selected_offers = [] + + # First pass: Take min_per_group from each group (cheapest from each) + if min_per_group: + for group_name, group_offers in groups_with_offers.items(): + sorted_group_offers = sorted(group_offers, key=lambda x: x.price) + take_count = min(min_per_group, len(sorted_group_offers), remaining_slots) + for offer in sorted_group_offers[:take_count]: + selected_offers.append((group_name, offer)) + remaining_slots -= take_count + + # Second pass: Fill remaining slots with cheapest offers globally + if remaining_slots and remaining_slots > 0: + all_remaining = [] + for group_name, group_offers in groups_with_offers.items(): + sorted_group_offers = sorted(group_offers, key=lambda x: x.price) + # Skip offers already selected + for offer in sorted_group_offers[min_per_group:]: + all_remaining.append((group_name, offer)) + + # Sort remaining by price and take the cheapest + all_remaining.sort(key=lambda x: x[1].price) + selected_offers.extend(all_remaining[:remaining_slots]) + + # If no max_offers limit, show all + if not max_offers: + selected_offers = [] + for group_name, group_offers in groups_with_offers.items(): + for offer in group_offers: + selected_offers.append((group_name, offer)) + + # Sort final selection by price for display + selected_offers.sort(key=lambda x: x[1].price) # Show groups with no offers FIRST for group_name in groups_with_no_offers: @@ -197,8 +238,8 @@ def th(s: str) -> str: style="secondary", ) - # Then show groups with offers - for i, (group_name, offer) in enumerate(all_offers, start=1): + # Then show selected offers + for i, (group_name, offer) in enumerate(selected_offers, start=1): r = offer.instance.resources availability = "" @@ -226,7 +267,7 @@ def th(s: str) -> str: style=None if i == 1 or not include_run_properties else "secondary", ) - if total_offers_count > len(all_offers): + if total_offers_count > len(selected_offers): offers.add_row("", "...", style="secondary") else: # Single job - original logic @@ -272,13 +313,13 @@ def th(s: str) -> str: console.print(offers) # Show summary for multi-job plans if len(run_plan.job_plans) > 1: - if total_offers_count > len(all_offers): + if total_offers_count > len(selected_offers): max_price_overall = max( (jp.max_price for jp in run_plan.job_plans if jp.max_price), default=None ) if max_price_overall: console.print( - f"[secondary] Shown {len(all_offers)} of {total_offers_count} offers, " + f"[secondary] Shown {len(selected_offers)} of {total_offers_count} offers, " f"${max_price_overall:3f}".rstrip("0").rstrip(".") + " max[/]" ) diff --git a/src/dstack/_internal/server/background/tasks/process_runs.py b/src/dstack/_internal/server/background/tasks/process_runs.py index 9a879c0fc..35fc949b0 100644 --- a/src/dstack/_internal/server/background/tasks/process_runs.py +++ b/src/dstack/_internal/server/background/tasks/process_runs.py @@ -521,20 +521,49 @@ async def _update_jobs_to_new_deployment_in_place( """ Bump deployment_num for jobs that do not require redeployment. """ + from dstack._internal.core.models.runs import get_normalized_replica_groups + secrets = await get_project_secrets_mapping( session=session, project=run_model.project, ) + + # Get replica groups from the new run_spec for matching + replica_group = None + if run_spec.configuration.type == "service": + normalized_groups = get_normalized_replica_groups(run_spec.configuration) + else: + normalized_groups = [] + for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs): if all(j.status.is_finished() for j in job_models): continue if all(j.deployment_num == run_model.deployment_num for j in job_models): continue + + # Determine which replica group this job belongs to + # Use the old job's replica_group_name to find the matching group in new spec + old_job_spec = JobSpec.__response__.parse_raw(job_models[0].job_spec_data) + if old_job_spec.replica_group_name and normalized_groups: + replica_group = next( + (g for g in normalized_groups if g.name == old_job_spec.replica_group_name), + None, + ) + if replica_group is None: + logger.warning( + "Replica group '%s' from old job not found in new run_spec. " + "Job will use base configuration.", + old_job_spec.replica_group_name, + ) + else: + replica_group = None + # FIXME: Handle getting image configuration errors or skip it. new_job_specs = await get_job_specs_from_run_spec( run_spec=run_spec, secrets=secrets, replica_num=replica_num, + replica_group=replica_group, ) assert len(new_job_specs) == len(job_models), ( "Changing the number of jobs within a replica is not yet supported" diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index 5a7f6875e..df05e76a8 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -748,6 +748,11 @@ def _get_profile_for_job(run_spec: RunSpec, job: Job) -> Profile: group = next((g for g in normalized_groups if g.name == group_name), None) if not group: + logger.warning( + "Replica group '%s' not found in run_spec. Available groups: %s", + group_name, + [g.name for g in normalized_groups], + ) return base_profile # Merge: group overrides base @@ -757,6 +762,17 @@ def _get_profile_for_job(run_spec: RunSpec, job: Job) -> Profile: if group_value is not None: setattr(merged, field_name, group_value) + logger.debug( + "Profile for group '%s': regions=%s, backends=%s, spot_policy=%s (base had: regions=%s, backends=%s, spot_policy=%s)", + group_name, + merged.regions, + [b.value for b in merged.backends] if merged.backends else None, + merged.spot_policy, + base_profile.regions, + [b.value for b in base_profile.backends] if base_profile.backends else None, + base_profile.spot_policy, + ) + return merged @@ -775,6 +791,29 @@ async def _run_job_on_new_instance( volumes = [] profile = _get_profile_for_job(run.run_spec, job) requirements = job.job_spec.requirements # Already has group resources baked in + + # Debug logging for replica groups + replica_group_name = job.job_spec.replica_group_name + if replica_group_name: + logger.debug( + "%s: Provisioning replica group '%s' with profile: regions=%s, backends=%s, spot_policy=%s", + fmt(job_model), + replica_group_name, + profile.regions, + [b.value for b in profile.backends] if profile.backends else None, + profile.spot_policy, + ) + gpu_req = ( + requirements.resources.gpu.name + if requirements.resources and requirements.resources.gpu + else None + ) + logger.debug( + "%s: GPU requirements for group '%s': %s", + fmt(job_model), + replica_group_name, + gpu_req, + ) fleet = None if fleet_model is not None: fleet = fleet_model_to_fleet(fleet_model) @@ -804,6 +843,16 @@ async def _run_job_on_new_instance( privileged=job.job_spec.privileged, instance_mounts=check_run_spec_requires_instance_mounts(run.run_spec), ) + + # Debug logging for offers + if replica_group_name and len(offers) > 0: + logger.debug( + "%s: Got %d offers for group '%s'. First 3: %s", + fmt(job_model), + len(offers), + replica_group_name, + [f"{o.instance.name} ({o.backend.value}/{o.region})" for _, o in offers[:3]], + ) # Limit number of offers tried to prevent long-running processing # in case all offers fail. for backend, offer in offers[: settings.MAX_OFFERS_TRIED]: diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 895155b33..fbe6f1d6c 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -30,6 +30,8 @@ ) from dstack._internal.core.models.profiles import ( CreationPolicy, + Profile, + ProfileParams, RetryEvent, ) from dstack._internal.core.models.repos.virtual import DEFAULT_VIRTUAL_REPO_ID, VirtualRunRepoData @@ -302,6 +304,45 @@ async def get_run_by_id( return run_model_to_run(run_model, return_in_api=True) +def _get_job_profile(run_spec: RunSpec, replica_group_name: Optional[str]) -> Profile: + """ + Get the profile for a job, including replica group overrides if applicable. + + Args: + run_spec: The run specification + replica_group_name: Name of the replica group, or None for legacy jobs + + Returns: + Profile with replica group overrides applied + """ + base_profile = run_spec.merged_profile + + # If no replica group, return base profile + if not replica_group_name: + return base_profile + + # Find the replica group + if run_spec.configuration.type != "service": + return base_profile + + from dstack._internal.core.models.runs import get_normalized_replica_groups + + normalized_groups = get_normalized_replica_groups(run_spec.configuration) + replica_group = next((g for g in normalized_groups if g.name == replica_group_name), None) + + if not replica_group: + return base_profile + + # Clone base profile and apply group overrides + merged = Profile.parse_obj(base_profile.dict()) + for field_name in ProfileParams.__fields__: + group_value = getattr(replica_group, field_name, None) + if group_value is not None: + setattr(merged, field_name, group_value) + + return merged + + async def get_plan( session: AsyncSession, project: ProjectModel, @@ -427,14 +468,17 @@ async def get_plan( job_offers.extend(matching_pool_offers) + # Get the correct profile for this job (with replica group overrides if applicable) + job_profile = _get_job_profile(effective_run_spec, job.job_spec.replica_group_name) + # Use shared offers if all jobs are identical, otherwise fetch per-job if shared_offers: job_offers.extend(offer for _, offer in shared_offers) elif creation_policy == CreationPolicy.REUSE_OR_CREATE: - # Fetch offers specific to this job's requirements + # Fetch offers specific to this job's requirements with job-specific profile job_specific_offers = await get_offers_by_requirements( project=project, - profile=profile, + profile=job_profile, requirements=job.job_spec.requirements, exclude_not_available=False, multinode=job.job_spec.jobs_per_replica > 1, @@ -1160,6 +1204,7 @@ def _validate_run_spec_and_set_defaults(user: UserModel, run_spec: RunSpec): "service": [ # in-place "replicas", + "replica_groups", # Named replica groups (mutually exclusive with replicas) "scaling", # rolling deployment # NOTE: keep this list in sync with the "Rolling deployment" section in services.md diff --git a/src/tests/_internal/cli/utils/test_run_plan_display.py b/src/tests/_internal/cli/utils/test_run_plan_display.py index e5255c106..98d115343 100644 --- a/src/tests/_internal/cli/utils/test_run_plan_display.py +++ b/src/tests/_internal/cli/utils/test_run_plan_display.py @@ -420,3 +420,220 @@ def test_replica_group_with_no_offers_shows_message(self, capsys): unavailable_pos = output.find("unavailable-group:") available_pos = output.find("available-group:") assert unavailable_pos < available_pos, "Group with no offers should appear first" + + +class TestReplicaGroupsFairOfferDistribution: + """Test that CLI displays offers from all replica groups fairly.""" + + def test_all_groups_represented_in_display(self, capsys): + """Test that offers from all replica groups are shown when max_offers is set.""" + # Create offers for three groups with different price ranges + h100_offers = [ + create_test_offer(BackendType.AWS, "H100", 3.0, region="us-east"), + create_test_offer(BackendType.AWS, "H100", 3.5, region="us-west"), + create_test_offer(BackendType.GCP, "H100", 4.0, region="eu-west"), + ] + + rtx5090_offers = [ + create_test_offer(BackendType.VASTAI, "RTX5090", 0.5, region="us"), + create_test_offer(BackendType.VASTAI, "RTX5090", 0.6, region="eu"), + ] + + a100_offers = [ + create_test_offer(BackendType.AWS, "A100", 2.0, region="us-east"), + create_test_offer(BackendType.GCP, "A100", 2.2, region="eu-west"), + ] + + # Create job plans for each group + job_plan_h100 = JobPlan( + job_spec=JobSpec( + replica_num=0, + replica_group_name="h100-group", + job_num=0, + job_name="test-job-0", + image_name="dstackai/base", + commands=["echo test"], + env={}, + working_dir="/workflow", + requirements=Requirements( + resources=ResourcesSpec(gpu={"name": "H100", "count": 1}) + ), + ), + offers=h100_offers, + total_offers=len(h100_offers), + max_price=4.0, + ) + + job_plan_rtx5090 = JobPlan( + job_spec=JobSpec( + replica_num=1, + replica_group_name="rtx5090-group", + job_num=1, + job_name="test-job-1", + image_name="dstackai/base", + commands=["echo test"], + env={}, + working_dir="/workflow", + requirements=Requirements( + resources=ResourcesSpec(gpu={"name": "RTX5090", "count": 1}) + ), + ), + offers=rtx5090_offers, + total_offers=len(rtx5090_offers), + max_price=0.6, + ) + + job_plan_a100 = JobPlan( + job_spec=JobSpec( + replica_num=2, + replica_group_name="a100-group", + job_num=2, + job_name="test-job-2", + image_name="dstackai/base", + commands=["echo test"], + env={}, + working_dir="/workflow", + requirements=Requirements( + resources=ResourcesSpec(gpu={"name": "A100", "count": 1}) + ), + ), + offers=a100_offers, + total_offers=len(a100_offers), + max_price=2.2, + ) + + config = ServiceConfiguration( + type="service", + port=8000, + commands=["echo test"], + ) + + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data=LocalRunRepoData(repo_dir="/tmp"), + configuration=config, + configuration_path=".dstack.yml", + profile=Profile(backends=[BackendType.AWS]), + ) + + run_plan = RunPlan( + project_name="test-project", + user="test-user", + run_spec=run_spec, + effective_run_spec=run_spec, + job_plans=[job_plan_h100, job_plan_rtx5090, job_plan_a100], + current_resource=None, + action=ApplyAction.CREATE, + ) + + # Print with max_offers=5 (should show at least 1 from each group) + print_run_plan(run_plan, max_offers=5, include_run_properties=True) + + captured = capsys.readouterr() + output = captured.out + + # Verify all three groups appear in the output + assert "h100-group:" in output, "H100 group should be displayed" + assert "rtx5090-group:" in output, "RTX5090 group should be displayed" + assert "a100-group:" in output, "A100 group should be displayed" + + # Verify GPUs are shown + assert "H100" in output + assert "RTX5090" in output + assert "A100" in output + + def test_fair_distribution_with_limited_slots(self, capsys): + """Test that when max_offers is limited, all groups get fair representation.""" + # Group 1: Many cheap offers + cheap_offers = [ + create_test_offer(BackendType.VASTAI, "RTX5090", 0.4 + i * 0.1, region="us") + for i in range(10) + ] + + # Group 2: Few expensive offers + expensive_offers = [ + create_test_offer(BackendType.AWS, "H100", 3.0 + i * 0.5, region="us-east") + for i in range(3) + ] + + job_plan_cheap = JobPlan( + job_spec=JobSpec( + replica_num=0, + replica_group_name="cheap-group", + job_num=0, + job_name="test-job-0", + image_name="dstackai/base", + commands=["echo test"], + env={}, + working_dir="/workflow", + requirements=Requirements( + resources=ResourcesSpec(gpu={"name": "RTX5090", "count": 1}) + ), + ), + offers=cheap_offers, + total_offers=len(cheap_offers), + max_price=1.4, + ) + + job_plan_expensive = JobPlan( + job_spec=JobSpec( + replica_num=1, + replica_group_name="expensive-group", + job_num=1, + job_name="test-job-1", + image_name="dstackai/base", + commands=["echo test"], + env={}, + working_dir="/workflow", + requirements=Requirements( + resources=ResourcesSpec(gpu={"name": "H100", "count": 1}) + ), + ), + offers=expensive_offers, + total_offers=len(expensive_offers), + max_price=4.0, + ) + + config = ServiceConfiguration( + type="service", + port=8000, + commands=["echo test"], + ) + + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data=LocalRunRepoData(repo_dir="/tmp"), + configuration=config, + configuration_path=".dstack.yml", + profile=Profile(backends=[BackendType.VASTAI]), + ) + + run_plan = RunPlan( + project_name="test-project", + user="test-user", + run_spec=run_spec, + effective_run_spec=run_spec, + job_plans=[job_plan_cheap, job_plan_expensive], + current_resource=None, + action=ApplyAction.CREATE, + ) + + # Print with max_offers=4 (should show at least 1 from each group) + print_run_plan(run_plan, max_offers=4, include_run_properties=True) + + captured = capsys.readouterr() + output = captured.out + + # Both groups should be represented + assert "cheap-group:" in output + assert "expensive-group:" in output + + # Count occurrences (rough check - both should appear) + cheap_count = output.count("cheap-group:") + expensive_count = output.count("expensive-group:") + + # Both should have at least one offer shown + assert cheap_count >= 1, "Cheap group should have at least one offer" + assert expensive_count >= 1, "Expensive group should have at least one offer" diff --git a/src/tests/_internal/server/services/test_replica_groups_profile_overrides.py b/src/tests/_internal/server/services/test_replica_groups_profile_overrides.py new file mode 100644 index 000000000..65bf86ccf --- /dev/null +++ b/src/tests/_internal/server/services/test_replica_groups_profile_overrides.py @@ -0,0 +1,265 @@ +"""Tests for replica group profile overrides (regions, spot_policy, etc.)""" + +import pytest +from pydantic import parse_obj_as + +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.configurations import ( + ReplicaGroup, + ScalingSpec, + ServiceConfiguration, +) +from dstack._internal.core.models.profiles import Profile, SpotPolicy +from dstack._internal.core.models.resources import GPUSpec, Range, ResourcesSpec +from dstack._internal.core.models.runs import RunSpec +from dstack._internal.server.services.runs import _get_job_profile + +pytestmark = pytest.mark.usefixtures("image_config_mock") + + +def test_spot_policy_override_per_group(): + """Test that each replica group can have its own spot_policy.""" + # Create a service with different spot policies per group + config = ServiceConfiguration( + commands=["echo hello"], + port=8000, + replica_groups=[ + ReplicaGroup( + name="on-demand-group", + replicas=parse_obj_as(Range[int], 1), + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "H100:1")), + spot_policy=SpotPolicy.ONDEMAND, + ), + ReplicaGroup( + name="spot-group", + replicas=parse_obj_as(Range[int], 1), + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "RTX5090:1")), + spot_policy=SpotPolicy.SPOT, + ), + ReplicaGroup( + name="auto-group", + replicas=parse_obj_as(Range[int], 1), + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "A100:1")), + spot_policy=SpotPolicy.AUTO, + ), + ], + scaling=ScalingSpec(metric="rps", target=10), + ) + + profile = Profile(name="test-profile", spot_policy=SpotPolicy.AUTO) # base policy + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data={"repo_type": "local", "repo_dir": "/repo"}, + configuration_path="dstack.yaml", + configuration=config, + profile=profile, + ssh_key_pub="ssh_key", + ) + + # Test on-demand group + on_demand_profile = _get_job_profile(run_spec, "on-demand-group") + assert on_demand_profile.spot_policy == SpotPolicy.ONDEMAND + + # Test spot group + spot_profile = _get_job_profile(run_spec, "spot-group") + assert spot_profile.spot_policy == SpotPolicy.SPOT + + # Test auto group + auto_profile = _get_job_profile(run_spec, "auto-group") + assert auto_profile.spot_policy == SpotPolicy.AUTO + + # Test legacy (no group) uses base profile + legacy_profile = _get_job_profile(run_spec, None) + assert legacy_profile.spot_policy == SpotPolicy.AUTO + + +def test_regions_override_per_group(): + """Test that each replica group can have its own regions.""" + config = ServiceConfiguration( + commands=["echo hello"], + port=8000, + replica_groups=[ + ReplicaGroup( + name="us-group", + replicas=parse_obj_as(Range[int], 1), + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "H100:1")), + regions=["us-east-1", "us-west-2"], + ), + ReplicaGroup( + name="eu-group", + replicas=parse_obj_as(Range[int], 1), + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "RTX5090:1")), + regions=["eu-west-1", "eu-central-1"], + ), + ReplicaGroup( + name="asia-group", + replicas=parse_obj_as(Range[int], 1), + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "A100:1")), + regions=["ap-northeast-1"], + ), + ], + scaling=ScalingSpec(metric="rps", target=10), + ) + + profile = Profile(name="test-profile", regions=["us-east-1"]) # base regions + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data={"repo_type": "local", "repo_dir": "/repo"}, + configuration_path="dstack.yaml", + configuration=config, + profile=profile, + ssh_key_pub="ssh_key", + ) + + # Test US group + us_profile = _get_job_profile(run_spec, "us-group") + assert us_profile.regions == ["us-east-1", "us-west-2"] + + # Test EU group + eu_profile = _get_job_profile(run_spec, "eu-group") + assert eu_profile.regions == ["eu-west-1", "eu-central-1"] + + # Test Asia group + asia_profile = _get_job_profile(run_spec, "asia-group") + assert asia_profile.regions == ["ap-northeast-1"] + + # Test legacy (no group) uses base profile + legacy_profile = _get_job_profile(run_spec, None) + assert legacy_profile.regions == ["us-east-1"] + + +def test_backends_override_per_group(): + """Test that each replica group can have its own backends.""" + config = ServiceConfiguration( + commands=["echo hello"], + port=8000, + replica_groups=[ + ReplicaGroup( + name="aws-group", + replicas=parse_obj_as(Range[int], 1), + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "H100:1")), + backends=[BackendType.AWS], + ), + ReplicaGroup( + name="vastai-group", + replicas=parse_obj_as(Range[int], 1), + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "RTX5090:1")), + backends=[BackendType.VASTAI], + ), + ], + scaling=ScalingSpec(metric="rps", target=10), + ) + + profile = Profile( + name="test-profile", + backends=[BackendType.AWS, BackendType.GCP], # base backends + ) + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data={"repo_type": "local", "repo_dir": "/repo"}, + configuration_path="dstack.yaml", + configuration=config, + profile=profile, + ssh_key_pub="ssh_key", + ) + + # Test AWS group + aws_profile = _get_job_profile(run_spec, "aws-group") + assert aws_profile.backends == [BackendType.AWS] + + # Test VastAI group + vastai_profile = _get_job_profile(run_spec, "vastai-group") + assert vastai_profile.backends == [BackendType.VASTAI] + + # Test legacy (no group) uses base profile + legacy_profile = _get_job_profile(run_spec, None) + assert legacy_profile.backends == [BackendType.AWS, BackendType.GCP] + + +def test_multiple_profile_overrides_per_group(): + """Test that a replica group can override multiple profile parameters at once.""" + config = ServiceConfiguration( + commands=["echo hello"], + port=8000, + replica_groups=[ + ReplicaGroup( + name="specialized-group", + replicas=parse_obj_as(Range[int], 1), + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "H100:1")), + regions=["us-west-2"], + backends=[BackendType.AWS], + spot_policy=SpotPolicy.ONDEMAND, + max_price=5.0, + ), + ], + scaling=ScalingSpec(metric="rps", target=10), + ) + + profile = Profile( + name="test-profile", + regions=["us-east-1"], + backends=[BackendType.GCP], + spot_policy=SpotPolicy.SPOT, + max_price=1.0, + ) + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data={"repo_type": "local", "repo_dir": "/repo"}, + configuration_path="dstack.yaml", + configuration=config, + profile=profile, + ssh_key_pub="ssh_key", + ) + + specialized_profile = _get_job_profile(run_spec, "specialized-group") + assert specialized_profile.regions == ["us-west-2"] + assert specialized_profile.backends == [BackendType.AWS] + assert specialized_profile.spot_policy == SpotPolicy.ONDEMAND + assert specialized_profile.max_price == 5.0 + + +def test_partial_profile_override(): + """Test that only specified profile parameters are overridden, others inherit from base.""" + config = ServiceConfiguration( + commands=["echo hello"], + port=8000, + replica_groups=[ + ReplicaGroup( + name="partial-group", + replicas=parse_obj_as(Range[int], 1), + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "H100:1")), + regions=["us-west-2"], # Only override regions + # spot_policy, backends, max_price should inherit from base + ), + ], + scaling=ScalingSpec(metric="rps", target=10), + ) + + profile = Profile( + name="test-profile", + regions=["us-east-1"], + backends=[BackendType.GCP, BackendType.AWS], + spot_policy=SpotPolicy.SPOT, + max_price=2.5, + ) + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data={"repo_type": "local", "repo_dir": "/repo"}, + configuration_path="dstack.yaml", + configuration=config, + profile=profile, + ssh_key_pub="ssh_key", + ) + + partial_profile = _get_job_profile(run_spec, "partial-group") + # Overridden + assert partial_profile.regions == ["us-west-2"] + # Inherited from base + assert partial_profile.backends == [BackendType.GCP, BackendType.AWS] + assert partial_profile.spot_policy == SpotPolicy.SPOT + assert partial_profile.max_price == 2.5 diff --git a/src/tests/_internal/server/services/test_replica_groups_update.py b/src/tests/_internal/server/services/test_replica_groups_update.py new file mode 100644 index 000000000..1ae7148cc --- /dev/null +++ b/src/tests/_internal/server/services/test_replica_groups_update.py @@ -0,0 +1,172 @@ +"""Tests for updating services with replica groups.""" + +from pydantic import parse_obj_as + +from dstack._internal.core.models.configurations import ( + ReplicaGroup, + ScalingSpec, + ServiceConfiguration, +) +from dstack._internal.core.models.profiles import Profile, SpotPolicy +from dstack._internal.core.models.resources import GPUSpec, Range, ResourcesSpec +from dstack._internal.core.models.runs import RunSpec +from dstack._internal.server.services.runs import _check_can_update_run_spec + + +def test_can_update_from_replicas_to_replica_groups(): + """Test that we can update a service from simple replicas to replica_groups.""" + # Old config with simple replicas + old_config = ServiceConfiguration( + commands=["echo hello"], + port=8000, + replicas=2, + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "H100:1")), + ) + + old_run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data={"repo_type": "local", "repo_dir": "/repo"}, + configuration_path="dstack.yaml", + configuration=old_config, + profile=Profile(name="test-profile"), + ssh_key_pub="ssh_key", + ) + + # New config with replica_groups + new_config = ServiceConfiguration( + commands=["echo hello"], + port=8000, + replica_groups=[ + ReplicaGroup( + name="h100-group", + replicas=parse_obj_as(Range[int], 1), + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "H100:1")), + regions=["us-east-1"], + ), + ReplicaGroup( + name="rtx5090-group", + replicas=parse_obj_as(Range[int], "0..3"), + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "RTX5090:1")), + regions=["jp-japan"], + ), + ], + scaling=ScalingSpec(metric="rps", target=10), + ) + + new_run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data={"repo_type": "local", "repo_dir": "/repo"}, + configuration_path="dstack.yaml", + configuration=new_config, + profile=Profile(name="test-profile"), + ssh_key_pub="ssh_key", + ) + + # This should NOT raise an error + _check_can_update_run_spec(old_run_spec, new_run_spec) + + +def test_can_update_from_replica_groups_to_replicas(): + """Test that we can update a service from replica_groups back to simple replicas.""" + # Old config with replica_groups + old_config = ServiceConfiguration( + commands=["echo hello"], + port=8000, + replica_groups=[ + ReplicaGroup( + name="h100-group", + replicas=parse_obj_as(Range[int], 1), + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "H100:1")), + ), + ], + ) + + old_run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data={"repo_type": "local", "repo_dir": "/repo"}, + configuration_path="dstack.yaml", + configuration=old_config, + profile=Profile(name="test-profile"), + ssh_key_pub="ssh_key", + ) + + # New config with simple replicas + new_config = ServiceConfiguration( + commands=["echo hello"], + port=8000, + replicas=2, + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "A100:1")), + ) + + new_run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data={"repo_type": "local", "repo_dir": "/repo"}, + configuration_path="dstack.yaml", + configuration=new_config, + profile=Profile(name="test-profile"), + ssh_key_pub="ssh_key", + ) + + # This should NOT raise an error + _check_can_update_run_spec(old_run_spec, new_run_spec) + + +def test_can_update_replica_groups(): + """Test that we can update replica_groups in place.""" + # Old config + old_config = ServiceConfiguration( + commands=["echo hello"], + port=8000, + replica_groups=[ + ReplicaGroup( + name="gpu-group", + replicas=parse_obj_as(Range[int], "1..3"), + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "H100:1")), + regions=["us-east-1"], + ), + ], + scaling=ScalingSpec(metric="rps", target=10), + ) + + old_run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data={"repo_type": "local", "repo_dir": "/repo"}, + configuration_path="dstack.yaml", + configuration=old_config, + profile=Profile(name="test-profile"), + ssh_key_pub="ssh_key", + ) + + # New config with different replica_groups + new_config = ServiceConfiguration( + commands=["echo hello"], + port=8000, + replica_groups=[ + ReplicaGroup( + name="gpu-group", + replicas=parse_obj_as(Range[int], "2..5"), # Changed range + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "A100:1")), # Changed GPU + regions=["us-west-2"], # Changed region + spot_policy=SpotPolicy.SPOT, # Added spot policy + ), + ], + scaling=ScalingSpec(metric="rps", target=20), # Changed target + ) + + new_run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data={"repo_type": "local", "repo_dir": "/repo"}, + configuration_path="dstack.yaml", + configuration=new_config, + profile=Profile(name="test-profile"), + ssh_key_pub="ssh_key", + ) + + # This should NOT raise an error (replica_groups + resources + scaling are all updatable) + _check_can_update_run_spec(old_run_spec, new_run_spec) From f67b38f9664bdf2faf083dae52230cafde05f444 Mon Sep 17 00:00:00 2001 From: "Alexander Nicholson 4584443+DragonStuff@users.noreply.github.com" <4584443+DragonStuff@users.noreply.github.com> Date: Sun, 19 Oct 2025 02:44:54 +0900 Subject: [PATCH 5/5] feat: enhance replica group handling and migration for legacy jobs - Implemented migration for legacy jobs that lack a replica_group_name, ensuring they are correctly assigned to the appropriate replica groups. - Updated CLI output to display group-specific properties such as spot policy, regions, and backends for better clarity. - Enhanced tests to validate the migration process and ensure that jobs are correctly assigned to their respective groups. - Improved handling of pool offers to accommodate multiple jobs in replica groups, ensuring all GPU types are considered. This update improves the robustness of the service configuration and enhances user experience by providing clearer information in the CLI. --- src/dstack/_internal/cli/utils/run.py | 31 ++- .../server/background/tasks/process_runs.py | 68 +++++ .../tasks/process_submitted_jobs.py | 43 +++- src/dstack/_internal/server/services/runs.py | 58 ++++- src/dstack/_internal/server/testing/common.py | 13 +- .../cli/utils/test_run_plan_display.py | 117 +++++++++ .../tasks/test_migrate_legacy_jobs.py | 242 ++++++++++++++++++ 7 files changed, 549 insertions(+), 23 deletions(-) create mode 100644 src/tests/_internal/server/background/tasks/test_migrate_legacy_jobs.py diff --git a/src/dstack/_internal/cli/utils/run.py b/src/dstack/_internal/cli/utils/run.py index d3706e970..b61bf94fb 100644 --- a/src/dstack/_internal/cli/utils/run.py +++ b/src/dstack/_internal/cli/utils/run.py @@ -122,31 +122,52 @@ def th(s: str) -> str: from dstack._internal.core.models.configurations import ServiceConfiguration - if ( + has_replica_groups = ( include_run_properties and isinstance(run_spec.configuration, ServiceConfiguration) and run_spec.configuration.replica_groups - ): + ) + + if has_replica_groups: groups_info = [] for group in run_spec.configuration.replica_groups: group_parts = [f"[cyan]{group.name}[/cyan]"] + # Replica count if group.replicas.min == group.replicas.max: group_parts.append(f"×{group.replicas.max}") else: group_parts.append(f"×{group.replicas.min}..{group.replicas.max}") group_parts.append("[dim](autoscalable)[/dim]") + # Resources group_parts.append(f"[dim]({group.resources.pretty_format()})[/dim]") + # Group-specific overrides + overrides = [] + if group.spot_policy is not None: + overrides.append(f"spot={group.spot_policy.value}") + if group.regions: + regions_str = ",".join(group.regions[:2]) # Show first 2 + if len(group.regions) > 2: + regions_str += f",+{len(group.regions) - 2}" + overrides.append(f"regions={regions_str}") + if group.backends: + backends_str = ",".join([b.value for b in group.backends[:2]]) + if len(group.backends) > 2: + backends_str += f",+{len(group.backends) - 2}" + overrides.append(f"backends={backends_str}") + + if overrides: + group_parts.append(f"[dim]({'; '.join(overrides)})[/dim]") + groups_info.append(" ".join(group_parts)) props.add_row(th("Replica groups"), "\n".join(groups_info)) else: props.add_row(th("Resources"), pretty_req) - - props.add_row(th("Spot policy"), spot_policy) - props.add_row(th("Max price"), max_price) + props.add_row(th("Spot policy"), spot_policy) + props.add_row(th("Max price"), max_price) if include_run_properties: props.add_row(th("Retry policy"), retry) props.add_row(th("Creation policy"), creation_policy) diff --git a/src/dstack/_internal/server/background/tasks/process_runs.py b/src/dstack/_internal/server/background/tasks/process_runs.py index 35fc949b0..0be74df83 100644 --- a/src/dstack/_internal/server/background/tasks/process_runs.py +++ b/src/dstack/_internal/server/background/tasks/process_runs.py @@ -156,6 +156,10 @@ async def _process_run(session: AsyncSession, run_model: RunModel): ) run_model = res.unique().scalar_one() logger.debug("%s: processing run", fmt(run_model)) + + # Migrate legacy jobs without replica_group_name (one-time fix) + await _migrate_legacy_job_replica_groups(session, run_model) + try: if run_model.status == RunStatus.PENDING: await _process_pending_run(session, run_model) @@ -176,6 +180,70 @@ async def _process_run(session: AsyncSession, run_model: RunModel): await session.commit() +async def _migrate_legacy_job_replica_groups(session: AsyncSession, run_model: RunModel): + """ + Migrate jobs from old runs that don't have replica_group_name set. + This fixes jobs created before the replica_groups feature was added. + """ + run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) + + # Only migrate service runs with replica_groups + if run_spec.configuration.type != "service": + return + + # Check if run uses replica_groups + if not getattr(run_spec.configuration, "replica_groups", None): + return + + from dstack._internal.core.models.runs import get_normalized_replica_groups + + normalized_groups = get_normalized_replica_groups(run_spec.configuration) + + # Check if any jobs need migration + needs_migration = any(job.replica_group_name is None for job in run_model.jobs) + + if not needs_migration: + return + + logger.info( + "%s: Migrating legacy jobs to assign replica_group_name", + fmt(run_model), + ) + + # Build a map of replica_num -> group_name based on how jobs were originally created + replica_num_to_group = {} + current_replica_num = 0 + + for group in normalized_groups: + group_min = group.replicas.min or 0 + for _ in range(group_min): + replica_num_to_group[current_replica_num] = group.name + current_replica_num += 1 + + # Update jobs + migrated_count = 0 + for job in run_model.jobs: + if job.replica_group_name is None: + expected_group = replica_num_to_group.get(job.replica_num) + if expected_group: + job.replica_group_name = expected_group + migrated_count += 1 + logger.info( + "%s: Migrated job replica_num=%d to group '%s'", + fmt(run_model), + job.replica_num, + expected_group, + ) + + if migrated_count > 0: + await session.commit() + logger.info( + "%s: Migrated %d job(s) to replica groups", + fmt(run_model), + migrated_count, + ) + + async def _process_pending_run(session: AsyncSession, run_model: RunModel): """Jobs are not created yet""" run = run_model_to_run(run_model) diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index df05e76a8..ea0422808 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -740,11 +740,23 @@ def _get_profile_for_job(run_spec: RunSpec, job: Job) -> Profile: base_profile = run_spec.merged_profile group_name = job.job_spec.replica_group_name + logger.info( + "Getting profile for job %s: replica_group_name=%s, config_type=%s", + job.job_spec.job_name, + group_name, + run_spec.configuration.type, + ) + if not group_name or run_spec.configuration.type != "service": + logger.info("Using base profile (no group_name or not a service)") return base_profile # Find the group normalized_groups = get_normalized_replica_groups(run_spec.configuration) + logger.info( + "Normalized groups: %s", + [f"{g.name} (regions={g.regions})" for g in normalized_groups], + ) group = next((g for g in normalized_groups if g.name == group_name), None) if not group: @@ -762,7 +774,7 @@ def _get_profile_for_job(run_spec: RunSpec, job: Job) -> Profile: if group_value is not None: setattr(merged, field_name, group_value) - logger.debug( + logger.info( "Profile for group '%s': regions=%s, backends=%s, spot_policy=%s (base had: regions=%s, backends=%s, spot_policy=%s)", group_name, merged.regions, @@ -832,6 +844,21 @@ async def _run_job_on_new_instance( multinode = job.job_spec.jobs_per_replica > 1 or ( fleet is not None and fleet.spec.configuration.placement == InstanceGroupPlacement.CLUSTER ) + + # Log the requirements and profile being used + gpu_requirement = ( + requirements.resources.gpu.name + if requirements.resources and requirements.resources.gpu + else None + ) + logger.info( + "%s: Fetching offers with GPU=%s, regions=%s, backends=%s", + fmt(job_model), + gpu_requirement, + profile.regions, + [b.value for b in profile.backends] if profile.backends else None, + ) + offers = await get_offers_by_requirements( project=project, profile=profile, @@ -845,14 +872,12 @@ async def _run_job_on_new_instance( ) # Debug logging for offers - if replica_group_name and len(offers) > 0: - logger.debug( - "%s: Got %d offers for group '%s'. First 3: %s", - fmt(job_model), - len(offers), - replica_group_name, - [f"{o.instance.name} ({o.backend.value}/{o.region})" for _, o in offers[:3]], - ) + logger.info( + "%s: Got %d offers. First 3: %s", + fmt(job_model), + len(offers), + [f"{o.instance.name} ({o.backend.value}/{o.region})" for _, o in offers[:3]], + ) # Limit number of offers tried to prevent long-running processing # in case all offers fail. for backend, offer in offers[: settings.MAX_OFFERS_TRIED]: diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index fbe6f1d6c..ffc7cbc14 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -417,13 +417,36 @@ async def get_plan( job_num=0, ) - pool_offers = await _get_pool_offers( - session=session, - project=project, - run_spec=effective_run_spec, - job=jobs[0], - volumes=volumes, - ) + # For replica groups, we need pool offers for all GPU types, not just the first job's type + # So we fetch pool offers for each job separately and aggregate them + all_pool_offers = [] + if len(jobs) > 1: + # Multiple jobs (likely replica groups) - get pool offers per job to include all GPU types + for job in jobs: + job_pool_offers = await _get_pool_offers( + session=session, + project=project, + run_spec=effective_run_spec, + job=job, + volumes=volumes, + ) + all_pool_offers.extend(job_pool_offers) + # Deduplicate by (backend, instance_name, region) tuple + seen_offers = set() + pool_offers = [] + for offer in all_pool_offers: + offer_key = (offer.backend, offer.instance.name, offer.region) + if offer_key not in seen_offers: + seen_offers.add(offer_key) + pool_offers.append(offer) + else: + pool_offers = await _get_pool_offers( + session=session, + project=project, + run_spec=effective_run_spec, + job=jobs[0], + volumes=volumes, + ) effective_run_spec.run_name = "dry-run" # will regenerate jobs on submission # Check if all jobs have identical requirements (optimization for single-type jobs) @@ -1561,10 +1584,29 @@ async def retry_run_replica_jobs( session=session, project=run_model.project, ) + + # Determine which replica group this job belongs to + run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) + replica_group = None + if run_spec.configuration.type == "service" and latest_jobs: + from dstack._internal.core.models.runs import get_normalized_replica_groups + + group_name = latest_jobs[0].replica_group_name + if group_name: + normalized_groups = get_normalized_replica_groups(run_spec.configuration) + replica_group = next((g for g in normalized_groups if g.name == group_name), None) + if replica_group: + logger.info( + "%s: retrying job from replica group '%s'", + fmt(run_model), + replica_group.name, + ) + new_jobs = await get_jobs_from_run_spec( - run_spec=RunSpec.__response__.parse_raw(run_model.run_spec), + run_spec=run_spec, secrets=secrets, replica_num=latest_jobs[0].replica_num, + replica_group=replica_group, ) assert len(new_jobs) == len(latest_jobs), ( "Changing the number of jobs within a replica is not yet supported" diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index f7be289d3..6b1d42a10 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -354,8 +354,19 @@ async def create_job( if deployment_num is None: deployment_num = run.deployment_num run_spec = RunSpec.parse_raw(run.run_spec) + + # Look up replica group if specified + replica_group = None + if replica_group_name and run_spec.configuration.type == "service": + from dstack._internal.core.models.runs import get_normalized_replica_groups + + normalized_groups = get_normalized_replica_groups(run_spec.configuration) + replica_group = next((g for g in normalized_groups if g.name == replica_group_name), None) + job_spec = ( - await get_job_specs_from_run_spec(run_spec=run_spec, secrets={}, replica_num=replica_num) + await get_job_specs_from_run_spec( + run_spec=run_spec, secrets={}, replica_num=replica_num, replica_group=replica_group + ) )[0] job_spec.job_num = job_num job = JobModel( diff --git a/src/tests/_internal/cli/utils/test_run_plan_display.py b/src/tests/_internal/cli/utils/test_run_plan_display.py index 98d115343..8ec6bb4fb 100644 --- a/src/tests/_internal/cli/utils/test_run_plan_display.py +++ b/src/tests/_internal/cli/utils/test_run_plan_display.py @@ -637,3 +637,120 @@ def test_fair_distribution_with_limited_slots(self, capsys): # Both should have at least one offer shown assert cheap_count >= 1, "Cheap group should have at least one offer" assert expensive_count >= 1, "Expensive group should have at least one offer" + + +class TestReplicaGroupsProfileOverridesDisplay: + """Test that CLI correctly displays profile overrides for replica groups.""" + + def test_shows_group_specific_spot_policy_and_regions(self, capsys): + """Test that group-specific spot_policy, regions, backends are displayed.""" + + from dstack._internal.core.models.backends.base import BackendType + + config = ServiceConfiguration( + type="service", + port=8000, + commands=["echo test"], + replica_groups=[ + { + "name": "h100-group", + "replicas": "1", + "resources": {"gpu": {"name": "H100", "count": 1}}, + "spot_policy": "spot", + "regions": ["us-east-1", "us-west-2"], + "backends": ["aws"], + }, + { + "name": "rtx5090-group", + "replicas": "0..5", + "resources": {"gpu": {"name": "RTX5090", "count": 1}}, + "spot_policy": "on-demand", + "regions": ["jp-japan"], + "backends": ["vastai", "runpod"], + }, + ], + scaling={"metric": "rps", "target": 10}, + ) + + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data=LocalRunRepoData(repo_dir="/tmp"), + configuration=config, + configuration_path=".dstack.yml", + profile=Profile(backends=[BackendType.AWS]), + ) + + # Create job plans + job_plan_h100 = JobPlan( + job_spec=JobSpec( + replica_num=0, + replica_group_name="h100-group", + job_num=0, + job_name="test-job-0", + image_name="dstackai/base", + commands=["echo test"], + env={}, + working_dir="/workflow", + requirements=Requirements( + resources=ResourcesSpec(gpu={"name": "H100", "count": 1}) + ), + ), + offers=[create_test_offer(BackendType.AWS, "H100", 3.0)], + total_offers=1, + max_price=3.0, + ) + + job_plan_rtx = JobPlan( + job_spec=JobSpec( + replica_num=1, + replica_group_name="rtx5090-group", + job_num=1, + job_name="test-job-1", + image_name="dstackai/base", + commands=["echo test"], + env={}, + working_dir="/workflow", + requirements=Requirements( + resources=ResourcesSpec(gpu={"name": "RTX5090", "count": 1}) + ), + ), + offers=[create_test_offer(BackendType.VASTAI, "RTX5090", 0.5)], + total_offers=1, + max_price=0.5, + ) + + run_plan = RunPlan( + project_name="test-project", + user="test-user", + run_spec=run_spec, + effective_run_spec=run_spec, + job_plans=[job_plan_h100, job_plan_rtx], + current_resource=None, + action=ApplyAction.CREATE, + ) + + # Print the plan + print_run_plan(run_plan, max_offers=10, include_run_properties=True) + + # Capture output + captured = capsys.readouterr() + output = captured.out + + # Verify group-specific overrides are shown + assert "h100-group" in output + assert "spot=spot" in output # H100 group's spot policy + assert "regions=us-east-1,us-west-2" in output # H100 group's regions + assert "backends=aws" in output # H100 group's backend + + assert "rtx5090-group" in output + assert "spot=on-demand" in output # RTX5090 group's spot policy + assert "regions=jp-japan" in output # RTX5090 group's region + assert "backends=vastai,runpod" in output # RTX5090 group's backends + + # Verify service-level "Spot policy" row is NOT shown (misleading with groups) + lines = output.split("\n") + spot_policy_lines = [line for line in lines if line.strip().startswith("Spot policy")] + assert len(spot_policy_lines) == 0, ( + "Service-level 'Spot policy' should not be shown with replica_groups" + ) diff --git a/src/tests/_internal/server/background/tasks/test_migrate_legacy_jobs.py b/src/tests/_internal/server/background/tasks/test_migrate_legacy_jobs.py new file mode 100644 index 000000000..3e9e2e3d5 --- /dev/null +++ b/src/tests/_internal/server/background/tasks/test_migrate_legacy_jobs.py @@ -0,0 +1,242 @@ +""" +Tests for migrating legacy jobs without replica_group_name. +""" + +import pytest + +from dstack._internal.core.models.configurations import ( + ServiceConfiguration, +) +from dstack._internal.core.models.profiles import Profile +from dstack._internal.core.models.resources import Range, ResourcesSpec +from dstack._internal.core.models.runs import JobStatus, RunSpec +from dstack._internal.server.background.tasks.process_runs import ( + _migrate_legacy_job_replica_groups, +) +from dstack._internal.server.testing.common import ( + create_job, + create_project, + create_repo, + create_run, + create_user, +) + + +class TestMigrateLegacyJobs: + @pytest.mark.asyncio + async def test_migrates_jobs_without_replica_group_name( + self, test_db, session, socket_enabled + ): + """Test that jobs without replica_group_name get migrated correctly.""" + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + repo = await create_repo(session=session, project_id=project.id) + + # Create a run with replica_groups configuration + service_config = ServiceConfiguration( + replica_groups=[ + { + "name": "h100-gpu", + "replicas": Range(min=1, max=1), + "resources": ResourcesSpec(gpu="H100:1"), + }, + { + "name": "rtx5090-gpu", + "replicas": Range(min=1, max=1), + "resources": ResourcesSpec(gpu="RTX5090:1"), + }, + ], + commands=["echo hello"], + port=8000, + ) + + run_spec = RunSpec( + run_name="test-service", + repo_id="test-repo", + configuration=service_config, + merged_profile=Profile(), + ) + + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="test-service", + run_spec=run_spec, + ) + + # Create jobs WITHOUT replica_group_name (simulating old code) + job1 = await create_job( + session=session, + run=run, + replica_num=0, + replica_group_name=None, # Old job without group + status=JobStatus.RUNNING, + ) + + job2 = await create_job( + session=session, + run=run, + replica_num=1, + replica_group_name=None, # Old job without group + status=JobStatus.RUNNING, + ) + + # Verify jobs have no group + assert job1.replica_group_name is None + assert job2.replica_group_name is None + + # Refresh run to load jobs relationship + await session.refresh(run, ["jobs"]) + + # Run migration + await _migrate_legacy_job_replica_groups(session, run) + await session.refresh(job1) + await session.refresh(job2) + + # Verify jobs now have correct groups + assert job1.replica_group_name == "h100-gpu" + assert job2.replica_group_name == "rtx5090-gpu" + + @pytest.mark.asyncio + async def test_skips_already_migrated_jobs(self, test_db, session): + """Test that jobs with replica_group_name are not re-migrated.""" + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + repo = await create_repo(session=session, project_id=project.id) + + service_config = ServiceConfiguration( + replica_groups=[ + { + "name": "gpu-group", + "replicas": Range(min=1, max=1), + "resources": ResourcesSpec(gpu="A100:1"), + }, + ], + commands=["echo hello"], + port=8000, + ) + + run_spec = RunSpec( + run_name="test-service", + repo_id="test-repo", + configuration=service_config, + merged_profile=Profile(), + ) + + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="test-service", + run_spec=run_spec, + ) + + # Create job WITH replica_group_name (already migrated) + job = await create_job( + session=session, + run=run, + replica_num=0, + replica_group_name="gpu-group", + status=JobStatus.RUNNING, + ) + + original_group = job.replica_group_name + + # Run migration (should be a no-op) + await _migrate_legacy_job_replica_groups(session, run) + await session.refresh(job) + + # Verify group unchanged + assert job.replica_group_name == original_group + + @pytest.mark.asyncio + async def test_skips_non_service_runs(self, test_db, session): + """Test that non-service runs are skipped.""" + from dstack._internal.core.models.configurations import TaskConfiguration + + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + repo = await create_repo(session=session, project_id=project.id) + + task_config = TaskConfiguration( + commands=["echo hello"], + ) + + run_spec = RunSpec( + run_name="test-task", + repo_id="test-repo", + configuration=task_config, + merged_profile=Profile(), + ) + + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="test-task", + run_spec=run_spec, + ) + + job = await create_job( + session=session, + run=run, + replica_num=0, + replica_group_name=None, + status=JobStatus.RUNNING, + ) + + # Run migration (should skip task runs) + await _migrate_legacy_job_replica_groups(session, run) + await session.refresh(job) + + # Verify no change + assert job.replica_group_name is None + + @pytest.mark.asyncio + async def test_skips_legacy_replicas_config(self, test_db, session): + """Test that runs using legacy 'replicas' (not replica_groups) are skipped.""" + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + repo = await create_repo(session=session, project_id=project.id) + + # Use legacy replicas configuration + service_config = ServiceConfiguration( + replicas=Range(min=2, max=2), + commands=["echo hello"], + port=8000, + ) + + run_spec = RunSpec( + run_name="test-service", + repo_id="test-repo", + configuration=service_config, + merged_profile=Profile(), + ) + + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="test-service", + run_spec=run_spec, + ) + + job = await create_job( + session=session, + run=run, + replica_num=0, + replica_group_name=None, + status=JobStatus.RUNNING, + ) + + # Run migration (should skip legacy replicas) + await _migrate_legacy_job_replica_groups(session, run) + await session.refresh(job) + + # Verify no change (legacy replicas don't use replica_group_name) + assert job.replica_group_name is None