diff --git a/changes/2723.feature.md b/changes/2723.feature.md new file mode 100644 index 0000000000..a10da59f5e --- /dev/null +++ b/changes/2723.feature.md @@ -0,0 +1 @@ +Allow filter and order in endpointlist gql request. diff --git a/src/ai/backend/manager/models/endpoint.py b/src/ai/backend/manager/models/endpoint.py index 2534bd0ad5..9791109fa6 100644 --- a/src/ai/backend/manager/models/endpoint.py +++ b/src/ai/backend/manager/models/endpoint.py @@ -63,6 +63,8 @@ gql_mutation_wrapper, ) from .image import ImageNode, ImageRefType, ImageRow +from .minilang.ordering import OrderSpecItem, QueryOrderParser +from .minilang.queryfilter import FieldSpecItem, QueryFilterParser from .resource_policy import keypair_resource_policies from .routing import RouteStatus, Routing from .scaling_group import scaling_groups @@ -832,10 +834,27 @@ async def load_count( result = await conn.execute(query) return result.scalar() + _queryfilter_fieldspec: Mapping[str, FieldSpecItem] = { + "name": ("endpoints_name", None), + "model": ("endpoints_model", None), + "domain": ("endpoints_domain", None), + "url": ("endpoints_url", None), + "created_user_email": ("users_email", None), + } + + _queryorder_colmap: Mapping[str, OrderSpecItem] = { + "name": ("endpoints_name", None), + "created_at": ("endpoints_created_at", None), + "model": ("endpoints_model", None), + "domain": ("endpoints_domain", None), + "url": ("endpoints_url", None), + "created_user_email": ("users_email", None), + } + @classmethod async def load_slice( cls, - ctx, # ctx: GraphQueryContext, + ctx, #: GraphQueryContext, # ctx: GraphQueryContext, limit: int, offset: int, *, @@ -847,19 +866,19 @@ async def load_slice( ) -> Sequence["Endpoint"]: query = ( sa.select(EndpointRow) + .select_from( + sa.join( + EndpointRow, + UserRow, + EndpointRow.created_user == UserRow.uuid, + isouter=True, + ) + ) .limit(limit) .offset(offset) .options(selectinload(EndpointRow.image_row).selectinload(ImageRow.aliases)) .options(selectinload(EndpointRow.routings)) - .options(selectinload(EndpointRow.created_user_row)) .options(selectinload(EndpointRow.session_owner_row)) - .order_by(sa.desc(EndpointRow.created_at)) - .filter( - EndpointRow.lifecycle_stage.in_([ - EndpointLifecycle.CREATED, - EndpointLifecycle.DESTROYING, - ]) - ) ) if project is not None: query = query.where(EndpointRow.project == project) @@ -867,16 +886,18 @@ async def load_slice( query = query.where(EndpointRow.domain == domain_name) if user_uuid is not None: query = query.where(EndpointRow.session_owner == user_uuid) - """ + if filter is not None: - parser = QueryFilterParser(cls._queryfilter_fieldspec) - query = parser.append_filter(query, filter) + filter_parser = QueryFilterParser(cls._queryfilter_fieldspec) + query = filter_parser.append_filter(query, filter) if order is not None: - parser = QueryOrderParser(cls._queryorder_colmap) - query = parser.append_ordering(query, order) - """ - async with ctx.db.begin_readonly_session() as session: - result = await session.execute(query) + order_parser = QueryOrderParser(cls._queryorder_colmap) + query = order_parser.append_ordering(query, order) + else: + query = query.order_by(sa.desc(EndpointRow.created_at)) + + async with ctx.db.begin_readonly_session() as db_session: + result = await db_session.execute(query) return [await cls.from_row(ctx, row) for row in result.scalars().all()] @classmethod