Skip to content

Commit

Permalink
update image query functions
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Nov 15, 2024
1 parent cb149cf commit cc689b4
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 30 deletions.
8 changes: 2 additions & 6 deletions src/ai/backend/agent/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,9 +503,7 @@ async def _pull(reporter: ProgressReporter) -> None:
)
if need_to_pull:
await self.agent.produce_event(
ImagePullStartedEvent(
image=str(img_ref),
)
ImagePullStartedEvent(image=str(img_ref), agent_id=self.agent.id)
)
image_pull_timeout = cast(
Optional[float], self.local_config["agent"]["api"]["pull-timeout"]
Expand All @@ -514,9 +512,7 @@ async def _pull(reporter: ProgressReporter) -> None:
img_ref, img_conf["registry"], timeout=image_pull_timeout
)
await self.agent.produce_event(
ImagePullFinishedEvent(
image=str(img_ref),
)
ImagePullFinishedEvent(image=str(img_ref), agent_id=self.agent.id)
)

task_id = await bgtask_mgr.start(_pull)
Expand Down
17 changes: 16 additions & 1 deletion src/ai/backend/common/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
from dataclasses import dataclass
from pathlib import Path, PurePath
from typing import (
TYPE_CHECKING,
Final,
Iterable,
Mapping,
NamedTuple,
Optional,
Self,
)

import aiohttp
Expand All @@ -31,6 +33,9 @@
from .service_ports import parse_service_ports
from .utils import is_ip_address_format, join_non_empty

if TYPE_CHECKING:
from .types import ImageConfig

__all__ = (
"arch_name_aliases",
"default_registry",
Expand Down Expand Up @@ -379,6 +384,16 @@ class ImageRef:
architecture: str
is_local: bool

@classmethod
def from_image_config(cls, config: ImageConfig) -> Self:
return cls.from_image_str(
config["canonical"],
config["project"],
config["registry"]["name"],
is_local=config["is_local"],
architecture=config["architecture"],
)

@classmethod
def from_image_str(
cls,
Expand All @@ -388,7 +403,7 @@ def from_image_str(
*,
architecture: str = "x86_64",
is_local: bool = False,
) -> ImageRef:
) -> Self:
"""
Parse the image reference string and return an ImageRef object from the string.
"""
Expand Down
1 change: 1 addition & 0 deletions src/ai/backend/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,7 @@ class ImageRegistry(TypedDict):

class ImageConfig(TypedDict):
canonical: str
project: Optional[str]
architecture: str
digest: str
repo_digest: Optional[str]
Expand Down
8 changes: 7 additions & 1 deletion src/ai/backend/manager/models/container_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import yarl
from graphql import Undefined, UndefinedType
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import load_only
from sqlalchemy.orm import load_only, relationship
from sqlalchemy.orm.exc import NoResultFound

from ai.backend.common.exception import UnknownImageRegistry
Expand Down Expand Up @@ -82,6 +82,12 @@ class ContainerRegistryRow(Base):
)
extra = sa.Column("extra", sa.JSON, nullable=True, default=None)

image_rows = relationship(
"ImageRow",
back_populates="registry_row",
primaryjoin="ContainerRegistryRow.id == foreign(ImageRow.registry_id)",
)

@classmethod
async def get(
cls,
Expand Down
7 changes: 7 additions & 0 deletions src/ai/backend/manager/models/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,12 @@ class ImageRow(Base):
# sessions = relationship("SessionRow", back_populates="image_row")
endpoints = relationship("EndpointRow", back_populates="image_row")

registry_row = relationship(
"ContainerRegistryRow",
back_populates="image_rows",
primaryjoin="ContainerRegistryRow.id == foreign(ImageRow.registry_id)",
)

def __init__(
self,
name,
Expand Down Expand Up @@ -559,6 +565,7 @@ async def bulk_get_image_configs(

image_conf: ImageConfig = {
"architecture": ref.architecture,
"project": resolved_image_info.project,
"canonical": ref.canonical,
"is_local": resolved_image_info.image_ref.is_local,
"digest": resolved_image_info.trimmed_digest,
Expand Down
55 changes: 33 additions & 22 deletions src/ai/backend/manager/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,25 +1369,26 @@ async def _check_and_pull_in_one_agent(
self,
agent_alloc_ctx: AgentAllocationContext,
kernel_agent_bindings: Sequence[KernelAgentBinding],
image_configs: Mapping[str, ImageConfig],
) -> dict[str, uuid.UUID]:
image_configs: Mapping[ImageIdentifier, ImageConfig],
) -> dict[ImageIdentifier, uuid.UUID]:
"""
Return {str(ImageRef): bgtask_id}
Return {ImageIdentifier(): bgtask_id}
"""
assert agent_alloc_ctx.agent_id is not None

result: dict[str, uuid.UUID] = {}
result: dict[ImageIdentifier, uuid.UUID] = {}
async with self.agent_cache.rpc_context(
agent_alloc_ctx.agent_id,
) as rpc:
for img, conf in image_configs.items():
resp = cast(dict[str, str], await rpc.call.check_and_pull(conf))
resp = await rpc.call.check_and_pull(conf)
resp = cast(dict[str, str], resp)
bgtask_id = resp["bgtask_id"]
result[img] = uuid.UUID(bgtask_id)

return result

async def check_before_start(
async def check_and_pull_images(
self,
scheduled_session: SessionRow,
) -> None:
Expand All @@ -1404,18 +1405,7 @@ async def check_before_start(
for k in scheduled_session.kernels
]

# Aggregate image registry information
_image_refs: set[ImageRef] = set([item.kernel.image_ref for item in kernel_agent_bindings])
auto_pull = cast(str, self.shared_config["docker"]["image"]["auto_pull"])
async with self.db.connect() as db_conn:
configs = await bulk_get_image_configs(
_image_refs,
AutoPullBehavior(auto_pull),
db=self.db,
db_conn=db_conn,
etcd=self.shared_config.etcd,
)
img_ref_to_conf_map = {ImageRef.from_image_config(item): item for item in configs}

def _keyfunc(binding: KernelAgentBinding) -> AgentId:
if binding.agent_alloc_ctx.agent_id is None:
Expand All @@ -1435,12 +1425,33 @@ def _keyfunc(binding: KernelAgentBinding) -> AgentId:
items: list[KernelAgentBinding] = [*group_iterator]
# Within a group, agent_alloc_ctx are same.
agent_alloc_ctx = items[0].agent_alloc_ctx
_filtered_imgs: set[ImageRef] = {binding.kernel.image_ref for binding in items}
_img_conf_map = {
str(img): conf
for img, conf in img_ref_to_conf_map.items()
if img in _filtered_imgs
_filtered_imgs: set[ImageRef] = {
binding.kernel.image_ref
for binding in items
if binding.kernel.image_ref is not None
}
_img_conf_map: dict[ImageIdentifier, ImageConfig] = {}
for binding in items:
img_ref = binding.kernel.image_ref
img_row = binding.kernel.image_row
registry_row = img_row.registry_row
if img_ref is not None:
_img_conf_map[ImageIdentifier(str(img_ref), img_row.architecture)] = {
"architecture": img_row.architecture,
"project": img_row.project,
"canonical": img_ref.canonical,
"is_local": img_row.is_local,
"digest": img_row.trimmed_digest,
"labels": img_row.labels,
"repo_digest": None,
"registry": {
"name": img_ref.registry,
"url": registry_row.url,
"username": registry_row.username,
"password": registry_row.password,
},
"auto_pull": auto_pull,
}
tg.create_task(
self._check_and_pull_in_one_agent(agent_alloc_ctx, items, _img_conf_map)
)
Expand Down

0 comments on commit cc689b4

Please sign in to comment.