From 00aec78630b39374a9acd494c7ffa360119bcde5 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 6 May 2024 18:17:01 +0500 Subject: [PATCH] Optimize ProjectModel loading (#1199) --- src/dstack/_internal/server/app.py | 8 ++++---- src/dstack/_internal/server/db.py | 2 +- src/dstack/_internal/server/models.py | 6 ++---- src/dstack/_internal/server/services/pools.py | 19 ++++++++----------- src/dstack/_internal/server/settings.py | 18 +++++++++++------- .../tasks/test_process_submitted_jobs.py | 16 +++++++++++++--- 6 files changed, 39 insertions(+), 30 deletions(-) diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index 648ed4177..63df8f8d0 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -31,10 +31,10 @@ from dstack._internal.server.services.users import get_or_create_admin_user from dstack._internal.server.settings import ( DEFAULT_PROJECT_NAME, - DSTACK_DO_NOT_UPDATE_DEFAULT_PROJECT, - DSTACK_UPDATE_DEFAULT_PROJECT, + DO_NOT_UPDATE_DEFAULT_PROJECT, SERVER_CONFIG_FILE_PATH, SERVER_URL, + UPDATE_DEFAULT_PROJECT, ) from dstack._internal.server.utils.logging import configure_logging from dstack._internal.server.utils.routers import ( @@ -109,8 +109,8 @@ async def lifespan(app: FastAPI): project_name=DEFAULT_PROJECT_NAME, url=SERVER_URL, token=admin.token, - default=DSTACK_UPDATE_DEFAULT_PROJECT, - no_default=DSTACK_DO_NOT_UPDATE_DEFAULT_PROJECT, + default=UPDATE_DEFAULT_PROJECT, + no_default=DO_NOT_UPDATE_DEFAULT_PROJECT, ) if settings.SERVER_BUCKET is not None: init_default_storage() diff --git a/src/dstack/_internal/server/db.py b/src/dstack/_internal/server/db.py index dec0a7bd7..b1ad11f82 100644 --- a/src/dstack/_internal/server/db.py +++ b/src/dstack/_internal/server/db.py @@ -14,7 +14,7 @@ class Database: def __init__(self, url: str): self.url = url - self.engine = create_async_engine(self.url, echo=False) + self.engine = create_async_engine(self.url, echo=settings.SQL_ECHO_ENABLED) self.session_maker = sessionmaker( bind=self.engine, expire_on_commit=False, class_=AsyncSession ) diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 941f6f905..1615f4a99 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -97,9 +97,7 @@ class ProjectModel(BaseModel): default_pool_id: Mapped[Optional[UUIDType]] = mapped_column( ForeignKey("pools.id", use_alter=True, ondelete="SET NULL"), nullable=True ) - default_pool: Mapped[Optional["PoolModel"]] = relationship( - foreign_keys=[default_pool_id], lazy="selectin" - ) + default_pool: Mapped[Optional["PoolModel"]] = relationship(foreign_keys=[default_pool_id]) class MemberModel(BaseModel): @@ -350,5 +348,5 @@ class InstanceModel(BaseModel): # current job job_id: Mapped[Optional[uuid.UUID]] = mapped_column(ForeignKey("jobs.id")) - job: Mapped[Optional["JobModel"]] = relationship(back_populates="instance", lazy="immediate") + job: Mapped[Optional["JobModel"]] = relationship(back_populates="instance", lazy="joined") last_job_processed_at: Mapped[Optional[datetime]] = mapped_column(DateTime) diff --git a/src/dstack/_internal/server/services/pools.py b/src/dstack/_internal/server/services/pools.py index c05da6ba4..267c5cb98 100644 --- a/src/dstack/_internal/server/services/pools.py +++ b/src/dstack/_internal/server/services/pools.py @@ -62,20 +62,12 @@ async def get_pool( return res.one_or_none() -async def get_named_or_default_pool( - session: AsyncSession, project: ProjectModel, pool_name: Optional[str] -) -> Optional[PoolModel]: - if pool_name is not None: - return await get_pool(session, project, pool_name) - return project.default_pool - - async def get_or_create_pool_by_name( session: AsyncSession, project: ProjectModel, pool_name: Optional[str] ) -> PoolModel: if pool_name is None: - if project.default_pool is not None: - return project.default_pool + if project.default_pool_id is not None: + return await get_default_pool_or_error(session, project) default_pool = await get_pool(session, project, DEFAULT_POOL_NAME) if default_pool is not None: await set_default_pool(session, project, DEFAULT_POOL_NAME) @@ -87,6 +79,11 @@ async def get_or_create_pool_by_name( return await create_pool(session, project, pool_name) +async def get_default_pool_or_error(session: AsyncSession, project: ProjectModel) -> PoolModel: + res = await session.execute(select(PoolModel).where(PoolModel.id == project.default_pool_id)) + return res.scalar_one() + + async def create_pool(session: AsyncSession, project: ProjectModel, name: str) -> PoolModel: pool = await get_pool(session, project, name) if pool is not None: @@ -98,7 +95,7 @@ async def create_pool(session: AsyncSession, project: ProjectModel, name: str) - session.add(pool) await session.commit() await session.refresh(pool) - if project.default_pool is None: + if project.default_pool_id is None: await set_default_pool(session, project, pool.name) return pool diff --git a/src/dstack/_internal/server/settings.py b/src/dstack/_internal/server/settings.py index b69b0c4df..847a13040 100644 --- a/src/dstack/_internal/server/settings.py +++ b/src/dstack/_internal/server/settings.py @@ -27,19 +27,12 @@ SERVER_CONFIG_DISABLED = os.getenv("DSTACK_SERVER_CONFIG_DISABLED") is not None SERVER_CONFIG_ENABLED = not SERVER_CONFIG_DISABLED -LOCAL_BACKEND_ENABLED = os.getenv("DSTACK_LOCAL_BACKEND_ENABLED") is not None SERVER_BUCKET = os.getenv("DSTACK_SERVER_BUCKET") SERVER_BUCKET_REGION = os.getenv("DSTACK_SERVER_BUCKET_REGION", "eu-west-1") DEFAULT_PROJECT_NAME = "main" -DSTACK_UPDATE_DEFAULT_PROJECT = os.getenv("DSTACK_UPDATE_DEFAULT_PROJECT") is not None -DSTACK_DO_NOT_UPDATE_DEFAULT_PROJECT = ( - os.getenv("DSTACK_DO_NOT_UPDATE_DEFAULT_PROJECT") is not None -) -SKIP_GATEWAY_UPDATE = bool(os.getenv("DSTACK_SKIP_GATEWAY_UPDATE", None)) - SENTRY_DSN = os.getenv("DSTACK_SENTRY_DSN") SENTRY_TRACES_SAMPLE_RATE = float(os.getenv("DSTACK_SENTRY_TRACES_SAMPLE_RATE", 0.1)) @@ -51,3 +44,14 @@ ACME_EAB_HMAC_KEY = os.getenv("DSTACK_ACME_EAB_HMAC_KEY") USER_PROJECT_DEFAULT_QUOTA = int(os.getenv("DSTACK_USER_PROJECT_DEFAULT_QUOTA", 10)) + + +# Development settings + +SQL_ECHO_ENABLED = os.getenv("DSTACK_SQL_ECHO_ENABLED") is not None + +LOCAL_BACKEND_ENABLED = os.getenv("DSTACK_LOCAL_BACKEND_ENABLED") is not None + +UPDATE_DEFAULT_PROJECT = os.getenv("DSTACK_UPDATE_DEFAULT_PROJECT") is not None +DO_NOT_UPDATE_DEFAULT_PROJECT = os.getenv("DSTACK_DO_NOT_UPDATE_DEFAULT_PROJECT") is not None +SKIP_GATEWAY_UPDATE = bool(os.getenv("DSTACK_SKIP_GATEWAY_UPDATE", None)) diff --git a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py index 948851bb6..33b0c6ef1 100644 --- a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py @@ -21,7 +21,7 @@ JobTerminationReason, ) from dstack._internal.server.background.tasks.process_submitted_jobs import process_submitted_jobs -from dstack._internal.server.models import JobModel +from dstack._internal.server.models import JobModel, ProjectModel from dstack._internal.server.services.pools import ( get_or_create_pool_by_name, ) @@ -117,7 +117,12 @@ async def test_provisiones_job(self, test_db, session: AsyncSession): assert job is not None assert job.status == JobStatus.PROVISIONING - await session.refresh(project) + res = await session.execute( + select(ProjectModel) + .where(ProjectModel.id == project.id) + .options(joinedload(ProjectModel.default_pool)) + ) + project = res.scalar_one() assert project.default_pool.name == DEFAULT_POOL_NAME instance_offer = InstanceOfferWithAvailability.parse_raw( @@ -165,7 +170,12 @@ async def test_fails_job_when_no_capacity(self, test_db, session: AsyncSession): assert job.status == JobStatus.TERMINATING assert job.termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY - await session.refresh(project) + res = await session.execute( + select(ProjectModel) + .where(ProjectModel.id == project.id) + .options(joinedload(ProjectModel.default_pool)) + ) + project = res.scalar_one() assert not project.default_pool.instances @pytest.mark.asyncio