diff --git a/changes/11733.feature.md b/changes/11733.feature.md new file mode 100644 index 00000000000..07d0ef56154 --- /dev/null +++ b/changes/11733.feature.md @@ -0,0 +1 @@ +Record the client IP address on every `login_history` entry and expose it as `clientIp` on `LoginHistoryV2` (null for system-driven events such as eviction/expiration). diff --git a/docs/manager/graphql-reference/supergraph.graphql b/docs/manager/graphql-reference/supergraph.graphql index 68ed385c512..421831e78b3 100644 --- a/docs/manager/graphql-reference/supergraph.graphql +++ b/docs/manager/graphql-reference/supergraph.graphql @@ -8729,6 +8729,11 @@ input LoginHistoryFilter domainName: StringFilter = null result: LoginHistoryResultFilter = null createdAt: DateTimeFilter = null + + """ + Added in UNRELEASED. Filter by the client IP recorded on the login_history event. + """ + clientIp: StringFilter = null AND: [LoginHistoryFilter!] = null OR: [LoginHistoryFilter!] = null NOT: [LoginHistoryFilter!] = null @@ -8749,6 +8754,7 @@ enum LoginHistoryOrderField CREATED_AT @join__enumValue(graph: STRAWBERRY) RESULT @join__enumValue(graph: STRAWBERRY) DOMAIN_NAME @join__enumValue(graph: STRAWBERRY) + CLIENT_IP @join__enumValue(graph: STRAWBERRY) } """Added in 26.4.2. Filter for login attempt result field.""" @@ -8790,6 +8796,11 @@ type LoginHistoryV2 implements Node """Timestamp when the login attempt occurred.""" createdAt: DateTime! + """ + Added in UNRELEASED. Client IP that initiated the event. Null for system-driven events (e.g. eviction, expiration). + """ + clientIp: String + """Added in 26.4.3. The user who attempted to log in.""" user: UserV2 diff --git a/docs/manager/graphql-reference/v2-schema.graphql b/docs/manager/graphql-reference/v2-schema.graphql index 199e2f10a96..4ffe57de923 100644 --- a/docs/manager/graphql-reference/v2-schema.graphql +++ b/docs/manager/graphql-reference/v2-schema.graphql @@ -5675,6 +5675,11 @@ input LoginHistoryFilter { domainName: StringFilter = null result: LoginHistoryResultFilter = null createdAt: DateTimeFilter = null + + """ + Added in UNRELEASED. Filter by the client IP recorded on the login_history event. + """ + clientIp: StringFilter = null AND: [LoginHistoryFilter!] = null OR: [LoginHistoryFilter!] = null NOT: [LoginHistoryFilter!] = null @@ -5691,6 +5696,7 @@ enum LoginHistoryOrderField { CREATED_AT RESULT DOMAIN_NAME + CLIENT_IP } """Added in 26.4.2. Filter for login attempt result field.""" @@ -5727,6 +5733,11 @@ type LoginHistoryV2 implements Node { """Timestamp when the login attempt occurred.""" createdAt: DateTime! + """ + Added in UNRELEASED. Client IP that initiated the event. Null for system-driven events (e.g. eviction, expiration). + """ + clientIp: String + """Added in 26.4.3. The user who attempted to log in.""" user: UserV2 diff --git a/src/ai/backend/common/dto/manager/v2/login_history/request.py b/src/ai/backend/common/dto/manager/v2/login_history/request.py index 391b326c759..11fcb230934 100644 --- a/src/ai/backend/common/dto/manager/v2/login_history/request.py +++ b/src/ai/backend/common/dto/manager/v2/login_history/request.py @@ -43,6 +43,9 @@ class LoginHistoryFilter(BaseRequestModel): created_at: DateTimeFilter | None = Field( default=None, description="Filter history by created_at datetime" ) + client_ip: StringFilter | None = Field( + default=None, description="Client IP filter (source IP recorded on the event)" + ) AND: list[LoginHistoryFilter] | None = Field( default=None, description="All conditions must match" ) diff --git a/src/ai/backend/common/dto/manager/v2/login_history/response.py b/src/ai/backend/common/dto/manager/v2/login_history/response.py index 8fbf7a6e703..cdb0dbe8624 100644 --- a/src/ai/backend/common/dto/manager/v2/login_history/response.py +++ b/src/ai/backend/common/dto/manager/v2/login_history/response.py @@ -28,6 +28,10 @@ class LoginHistoryNode(BaseResponseModel): fail_reason: str | None = Field( default=None, description="Detailed reason for the login failure" ) + client_ip: str | None = Field( + default=None, + description="Client IP that initiated the event. Null for system-driven events (e.g. eviction, expiration).", + ) created_at: datetime = Field(description="Timestamp when the login attempt occurred") diff --git a/src/ai/backend/common/dto/manager/v2/login_history/types.py b/src/ai/backend/common/dto/manager/v2/login_history/types.py index 0e59e2c2a85..7b2e2671446 100644 --- a/src/ai/backend/common/dto/manager/v2/login_history/types.py +++ b/src/ai/backend/common/dto/manager/v2/login_history/types.py @@ -36,3 +36,4 @@ class LoginHistoryOrderField(StrEnum): CREATED_AT = "created_at" RESULT = "result" DOMAIN_NAME = "domain_name" + CLIENT_IP = "client_ip" diff --git a/src/ai/backend/manager/api/adapters/login_history/adapter.py b/src/ai/backend/manager/api/adapters/login_history/adapter.py index fdd90b6c62e..4799e2f92ff 100644 --- a/src/ai/backend/manager/api/adapters/login_history/adapter.py +++ b/src/ai/backend/manager/api/adapters/login_history/adapter.py @@ -132,6 +132,17 @@ def _convert_filter(self, f: LoginHistoryFilter) -> list[QueryCondition]: ) if condition is not None: conditions.append(condition) + if f.client_ip is not None: + condition = self.convert_string_filter( + f.client_ip, + contains_factory=LoginHistoryConditions.by_client_ip_contains, + equals_factory=LoginHistoryConditions.by_client_ip_equals, + starts_with_factory=LoginHistoryConditions.by_client_ip_starts_with, + ends_with_factory=LoginHistoryConditions.by_client_ip_ends_with, + in_factory=LoginHistoryConditions.by_client_ip_in, + ) + if condition is not None: + conditions.append(condition) if f.AND: for sub_filter in f.AND: conditions.extend(self._convert_filter(sub_filter)) @@ -172,6 +183,8 @@ def _convert_orders(orders: list[LoginHistoryOrder]) -> list[QueryOrder]: result.append(LoginHistoryOrders.result(ascending)) case LoginHistoryOrderField.DOMAIN_NAME: result.append(LoginHistoryOrders.domain_name(ascending)) + case LoginHistoryOrderField.CLIENT_IP: + result.append(LoginHistoryOrders.client_ip(ascending)) return result @staticmethod @@ -182,5 +195,6 @@ def _data_to_node(data: LoginHistoryData) -> LoginHistoryNode: domain_name=data.domain_name, result=LoginAttemptResult(data.result.value), fail_reason=data.fail_reason, + client_ip=data.client_ip, created_at=data.created_at, ) diff --git a/src/ai/backend/manager/api/gql/login_history/types/filter.py b/src/ai/backend/manager/api/gql/login_history/types/filter.py index 84304f216b7..5db70f63f8b 100644 --- a/src/ai/backend/manager/api/gql/login_history/types/filter.py +++ b/src/ai/backend/manager/api/gql/login_history/types/filter.py @@ -8,9 +8,11 @@ LoginHistoryFilter, LoginHistoryResultFilter, ) +from ai.backend.common.meta.meta import NEXT_RELEASE_VERSION from ai.backend.manager.api.gql.base import DateTimeFilter, StringFilter from ai.backend.manager.api.gql.decorators import ( BackendAIGQLMeta, + gql_added_field, gql_field, gql_pydantic_input, ) @@ -48,6 +50,13 @@ class LoginHistoryFilterGQL(PydanticInputMixin[LoginHistoryFilter]): domain_name: StringFilter | None = None result: LoginHistoryResultFilterGQL | None = None created_at: DateTimeFilter | None = None + client_ip: StringFilter | None = gql_added_field( + BackendAIGQLMeta( + added_version=NEXT_RELEASE_VERSION, + description="Filter by the client IP recorded on the login_history event.", + ), + default=None, + ) AND: list[Self] | None = None OR: list[Self] | None = None diff --git a/src/ai/backend/manager/api/gql/login_history/types/node.py b/src/ai/backend/manager/api/gql/login_history/types/node.py index 66243657442..bc6fc9c74b5 100644 --- a/src/ai/backend/manager/api/gql/login_history/types/node.py +++ b/src/ai/backend/manager/api/gql/login_history/types/node.py @@ -12,6 +12,7 @@ from strawberry.relay import Connection, Edge, NodeID from ai.backend.common.dto.manager.v2.login_history.response import LoginHistoryNode +from ai.backend.common.meta.meta import NEXT_RELEASE_VERSION from ai.backend.manager.api.gql.decorators import ( BackendAIGQLMeta, gql_added_field, @@ -63,6 +64,13 @@ class LoginHistoryV2GQL(PydanticNodeMixin[LoginHistoryNode]): fail_reason: str | None = gql_field(description="Detailed reason for the login failure.") created_at: datetime = gql_field(description="Timestamp when the login attempt occurred.") + client_ip: str | None = gql_added_field( + BackendAIGQLMeta( + added_version=NEXT_RELEASE_VERSION, + description="Client IP that initiated the event. Null for system-driven events (e.g. eviction, expiration).", + ) + ) + @gql_added_field( BackendAIGQLMeta( added_version="26.4.3", diff --git a/src/ai/backend/manager/api/gql/login_history/types/order.py b/src/ai/backend/manager/api/gql/login_history/types/order.py index e684dff2084..5e3fe3b560b 100644 --- a/src/ai/backend/manager/api/gql/login_history/types/order.py +++ b/src/ai/backend/manager/api/gql/login_history/types/order.py @@ -25,6 +25,7 @@ class LoginHistoryOrderFieldGQL(StrEnum): CREATED_AT = "created_at" RESULT = "result" DOMAIN_NAME = "domain_name" + CLIENT_IP = "client_ip" @gql_pydantic_input( diff --git a/src/ai/backend/manager/api/rest/auth/handler.py b/src/ai/backend/manager/api/rest/auth/handler.py index 49f5041eef3..bec1c977bed 100644 --- a/src/ai/backend/manager/api/rest/auth/handler.py +++ b/src/ai/backend/manager/api/rest/auth/handler.py @@ -151,6 +151,7 @@ async def authorize( stoken=params.stoken, otp=params.otp, client_type_id=params.client_type_id, + client_ip=extract_client_ip(ctx.request), force=params.force, ) result = await self._auth.authorize.wait_for_complete(action) @@ -179,7 +180,10 @@ async def authorize( async def logout(self, body: BodyParam[LogoutRequest], ctx: RequestCtx) -> APIResponse: params = body.parsed log.info("AUTH.LOGOUT(session_token:{}...)", params.session_token[:8]) - action = LogoutAction(session_token=params.session_token) + action = LogoutAction( + session_token=params.session_token, + client_ip=extract_client_ip(ctx.request), + ) await self._auth.logout.wait_for_complete(action) return APIResponse.build(HTTPStatus.OK, LogoutResponse()) diff --git a/src/ai/backend/manager/data/auth/login_session_types.py b/src/ai/backend/manager/data/auth/login_session_types.py index 265e30feec6..6ead39a9833 100644 --- a/src/ai/backend/manager/data/auth/login_session_types.py +++ b/src/ai/backend/manager/data/auth/login_session_types.py @@ -26,4 +26,5 @@ class LoginHistoryData: domain_name: str result: LoginAttemptResult fail_reason: str | None + client_ip: str | None created_at: datetime diff --git a/src/ai/backend/manager/models/alembic/versions/a1b3e7c2d4f5_add_client_ip_to_login_history.py b/src/ai/backend/manager/models/alembic/versions/a1b3e7c2d4f5_add_client_ip_to_login_history.py new file mode 100644 index 00000000000..bbbe0e3462b --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/a1b3e7c2d4f5_add_client_ip_to_login_history.py @@ -0,0 +1,28 @@ +"""add client_ip to login_history + +Revision ID: a1b3e7c2d4f5 +Revises: b8a85c96607c +Create Date: 2026-05-20 + +""" + +# Part of: 26.5.0 + +import sqlalchemy as sa +from alembic import op + +revision = "a1b3e7c2d4f5" +down_revision = "b8a85c96607c" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "login_history", + sa.Column("client_ip", sa.String(length=45), nullable=True), + ) + + +def downgrade() -> None: + op.drop_column("login_history", "client_ip") diff --git a/src/ai/backend/manager/models/login_session/row.py b/src/ai/backend/manager/models/login_session/row.py index 6f841e782da..e0acb0cd8a2 100644 --- a/src/ai/backend/manager/models/login_session/row.py +++ b/src/ai/backend/manager/models/login_session/row.py @@ -105,6 +105,9 @@ class LoginHistoryRow(Base): # type: ignore[misc] index=True, ) fail_reason: Mapped[str | None] = mapped_column("fail_reason", sa.Text, nullable=True) + # 45 = INET6_ADDRSTRLEN - 1: the longest possible textual representation of an + # IP address is an IPv4-mapped IPv6 form like "0000:0000:0000:0000:0000:ffff:255.255.255.255". + client_ip: Mapped[str | None] = mapped_column("client_ip", sa.String(45), nullable=True) created_at: Mapped[datetime] = mapped_column( "created_at", sa.DateTime(timezone=True), @@ -126,5 +129,6 @@ def to_data(self) -> LoginHistoryData: domain_name=self.domain_name, result=self.result, fail_reason=self.fail_reason, + client_ip=self.client_ip, created_at=self.created_at, ) diff --git a/src/ai/backend/manager/repositories/auth/db_source/db_source.py b/src/ai/backend/manager/repositories/auth/db_source/db_source.py index 36dbff059f4..d3570bc9ee0 100644 --- a/src/ai/backend/manager/repositories/auth/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/auth/db_source/db_source.py @@ -371,6 +371,7 @@ async def _record_login_history( domain_name: str, result: LoginAttemptResult, fail_reason: str | None, + client_ip: str | None, ) -> None: """Insert a login history record (internal, within an existing connection).""" await conn.execute( @@ -379,6 +380,7 @@ async def _record_login_history( domain_name=domain_name, result=result, fail_reason=fail_reason, + client_ip=client_ip, ) ) @@ -389,6 +391,7 @@ async def record_login_history( domain_name: str, result: LoginAttemptResult, fail_reason: str | None = None, + client_ip: str | None = None, ) -> None: """Insert a login history record (public, manages its own transaction).""" async with self._db.begin_session() as db_session: @@ -398,6 +401,7 @@ async def record_login_history( domain_name=domain_name, result=result, fail_reason=fail_reason, + client_ip=client_ip, ) ) @@ -465,6 +469,7 @@ async def delete_sessions_by_tokens( self, session_tokens: list[str], result: LoginAttemptResult, + client_ip: str | None = None, ) -> None: """Delete the given login sessions and record history for each. @@ -483,11 +488,12 @@ async def delete_sessions_by_tokens( .cte("deleted") ) insert_query = lh.insert().from_select( - ["user_id", "domain_name", "result"], + ["user_id", "domain_name", "result", "client_ip"], sa.select( deleted.c.user_id, users.c.domain_name, sa.literal(result.value).label("result"), + sa.literal(client_ip).label("client_ip"), ).select_from(deleted.join(users, deleted.c.user_id == users.c.uuid)), ) await conn.execute(insert_query) @@ -501,6 +507,7 @@ async def create_login_session( domain_name: str, *, login_client_type_id: UUID | None = None, + client_ip: str | None = None, ) -> LoginSessionCreationResult: """Create a new active login session and record a successful login history entry. @@ -523,7 +530,12 @@ async def create_login_session( # Record successful login in the same transaction. await self._record_login_history( - conn, user_id, domain_name, LoginAttemptResult.SUCCESS, fail_reason=None + conn, + user_id, + domain_name, + LoginAttemptResult.SUCCESS, + fail_reason=None, + client_ip=client_ip, ) await conn.commit() @@ -608,6 +620,7 @@ async def delete_session_by_token( self, session_token: str, result: LoginAttemptResult, + client_ip: str | None = None, ) -> None: """Delete a single login session by its token and record history. @@ -624,11 +637,12 @@ async def delete_session_by_token( .cte("deleted") ) insert_query = lh.insert().from_select( - ["user_id", "domain_name", "result"], + ["user_id", "domain_name", "result", "client_ip"], sa.select( deleted.c.user_id, users.c.domain_name, sa.literal(result.value).label("result"), + sa.literal(client_ip).label("client_ip"), ).select_from(deleted.join(users, deleted.c.user_id == users.c.uuid)), ) await conn.execute(insert_query) @@ -640,6 +654,7 @@ async def delete_sessions_by_user( user_id: UUID, domain_name: str, result: LoginAttemptResult, + client_ip: str | None = None, ) -> list[str]: """Delete all login sessions for a user, record history, return tokens. @@ -661,6 +676,7 @@ async def delete_sessions_by_user( "user_id": user_id, "domain_name": domain_name, "result": result, + "client_ip": client_ip, } for _ in deleted_tokens ], @@ -721,6 +737,7 @@ async def delete_session_by_id( self, session_id: UUID, result: LoginAttemptResult, + client_ip: str | None = None, ) -> str: """Delete a login session by its ID, record history, return session_token. @@ -754,6 +771,7 @@ async def delete_session_by_id( user_id=row.user_id, domain_name=domain_name, result=result, + client_ip=client_ip, ) ) await conn.commit() diff --git a/src/ai/backend/manager/repositories/auth/options.py b/src/ai/backend/manager/repositories/auth/options.py index ea104081e7e..31f1c4b9aca 100644 --- a/src/ai/backend/manager/repositories/auth/options.py +++ b/src/ai/backend/manager/repositories/auth/options.py @@ -315,6 +315,62 @@ def inner() -> sa.sql.expression.ColumnElement[bool]: by_domain_name_in = staticmethod(make_string_in_factory(LoginHistoryRow.domain_name)) + # --- client_ip string filters --- + + @staticmethod + def by_client_ip_contains(spec: StringMatchSpec) -> QueryCondition: + def inner() -> sa.sql.expression.ColumnElement[bool]: + if spec.case_insensitive: + condition = LoginHistoryRow.client_ip.ilike(f"%{spec.value}%") + else: + condition = LoginHistoryRow.client_ip.like(f"%{spec.value}%") + if spec.negated: + condition = sa.not_(condition) + return condition + + return inner + + @staticmethod + def by_client_ip_equals(spec: StringMatchSpec) -> QueryCondition: + def inner() -> sa.sql.expression.ColumnElement[bool]: + if spec.case_insensitive: + condition = sa.func.lower(LoginHistoryRow.client_ip) == spec.value.lower() + else: + condition = LoginHistoryRow.client_ip == spec.value + if spec.negated: + condition = sa.not_(condition) + return condition + + return inner + + @staticmethod + def by_client_ip_starts_with(spec: StringMatchSpec) -> QueryCondition: + def inner() -> sa.sql.expression.ColumnElement[bool]: + if spec.case_insensitive: + condition = LoginHistoryRow.client_ip.ilike(f"{spec.value}%") + else: + condition = LoginHistoryRow.client_ip.like(f"{spec.value}%") + if spec.negated: + condition = sa.not_(condition) + return condition + + return inner + + @staticmethod + def by_client_ip_ends_with(spec: StringMatchSpec) -> QueryCondition: + def inner() -> sa.sql.expression.ColumnElement[bool]: + if spec.case_insensitive: + condition = LoginHistoryRow.client_ip.ilike(f"%{spec.value}") + else: + condition = LoginHistoryRow.client_ip.like(f"%{spec.value}") + if spec.negated: + condition = sa.not_(condition) + return condition + + return inner + + by_client_ip_in = staticmethod(make_string_in_factory(LoginHistoryRow.client_ip)) + # --- created_at datetime filters --- @staticmethod @@ -385,3 +441,9 @@ def domain_name(ascending: bool = True) -> QueryOrder: if ascending: return LoginHistoryRow.domain_name.asc() return LoginHistoryRow.domain_name.desc() + + @staticmethod + def client_ip(ascending: bool = True) -> QueryOrder: + if ascending: + return LoginHistoryRow.client_ip.asc() + return LoginHistoryRow.client_ip.desc() diff --git a/src/ai/backend/manager/repositories/auth/repository.py b/src/ai/backend/manager/repositories/auth/repository.py index 20da8f1ff0b..bef6e0ba883 100644 --- a/src/ai/backend/manager/repositories/auth/repository.py +++ b/src/ai/backend/manager/repositories/auth/repository.py @@ -120,19 +120,24 @@ async def create_login_session( domain_name: str, *, login_client_type_id: UUID | None = None, + client_ip: str | None = None, ) -> LoginSessionCreationResult: return await self._db_source.create_login_session( user_id, access_key, domain_name, login_client_type_id=login_client_type_id, + client_ip=client_ip, ) @auth_repository_resilience.apply() async def delete_login_sessions_by_tokens( - self, session_tokens: list[str], result: LoginAttemptResult + self, + session_tokens: list[str], + result: LoginAttemptResult, + client_ip: str | None = None, ) -> None: - await self._db_source.delete_sessions_by_tokens(session_tokens, result) + await self._db_source.delete_sessions_by_tokens(session_tokens, result, client_ip) @auth_repository_resilience.apply() async def check_credential_without_migration( @@ -166,15 +171,24 @@ async def get_active_session_tokens( @auth_repository_resilience.apply() async def delete_login_session_by_token( - self, session_token: str, result: LoginAttemptResult + self, + session_token: str, + result: LoginAttemptResult, + client_ip: str | None = None, ) -> None: - await self._db_source.delete_session_by_token(session_token, result) + await self._db_source.delete_session_by_token(session_token, result, client_ip) @auth_repository_resilience.apply() async def delete_user_login_sessions( - self, user_id: UUID, domain_name: str, result: LoginAttemptResult + self, + user_id: UUID, + domain_name: str, + result: LoginAttemptResult, + client_ip: str | None = None, ) -> list[str]: - return await self._db_source.delete_sessions_by_user(user_id, domain_name, result) + return await self._db_source.delete_sessions_by_user( + user_id, domain_name, result, client_ip + ) @auth_repository_resilience.apply() async def admin_search_login_sessions( @@ -196,9 +210,14 @@ async def get_login_session_by_id(self, session_id: UUID) -> LoginSessionData: return await self._db_source.fetch_login_session_by_id(session_id) @auth_repository_resilience.apply() - async def delete_login_session_by_id(self, session_id: UUID, result: LoginAttemptResult) -> str: + async def delete_login_session_by_id( + self, + session_id: UUID, + result: LoginAttemptResult, + client_ip: str | None = None, + ) -> str: """Delete a login session, record history, and return its session_token.""" - return await self._db_source.delete_session_by_id(session_id, result) + return await self._db_source.delete_session_by_id(session_id, result, client_ip) @auth_repository_resilience.apply() async def record_login_history( @@ -207,8 +226,11 @@ async def record_login_history( domain_name: str, result: LoginAttemptResult, fail_reason: str | None = None, + client_ip: str | None = None, ) -> None: - await self._db_source.record_login_history(user_id, domain_name, result, fail_reason) + await self._db_source.record_login_history( + user_id, domain_name, result, fail_reason, client_ip + ) # --- Login History --- diff --git a/src/ai/backend/manager/services/auth/actions/authorize.py b/src/ai/backend/manager/services/auth/actions/authorize.py index 49913db1daa..4f437427566 100644 --- a/src/ai/backend/manager/services/auth/actions/authorize.py +++ b/src/ai/backend/manager/services/auth/actions/authorize.py @@ -21,6 +21,7 @@ class AuthorizeAction(AuthAction): stoken: str | None otp: str | None client_type_id: UUID | None + client_ip: str | None = None force: bool = False @override diff --git a/src/ai/backend/manager/services/auth/actions/logout.py b/src/ai/backend/manager/services/auth/actions/logout.py index 0c18e515caf..64975843935 100644 --- a/src/ai/backend/manager/services/auth/actions/logout.py +++ b/src/ai/backend/manager/services/auth/actions/logout.py @@ -9,6 +9,7 @@ @dataclass class LogoutAction(AuthAction): session_token: str + client_ip: str | None = None @override def entity_id(self) -> str | None: diff --git a/src/ai/backend/manager/services/auth/service.py b/src/ai/backend/manager/services/auth/service.py index 54cb241e525..285773468a5 100644 --- a/src/ai/backend/manager/services/auth/service.py +++ b/src/ai/backend/manager/services/auth/service.py @@ -15,6 +15,7 @@ LoginSessionInner, LoginSessionTokenData, ) +from ai.backend.common.contexts.client_ip import current_client_ip from ai.backend.common.dto.manager.auth.types import AuthTokenType from ai.backend.common.exception import InvalidAPIParameters, UserResourcePolicyNotFound from ai.backend.common.plugin.hook import ALL_COMPLETED, FIRST_COMPLETED, PASSED, HookPluginContext @@ -216,6 +217,7 @@ async def authorize(self, action: AuthorizeAction) -> AuthorizeActionResult: user.uuid, action.domain_name, _classify_failure(e), + action.client_ip, ) raise @@ -375,6 +377,7 @@ async def _create_login_session( access_key=keypair_row.access_key, domain_name=action.domain_name, login_client_type_id=login_client_type_id, + client_ip=action.client_ip, ) if tokens_to_invalidate: @@ -421,9 +424,12 @@ async def _record_login_failure( user_uuid: uuid.UUID, domain_name: str, result: LoginAttemptResult, + client_ip: str | None, ) -> None: try: - await self._auth_repository.record_login_history(user_uuid, domain_name, result) + await self._auth_repository.record_login_history( + user_uuid, domain_name, result, client_ip=client_ip + ) except Exception: log.warning("Failed to record login history: {} for user {}", result, user_uuid) @@ -541,7 +547,7 @@ async def signup(self, action: SignupAction) -> SignupActionResult: async def logout(self, action: LogoutAction) -> LogoutActionResult: await self._auth_repository.delete_login_session_by_token( - action.session_token, LoginAttemptResult.LOGOUT + action.session_token, LoginAttemptResult.LOGOUT, action.client_ip ) await self._valkey_session_client.delete_login_session(action.session_token) return LogoutActionResult(success=True) @@ -550,7 +556,7 @@ async def admin_revoke_login_session( self, action: AdminRevokeLoginSessionAction ) -> RevokeLoginSessionActionResult: session_token = await self._auth_repository.delete_login_session_by_id( - action.session_id, LoginAttemptResult.REVOKED_BY_ADMIN + action.session_id, LoginAttemptResult.REVOKED_BY_ADMIN, current_client_ip() ) await self._valkey_session_client.delete_login_session(session_token) return RevokeLoginSessionActionResult(success=True) @@ -562,7 +568,7 @@ async def my_revoke_login_session( if session_data.user_id != action.user_id: raise GenericForbidden("You can only revoke your own login sessions.") session_token = await self._auth_repository.delete_login_session_by_id( - action.session_id, LoginAttemptResult.REVOKED_BY_USER + action.session_id, LoginAttemptResult.REVOKED_BY_USER, current_client_ip() ) await self._valkey_session_client.delete_login_session(session_token) return RevokeLoginSessionActionResult(success=True) @@ -583,7 +589,7 @@ async def signout(self, action: SignoutAction) -> SignoutActionResult: action.password, ) deleted_tokens = await self._auth_repository.delete_user_login_sessions( - action.user_id, action.domain_name, LoginAttemptResult.LOGOUT + action.user_id, action.domain_name, LoginAttemptResult.LOGOUT, current_client_ip() ) for token in deleted_tokens: await self._valkey_session_client.delete_login_session(token) diff --git a/tests/unit/manager/repositories/auth/test_login_history_client_ip.py b/tests/unit/manager/repositories/auth/test_login_history_client_ip.py new file mode 100644 index 00000000000..276d553fcba --- /dev/null +++ b/tests/unit/manager/repositories/auth/test_login_history_client_ip.py @@ -0,0 +1,305 @@ +"""Tests that verify ``client_ip`` is persisted on ``login_history`` rows. + +The login_history flow is driven by ``AuthDBSource`` — every public mutation +method that writes a history row (``create_login_session``, +``record_login_history``, ``delete_session_by_token``, ``delete_session_by_id``, +``delete_sessions_by_user``, ``delete_sessions_by_tokens``) accepts a +``client_ip`` kwarg and we want to keep the contract that the value flows +into the column verbatim (or stays NULL for system-driven events). +""" + +from __future__ import annotations + +import uuid +from collections.abc import AsyncGenerator +from dataclasses import dataclass + +import pytest +import sqlalchemy as sa + +from ai.backend.common.types import ResourceSlot +from ai.backend.manager.models.domain import DomainRow +from ai.backend.manager.models.keypair import KeyPairRow +from ai.backend.manager.models.login_client_type.row import LoginClientTypeRow +from ai.backend.manager.models.login_session.enums import LoginAttemptResult +from ai.backend.manager.models.login_session.row import LoginHistoryRow, LoginSessionRow +from ai.backend.manager.models.resource_policy import ( + KeyPairResourcePolicyRow, + UserResourcePolicyRow, +) +from ai.backend.manager.models.user import UserRole, UserRow +from ai.backend.manager.models.utils import ExtendedAsyncSAEngine +from ai.backend.manager.repositories.auth.db_source.db_source import AuthDBSource +from ai.backend.testutils.db import with_tables + + +@dataclass +class SampleUser: + user_id: uuid.UUID + domain_name: str + access_key: str + + +class TestLoginHistoryClientIP: + """Every login_history-writing path on ``AuthDBSource`` records ``client_ip``.""" + + @pytest.fixture + async def db( + self, database_connection: ExtendedAsyncSAEngine + ) -> AsyncGenerator[ExtendedAsyncSAEngine, None]: + async with with_tables( + database_connection, + [ + DomainRow, + UserResourcePolicyRow, + KeyPairResourcePolicyRow, + UserRow, + KeyPairRow, + LoginClientTypeRow, + LoginSessionRow, + LoginHistoryRow, + ], + ): + yield database_connection + + @pytest.fixture + async def auth_db_source(self, db: ExtendedAsyncSAEngine) -> AuthDBSource: + return AuthDBSource(db) + + @pytest.fixture + def client_ip(self) -> str: + return "1.2.3.4" + + @pytest.fixture + async def sample(self, db: ExtendedAsyncSAEngine) -> AsyncGenerator[SampleUser, None]: + domain_name = f"test-domain-{uuid.uuid4()}" + user_uuid = uuid.uuid4() + email = f"test-{uuid.uuid4()}@example.com" + access_key = f"AKIA{uuid.uuid4().hex[:16]}" + + async with db.begin_session() as db_sess: + db_sess.add( + DomainRow( + name=domain_name, + description="test", + is_active=True, + total_resource_slots=ResourceSlot(), + allowed_vfolder_hosts={}, + allowed_docker_registries=[], + ) + ) + db_sess.add( + UserResourcePolicyRow( + name="test-user-policy", + max_vfolder_count=10, + max_quota_scope_size=-1, + max_session_count_per_model_session=10, + max_customized_image_count=10, + ) + ) + db_sess.add( + KeyPairResourcePolicyRow( + name="test-keypair-policy", + max_concurrent_sessions=10, + max_concurrent_sftp_sessions=2, + max_containers_per_session=10, + idle_timeout=3600, + ) + ) + await db_sess.flush() + + user_row = UserRow( + uuid=user_uuid, + username=email, + email=email, + password=None, + domain_name=domain_name, + role=UserRole.USER, + resource_policy="test-user-policy", + need_password_change=False, + ) + db_sess.add(user_row) + await db_sess.flush() + + keypair = KeyPairRow( + access_key=access_key, + secret_key="test_secret_key", + user_id=email, + user=user_uuid, + is_active=True, + resource_policy="test-keypair-policy", + ) + db_sess.add(keypair) + await db_sess.flush() + user_row.main_access_key = access_key + await db_sess.commit() + + yield SampleUser(user_id=user_uuid, domain_name=domain_name, access_key=access_key) + + @staticmethod + async def _fetch_client_ips( + db: ExtendedAsyncSAEngine, + user_id: uuid.UUID, + result: LoginAttemptResult, + ) -> list[str | None]: + async with db.begin_readonly() as conn: + rows = await conn.execute( + sa.select(LoginHistoryRow.__table__.c.client_ip) + .where(LoginHistoryRow.__table__.c.user_id == user_id) + .where(LoginHistoryRow.__table__.c.result == result) + ) + return [row.client_ip for row in rows] + + async def test_create_login_session_records_client_ip( + self, + auth_db_source: AuthDBSource, + db: ExtendedAsyncSAEngine, + sample: SampleUser, + client_ip: str, + ) -> None: + await auth_db_source.create_login_session( + user_id=sample.user_id, + access_key=sample.access_key, + domain_name=sample.domain_name, + client_ip=client_ip, + ) + ips = await self._fetch_client_ips(db, sample.user_id, LoginAttemptResult.SUCCESS) + assert ips == [client_ip] + + async def test_record_login_history_records_client_ip( + self, + auth_db_source: AuthDBSource, + db: ExtendedAsyncSAEngine, + sample: SampleUser, + client_ip: str, + ) -> None: + await auth_db_source.record_login_history( + user_id=sample.user_id, + domain_name=sample.domain_name, + result=LoginAttemptResult.FAILED_INVALID_CREDENTIALS, + client_ip=client_ip, + ) + ips = await self._fetch_client_ips( + db, sample.user_id, LoginAttemptResult.FAILED_INVALID_CREDENTIALS + ) + assert ips == [client_ip] + + async def test_delete_session_by_token_records_client_ip( + self, + auth_db_source: AuthDBSource, + db: ExtendedAsyncSAEngine, + sample: SampleUser, + client_ip: str, + ) -> None: + session = await auth_db_source.create_login_session( + user_id=sample.user_id, + access_key=sample.access_key, + domain_name=sample.domain_name, + client_ip=client_ip, + ) + await auth_db_source.delete_session_by_token( + session.session_token, + LoginAttemptResult.LOGOUT, + client_ip=client_ip, + ) + ips = await self._fetch_client_ips(db, sample.user_id, LoginAttemptResult.LOGOUT) + assert ips == [client_ip] + + async def test_delete_session_by_id_records_client_ip( + self, + auth_db_source: AuthDBSource, + db: ExtendedAsyncSAEngine, + sample: SampleUser, + client_ip: str, + ) -> None: + session = await auth_db_source.create_login_session( + user_id=sample.user_id, + access_key=sample.access_key, + domain_name=sample.domain_name, + client_ip=client_ip, + ) + async with db.begin_readonly() as conn: + session_id = await conn.scalar( + sa.select(LoginSessionRow.__table__.c.id).where( + LoginSessionRow.__table__.c.session_token == session.session_token + ) + ) + assert session_id is not None + await auth_db_source.delete_session_by_id( + session_id, + LoginAttemptResult.REVOKED_BY_ADMIN, + client_ip=client_ip, + ) + ips = await self._fetch_client_ips(db, sample.user_id, LoginAttemptResult.REVOKED_BY_ADMIN) + assert ips == [client_ip] + + async def test_delete_sessions_by_user_records_client_ip( + self, + auth_db_source: AuthDBSource, + db: ExtendedAsyncSAEngine, + sample: SampleUser, + client_ip: str, + ) -> None: + for _ in range(2): + await auth_db_source.create_login_session( + user_id=sample.user_id, + access_key=sample.access_key, + domain_name=sample.domain_name, + client_ip=client_ip, + ) + await auth_db_source.delete_sessions_by_user( + user_id=sample.user_id, + domain_name=sample.domain_name, + result=LoginAttemptResult.LOGOUT, + client_ip=client_ip, + ) + ips = await self._fetch_client_ips(db, sample.user_id, LoginAttemptResult.LOGOUT) + assert ips == [client_ip, client_ip] + + async def test_eviction_writes_separate_history_row_without_dropping_login_ip( + self, + auth_db_source: AuthDBSource, + db: ExtendedAsyncSAEngine, + sample: SampleUser, + client_ip: str, + ) -> None: + """Eviction is a system-driven event that appends a NEW ``login_history`` row. + + - The original ``SUCCESS`` row's ``client_ip`` is preserved (login_history is + append-only; deleting a ``login_session`` does not touch existing history rows). + - The new ``EVICTED`` row carries ``client_ip = NULL`` because no client + initiated the eviction — the service calls ``delete_sessions_by_tokens`` + without a ``client_ip`` argument for system-driven results + (``EVICTED`` / ``EXPIRED``). + """ + session = await auth_db_source.create_login_session( + user_id=sample.user_id, + access_key=sample.access_key, + domain_name=sample.domain_name, + client_ip=client_ip, + ) + await auth_db_source.delete_sessions_by_tokens( + [session.session_token], + LoginAttemptResult.EVICTED, + ) + + # Original SUCCESS row keeps its client_ip. + success_ips = await self._fetch_client_ips(db, sample.user_id, LoginAttemptResult.SUCCESS) + assert success_ips == [client_ip] + # New EVICTED row is system-driven, so client_ip is NULL. + evicted_ips = await self._fetch_client_ips(db, sample.user_id, LoginAttemptResult.EVICTED) + assert evicted_ips == [None] + + async def test_client_ip_defaults_to_null_when_not_provided( + self, + auth_db_source: AuthDBSource, + db: ExtendedAsyncSAEngine, + sample: SampleUser, + ) -> None: + await auth_db_source.create_login_session( + user_id=sample.user_id, + access_key=sample.access_key, + domain_name=sample.domain_name, + ) + ips = await self._fetch_client_ips(db, sample.user_id, LoginAttemptResult.SUCCESS) + assert ips == [None]