Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make gateway creation async #1236

Merged
merged 5 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/dstack/_internal/cli/utils/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ def print_gateways_table(gateways: List[Gateway], verbose: bool = False):
table.add_column("ADDRESS")
table.add_column("DOMAIN")
table.add_column("DEFAULT")
table.add_column("STATUS")
if verbose:
table.add_column("INSTANCE ID")
table.add_column("ERROR")
table.add_column("CREATED")
table.add_column("INSTANCE_ID")

gateways = sorted(gateways, key=lambda g: g.backend)
for backend, backend_gateways in itertools.groupby(gateways, key=lambda g: g.backend):
Expand All @@ -30,10 +32,12 @@ def print_gateways_table(gateways: List[Gateway], verbose: bool = False):
gateway.ip_address,
gateway.wildcard_domain,
"✓" if gateway.default else "",
gateway.status,
]
if verbose:
renderables.append(gateway.instance_id)
renderables.append(gateway.status_message)
renderables.append(pretty_date(gateway.created_at))
renderables.append(gateway.instance_id)
table.add_row(*renderables)
console.print(table)
console.print()
4 changes: 2 additions & 2 deletions src/dstack/_internal/core/backends/kubernetes/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
get_api_from_config_data,
get_cluster_public_ip,
)
from dstack._internal.core.errors import ComputeError, GatewayError
from dstack._internal.core.errors import ComputeError
from dstack._internal.core.models.backends.base import BackendType

# TODO: update import as KNOWN_GPUS becomes public
Expand Down Expand Up @@ -286,7 +286,7 @@ def create_gateway(
)
if hostname is None:
self.terminate_instance(instance_name, region="-")
raise GatewayError(
raise ComputeError(
"Failed to get gateway hostname. "
"Ensure the Kubernetes cluster supports Load Balancer services."
)
Expand Down
10 changes: 10 additions & 0 deletions src/dstack/_internal/core/models/gateways.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
from enum import Enum
from typing import Optional, Union

from pydantic import Field
Expand All @@ -8,6 +9,13 @@
from dstack._internal.core.models.common import CoreModel


class GatewayStatus(str, Enum):
SUBMITTED = "submitted"
PROVISIONING = "provisioning"
RUNNING = "running"
FAILED = "failed"


class GatewayConfiguration(CoreModel):
type: Literal["gateway"] = "gateway"
name: Annotated[Optional[str], Field(description="The gateway name")] = None
Expand Down Expand Up @@ -40,6 +48,8 @@ class Gateway(CoreModel):
default: bool
created_at: datetime.datetime
backend: BackendType
status: GatewayStatus
status_message: Optional[str]
configuration: GatewayConfiguration


Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/core/models/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class DockerConfig(CoreModel):
class InstanceConfiguration(CoreModel):
project_name: str
instance_name: str # unique in pool
instance_id: str
instance_id: Optional[str]
ssh_keys: List[SSHKey]
job_docker_config: Optional[DockerConfig]
user: str # dstack user name
Expand Down
8 changes: 6 additions & 2 deletions src/dstack/_internal/server/background/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.interval import IntervalTrigger

from dstack._internal.server.background.tasks.process_gateways import process_gateways
from dstack._internal.server.background.tasks.process_gateways import (
process_gateways_connections,
process_submitted_gateways,
)
from dstack._internal.server.background.tasks.process_instances import (
process_instances,
terminate_idle_instances,
Expand All @@ -27,6 +30,7 @@ def start_background_tasks() -> AsyncIOScheduler:
_scheduler.add_job(process_instances, IntervalTrigger(seconds=10))
_scheduler.add_job(terminate_idle_instances, IntervalTrigger(seconds=10))
_scheduler.add_job(process_runs, IntervalTrigger(seconds=1))
_scheduler.add_job(process_gateways, IntervalTrigger(seconds=15))
_scheduler.add_job(process_gateways_connections, IntervalTrigger(seconds=15))
_scheduler.add_job(process_submitted_gateways, IntervalTrigger(seconds=10), max_instances=5)
_scheduler.start()
return _scheduler
139 changes: 136 additions & 3 deletions src/dstack/_internal/server/background/tasks/process_gateways.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,154 @@
import asyncio
from uuid import UUID

from dstack._internal.core.errors import SSHError
from dstack._internal.server.services.gateways import GatewayConnection, gateway_connections_pool
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload

from dstack._internal.core.errors import BackendError, BackendNotAvailable, SSHError
from dstack._internal.core.models.gateways import GatewayStatus
from dstack._internal.server.db import get_session_ctx
from dstack._internal.server.models import GatewayModel
from dstack._internal.server.services import backends as backends_services
from dstack._internal.server.services import gateways as gateways_services
from dstack._internal.server.services.gateways import (
PROCESSING_GATEWAYS_IDS,
PROCESSING_GATEWAYS_LOCK,
GatewayConnection,
create_gateway_compute,
gateway_connections_pool,
)
from dstack._internal.utils.common import get_current_datetime
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)


async def process_gateways():
async def process_gateways_connections():
# TODO(egor-s): distribute the load evenly
connections = await gateway_connections_pool.all()
await asyncio.gather(*(_process_connection(conn) for conn in connections))


async def process_submitted_gateways():
async with get_session_ctx() as session:
async with PROCESSING_GATEWAYS_LOCK:
res = await session.execute(
select(GatewayModel)
.where(
GatewayModel.status == GatewayStatus.SUBMITTED,
GatewayModel.id.not_in(PROCESSING_GATEWAYS_IDS),
)
.order_by(GatewayModel.last_processed_at.asc())
.limit(1)
)
gateway_model = res.scalar()
if gateway_model is None:
return

PROCESSING_GATEWAYS_IDS.add(gateway_model.id)

try:
await _process_gateway(gateway_id=gateway_model.id)
finally:
PROCESSING_GATEWAYS_IDS.remove(gateway_model.id)


async def _process_connection(conn: GatewayConnection):
try:
await conn.check_or_restart()
await conn.try_collect_stats()
except SSHError as e:
logger.error("Connection to gateway %s failed: %s", conn.ip_address, e)


async def _process_gateway(gateway_id: UUID):
async with get_session_ctx() as session:
res = await session.execute(
select(GatewayModel)
.where(GatewayModel.id == gateway_id)
.options(joinedload(GatewayModel.project))
)
gateway_model = res.scalar_one()
await _process_submitted_gateway(
session=session,
gateway_model=gateway_model,
)


async def _process_submitted_gateway(session: AsyncSession, gateway_model: GatewayModel):
logger.info("Started gateway %s provisioning", gateway_model.name)
configuration = gateways_services.get_gateway_configuration(gateway_model)
try:
(
backend_model,
backend,
) = await backends_services.get_project_backend_with_model_by_type_or_error(
project=gateway_model.project, backend_type=configuration.backend
)
except BackendNotAvailable:
gateway_model.status = GatewayStatus.FAILED
gateway_model.status_message = "Backend not available"
gateway_model.last_processed_at = get_current_datetime()
await session.commit()
return

try:
gateway_model.gateway_compute = await create_gateway_compute(
backend_compute=backend.compute(),
project_name=gateway_model.project.name,
configuration=configuration,
backend_id=backend_model.id,
)
session.add(gateway_model)
gateway_model.status = GatewayStatus.PROVISIONING
await session.commit()
await session.refresh(gateway_model)
except BackendError as e:
logger.info(
"Failed to create gateway compute for gateway %s: %s", gateway_model.name, repr(e)
)
gateway_model.status = GatewayStatus.FAILED
status_message = f"Backend error: {repr(e)}"
if len(e.args) > 0:
status_message = str(e.args[0])
gateway_model.status_message = status_message
gateway_model.last_processed_at = get_current_datetime()
await session.commit()
return
except Exception as e:
logger.exception(
"Got exception when creating gateway compute for gateway %s", gateway_model.name
)
gateway_model.status = GatewayStatus.FAILED
gateway_model.status_message = f"Unexpected error: {repr(e)}"
gateway_model.last_processed_at = get_current_datetime()
await session.commit()
return

connection = await gateways_services.connect_to_gateway_with_retry(
gateway_model.gateway_compute
)
if connection is None:
gateway_model.status = GatewayStatus.FAILED
gateway_model.status_message = "Failed to connect to gateway"
gateway_model.last_processed_at = get_current_datetime()
gateway_model.gateway_compute.deleted = True
await session.commit()
return

try:
await gateways_services.configure_gateway(connection)
except Exception:
logger.exception("Failed to configure gateway %s", gateway_model.name)
gateway_model.status = GatewayStatus.FAILED
gateway_model.status_message = "Failed to configure gateway"
gateway_model.last_processed_at = get_current_datetime()
await gateway_connections_pool.remove(gateway_model.gateway_compute.ip_address)
gateway_model.gateway_compute.active = False
await session.commit()
return

gateway_model.status = GatewayStatus.RUNNING
gateway_model.last_processed_at = get_current_datetime()
await session.commit()
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Add fields for async gateway creation

Revision ID: c154eece89da
Revises: 58aa5162dcc3
Create Date: 2024-05-16 14:18:29.044545

"""

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "c154eece89da"
down_revision = "58aa5162dcc3"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("gateway_computes", schema=None) as batch_op:
batch_op.add_column(sa.Column("active", sa.Boolean(), nullable=True))

with op.batch_alter_table("gateways", schema=None) as batch_op:
batch_op.add_column(
sa.Column(
"status",
sa.Enum("SUBMITTED", "PROVISIONING", "RUNNING", "FAILED", name="gatewaystatus"),
nullable=True,
)
)
batch_op.add_column(sa.Column("status_message", sa.Text(), nullable=True))
batch_op.add_column(sa.Column("last_processed_at", sa.DateTime(), nullable=True))

op.execute(
sa.sql.text("UPDATE gateway_computes SET active = NOT deleted WHERE active is NULL")
)
op.execute(sa.sql.text("UPDATE gateways SET status = 'RUNNING' WHERE status is NULL"))
op.execute(
sa.sql.text(
"UPDATE gateways SET last_processed_at = created_at WHERE last_processed_at is NULL"
)
)
with op.batch_alter_table("gateway_computes", schema=None) as batch_op:
batch_op.alter_column("active", nullable=False)

with op.batch_alter_table("gateways", schema=None) as batch_op:
batch_op.alter_column("status", nullable=False)
batch_op.alter_column("last_processed_at", nullable=False)

# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("gateways", schema=None) as batch_op:
batch_op.drop_column("last_processed_at")
batch_op.drop_column("status_message")
batch_op.drop_column("status")

with op.batch_alter_table("gateway_computes", schema=None) as batch_op:
batch_op.drop_column("active")

# ### end Alembic commands ###
6 changes: 6 additions & 0 deletions src/dstack/_internal/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from sqlalchemy_utils import UUIDType

from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.gateways import GatewayStatus
from dstack._internal.core.models.profiles import (
DEFAULT_POOL_TERMINATION_IDLE_TIME,
TerminationPolicy,
Expand Down Expand Up @@ -232,6 +233,9 @@ class GatewayModel(BaseModel):
wildcard_domain: Mapped[str] = mapped_column(String(100), nullable=True)
configuration: Mapped[Optional[str]] = mapped_column(Text)
created_at: Mapped[datetime] = mapped_column(DateTime, default=get_current_datetime)
status: Mapped[GatewayStatus] = mapped_column(Enum(GatewayStatus))
status_message: Mapped[Optional[str]] = mapped_column(Text)
last_processed_at: Mapped[datetime] = mapped_column(DateTime)

project_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("projects.id", ondelete="CASCADE"))
project: Mapped["ProjectModel"] = relationship(foreign_keys=[project_id])
Expand Down Expand Up @@ -268,6 +272,8 @@ class GatewayComputeModel(BaseModel):
ssh_private_key: Mapped[str] = mapped_column(Text)
ssh_public_key: Mapped[str] = mapped_column(Text)

# active means the server should maintain connection to gateway.
active: Mapped[bool] = mapped_column(Boolean, default=True)
deleted: Mapped[bool] = mapped_column(Boolean, server_default=false())


Expand Down
Loading
Loading