diff --git a/src/ai/backend/manager/repositories/vfolder/repository.py b/src/ai/backend/manager/repositories/vfolder/repository.py index 4e39defe82c..832a7fd4054 100644 --- a/src/ai/backend/manager/repositories/vfolder/repository.py +++ b/src/ai/backend/manager/repositories/vfolder/repository.py @@ -139,6 +139,11 @@ execute_rbac_revoker, ) from ai.backend.manager.repositories.base.updater import Updater, execute_updater +from ai.backend.manager.repositories.permission_controller.role_manager import ( + RoleManager, + UserSystemRoleSpec, +) + from ai.backend.manager.repositories.vfolder.creators import VFolderCreatorSpec from ai.backend.manager.repositories.vfolder.types import ( BulkVFolderPurgeResult, @@ -172,9 +177,11 @@ class _VFolderWithLinkedModelCards: class VfolderRepository: _db: ExtendedAsyncSAEngine + _role_manager: RoleManager def __init__(self, db: ExtendedAsyncSAEngine) -> None: self._db = db + self._role_manager = RoleManager() @vfolder_repository_resilience.apply() async def get_by_id_validated( @@ -1135,7 +1142,8 @@ async def _get_user_role_id(self, session: SASession, user_id: uuid.UUID) -> uui Get the system role_id associated with a user. Looks up the UserRoleRow joined with RoleRow where the role source is SYSTEM. - Raises ObjectNotFound if the user's system role is not found. + If no SYSTEM role exists for the user, creates one as a fallback within the + same transaction and returns its role_id. """ stmt = ( sa.select(UserRoleRow.role_id) @@ -1148,9 +1156,13 @@ async def _get_user_role_id(self, session: SASession, user_id: uuid.UUID) -> uui ) ) result = await session.scalar(stmt) - if result is None: - raise ObjectNotFound(object_name="user system role", extra_msg=str(user_id)) - return result + if result is not None: + return result + + # Fallback: create a system role for this user and return its id + spec = UserSystemRoleSpec(user_id=user_id) + role = await self._role_manager.create_system_role(session, spec) + return role.id def _get_vfolder_scope(self, vfolder: VFolderData) -> ScopeId: """Determine scope from vfolder ownership.""" diff --git a/tests/unit/manager/repositories/vfolder/test_vfolder_invitations.py b/tests/unit/manager/repositories/vfolder/test_vfolder_invitations.py index 7649049afd6..f51ba01a303 100644 --- a/tests/unit/manager/repositories/vfolder/test_vfolder_invitations.py +++ b/tests/unit/manager/repositories/vfolder/test_vfolder_invitations.py @@ -523,3 +523,132 @@ async def test_get_vfolder_invitations_username_null( assert len(results) == 1 assert results[0].inviter_username is None + + +class TestGetUserRoleIdFallback: + """Fallback system role creation when no SYSTEM role exists for a user (BA-6253).""" + + @pytest.fixture + async def db_with_cleanup( + self, database_connection: ExtendedAsyncSAEngine + ) -> AsyncGenerator[ExtendedAsyncSAEngine, None]: + async with with_tables( + database_connection, + [ + DomainRow, + UserResourcePolicyRow, + KeyPairResourcePolicyRow, + UserRow, + KeyPairRow, + VFolderRow, + VFolderInvitationRow, + ], + ): + yield database_connection + + @pytest.fixture + async def sample_domain( + self, + domain_factory: DomainFactory, + db_with_cleanup: ExtendedAsyncSAEngine, + ) -> DomainFixtureData: + return await domain_factory(db_with_cleanup) + + async def _create_user( + self, + db: ExtendedAsyncSAEngine, + *, + domain_name: str, + resource_policy: str, + email: str, + username: str | None, + ) -> UserRow: + user_uuid = uuid.uuid4() + async with db.begin_session() as session: + user = UserRow( + uuid=user_uuid, + domain_name=domain_name, + email=email, + username=username, + password=None, + full_name="Test User", + is_active=True, + status=UserStatus.ACTIVE, + role=UserRole.USER, + resource_policy=resource_policy, + created_at=datetime.now(UTC), + ) + session.add(user) + async with db.begin_session() as session: + return (await session.get(UserRow, user_uuid)) # type: ignore[return-value] + + @pytest.fixture + async def user_resource_policy(self, db_with_cleanup: ExtendedAsyncSAEngine) -> str: + policy_name = f"user-policy-{uuid.uuid4().hex[:8]}" + async with db_with_cleanup.begin_session() as session: + policy = UserResourcePolicyRow( + name=policy_name, + max_vfolder_count=0, + max_quota_scope_size=-1, + max_session_count_per_model_session=10, + max_customized_image_count=10, + ) + session.add(policy) + return policy_name + + async def test_existing_role_returned_when_present( + self, + db_with_cleanup: ExtendedAsyncSAEngine, + sample_domain: DomainFixtureData, + user_resource_policy: str, + ) -> None: + """When user already has a SYSTEM role, _get_user_role_id returns it without creating a new one.""" + repo = VfolderRepository(db_with_cleanup) + + user = await self._create_user( + db_with_cleanup, + domain_name=sample_domain.name, + resource_policy=user_resource_policy, + email="existing-role@test.com", + username="existing_role_user", + ) + + async with db_with_cleanup.begin_session() as session: + role_id = await repo._get_user_role_id(session, user.uuid) + assert role_id is not None + + # Should be the same role on subsequent calls + role_id_2 = await repo._get_user_role_id(session, user.uuid) + assert role_id == role_id_2 + + async def test_fallback_creates_system_role_when_missing( + self, + db_with_cleanup: ExtendedAsyncSAEngine, + sample_domain: DomainFixtureData, + user_resource_policy: str, + ) -> None: + """When user has no SYSTEM role, _get_user_role_id creates one and returns its id.""" + from ai.backend.manager.models.rbac_models.role import RoleRow + + repo = VfolderRepository(db_with_cleanup) + + user = await self._create_user( + db_with_cleanup, + domain_name=sample_domain.name, + resource_policy=user_resource_policy, + email="no-role@test.com", + username="no_role_user", + ) + + async with db_with_cleanup.begin_session() as session: + role_id = await repo._get_user_role_id(session, user.uuid) + assert role_id is not None + + # Verify the role was actually created and is a SYSTEM role + role_row = await session.get(RoleRow, role_id) + assert role_row is not None + assert role_row.source.value == "SYSTEM" + + # Second call should return the same role id (no duplicate created) + role_id_2 = await repo._get_user_role_id(session, user.uuid) + assert role_id == role_id_2