diff --git a/src/air/__init__.py b/src/air/__init__.py index 85b087a1b..2f12c7967 100644 --- a/src/air/__init__.py +++ b/src/air/__init__.py @@ -23,7 +23,10 @@ AirForm as AirForm, to_form as to_form, ) -from .middleware import SessionMiddleware as SessionMiddleware +from .middleware import ( + CSRFMiddleware as CSRFMiddleware, + SessionMiddleware as SessionMiddleware, +) from .models import ( AirModel as AirModel, ) diff --git a/src/air/csrf.py b/src/air/csrf.py new file mode 100644 index 000000000..91f04caf1 --- /dev/null +++ b/src/air/csrf.py @@ -0,0 +1,66 @@ +"""Shared CSRF utilities for Air middleware and template helpers.""" + +from __future__ import annotations + +import secrets +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from starlette.requests import Request + from starlette.types import Scope + +DEFAULT_CSRF_COOKIE_NAME = "air_csrf_token" +DEFAULT_CSRF_HEADER_NAME = "X-CSRF-Token" +DEFAULT_CSRF_FORM_FIELD_NAME = "csrf_token" + +CSRF_STATE_TOKEN_KEY = "csrf_token" +CSRF_STATE_COOKIE_NAME_KEY = "csrf_cookie_name" +CSRF_STATE_FORM_FIELD_NAME_KEY = "csrf_form_field_name" + + +def new_csrf_token() -> str: + """Generate a secure CSRF token.""" + return secrets.token_urlsafe(32) + + +def set_csrf_state(scope: Scope, *, token: str, cookie_name: str, form_field_name: str) -> None: + """Attach CSRF data to request state for downstream consumers.""" + state = scope.setdefault("state", {}) + state[CSRF_STATE_TOKEN_KEY] = token + state[CSRF_STATE_COOKIE_NAME_KEY] = cookie_name + state[CSRF_STATE_FORM_FIELD_NAME_KEY] = form_field_name + + +def get_csrf_cookie_name_from_request(request: Request) -> str: + """Get the active CSRF cookie name for a request.""" + cookie_name = getattr(request.state, CSRF_STATE_COOKIE_NAME_KEY, None) + if isinstance(cookie_name, str): + return cookie_name + return DEFAULT_CSRF_COOKIE_NAME + + +def get_csrf_form_field_name_from_request(request: Request) -> str: + """Get the active CSRF form field name for a request.""" + form_field_name = getattr(request.state, CSRF_STATE_FORM_FIELD_NAME_KEY, None) + if isinstance(form_field_name, str): + return form_field_name + return DEFAULT_CSRF_FORM_FIELD_NAME + + +def get_csrf_token_from_request(request: Request) -> str: + """Resolve CSRF token from request state or cookie. + + Raises: + RuntimeError: If token is unavailable, typically because CSRFMiddleware is not installed. + """ + token = getattr(request.state, CSRF_STATE_TOKEN_KEY, None) + if isinstance(token, str): + return token + + cookie_name = get_csrf_cookie_name_from_request(request) + cookie_token = request.cookies.get(cookie_name) + if isinstance(cookie_token, str): + return cookie_token + + msg = "CSRF token not found on request. Add air.CSRFMiddleware before using csrf helpers." + raise RuntimeError(msg) diff --git a/src/air/middleware.py b/src/air/middleware.py index aa173cb6e..0e5f7cf7f 100644 --- a/src/air/middleware.py +++ b/src/air/middleware.py @@ -11,7 +11,223 @@ Background tasks run _after_ middleware. """ +from __future__ import annotations + +import hmac +from http.cookies import SimpleCookie +from typing import TYPE_CHECKING, Literal + +from starlette.datastructures import Headers, MutableHeaders from starlette.middleware.sessions import SessionMiddleware as StarletteSessionMiddleware +from starlette.requests import Request +from starlette.responses import PlainTextResponse, Response + +from .csrf import ( + DEFAULT_CSRF_COOKIE_NAME, + DEFAULT_CSRF_FORM_FIELD_NAME, + DEFAULT_CSRF_HEADER_NAME, + new_csrf_token, + set_csrf_state, +) + +_SAFE_HTTP_METHODS = {"GET", "HEAD", "OPTIONS", "TRACE"} + +if TYPE_CHECKING: + from collections.abc import Iterable + + from starlette.types import ASGIApp, Message, Receive, Scope, Send + + +def _cookie_value(cookie_header: str | None, name: str) -> str | None: + if cookie_header is None: + return None + cookie = SimpleCookie() + cookie.load(cookie_header) + morsel = cookie.get(name) + if morsel is None: + return None + return morsel.value + + +async def _read_body(receive: Receive) -> bytes: + chunks: list[bytes] = [] + more_body = True + + while more_body: + message = await receive() + if message["type"] == "http.disconnect": + break + if message["type"] != "http.request": + continue + chunks.append(message.get("body", b"")) + more_body = message.get("more_body", False) + + return b"".join(chunks) + + +def _receive_from_body(body: bytes) -> Receive: + has_emitted = False + + async def receive() -> Message: + nonlocal has_emitted + if has_emitted: + return {"type": "http.request", "body": b"", "more_body": False} + has_emitted = True + return {"type": "http.request", "body": body, "more_body": False} + + return receive + + +class CSRFMiddleware: + """Validate CSRF tokens using the double-submit cookie pattern. + + This middleware is intended for form workflows where templates render a hidden token input. + For Jinja templates, use the built-in globals: + + - `csrf_token()` to read the request token + - `csrf_input()` to render `` + + TODO: Make this a default middleware always installed by Air + + Example: + + import air + + app = air.Air() + app.add_middleware(air.CSRFMiddleware) + + jinja = air.JinjaRenderer(directory="templates") + + + @app.get("/contact") + def contact(request: air.Request): + return jinja(request, "contact.html") + + """ + + def __init__( + self, + app: ASGIApp, + *, + cookie_name: str = DEFAULT_CSRF_COOKIE_NAME, + header_name: str = DEFAULT_CSRF_HEADER_NAME, + form_field_name: str = DEFAULT_CSRF_FORM_FIELD_NAME, + safe_methods: Iterable[str] = _SAFE_HTTP_METHODS, + exempt_paths: Iterable[str] | None = None, + cookie_path: str = "/", + cookie_domain: str | None = None, + cookie_secure: bool = False, + cookie_httponly: bool = True, + cookie_samesite: Literal["lax", "strict", "none"] = "lax", + cookie_max_age: int | None = None, + error_message: str = "CSRF verification failed.", + ) -> None: + self.app = app + self.cookie_name = cookie_name + self.header_name = header_name + self.form_field_name = form_field_name + self.safe_methods = {method.upper() for method in safe_methods} + self.exempt_paths = set(exempt_paths or ()) + self.cookie_path = cookie_path + self.cookie_domain = cookie_domain + self.cookie_secure = cookie_secure + self.cookie_httponly = cookie_httponly + self.cookie_samesite = cookie_samesite + self.cookie_max_age = cookie_max_age + self.error_message = error_message + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + method = scope["method"].upper() + path = scope["path"] + headers = Headers(scope=scope) + cookie_token = _cookie_value(headers.get("cookie"), self.cookie_name) + + if method in self.safe_methods or path in self.exempt_paths: + token = cookie_token or new_csrf_token() + set_csrf_state(scope, token=token, cookie_name=self.cookie_name, form_field_name=self.form_field_name) + await self._call_next_with_optional_cookie( + scope, + receive, + send, + set_cookie=(cookie_token is None), + token=token, + ) + return + + token = cookie_token + if token is None: + await PlainTextResponse(self.error_message, status_code=403)(scope, receive, send) + return + + submitted_token = headers.get(self.header_name) + downstream_receive = receive + + if submitted_token is None: + buffered_body = await _read_body(receive) + submitted_token = await self._extract_form_token(scope, headers, buffered_body) + downstream_receive = _receive_from_body(buffered_body) + + if submitted_token is None or not hmac.compare_digest(submitted_token, token): + await PlainTextResponse(self.error_message, status_code=403)(scope, receive, send) + return + + set_csrf_state(scope, token=token, cookie_name=self.cookie_name, form_field_name=self.form_field_name) + await self.app(scope, downstream_receive, send) + + async def _call_next_with_optional_cookie( + self, + scope: Scope, + receive: Receive, + send: Send, + *, + set_cookie: bool, + token: str, + ) -> None: + if not set_cookie: + await self.app(scope, receive, send) + return + + cookie_builder = Response() + cookie_builder.set_cookie( + key=self.cookie_name, + value=token, + max_age=self.cookie_max_age, + path=self.cookie_path, + domain=self.cookie_domain, + secure=self.cookie_secure, + httponly=self.cookie_httponly, + samesite=self.cookie_samesite, + ) + set_cookie_header = cookie_builder.headers["set-cookie"] + + async def send_with_cookie(message: Message) -> None: + if message["type"] == "http.response.start": + mutable_headers = MutableHeaders(scope=message) + mutable_headers.append("set-cookie", set_cookie_header) + await send(message) + + await self.app(scope, receive, send_with_cookie) + + async def _extract_form_token(self, scope: Scope, headers: Headers, body: bytes) -> str | None: + content_type = headers.get("content-type", "").lower() + if not content_type.startswith(("application/x-www-form-urlencoded", "multipart/form-data")): + return None + + request = Request(scope, receive=_receive_from_body(body)) + + try: + form_data = await request.form() + except (RuntimeError, TypeError, ValueError): + return None + + token = form_data.get(self.form_field_name) + if isinstance(token, str): + return token + return None class SessionMiddleware(StarletteSessionMiddleware): diff --git a/src/air/templating.py b/src/air/templating.py index ae5b0f23f..2f295dfa2 100644 --- a/src/air/templating.py +++ b/src/air/templating.py @@ -11,10 +11,15 @@ import jinja2 from fastapi.templating import Jinja2Templates +from markupsafe import Markup, escape from starlette.requests import Request as StarletteRequest from starlette.responses import HTMLResponse from starlette.templating import _TemplateResponse +from .csrf import ( + get_csrf_form_field_name_from_request, + get_csrf_token_from_request, +) from .requests import Request from .tags.models.base import BaseTag from .tags.utils import SafeStr @@ -36,6 +41,33 @@ def _jinja_context_item(item: Any) -> Any: return item +def _template_request(context: jinja2.runtime.Context) -> StarletteRequest: + request = context.get("request") + if not isinstance(request, StarletteRequest): + msg = "CSRF template helpers require a request in template context." + raise TypeError(msg) + return request + + +@jinja2.pass_context +def _jinja_csrf_token(context: jinja2.runtime.Context) -> str: + request = _template_request(context) + return get_csrf_token_from_request(request) + + +@jinja2.pass_context +def _jinja_csrf_input(context: jinja2.runtime.Context) -> Markup: + request = _template_request(context) + token = get_csrf_token_from_request(request) + field_name = get_csrf_form_field_name_from_request(request) + return Markup(f'') + + +def _configure_csrf_template_helpers(templates: Jinja2Templates) -> None: + templates.env.globals.setdefault("csrf_token", _jinja_csrf_token) + templates.env.globals.setdefault("csrf_input", _jinja_csrf_input) + + class JinjaRenderer: """Template renderer to make Jinja easier in Air. @@ -86,6 +118,7 @@ def __init__( ) -> None: """Initialize with template directory path""" self.templates = Jinja2Templates(directory=directory, context_processors=context_processors, env=env) + _configure_csrf_template_helpers(self.templates) @overload def __call__( @@ -204,6 +237,7 @@ def __init__( ) -> None: """Initialize with template directory path""" self.templates = Jinja2Templates(directory=directory, context_processors=context_processors, env=env) + _configure_csrf_template_helpers(self.templates) self.package = package def __call__( diff --git a/tests/templates/csrf_form.html b/tests/templates/csrf_form.html new file mode 100644 index 000000000..6007f594c --- /dev/null +++ b/tests/templates/csrf_form.html @@ -0,0 +1,6 @@ +
+{{ csrf_token() }}
diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 27c25b77e..aec75b038 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -1,7 +1,7 @@ from fastapi.testclient import TestClient from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request -from starlette.responses import Response +from starlette.responses import HTMLResponse, Response import air @@ -80,3 +80,66 @@ async def home(request: air.Request, timestamp: int) -> air.Html | air.Children: response = client.get("/check") assert response.status_code == 200 assert "654321" in response.text + + +def _build_csrf_app() -> air.Air: + app = air.Air() + app.add_middleware(air.CSRFMiddleware) + jinja = air.JinjaRenderer(directory="tests/templates") + + @app.get("/form") + def form(request: air.Request) -> HTMLResponse: + return jinja(request, name="csrf_form.html") + + @app.post("/submit") + async def submit(request: air.Request) -> air.P: + form_data = await request.form() + return air.P(f"Hello {form_data.get('name', '')}") + + return app + + +def test_csrf_middleware_accepts_valid_form_token() -> None: + app = _build_csrf_app() + client = TestClient(app) + + form_response = client.get("/form") + token = form_response.cookies.get("air_csrf_token") + assert token is not None + + response = client.post("/submit", data={"csrf_token": token, "name": "Cheddar"}) + assert response.status_code == 200 + assert "Hello Cheddar" in response.text + + +def test_csrf_middleware_rejects_missing_form_token() -> None: + app = _build_csrf_app() + client = TestClient(app) + + client.get("/form") + response = client.post("/submit", data={"name": "Brie"}) + assert response.status_code == 403 + assert response.text == "CSRF verification failed." + + +def test_csrf_middleware_rejects_invalid_form_token() -> None: + app = _build_csrf_app() + client = TestClient(app) + + client.get("/form") + response = client.post("/submit", data={"csrf_token": "bad-token", "name": "Gouda"}) + assert response.status_code == 403 + assert response.text == "CSRF verification failed." + + +def test_csrf_middleware_accepts_header_token() -> None: + app = _build_csrf_app() + client = TestClient(app) + + form_response = client.get("/form") + token = form_response.cookies.get("air_csrf_token") + assert token is not None + + response = client.post("/submit", data={"name": "Swiss"}, headers={"X-CSRF-Token": token}) + assert response.status_code == 200 + assert "Hello Swiss" in response.text diff --git a/tests/test_templating.py b/tests/test_templating.py index 4efba6234..a2ac91c47 100644 --- a/tests/test_templating.py +++ b/tests/test_templating.py @@ -189,6 +189,25 @@ def test_jinja_renderer_with_env() -> None: assert jinja.templates is not None +def test_jinja_csrf_helpers_render_token_and_hidden_input() -> None: + app = Air() + app.add_middleware(air.CSRFMiddleware) + jinja = JinjaRenderer(directory="tests/templates") + + @app.get("/form") + def form(request: Request) -> HTMLResponse: + return jinja(request, name="csrf_form.html") + + client = TestClient(app) + response = client.get("/form") + + assert response.status_code == 200 + token = response.cookies.get("air_csrf_token") + assert token is not None + assert f'name="csrf_token" value="{token}"' in response.text + assert f'{token}
' in response.text + + def test_renderer() -> None: """Test the Renderer class.""" app = air.Air() diff --git a/uv.lock b/uv.lock index c78088195..20f8c5f25 100644 --- a/uv.lock +++ b/uv.lock @@ -130,14 +130,14 @@ dev = [ { name = "pytest-github-actions-annotate-failures", specifier = ">=0.3.0" }, { name = "pytest-pretty", specifier = ">=1.3.0" }, { name = "ruff", specifier = ">=0.15.0" }, - { name = "ty", specifier = ">=0.0.14" }, + { name = "ty", specifier = ">=0.0.19" }, { name = "types-lxml", specifier = ">=2026.1.1" }, { name = "types-markdown", specifier = ">=3.10.2.20260211" }, { name = "types-pygments", specifier = ">=2.19.0.20251121" }, ] devtools = [ { name = "ruff", specifier = ">=0.15.0" }, - { name = "ty", specifier = ">=0.0.14" }, + { name = "ty", specifier = ">=0.0.19" }, ] docs = [ { name = "click", specifier = ">=8.3.1" },