Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pagination Hook #3494

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backend/ee/onyx/server/reporting/usage_export_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ee.onyx.server.reporting.usage_export_models import UserSkeleton
from onyx.auth.schemas import UserStatus
from onyx.configs.constants import FileOrigin
from onyx.db.users import list_users
from onyx.db.users import get_all_users
from onyx.file_store.constants import MAX_IN_MEMORY_SIZE
from onyx.file_store.file_store import FileStore
from onyx.file_store.file_store import get_default_file_store
Expand Down Expand Up @@ -86,7 +86,7 @@ def generate_user_report(
csvwriter = csv.writer(temp_file, delimiter=",")
csvwriter.writerow(["user_id", "status"])

users = list_users(db_session)
users = get_all_users(db_session)
for user in users:
user_skeleton = UserSkeleton(
user_id=str(user.id),
Expand Down
78 changes: 76 additions & 2 deletions backend/onyx/db/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
from fastapi import HTTPException
from fastapi_users.password import PasswordHelper
from sqlalchemy import func
from sqlalchemy import not_
from sqlalchemy import select
from sqlalchemy.orm import Session

from onyx.auth.schemas import UserRole
from onyx.auth.schemas import UserStatus
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.db.models import User


Expand Down Expand Up @@ -83,7 +86,7 @@ def validate_user_role_update(requested_role: UserRole, current_role: UserRole)
)


def list_users(
def get_all_users(
db_session: Session, email_filter_string: str = "", include_external: bool = False
) -> Sequence[User]:
"""List all users. No pagination as of now, as the # of users
Expand All @@ -103,13 +106,84 @@ def list_users(
return db_session.scalars(stmt).unique().all()


def _get_accepted_user_where_clause(
email_filter_string: str = "",
status_filter: UserStatus | None = None,
roles_filter: list[UserRole] = [],
include_external: bool = False,
) -> list:
# Init where clause and remove any users with API email domains
where_clause = [not_(User.email.endswith(DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN))]

# Exclude external permissioned users
if not include_external:
where_clause.append(User.role != UserRole.EXT_PERM_USER)

if email_filter_string:
where_clause.append(User.email.ilike(f"%{email_filter_string}%"))

if roles_filter:
where_clause.append(User.role.in_(roles_filter))

# When status_filter = "live" the inner condition evaluates True, if status_filter is "deactivated" we get False.
# so if status_filter = "live" we select only active users, and if status_filter = "deactivated" we select only inactive users
if status_filter:
where_clause.append(User.is_active == (status_filter == UserStatus.LIVE))

return where_clause


def get_page_of_filtered_users(
db_session: Session,
page_size: int,
page: int,
email_filter_string: str = "",
status_filter: UserStatus | None = None,
roles_filter: list[UserRole] = [],
include_external: bool = False,
) -> Sequence[User]:
users_stmt = select(User)

where_clause = _get_accepted_user_where_clause(
email_filter_string=email_filter_string,
status_filter=status_filter,
roles_filter=roles_filter,
include_external=include_external,
)
# Apply pagination
users_stmt = users_stmt.offset((page - 1) * page_size).limit(page_size)
# Apply filtering
users_stmt = users_stmt.where(*where_clause)

return db_session.scalars(users_stmt).unique().all()


def get_total_filtered_users_count(
db_session: Session,
email_filter_string: str = "",
status_filter: UserStatus | None = None,
roles_filter: list[UserRole] = [],
include_external: bool = False,
) -> int:
where_clause = _get_accepted_user_where_clause(
email_filter_string=email_filter_string,
status_filter=status_filter,
roles_filter=roles_filter,
include_external=include_external,
)
total_count_stmt = select(func.count()).select_from(User)
# Apply filtering
total_count_stmt = total_count_stmt.where(*where_clause)

return db_session.scalar(total_count_stmt)


def get_user_by_email(email: str, db_session: Session) -> User | None:
user = (
db_session.query(User)
.filter(func.lower(User.email) == func.lower(email))
.first()
)

return user


Expand Down
16 changes: 9 additions & 7 deletions backend/onyx/server/documents/cc_pair.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import math
from datetime import datetime
from http import HTTPStatus

Expand Down Expand Up @@ -48,7 +47,8 @@
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
from onyx.server.documents.models import ConnectorCredentialPairMetadata
from onyx.server.documents.models import DocumentSyncStatus
from onyx.server.documents.models import PaginatedIndexAttempts
from onyx.server.documents.models import IndexAttemptSnapshot
from onyx.server.documents.models import PaginatedReturn
from onyx.server.models import StatusResponse
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
Expand All @@ -64,7 +64,7 @@ def get_cc_pair_index_attempts(
page_size: int = Query(10, ge=1, le=1000),
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> PaginatedIndexAttempts:
) -> PaginatedReturn[IndexAttemptSnapshot]:
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id, db_session, user, get_editable=False
)
Expand All @@ -82,10 +82,12 @@ def get_cc_pair_index_attempts(
page=page,
page_size=page_size,
)
return PaginatedIndexAttempts.from_models(
index_attempt_models=index_attempts,
page=page,
total_pages=math.ceil(total_count / page_size),
return PaginatedReturn(
items=[
IndexAttemptSnapshot.from_index_attempt_db_model(index_attempt)
for index_attempt in index_attempts
],
total_items=total_count,
)


Expand Down
35 changes: 16 additions & 19 deletions backend/onyx/server/documents/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from datetime import datetime
from typing import Any
from typing import Generic
from typing import TypeVar
from uuid import UUID

from pydantic import BaseModel
Expand All @@ -19,6 +21,8 @@
from onyx.db.models import IndexAttemptError as DbIndexAttemptError
from onyx.db.models import IndexingStatus
from onyx.db.models import TaskStatus
from onyx.server.models import FullUserSnapshot
from onyx.server.models import InvitedUserSnapshot
from onyx.server.utils import mask_credential_dict


Expand Down Expand Up @@ -201,26 +205,19 @@ def from_db_model(cls, error: DbIndexAttemptError) -> "IndexAttemptError":
)


class PaginatedIndexAttempts(BaseModel):
index_attempts: list[IndexAttemptSnapshot]
page: int
total_pages: int
# These are the types currently supported by the pagination hook
# More api endpoints can be refactored and be added here for use with the pagination hook
PaginatedType = TypeVar(
"PaginatedType",
IndexAttemptSnapshot,
FullUserSnapshot,
InvitedUserSnapshot,
)

@classmethod
def from_models(
cls,
index_attempt_models: list[IndexAttempt],
page: int,
total_pages: int,
) -> "PaginatedIndexAttempts":
return cls(
index_attempts=[
IndexAttemptSnapshot.from_index_attempt_db_model(index_attempt_model)
for index_attempt_model in index_attempt_models
],
page=page,
total_pages=total_pages,
)

class PaginatedReturn(BaseModel, Generic[PaginatedType]):
items: list[PaginatedType]
total_items: int


class CCPairFullInfo(BaseModel):
Expand Down
77 changes: 72 additions & 5 deletions backend/onyx/server/manage/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from fastapi import Body
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi import Request
from psycopg2.errors import UniqueViolation
from pydantic import BaseModel
Expand Down Expand Up @@ -46,10 +47,13 @@
from onyx.db.models import SamlAccount
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.users import get_all_users
from onyx.db.users import get_page_of_filtered_users
from onyx.db.users import get_total_filtered_users_count
from onyx.db.users import get_user_by_email
from onyx.db.users import list_users
from onyx.db.users import validate_user_role_update
from onyx.key_value_store.factory import get_kv_store
from onyx.server.documents.models import PaginatedReturn
from onyx.server.manage.models import AllUsersResponse
from onyx.server.manage.models import AutoScrollRequest
from onyx.server.manage.models import UserByEmail
Expand All @@ -67,10 +71,8 @@
from shared_configs.configs import MULTI_TENANT

logger = setup_logger()

router = APIRouter()


USERS_PAGE_SIZE = 10


Expand Down Expand Up @@ -115,6 +117,71 @@ def set_user_role(
db_session.commit()


@router.get("/manage/users/accepted")
def list_accepted_users(
q: str | None = Query(default=None),
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=1000),
roles: list[UserRole] = Query(default=[]),
status: UserStatus | None = Query(default=None),
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> PaginatedReturn[FullUserSnapshot]:
if not q:
q = ""

filtered_accepted_users = get_page_of_filtered_users(
db_session=db_session,
page_size=page_size,
page=page,
email_filter_string=q,
status_filter=status,
roles_filter=roles,
)

total_accepted_users_count = get_total_filtered_users_count(
db_session=db_session,
email_filter_string=q,
status_filter=status,
roles_filter=roles,
)

if not filtered_accepted_users:
logger.info("No users found")
return PaginatedReturn(
items=[],
total_items=0,
)

return PaginatedReturn(
items=[
FullUserSnapshot.from_user_model(user) for user in filtered_accepted_users
],
total_items=total_accepted_users_count,
)


@router.get("/manage/users/invited")
def list_invited_users(
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=1000),
user: User | None = Depends(current_curator_or_admin_user),
) -> PaginatedReturn[InvitedUserSnapshot]:
invited_emails = get_invited_users()

total_count = len(invited_emails)
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size

return PaginatedReturn(
items=[
InvitedUserSnapshot(email=email)
for email in invited_emails[start_idx:end_idx]
],
total_items=total_count,
)


@router.get("/manage/users")
def list_all_users(
q: str | None = None,
Expand All @@ -129,7 +196,7 @@ def list_all_users(

users = [
user
for user in list_users(db_session, email_filter_string=q)
for user in get_all_users(db_session, email_filter_string=q)
if not is_api_key_email_address(user.email)
]

Expand Down Expand Up @@ -449,7 +516,7 @@ def list_all_users_basic_info(
_: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[MinimalUserSnapshot]:
users = list_users(db_session)
users = get_all_users(db_session)
return [MinimalUserSnapshot(id=user.id, email=user.email) for user in users]


Expand Down
10 changes: 10 additions & 0 deletions backend/onyx/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from onyx.auth.schemas import UserRole
from onyx.auth.schemas import UserStatus
from onyx.db.models import User


DataT = TypeVar("DataT")
Expand Down Expand Up @@ -37,6 +38,15 @@ class FullUserSnapshot(BaseModel):
role: UserRole
status: UserStatus

@classmethod
def from_user_model(cls, user: User) -> "FullUserSnapshot":
return cls(
id=user.id,
email=user.email,
role=user.role,
status=UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED,
)


class InvitedUserSnapshot(BaseModel):
email: str
Expand Down
Loading
Loading