Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions src/ai/backend/manager/repositories/vfolder/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@
execute_rbac_revoker,
)
from ai.backend.manager.repositories.base.updater import Updater, execute_updater
from ai.backend.manager.repositories.permission_controller.role_manager import (
RoleManager,
UserSystemRoleSpec,
)

from ai.backend.manager.repositories.vfolder.creators import VFolderCreatorSpec
from ai.backend.manager.repositories.vfolder.types import (
BulkVFolderPurgeResult,
Expand Down Expand Up @@ -172,9 +177,11 @@ class _VFolderWithLinkedModelCards:

class VfolderRepository:
_db: ExtendedAsyncSAEngine
_role_manager: RoleManager

def __init__(self, db: ExtendedAsyncSAEngine) -> None:
self._db = db
self._role_manager = RoleManager()

@vfolder_repository_resilience.apply()
async def get_by_id_validated(
Expand Down Expand Up @@ -1135,7 +1142,8 @@ async def _get_user_role_id(self, session: SASession, user_id: uuid.UUID) -> uui
Get the system role_id associated with a user.

Looks up the UserRoleRow joined with RoleRow where the role source is SYSTEM.
Raises ObjectNotFound if the user's system role is not found.
If no SYSTEM role exists for the user, creates one as a fallback within the
same transaction and returns its role_id.
"""
stmt = (
sa.select(UserRoleRow.role_id)
Expand All @@ -1148,9 +1156,13 @@ async def _get_user_role_id(self, session: SASession, user_id: uuid.UUID) -> uui
)
)
result = await session.scalar(stmt)
if result is None:
raise ObjectNotFound(object_name="user system role", extra_msg=str(user_id))
return result
if result is not None:
return result

# Fallback: create a system role for this user and return its id
spec = UserSystemRoleSpec(user_id=user_id)
role = await self._role_manager.create_system_role(session, spec)
return role.id
Comment on lines 1150 to +1165
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach doesn't seem appropriate. It would make more sense for the operation to fail if it can't retrieve the data from get; creating and passing the user's role when it can't be retrieved is very tricky.


def _get_vfolder_scope(self, vfolder: VFolderData) -> ScopeId:
"""Determine scope from vfolder ownership."""
Expand Down
129 changes: 129 additions & 0 deletions tests/unit/manager/repositories/vfolder/test_vfolder_invitations.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,3 +523,132 @@ async def test_get_vfolder_invitations_username_null(

assert len(results) == 1
assert results[0].inviter_username is None


class TestGetUserRoleIdFallback:
"""Fallback system role creation when no SYSTEM role exists for a user (BA-6253)."""

@pytest.fixture
async def db_with_cleanup(
self, database_connection: ExtendedAsyncSAEngine
) -> AsyncGenerator[ExtendedAsyncSAEngine, None]:
async with with_tables(
database_connection,
[
DomainRow,
UserResourcePolicyRow,
KeyPairResourcePolicyRow,
UserRow,
KeyPairRow,
VFolderRow,
VFolderInvitationRow,
],
):
yield database_connection

@pytest.fixture
async def sample_domain(
self,
domain_factory: DomainFactory,
db_with_cleanup: ExtendedAsyncSAEngine,
) -> DomainFixtureData:
return await domain_factory(db_with_cleanup)

async def _create_user(
self,
db: ExtendedAsyncSAEngine,
*,
domain_name: str,
resource_policy: str,
email: str,
username: str | None,
) -> UserRow:
user_uuid = uuid.uuid4()
async with db.begin_session() as session:
user = UserRow(
uuid=user_uuid,
domain_name=domain_name,
email=email,
username=username,
password=None,
full_name="Test User",
is_active=True,
status=UserStatus.ACTIVE,
role=UserRole.USER,
resource_policy=resource_policy,
created_at=datetime.now(UTC),
)
session.add(user)
async with db.begin_session() as session:
return (await session.get(UserRow, user_uuid)) # type: ignore[return-value]

@pytest.fixture
async def user_resource_policy(self, db_with_cleanup: ExtendedAsyncSAEngine) -> str:
policy_name = f"user-policy-{uuid.uuid4().hex[:8]}"
async with db_with_cleanup.begin_session() as session:
policy = UserResourcePolicyRow(
name=policy_name,
max_vfolder_count=0,
max_quota_scope_size=-1,
max_session_count_per_model_session=10,
max_customized_image_count=10,
)
session.add(policy)
return policy_name

async def test_existing_role_returned_when_present(
self,
db_with_cleanup: ExtendedAsyncSAEngine,
sample_domain: DomainFixtureData,
user_resource_policy: str,
) -> None:
"""When user already has a SYSTEM role, _get_user_role_id returns it without creating a new one."""
repo = VfolderRepository(db_with_cleanup)

user = await self._create_user(
db_with_cleanup,
domain_name=sample_domain.name,
resource_policy=user_resource_policy,
email="existing-role@test.com",
username="existing_role_user",
)

async with db_with_cleanup.begin_session() as session:
role_id = await repo._get_user_role_id(session, user.uuid)
assert role_id is not None

# Should be the same role on subsequent calls
role_id_2 = await repo._get_user_role_id(session, user.uuid)
assert role_id == role_id_2

async def test_fallback_creates_system_role_when_missing(
self,
db_with_cleanup: ExtendedAsyncSAEngine,
sample_domain: DomainFixtureData,
user_resource_policy: str,
) -> None:
"""When user has no SYSTEM role, _get_user_role_id creates one and returns its id."""
from ai.backend.manager.models.rbac_models.role import RoleRow

repo = VfolderRepository(db_with_cleanup)

user = await self._create_user(
db_with_cleanup,
domain_name=sample_domain.name,
resource_policy=user_resource_policy,
email="no-role@test.com",
username="no_role_user",
)

async with db_with_cleanup.begin_session() as session:
role_id = await repo._get_user_role_id(session, user.uuid)
assert role_id is not None

# Verify the role was actually created and is a SYSTEM role
role_row = await session.get(RoleRow, role_id)
assert role_row is not None
assert role_row.source.value == "SYSTEM"

# Second call should return the same role id (no duplicate created)
role_id_2 = await repo._get_user_role_id(session, user.uuid)
assert role_id == role_id_2
Loading