-
-
Notifications
You must be signed in to change notification settings - Fork 98
CSRF functionality for templates #1087
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| """Shared CSRF utilities for Air middleware and template helpers.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import secrets | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| if TYPE_CHECKING: | ||
| from starlette.requests import Request | ||
| from starlette.types import Scope | ||
|
|
||
| DEFAULT_CSRF_COOKIE_NAME = "air_csrf_token" | ||
| DEFAULT_CSRF_HEADER_NAME = "X-CSRF-Token" | ||
| DEFAULT_CSRF_FORM_FIELD_NAME = "csrf_token" | ||
|
|
||
| CSRF_STATE_TOKEN_KEY = "csrf_token" | ||
| CSRF_STATE_COOKIE_NAME_KEY = "csrf_cookie_name" | ||
| CSRF_STATE_FORM_FIELD_NAME_KEY = "csrf_form_field_name" | ||
|
|
||
|
|
||
| def new_csrf_token() -> str: | ||
| """Generate a secure CSRF token.""" | ||
| return secrets.token_urlsafe(32) | ||
|
|
||
|
|
||
| def set_csrf_state(scope: Scope, *, token: str, cookie_name: str, form_field_name: str) -> None: | ||
| """Attach CSRF data to request state for downstream consumers.""" | ||
| state = scope.setdefault("state", {}) | ||
| state[CSRF_STATE_TOKEN_KEY] = token | ||
| state[CSRF_STATE_COOKIE_NAME_KEY] = cookie_name | ||
| state[CSRF_STATE_FORM_FIELD_NAME_KEY] = form_field_name | ||
|
|
||
|
|
||
| def get_csrf_cookie_name_from_request(request: Request) -> str: | ||
| """Get the active CSRF cookie name for a request.""" | ||
| cookie_name = getattr(request.state, CSRF_STATE_COOKIE_NAME_KEY, None) | ||
| if isinstance(cookie_name, str): | ||
| return cookie_name | ||
| return DEFAULT_CSRF_COOKIE_NAME | ||
|
|
||
|
|
||
| def get_csrf_form_field_name_from_request(request: Request) -> str: | ||
| """Get the active CSRF form field name for a request.""" | ||
| form_field_name = getattr(request.state, CSRF_STATE_FORM_FIELD_NAME_KEY, None) | ||
| if isinstance(form_field_name, str): | ||
| return form_field_name | ||
| return DEFAULT_CSRF_FORM_FIELD_NAME | ||
|
|
||
|
|
||
| def get_csrf_token_from_request(request: Request) -> str: | ||
| """Resolve CSRF token from request state or cookie. | ||
|
|
||
| Raises: | ||
| RuntimeError: If token is unavailable, typically because CSRFMiddleware is not installed. | ||
| """ | ||
| token = getattr(request.state, CSRF_STATE_TOKEN_KEY, None) | ||
| if isinstance(token, str): | ||
| return token | ||
|
|
||
| cookie_name = get_csrf_cookie_name_from_request(request) | ||
| cookie_token = request.cookies.get(cookie_name) | ||
| if isinstance(cookie_token, str): | ||
| return cookie_token | ||
|
|
||
| msg = "CSRF token not found on request. Add air.CSRFMiddleware before using csrf helpers." | ||
| raise RuntimeError(msg) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,7 +11,223 @@ | |
| Background tasks run _after_ middleware. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import hmac | ||
| from http.cookies import SimpleCookie | ||
| from typing import TYPE_CHECKING, Literal | ||
|
|
||
| from starlette.datastructures import Headers, MutableHeaders | ||
| from starlette.middleware.sessions import SessionMiddleware as StarletteSessionMiddleware | ||
| from starlette.requests import Request | ||
| from starlette.responses import PlainTextResponse, Response | ||
|
|
||
| from .csrf import ( | ||
| DEFAULT_CSRF_COOKIE_NAME, | ||
| DEFAULT_CSRF_FORM_FIELD_NAME, | ||
| DEFAULT_CSRF_HEADER_NAME, | ||
| new_csrf_token, | ||
| set_csrf_state, | ||
| ) | ||
|
|
||
| _SAFE_HTTP_METHODS = {"GET", "HEAD", "OPTIONS", "TRACE"} | ||
|
|
||
| if TYPE_CHECKING: | ||
| from collections.abc import Iterable | ||
|
|
||
| from starlette.types import ASGIApp, Message, Receive, Scope, Send | ||
|
|
||
|
|
||
| def _cookie_value(cookie_header: str | None, name: str) -> str | None: | ||
| if cookie_header is None: | ||
| return None | ||
| cookie = SimpleCookie() | ||
| cookie.load(cookie_header) | ||
| morsel = cookie.get(name) | ||
| if morsel is None: | ||
| return None | ||
| return morsel.value | ||
|
|
||
|
|
||
| async def _read_body(receive: Receive) -> bytes: | ||
| chunks: list[bytes] = [] | ||
| more_body = True | ||
|
|
||
| while more_body: | ||
| message = await receive() | ||
| if message["type"] == "http.disconnect": | ||
| break | ||
| if message["type"] != "http.request": | ||
| continue | ||
| chunks.append(message.get("body", b"")) | ||
| more_body = message.get("more_body", False) | ||
|
|
||
| return b"".join(chunks) | ||
|
|
||
|
|
||
| def _receive_from_body(body: bytes) -> Receive: | ||
| has_emitted = False | ||
|
|
||
| async def receive() -> Message: | ||
| nonlocal has_emitted | ||
| if has_emitted: | ||
| return {"type": "http.request", "body": b"", "more_body": False} | ||
| has_emitted = True | ||
| return {"type": "http.request", "body": body, "more_body": False} | ||
|
|
||
| return receive | ||
|
|
||
|
|
||
| class CSRFMiddleware: | ||
| """Validate CSRF tokens using the double-submit cookie pattern. | ||
|
|
||
| This middleware is intended for form workflows where templates render a hidden token input. | ||
| For Jinja templates, use the built-in globals: | ||
|
|
||
| - `csrf_token()` to read the request token | ||
| - `csrf_input()` to render `<input type="hidden" ...>` | ||
|
|
||
| TODO: Make this a default middleware always installed by Air | ||
|
|
||
| Example: | ||
|
|
||
| import air | ||
|
|
||
| app = air.Air() | ||
| app.add_middleware(air.CSRFMiddleware) | ||
|
|
||
| jinja = air.JinjaRenderer(directory="templates") | ||
|
|
||
|
|
||
| @app.get("/contact") | ||
| def contact(request: air.Request): | ||
| return jinja(request, "contact.html") | ||
|
|
||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| app: ASGIApp, | ||
| *, | ||
| cookie_name: str = DEFAULT_CSRF_COOKIE_NAME, | ||
| header_name: str = DEFAULT_CSRF_HEADER_NAME, | ||
| form_field_name: str = DEFAULT_CSRF_FORM_FIELD_NAME, | ||
| safe_methods: Iterable[str] = _SAFE_HTTP_METHODS, | ||
| exempt_paths: Iterable[str] | None = None, | ||
| cookie_path: str = "/", | ||
| cookie_domain: str | None = None, | ||
| cookie_secure: bool = False, | ||
| cookie_httponly: bool = True, | ||
| cookie_samesite: Literal["lax", "strict", "none"] = "lax", | ||
| cookie_max_age: int | None = None, | ||
| error_message: str = "CSRF verification failed.", | ||
| ) -> None: | ||
| self.app = app | ||
| self.cookie_name = cookie_name | ||
| self.header_name = header_name | ||
| self.form_field_name = form_field_name | ||
| self.safe_methods = {method.upper() for method in safe_methods} | ||
| self.exempt_paths = set(exempt_paths or ()) | ||
| self.cookie_path = cookie_path | ||
| self.cookie_domain = cookie_domain | ||
| self.cookie_secure = cookie_secure | ||
| self.cookie_httponly = cookie_httponly | ||
| self.cookie_samesite = cookie_samesite | ||
| self.cookie_max_age = cookie_max_age | ||
| self.error_message = error_message | ||
|
|
||
| async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: | ||
| if scope["type"] != "http": | ||
| await self.app(scope, receive, send) | ||
| return | ||
|
|
||
| method = scope["method"].upper() | ||
| path = scope["path"] | ||
| headers = Headers(scope=scope) | ||
| cookie_token = _cookie_value(headers.get("cookie"), self.cookie_name) | ||
|
|
||
| if method in self.safe_methods or path in self.exempt_paths: | ||
| token = cookie_token or new_csrf_token() | ||
| set_csrf_state(scope, token=token, cookie_name=self.cookie_name, form_field_name=self.form_field_name) | ||
| await self._call_next_with_optional_cookie( | ||
| scope, | ||
| receive, | ||
| send, | ||
| set_cookie=(cookie_token is None), | ||
| token=token, | ||
| ) | ||
| return | ||
|
|
||
| token = cookie_token | ||
| if token is None: | ||
| await PlainTextResponse(self.error_message, status_code=403)(scope, receive, send) | ||
| return | ||
|
|
||
| submitted_token = headers.get(self.header_name) | ||
| downstream_receive = receive | ||
|
|
||
| if submitted_token is None: | ||
| buffered_body = await _read_body(receive) | ||
| submitted_token = await self._extract_form_token(scope, headers, buffered_body) | ||
| downstream_receive = _receive_from_body(buffered_body) | ||
|
Comment on lines
+169
to
+172
|
||
|
|
||
| if submitted_token is None or not hmac.compare_digest(submitted_token, token): | ||
| await PlainTextResponse(self.error_message, status_code=403)(scope, receive, send) | ||
| return | ||
|
|
||
| set_csrf_state(scope, token=token, cookie_name=self.cookie_name, form_field_name=self.form_field_name) | ||
| await self.app(scope, downstream_receive, send) | ||
|
|
||
| async def _call_next_with_optional_cookie( | ||
| self, | ||
| scope: Scope, | ||
| receive: Receive, | ||
| send: Send, | ||
| *, | ||
| set_cookie: bool, | ||
| token: str, | ||
| ) -> None: | ||
| if not set_cookie: | ||
| await self.app(scope, receive, send) | ||
| return | ||
|
|
||
| cookie_builder = Response() | ||
| cookie_builder.set_cookie( | ||
| key=self.cookie_name, | ||
| value=token, | ||
| max_age=self.cookie_max_age, | ||
| path=self.cookie_path, | ||
| domain=self.cookie_domain, | ||
| secure=self.cookie_secure, | ||
| httponly=self.cookie_httponly, | ||
| samesite=self.cookie_samesite, | ||
| ) | ||
| set_cookie_header = cookie_builder.headers["set-cookie"] | ||
|
|
||
| async def send_with_cookie(message: Message) -> None: | ||
| if message["type"] == "http.response.start": | ||
| mutable_headers = MutableHeaders(scope=message) | ||
| mutable_headers.append("set-cookie", set_cookie_header) | ||
| await send(message) | ||
|
|
||
| await self.app(scope, receive, send_with_cookie) | ||
|
|
||
| async def _extract_form_token(self, scope: Scope, headers: Headers, body: bytes) -> str | None: | ||
| content_type = headers.get("content-type", "").lower() | ||
| if not content_type.startswith(("application/x-www-form-urlencoded", "multipart/form-data")): | ||
| return None | ||
|
|
||
| request = Request(scope, receive=_receive_from_body(body)) | ||
|
|
||
| try: | ||
| form_data = await request.form() | ||
| except (RuntimeError, TypeError, ValueError): | ||
| return None | ||
|
|
||
| token = form_data.get(self.form_field_name) | ||
| if isinstance(token, str): | ||
| return token | ||
| return None | ||
|
|
||
|
|
||
| class SessionMiddleware(StarletteSessionMiddleware): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| <form method="post" action="/submit"> | ||
| {{ csrf_input() }} | ||
| <input type="text" name="name" /> | ||
| <button type="submit">Submit</button> | ||
| </form> | ||
| <p id="token">{{ csrf_token() }}</p> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SameSite=Nonecookies are rejected by modern browsers unless they also have theSecureattribute. Since this middleware exposescookie_samesiteandcookie_secureindependently, settingcookie_samesite="none"withcookie_secure=Falsewill silently break CSRF (cookie never stored/sent). Consider validating these options in__init__and raising aValueError(or auto-enablingcookie_secure) whencookie_samesite == "none"andcookie_secureis false.