Skip to content

Commit

Permalink
feat: Add image_node and vfolder_node fields to ComputeSession schema
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Oct 29, 2024
1 parent cff51b8 commit 96883da
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 52 deletions.
1 change: 1 addition & 0 deletions changes/2987.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add image_node and vfolder_node fields to ComputeSession schema
39 changes: 17 additions & 22 deletions src/ai/backend/manager/api/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -1045,9 +1045,17 @@ type ComputeSessionNode implements Node {
vfolder_mounts: [String]
occupied_slots: JSONString
requested_slots: JSONString

"""Added in 24.12.0."""
image_references: [String]

Check notice on line 1050 in src/ai/backend/manager/api/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Field 'image_references' was added to object type 'ComputeSessionNode'

Field 'image_references' was added to object type 'ComputeSessionNode'

"""Added in 24.12.0."""
vfolder_nodes: [VirtualFolderNode]

Check notice on line 1053 in src/ai/backend/manager/api/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Field 'vfolder_nodes' was added to object type 'ComputeSessionNode'

Field 'vfolder_nodes' was added to object type 'ComputeSessionNode'
num_queries: BigInt
inference_metrics: JSONString
kernel_nodes(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): KernelConnection

"""Added in 24.9.0."""
kernel_nodes: [KernelNode]

Check notice on line 1058 in src/ai/backend/manager/api/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Field 'ComputeSessionNode.kernel_nodes' has description 'Added in 24.9.0.'

Field 'ComputeSessionNode.kernel_nodes' has description 'Added in 24.9.0.'

Check failure on line 1058 in src/ai/backend/manager/api/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Field 'ComputeSessionNode.kernel_nodes' changed type from 'KernelConnection' to '[KernelNode]'

Field 'ComputeSessionNode.kernel_nodes' changed type from 'KernelConnection' to '[KernelNode]'

Check failure on line 1058 in src/ai/backend/manager/api/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Argument 'filter: String' was removed from field 'ComputeSessionNode.kernel_nodes'

Removing a field argument is a breaking change because it will cause existing queries that use this argument to error.

Check failure on line 1058 in src/ai/backend/manager/api/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Argument 'order: String' was removed from field 'ComputeSessionNode.kernel_nodes'

Removing a field argument is a breaking change because it will cause existing queries that use this argument to error.

Check failure on line 1058 in src/ai/backend/manager/api/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Argument 'offset: Int' was removed from field 'ComputeSessionNode.kernel_nodes'

Removing a field argument is a breaking change because it will cause existing queries that use this argument to error.

Check failure on line 1058 in src/ai/backend/manager/api/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Argument 'before: String' was removed from field 'ComputeSessionNode.kernel_nodes'

Removing a field argument is a breaking change because it will cause existing queries that use this argument to error.

Check failure on line 1058 in src/ai/backend/manager/api/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Argument 'after: String' was removed from field 'ComputeSessionNode.kernel_nodes'

Removing a field argument is a breaking change because it will cause existing queries that use this argument to error.

Check failure on line 1058 in src/ai/backend/manager/api/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Argument 'first: Int' was removed from field 'ComputeSessionNode.kernel_nodes'

Removing a field argument is a breaking change because it will cause existing queries that use this argument to error.

Check failure on line 1058 in src/ai/backend/manager/api/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Argument 'last: Int' was removed from field 'ComputeSessionNode.kernel_nodes'

Removing a field argument is a breaking change because it will cause existing queries that use this argument to error.

"""Added in 24.09.0."""
dependents(filter: String, order: String, offset: Int, before: String, after: String, first: Int, last: Int): ComputeSessionConnection
Expand All @@ -1064,27 +1072,6 @@ Added in 24.09.0. One of ['read_attribute', 'update_attribute', 'delete_session'
"""
scalar SessionPermissionValueField

"""Added in 24.09.0."""
type KernelConnection {
"""Pagination data for this connection."""
pageInfo: PageInfo!

"""Contains the nodes in this connection."""
edges: [KernelEdge]!

"""Total count of the GQL nodes of the query."""
count: Int
}

"""Added in 24.09.0. A Relay edge containing a `Kernel` and its cursor."""
type KernelEdge {
"""The item at the end of the edge"""
node: KernelNode

"""A cursor for use in pagination"""
cursor: String!
}

"""Added in 24.09.0."""
type KernelNode implements Node {
"""The ID of the object"""
Expand All @@ -1098,6 +1085,14 @@ type KernelNode implements Node {
cluster_hostname: String
session_id: UUID
image: ImageNode

"""Added in 24.12.0."""
image_reference: String

Check notice on line 1090 in src/ai/backend/manager/api/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Field 'image_reference' was added to object type 'KernelNode'

Field 'image_reference' was added to object type 'KernelNode'

"""
Added in 24.12.0. The architecture that the image of this kernel requires
"""
architecture: String

Check notice on line 1095 in src/ai/backend/manager/api/schema.graphql

View workflow job for this annotation

GitHub Actions / GraphQL Inspector

Field 'architecture' was added to object type 'KernelNode'

Field 'architecture' was added to object type 'KernelNode'
status: String
status_changed: DateTime
status_info: String
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/manager/models/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ async def from_row(
return cls(
endpoint_id=row.id,
# image="", # deprecated, row.image_object.name,
image_object=ImageNode.from_row(row.image_row),
image_object=ImageNode.from_row(ctx, row.image_row),
domain=row.domain,
project=row.project,
resource_group=row.resource_group,
Expand Down
45 changes: 37 additions & 8 deletions src/ai/backend/manager/models/gql_models/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AsyncIterator,
List,
Optional,
Self,
overload,
)
from uuid import UUID
Expand All @@ -27,12 +28,13 @@
ImageAlias,
)
from ai.backend.logging import BraceStyleAdapter
from ai.backend.manager.models.base import batch_multiresult_in_scalar_stream
from ai.backend.manager.models.container_registry import ContainerRegistryRow, ContainerRegistryType

from ...api.exceptions import ImageNotFound, ObjectNotFound
from ...defs import DEFAULT_IMAGE_ARCH
from ..base import set_if_set
from ..gql_relay import AsyncNode
from ..gql_relay import AsyncNode, Connection
from ..image import (
ImageAliasRow,
ImageIdentifier,
Expand Down Expand Up @@ -330,16 +332,37 @@ class Meta:
graphene.String, description="Added in 24.03.4. The array of image aliases."
)

@classmethod
async def batch_load_by_name_and_arch(
cls,
graph_ctx: GraphQueryContext,
name_and_arch: Sequence[tuple[str, str]],
) -> Sequence[Sequence[ImageNode]]:
query = (
sa.select(ImageRow)
.where(sa.tuple_(ImageRow.name, ImageRow.architecture).in_(name_and_arch))
.options(selectinload(ImageRow.aliases))
)
async with graph_ctx.db.begin_readonly_session() as db_session:
return await batch_multiresult_in_scalar_stream(
graph_ctx,
db_session,
query,
cls,
name_and_arch,
lambda row: (row.name, row.architecture),
)

@overload
@classmethod
def from_row(cls, row: ImageRow) -> ImageNode: ...
def from_row(cls, graph_ctx: GraphQueryContext, row: ImageRow) -> Self: ...

@overload
@classmethod
def from_row(cls, row: None) -> None: ...
def from_row(cls, graph_ctx: GraphQueryContext, row: None) -> None: ...

@classmethod
def from_row(cls, row: ImageRow | None) -> ImageNode | None:
def from_row(cls, graph_ctx: GraphQueryContext, row: ImageRow | None) -> Self | None:
if row is None:
return None
return cls(
Expand Down Expand Up @@ -401,7 +424,13 @@ async def get_node(cls, info: graphene.ResolveInfo, id: str) -> ImageNode:
image_row = await db_session.scalar(query)
if image_row is None:
raise ValueError(f"Image not found (id: {image_id})")
return cls.from_row(image_row)
return cls.from_row(graph_ctx, image_row)


class ImageConnection(Connection):
class Meta:
node = ImageNode
description = "Added in 24.12.0."


class ForgetImageById(graphene.Mutation):
Expand Down Expand Up @@ -453,7 +482,7 @@ async def mutate(
):
return ForgetImageById(ok=False, msg="Forbidden")
await session.delete(image_row)
return ForgetImageById(ok=True, msg="", image=ImageNode.from_row(image_row))
return ForgetImageById(ok=True, msg="", image=ImageNode.from_row(ctx, image_row))


class ForgetImage(graphene.Mutation):
Expand Down Expand Up @@ -500,7 +529,7 @@ async def mutate(
):
return ForgetImage(ok=False, msg="Forbidden")
await session.delete(image_row)
return ForgetImage(ok=True, msg="", image=ImageNode.from_row(image_row))
return ForgetImage(ok=True, msg="", image=ImageNode.from_row(ctx, image_row))


class UntagImageFromRegistry(graphene.Mutation):
Expand Down Expand Up @@ -566,7 +595,7 @@ async def mutate(
scanner = HarborRegistry_v2(ctx.db, image_row.image_ref.registry, registry_info)
await scanner.untag(image_row.image_ref)

return UntagImageFromRegistry(ok=True, msg="", image=ImageNode.from_row(image_row))
return UntagImageFromRegistry(ok=True, msg="", image=ImageNode.from_row(ctx, image_row))


class PreloadImage(graphene.Mutation):
Expand Down
27 changes: 22 additions & 5 deletions src/ai/backend/manager/models/gql_models/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from typing import (
TYPE_CHECKING,
Any,
Optional,
Self,
cast,
)

import graphene
Expand All @@ -14,7 +16,7 @@

from ai.backend.common import msgpack, redis_helper
from ai.backend.common.types import KernelId, SessionId
from ai.backend.manager.models.base import batch_multiresult_in_session
from ai.backend.manager.models.base import batch_multiresult_in_scalar_stream

from ..gql_relay import AsyncNode, Connection
from ..kernel import KernelRow, KernelStatus
Expand Down Expand Up @@ -45,6 +47,10 @@ class Meta:

# image
image = graphene.Field(ImageNode)
image_reference = graphene.String(description="Added in 24.12.0.")
architecture = graphene.String(
description="Added in 24.12.0. The architecture that the image of this kernel requires"
)

# status
status = graphene.String()
Expand Down Expand Up @@ -72,11 +78,9 @@ async def batch_load_by_session_id(
graph_ctx: GraphQueryContext,
session_ids: Sequence[SessionId],
) -> Sequence[Sequence[Self]]:
from ..kernel import kernels

async with graph_ctx.db.begin_readonly_session() as db_sess:
query = sa.select(kernels).where(kernels.c.session_id.in_(session_ids))
return await batch_multiresult_in_session(
query = sa.select(KernelRow).where(KernelRow.session_id.in_(session_ids))
return await batch_multiresult_in_scalar_stream(
graph_ctx,
db_sess,
query,
Expand All @@ -102,6 +106,8 @@ def from_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Self:
local_rank=row.local_rank,
cluster_role=row.cluster_role,
session_id=row.session_id,
architecture=row.architecture,
image_reference=row.image,
status=row.status,
status_changed=row.status_changed,
status_info=row.status_info,
Expand All @@ -118,6 +124,17 @@ def from_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Self:
preopen_ports=row.preopen_ports,
)

async def resolve_image(self, info: graphene.ResolveInfo) -> Optional[ImageNode]:
graph_ctx: GraphQueryContext = info.context
loader = graph_ctx.dataloader_manager.get_loader_by_func(
graph_ctx, ImageNode.batch_load_by_name_and_arch
)
images = cast(list[ImageNode], await loader.load((self.image_reference, self.architecture)))
try:
return images[0]
except IndexError:
return None

async def resolve_live_stat(self, info: graphene.ResolveInfo) -> dict[str, Any] | None:
graph_ctx: GraphQueryContext = info.context
loader = graph_ctx.dataloader_manager.get_loader_by_func(
Expand Down
45 changes: 31 additions & 14 deletions src/ai/backend/manager/models/gql_models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

from ai.backend.common.types import ClusterMode, SessionId, SessionResult
from ai.backend.common.types import ClusterMode, SessionId, SessionResult, VFolderMount
from ai.backend.manager.idle import ReportInfo

from ..base import (
Expand Down Expand Up @@ -54,7 +54,8 @@
)
from ..user import UserRole
from ..utils import execute_with_txn_retry
from .kernel import KernelConnection, KernelNode
from .kernel import KernelNode
from .vfolder import VirtualFolderNode

if TYPE_CHECKING:
from ..gql import GraphQueryContext
Expand Down Expand Up @@ -191,14 +192,23 @@ class Meta:
vfolder_mounts = graphene.List(lambda: graphene.String)
occupied_slots = graphene.JSONString()
requested_slots = graphene.JSONString()
image_references = graphene.List(
lambda: graphene.String,
description="Added in 24.12.0.",
)
vfolder_nodes = graphene.List(
lambda: VirtualFolderNode,
description="Added in 24.12.0.",
)

# statistics
num_queries = BigInt()
inference_metrics = graphene.JSONString()

# relations
kernel_nodes = PaginatedConnectionField(
KernelConnection,
kernel_nodes = graphene.List(
lambda: KernelNode,
description="Added in 24.9.0.",
)
dependents = PaginatedConnectionField(
"ai.backend.manager.models.gql_models.session.ComputeSessionConnection",
Expand Down Expand Up @@ -259,6 +269,7 @@ def from_row(
vfolder_mounts=row.vfolder_mounts,
occupied_slots=row.occupying_slots.to_json(),
requested_slots=row.requested_slots.to_json(),
image_references=row.images,
# statistics
num_queries=row.num_queries,
)
Expand All @@ -272,20 +283,27 @@ async def resolve_idle_checks(self, info: graphene.ResolveInfo) -> dict[str, Any
)
return await loader.load(self.row_id)

async def resolve_vfolder_nodes(
self,
info: graphene.ResolveInfo,
) -> list[VirtualFolderNode]:
ctx: GraphQueryContext = info.context
vfolder_mounts = cast(list[VFolderMount], self.vfolder_mounts)
_folder_ids = [vf_mount.vfid.folder_id for vf_mount in vfolder_mounts]
loader = ctx.dataloader_manager.get_loader_by_func(ctx, VirtualFolderNode.batch_load_by_id)
result = cast(list[list[VirtualFolderNode]], await loader.load_many(_folder_ids))
try:
return result[0]
except IndexError:
return []

async def resolve_kernel_nodes(
self,
info: graphene.ResolveInfo,
) -> ConnectionResolverResult[KernelNode]:
) -> list[KernelNode]:
ctx: GraphQueryContext = info.context
loader = ctx.dataloader_manager.get_loader(ctx, "KernelNode.by_session_id")
kernels = await loader.load(self.row_id)
return ConnectionResolverResult(
kernels,
None,
None,
None,
total_count=len(kernels),
)
return await loader.load(self.row_id)

async def resolve_dependees(
self,
Expand Down Expand Up @@ -489,7 +507,6 @@ async def get_accessible_connection(
before=before,
last=last,
)
query = query.options(selectinload(SessionRow.kernels))
async with graph_ctx.db.connect() as db_conn:
user = graph_ctx.user
client_ctx = ClientContext(
Expand Down
22 changes: 21 additions & 1 deletion src/ai/backend/manager/models/gql_models/vfolder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import uuid
from collections.abc import Iterable, Mapping
from collections.abc import Iterable, Mapping, Sequence
from datetime import datetime
from typing import (
TYPE_CHECKING,
Expand All @@ -25,6 +25,7 @@
VFolderID,
VFolderUsageMode,
)
from ai.backend.manager.models.base import batch_multiresult_in_scalar_stream

from ...api.exceptions import (
VFolderOperationFailed,
Expand Down Expand Up @@ -216,6 +217,25 @@ def from_row(
result.permissions = [] if permissions is None else permissions
return result

@classmethod
async def batch_load_by_id(
cls,
graph_ctx: GraphQueryContext,
folder_ids: Sequence[uuid.UUID],
) -> Sequence[Sequence[Self]]:
query = (
sa.select(VFolderRow)
.where(VFolderRow.id.in_(folder_ids))
.options(
joinedload(VFolderRow.user_row),
joinedload(VFolderRow.group_row),
)
)
async with graph_ctx.db.begin_readonly_session() as db_session:
return await batch_multiresult_in_scalar_stream(
graph_ctx, db_session, query, cls, folder_ids, lambda row: row.id
)

@classmethod
async def get_node(cls, info: graphene.ResolveInfo, id: str) -> Self:
graph_ctx: GraphQueryContext = info.context
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/manager/models/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ def parse_row(cls, ctx: GraphQueryContext, row: KernelRow) -> Mapping[str, Any]:
"session_id": row.session_id,
# image
"image": row.image,
"image_object": ImageNode.from_row(row.image_row),
"image_object": ImageNode.from_row(ctx, row.image_row),
"architecture": row.architecture,
"registry": row.registry,
# status
Expand Down

0 comments on commit 96883da

Please sign in to comment.