Skip to content

Commit

Permalink
Support AWS placement groups for cluster fleets (#1725)
Browse files Browse the repository at this point in the history
* Support aws placement groups for cluster fleets

* Test placement group processing

* Adapt placement groups to work with dstack sky

* Update docs

* Fix tests
  • Loading branch information
r4victor authored Sep 25, 2024
1 parent 42b9bcb commit c187166
Show file tree
Hide file tree
Showing 20 changed files with 611 additions and 29 deletions.
9 changes: 6 additions & 3 deletions docs/docs/concepts/fleets.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ are both acceptable).
to the specified parameters.

!!! info "Network"
Set `placement` to `cluster` if the nodes should be interconnected (e.g. if you'd like to use them for multi-node tasks).
In that case, `dstack` will provision all nodes in the same backend and region.
Set `placement` to `cluster` if the nodes should be interconnected
(e.g. if you'd like to use them for [multi-node tasks](reference/dstack.yml/task.md#distributed-tasks)).
In that case, `dstack` will provision all nodes in the same backend and region and configure the optimal
connectivity via availability zones, placement groups, etc.

Note that cloud fleets aren't supported for the `kubernetes`, `vastai`, and `runpod` backends.

Expand Down Expand Up @@ -120,7 +122,8 @@ are both acceptable).
```

!!! info "Network"
Set `placement` to `cluster` if the hosts are interconnected (e.g. if you'd like to use them for multi-node tasks).
Set `placement` to `cluster` if the hosts are interconnected
(e.g. if you'd like to use them for [multi-node tasks](reference/dstack.yml/task.md#distributed-tasks)).
In that case, by default, `dstack` will automatically detect the private network.
You can specify the [`network`](../reference/dstack.yml/fleet.md#network) parameter manually.

Expand Down
3 changes: 3 additions & 0 deletions src/dstack/_internal/core/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
BackendType.OCI,
BackendType.TENSORDOCK,
]
BACKENDS_WITH_PLACEMENT_GROUPS_SUPPORT = [
BackendType.AWS,
]
BACKENDS_WITH_GATEWAY_SUPPORT = [
BackendType.AWS,
BackendType.AZURE,
Expand Down
43 changes: 40 additions & 3 deletions src/dstack/_internal/core/backends/aws/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
get_user_data,
)
from dstack._internal.core.backends.base.offers import get_catalog_offers
from dstack._internal.core.errors import ComputeError, NoCapacityError
from dstack._internal.core.errors import ComputeError, NoCapacityError, PlacementGroupInUseError
from dstack._internal.core.models.backends.aws import AWSAccessKeyCreds
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import CoreModel, is_core_model_instance
Expand All @@ -31,6 +31,7 @@
InstanceOfferWithAvailability,
SSHKey,
)
from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData
from dstack._internal.core.models.resources import Memory, Range
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
from dstack._internal.core.models.volumes import (
Expand Down Expand Up @@ -116,7 +117,7 @@ def terminate_instance(
ec2_client.terminate_instances(InstanceIds=[instance_id])
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == "InvalidInstanceID.NotFound":
pass
logger.debug("Skipping instance %s termination. Instance not found.", instance_id)
else:
raise e

Expand Down Expand Up @@ -181,6 +182,7 @@ def create_instance(
spot=instance_offer.instance.resources.spot,
subnet_id=subnet_id,
allocate_public_ip=allocate_public_ip,
placement_group_name=instance_config.placement_group_name,
)
)
instance = response[0]
Expand Down Expand Up @@ -239,6 +241,41 @@ def run_job(
instance_config.availability_zone = volume.provisioning_data.availability_zone
return self.create_instance(instance_offer, instance_config)

def create_placement_group(
self,
placement_group: PlacementGroup,
) -> PlacementGroupProvisioningData:
ec2_client = self.session.client("ec2", region_name=placement_group.configuration.region)
logger.debug("Creating placement group %s...", placement_group.name)
ec2_client.create_placement_group(
GroupName=placement_group.name,
Strategy=placement_group.configuration.placement_strategy.value,
)
logger.debug("Created placement group %s", placement_group.name)
return PlacementGroupProvisioningData(
backend=BackendType.AWS,
backend_data=None,
)

def delete_placement_group(
self,
placement_group: PlacementGroup,
):
ec2_client = self.session.client("ec2", region_name=placement_group.configuration.region)
logger.debug("Deleting placement group %s...", placement_group.name)
try:
ec2_client.delete_placement_group(GroupName=placement_group.name)
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == "InvalidPlacementGroup.Unknown":
logger.debug("Placement group %s not found", placement_group.name)
return
elif e.response["Error"]["Code"] == "InvalidPlacementGroup.InUse":
logger.debug("Placement group %s is in use", placement_group.name)
raise PlacementGroupInUseError()
else:
raise e
logger.debug("Deleted placement group %s", placement_group.name)

def create_gateway(
self,
configuration: GatewayComputeConfiguration,
Expand Down Expand Up @@ -372,7 +409,7 @@ def create_gateway(

def terminate_gateway(
self,
instance_id,
instance_id: str,
configuration: GatewayComputeConfiguration,
backend_data: Optional[str] = None,
):
Expand Down
9 changes: 8 additions & 1 deletion src/dstack/_internal/core/backends/aws/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ def create_instances_struct(
spot: bool,
subnet_id: Optional[str] = None,
allocate_public_ip: bool = True,
placement_group_name: Optional[str] = None,
) -> Dict[str, Any]:
struct = dict(
struct: Dict[str, Any] = dict(
BlockDeviceMappings=[
{
"DeviceName": "/dev/sda1",
Expand Down Expand Up @@ -151,6 +152,12 @@ def create_instances_struct(
]
else:
struct["SecurityGroupIds"] = [security_group_id]

if placement_group_name is not None:
struct["Placement"] = {
"GroupName": placement_group_name,
}

return struct


Expand Down
24 changes: 22 additions & 2 deletions src/dstack/_internal/core/backends/base/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
InstanceConfiguration,
InstanceOfferWithAvailability,
)
from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
from dstack._internal.core.models.volumes import (
Volume,
Expand Down Expand Up @@ -62,8 +63,8 @@ def terminate_instance(
backend_data: Optional[str] = None,
) -> None:
"""
Terminates an instance by `instance_id`. If instance does not exist,
it should not raise errors but return silently.
Terminates an instance by `instance_id`.
If the instance does not exist, it should not raise errors but return silently.
"""
pass

Expand Down Expand Up @@ -95,6 +96,25 @@ def update_provisioning_data(
"""
pass

def create_placement_group(
self,
placement_group: PlacementGroup,
) -> PlacementGroupProvisioningData:
"""
Creates a placement group.
"""
raise NotImplementedError()

def delete_placement_group(
self,
placement_group: PlacementGroup,
):
"""
Deletes a placement group.
If the group does not exist, it should not raise errors but return silently.
"""
raise NotImplementedError()

def create_gateway(
self,
configuration: GatewayComputeConfiguration,
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/core/backends/remote/provisioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import paramiko
from gpuhunt import correct_gpu_memory_gib

# FIXME: ProvisioningError is a subclass of ComputeError and should not be used outside of Compute
from dstack._internal.core.errors import ProvisioningError
from dstack._internal.core.models.instances import (
Disk,
Expand Down
4 changes: 4 additions & 0 deletions src/dstack/_internal/core/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ class ComputeResourceNotFoundError(ComputeError):
pass


class PlacementGroupInUseError(ComputeError):
pass


class CLIError(DstackError):
pass

Expand Down
9 changes: 5 additions & 4 deletions src/dstack/_internal/core/models/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,13 @@ class DockerConfig(CoreModel):

class InstanceConfiguration(CoreModel):
project_name: str
instance_name: str # unique in pool
instance_id: Optional[str] = None
ssh_keys: List[SSHKey]
job_docker_config: Optional[DockerConfig]
instance_name: str
user: str # dstack user name
ssh_keys: List[SSHKey]
instance_id: Optional[str] = None
availability_zone: Optional[str] = None
placement_group_name: Optional[str] = None
job_docker_config: Optional[DockerConfig] # FIXME: cannot find any usages – remove?

def get_public_keys(self) -> List[str]:
return [ssh_key.public.strip() for ssh_key in self.ssh_keys]
Expand Down
27 changes: 27 additions & 0 deletions src/dstack/_internal/core/models/placement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from enum import Enum
from typing import Optional

from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import CoreModel


class PlacementStrategy(str, Enum):
CLUSTER = "cluster"


class PlacementGroupConfiguration(CoreModel):
backend: BackendType
region: str
placement_strategy: PlacementStrategy


class PlacementGroupProvisioningData(CoreModel):
backend: BackendType # can be different from configuration backend
backend_data: Optional[str] = None


class PlacementGroup(CoreModel):
name: str
project_name: str
configuration: PlacementGroupConfiguration
provisioning_data: Optional[PlacementGroupProvisioningData] = None
4 changes: 4 additions & 0 deletions src/dstack/_internal/server/background/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from dstack._internal.server.background.tasks.process_instances import (
process_instances,
)
from dstack._internal.server.background.tasks.process_placement_groups import (
process_placement_groups,
)
from dstack._internal.server.background.tasks.process_running_jobs import process_running_jobs
from dstack._internal.server.background.tasks.process_runs import process_runs
from dstack._internal.server.background.tasks.process_submitted_jobs import process_submitted_jobs
Expand Down Expand Up @@ -45,5 +48,6 @@ def start_background_tasks() -> AsyncIOScheduler:
process_submitted_volumes, IntervalTrigger(seconds=10, jitter=2), max_instances=5
)
_scheduler.add_job(process_fleets, IntervalTrigger(seconds=15))
_scheduler.add_job(process_placement_groups, IntervalTrigger(seconds=30))
_scheduler.start()
return _scheduler
15 changes: 13 additions & 2 deletions src/dstack/_internal/server/background/tasks/process_fleets.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from sqlalchemy import select
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload

from dstack._internal.core.models.fleets import FleetStatus
from dstack._internal.server.db import get_session_ctx
from dstack._internal.server.models import FleetModel
from dstack._internal.server.models import FleetModel, PlacementGroupModel
from dstack._internal.server.services.fleets import (
fleet_model_to_fleet,
is_fleet_empty,
Expand Down Expand Up @@ -73,5 +73,16 @@ async def _autodelete_fleet(session: AsyncSession, fleet_model: FleetModel):
fleet_model.status = FleetStatus.TERMINATED
fleet_model.deleted = True
fleet_model.last_processed_at = get_current_datetime()
await _mark_placement_groups_as_ready_for_deletion(session=session, fleet_model=fleet_model)
await session.commit()
logger.info("Fleet %s deleted", fleet_model.name)


async def _mark_placement_groups_as_ready_for_deletion(
session: AsyncSession, fleet_model: FleetModel
):
await session.execute(
update(PlacementGroupModel)
.where(PlacementGroupModel.fleet_id == fleet_model.id)
.values(fleet_deleted=True)
)
Loading

0 comments on commit c187166

Please sign in to comment.