Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/air/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
AirForm as AirForm,
to_form as to_form,
)
from .middleware import SessionMiddleware as SessionMiddleware
from .middleware import (
CSRFMiddleware as CSRFMiddleware,
SessionMiddleware as SessionMiddleware,
)
from .models import (
AirModel as AirModel,
)
Expand Down
66 changes: 66 additions & 0 deletions src/air/csrf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Shared CSRF utilities for Air middleware and template helpers."""

from __future__ import annotations

import secrets
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from starlette.requests import Request
from starlette.types import Scope

DEFAULT_CSRF_COOKIE_NAME = "air_csrf_token"
DEFAULT_CSRF_HEADER_NAME = "X-CSRF-Token"
DEFAULT_CSRF_FORM_FIELD_NAME = "csrf_token"

CSRF_STATE_TOKEN_KEY = "csrf_token"
CSRF_STATE_COOKIE_NAME_KEY = "csrf_cookie_name"
CSRF_STATE_FORM_FIELD_NAME_KEY = "csrf_form_field_name"


def new_csrf_token() -> str:
"""Generate a secure CSRF token."""
return secrets.token_urlsafe(32)


def set_csrf_state(scope: Scope, *, token: str, cookie_name: str, form_field_name: str) -> None:
"""Attach CSRF data to request state for downstream consumers."""
state = scope.setdefault("state", {})
state[CSRF_STATE_TOKEN_KEY] = token
state[CSRF_STATE_COOKIE_NAME_KEY] = cookie_name
state[CSRF_STATE_FORM_FIELD_NAME_KEY] = form_field_name


def get_csrf_cookie_name_from_request(request: Request) -> str:
"""Get the active CSRF cookie name for a request."""
cookie_name = getattr(request.state, CSRF_STATE_COOKIE_NAME_KEY, None)
if isinstance(cookie_name, str):
return cookie_name
return DEFAULT_CSRF_COOKIE_NAME


def get_csrf_form_field_name_from_request(request: Request) -> str:
"""Get the active CSRF form field name for a request."""
form_field_name = getattr(request.state, CSRF_STATE_FORM_FIELD_NAME_KEY, None)
if isinstance(form_field_name, str):
return form_field_name
return DEFAULT_CSRF_FORM_FIELD_NAME


def get_csrf_token_from_request(request: Request) -> str:
"""Resolve CSRF token from request state or cookie.

Raises:
RuntimeError: If token is unavailable, typically because CSRFMiddleware is not installed.
"""
token = getattr(request.state, CSRF_STATE_TOKEN_KEY, None)
if isinstance(token, str):
return token

cookie_name = get_csrf_cookie_name_from_request(request)
cookie_token = request.cookies.get(cookie_name)
if isinstance(cookie_token, str):
return cookie_token

msg = "CSRF token not found on request. Add air.CSRFMiddleware before using csrf helpers."
raise RuntimeError(msg)
216 changes: 216 additions & 0 deletions src/air/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,223 @@
Background tasks run _after_ middleware.
"""

from __future__ import annotations

import hmac
from http.cookies import SimpleCookie
from typing import TYPE_CHECKING, Literal

from starlette.datastructures import Headers, MutableHeaders
from starlette.middleware.sessions import SessionMiddleware as StarletteSessionMiddleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response

from .csrf import (
DEFAULT_CSRF_COOKIE_NAME,
DEFAULT_CSRF_FORM_FIELD_NAME,
DEFAULT_CSRF_HEADER_NAME,
new_csrf_token,
set_csrf_state,
)

_SAFE_HTTP_METHODS = {"GET", "HEAD", "OPTIONS", "TRACE"}

if TYPE_CHECKING:
from collections.abc import Iterable

from starlette.types import ASGIApp, Message, Receive, Scope, Send


def _cookie_value(cookie_header: str | None, name: str) -> str | None:
if cookie_header is None:
return None
cookie = SimpleCookie()
cookie.load(cookie_header)
morsel = cookie.get(name)
if morsel is None:
return None
return morsel.value


async def _read_body(receive: Receive) -> bytes:
chunks: list[bytes] = []
more_body = True

while more_body:
message = await receive()
if message["type"] == "http.disconnect":
break
if message["type"] != "http.request":
continue
chunks.append(message.get("body", b""))
more_body = message.get("more_body", False)

return b"".join(chunks)


def _receive_from_body(body: bytes) -> Receive:
has_emitted = False

async def receive() -> Message:
nonlocal has_emitted
if has_emitted:
return {"type": "http.request", "body": b"", "more_body": False}
has_emitted = True
return {"type": "http.request", "body": body, "more_body": False}

return receive


class CSRFMiddleware:
"""Validate CSRF tokens using the double-submit cookie pattern.

This middleware is intended for form workflows where templates render a hidden token input.
For Jinja templates, use the built-in globals:

- `csrf_token()` to read the request token
- `csrf_input()` to render `<input type="hidden" ...>`

TODO: Make this a default middleware always installed by Air

Example:

import air

app = air.Air()
app.add_middleware(air.CSRFMiddleware)

jinja = air.JinjaRenderer(directory="templates")


@app.get("/contact")
def contact(request: air.Request):
return jinja(request, "contact.html")

"""

def __init__(
self,
app: ASGIApp,
*,
cookie_name: str = DEFAULT_CSRF_COOKIE_NAME,
header_name: str = DEFAULT_CSRF_HEADER_NAME,
form_field_name: str = DEFAULT_CSRF_FORM_FIELD_NAME,
safe_methods: Iterable[str] = _SAFE_HTTP_METHODS,
exempt_paths: Iterable[str] | None = None,
cookie_path: str = "/",
cookie_domain: str | None = None,
cookie_secure: bool = False,
cookie_httponly: bool = True,
cookie_samesite: Literal["lax", "strict", "none"] = "lax",
cookie_max_age: int | None = None,
error_message: str = "CSRF verification failed.",
) -> None:
Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SameSite=None cookies are rejected by modern browsers unless they also have the Secure attribute. Since this middleware exposes cookie_samesite and cookie_secure independently, setting cookie_samesite="none" with cookie_secure=False will silently break CSRF (cookie never stored/sent). Consider validating these options in __init__ and raising a ValueError (or auto-enabling cookie_secure) when cookie_samesite == "none" and cookie_secure is false.

Suggested change
) -> None:
) -> None:
if cookie_samesite == "none" and not cookie_secure:
raise ValueError(
'CSRFMiddleware: cookie_samesite="none" requires cookie_secure=True '
"for modern browsers to accept the CSRF cookie."
)

Copilot uses AI. Check for mistakes.
self.app = app
self.cookie_name = cookie_name
self.header_name = header_name
self.form_field_name = form_field_name
self.safe_methods = {method.upper() for method in safe_methods}
self.exempt_paths = set(exempt_paths or ())
self.cookie_path = cookie_path
self.cookie_domain = cookie_domain
self.cookie_secure = cookie_secure
self.cookie_httponly = cookie_httponly
self.cookie_samesite = cookie_samesite
self.cookie_max_age = cookie_max_age
self.error_message = error_message

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return

method = scope["method"].upper()
path = scope["path"]
headers = Headers(scope=scope)
cookie_token = _cookie_value(headers.get("cookie"), self.cookie_name)

if method in self.safe_methods or path in self.exempt_paths:
token = cookie_token or new_csrf_token()
set_csrf_state(scope, token=token, cookie_name=self.cookie_name, form_field_name=self.form_field_name)
await self._call_next_with_optional_cookie(
scope,
receive,
send,
set_cookie=(cookie_token is None),
token=token,
)
return

token = cookie_token
if token is None:
await PlainTextResponse(self.error_message, status_code=403)(scope, receive, send)
return

submitted_token = headers.get(self.header_name)
downstream_receive = receive

if submitted_token is None:
buffered_body = await _read_body(receive)
submitted_token = await self._extract_form_token(scope, headers, buffered_body)
downstream_receive = _receive_from_body(buffered_body)
Comment on lines +169 to +172
Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CSRFMiddleware buffers the entire request body into memory when the CSRF header is missing (_read_body + _extract_form_token). For large multipart/form-data submissions (e.g., file uploads), this can cause unbounded memory usage and makes a straightforward DoS vector. Consider enforcing a maximum buffered size (via Content-Length + incremental limit), or requiring the header token for multipart requests and skipping body buffering in that case, returning 413/403 when the body is too large to safely buffer.

Copilot uses AI. Check for mistakes.

if submitted_token is None or not hmac.compare_digest(submitted_token, token):
await PlainTextResponse(self.error_message, status_code=403)(scope, receive, send)
return

set_csrf_state(scope, token=token, cookie_name=self.cookie_name, form_field_name=self.form_field_name)
await self.app(scope, downstream_receive, send)

async def _call_next_with_optional_cookie(
self,
scope: Scope,
receive: Receive,
send: Send,
*,
set_cookie: bool,
token: str,
) -> None:
if not set_cookie:
await self.app(scope, receive, send)
return

cookie_builder = Response()
cookie_builder.set_cookie(
key=self.cookie_name,
value=token,
max_age=self.cookie_max_age,
path=self.cookie_path,
domain=self.cookie_domain,
secure=self.cookie_secure,
httponly=self.cookie_httponly,
samesite=self.cookie_samesite,
)
set_cookie_header = cookie_builder.headers["set-cookie"]

async def send_with_cookie(message: Message) -> None:
if message["type"] == "http.response.start":
mutable_headers = MutableHeaders(scope=message)
mutable_headers.append("set-cookie", set_cookie_header)
await send(message)

await self.app(scope, receive, send_with_cookie)

async def _extract_form_token(self, scope: Scope, headers: Headers, body: bytes) -> str | None:
content_type = headers.get("content-type", "").lower()
if not content_type.startswith(("application/x-www-form-urlencoded", "multipart/form-data")):
return None

request = Request(scope, receive=_receive_from_body(body))

try:
form_data = await request.form()
except (RuntimeError, TypeError, ValueError):
return None

token = form_data.get(self.form_field_name)
if isinstance(token, str):
return token
return None


class SessionMiddleware(StarletteSessionMiddleware):
Expand Down
34 changes: 34 additions & 0 deletions src/air/templating.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,15 @@

import jinja2
from fastapi.templating import Jinja2Templates
from markupsafe import Markup, escape
from starlette.requests import Request as StarletteRequest
from starlette.responses import HTMLResponse
from starlette.templating import _TemplateResponse

from .csrf import (
get_csrf_form_field_name_from_request,
get_csrf_token_from_request,
)
from .requests import Request
from .tags.models.base import BaseTag
from .tags.utils import SafeStr
Expand All @@ -36,6 +41,33 @@ def _jinja_context_item(item: Any) -> Any:
return item


def _template_request(context: jinja2.runtime.Context) -> StarletteRequest:
request = context.get("request")
if not isinstance(request, StarletteRequest):
msg = "CSRF template helpers require a request in template context."
raise TypeError(msg)
return request


@jinja2.pass_context
def _jinja_csrf_token(context: jinja2.runtime.Context) -> str:
request = _template_request(context)
return get_csrf_token_from_request(request)


@jinja2.pass_context
def _jinja_csrf_input(context: jinja2.runtime.Context) -> Markup:
request = _template_request(context)
token = get_csrf_token_from_request(request)
field_name = get_csrf_form_field_name_from_request(request)
return Markup(f'<input type="hidden" name="{escape(field_name)}" value="{escape(token)}">')


def _configure_csrf_template_helpers(templates: Jinja2Templates) -> None:
templates.env.globals.setdefault("csrf_token", _jinja_csrf_token)
templates.env.globals.setdefault("csrf_input", _jinja_csrf_input)


class JinjaRenderer:
"""Template renderer to make Jinja easier in Air.

Expand Down Expand Up @@ -86,6 +118,7 @@ def __init__(
) -> None:
"""Initialize with template directory path"""
self.templates = Jinja2Templates(directory=directory, context_processors=context_processors, env=env)
_configure_csrf_template_helpers(self.templates)

@overload
def __call__(
Expand Down Expand Up @@ -204,6 +237,7 @@ def __init__(
) -> None:
"""Initialize with template directory path"""
self.templates = Jinja2Templates(directory=directory, context_processors=context_processors, env=env)
_configure_csrf_template_helpers(self.templates)
self.package = package

def __call__(
Expand Down
6 changes: 6 additions & 0 deletions tests/templates/csrf_form.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
<form method="post" action="/submit">
{{ csrf_input() }}
<input type="text" name="name" />
<button type="submit">Submit</button>
</form>
<p id="token">{{ csrf_token() }}</p>
Loading