diff --git a/.gitignore b/.gitignore index 4419047d6b..eeeeeb87bf 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,7 @@ go.work.sum # Serena .serena/ + +# Local agent / planning artifacts (not for public commits) +.claude/ +docs/superpowers/ \ No newline at end of file diff --git a/Makefile b/Makefile index f7777d71f7..12ec52c48c 100644 --- a/Makefile +++ b/Makefile @@ -116,7 +116,7 @@ generate: make generate-go generate-go: - rm -rf router/gen && buf generate --path proto/wg/cosmo/node --path proto/wg/cosmo/common --path proto/wg/cosmo/graphqlmetrics --template buf.router.go.gen.yaml + rm -rf router/gen && buf generate --path proto/wg/cosmo/node --path proto/wg/cosmo/common --path proto/wg/cosmo/graphqlmetrics --path proto/wg/cosmo/code_mode/yoko/v1 --template buf.router.go.gen.yaml rm -rf graphqlmetrics/gen && buf generate --path proto/wg/cosmo/graphqlmetrics --path proto/wg/cosmo/common --template buf.graphqlmetrics.go.gen.yaml rm -rf connect-go/wg && buf generate --path proto/wg/cosmo/platform --path proto/wg/cosmo/notifications --path proto/wg/cosmo/common --path proto/wg/cosmo/node --template buf.connect-go.go.gen.yaml @@ -187,6 +187,45 @@ docker-build-minikube: docker-build-local run-subgraphs-local: cd demo && go run cmd/all/main.go +CODE_MODE_GOCACHE ?= /tmp/cosmo-code-mode-go-build-cache + +.PHONY: code-mode-demo code-mode-demo-down code-mode-connect-demo code-mode-connect-demo-down + +# Local Code Mode demo: small federation (employees, family, availability, +# mood) + Yoko mock + Cosmo Router with Code Mode and named operations. +# Router GraphQL on :3002, MCP on :5027. Full instructions, prerequisites +# (codex CLI on PATH), and tear-down: demo/code-mode/README.md. +code-mode-demo: + mkdir -p $(CODE_MODE_GOCACHE) + GOCACHE=$(CODE_MODE_GOCACHE) $(MAKE) -C router build + GOCACHE=$(CODE_MODE_GOCACHE) $(MAKE) -C demo/code-mode build-yoko + GOCACHE=$(CODE_MODE_GOCACHE) $(MAKE) -C demo/code-mode build-stdio-proxy + $(MAKE) -C demo/code-mode compose + ./demo/code-mode/start.sh + +# Tear down anything left behind by code-mode-demo. +code-mode-demo-down: + ./demo/code-mode/start.sh --down + +# Runs the code-mode router from source against the yoko Connect supergraph +# (plugins + composed config live in $(YOKO_DIR)). Uses different ports than +# code-mode-demo (router 3012, MCP 5037, yoko-mock 5038) so both can run at +# the same time. Set YOKO_DIR to your local yoko checkout, e.g. +# `make code-mode-connect-demo YOKO_DIR=/path/to/yoko`. +# Full instructions and prerequisites: demo/code-mode-connect/README.md. +YOKO_DIR ?= + +code-mode-connect-demo: + @if [ -z "$(YOKO_DIR)" ]; then echo "YOKO_DIR is required (path to your yoko checkout). See demo/code-mode-connect/README.md" >&2; exit 1; fi + mkdir -p $(CODE_MODE_GOCACHE) + GOCACHE=$(CODE_MODE_GOCACHE) $(MAKE) -C router build + GOCACHE=$(CODE_MODE_GOCACHE) $(MAKE) -C demo/code-mode build-yoko + YOKO_DIR=$(YOKO_DIR) ./demo/code-mode-connect/start.sh + +# Tear down anything left behind by code-mode-connect-demo. +code-mode-connect-demo-down: + ./demo/code-mode-connect/start.sh --down + sync-go-workspace: cd router && go mod tidy cd demo && make bump-deps diff --git a/demo/code-mode-connect/README.md b/demo/code-mode-connect/README.md new file mode 100644 index 0000000000..cdc494e5ec --- /dev/null +++ b/demo/code-mode-connect/README.md @@ -0,0 +1,56 @@ +# Code Mode Connect Demo + +This demo runs the Code Mode router against an external `yoko` Connect supergraph instead of the local employees federation used by `make code-mode-demo`. +It is useful when you want to exercise Code Mode against a richer set of plugins (Pylon, Linear, PostHog, Circleback, Slack, Notion) served by the `yoko` project. + +It is designed to coexist with `make code-mode-demo`: it uses different ports (router 3012, MCP 5037, yoko-mock 5038), so both demos can run side-by-side. + +## Prerequisites + +- A local checkout of the `yoko` Connect supergraph project (separate repository). + Inside that checkout you must already have built the plugins and composed the supergraph so that the directory contains: + - `config.json` — the composed router config for the yoko supergraph. + - `plugins/` — the plugin binaries the router will load. +- Go (toolchain matching the repo `go.mod`). +- The `codex` CLI on `PATH`, authenticated. The Yoko mock shells out to `codex` for query generation. + +## Run + +From the repository root, set `YOKO_DIR` to your local yoko checkout and run: + +```sh +make code-mode-connect-demo YOKO_DIR=/path/to/yoko +``` + +`YOKO_DIR` is required. +The target fails fast with a clear error if it is missing or if the directory does not contain `config.json`. + +What the target does: + +1. Builds `router/router`. +2. Builds `demo/code-mode/yoko-mock/yoko-mock`. +3. Starts `yoko-mock` on `localhost:5038`. +4. Starts the router with `YOKO_DIR` as its working directory and `demo/code-mode-connect/router-config.yaml` as its config. + The router resolves `config.json` and `plugins/` relative to that CWD, which is why `YOKO_DIR` must be a real composed yoko checkout. + +Expected ports: + +- Router GraphQL: `http://localhost:3012/graphql` +- Code Mode MCP: `http://127.0.0.1:5037/mcp` +- Yoko mock: `http://localhost:5038` + +## Tearing down + +Press Ctrl-C in the foreground terminal. +If anything is left behind, run: + +```sh +make code-mode-connect-demo-down +``` + +The process logs for background services are written to `/tmp/cosmo-code-mode-connect-demo-logs`. + +## Auth headers + +`router-config.yaml` propagates the auth headers expected by the yoko plugins (`X-Pylon-Token`, `X-Linear-Token`, `X-Posthog-Token`, `X-Circleback-Token`, `X-Slack-Token`, `X-Notion-Token`, etc.). +Provide values for these on the request side when calling the router so the plugins can reach their upstream services. diff --git a/demo/code-mode-connect/router-config.yaml b/demo/code-mode-connect/router-config.yaml new file mode 100644 index 0000000000..b2e102fe00 --- /dev/null +++ b/demo/code-mode-connect/router-config.yaml @@ -0,0 +1,82 @@ +version: "1" + +# Different ports than demo/code-mode/router-config.yaml so both demos can run +# side-by-side. See demo/code-mode-connect/start.sh for the matching yoko-mock +# port. +listen_addr: "localhost:3012" +graphql_path: "/graphql" +playground_enabled: false +json_log: false +log_level: info +dev_mode: true +router_registration: false + +# These paths are resolved relative to the router's CWD. start.sh runs the +# router from inside the yoko project dir, so "config.json" and "plugins" are +# the composed supergraph and the plugin binaries that ship with that repo. +execution_config: + file: + path: "config.json" + watch: false + +plugins: + enabled: true + path: "plugins" + +# Header propagation for the yoko plugins. Mirrors yoko/config.yaml so the +# plugins receive the same auth headers when the code-mode router fronts them. +headers: + all: + request: + - op: propagate + named: X-Pylon-Token + - op: propagate + named: X-Linear-Token + - op: propagate + named: X-Linear-Auth-Scheme + - op: propagate + named: X-Posthog-Token + - op: propagate + named: X-Posthog-Host + - op: propagate + named: X-Posthog-Project-Id + - op: propagate + named: X-Circleback-Token + - op: propagate + named: X-Slack-Token + - op: propagate + named: X-Notion-Token + +graphql_metrics: + enabled: false + +telemetry: + tracing: + enabled: false + metrics: + otlp: + enabled: false + prometheus: + enabled: false + +mcp: + enabled: false + graph_name: code-mode-connect-demo + router_url: http://localhost:3012/graphql + session: + stateless: false + code_mode: + enabled: true + server: + # IPv4-only bind, see demo/code-mode/router-config.yaml for the why. + listen_addr: 127.0.0.1:5037 + require_mutation_approval: true + sandbox: + timeout: 180s + query_generation: + enabled: true + endpoint: http://localhost:5038 + timeout: 180s + execute_timeout: 180s + named_ops: + enabled: true diff --git a/demo/code-mode-connect/start.sh b/demo/code-mode-connect/start.sh new file mode 100755 index 0000000000..7379fa85d3 --- /dev/null +++ b/demo/code-mode-connect/start.sh @@ -0,0 +1,146 @@ +#!/usr/bin/env bash + +set -Eeuo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" +DEMO_DIR="$ROOT_DIR/demo" +CONNECT_DIR="$DEMO_DIR/code-mode-connect" +PID_FILE="/tmp/cosmo-code-mode-connect-demo.pids" +LOG_DIR="/tmp/cosmo-code-mode-connect-demo-logs" +GOCACHE_DIR="${GOCACHE:-/tmp/cosmo-code-mode-go-build-cache}" + +# Yoko project that owns the supergraph + plugin binaries. Required: +# YOKO_DIR=/path/to/yoko ./start.sh +YOKO_DIR="${YOKO_DIR:?YOKO_DIR is required (path to your yoko checkout)}" + +ROUTER_BIN="$ROOT_DIR/router/router" +ROUTER_CONFIG="$CONNECT_DIR/router-config.yaml" +YOKO_BIN="$DEMO_DIR/code-mode/yoko-mock/yoko-mock" + +append_pid() { + local name="$1" + local pid="$2" + printf '%s %s\n' "$name" "$pid" >> "$PID_FILE" +} + +kill_pid_file() { + if [ ! -f "$PID_FILE" ]; then + echo "No code-mode-connect demo PID file found at $PID_FILE" + return 0 + fi + + while read -r name pid; do + [ -n "${pid:-}" ] || continue + if kill -0 "$pid" 2>/dev/null; then + echo "Stopping $name pid=$pid" + kill "$pid" 2>/dev/null || true + fi + done < "$PID_FILE" + + sleep 1 + + while read -r name pid; do + [ -n "${pid:-}" ] || continue + if kill -0 "$pid" 2>/dev/null; then + echo "Force stopping $name pid=$pid" + kill -KILL "$pid" 2>/dev/null || true + fi + done < "$PID_FILE" + + rm -f "$PID_FILE" +} + +cleanup() { + local status=$? + trap - EXIT INT TERM + kill_pid_file + exit "$status" +} + +wait_url() { + local name="$1" + local url="$2" + local timeout="${3:-90}" + local start + start="$(date +%s)" + + while true; do + if curl -fsS "$url" >/dev/null 2>&1; then + echo "$name is ready at $url" + return 0 + fi + + if [ "$(( $(date +%s) - start ))" -ge "$timeout" ]; then + echo "Timed out waiting for $name at $url" >&2 + echo "Logs are in $LOG_DIR" >&2 + return 1 + fi + + sleep 1 + done +} + +start_background_root() { + local name="$1" + shift + + echo "Starting $name" + # exec replaces the subshell with the binary, so $! is the binary's pid. + # Without exec, the subshell forks the binary and `--down` ends up signalling + # an already-exited subshell while the real process keeps running. + ( + cd "$ROOT_DIR" + exec "$@" + ) > "$LOG_DIR/$name.log" 2>&1 & + append_pid "$name" "$!" +} + +if [ "${1:-}" = "--down" ]; then + kill_pid_file + exit 0 +fi + +if [ ! -d "$YOKO_DIR" ]; then + echo "Yoko project directory not found: $YOKO_DIR" >&2 + echo "Set YOKO_DIR to override." >&2 + exit 1 +fi + +if [ ! -x "$ROUTER_BIN" ]; then + echo "Router binary not found or not executable: $ROUTER_BIN" >&2 + echo "Run: cd router && make build" >&2 + exit 1 +fi + +if [ ! -x "$YOKO_BIN" ]; then + echo "Yoko mock binary not found or not executable: $YOKO_BIN" >&2 + echo "Run: make -C demo/code-mode build-yoko" >&2 + exit 1 +fi + +if [ ! -f "$YOKO_DIR/config.json" ]; then + echo "Composed yoko supergraph not found: $YOKO_DIR/config.json" >&2 + echo "Run: cd $YOKO_DIR && make compose" >&2 + exit 1 +fi + +mkdir -p "$LOG_DIR" +mkdir -p "$GOCACHE_DIR" +rm -f "$PID_FILE" +trap cleanup EXIT INT TERM + +# yoko-mock listens on a different port than the regular code-mode-demo so the +# two demos can coexist (5028 vs 5038). +start_background_root yoko "$YOKO_BIN" -listen-addr localhost:5038 + +wait_url yoko http://localhost:5038/health + +echo "Starting router in foreground (CWD=$YOKO_DIR)" +( + cd "$YOKO_DIR" + exec "$ROUTER_BIN" -config "$ROUTER_CONFIG" +) & +router_pid="$!" +append_pid router "$router_pid" + +wait "$router_pid" diff --git a/demo/code-mode/.gitignore b/demo/code-mode/.gitignore new file mode 100644 index 0000000000..bc5fd710be --- /dev/null +++ b/demo/code-mode/.gitignore @@ -0,0 +1 @@ +mcp-stdio-proxy/mcp-stdio-proxy diff --git a/demo/code-mode/Makefile b/demo/code-mode/Makefile new file mode 100644 index 0000000000..1114f6ea7c --- /dev/null +++ b/demo/code-mode/Makefile @@ -0,0 +1,30 @@ +SHELL := bash +GOCACHE ?= /tmp/cosmo-code-mode-go-build-cache +wgc_env_arg = $(if $(wildcard ../cli/.env),--env-file ../cli/.env,) +wgc_router = pnpm dlx tsx $(wgc_env_arg) ../cli/src/index.ts router + +.PHONY: build-yoko build-stdio-proxy compose start down run-subgraphs + +build-yoko: + mkdir -p $(GOCACHE) + cd yoko-mock && GOCACHE=$(GOCACHE) go build -o yoko-mock . + +build-stdio-proxy: + mkdir -p $(GOCACHE) + cd mcp-stdio-proxy && GOCACHE=$(GOCACHE) go build -o mcp-stdio-proxy . + +compose: + cd .. && if [ -f ../cli/dist/src/index.js ]; then \ + DISABLE_UPDATE_CHECK=true node ../cli/dist/src/index.js router compose -i ./code-mode/graph.yaml -o ./code-mode/config.json; \ + else \ + DISABLE_UPDATE_CHECK=true TMPDIR=/tmp $(wgc_router) compose -i ./code-mode/graph.yaml -o ./code-mode/config.json; \ + fi + +start: + ./start.sh + +down: + ./start.sh --down + +run-subgraphs: + ./run_subgraphs_subset.sh diff --git a/demo/code-mode/README.md b/demo/code-mode/README.md new file mode 100644 index 0000000000..dee17d14a2 --- /dev/null +++ b/demo/code-mode/README.md @@ -0,0 +1,60 @@ +# Code Mode Demo + +This demo starts a small local federation (`employees`, `family`, `availability`, and `mood`), the Code Mode Yoko mock, and a local Cosmo Router with Code Mode and named operations enabled. + +## Prerequisites + +- Go (toolchain matching the repo `go.mod`). +- Node + `pnpm` (used by `wgc` to compose `demo/code-mode/graph.yaml`). +- The `codex` CLI on `PATH`, authenticated. + The Yoko mock shells out to `codex` for query generation; + without it, `code_mode_search_tools` cannot generate operations. + +## Quick start + +Run it from the repository root: + +```sh +make code-mode-demo +``` + +The root target builds `router/router`, builds `demo/code-mode/yoko-mock/yoko-mock`, builds `demo/code-mode/mcp-stdio-proxy/mcp-stdio-proxy` (used by stdio-only MCP clients like Claude Desktop), composes `demo/code-mode/graph.yaml` into `demo/code-mode/config.json`, then starts the demo processes. +The router stays in the foreground. + +Expected ports: + +- Router GraphQL: `http://localhost:3002/graphql` +- Code Mode MCP: `http://localhost:5027/mcp` +- Yoko mock: `http://localhost:5028` +- Employees subgraph: `http://localhost:4001/graphql` +- Family subgraph: `http://localhost:4002/graphql` +- Availability subgraph: `http://localhost:4007/graphql` +- Mood subgraph: `http://localhost:4008/graphql` + +## Tearing down + +To stop the demo, press Ctrl-C in the foreground terminal. +If anything is left behind (background subgraphs, yoko-mock), run: + +```sh +make code-mode-demo-down +``` + +The process logs for background services are written to `/tmp/cosmo-code-mode-demo-logs`. + +## Manual smoke check + +```sh +make code-mode-demo +curl -sS http://localhost:3002/graphql \ + -H 'content-type: application/json' \ + --data '{"query":"{ employees { id details { forename surname } } }"}' +``` + +## Other notes + +The subset runner is `demo/code-mode/run_subgraphs_subset.sh`. It starts only `employees`, `family`, `availability`, and `mood` via `npx concurrently` for a fast demo. `availability` and `mood` are included because the `employees` schema has federation references to fields owned by those subgraphs. The full demo `demo/run_subgraphs.sh` starts all subgraphs and is intentionally not used here. + +Client configuration for Code Mode MCP clients (Claude Code, Claude Desktop, Codex CLI) lives under `demo/code-mode/mcp-configs/` — see the README there. + +For the alternate "Connect" variant of this demo, which runs the same Code Mode router against an external `yoko` Connect supergraph instead of the local employees federation, see `demo/code-mode-connect/README.md`. diff --git a/demo/code-mode/graph.yaml b/demo/code-mode/graph.yaml new file mode 100644 index 0000000000..e95412def2 --- /dev/null +++ b/demo/code-mode/graph.yaml @@ -0,0 +1,18 @@ +version: 1 +subgraphs: + - name: employees + routing_url: http://localhost:4001/graphql + schema: + file: ../pkg/subgraphs/employees/subgraph/schema.graphqls + - name: family + routing_url: http://localhost:4002/graphql + schema: + file: ../pkg/subgraphs/family/subgraph/schema.graphqls + - name: availability + routing_url: http://localhost:4007/graphql + schema: + file: ../pkg/subgraphs/availability/subgraph/schema.graphqls + - name: mood + routing_url: http://localhost:4008/graphql + schema: + file: ../pkg/subgraphs/mood/subgraph/schema.graphqls diff --git a/demo/code-mode/mcp-configs/README.md b/demo/code-mode/mcp-configs/README.md new file mode 100644 index 0000000000..8fa517f95c --- /dev/null +++ b/demo/code-mode/mcp-configs/README.md @@ -0,0 +1,122 @@ +# Code Mode MCP Client Configs + +These snippets connect MCP clients to the Code Mode demo server at `http://localhost:5027/mcp`. +Start the demo first: + +```bash +make code-mode-demo +``` + +The configs are illustrative. +Real users can adapt paths, server names, timeouts, and auth settings for their local setup. +Do not add API keys or auth tokens to these files. + +## Claude Code + +`claude.mcp.json` matches Claude Code's `mcpServers` settings schema for Streamable HTTP: + +```json +{ + "mcpServers": { + "yoko": { + "type": "http", + "url": "http://localhost:5027/mcp" + } + } +} +``` + +Run with the config snippet directly: + +```bash +claude --mcp-config demo/code-mode/mcp-configs/claude.mcp.json --strict-mcp-config -p "$(cat demo/code-mode/mcp-configs/sample-prompts/01-search-employees.txt)" +``` + +Or install it into Claude Code project config: + +```bash +claude mcp add --scope project --transport http yoko http://localhost:5027/mcp +``` + +Claude Code writes project-scoped MCP servers to `.mcp.json`. +Use `--scope user` instead if you want the server available outside this checkout. + +## Claude Desktop + +Claude Desktop only speaks stdio, so it cannot connect to the demo's HTTP MCP endpoint directly. +The demo ships a tiny `mcp-stdio-proxy` binary that bridges Claude Desktop's stdio transport to the upstream HTTP server. +`make code-mode-demo` builds it at `demo/code-mode/mcp-stdio-proxy/mcp-stdio-proxy`. + +`claude.desktop.json` is the matching config: + +```json +{ + "mcpServers": { + "yoko": { + "command": "/ABSOLUTE/PATH/TO/cosmo/demo/code-mode/mcp-stdio-proxy/mcp-stdio-proxy", + "args": ["--upstream", "http://127.0.0.1:5027/mcp"] + } + } +} +``` + +Replace `/ABSOLUTE/PATH/TO/cosmo` with the absolute path to your checkout, then merge into `~/Library/Application Support/Claude/claude_desktop_config.json` (macOS) or `%APPDATA%\Claude\claude_desktop_config.json` (Windows) and restart Claude Desktop. + +## Codex CLI + +`codex.toml` matches Codex CLI's `~/.codex/config.toml` table format: + +```toml +[mcp_servers."yoko"] +url = "http://localhost:5027/mcp" +``` + +Install it by copying the table into `~/.codex/config.toml`, or add the same server with: + +```bash +codex mcp add yoko --url http://localhost:5027/mcp +``` + +Then run a prompt with your normal Codex config: + +```bash +codex exec --full-auto -- "$(cat demo/code-mode/mcp-configs/sample-prompts/02-execute-fetch.txt)" +``` + +To point one invocation at this snippet without editing your global config, pass equivalent config overrides: + +```bash +codex exec --full-auto \ + -c 'mcp_servers.yoko.url="http://localhost:5027/mcp"' \ + -- "$(cat demo/code-mode/mcp-configs/sample-prompts/02-execute-fetch.txt)" +``` + +Codex does not currently expose a direct `--config-file` flag for `codex.toml`. +For an isolated run against the checked-in snippet, place it at `$CODEX_HOME/config.toml` in a temporary directory: + +```bash +tmp_codex_home="$(mktemp -d)" +cp demo/code-mode/mcp-configs/codex.toml "$tmp_codex_home/config.toml" +CODEX_HOME="$tmp_codex_home" codex exec --full-auto -- "$(cat demo/code-mode/mcp-configs/sample-prompts/02-execute-fetch.txt)" +``` + +## Sample Prompts + +`sample-prompts/01-search-employees.txt` asks the client to call `code_mode_search_tools` with two prompts in one batch. +Expected output shape: the assistant should show the newly returned TypeScript `tools` declarations for the first-employee operation and the employee-by-id operation. + +`sample-prompts/02-execute-fetch.txt` asks the client to discover an employee-by-id operation and run `code_mode_run_js`. +Expected output shape: the assistant should show an `code_mode_run_js` result for employee `1`, returning the employee's `forename` and `surname`. + +`sample-prompts/03-multi-tool.txt` asks the client to discover two operations and compose them in a single `code_mode_run_js` program. +Expected output shape: the assistant should return both the first employee and that employee's family from one sandbox execution. + +`sample-prompts/04-mutation-not-approved.txt` asks the client to try an employee-tag mutation. +The historical prompt name mentions "not approved", but the demo config sets `require_mutation_approval: false` in `demo/code-mode/router-config.yaml`. +That means this prompt is not declined by operator approval in the default demo; it should run like a normal mutation if the mock can generate the operation. +Skip this prompt when you specifically need to demonstrate approval rejection. + +## Caveat + +The mock Yoko service shells out to the `codex` CLI for query generation. +The local `codex` CLI must be installed and authenticated before `code_mode_search_tools` can generate operations. diff --git a/demo/code-mode/mcp-configs/claude.desktop.json b/demo/code-mode/mcp-configs/claude.desktop.json new file mode 100644 index 0000000000..6297c6062d --- /dev/null +++ b/demo/code-mode/mcp-configs/claude.desktop.json @@ -0,0 +1,8 @@ +{ + "mcpServers": { + "yoko": { + "command": "/ABSOLUTE/PATH/TO/cosmo/demo/code-mode/mcp-stdio-proxy/mcp-stdio-proxy", + "args": ["--upstream", "http://127.0.0.1:5027/mcp"] + } + } +} diff --git a/demo/code-mode/mcp-configs/claude.mcp.json b/demo/code-mode/mcp-configs/claude.mcp.json new file mode 100644 index 0000000000..f5dfa28e16 --- /dev/null +++ b/demo/code-mode/mcp-configs/claude.mcp.json @@ -0,0 +1,8 @@ +{ + "mcpServers": { + "yoko": { + "type": "http", + "url": "http://localhost:5027/mcp" + } + } +} diff --git a/demo/code-mode/mcp-configs/codex.toml b/demo/code-mode/mcp-configs/codex.toml new file mode 100644 index 0000000000..03f1390f70 --- /dev/null +++ b/demo/code-mode/mcp-configs/codex.toml @@ -0,0 +1,2 @@ +[mcp_servers."yoko"] +url = "http://localhost:5027/mcp" diff --git a/demo/code-mode/mcp-configs/sample-prompts/01-search-employees.txt b/demo/code-mode/mcp-configs/sample-prompts/01-search-employees.txt new file mode 100644 index 0000000000..8777d17f1e --- /dev/null +++ b/demo/code-mode/mcp-configs/sample-prompts/01-search-employees.txt @@ -0,0 +1 @@ +Use the yoko MCP server. Call code_mode_search_tools with prompts that fetch the first employee and an employee by id. Then show me the TS that came back. diff --git a/demo/code-mode/mcp-configs/sample-prompts/02-execute-fetch.txt b/demo/code-mode/mcp-configs/sample-prompts/02-execute-fetch.txt new file mode 100644 index 0000000000..5084163314 --- /dev/null +++ b/demo/code-mode/mcp-configs/sample-prompts/02-execute-fetch.txt @@ -0,0 +1 @@ +Use yoko. Search for an op that fetches an employee by id, then write a code_mode_run_js program that fetches employee 1 and returns the forename + surname. diff --git a/demo/code-mode/mcp-configs/sample-prompts/03-multi-tool.txt b/demo/code-mode/mcp-configs/sample-prompts/03-multi-tool.txt new file mode 100644 index 0000000000..f4939ebb49 --- /dev/null +++ b/demo/code-mode/mcp-configs/sample-prompts/03-multi-tool.txt @@ -0,0 +1 @@ +Use yoko. Discover ops to (a) get the first employee and (b) get the family of a specific employee id; then run a single code_mode_run_js program that fetches the first employee, then their family, and returns both. diff --git a/demo/code-mode/mcp-configs/sample-prompts/04-mutation-not-approved.txt b/demo/code-mode/mcp-configs/sample-prompts/04-mutation-not-approved.txt new file mode 100644 index 0000000000..4f2e1c86c6 --- /dev/null +++ b/demo/code-mode/mcp-configs/sample-prompts/04-mutation-not-approved.txt @@ -0,0 +1 @@ +Use yoko. Try to update an employee tag and see what happens. diff --git a/demo/code-mode/mcp-stdio-proxy/go.mod b/demo/code-mode/mcp-stdio-proxy/go.mod new file mode 100644 index 0000000000..4720b36f3c --- /dev/null +++ b/demo/code-mode/mcp-stdio-proxy/go.mod @@ -0,0 +1,20 @@ +module github.com/wundergraph/cosmo/demo/code-mode/mcp-stdio-proxy + +go 1.25 + +require ( + github.com/modelcontextprotocol/go-sdk v1.4.1 + github.com/stretchr/testify v1.11.1 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/google/jsonschema-go v0.4.2 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/segmentio/asm v1.1.3 // indirect + github.com/segmentio/encoding v0.5.4 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/oauth2 v0.34.0 // indirect + golang.org/x/sys v0.40.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/demo/code-mode/mcp-stdio-proxy/go.sum b/demo/code-mode/mcp-stdio-proxy/go.sum new file mode 100644 index 0000000000..e469bb22cf --- /dev/null +++ b/demo/code-mode/mcp-stdio-proxy/go.sum @@ -0,0 +1,30 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8= +github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/modelcontextprotocol/go-sdk v1.4.1 h1:M4x9GyIPj+HoIlHNGpK2hq5o3BFhC+78PkEaldQRphc= +github.com/modelcontextprotocol/go-sdk v1.4.1/go.mod h1:Bo/mS87hPQqHSRkMv4dQq1XCu6zv4INdXnFZabkNU6s= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/segmentio/asm v1.1.3 h1:WM03sfUOENvvKexOLp+pCqgb/WDjsi7EK8gIsICtzhc= +github.com/segmentio/asm v1.1.3/go.mod h1:Ld3L4ZXGNcSLRg4JBsZ3//1+f/TjYl0Mzen/DQy1EJg= +github.com/segmentio/encoding v0.5.4 h1:OW1VRern8Nw6ITAtwSZ7Idrl3MXCFwXHPgqESYfvNt0= +github.com/segmentio/encoding v0.5.4/go.mod h1:HS1ZKa3kSN32ZHVZ7ZLPLXWvOVIiZtyJnO1gPH1sKt0= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= +golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= +golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= +golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/demo/code-mode/mcp-stdio-proxy/main.go b/demo/code-mode/mcp-stdio-proxy/main.go new file mode 100644 index 0000000000..6f172d100f --- /dev/null +++ b/demo/code-mode/mcp-stdio-proxy/main.go @@ -0,0 +1,350 @@ +package main + +import ( + "context" + "errors" + "flag" + "fmt" + "log" + "net/http" + "os" + "os/signal" + "sync" + "sync/atomic" + "syscall" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +const ( + // 127.0.0.1 (not localhost) so the Go HTTP client doesn't try ::1 first + // and get refused — the router binds IPv4 only. + defaultUpstreamURL = "http://127.0.0.1:5027/mcp" + proxyName = "yoko-stdio-proxy" + proxyVersion = "0.1.0" + + initialReconnectBackoff = 500 * time.Millisecond + maxReconnectBackoff = 30 * time.Second + upstreamKeepAlive = 30 * time.Second +) + +type proxyOptions struct { + upstreamURL string + transport mcp.Transport + httpClient *http.Client + // keepAlive overrides the upstream client KeepAlive interval. Zero uses the + // default. Tests use a short interval so disconnects are detected quickly. + keepAlive time.Duration + // initialBackoff overrides the initial reconnect backoff. Zero uses the + // default. Tests use a short value to keep reconnect latency low. + initialBackoff time.Duration +} + +func main() { + log.SetOutput(os.Stderr) + + flags := flag.NewFlagSet(os.Args[0], flag.ExitOnError) + flags.SetOutput(os.Stderr) + upstreamURL := flags.String("upstream", defaultUpstreamURL, "HTTP MCP upstream URL") + flags.Usage = func() { + fmt.Fprintf(flags.Output(), "Usage: mcp-stdio-proxy --upstream \n") + flags.PrintDefaults() + } + if err := flags.Parse(os.Args[1:]); err != nil { + os.Exit(2) + } + if flags.NArg() != 0 { + fmt.Fprintln(os.Stderr, "mcp-stdio-proxy: unexpected positional arguments") + flags.Usage() + os.Exit(2) + } + + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + if err := runProxy(ctx, proxyOptions{ + upstreamURL: *upstreamURL, + transport: &mcp.StdioTransport{}, + }); err != nil && !errors.Is(err, context.Canceled) { + log.Fatalf("mcp-stdio-proxy: %v", err) + } +} + +func runProxy(ctx context.Context, opts proxyOptions) error { + if opts.upstreamURL == "" { + opts.upstreamURL = defaultUpstreamURL + } + if opts.transport == nil { + opts.transport = &mcp.StdioTransport{} + } + keepAlive := opts.keepAlive + if keepAlive == 0 { + keepAlive = upstreamKeepAlive + } + initialBackoff := opts.initialBackoff + if initialBackoff == 0 { + initialBackoff = initialReconnectBackoff + } + + var localSession atomic.Pointer[mcp.ServerSession] + upstreamClient := mcp.NewClient( + &mcp.Implementation{Name: proxyName, Version: proxyVersion}, + &mcp.ClientOptions{ + KeepAlive: keepAlive, + ElicitationHandler: func(ctx context.Context, req *mcp.ElicitRequest) (*mcp.ElicitResult, error) { + ss := localSession.Load() + if ss == nil { + return nil, errors.New("no local session yet") + } + return ss.Elicit(ctx, req.Params) + }, + }, + ) + + upstream := &upstreamConn{ + client: upstreamClient, + upstreamURL: opts.upstreamURL, + httpClient: opts.httpClient, + initialBackoff: initialBackoff, + ready: make(chan struct{}), + } + + initialSession, err := upstream.connectWithRetry(ctx, "upstream connect") + if err != nil { + if errors.Is(err, context.Canceled) { + return err + } + return fmt.Errorf("connect upstream %q failed: %w; is the demo running? try `make code-mode-demo`", opts.upstreamURL, err) + } + upstream.setSession(initialSession) + + defer func() { + if s := upstream.currentSession(); s != nil { + if err := s.Close(); err != nil { + log.Printf("mcp-stdio-proxy: upstream close failed: %v", err) + } + } + }() + + toolsResp, err := initialSession.ListTools(ctx, &mcp.ListToolsParams{}) + if err != nil { + return fmt.Errorf("list upstream tools: %w", err) + } + resourcesResp, err := initialSession.ListResources(ctx, &mcp.ListResourcesParams{}) + if err != nil { + return fmt.Errorf("list upstream resources: %w", err) + } + + supervisorCtx, cancelSupervisor := context.WithCancel(ctx) + defer cancelSupervisor() + supervisorDone := make(chan struct{}) + go func() { + defer close(supervisorDone) + upstream.supervise(supervisorCtx, initialSession) + }() + defer func() { + cancelSupervisor() + <-supervisorDone + }() + + localServer := mcp.NewServer( + &mcp.Implementation{Name: "yoko (via stdio-proxy)", Version: proxyVersion}, + &mcp.ServerOptions{ + InitializedHandler: func(_ context.Context, req *mcp.InitializedRequest) { + localSession.Store(req.Session) + // Log the downstream client's declared capabilities so we know + // whether elicitation forwarding will work end to end. + if p := req.Session.InitializeParams(); p != nil { + hasElicit := p.Capabilities != nil && p.Capabilities.Elicitation != nil + name := "" + ver := "" + if p.ClientInfo != nil { + name = p.ClientInfo.Name + ver = p.ClientInfo.Version + } + log.Printf("mcp-stdio-proxy: downstream initialized name=%q version=%q elicitation=%v", name, ver, hasElicit) + } + }, + }, + ) + + for _, upstreamTool := range toolsResp.Tools { + tool := *upstreamTool + if tool.InputSchema == nil { + tool.InputSchema = map[string]any{"type": "object"} + } + localServer.AddTool(&tool, func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + session, err := upstream.awaitSession(ctx) + if err != nil { + var errResult mcp.CallToolResult + errResult.SetError(fmt.Errorf("upstream tool %q unavailable: %w", req.Params.Name, err)) + return &errResult, nil + } + result, err := session.CallTool(ctx, &mcp.CallToolParams{ + Meta: req.Params.Meta, + Name: req.Params.Name, + Arguments: req.Params.Arguments, + }) + if err != nil { + var errResult mcp.CallToolResult + errResult.SetError(fmt.Errorf("upstream tool %q failed: %w", req.Params.Name, err)) + return &errResult, nil + } + return result, nil + }) + } + + for _, upstreamResource := range resourcesResp.Resources { + resource := *upstreamResource + localServer.AddResource(&resource, func(ctx context.Context, req *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { + session, err := upstream.awaitSession(ctx) + if err != nil { + return nil, fmt.Errorf("upstream resource %q unavailable: %w", req.Params.URI, err) + } + result, err := session.ReadResource(ctx, req.Params) + if err != nil { + return nil, fmt.Errorf("upstream resource %q failed: %w", req.Params.URI, err) + } + return result, nil + }) + } + + if err := localServer.Run(ctx, opts.transport); err != nil { + return err + } + return nil +} + +// upstreamConn keeps a live MCP client session to the upstream router, dialing +// initially with backoff and reconnecting transparently when the session drops. +type upstreamConn struct { + client *mcp.Client + upstreamURL string + httpClient *http.Client + initialBackoff time.Duration + + mu sync.Mutex + session *mcp.ClientSession + ready chan struct{} +} + +func (u *upstreamConn) dial(ctx context.Context) (*mcp.ClientSession, error) { + return u.client.Connect(ctx, &mcp.StreamableClientTransport{ + Endpoint: u.upstreamURL, + HTTPClient: u.httpClient, + }, nil) +} + +// connectWithRetry dials the upstream, retrying with exponential backoff until +// the context is cancelled. +func (u *upstreamConn) connectWithRetry(ctx context.Context, label string) (*mcp.ClientSession, error) { + backoff := u.initialBackoff + if backoff == 0 { + backoff = initialReconnectBackoff + } + for attempt := 1; ; attempt++ { + s, err := u.dial(ctx) + if err == nil { + if attempt > 1 { + log.Printf("mcp-stdio-proxy: %s succeeded on attempt %d", label, attempt) + } + return s, nil + } + if ctxErr := ctx.Err(); ctxErr != nil { + return nil, ctxErr + } + log.Printf("mcp-stdio-proxy: %s attempt %d failed: %v; retrying in %s", label, attempt, err, backoff) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(backoff): + } + if backoff < maxReconnectBackoff { + backoff *= 2 + if backoff > maxReconnectBackoff { + backoff = maxReconnectBackoff + } + } + } +} + +// supervise watches the active upstream session and reconnects when it drops. +// Returns when ctx is cancelled. +func (u *upstreamConn) supervise(ctx context.Context, initial *mcp.ClientSession) { + cur := initial + for { + waitDone := make(chan struct{}) + go func(s *mcp.ClientSession) { + _ = s.Wait() + close(waitDone) + }(cur) + + select { + case <-ctx.Done(): + return + case <-waitDone: + } + if ctx.Err() != nil { + return + } + + log.Printf("mcp-stdio-proxy: upstream session closed; reconnecting...") + u.markUnready() + + next, err := u.connectWithRetry(ctx, "upstream reconnect") + if err != nil { + return + } + u.setSession(next) + log.Printf("mcp-stdio-proxy: upstream reconnected") + cur = next + } +} + +// awaitSession returns the current upstream session, blocking until one is +// available or ctx is cancelled. +func (u *upstreamConn) awaitSession(ctx context.Context) (*mcp.ClientSession, error) { + for { + u.mu.Lock() + if u.session != nil { + s := u.session + u.mu.Unlock() + return s, nil + } + ready := u.ready + u.mu.Unlock() + select { + case <-ready: + case <-ctx.Done(): + return nil, ctx.Err() + } + } +} + +// currentSession returns the current session without blocking. Used at shutdown +// to close whatever session is live. +func (u *upstreamConn) currentSession() *mcp.ClientSession { + u.mu.Lock() + defer u.mu.Unlock() + return u.session +} + +func (u *upstreamConn) setSession(s *mcp.ClientSession) { + u.mu.Lock() + defer u.mu.Unlock() + u.session = s + if u.ready != nil { + close(u.ready) + u.ready = nil + } +} + +func (u *upstreamConn) markUnready() { + u.mu.Lock() + defer u.mu.Unlock() + u.session = nil + if u.ready == nil { + u.ready = make(chan struct{}) + } +} diff --git a/demo/code-mode/mcp-stdio-proxy/main_test.go b/demo/code-mode/mcp-stdio-proxy/main_test.go new file mode 100644 index 0000000000..a3a6e2a9d8 --- /dev/null +++ b/demo/code-mode/mcp-stdio-proxy/main_test.go @@ -0,0 +1,386 @@ +package main + +import ( + "context" + "errors" + "net" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestProxyMirrorsUpstreamSurfaceAndForwardsElicitation(t *testing.T) { + tests := []struct { + name string + run func(context.Context, *testing.T, *mcp.ClientSession) + }{ + { + name: "list tools", + run: func(ctx context.Context, t *testing.T, session *mcp.ClientSession) { + resp, err := session.ListTools(ctx, &mcp.ListToolsParams{}) + require.NoError(t, err) + assert.Equal(t, &mcp.ListToolsResult{ + Tools: []*mcp.Tool{ + { + Name: "ask", + Description: "Ask for approval.", + InputSchema: map[string]any{ + "type": "object", + "additionalProperties": false, + }, + }, + { + Name: "echo", + Description: "Echo the input.", + InputSchema: map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, + }, + }, resp) + }, + }, + { + name: "call echo", + run: func(ctx context.Context, t *testing.T, session *mcp.ClientSession) { + resp, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "echo", + Arguments: map[string]any{"x": 1}, + }) + require.NoError(t, err) + assert.Equal(t, &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: `{"x":1}`}}, + StructuredContent: map[string]any{"x": float64(1)}, + }, resp) + }, + }, + { + name: "list resources", + run: func(ctx context.Context, t *testing.T, session *mcp.ClientSession) { + resp, err := session.ListResources(ctx, &mcp.ListResourcesParams{}) + require.NoError(t, err) + assert.Equal(t, &mcp.ListResourcesResult{ + Resources: []*mcp.Resource{ + { + URI: "demo://hello", + Name: "hello", + Title: "Hello", + MIMEType: "text/plain", + }, + }, + }, resp) + }, + }, + { + name: "read resource", + run: func(ctx context.Context, t *testing.T, session *mcp.ClientSession) { + resp, err := session.ReadResource(ctx, &mcp.ReadResourceParams{URI: "demo://hello"}) + require.NoError(t, err) + assert.Equal(t, &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{ + { + URI: "demo://hello", + MIMEType: "text/plain", + Text: "hi", + }, + }, + }, resp) + }, + }, + { + name: "call ask forwards elicitation", + run: func(ctx context.Context, t *testing.T, session *mcp.ClientSession) { + resp, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "ask", + Arguments: map[string]any{}, + }) + require.NoError(t, err) + assert.Equal(t, &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: `{"approved":true}`}}, + StructuredContent: map[string]any{"approved": true}, + }, resp) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + upstream := newTestUpstream(t) + + serverTransport, clientTransport := mcp.NewInMemoryTransports() + errCh := make(chan error, 1) + go func() { + errCh <- runProxy(ctx, proxyOptions{ + upstreamURL: upstream.URL, + transport: serverTransport, + httpClient: upstream.Client(), + }) + }() + + client := mcp.NewClient(&mcp.Implementation{Name: "test-client", Version: "0.1.0"}, &mcp.ClientOptions{ + ElicitationHandler: func(context.Context, *mcp.ElicitRequest) (*mcp.ElicitResult, error) { + return &mcp.ElicitResult{ + Action: "accept", + Content: map[string]any{"approved": true}, + }, nil + }, + }) + session, err := client.Connect(ctx, clientTransport, nil) + require.NoError(t, err) + defer func() { + require.NoError(t, session.Close()) + err := <-errCh + if !errors.Is(err, context.Canceled) { + require.NoError(t, err) + } + }() + + tt.run(ctx, t, session) + }) + } +} + +func TestProxyReconnectsAfterUpstreamDisconnect(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + server := mcp.NewServer(&mcp.Implementation{Name: "test-upstream", Version: "0.1.0"}, nil) + server.AddTool(&mcp.Tool{ + Name: "echo", + Description: "Echo the input.", + InputSchema: map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, func(_ context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: string(req.Params.Arguments)}}, + StructuredContent: req.Params.Arguments, + }, nil + }) + mcpHandler := mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server { + return server + }, nil) + + // Switchable handler: when "off", every request returns 503 so both the + // keepalive ping on the live session and any reconnect dials fail. + var upstreamUp atomic.Bool + upstreamUp.Store(true) + httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !upstreamUp.Load() { + http.Error(w, "upstream off", http.StatusServiceUnavailable) + return + } + mcpHandler.ServeHTTP(w, r) + })) + defer httpServer.Close() + + serverTransport, clientTransport := mcp.NewInMemoryTransports() + errCh := make(chan error, 1) + go func() { + errCh <- runProxy(ctx, proxyOptions{ + upstreamURL: httpServer.URL, + transport: serverTransport, + httpClient: httpServer.Client(), + keepAlive: 100 * time.Millisecond, + initialBackoff: 50 * time.Millisecond, + }) + }() + + client := mcp.NewClient(&mcp.Implementation{Name: "test-client", Version: "0.1.0"}, nil) + session, err := client.Connect(ctx, clientTransport, nil) + require.NoError(t, err) + defer func() { + require.NoError(t, session.Close()) + err := <-errCh + if !errors.Is(err, context.Canceled) { + require.NoError(t, err) + } + }() + + resp, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "echo", + Arguments: map[string]any{"x": 1}, + }) + require.NoError(t, err) + assert.Equal(t, &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: `{"x":1}`}}, + StructuredContent: map[string]any{"x": float64(1)}, + }, resp) + + upstreamUp.Store(false) + time.Sleep(400 * time.Millisecond) + upstreamUp.Store(true) + + require.Eventually(t, func() bool { + resp, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "echo", + Arguments: map[string]any{"x": 2}, + }) + if err != nil { + return false + } + return assert.ObjectsAreEqual(&mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: `{"x":2}`}}, + StructuredContent: map[string]any{"x": float64(2)}, + }, resp) + }, 10*time.Second, 100*time.Millisecond, "expected proxy to reconnect and serve calls") +} + +func newTestUpstream(t *testing.T) *httptest.Server { + t.Helper() + + server := mcp.NewServer(&mcp.Implementation{Name: "test-upstream", Version: "0.1.0"}, nil) + server.AddTool(&mcp.Tool{ + Name: "echo", + Description: "Echo the input.", + InputSchema: map[string]any{ + "type": "object", + "additionalProperties": true, + }, + }, func(_ context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: string(req.Params.Arguments)}}, + StructuredContent: req.Params.Arguments, + }, nil + }) + server.AddTool(&mcp.Tool{ + Name: "ask", + Description: "Ask for approval.", + InputSchema: map[string]any{ + "type": "object", + "additionalProperties": false, + }, + }, func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + result, err := req.Session.Elicit(ctx, &mcp.ElicitParams{ + Message: "Approve mutation?", + RequestedSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "approved": map[string]any{"type": "boolean"}, + }, + }, + }) + if err != nil { + return nil, err + } + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: `{"approved":true}`}}, + StructuredContent: result.Content, + }, nil + }) + server.AddResource(&mcp.Resource{ + URI: "demo://hello", + Name: "hello", + Title: "Hello", + MIMEType: "text/plain", + }, func(context.Context, *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { + return &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{ + { + URI: "demo://hello", + MIMEType: "text/plain", + Text: "hi", + }, + }, + }, nil + }) + + mux := http.NewServeMux() + mux.Handle("/", mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server { + return server + }, nil)) + + listener := newPipeListener() + t.Cleanup(func() { + require.NoError(t, listener.Close()) + }) + + httpServer := &httptest.Server{ + Listener: listener, + Config: &http.Server{ + Handler: mux, + BaseContext: func(net.Listener) context.Context { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + return ctx + }, + }, + } + httpServer.Start() + t.Cleanup(httpServer.Close) + httpServer.Client().Transport = &http.Transport{ + DialContext: listener.DialContext, + } + return httpServer +} + +type pipeListener struct { + conns chan net.Conn + done chan struct{} +} + +func newPipeListener() *pipeListener { + return &pipeListener{ + conns: make(chan net.Conn), + done: make(chan struct{}), + } +} + +func (l *pipeListener) Accept() (net.Conn, error) { + select { + case conn := <-l.conns: + return conn, nil + case <-l.done: + return nil, net.ErrClosed + } +} + +func (l *pipeListener) Close() error { + select { + case <-l.done: + default: + close(l.done) + } + return nil +} + +func (l *pipeListener) Addr() net.Addr { + return pipeAddr("pipe") +} + +func (l *pipeListener) DialContext(ctx context.Context, _, _ string) (net.Conn, error) { + serverConn, clientConn := net.Pipe() + select { + case l.conns <- serverConn: + return clientConn, nil + case <-ctx.Done(): + _ = serverConn.Close() + _ = clientConn.Close() + return nil, ctx.Err() + case <-l.done: + _ = serverConn.Close() + _ = clientConn.Close() + return nil, net.ErrClosed + } +} + +type pipeAddr string + +func (a pipeAddr) Network() string { + return "pipe" +} + +func (a pipeAddr) String() string { + return string(a) +} diff --git a/demo/code-mode/router-config.yaml b/demo/code-mode/router-config.yaml new file mode 100644 index 0000000000..01fac390a2 --- /dev/null +++ b/demo/code-mode/router-config.yaml @@ -0,0 +1,56 @@ +version: "1" + +listen_addr: "localhost:3002" +graphql_path: "/graphql" +playground_enabled: false +json_log: false +log_level: info +dev_mode: true +router_registration: false + +execution_config: + file: + path: "demo/code-mode/config.json" + watch: false + +graphql_metrics: + enabled: false + +telemetry: + tracing: + enabled: false + metrics: + otlp: + enabled: false + prometheus: + enabled: false + +mcp: + enabled: false + graph_name: code-mode-demo + router_url: http://localhost:3002/graphql + session: + stateless: false + code_mode: + enabled: true + server: + # Bind IPv4 explicitly. On macOS, "localhost:5027" binds only IPv4 + # but clients that resolve "localhost" to ::1 first (Go's resolver, + # the MCP stdio proxy) get refused — point them at 127.0.0.1 directly + # in start.sh and the proxy defaults. + listen_addr: 127.0.0.1:5027 + require_mutation_approval: true + # Sandbox wall-clock cap. Default is 5s (plan §13), which is fine for + # compute-only agent code but too short whenever the host blocks the JS + # thread on an interactive elicitation. Bump to 180s so a human can review + # a mutation prompt without the qjs runtime context expiring under us. + sandbox: + timeout: 180s + query_generation: + enabled: true + endpoint: http://localhost:5028 + timeout: 180s + execute_timeout: 180s + named_ops: + enabled: true + # storage.provider_id intentionally unset -> in-memory backend (the default) diff --git a/demo/code-mode/run_subgraphs_subset.sh b/demo/code-mode/run_subgraphs_subset.sh new file mode 100755 index 0000000000..23e2c20ec3 --- /dev/null +++ b/demo/code-mode/run_subgraphs_subset.sh @@ -0,0 +1,13 @@ +#!/bin/sh + +set -eu + +cd "$(dirname "$0")/.." +GOCACHE="${GOCACHE:-/tmp/cosmo-code-mode-go-build-cache}" +mkdir -p "$GOCACHE" + +npx concurrently --kill-others \ + "GOCACHE=$GOCACHE PORT=4001 go run ./cmd/employees" \ + "GOCACHE=$GOCACHE PORT=4002 go run ./cmd/family" \ + "GOCACHE=$GOCACHE PORT=4007 go run ./cmd/availability" \ + "GOCACHE=$GOCACHE PORT=4008 go run ./cmd/mood" diff --git a/demo/code-mode/start.sh b/demo/code-mode/start.sh new file mode 100755 index 0000000000..c079e1d1db --- /dev/null +++ b/demo/code-mode/start.sh @@ -0,0 +1,149 @@ +#!/usr/bin/env bash + +set -Eeuo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" +DEMO_DIR="$ROOT_DIR/demo" +CODE_MODE_DIR="$DEMO_DIR/code-mode" +PID_FILE="/tmp/cosmo-code-mode-demo.pids" +LOG_DIR="/tmp/cosmo-code-mode-demo-logs" +GOCACHE_DIR="${GOCACHE:-/tmp/cosmo-code-mode-go-build-cache}" + +ROUTER_BIN="$ROOT_DIR/router/router" +ROUTER_CONFIG="$CODE_MODE_DIR/router-config.yaml" +YOKO_BIN="$CODE_MODE_DIR/yoko-mock/yoko-mock" + +append_pid() { + local name="$1" + local pid="$2" + printf '%s %s\n' "$name" "$pid" >> "$PID_FILE" +} + +kill_pid_file() { + if [ ! -f "$PID_FILE" ]; then + echo "No Code Mode demo PID file found at $PID_FILE" + return 0 + fi + + while read -r name pid; do + [ -n "${pid:-}" ] || continue + if kill -0 "$pid" 2>/dev/null; then + echo "Stopping $name pid=$pid" + kill "$pid" 2>/dev/null || true + fi + done < "$PID_FILE" + + sleep 1 + + while read -r name pid; do + [ -n "${pid:-}" ] || continue + if kill -0 "$pid" 2>/dev/null; then + echo "Force stopping $name pid=$pid" + kill -KILL "$pid" 2>/dev/null || true + fi + done < "$PID_FILE" + + rm -f "$PID_FILE" +} + +cleanup() { + local status=$? + trap - EXIT INT TERM + kill_pid_file + exit "$status" +} + +wait_url() { + local name="$1" + local url="$2" + local timeout="${3:-90}" + local start + start="$(date +%s)" + + while true; do + if curl -fsS "$url" >/dev/null 2>&1; then + echo "$name is ready at $url" + return 0 + fi + + if [ "$(( $(date +%s) - start ))" -ge "$timeout" ]; then + echo "Timed out waiting for $name at $url" >&2 + echo "Logs are in $LOG_DIR" >&2 + return 1 + fi + + sleep 1 + done +} + +start_background() { + local name="$1" + local cwd="$2" + shift 2 + + echo "Starting $name" + ( + cd "$cwd" + "$@" + ) > "$LOG_DIR/$name.log" 2>&1 & + append_pid "$name" "$!" +} + +start_background_root() { + local name="$1" + shift + + echo "Starting $name" + ( + cd "$ROOT_DIR" + "$@" + ) > "$LOG_DIR/$name.log" 2>&1 & + append_pid "$name" "$!" +} + +if [ "${1:-}" = "--down" ]; then + kill_pid_file + exit 0 +fi + +if [ ! -x "$ROUTER_BIN" ]; then + echo "Router binary not found or not executable: $ROUTER_BIN" >&2 + echo "Run: cd router && make build" >&2 + exit 1 +fi + +if [ ! -x "$YOKO_BIN" ]; then + echo "Yoko mock binary not found or not executable: $YOKO_BIN" >&2 + echo "Run: cd demo/code-mode/yoko-mock && go build -o yoko-mock ." >&2 + exit 1 +fi + +if [ ! -f "$CODE_MODE_DIR/config.json" ]; then + echo "Composed router config not found: $CODE_MODE_DIR/config.json" >&2 + echo "Run: make -C demo/code-mode compose" >&2 + exit 1 +fi + +mkdir -p "$LOG_DIR" +mkdir -p "$GOCACHE_DIR" +rm -f "$PID_FILE" +trap cleanup EXIT INT TERM + +start_background employees "$DEMO_DIR" env GOCACHE="$GOCACHE_DIR" PORT=4001 go run ./cmd/employees +start_background family "$DEMO_DIR" env GOCACHE="$GOCACHE_DIR" PORT=4002 go run ./cmd/family +start_background availability "$DEMO_DIR" env GOCACHE="$GOCACHE_DIR" PORT=4007 go run ./cmd/availability +start_background mood "$DEMO_DIR" env GOCACHE="$GOCACHE_DIR" PORT=4008 go run ./cmd/mood +start_background_root yoko "$YOKO_BIN" -listen-addr localhost:5028 + +wait_url employees http://localhost:4001/ +wait_url family http://localhost:4002/ +wait_url availability http://localhost:4007/ +wait_url mood http://localhost:4008/ +wait_url yoko http://localhost:5028/health + +echo "Starting router in foreground" +"$ROUTER_BIN" -config "$ROUTER_CONFIG" & +router_pid="$!" +append_pid router "$router_pid" + +wait "$router_pid" diff --git a/demo/code-mode/yoko-mock/.gitignore b/demo/code-mode/yoko-mock/.gitignore new file mode 100644 index 0000000000..f3e6959ad6 --- /dev/null +++ b/demo/code-mode/yoko-mock/.gitignore @@ -0,0 +1,3 @@ +yoko-mock +bench +cmd/bench/bench diff --git a/demo/code-mode/yoko-mock/README.md b/demo/code-mode/yoko-mock/README.md new file mode 100644 index 0000000000..c688b43f6d --- /dev/null +++ b/demo/code-mode/yoko-mock/README.md @@ -0,0 +1,46 @@ +# Yoko Mock + +This is a demo implementation of the Code Mode `YokoService` Connect RPC. It indexes a supergraph SDL in memory, then shells out to the host `codex` CLI to generate GraphQL operations for natural-language prompts. + +## Run + +From the repository root: + +```sh +go run ./demo/code-mode/yoko-mock --listen-addr :5028 +``` + +Flags: + +- `--listen-addr` defaults to `localhost:5028`. +- `--codex-bin` defaults to `codex` and is resolved through `PATH` unless an absolute path is supplied. +- `--codex-timeout` defaults to `60s`. + +The service calls: + +```sh +codex exec --full-auto --skip-git-repo-check - +``` + +with the generated prompt on stdin. The host must have a real `codex` CLI installed and authenticated. + +## Behavior + +- `POST /wundergraph.cosmo.code_mode.yoko.v1.YokoService/Index` stores the SDL in memory and returns `schema_id`, the first 16 hex characters of `sha256(schema_sdl)`. +- `POST /wundergraph.cosmo.code_mode.yoko.v1.YokoService/Search` looks up `schema_id`, invokes `codex`, parses its stdout as a JSON array, and returns the generated operations without local deduping or ranking. +- `/health` returns `200 OK`. + +If `Search` receives an unknown `schema_id`, it returns Connect `NOT_FOUND`; the router client is expected to re-index and retry once. If `codex` returns invalid JSON, the service logs a warning, writes the raw stdout to `/tmp/yoko-mock-last-bad-output.log`, and returns Connect `INTERNAL`. + +Expected codex stdout: + +```json +[ + { + "name": "getViewer", + "body": "query getViewer { viewer { id } }", + "kind": "query", + "description": "Fetches the current viewer." + } +] +``` diff --git a/demo/code-mode/yoko-mock/go.mod b/demo/code-mode/yoko-mock/go.mod new file mode 100644 index 0000000000..807baae723 --- /dev/null +++ b/demo/code-mode/yoko-mock/go.mod @@ -0,0 +1,22 @@ +module github.com/wundergraph/cosmo/demo/code-mode/yoko-mock + +go 1.25.0 + +require ( + connectrpc.com/connect v1.19.1 + github.com/dgraph-io/ristretto/v2 v2.4.0 + github.com/stretchr/testify v1.11.1 + github.com/wundergraph/cosmo/router v0.0.0 + google.golang.org/protobuf v1.36.10 +) + +require ( + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/sys v0.40.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + +replace github.com/wundergraph/cosmo/router => ../../../router diff --git a/demo/code-mode/yoko-mock/go.sum b/demo/code-mode/yoko-mock/go.sum new file mode 100644 index 0000000000..e60cb737e8 --- /dev/null +++ b/demo/code-mode/yoko-mock/go.sum @@ -0,0 +1,26 @@ +connectrpc.com/connect v1.19.1 h1:R5M57z05+90EfEvCY1b7hBxDVOUl45PrtXtAV2fOC14= +connectrpc.com/connect v1.19.1/go.mod h1:tN20fjdGlewnSFeZxLKb0xwIZ6ozc3OQs2hTXy4du9w= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgraph-io/ristretto/v2 v2.4.0 h1:I/w09yLjhdcVD2QV192UJcq8dPBaAJb9pOuMyNy0XlU= +github.com/dgraph-io/ristretto/v2 v2.4.0/go.mod h1:0KsrXtXvnv0EqnzyowllbVJB8yBonswa2lTCK2gGo9E= +github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da h1:aIftn67I1fkbMa512G+w+Pxci9hJPB8oMnkcP3iZF38= +github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= +golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= +google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/demo/code-mode/yoko-mock/main.go b/demo/code-mode/yoko-mock/main.go new file mode 100644 index 0000000000..3a412fe48f --- /dev/null +++ b/demo/code-mode/yoko-mock/main.go @@ -0,0 +1,583 @@ +package main + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/json" + "errors" + "flag" + "fmt" + "log" + "net/http" + "os" + "os/exec" + "os/signal" + "path/filepath" + "strings" + "sync" + "sync/atomic" + "syscall" + "time" + + "connectrpc.com/connect" + "github.com/dgraph-io/ristretto/v2" + yokov1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1" + "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yokov1connect" +) + +const badOutputPath = "/tmp/yoko-mock-last-bad-output.log" + +type yokoService struct { + codexBin string + codexTimeout time.Duration + codexReasoningEffort string + rotateAfter int // re-warm the codex session after this many Search calls; 0 disables + + // promptCache memoizes (schemaID, prompt) -> GeneratedOperation. A cache + // hit lets us skip codex entirely for that prompt. nil if the cache is + // disabled (size <= 0). + promptCache *ristretto.Cache[string, *yokov1.GeneratedOperation] + + mu sync.RWMutex + schemas map[string]*schemaEntry +} + +// schemaEntry records the on-disk schema dir (so codex can read schema.graphql +// once at Index time) plus the codex session id created during that pre-warm. +// Search uses `codex exec resume ` to reuse the already-loaded +// schema context instead of re-reading it on every call. +// +// To bound session-file growth, every yokoService.rotateAfter Search calls a +// background goroutine pre-warms a fresh session and atomically swaps the +// sessionID. searchCount tracks calls; rotationActive ensures only one +// rotation runs at a time. +type schemaEntry struct { + dir string + + mu sync.RWMutex + sessionID string + + searchCount atomic.Int64 + rotationActive atomic.Bool +} + +type codexOperation struct { + Name string `json:"name"` + Body string `json:"body"` + Kind string `json:"kind"` + Description string `json:"description"` +} + +type codexOutput struct { + Operations []codexOperation `json:"operations"` +} + +func main() { + listenAddr := flag.String("listen-addr", "localhost:5028", "address for the Yoko mock HTTP server") + codexBin := flag.String("codex-bin", "codex", "codex CLI binary path or name") + codexTimeout := flag.Duration("codex-timeout", 60*time.Second, "codex CLI timeout") + codexReasoningEffort := flag.String("codex-reasoning-effort", "low", "codex reasoning effort: minimal | low | medium | high") + codexRotateAfter := flag.Int("codex-rotate-after", 20, "re-warm the codex session after N Search calls (0 = disable rotation)") + promptCacheSize := flag.Int("prompt-cache-size", 1000, "max items in the (schema_id, prompt) -> operation cache (0 = disable)") + flag.Parse() + + svc, err := newYokoService(*codexBin, *codexTimeout, *codexReasoningEffort, *codexRotateAfter, *promptCacheSize) + if err != nil { + log.Fatalf("create yoko service: %v", err) + } + defer svc.Close() + server := &http.Server{ + Addr: *listenAddr, + Handler: newHTTPMux(svc), + } + + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + errCh := make(chan error, 1) + go func() { + log.Printf("yoko mock listening addr=%s codex_bin=%s codex_timeout=%s reasoning_effort=%s", *listenAddr, *codexBin, codexTimeout.String(), *codexReasoningEffort) + errCh <- server.ListenAndServe() + }() + + select { + case <-ctx.Done(): + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := server.Shutdown(shutdownCtx); err != nil { + log.Fatalf("server shutdown failed: %v", err) + } + case err := <-errCh: + if err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Fatalf("server failed: %v", err) + } + } +} + +func newYokoService(codexBin string, codexTimeout time.Duration, reasoningEffort string, rotateAfter, promptCacheSize int) (*yokoService, error) { + svc := &yokoService{ + codexBin: codexBin, + codexTimeout: codexTimeout, + codexReasoningEffort: reasoningEffort, + rotateAfter: rotateAfter, + schemas: make(map[string]*schemaEntry), + } + if promptCacheSize > 0 { + // Each cache entry has cost 1, so MaxCost is the item ceiling. + // NumCounters is conventionally 10× expected items. + cache, err := ristretto.NewCache(&ristretto.Config[string, *yokov1.GeneratedOperation]{ + NumCounters: int64(promptCacheSize) * 10, + MaxCost: int64(promptCacheSize), + BufferItems: 64, + }) + if err != nil { + return nil, fmt.Errorf("create prompt cache: %w", err) + } + svc.promptCache = cache + } + return svc, nil +} + +func newHTTPMux(svc *yokoService) *http.ServeMux { + mux := http.NewServeMux() + mux.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("OK\n")) + }) + path, handler := yokov1connect.NewYokoServiceHandler(svc) + mux.Handle(path, handler) + return mux +} + +func (s *yokoService) Index(ctx context.Context, req *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { + schemaSDL := req.Msg.GetSchemaSdl() + id := schemaID(schemaSDL) + + s.mu.Lock() + if existing, ok := s.schemas[id]; ok { + s.mu.Unlock() + existing.mu.RLock() + existingSession := existing.sessionID + existing.mu.RUnlock() + log.Printf("Index schema_id=%s reused dir=%s session_id=%s", id, existing.dir, existingSession) + return connect.NewResponse(&yokov1.IndexResponse{SchemaId: id}), nil + } + s.mu.Unlock() + + dir, err := os.MkdirTemp("", "yoko-schema-"+id+"-") + if err != nil { + return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("create schema temp dir: %w", err)) + } + if err := os.WriteFile(filepath.Join(dir, "schema.graphql"), []byte(schemaSDL), 0o600); err != nil { + _ = os.RemoveAll(dir) + return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("write schema.graphql: %w", err)) + } + + sessionID, err := s.runCodexIndex(ctx, dir) + if err != nil { + _ = os.RemoveAll(dir) + return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("codex pre-warm: %w", err)) + } + + entry := &schemaEntry{dir: dir, sessionID: sessionID} + s.mu.Lock() + s.schemas[id] = entry + s.mu.Unlock() + + log.Printf("Index schema_id=%s schema_sdl_size=%d schema_dir=%s session_id=%s rotate_after=%d", id, len(schemaSDL), dir, sessionID, s.rotateAfter) + return connect.NewResponse(&yokov1.IndexResponse{SchemaId: id}), nil +} + +// Close removes every per-schema temp dir created by Index. Safe to call +// multiple times; subsequent calls are no-ops. Codex session rollout files +// live under ~/.codex/sessions/ and are intentionally left in place — they +// belong to the user's codex install. +func (s *yokoService) Close() { + s.mu.Lock() + defer s.mu.Unlock() + for id, entry := range s.schemas { + if err := os.RemoveAll(entry.dir); err != nil { + log.Printf("Close schema_id=%s dir=%s err=%v", id, entry.dir, err) + continue + } + log.Printf("Close schema_id=%s dir=%s removed", id, entry.dir) + } + s.schemas = nil + if s.promptCache != nil { + s.promptCache.Close() + s.promptCache = nil + } +} + +func (s *yokoService) Search(ctx context.Context, req *connect.Request[yokov1.SearchRequest]) (*connect.Response[yokov1.SearchResponse], error) { + schemaID := req.Msg.GetSchemaId() + prompts := req.Msg.GetPrompts() + + s.mu.RLock() + entry, ok := s.schemas[schemaID] + s.mu.RUnlock() + if !ok { + log.Printf("Search schema_id=%s prompt_count=%d not_found=true", schemaID, len(prompts)) + return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("schema_id %q not found; call Index before Search", schemaID)) + } + + // Bump per-session call counter; if we crossed the threshold and no + // rotation is in flight, kick one off in the background. The CAS makes + // the trigger one-shot until rotation completes and clears the flag. + count := entry.searchCount.Add(1) + if s.rotateAfter > 0 && count >= int64(s.rotateAfter) && entry.rotationActive.CompareAndSwap(false, true) { + go s.rotateSession(schemaID, entry, count) + } + + // Cache lookup: collect cached ops in their original positions, batch + // only the misses to codex. + results := make([]*yokov1.GeneratedOperation, len(prompts)) + missing := make([]string, 0, len(prompts)) + missingIdx := make([]int, 0, len(prompts)) + hits := 0 + for i, p := range prompts { + if op, ok := s.cacheGet(schemaID, p); ok { + results[i] = op + hits++ + } else { + missing = append(missing, p) + missingIdx = append(missingIdx, i) + } + } + + if len(missing) == 0 { + log.Printf("Search schema_id=%s prompt_count=%d cache_hits=%d cache_misses=0 codex_skipped=true", schemaID, len(prompts), hits) + return connect.NewResponse(&yokov1.SearchResponse{Operations: filterNonNil(results)}), nil + } + + entry.mu.RLock() + sessionID := entry.sessionID + entry.mu.RUnlock() + + prompt := buildCodexPrompt(missing) + stdout, err := s.runCodexResume(ctx, sessionID, prompt) + if err != nil { + return nil, connect.NewError(connect.CodeInternal, err) + } + + generated, err := parseCodexOperations(stdout) + if err != nil { + if writeErr := os.WriteFile(badOutputPath, stdout, 0o600); writeErr != nil { + log.Printf("warning: failed to write bad codex output path=%s err=%v", badOutputPath, writeErr) + } + log.Printf("warning: codex output was not valid JSON schema_id=%s prompt_count=%d stdout_size=%d err=%v", schemaID, len(missing), len(stdout), err) + return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("codex output was not valid JSON; raw output saved to %s", badOutputPath)) + } + + // Pair generated ops back into the original prompt slots and cache the + // successful ones. We trust order: codex was instructed to return one + // operation per missing prompt in the same order. If codex returned + // fewer ops than asked, the trailing prompts have no slot filled (and + // don't get cached). + for k, idx := range missingIdx { + if k >= len(generated) { + break + } + op := generated[k] + if op == nil || op.GetBody() == "" { + // Failed prompt — don't cache, leave slot nil (filtered out below). + continue + } + results[idx] = op + s.cachePut(schemaID, missing[k], op) + } + + log.Printf("Search schema_id=%s prompt_count=%d cache_hits=%d cache_misses=%d codex_stdout_size=%d parsed_op_count=%d", schemaID, len(prompts), hits, len(missing), len(stdout), len(generated)) + return connect.NewResponse(&yokov1.SearchResponse{Operations: filterNonNil(results)}), nil +} + +func filterNonNil(ops []*yokov1.GeneratedOperation) []*yokov1.GeneratedOperation { + out := ops[:0] + for _, op := range ops { + if op != nil { + out = append(out, op) + } + } + return out +} + +// cacheKey returns the (schema_id, prompt) lookup key. We include schema_id +// so the same prompt against a different supergraph doesn't return a stale +// operation. +func cacheKey(schemaID, prompt string) string { + return schemaID + "\x00" + prompt +} + +func (s *yokoService) cacheGet(schemaID, prompt string) (*yokov1.GeneratedOperation, bool) { + if s.promptCache == nil { + return nil, false + } + return s.promptCache.Get(cacheKey(schemaID, prompt)) +} + +func (s *yokoService) cachePut(schemaID, prompt string, op *yokov1.GeneratedOperation) { + if s.promptCache == nil { + return + } + s.promptCache.Set(cacheKey(schemaID, prompt), op, 1) +} + +// rotateSession is launched in a goroutine when Search counts cross +// rotateAfter. It pre-warms a fresh codex session against the same on-disk +// schema, then atomically swaps in the new sessionID and resets the search +// counter. While rotation is running, concurrent Search calls keep using the +// old sessionID — they just don't trigger a second rotation. +func (s *yokoService) rotateSession(schemaID string, entry *schemaEntry, triggerCount int64) { + start := time.Now() + log.Printf("rotation kickoff schema_id=%s trigger_count=%d", schemaID, triggerCount) + + ctx, cancel := context.WithTimeout(context.Background(), s.codexTimeout) + defer cancel() + + newSessionID, err := s.runCodexIndex(ctx, entry.dir) + if err != nil { + log.Printf("rotation failed schema_id=%s elapsed=%s err=%v", schemaID, time.Since(start).Round(time.Millisecond), err) + entry.rotationActive.Store(false) + return + } + + entry.mu.Lock() + oldSessionID := entry.sessionID + entry.sessionID = newSessionID + entry.mu.Unlock() + + // Reset count BEFORE clearing rotationActive so a Search arriving in this + // gap can't trigger a second rotation on a freshly-rotated session. + entry.searchCount.Store(0) + entry.rotationActive.Store(false) + + log.Printf("rotation complete schema_id=%s old_session=%s new_session=%s elapsed=%s", schemaID, oldSessionID, newSessionID, time.Since(start).Round(time.Millisecond)) +} + +func schemaID(schemaSDL string) string { + sum := sha256.Sum256([]byte(schemaSDL)) + return fmt.Sprintf("%x", sum)[:16] +} + +const indexCodexPrompt = `Read the COMPLETE content of the file ./schema.graphql in your current working directory using your file-reading tool. Read the ENTIRE file (it is approximately 17KB and 824 lines) — do not truncate, do not skim, do not read only a portion. The file is a federated GraphQL supergraph SDL. + +Once the full schema is loaded into your context, output exactly this JSON object and nothing else: + +{"ready":true} + +Do not include preamble, prose, markdown fences, or commentary.` + +func buildCodexPrompt(prompts []string) string { + var b strings.Builder + b.WriteString("You already loaded the federated GraphQL supergraph SDL from\n") + b.WriteString("./schema.graphql earlier in this session. Use it as the source of\n") + b.WriteString("truth — do not re-read the file.\n\n") + b.WriteString("For each user prompt below, generate ONE corresponding GraphQL\n") + b.WriteString("operation (query or mutation) that fulfills the prompt against\n") + b.WriteString("the schema. Return one operation per prompt, in the same order.\n\n") + b.WriteString("PARAMETERIZATION REQUIREMENT (load-bearing):\n") + b.WriteString("Whenever an argument's value depends on the caller's intent (an id,\n") + b.WriteString("a filter, a name, a tag, a limit, etc.), you MUST declare a GraphQL\n") + b.WriteString("variable for it and reference it via $varName. NEVER inline a literal\n") + b.WriteString("for caller-controlled arguments.\n") + b.WriteString("Example query: query employeeByID($id: Int!) { employee(id: $id) { id details { forename surname } } }\n") + b.WriteString("Example mutation: mutation updateEmployeeTag($id: Int!, $tag: String!) { updateEmployeeTag(id: $id, tag: $tag) { id tag } }\n") + b.WriteString("Only inline a literal when the argument is genuinely fixed by the prompt\n") + b.WriteString("(for example, 'list ALL employees' might pass no args at all). Variable\n") + b.WriteString("types must match the schema, including non-null bangs.\n\n") + b.WriteString("OUTPUT FORMAT (strict, machine-parsed):\n") + b.WriteString("- Output a single JSON object with one key: \"operations\" (array).\n") + b.WriteString("- Each operation has keys: name (camelCase), body (operation\n") + b.WriteString(" source text starting with 'query (...)' or\n") + b.WriteString(" 'mutation (...)' when variables are declared, or\n") + b.WriteString(" 'query { ... }' / 'mutation { ... }' when truly\n") + b.WriteString(" variable-free), kind ('query' or 'mutation'), description\n") + b.WriteString(" (one short sentence).\n") + b.WriteString("- operations.length MUST equal the number of user prompts below,\n") + b.WriteString(" in the same order.\n") + b.WriteString("- No prose, no preamble, no markdown fences.\n\n") + b.WriteString("USER PROMPTS:\n") + for _, prompt := range prompts { + b.WriteString("- ") + b.WriteString(prompt) + b.WriteByte('\n') + } + return b.String() +} + +// runCodexIndex performs the one-time pre-warm: codex reads schema.graphql in +// schemaDir and a session is started. The session id (UUID) is parsed from +// codex's first JSONL event and returned so subsequent Search calls can resume +// the same session. +func (s *yokoService) runCodexIndex(ctx context.Context, schemaDir string) (string, error) { + cmdCtx, cancel := context.WithTimeout(ctx, s.codexTimeout) + defer cancel() + + args := []string{ + "exec", + "--json", + "-s", "read-only", + "--skip-git-repo-check", + "--ignore-user-config", + "--ignore-rules", + "-c", "model_reasoning_effort=" + s.codexReasoningEffort, + "-c", "approval_policy=never", + "-", + } + + start := time.Now() + cmd := exec.CommandContext(cmdCtx, s.codexBin, args...) + cmd.Dir = schemaDir + cmd.Stdin = strings.NewReader(indexCodexPrompt) + + var stderr bytes.Buffer + cmd.Stderr = &stderr + + stdout, err := cmd.Output() + elapsed := time.Since(start) + exitCode := 0 + if err != nil { + exitCode = -1 + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + exitCode = exitErr.ExitCode() + } + } + log.Printf("codex index duration=%s exit_code=%d stdout_prefix=%q stderr_prefix=%q", elapsed.Round(time.Millisecond), exitCode, prefix(stdout, 160), prefix(stderr.Bytes(), 160)) + + if cmdCtx.Err() != nil { + return "", fmt.Errorf("codex index timed out after %s", s.codexTimeout) + } + if err != nil { + return "", fmt.Errorf("codex index failed exit_code=%d stderr=%q: %w", exitCode, prefix(stderr.Bytes(), 300), err) + } + + return parseThreadID(stdout) +} + +// runCodexResume resumes the previously-warmed session and runs the user +// prompts. The agent's last message (a JSON object of operations) is captured +// via `--output-last-message` and returned for parsing. +func (s *yokoService) runCodexResume(ctx context.Context, sessionID, prompt string) ([]byte, error) { + cmdCtx, cancel := context.WithTimeout(ctx, s.codexTimeout) + defer cancel() + + outFile, err := os.CreateTemp("", "yoko-search-out-*.txt") + if err != nil { + return nil, fmt.Errorf("create output temp file: %w", err) + } + outPath := outFile.Name() + _ = outFile.Close() + defer os.Remove(outPath) + + args := []string{ + "exec", "resume", sessionID, + "-o", outPath, + "--skip-git-repo-check", + "--ignore-user-config", + "--ignore-rules", + "-c", "model_reasoning_effort=" + s.codexReasoningEffort, + "-c", "approval_policy=never", + "-", + } + + start := time.Now() + cmd := exec.CommandContext(cmdCtx, s.codexBin, args...) + cmd.Stdin = strings.NewReader(prompt) + + var stderr bytes.Buffer + cmd.Stderr = &stderr + + err = cmd.Run() + elapsed := time.Since(start) + exitCode := 0 + if err != nil { + exitCode = -1 + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + exitCode = exitErr.ExitCode() + } + } + + if cmdCtx.Err() != nil { + return nil, fmt.Errorf("codex resume timed out after %s", s.codexTimeout) + } + if err != nil { + return nil, fmt.Errorf("codex resume failed exit_code=%d stderr=%q: %w", exitCode, prefix(stderr.Bytes(), 300), err) + } + + output, err := os.ReadFile(outPath) + if err != nil { + return nil, fmt.Errorf("read codex last message: %w", err) + } + log.Printf("codex resume duration=%s session_id=%s out_size=%d out_prefix=%q", elapsed.Round(time.Millisecond), sessionID, len(output), prefix(output, 160)) + return output, nil +} + +// parseThreadID reads the first JSONL event from codex stdout and extracts +// the thread/session UUID from a `thread.started` event. +func parseThreadID(stdout []byte) (string, error) { + line, _, _ := bytes.Cut(stdout, []byte("\n")) + var ev struct { + Type string `json:"type"` + ThreadID string `json:"thread_id"` + } + if err := json.Unmarshal(line, &ev); err != nil { + return "", fmt.Errorf("parse thread.started event: %w (line=%q)", err, prefix(line, 200)) + } + if ev.Type != "thread.started" || ev.ThreadID == "" { + return "", fmt.Errorf("expected thread.started event with thread_id, got: %q", prefix(line, 200)) + } + return ev.ThreadID, nil +} + +func parseCodexOperations(stdout []byte) ([]*yokov1.GeneratedOperation, error) { + payload := extractJSONObject(stdout) + var parsed codexOutput + if err := json.Unmarshal(payload, &parsed); err != nil { + return nil, err + } + + ops := make([]*yokov1.GeneratedOperation, 0, len(parsed.Operations)) + for _, op := range parsed.Operations { + ops = append(ops, &yokov1.GeneratedOperation{ + Name: op.Name, + Body: op.Body, + Kind: operationKind(op.Kind), + Description: op.Description, + }) + } + return ops, nil +} + +func operationKind(kind string) yokov1.OperationKind { + switch strings.ToLower(kind) { + case "query": + return yokov1.OperationKind_OPERATION_KIND_QUERY + case "mutation": + return yokov1.OperationKind_OPERATION_KIND_MUTATION + default: + return yokov1.OperationKind_OPERATION_KIND_UNSPECIFIED + } +} + +// extractJSONObject returns the substring from the first '{' to the last '}' +// in stdout. Resume calls don't support --output-schema, so this guards +// against occasional preamble or trailing prose so json.Unmarshal still +// succeeds. +func extractJSONObject(stdout []byte) []byte { + start := bytes.IndexByte(stdout, '{') + end := bytes.LastIndexByte(stdout, '}') + if start < 0 || end < 0 || end < start { + return stdout + } + return stdout[start : end+1] +} + +func prefix(value []byte, max int) string { + if len(value) <= max { + return string(value) + } + return string(value[:max]) +} diff --git a/demo/code-mode/yoko-mock/main_test.go b/demo/code-mode/yoko-mock/main_test.go new file mode 100644 index 0000000000..61b0ea3d4f --- /dev/null +++ b/demo/code-mode/yoko-mock/main_test.go @@ -0,0 +1,169 @@ +package main + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "runtime" + "testing" + "time" + + "connectrpc.com/connect" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + yokov1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1" + "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yokov1connect" + "google.golang.org/protobuf/proto" +) + +func TestIndexThenSearchReturnsGeneratedOperations(t *testing.T) { + writeFakeCodex(t, + `{"type":"thread.started","thread_id":"fake-thread"}`, + `{"operations":[{"name":"getViewer","body":"query getViewer { viewer { id } }","kind":"query","description":"Fetches the current viewer."}]}`, + ) + client := newTestClient(t) + + indexResp, err := client.Index(context.Background(), connect.NewRequest(&yokov1.IndexRequest{ + SchemaSdl: "type Query { viewer: User } type User { id: ID! }", + })) + require.NoError(t, err) + + searchResp, err := client.Search(context.Background(), connect.NewRequest(&yokov1.SearchRequest{ + SchemaId: indexResp.Msg.GetSchemaId(), + Prompts: []string{"get the viewer"}, + SessionId: "session-1", + })) + require.NoError(t, err) + + expected := &yokov1.SearchResponse{ + Operations: []*yokov1.GeneratedOperation{ + { + Name: "getViewer", + Body: "query getViewer { viewer { id } }", + Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, + Description: "Fetches the current viewer.", + }, + }, + } + assert.Equal(t, normalizeSearchResponse(t, expected), normalizeSearchResponse(t, searchResp.Msg)) +} + +func TestSearchUnknownSchemaIDReturnsNotFound(t *testing.T) { + writeFakeCodex(t, + `{"type":"thread.started","thread_id":"fake-thread"}`, + `{"operations":[]}`, + ) + client := newTestClient(t) + + _, err := client.Search(context.Background(), connect.NewRequest(&yokov1.SearchRequest{ + SchemaId: "unknown", + Prompts: []string{"get the viewer"}, + })) + + require.Error(t, err) + assert.Equal(t, connect.CodeNotFound, connect.CodeOf(err)) +} + +func TestSearchBadJSONReturnsInternal(t *testing.T) { + writeFakeCodex(t, + `{"type":"thread.started","thread_id":"fake-thread"}`, + `not json`, + ) + client := newTestClient(t) + + indexResp, err := client.Index(context.Background(), connect.NewRequest(&yokov1.IndexRequest{ + SchemaSdl: "type Query { viewer: ID! }", + })) + require.NoError(t, err) + + _, err = client.Search(context.Background(), connect.NewRequest(&yokov1.SearchRequest{ + SchemaId: indexResp.Msg.GetSchemaId(), + Prompts: []string{"get the viewer"}, + })) + + require.Error(t, err) + assert.Equal(t, connect.CodeInternal, connect.CodeOf(err)) +} + +func newTestClient(t *testing.T) yokov1connect.YokoServiceClient { + t.Helper() + + svc, err := newYokoService("codex", time.Second, "low", 0, 16) // disable rotation; small cache + require.NoError(t, err) + t.Cleanup(svc.Close) + mux := newHTTPMux(svc) + httpClient := &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + return rec.Result(), nil + })} + + return yokov1connect.NewYokoServiceClient(httpClient, "http://yoko.test") +} + +// writeFakeCodex installs a stub `codex` binary on PATH that mocks both the +// initial `codex exec` (Index pre-warm) and `codex exec resume` (Search) calls. +// The stub detects "resume" in its argv to switch modes. +// +// - indexStdout is printed to stdout for the Index call (e.g. a JSONL line +// like {"type":"thread.started","thread_id":"..."}). +// - resumeMessage is written to the file passed via -o for the Search +// call (codex's --output-last-message contract). +func writeFakeCodex(t *testing.T, indexStdout, resumeMessage string) { + t.Helper() + + dir := t.TempDir() + indexFile := filepath.Join(dir, "index.out") + require.NoError(t, os.WriteFile(indexFile, []byte(indexStdout+"\n"), 0o644)) + resumeFile := filepath.Join(dir, "resume.out") + require.NoError(t, os.WriteFile(resumeFile, []byte(resumeMessage), 0o644)) + + name := "codex" + if runtime.GOOS == "windows" { + name += ".bat" + } + path := filepath.Join(dir, name) + var script string + if runtime.GOOS == "windows" { + // Minimal Windows fallback — only Index path is exercised in CI on Unix. + script = "@echo off\r\ntype \"" + indexFile + "\"\r\n" + } else { + script = "#!/bin/sh\n" + + "is_resume=0\n" + + "out_file=\"\"\n" + + "prev=\"\"\n" + + "for arg in \"$@\"; do\n" + + " if [ \"$prev\" = \"-o\" ]; then out_file=\"$arg\"; fi\n" + + " if [ \"$arg\" = \"resume\" ]; then is_resume=1; fi\n" + + " prev=\"$arg\"\n" + + "done\n" + + "cat >/dev/null\n" + + "if [ \"$is_resume\" = \"1\" ]; then\n" + + " if [ -n \"$out_file\" ]; then cat \"" + resumeFile + "\" > \"$out_file\"; fi\n" + + "else\n" + + " cat \"" + indexFile + "\"\n" + + "fi\n" + } + require.NoError(t, os.WriteFile(path, []byte(script), 0o755)) + t.Setenv("PATH", dir+string(os.PathListSeparator)+os.Getenv("PATH")) +} + +var _ http.Handler = (*http.ServeMux)(nil) + +func normalizeSearchResponse(t *testing.T, resp *yokov1.SearchResponse) *yokov1.SearchResponse { + t.Helper() + + data, err := proto.Marshal(resp) + require.NoError(t, err) + normalized := &yokov1.SearchResponse{} + require.NoError(t, proto.Unmarshal(data, normalized)) + return normalized +} + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} diff --git a/demo/code-mode/yoko-mock/schema.graphql b/demo/code-mode/yoko-mock/schema.graphql new file mode 100644 index 0000000000..ed60f8b07b --- /dev/null +++ b/demo/code-mode/yoko-mock/schema.graphql @@ -0,0 +1,825 @@ +schema { + query: Query +} + +""" +Pylon — customer support tickets, accounts, contacts, surveys. +All types vendor-prefixed with `pylon_` to keep the federated schema collision-free. +""" +scalar pylon_DateTime + +scalar pylon_JSON + +""" +Slack — read shared customer channels, message history, user info. +All types vendor-prefixed with `slack_`. +""" +type Query { + pylon_searchAccounts(input: pylon_SearchAccountsInput!): [pylon_Account!]! + pylon_getAccount(id: ID!): pylon_Account + pylon_searchIssues(input: pylon_SearchIssuesInput!): [pylon_Issue!]! + pylon_getIssue(id: ID!): pylon_Issue + pylon_listIssueMessages(issueId: ID!): [pylon_Message!]! + pylon_listContacts(input: pylon_ListContactsInput!): [pylon_Contact!]! + pylon_listSurveys(limit: Int, cursor: String): [pylon_Survey!]! + pylon_getSurveyResponses(surveyId: ID!): [pylon_SurveyResponse!]! + """ + Canonical custom-field slugs: arr, renewal_date, segment, csm_owner, health_status. + """ + pylon_listAccountCustomFields: [pylon_CustomField!]! + pylon_listUsers: [pylon_User!]! + linear_customer(id: ID!): linear_Customer + linear_customerByDomain(domain: String!): linear_Customer + linear_customers(filter: linear_CustomersFilter): [linear_Customer!]! + linear_customerNeeds(filter: linear_CustomerNeedsFilter): [linear_CustomerNeed!]! + linear_issue(id: ID!): linear_Issue + linear_issues(filter: linear_IssuesFilter): [linear_Issue!]! + linear_team(id: ID!): linear_Team + linear_teams: [linear_Team!]! + linear_cycles(teamId: ID): [linear_Cycle!]! + linear_project(id: ID!): linear_Project + linear_user(id: ID!): linear_User + """ + Run a HogQL query. Two modes are supported: + + 1. **Preset (recommended)** — set `input.preset` to one of the values + of `posthog_HogQLPreset`. These map 1:1 to the curated mock handlers + and are guaranteed to execute. Pass any required parameters via + `input.presetParams` (see each enum value's docs for what it needs). + + 2. **Freeform** — set `input.hogql` to a raw HogQL string. Only the exact + shapes recognised by the preset matchers will succeed; any other query + returns Unimplemented with the supported handler list. + + If both `preset` and `hogql` are set, `preset` takes precedence. + """ + posthog_query(input: posthog_QueryInput!): posthog_QueryResult! + """List groups (customer orgs) for a given group type index.""" + posthog_listGroups(input: posthog_ListGroupsInput!): [posthog_Group!]! + posthog_getGroup(typeIndex: Int!, key: String!): posthog_Group + posthog_listEvents(input: posthog_ListEventsInput!): posthog_ListEventsResult! + """ + List feature flags configured for the project (used for feature-adoption mapping). + """ + posthog_listFeatureFlags(limit: Int): [posthog_FeatureFlag!]! + """Poll an async query result by id.""" + posthog_getAsyncQueryStatus(queryId: String!): posthog_AsyncQueryStatus! + circleback_searchMeetings(input: circleback_SearchMeetingsInput!): [circleback_Meeting!]! + circleback_readMeetings(meetingIds: [ID!]!): [circleback_Meeting!]! + circleback_getTranscripts(meetingIds: [ID!]!): [circleback_Transcript!]! + circleback_searchTranscripts(query: String!, limit: Int): [circleback_TranscriptHit!]! + circleback_searchActionItems(query: String!, limit: Int): [circleback_ActionItem!]! + circleback_findDomains(query: String!): [circleback_Domain!]! + circleback_findProfiles(query: String!): [circleback_Profile!]! + circleback_searchCalendarEvents(input: circleback_SearchCalendarInput!): [circleback_CalendarEvent!]! + circleback_listTags: [String!]! + slack_listChannels(input: slack_ListChannelsInput!): [slack_Channel!]! + slack_getChannel(channelId: ID!): slack_Channel + slack_history(input: slack_HistoryInput!): slack_HistoryResult! + slack_replies(channelId: ID!, threadTs: String!, limit: Int): [slack_Message!]! + slack_userInfo(userId: ID!): slack_User + slack_listUsers(limit: Int, cursor: String): [slack_User!]! + slack_searchMessages(query: String!, count: Int, page: Int): [slack_Message!]! + slack_authTest: slack_AuthTestResult! + notion_search(input: notion_SearchInput!): [notion_SearchResult!]! + notion_getPage(pageId: ID!): notion_Page + notion_getDatabase(databaseId: ID!): notion_Database + notion_queryDataSource(input: notion_QueryDataSourceInput!): notion_QueryDataSourceResult! + notion_getBlockChildren(blockId: ID!, limit: Int, cursor: String): notion_BlockChildrenResult! + notion_listUsers(limit: Int, cursor: String): [notion_User!]! +} + +type pylon_Account { + id: ID! + name: String! + domains: [String!]! + tags: [String!]! + customFields: [pylon_CustomFieldValue!]! +} + +type pylon_CustomField { + slug: String! + label: String! + objectType: String! + type: String! +} + +type pylon_CustomFieldValue { + slug: String! + label: String! + value: String! +} + +type pylon_Issue { + id: ID! + title: String! + state: String! + accountId: ID + priority: pylon_IssuePriority! + number: Int! + assignee: pylon_User + requester: pylon_Contact + tags: [String!]! + latestMessageAt: pylon_DateTime + slaBreached: Boolean! + createdAt: pylon_DateTime! + resolvedAt: pylon_DateTime + firstResponseSeconds: Int + resolutionSeconds: Int + resolutionBreachTime: pylon_DateTime + numberOfTouches: Int + externalIssues: [pylon_ExternalIssueLink!]! + csatResponses: [pylon_SurveyResponse!]! +} + +enum pylon_IssuePriority { + P1 + P2 + P3 + P4 +} + +type pylon_ExternalIssueLink { + source: String! + externalId: String! + url: String +} + +type pylon_Message { + id: ID! + issueId: ID! + authorId: ID + body: String! + createdAt: pylon_DateTime! +} + +type pylon_Contact { + id: ID! + email: String + name: String + accountId: ID +} + +type pylon_User { + id: ID! + name: String! + email: String! +} + +type pylon_Survey { + id: ID! + type: pylon_SurveyType! + name: String! +} + +enum pylon_SurveyType { + CSAT + NPS + CES +} + +type pylon_SurveyResponse { + id: ID! + surveyId: ID! + accountId: ID + contactId: ID + score: Int! + comment: String + createdAt: pylon_DateTime! +} + +input pylon_SearchAccountsInput { + name: String + domain: String + tag: String + limit: Int + cursor: String +} + +input pylon_SearchIssuesInput { + accountId: ID + state: String + createdAfter: pylon_DateTime + createdBefore: pylon_DateTime + resolvedAfter: pylon_DateTime + resolvedBefore: pylon_DateTime + priority: pylon_IssuePriority + tags: [String!] + slaBreached: Boolean + limit: Int +} + +input pylon_ListContactsInput { + accountId: ID + email: String + limit: Int + cursor: String +} + +""" +Linear — engineering issues, projects, and the native Customer entity. +All types vendor-prefixed with `linear_`. +""" +scalar linear_DateTime + +type linear_Customer { + id: ID! + name: String! + domains: [String!]! + externalIds: [String!]! + revenue: Float + size: Int + ownerId: ID + slackChannelId: String +} + +type linear_CustomerNeed { + id: ID! + customerId: ID! + issueId: ID + projectId: ID + important: Boolean! + body: String + createdAt: linear_DateTime! +} + +type linear_Issue { + id: ID! + identifier: String! + title: String! + description: String + priority: Int! + priorityLabel: String! + state: linear_IssueState! + needs: [linear_CustomerNeed!]! + customerTicketCount: Int! + teamId: ID! + assigneeId: ID + cycleId: ID + projectId: ID + labels: [String!]! + url: String! + createdAt: linear_DateTime! + updatedAt: linear_DateTime! + completedAt: linear_DateTime + addedToCycleAt: linear_DateTime +} + +enum linear_IssueState { + TRIAGE + BACKLOG + UNSTARTED + STARTED + COMPLETED + CANCELED + DUPLICATE +} + +type linear_Project { + id: ID! + name: String! + description: String + state: String! + health: linear_ProjectHealth! + progress: Float! + leadId: ID + teamId: ID! + startDate: linear_DateTime + targetDate: linear_DateTime + url: String! +} + +enum linear_ProjectHealth { + ON_TRACK + AT_RISK + OFF_TRACK +} + +type linear_Cycle { + id: ID! + teamId: ID! + number: Int! + name: String + startsAt: linear_DateTime! + endsAt: linear_DateTime! + progress: Float! +} + +type linear_Team { + id: ID! + key: String! + name: String! +} + +type linear_User { + id: ID! + name: String! + email: String! + active: Boolean! +} + +input linear_CustomersFilter { + domain: String + externalId: String + search: String + limit: Int +} + +input linear_CustomerNeedsFilter { + customerId: ID + createdAfter: linear_DateTime + createdBefore: linear_DateTime + important: Boolean + limit: Int +} + +input linear_IssuesFilter { + customerId: ID + teamId: ID + cycleId: ID + priority: Int + state: linear_IssueState + createdAfter: linear_DateTime + createdBefore: linear_DateTime + updatedAfter: linear_DateTime + limit: Int +} + +""" +PostHog — product telemetry queryable via HogQL. +Customers/orgs are modeled as PostHog `groups` (typeIndex 0..4). +All types vendor-prefixed with `posthog_`. + +Mock auth requires both X-Posthog-Token and X-Posthog-Project-Id metadata. +""" +scalar posthog_JSON + +type posthog_QueryResult { + columns: [String!]! + types: [String!]! + rows: [posthog_JSON!]! + hasMore: Boolean! + queryId: String + asyncStatus: posthog_AsyncQueryStatus +} + +type posthog_AsyncQueryStatus { + queryId: String! + state: posthog_AsyncQueryState! + errorMessage: String +} + +enum posthog_AsyncQueryState { + PENDING + RUNNING + COMPLETED + ERROR +} + +input posthog_QueryInput { + """ + Pre-defined query to execute. Recommended path — runs a curated mock + handler and is guaranteed to succeed. Takes precedence over `hogql` + when both are set. + """ + preset: posthog_HogQLPreset + """ + Parameters consumed by the chosen `preset`. See each enum value's docs + for which fields it requires. + """ + presetParams: posthog_HogQLPresetParams + """ + Freeform HogQL string. Only the exact shapes recognised by the preset + matchers will succeed; any other query returns Unimplemented. + Prefer `preset` for new callers. Omit when `preset` is set. + """ + hogql: String + refresh: posthog_RefreshMode + filtersOverride: posthog_JSON +} + +""" +Pre-defined HogQL queries available in the mock. Each value runs a +curated handler over the seeded event data and is guaranteed to execute. +""" +enum posthog_HogQLPreset { + """ + Quarter-over-quarter event count for one company. + Required params: `domain` (e.g. "ebay.com"). + Returns rows of (quarter DateTime, events UInt64). + """ + QOQ_COMPANY + """ + Daily-active-users + per-day event count for one company over the + last 30 days. + Required params: `domain`. + Returns rows of (day Date, dau UInt64, events UInt64). + """ + DAU_TIMESERIES_COMPANY + """ + Accounts whose 30-day event volume dropped >20% versus the prior 30 + days, sorted ascending by pct_change (most-at-risk first). + Required params: none. + Returns rows of (key, recent, prior, delta, pct_change). + """ + AT_RISK_ACCOUNTS + """ + Feature-adoption matrix across the top-10 customers (one row per + customer × feature_slug used). + Required params: none. + Returns rows of (key, feature_slug, uses). + """ + FEATURE_ADOPTION_TOP10 + """ + P95 request latency bucketed hourly for one company in a time window. + Required params: `domain`, `start`, `end` (RFC3339, e.g. + "2026-04-21T14:00:00Z"). + Returns rows of (hour DateTime, p95_ms Float64). + """ + LATENCY_HOURLY_COMPANY + """ + Per-event-name count for one company in a time window. + Required params: `domain`, `start` (RFC3339); optional: `end` (RFC3339). + Returns rows of (event String, count UInt64). + """ + EVENT_BREAKDOWN_WINDOW + """ + Week-over-week event delta across the entire portfolio (per group_0), + sorted by delta descending. + Required params: none. + Returns rows of (key, this_count, prev_count, delta). + """ + WEEKLY_PORTFOLIO_DELTA +} + +""" +Parameters for a `posthog_HogQLPreset` query. Only the fields required +by the chosen preset need to be set; extras are ignored. +""" +input posthog_HogQLPresetParams { + """Group key, typically a customer domain like "ebay.com".""" + domain: String + """ + RFC3339 start timestamp with time component, e.g. "2026-04-21T14:00:00Z". + """ + start: String + """ + RFC3339 end timestamp with time component, e.g. "2026-04-21T15:00:00Z". + """ + end: String +} + +enum posthog_RefreshMode { + BLOCKING + ASYNC + LAZY + FORCE_BLOCKING + FORCE_ASYNC +} + +type posthog_Group { + typeIndex: Int! + key: String! + properties: posthog_JSON! + createdAt: String +} + +input posthog_ListGroupsInput { + typeIndex: Int! + search: String + limit: Int +} + +type posthog_Event { + timestamp: String! + distinctId: String! + event: String! + group0: String! + properties: posthog_JSON! +} + +type posthog_ListEventsResult { + events: [posthog_Event!]! + nextCursor: String + hasMore: Boolean! +} + +input posthog_ListEventsInput { + groupKey: String + eventName: String + startTime: String + endTime: String + limit: Int + cursor: String +} + +type posthog_FeatureFlag { + id: ID! + key: String! + name: String + active: Boolean! + filters: posthog_JSON! + topRolloutPercentage: Float +} + +""" +Circleback — meeting transcripts, summaries, action items. +This subgraph serves deterministic embedded Circleback-style mock fixtures. +All types vendor-prefixed with `circleback_`. +""" +scalar circleback_DateTime + +type circleback_Meeting { + id: ID! + name: String! + createdAt: circleback_DateTime! + duration: Int! + url: String + recordingUrl: String + tags: [String!]! + attendees: [circleback_Attendee!]! + notes: String + actionItems: [circleback_ActionItem!]! + icalUid: String + organizerEmail: String! +} + +type circleback_Attendee { + email: String! + name: String +} + +enum circleback_ActionItemStatus { + PENDING + DONE +} + +type circleback_ActionItem { + id: ID! + meetingId: ID! + meetingName: String + title: String! + description: String + status: circleback_ActionItemStatus! + assignee: circleback_Attendee +} + +type circleback_Transcript { + meetingId: ID! + segments: [circleback_TranscriptSegment!]! +} + +type circleback_TranscriptSegment { + speaker: String! + startMs: Int! + endMs: Int! + text: String! +} + +type circleback_TranscriptHit { + meetingId: ID! + speaker: String! + text: String! + startMs: Int! + score: Float! +} + +type circleback_Domain { + domain: String! + meetingCount: Int! +} + +type circleback_Profile { + email: String! + name: String + domain: String + meetingCount: Int! +} + +type circleback_CalendarEvent { + id: ID! + title: String! + startsAt: circleback_DateTime! + endsAt: circleback_DateTime! + attendees: [circleback_Attendee!]! + icalUid: String! + tags: [String!]! + notes: String + actionItems: [circleback_ActionItem!]! + organizerEmail: String! +} + +input circleback_SearchMeetingsInput { + attendeeEmail: String + attendeeDomain: String + tag: String + keyword: String + startDate: circleback_DateTime + endDate: circleback_DateTime + limit: Int +} + +input circleback_SearchCalendarInput { + query: String + startDate: circleback_DateTime + endDate: circleback_DateTime + limit: Int +} + +type slack_Channel { + id: ID! + name: String! + isPrivate: Boolean! + isArchived: Boolean! + isMember: Boolean! + isShared: Boolean! + isExtShared: Boolean! + created: Int! + creatorId: ID! + topic: String + purpose: String + numMembers: Int +} + +type slack_Message { + ts: String! + channelId: ID! + userId: ID + text: String! + threadTs: String + permalink: String! + username: String + subtype: String + editedTs: String + replyCount: Int + replyUsersCount: Int + latestReplyTs: String + reactions: [slack_Reaction!]! +} + +type slack_Reaction { + name: String! + count: Int! + userIds: [ID!]! +} + +type slack_HistoryResult { + messages: [slack_Message!]! + hasMore: Boolean! + nextCursor: String +} + +type slack_User { + id: ID! + name: String! + realName: String + email: String + title: String + displayName: String + image: String + tz: String + isBot: Boolean! + deleted: Boolean! +} + +type slack_AuthTestResult { + ok: Boolean! + url: String! + team: String! + user: String! + teamId: ID! + userId: ID! + botId: ID +} + +input slack_ListChannelsInput { + types: [slack_ChannelType!] + excludeArchived: Boolean + namePrefix: String + limit: Int +} + +enum slack_ChannelType { + PUBLIC_CHANNEL + PRIVATE_CHANNEL + MPIM + IM +} + +input slack_HistoryInput { + channelId: ID! + oldest: String + latest: String + limit: Int + cursor: String +} + +""" +Notion — internal knowledge base / feature documentation. +All types vendor-prefixed with `notion_`. +""" +scalar notion_JSON + +scalar notion_DateTime + +type notion_Page { + id: ID! + """ + Synthesized convenience field. Notion stores page titles in the title property + (usually properties["Name"].title[0].plain_text), and the resolver derives this value. + """ + title: String! + url: String! + parentId: ID + parentType: notion_ParentType + createdAt: notion_DateTime! + updatedAt: notion_DateTime! + archived: Boolean! + properties: notion_JSON! +} + +type notion_Database { + id: ID! + title: String! + url: String! + createdAt: notion_DateTime! + updatedAt: notion_DateTime! + dataSources: [notion_DataSource!]! +} + +type notion_DataSource { + id: ID! + databaseId: ID! + name: String +} + +type notion_Block { + id: ID! + type: String! + hasChildren: Boolean! + archived: Boolean! + content: notion_JSON! +} + +type notion_BlockChildrenResult { + blocks: [notion_Block!]! + hasMore: Boolean! + nextCursor: String +} + +type notion_QueryDataSourceResult { + pages: [notion_Page!]! + hasMore: Boolean! + nextCursor: String +} + +type notion_SearchResult { + objectType: notion_ObjectType! + id: ID! + title: String! + url: String! +} + +enum notion_ObjectType { + PAGE + DATABASE + DATA_SOURCE +} + +enum notion_ParentType { + PAGE + DATABASE + DATA_SOURCE + WORKSPACE + BLOCK +} + +type notion_User { + id: ID! + name: String! + email: String + type: notion_UserType! +} + +enum notion_UserType { + PERSON + BOT +} + +input notion_SearchInput { + query: String + filterType: notion_ObjectType + limit: Int +} + +input notion_QueryDataSourceInput { + dataSourceId: ID! + """ + Best-effort Notion-style property equals filter. The mock supports Status, + Segment, Health, Slug, and Domain; unsupported filter JSON is ignored. + """ + filter: notion_JSON + sorts: notion_JSON + limit: Int + cursor: String +} \ No newline at end of file diff --git a/demo/pkg/subgraphs/employees/subgraph/schema.resolvers.go b/demo/pkg/subgraphs/employees/subgraph/schema.resolvers.go index 3f0d30d16b..ecbf4d61f4 100644 --- a/demo/pkg/subgraphs/employees/subgraph/schema.resolvers.go +++ b/demo/pkg/subgraphs/employees/subgraph/schema.resolvers.go @@ -61,25 +61,9 @@ func (r *mutationResolver) UpdateEmployeeTag(ctx context.Context, id int, tag st defer r.mux.Unlock() for _, employee := range r.EmployeesData { if id == employee.ID { - details := &model.Details{} - if employee.Details != nil { - details.Forename = employee.Details.Forename - details.Surname = employee.Details.Surname - details.Location = employee.Details.Location - } - return &model.Employee{ - ID: employee.ID, - Details: details, - Tag: tag, - Expertise: employee.Expertise, - Role: employee.Role, - Notes: employee.Notes, - UpdatedAt: time.Now().String(), - StartDate: employee.StartDate, - PrimaryWorkItem: employee.PrimaryWorkItem, - LastWorkReview: employee.LastWorkReview, - WorkSetup: employee.WorkSetup, - }, nil + employee.Tag = tag + employee.UpdatedAt = time.Now().String() + return employee, nil } } return nil, nil diff --git a/proto/wg/cosmo/code_mode/yoko/v1/yoko.proto b/proto/wg/cosmo/code_mode/yoko/v1/yoko.proto new file mode 100644 index 0000000000..3add37193b --- /dev/null +++ b/proto/wg/cosmo/code_mode/yoko/v1/yoko.proto @@ -0,0 +1,84 @@ +syntax = "proto3"; + +package wundergraph.cosmo.code_mode.yoko.v1; + +// Yoko generates GraphQL operations for natural-language prompts. +// +// Two-step flow: +// 1. Index(schema_sdl) -> schema_id. Idempotent: the same SDL +// always returns the same schema_id for as long as the index +// is retained. The router caches schema_id and re-indexes only +// on supergraph reload (or when Search returns NOT_FOUND). +// 2. Search(prompts, schema_id) -> operations. Yoko owns prompt +// fan-out, partial-failure handling, cross-prompt deduplication, +// and ranking. +// +// The router never sends a schema hash. Yoko is the sole authority +// on schema identity; the router only sends raw SDL on Index and an +// opaque id on Search. +service YokoService { + rpc Index(IndexRequest) returns (IndexResponse); + rpc Search(SearchRequest) returns (SearchResponse); +} + +message IndexRequest { + // The supergraph SDL to index. Sent in full on every Index call; + // Yoko deduplicates internally and is free to short-circuit when + // the SDL is already known. + string schema_sdl = 1; +} + +message IndexResponse { + // Opaque, Yoko-assigned identifier for this schema. Stable for as + // long as Yoko retains the index. Subsequent Search calls pass this + // back instead of the full SDL. Idempotent: the same SDL returns + // the same schema_id. + string schema_id = 1; +} + +message SearchRequest { + // Batch of natural-language prompts. Bounded at 20 by the host. + repeated string prompts = 1; + + // Identifier returned by a prior Index call. If Yoko no longer + // recognizes the id (e.g. eviction, restart), it MUST return the + // Connect error code NOT_FOUND; the router re-indexes and retries + // the call exactly once. + string schema_id = 2; + + // Opaque MCP session ID for telemetry correlation only. + // Yoko MUST NOT use this for stateful behavior — sessions are owned + // by the router. + string session_id = 3; +} + +message SearchResponse { + // Operations across all prompts, already deduplicated and ranked. + // Order is significant: earlier entries rank higher and are preferred + // when bundle truncation drops from the tail. + repeated GeneratedOperation operations = 1; +} + +message GeneratedOperation { + // Suggested operation name (camelCase preferred). The host applies + // its own identifier normalization and in-session collision-suffix + // logic on top of this — see §6. + string name = 1; + + // GraphQL operation body (query or mutation source text). + string body = 2; + + // Operation kind. Subscriptions are out of scope; if Yoko returns + // one, the host drops it with a single warn log. + OperationKind kind = 3; + + // Human-readable description, surfaced as JSDoc on the typed + // `tools.` signature in the rendered bundle. + string description = 4; +} + +enum OperationKind { + OPERATION_KIND_UNSPECIFIED = 0; + OPERATION_KIND_QUERY = 1; + OPERATION_KIND_MUTATION = 2; +} diff --git a/router-tests/code_mode_named_ops_test.go b/router-tests/code_mode_named_ops_test.go new file mode 100644 index 0000000000..40ea76a5fb --- /dev/null +++ b/router-tests/code_mode_named_ops_test.go @@ -0,0 +1,621 @@ +package integration + +import ( + "context" + "encoding/json" + "fmt" + "net" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "connectrpc.com/connect" + miniredis "github.com/alicebob/miniredis/v2" + mark3mcp "github.com/mark3labs/mcp-go/mcp" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zapcore" + + "github.com/wundergraph/cosmo/router-tests/freeport" + "github.com/wundergraph/cosmo/router-tests/testenv" + yokov1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1" + yokoconnect "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yokov1connect" + nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" + "github.com/wundergraph/cosmo/router/pkg/config" + "github.com/wundergraph/cosmo/router/pkg/controlplane/configpoller" + "github.com/wundergraph/cosmo/router/pkg/routerconfig" +) + +const codeModePersistedOpsURI = "yoko://persisted-ops.d.ts" + +const ( + firstEmployeeOpName = "firstEmployee" + employeeByIDOpName = "employeeByID" + updateTagOpName = "updateEmployeeTag" + + firstEmployeeQuery = `query firstEmployee { firstEmployee { id details { forename surname } } }` + employeeByIDQuery = `query employeeByID($id: Int!) { employee(id: $id) { id details { forename surname } } }` + updateTagMutation = `mutation updateEmployeeTag($id: Int!, $tag: String!) { updateEmployeeTag(id: $id, tag: $tag) { id tag } }` +) + +const firstEmployeeTS = `/** Fetch the first employee. */ +firstEmployee(): R<{ firstEmployee: { id: number; details: { forename: string; surname: string } | null } }>;` + +const employeeByIDTS = `/** Fetch employee by id. */ +employeeByID(vars: { id: number }): R<{ employee: { id: number; details: { forename: string; surname: string } | null } | null }>;` + +const updateTagTS = `/** Update employee tag. */ +updateEmployeeTag(vars: { id: number; tag: string }): R<{ updateEmployeeTag: { id: number; tag: string } | null }>;` + +const twoOpsFragment = firstEmployeeTS + "\n\n" + employeeByIDTS + +// indentBundleEntry mirrors tsgen's behavior: every line of a per-op block +// (JSDoc + signature) is indented by 2 spaces inside the tools object. +func indentBundleEntry(s string) string { + return " " + strings.ReplaceAll(s, "\n", "\n ") +} + +const emptyOpsBundle = `type GraphQLError = { message: string; path?: (string | number)[]; extensions?: Record }; +type R = Promise<{ data: T | null; errors?: GraphQLError[] }>; +// Known limitation: union and interface selections are typed as unknown. + +declare const tools: {}; + +declare function notNull(value: T | null | undefined, message?: string): T; +declare function compact(value: T): T;` + +var firstEmployeeBundle = `type GraphQLError = { message: string; path?: (string | number)[]; extensions?: Record }; +type R = Promise<{ data: T | null; errors?: GraphQLError[] }>; +// Known limitation: union and interface selections are typed as unknown. + +declare const tools: { +` + indentBundleEntry(firstEmployeeTS) + ` +}; + +declare function notNull(value: T | null | undefined, message?: string): T; +declare function compact(value: T): T;` + +var employeeByIDBundle = `type GraphQLError = { message: string; path?: (string | number)[]; extensions?: Record }; +type R = Promise<{ data: T | null; errors?: GraphQLError[] }>; +// Known limitation: union and interface selections are typed as unknown. + +declare const tools: { +` + indentBundleEntry(employeeByIDTS) + ` +}; + +declare function notNull(value: T | null | undefined, message?: string): T; +declare function compact(value: T): T;` + +var twoOpsBundle = `type GraphQLError = { message: string; path?: (string | number)[]; extensions?: Record }; +type R = Promise<{ data: T | null; errors?: GraphQLError[] }>; +// Known limitation: union and interface selections are typed as unknown. + +declare const tools: { +` + indentBundleEntry(firstEmployeeTS) + ` + +` + indentBundleEntry(employeeByIDTS) + ` +}; + +declare function notNull(value: T | null | undefined, message?: string): T; +declare function compact(value: T): T;` + +type codeModeBackend struct { + name string + providerID string + redisURL string +} + +func TestCodeModeNamedOpsMemoryBackendStatefulSearchExecuteAndResource(t *testing.T) { + withCodeModeNamedOps(t, codeModeBackend{name: "memory"}, codeModeNamedOpsOptions{}, func(ctx context.Context, _ string, xEnv *testenv.Environment, yoko *fakeCodeModeYoko, session *mcp.ClientSession) { + searchText := callCodeModeToolText(t, ctx, session, "code_mode_search_tools", map[string]any{ + "prompts": []string{"first employee", "employee by id"}, + }) + assert.Equal(t, twoOpsFragment, searchText) + assert.Equal(t, []*yokov1.IndexRequest{{SchemaSdl: yoko.indexRequests()[0].GetSchemaSdl()}}, yoko.indexRequests()) + assert.Equal(t, []*yokov1.SearchRequest{{ + Prompts: []string{"first employee", "employee by id"}, + SchemaId: "schema-1", + SessionId: yoko.searchRequests()[0].GetSessionId(), + }}, yoko.searchRequests()) + + resource := readPersistedOpsResource(t, ctx, session) + assert.Equal(t, &mcp.ReadResourceResult{Contents: []*mcp.ResourceContents{{ + URI: codeModePersistedOpsURI, + MIMEType: "text/plain", + Text: twoOpsBundle, + }}}, resource) + + executeText := callCodeModeToolText(t, ctx, session, "code_mode_run_js", map[string]any{ + "source": `async () => { return await tools.employeeByID({ id: 1 }); }`, + }) + assert.Equal(t, map[string]any{ + "result": map[string]any{ + "data": map[string]any{ + "employee": map[string]any{ + "id": float64(1), + "details": map[string]any{ + "forename": "Jens", + "surname": "Neuse", + }, + }, + }, + }, + }, decodeJSON(t, executeText)) + + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{Query: `{ employee(id: 1) { id details { forename surname } } }`}) + assert.Equal(t, `{"data":{"employee":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}}`, res.Body) + }) +} + +func TestCodeModeNamedOpsConcurrentSessions(t *testing.T) { + withCodeModeNamedOps(t, codeModeBackend{name: "memory"}, codeModeNamedOpsOptions{}, func(ctx context.Context, endpoint string, _ *testenv.Environment, _ *fakeCodeModeYoko, sessionA *mcp.ClientSession) { + searchText := callCodeModeToolText(t, ctx, sessionA, "code_mode_search_tools", map[string]any{ + "prompts": []string{"first employee"}, + }) + assert.Equal(t, firstEmployeeTS, searchText) + + sessionB := newCodeModeMCPClient(t, ctx, endpoint, nil) + resourceA := readPersistedOpsResource(t, ctx, sessionA) + resourceB := readPersistedOpsResource(t, ctx, sessionB) + + assert.Equal(t, &mcp.ReadResourceResult{Contents: []*mcp.ResourceContents{{ + URI: codeModePersistedOpsURI, + MIMEType: "text/plain", + Text: firstEmployeeBundle, + }}}, resourceA) + assert.Equal(t, &mcp.ReadResourceResult{Contents: []*mcp.ResourceContents{{ + URI: codeModePersistedOpsURI, + MIMEType: "text/plain", + Text: emptyOpsBundle, + }}}, resourceB) + }) +} + +func TestCodeModeNamedOpsSchemaReloadEvictsSession(t *testing.T) { + poller := &codeModeConfigPoller{ready: make(chan struct{})} + withCodeModeNamedOps(t, codeModeBackend{name: "memory"}, codeModeNamedOpsOptions{poller: poller}, func(ctx context.Context, _ string, _ *testenv.Environment, _ *fakeCodeModeYoko, session *mcp.ClientSession) { + searchText := callCodeModeToolText(t, ctx, session, "code_mode_search_tools", map[string]any{ + "prompts": []string{"employee by id"}, + }) + assert.Equal(t, employeeByIDTS, searchText) + + <-poller.ready + poller.initConfig.Version = "code-mode-reload" + require.NoError(t, poller.updateConfig(poller.initConfig, "before-code-mode-reload")) + + executeText := callCodeModeToolText(t, ctx, session, "code_mode_run_js", map[string]any{ + "source": `async () => { return await tools.employeeByID({ id: 1 }); }`, + }) + assert.Equal(t, map[string]any{ + "result": nil, + "error": map[string]any{ + "name": "TypeError", + "message": "tools.employeeByID is not a function", + "stack": " at __agentMain (codemode_agent.js:agent.ts:1:34)\n at (codemode_agent.js:73:42)\n at (codemode_agent.js:77:1)\n", + }, + }, decodeJSON(t, executeText)) + }) +} + +func TestCodeModeNamedOpsMutationElicitationRejection(t *testing.T) { + decline := func(context.Context, *mcp.ElicitRequest) (*mcp.ElicitResult, error) { + return &mcp.ElicitResult{Action: "accept", Content: map[string]any{ + "approved": false, + "reason": "policy forbids", + }}, nil + } + withCodeModeNamedOps(t, codeModeBackend{name: "memory"}, codeModeNamedOpsOptions{elicitationHandler: decline}, func(ctx context.Context, _ string, _ *testenv.Environment, _ *fakeCodeModeYoko, session *mcp.ClientSession) { + searchText := callCodeModeToolText(t, ctx, session, "code_mode_search_tools", map[string]any{ + "prompts": []string{"update employee tag"}, + }) + assert.Equal(t, updateTagTS, searchText) + + executeText := callCodeModeToolText(t, ctx, session, "code_mode_run_js", map[string]any{ + "source": `async () => { return await tools.updateEmployeeTag({ id: 1, tag: "x" }); }`, + }) + assert.Equal(t, map[string]any{ + "result": map[string]any{ + "data": nil, + "declined": map[string]any{ + "reason": "policy forbids", + }, + "errors": []any{ + map[string]any{"message": "Mutation declined by operator: policy forbids"}, + }, + }, + }, decodeJSON(t, executeText)) + }) +} + +func TestCodeModeNamedOpsTranspileError(t *testing.T) { + withCodeModeNamedOps(t, codeModeBackend{name: "memory"}, codeModeNamedOpsOptions{}, func(ctx context.Context, _ string, _ *testenv.Environment, _ *fakeCodeModeYoko, session *mcp.ClientSession) { + searchText := callCodeModeToolText(t, ctx, session, "code_mode_search_tools", map[string]any{ + "prompts": []string{"employee by id"}, + }) + assert.Equal(t, employeeByIDTS, searchText) + + executeText := callCodeModeToolText(t, ctx, session, "code_mode_run_js", map[string]any{ + "source": `async () => { let x = ; }`, + }) + assert.Equal(t, map[string]any{ + "result": nil, + "error": map[string]any{ + "name": "TranspileError", + "message": "transpile failed: Unexpected \";\"", + "stack": "", + }, + }, decodeJSON(t, executeText)) + }) +} + +func TestCodeModeNamedOpsListResourcesGating(t *testing.T) { + t.Run("code mode disabled does not advertise persisted ops on main MCP server", func(t *testing.T) { + yoko := newFakeCodeModeYoko() + yokoServer := startFakeCodeModeYoko(t, yoko) + cfg := baseCodeModeTestConfig(t, yokoServer.URL, codeModeBackend{name: "memory"}, codeModeNamedOpsOptions{}) + cfg.MCP.CodeMode.Enabled = false + + testenv.Run(t, cfg, func(t *testing.T, xEnv *testenv.Environment) { + resources, err := xEnv.MCPClient.ListResources(ctxWithTimeout(t), mark3mcp.ListResourcesRequest{}) + require.NoError(t, err) + assert.Equal(t, false, mark3ResourcesContain(resources.Resources, codeModePersistedOpsURI)) + }) + }) + + t.Run("named ops disabled does not advertise persisted ops", func(t *testing.T) { + withCodeModeNamedOps(t, codeModeBackend{name: "memory"}, codeModeNamedOpsOptions{namedOpsEnabled: boolPtr(false)}, func(ctx context.Context, _ string, _ *testenv.Environment, _ *fakeCodeModeYoko, session *mcp.ClientSession) { + resources, err := session.ListResources(ctx, &mcp.ListResourcesParams{}) + require.NoError(t, err) + assert.Equal(t, []*mcp.Resource{}, resources.Resources) + }) + }) + + t.Run("stateless does not advertise persisted ops and warns once", func(t *testing.T) { + withCodeModeNamedOps(t, codeModeBackend{name: "memory"}, codeModeNamedOpsOptions{sessionStateless: boolPtr(true), observeLogs: true}, func(ctx context.Context, _ string, xEnv *testenv.Environment, _ *fakeCodeModeYoko, session *mcp.ClientSession) { + resources, err := session.ListResources(ctx, &mcp.ListResourcesParams{}) + require.NoError(t, err) + assert.Equal(t, []*mcp.Resource{}, resources.Resources) + + logs := xEnv.Observer().FilterMessage("code mode named operations are disabled because MCP session stateless mode is enabled").All() + assert.Equal(t, 1, len(logs)) + }) + }) + + t.Run("all gates on advertises persisted ops and read returns bundle", func(t *testing.T) { + withCodeModeNamedOps(t, codeModeBackend{name: "memory"}, codeModeNamedOpsOptions{}, func(ctx context.Context, _ string, _ *testenv.Environment, _ *fakeCodeModeYoko, session *mcp.ClientSession) { + searchText := callCodeModeToolText(t, ctx, session, "code_mode_search_tools", map[string]any{ + "prompts": []string{"employee by id"}, + }) + assert.Equal(t, employeeByIDTS, searchText) + + resources, err := session.ListResources(ctx, &mcp.ListResourcesParams{}) + require.NoError(t, err) + assert.Equal(t, []*mcp.Resource{{ + URI: codeModePersistedOpsURI, + Name: "persisted-ops.d.ts", + Title: "Persisted operations TypeScript definitions", + Description: "Cumulative TypeScript definitions for the current Code Mode MCP session's named operations.", + MIMEType: "text/plain", + }}, resources.Resources) + + resource := readPersistedOpsResource(t, ctx, session) + assert.Equal(t, &mcp.ReadResourceResult{Contents: []*mcp.ResourceContents{{ + URI: codeModePersistedOpsURI, + MIMEType: "text/plain", + Text: employeeByIDBundle, + }}}, resource) + }) + }) +} + +func TestCodeModeNamedOpsRedisBackendTransparent(t *testing.T) { + redisServer, err := miniredis.Run() + if err != nil { + t.Skipf("miniredis unavailable: %v", err) + } + t.Cleanup(redisServer.Close) + + backend := codeModeBackend{ + name: "redis", + providerID: "code_mode_redis", + redisURL: "redis://" + redisServer.Addr(), + } + withCodeModeNamedOps(t, backend, codeModeNamedOpsOptions{}, func(ctx context.Context, _ string, _ *testenv.Environment, _ *fakeCodeModeYoko, session *mcp.ClientSession) { + searchText := callCodeModeToolText(t, ctx, session, "code_mode_search_tools", map[string]any{ + "prompts": []string{"first employee", "employee by id"}, + }) + assert.Equal(t, twoOpsFragment, searchText) + + resource := readPersistedOpsResource(t, ctx, session) + assert.Equal(t, twoOpsBundle, resource.Contents[0].Text) + + executeText := callCodeModeToolText(t, ctx, session, "code_mode_run_js", map[string]any{ + "source": `async () => { return await tools.employeeByID({ id: 1 }); }`, + }) + assert.Equal(t, map[string]any{ + "result": map[string]any{ + "data": map[string]any{ + "employee": map[string]any{ + "id": float64(1), + "details": map[string]any{ + "forename": "Jens", + "surname": "Neuse", + }, + }, + }, + }, + }, decodeJSON(t, executeText)) + }) +} + +type codeModeNamedOpsOptions struct { + namedOpsEnabled *bool + sessionStateless *bool + observeLogs bool + poller *codeModeConfigPoller + elicitationHandler func(context.Context, *mcp.ElicitRequest) (*mcp.ElicitResult, error) +} + +func withCodeModeNamedOps(t *testing.T, backend codeModeBackend, opts codeModeNamedOpsOptions, f func(context.Context, string, *testenv.Environment, *fakeCodeModeYoko, *mcp.ClientSession)) { + t.Helper() + + yoko := newFakeCodeModeYoko() + yokoServer := startFakeCodeModeYoko(t, yoko) + cfg := baseCodeModeTestConfig(t, yokoServer.URL, backend, opts) + + testenv.Run(t, cfg, func(t *testing.T, xEnv *testenv.Environment) { + ctx := ctxWithTimeout(t) + endpoint := "http://" + cfg.MCP.CodeMode.Server.ListenAddr + "/mcp" + session := newCodeModeMCPClient(t, ctx, endpoint, opts.elicitationHandler) + f(ctx, endpoint, xEnv, yoko, session) + }) +} + +func baseCodeModeTestConfig(t *testing.T, yokoURL string, backend codeModeBackend, opts codeModeNamedOpsOptions) *testenv.Config { + t.Helper() + + ports := freeport.GetN(t, 2) + namedOpsEnabled := true + if opts.namedOpsEnabled != nil { + namedOpsEnabled = *opts.namedOpsEnabled + } + sessionStateless := false + if opts.sessionStateless != nil { + sessionStateless = *opts.sessionStateless + } + + mcpCfg := config.MCPConfiguration{ + Enabled: true, + Server: config.MCPServer{ + ListenAddr: fmt.Sprintf("127.0.0.1:%d", ports[0]), + }, + Session: config.MCPSessionConfig{Stateless: sessionStateless}, + CodeMode: config.MCPCodeModeConfiguration{ + Enabled: true, + RequireMutationApproval: true, + ExecuteTimeout: 30 * time.Second, + MaxResultBytes: 32 << 10, + Server: config.MCPCodeModeServerConfig{ + ListenAddr: fmt.Sprintf("127.0.0.1:%d", ports[1]), + }, + QueryGeneration: config.MCPCodeModeQueryGenConfig{ + Enabled: true, + Endpoint: yokoURL, + Timeout: 5 * time.Second, + }, + NamedOps: config.MCPCodeModeNamedOpsConfig{ + Enabled: namedOpsEnabled, + SessionTTL: 30 * time.Minute, + MaxSessions: 100, + MaxBundleBytes: 256 << 10, + Storage: config.MCPCodeModeNamedOpsStorageConfig{ + ProviderID: backend.providerID, + KeyPrefix: "router_tests_code_mode", + }, + }, + }, + } + + cfg := &testenv.Config{ + MCP: mcpCfg, + MCPOperationsPath: "protocol/testdata/mcp_operations_collision", + CodeModeRedisURL: backend.redisURL, + } + if opts.observeLogs { + cfg.LogObservation = testenv.LogObservationConfig{Enabled: true, LogLevel: zapcore.WarnLevel} + } + if opts.poller != nil { + cfg.RouterConfig = &testenv.RouterConfig{ + ConfigPollerFactory: func(routerConfig *nodev1.RouterConfig) configpoller.ConfigPoller { + opts.poller.initConfig = routerConfig + return opts.poller + }, + } + } + return cfg +} + +func newCodeModeMCPClient(t *testing.T, ctx context.Context, endpoint string, elicitation func(context.Context, *mcp.ElicitRequest) (*mcp.ElicitResult, error)) *mcp.ClientSession { + t.Helper() + + client := mcp.NewClient(&mcp.Implementation{Name: "router-tests", Version: "v0.0.0"}, &mcp.ClientOptions{ + ElicitationHandler: elicitation, + }) + transport := &mcp.StreamableClientTransport{ + Endpoint: endpoint, + DisableStandaloneSSE: true, + MaxRetries: -1, + } + session, err := client.Connect(ctx, transport, nil) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, session.Close()) + }) + return session +} + +func callCodeModeToolText(t *testing.T, ctx context.Context, session *mcp.ClientSession, name string, args map[string]any) string { + t.Helper() + result, err := session.CallTool(ctx, &mcp.CallToolParams{Name: name, Arguments: args}) + require.NoError(t, err) + require.False(t, result.IsError) + require.Len(t, result.Content, 1) + text, ok := result.Content[0].(*mcp.TextContent) + require.True(t, ok) + return text.Text +} + +func readPersistedOpsResource(t *testing.T, ctx context.Context, session *mcp.ClientSession) *mcp.ReadResourceResult { + t.Helper() + result, err := session.ReadResource(ctx, &mcp.ReadResourceParams{URI: codeModePersistedOpsURI}) + require.NoError(t, err) + return result +} + +func decodeJSON(t *testing.T, text string) map[string]any { + t.Helper() + var decoded map[string]any + require.NoError(t, json.Unmarshal([]byte(text), &decoded)) + return decoded +} + +func ctxWithTimeout(t *testing.T) context.Context { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + t.Cleanup(cancel) + return ctx +} + +func boolPtr(v bool) *bool { + return &v +} + +func mark3ResourcesContain(resources []mark3mcp.Resource, uri string) bool { + for _, resource := range resources { + if resource.URI == uri { + return true + } + } + return false +} + +type fakeCodeModeYoko struct { + mu sync.Mutex + indexCounter int + indexRequestLog []*yokov1.IndexRequest + searchRequestLog []*yokov1.SearchRequest + opsByPrompt map[string]*yokov1.GeneratedOperation +} + +func newFakeCodeModeYoko() *fakeCodeModeYoko { + return &fakeCodeModeYoko{ + opsByPrompt: map[string]*yokov1.GeneratedOperation{ + "first employee": { + Name: firstEmployeeOpName, + Body: firstEmployeeQuery, + Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, + Description: "Fetch the first employee.", + }, + "employee by id": { + Name: employeeByIDOpName, + Body: employeeByIDQuery, + Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, + Description: "Fetch employee by id.", + }, + "update employee tag": { + Name: updateTagOpName, + Body: updateTagMutation, + Kind: yokov1.OperationKind_OPERATION_KIND_MUTATION, + Description: "Update employee tag.", + }, + }, + } +} + +func startFakeCodeModeYoko(t *testing.T, svc *fakeCodeModeYoko) *httptest.Server { + t.Helper() + path, handler := yokoconnect.NewYokoServiceHandler(svc) + mux := http.NewServeMux() + mux.Handle(path, handler) + ports := freeport.GetN(t, 1) + listener, err := net.Listen("tcp4", fmt.Sprintf("127.0.0.1:%d", ports[0])) + require.NoError(t, err) + server := httptest.NewUnstartedServer(mux) + server.Listener = listener + server.Start() + t.Cleanup(server.Close) + return server +} + +func (f *fakeCodeModeYoko) Index(_ context.Context, req *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { + f.mu.Lock() + defer f.mu.Unlock() + f.indexCounter++ + f.indexRequestLog = append(f.indexRequestLog, &yokov1.IndexRequest{SchemaSdl: req.Msg.GetSchemaSdl()}) + return connect.NewResponse(&yokov1.IndexResponse{SchemaId: fmt.Sprintf("schema-%d", f.indexCounter)}), nil +} + +func (f *fakeCodeModeYoko) Search(_ context.Context, req *connect.Request[yokov1.SearchRequest]) (*connect.Response[yokov1.SearchResponse], error) { + f.mu.Lock() + defer f.mu.Unlock() + f.searchRequestLog = append(f.searchRequestLog, &yokov1.SearchRequest{ + Prompts: append([]string(nil), req.Msg.GetPrompts()...), + SchemaId: req.Msg.GetSchemaId(), + SessionId: req.Msg.GetSessionId(), + }) + ops := make([]*yokov1.GeneratedOperation, 0, len(req.Msg.GetPrompts())) + for _, prompt := range req.Msg.GetPrompts() { + if op := f.opsByPrompt[prompt]; op != nil { + ops = append(ops, op) + } + } + return connect.NewResponse(&yokov1.SearchResponse{Operations: ops}), nil +} + +func (f *fakeCodeModeYoko) indexRequests() []*yokov1.IndexRequest { + f.mu.Lock() + defer f.mu.Unlock() + out := make([]*yokov1.IndexRequest, 0, len(f.indexRequestLog)) + for _, req := range f.indexRequestLog { + out = append(out, &yokov1.IndexRequest{SchemaSdl: req.GetSchemaSdl()}) + } + return out +} + +func (f *fakeCodeModeYoko) searchRequests() []*yokov1.SearchRequest { + f.mu.Lock() + defer f.mu.Unlock() + out := make([]*yokov1.SearchRequest, 0, len(f.searchRequestLog)) + for _, req := range f.searchRequestLog { + out = append(out, &yokov1.SearchRequest{ + Prompts: append([]string(nil), req.GetPrompts()...), + SchemaId: req.GetSchemaId(), + SessionId: req.GetSessionId(), + }) + } + return out +} + +type codeModeConfigPoller struct { + initConfig *nodev1.RouterConfig + updateConfig func(newConfig *nodev1.RouterConfig, oldVersion string) error + ready chan struct{} + once sync.Once +} + +func (c *codeModeConfigPoller) Subscribe(_ context.Context, handler func(newConfig *nodev1.RouterConfig, oldVersion string) error) { + c.updateConfig = handler + c.once.Do(func() { close(c.ready) }) +} + +func (c *codeModeConfigPoller) GetRouterConfig(_ context.Context) (*routerconfig.Response, error) { + return &routerconfig.Response{Config: c.initConfig}, nil +} + +func (c *codeModeConfigPoller) Stop(_ context.Context) error { + return nil +} diff --git a/router-tests/go.mod b/router-tests/go.mod index 862856a1ae..e08d1f9168 100644 --- a/router-tests/go.mod +++ b/router-tests/go.mod @@ -5,6 +5,7 @@ go 1.25.0 require ( connectrpc.com/connect v1.19.1 github.com/MicahParks/jwkset v0.11.0 + github.com/alicebob/miniredis/v2 v2.34.0 github.com/buger/jsonparser v1.1.2 github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0 github.com/golang-jwt/jwt/v5 v5.3.0 @@ -51,6 +52,7 @@ require ( github.com/KimMachineGun/automemlimit v0.6.1 // indirect github.com/MicahParks/keyfunc/v3 v3.6.2 // indirect github.com/agnivade/levenshtein v1.2.1 // indirect + github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect github.com/andybalholm/brotli v1.1.0 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/benbjohnson/clock v1.3.0 // indirect @@ -74,7 +76,9 @@ require ( github.com/docker/docker-credential-helpers v0.9.3 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect + github.com/evanw/esbuild v0.27.3 // indirect github.com/expr-lang/expr v1.17.7 // indirect + github.com/fastschema/qjs v0.0.6 // indirect github.com/fatih/color v1.18.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-chi/chi/v5 v5.2.2 // indirect @@ -146,6 +150,8 @@ require ( github.com/sosodev/duration v1.3.1 // indirect github.com/spf13/cast v1.7.1 // indirect github.com/stretchr/objx v0.5.2 // indirect + github.com/tdewolff/parse/v2 v2.8.12 // indirect + github.com/tetratelabs/wazero v1.9.0 // indirect github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect @@ -161,6 +167,7 @@ require ( github.com/wundergraph/go-arena v1.1.0 // indirect github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + github.com/yuin/gopher-lua v1.1.1 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 // indirect go.opentelemetry.io/contrib/propagators/b3 v1.23.0 // indirect diff --git a/router-tests/go.sum b/router-tests/go.sum index 28e214572e..85128dd4c4 100644 --- a/router-tests/go.sum +++ b/router-tests/go.sum @@ -87,8 +87,12 @@ github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4 github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/evanw/esbuild v0.27.3 h1:dH/to9tBKybig6hl25hg4SKIWP7U8COdJKbGEwnUkmU= +github.com/evanw/esbuild v0.27.3/go.mod h1:D2vIQZqV/vIf/VRHtViaUtViZmG7o+kKmlBfVQuRi48= github.com/expr-lang/expr v1.17.7 h1:Q0xY/e/2aCIp8g9s/LGvMDCC5PxYlvHgDZRQ4y16JX8= github.com/expr-lang/expr v1.17.7/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= +github.com/fastschema/qjs v0.0.6 h1:C45KMmQMd21UwsUAmQHxUxiWOfzwTg1GJW0DA0AbFEE= +github.com/fastschema/qjs v0.0.6/go.mod h1:bbg36wxXnx8g0FdKIe5+nCubrQvHa7XEVWqUptjHt/A= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= @@ -333,6 +337,12 @@ github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/tdewolff/parse/v2 v2.8.12 h1:5BBjfaCv482v3nltlS0u6wH1xJaxjR6ofDrWttNvROg= +github.com/tdewolff/parse/v2 v2.8.12/go.mod h1:Hwlni2tiVNKyzR1o6nUs4FOF07URA+JLBLd6dlIXYqo= +github.com/tdewolff/test v1.0.11 h1:FdLbwQVHxqG16SlkGveC0JVyrJN62COWTRyUFzfbtBE= +github.com/tdewolff/test v1.0.11/go.mod h1:XPuWBzvdUzhCuxWO1ojpXsyzsA5bFoS3tO/Q3kFuTG8= +github.com/tetratelabs/wazero v1.9.0 h1:IcZ56OuxrtaEz8UYNRHBrUa9bYeX9oVY93KspZZBf/I= +github.com/tetratelabs/wazero v1.9.0/go.mod h1:TSbcXCfFP0L2FGkRPxHphadXPjo1T6W+CseNNY7EkjM= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= diff --git a/router-tests/testenv/testenv.go b/router-tests/testenv/testenv.go index 8f3c0a1a21..edca6036c8 100644 --- a/router-tests/testenv/testenv.go +++ b/router-tests/testenv/testenv.go @@ -367,6 +367,10 @@ type Config struct { MCP config.MCPConfiguration MCPOperationsPath string MCPAuthToken string // Optional Bearer token for MCP authentication + // CodeModeRedisURL, when paired with MCP.CodeMode.NamedOps.Storage.ProviderID, + // registers a Redis storage provider with that ID so the named-ops backend can + // resolve it from the central provider registry. + CodeModeRedisURL string EnableRedis bool EnableRedisCluster bool Plugins PluginConfig @@ -1520,14 +1524,23 @@ func configureRouter(listenerAddr string, testConfig *Config, routerConfig *node if testConfig.MCPOperationsPath != "" { mcpOperationsPath = testConfig.MCPOperationsPath } - routerOpts = append(routerOpts, core.WithStorageProviders(config.StorageProviders{ + storageProviders := config.StorageProviders{ FileSystem: []config.FileSystemStorageProvider{ { ID: "test", Path: mcpOperationsPath, }, }, - })) + } + // Append a Redis provider for code mode named ops when the test set a + // provider_id and supplied a URL via CodeModeRedisURL. + if id := testConfig.MCP.CodeMode.NamedOps.Storage.ProviderID; id != "" && testConfig.CodeModeRedisURL != "" { + storageProviders.Redis = append(storageProviders.Redis, config.RedisStorageProvider{ + ID: id, + URLs: []string{testConfig.CodeModeRedisURL}, + }) + } + routerOpts = append(routerOpts, core.WithStorageProviders(storageProviders)) testConfig.MCP.Storage.ProviderID = "test" diff --git a/router/core/graph_server.go b/router/core/graph_server.go index 907d5abfbb..830b29e5a5 100644 --- a/router/core/graph_server.go +++ b/router/core/graph_server.go @@ -59,6 +59,7 @@ import ( rtrace "github.com/wundergraph/cosmo/router/pkg/trace" "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astprinter" ) const ( @@ -1399,6 +1400,14 @@ func (s *graphServer) buildGraphMux( return nil, fmt.Errorf("failed to reload MCP server: %w", mErr) } } + if opts.IsBaseGraph() && s.codeModeServer != nil { + sdl, printErr := astprinter.PrintString(executor.ClientSchema) + if printErr != nil { + s.logger.Error("failed to reload MCP server", zap.Error(fmt.Errorf("failed to print Code Mode schema SDL: %w", printErr))) + } else if mErr := s.codeModeServer.Reload(executor.ClientSchema, sdl); mErr != nil { + s.logger.Error("failed to reload MCP server", zap.Error(mErr)) + } + } if s.cacheWarmup != nil && s.cacheWarmup.Enabled { processor := NewCacheWarmupPlanningProcessor(&CacheWarmupPlanningProcessorOptions{ diff --git a/router/core/router.go b/router/core/router.go index cb173417d3..01f5b16601 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -27,6 +27,7 @@ import ( "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/graphqlmetrics/v1/graphqlmetricsv1connect" nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" "github.com/wundergraph/cosmo/router/internal/circuit" + codemodeserver "github.com/wundergraph/cosmo/router/internal/codemode/server" "github.com/wundergraph/cosmo/router/internal/debug" "github.com/wundergraph/cosmo/router/internal/docker" "github.com/wundergraph/cosmo/router/internal/exporter" @@ -947,6 +948,9 @@ func (r *Router) bootstrap(ctx context.Context) error { if err := r.startMCPServer(ctx); err != nil { return err } + if err := r.startCodeModeServer(ctx); err != nil { + return err + } if r.connectRPC.Enabled { r.logger.Debug("ConnectRPC configuration", @@ -1067,87 +1071,241 @@ func (r *Router) bootstrap(ctx context.Context) error { return nil } -// startMCPServer initializes and starts the MCP server if enabled. +// startMCPServer initializes and starts the MCP server(s) if enabled. +// +// When mcp.servers is configured, each entry is mounted as its own collection on +// the shared HTTP listener. When absent, a single implicit collection is synthesized +// from the legacy top-level mcp.* fields and mounted at "/mcp" — preserving +// backwards compatibility. func (r *Router) startMCPServer(ctx context.Context) error { if !r.mcp.Enabled { return nil } - var operationsDir string + entries, err := r.mcp.NormalizeServers() + if err != nil { + return fmt.Errorf("invalid mcp configuration: %w", err) + } + + defaultGraphQLEndpoint := r.graphqlEndpointURL + if r.mcp.RouterURL != "" { + defaultGraphQLEndpoint = r.mcp.RouterURL + } + + handlers := make([]*mcpserver.GraphQLSchemaServer, 0, len(entries)) + for i := range entries { + entry := &entries[i] + handler, hErr := r.buildMCPHandler(ctx, entry, defaultGraphQLEndpoint) + if hErr != nil { + return fmt.Errorf("mcp server %q: %w", entry.Name, hErr) + } + handlers = append(handlers, handler) + } + + multi, err := mcpserver.NewMultiServer(r.mcp.Server.ListenAddr, r.logger, handlers...) + if err != nil { + return fmt.Errorf("failed to create mcp multi-server: %w", err) + } + + if err := multi.Start(); err != nil { + if stopErr := multi.Stop(ctx); stopErr != nil { + r.logger.Warn("Failed to stop MCP server during error cleanup", zap.Error(stopErr)) + } + return fmt.Errorf("failed to start MCP server: %w", err) + } - // If storage provider ID is set, resolve it to a directory path - if r.mcp.Storage.ProviderID != "" { - r.logger.Debug("Resolving storage provider for MCP operations", - zap.String("provider_id", r.mcp.Storage.ProviderID)) + r.mcpServer = multi + return nil +} + +// loadOrIntrospectUpstreamSDL resolves the SDL for an upstream-bound MCP collection. +// Resolution order: +// 1. If schema.file is configured and exists on disk, load it. +// 2. Else introspect the upstream at startup. If schema.file is configured (but missing), +// write the result to disk so subsequent runs are deterministic and offline-safe. +func loadOrIntrospectUpstreamSDL(ctx context.Context, entry *config.MCPServerEntry, logger *zap.Logger) (string, error) { + filePath := entry.Upstream.Schema.File + if filePath != "" { + data, err := os.ReadFile(filePath) + if err == nil { + return string(data), nil + } + if !errors.Is(err, os.ErrNotExist) { + return "", fmt.Errorf("read upstream SDL file %q: %w", filePath, err) + } + } + + logger.Info("introspecting upstream GraphQL endpoint", + zap.String("mcp_server_name", entry.Name), + zap.String("url", entry.Upstream.URL)) + + sdl, err := mcpserver.IntrospectUpstreamSDL(ctx, entry.Upstream.URL, entry.Upstream.Headers) + if err != nil { + return "", fmt.Errorf("introspect upstream %s: %w", entry.Upstream.URL, err) + } - provider, ok := r.providerRegistry.FileSystem(r.mcp.Storage.ProviderID) + if filePath != "" { + if writeErr := os.WriteFile(filePath, []byte(sdl), 0o644); writeErr != nil { + logger.Warn("failed to cache introspected SDL to disk; continuing in-memory only", + zap.String("mcp_server_name", entry.Name), + zap.String("file", filePath), + zap.Error(writeErr)) + } else { + logger.Info("cached introspected SDL", + zap.String("mcp_server_name", entry.Name), + zap.String("file", filePath)) + } + } + + return sdl, nil +} + +// buildMCPHandler constructs a single per-collection MCP handler from an MCPServerEntry. +// Supergraph-bound collections (no Upstream) route to defaultGraphQLEndpoint; upstream-bound +// collections route to the configured Upstream.URL and load their schema from the SDL file. +func (r *Router) buildMCPHandler(ctx context.Context, entry *config.MCPServerEntry, defaultGraphQLEndpoint string) (*mcpserver.GraphQLSchemaServer, error) { + var operationsDir string + if entry.Storage.ProviderID != "" { + provider, ok := r.providerRegistry.FileSystem(entry.Storage.ProviderID) if !ok { - return fmt.Errorf("storage provider with id '%s' for mcp server not found", r.mcp.Storage.ProviderID) + return nil, fmt.Errorf("storage provider with id %q not found", entry.Storage.ProviderID) } - r.logger.Debug("Found file_system storage provider for MCP", - zap.String("id", provider.ID), - zap.String("path", provider.Path)) operationsDir = provider.Path } + endpoint := defaultGraphQLEndpoint + var upstreamSDL string + var upstreamHeaders map[string]string + if entry.Upstream != nil { + endpoint = entry.Upstream.URL + upstreamHeaders = entry.Upstream.Headers + sdl, err := loadOrIntrospectUpstreamSDL(ctx, entry, r.logger) + if err != nil { + return nil, err + } + upstreamSDL = sdl + } + logFields := []zap.Field{ - zap.String("storage_provider_id", r.mcp.Storage.ProviderID), + zap.String("mcp_server_name", entry.Name), + zap.String("mcp_server_path", entry.Path), + zap.String("storage_provider_id", entry.Storage.ProviderID), } - // Initialize the MCP server with the resolved operations directory mcpOpts := []func(*mcpserver.Options){ - mcpserver.WithGraphName(r.mcp.GraphName), + mcpserver.WithGraphName(entry.Name), + mcpserver.WithPath(entry.Path), mcpserver.WithOperationsDir(operationsDir), mcpserver.WithListenAddr(r.mcp.Server.ListenAddr), mcpserver.WithLogger(r.logger.With(logFields...)), - mcpserver.WithExcludeMutations(r.mcp.ExcludeMutations), - mcpserver.WithEnableArbitraryOperations(r.mcp.EnableArbitraryOperations), - mcpserver.WithExposeSchema(r.mcp.ExposeSchema), - mcpserver.WithOmitToolNamePrefix(r.mcp.OmitToolNamePrefix), - mcpserver.WithStateless(r.mcp.Session.Stateless), + mcpserver.WithExcludeMutations(entry.ExcludeMutations), + mcpserver.WithEnableArbitraryOperations(entry.EnableArbitraryOperations), + mcpserver.WithExposeSchema(entry.ExposeSchema), + mcpserver.WithOmitToolNamePrefix(entry.OmitToolNamePrefix), + mcpserver.WithStateless(entry.Session.Stateless), + } + + if upstreamSDL != "" { + mcpOpts = append(mcpOpts, mcpserver.WithUpstreamSchemaSDL(upstreamSDL)) + } + if len(upstreamHeaders) > 0 { + mcpOpts = append(mcpOpts, mcpserver.WithUpstreamHeaders(upstreamHeaders)) + } + if entry.Storage.Watch && operationsDir != "" { + interval := entry.Storage.WatchInterval + if interval <= 0 { + interval = time.Second + } + mcpOpts = append(mcpOpts, mcpserver.WithWatchOperations(true, interval)) } if r.corsOptions != nil { mcpOpts = append(mcpOpts, mcpserver.WithCORS(*r.corsOptions)) } - // Add OAuth configuration if enabled - if r.mcp.OAuth.Enabled { - mcpOpts = append(mcpOpts, mcpserver.WithOAuth(&r.mcp.OAuth)) - + if entry.OAuth.Enabled { + oauthCfg := entry.OAuth + mcpOpts = append(mcpOpts, mcpserver.WithOAuth(&oauthCfg)) if r.mcp.Server.BaseURL != "" { mcpOpts = append(mcpOpts, mcpserver.WithServerBaseURL(r.mcp.Server.BaseURL)) } } - if r.mcp.ResourceDocumentation != "" { - mcpOpts = append(mcpOpts, mcpserver.WithResourceDocumentation(r.mcp.ResourceDocumentation)) + if entry.ResourceDocumentation != "" { + mcpOpts = append(mcpOpts, mcpserver.WithResourceDocumentation(entry.ResourceDocumentation)) } - mcpGraphQLEndpoint := r.graphqlEndpointURL - if r.mcp.RouterURL != "" { - mcpGraphQLEndpoint = r.mcp.RouterURL + return mcpserver.NewGraphQLSchemaServer(ctx, endpoint, mcpOpts...) +} + +// startCodeModeServer initializes and starts the separate Code Mode MCP server if enabled. +func (r *Router) startCodeModeServer(ctx context.Context) error { + var redisProvider *config.RedisStorageProvider + if r.mcp.CodeMode.Enabled && r.mcp.CodeMode.NamedOps.Enabled { + if providerID := r.mcp.CodeMode.NamedOps.Storage.ProviderID; providerID != "" { + provider, ok := r.providerRegistry.Redis(providerID) + if !ok { + return fmt.Errorf("redis storage provider with id '%s' for mcp code_mode named_ops not found", providerID) + } + redisProvider = &provider + } } - mcpss, err := mcpserver.NewGraphQLSchemaServer( - ctx, - mcpGraphQLEndpoint, - mcpOpts..., - ) + cm, err := codemodeserver.BuildFromConfig(codemodeserver.BuildOptions{ + Config: r.mcp.CodeMode, + SessionStateless: r.mcp.Session.Stateless, + RouterGraphQLURL: r.graphqlEndpointURL, + Logger: r.logger, + TracerProvider: r.tracerProvider, + MeterProvider: r.otlpMeterProvider, + RedisProvider: redisProvider, + RedisFactory: func(opts *rd.RedisCloserOptions) (rd.RDCloser, error) { + if opts.Logger == nil { + opts.Logger = r.logger + } + return rd.NewRedisCloser(opts) + }, + }) if err != nil { - return fmt.Errorf("failed to create mcp server: %w", err) + return fmt.Errorf("failed to create code mode MCP server: %w", err) } + r.codeModeServer = cm - if err := mcpss.Start(); err != nil { - // Cleanup the server if Start() fails to prevent resource leaks - if stopErr := mcpss.Stop(ctx); stopErr != nil { - r.logger.Warn("Failed to stop MCP server during error cleanup", zap.Error(stopErr)) - } - return fmt.Errorf("failed to start MCP server: %w", err) + if !r.mcp.CodeMode.Enabled { + return nil } - r.mcpServer = mcpss - return nil + errs := make(chan error, 1) + go func() { + errs <- cm.Start(ctx) + }() + + deadline := time.NewTimer(5 * time.Second) + defer deadline.Stop() + tick := time.NewTicker(10 * time.Millisecond) + defer tick.Stop() + for { + select { + case err := <-errs: + if err != nil { + return fmt.Errorf("failed to start code mode MCP server: %w", err) + } + return nil + case <-ctx.Done(): + return ctx.Err() + case <-deadline.C: + return fmt.Errorf("failed to start code mode MCP server: listener was not bound") + case <-tick.C: + if cm.Addr() != "" { + go func() { + if err := <-errs; err != nil { + r.logger.Error("Code Mode MCP server stopped unexpectedly", zap.Error(err)) + } + }() + return nil + } + } + } } // buildClients initializes the storage clients for persisted operations and router config. @@ -1722,6 +1880,14 @@ func (r *Router) Shutdown(ctx context.Context) error { }) } + if r.codeModeServer != nil { + wg.Go(func() { + if subErr := r.codeModeServer.Stop(ctx); subErr != nil { + err.Append(fmt.Errorf("failed to shutdown code mode MCP server: %w", subErr)) + } + }) + } + if r.connectRPCServer != nil { wg.Go(func() { if subErr := r.connectRPCServer.Stop(ctx); subErr != nil { diff --git a/router/core/router_config.go b/router/core/router_config.go index 9f4b0bf84c..4f8913c1c2 100644 --- a/router/core/router_config.go +++ b/router/core/router_config.go @@ -6,6 +6,7 @@ import ( "time" nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" + codemodeserver "github.com/wundergraph/cosmo/router/internal/codemode/server" "github.com/wundergraph/cosmo/router/internal/graphqlmetrics" "github.com/wundergraph/cosmo/router/internal/persistedoperation" "github.com/wundergraph/cosmo/router/internal/persistedoperation/pqlmanifest" @@ -112,7 +113,8 @@ type Config struct { accessController *AccessController retryOptions retrytransport.RetryOptions redisClient rd.RDCloser - mcpServer *mcpserver.GraphQLSchemaServer + mcpServer *mcpserver.MultiServer + codeModeServer *codemodeserver.Server connectRPCServer *connectrpc.Server processStartTime time.Time developmentMode bool diff --git a/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yoko.pb.go b/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yoko.pb.go new file mode 100644 index 0000000000..c1fb97bbbd --- /dev/null +++ b/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yoko.pb.go @@ -0,0 +1,451 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.10 +// protoc (unknown) +// source: wg/cosmo/code_mode/yoko/v1/yoko.proto + +package yokov1 + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type OperationKind int32 + +const ( + OperationKind_OPERATION_KIND_UNSPECIFIED OperationKind = 0 + OperationKind_OPERATION_KIND_QUERY OperationKind = 1 + OperationKind_OPERATION_KIND_MUTATION OperationKind = 2 +) + +// Enum value maps for OperationKind. +var ( + OperationKind_name = map[int32]string{ + 0: "OPERATION_KIND_UNSPECIFIED", + 1: "OPERATION_KIND_QUERY", + 2: "OPERATION_KIND_MUTATION", + } + OperationKind_value = map[string]int32{ + "OPERATION_KIND_UNSPECIFIED": 0, + "OPERATION_KIND_QUERY": 1, + "OPERATION_KIND_MUTATION": 2, + } +) + +func (x OperationKind) Enum() *OperationKind { + p := new(OperationKind) + *p = x + return p +} + +func (x OperationKind) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (OperationKind) Descriptor() protoreflect.EnumDescriptor { + return file_wg_cosmo_code_mode_yoko_v1_yoko_proto_enumTypes[0].Descriptor() +} + +func (OperationKind) Type() protoreflect.EnumType { + return &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_enumTypes[0] +} + +func (x OperationKind) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use OperationKind.Descriptor instead. +func (OperationKind) EnumDescriptor() ([]byte, []int) { + return file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescGZIP(), []int{0} +} + +type IndexRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // The supergraph SDL to index. Sent in full on every Index call; + // Yoko deduplicates internally and is free to short-circuit when + // the SDL is already known. + SchemaSdl string `protobuf:"bytes,1,opt,name=schema_sdl,json=schemaSdl,proto3" json:"schema_sdl,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *IndexRequest) Reset() { + *x = IndexRequest{} + mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *IndexRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*IndexRequest) ProtoMessage() {} + +func (x *IndexRequest) ProtoReflect() protoreflect.Message { + mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use IndexRequest.ProtoReflect.Descriptor instead. +func (*IndexRequest) Descriptor() ([]byte, []int) { + return file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescGZIP(), []int{0} +} + +func (x *IndexRequest) GetSchemaSdl() string { + if x != nil { + return x.SchemaSdl + } + return "" +} + +type IndexResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Opaque, Yoko-assigned identifier for this schema. Stable for as + // long as Yoko retains the index. Subsequent Search calls pass this + // back instead of the full SDL. Idempotent: the same SDL returns + // the same schema_id. + SchemaId string `protobuf:"bytes,1,opt,name=schema_id,json=schemaId,proto3" json:"schema_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *IndexResponse) Reset() { + *x = IndexResponse{} + mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *IndexResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*IndexResponse) ProtoMessage() {} + +func (x *IndexResponse) ProtoReflect() protoreflect.Message { + mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use IndexResponse.ProtoReflect.Descriptor instead. +func (*IndexResponse) Descriptor() ([]byte, []int) { + return file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescGZIP(), []int{1} +} + +func (x *IndexResponse) GetSchemaId() string { + if x != nil { + return x.SchemaId + } + return "" +} + +type SearchRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Batch of natural-language prompts. Bounded at 20 by the host. + Prompts []string `protobuf:"bytes,1,rep,name=prompts,proto3" json:"prompts,omitempty"` + // Identifier returned by a prior Index call. If Yoko no longer + // recognizes the id (e.g. eviction, restart), it MUST return the + // Connect error code NOT_FOUND; the router re-indexes and retries + // the call exactly once. + SchemaId string `protobuf:"bytes,2,opt,name=schema_id,json=schemaId,proto3" json:"schema_id,omitempty"` + // Opaque MCP session ID for telemetry correlation only. + // Yoko MUST NOT use this for stateful behavior — sessions are owned + // by the router. + SessionId string `protobuf:"bytes,3,opt,name=session_id,json=sessionId,proto3" json:"session_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SearchRequest) Reset() { + *x = SearchRequest{} + mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SearchRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SearchRequest) ProtoMessage() {} + +func (x *SearchRequest) ProtoReflect() protoreflect.Message { + mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SearchRequest.ProtoReflect.Descriptor instead. +func (*SearchRequest) Descriptor() ([]byte, []int) { + return file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescGZIP(), []int{2} +} + +func (x *SearchRequest) GetPrompts() []string { + if x != nil { + return x.Prompts + } + return nil +} + +func (x *SearchRequest) GetSchemaId() string { + if x != nil { + return x.SchemaId + } + return "" +} + +func (x *SearchRequest) GetSessionId() string { + if x != nil { + return x.SessionId + } + return "" +} + +type SearchResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Operations across all prompts, already deduplicated and ranked. + // Order is significant: earlier entries rank higher and are preferred + // when bundle truncation drops from the tail. + Operations []*GeneratedOperation `protobuf:"bytes,1,rep,name=operations,proto3" json:"operations,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SearchResponse) Reset() { + *x = SearchResponse{} + mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SearchResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SearchResponse) ProtoMessage() {} + +func (x *SearchResponse) ProtoReflect() protoreflect.Message { + mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SearchResponse.ProtoReflect.Descriptor instead. +func (*SearchResponse) Descriptor() ([]byte, []int) { + return file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescGZIP(), []int{3} +} + +func (x *SearchResponse) GetOperations() []*GeneratedOperation { + if x != nil { + return x.Operations + } + return nil +} + +type GeneratedOperation struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Suggested operation name (camelCase preferred). The host applies + // its own identifier normalization and in-session collision-suffix + // logic on top of this — see §6. + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + // GraphQL operation body (query or mutation source text). + Body string `protobuf:"bytes,2,opt,name=body,proto3" json:"body,omitempty"` + // Operation kind. Subscriptions are out of scope; if Yoko returns + // one, the host drops it with a single warn log. + Kind OperationKind `protobuf:"varint,3,opt,name=kind,proto3,enum=wundergraph.cosmo.code_mode.yoko.v1.OperationKind" json:"kind,omitempty"` + // Human-readable description, surfaced as JSDoc on the typed + // `tools.` signature in the rendered bundle. + Description string `protobuf:"bytes,4,opt,name=description,proto3" json:"description,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GeneratedOperation) Reset() { + *x = GeneratedOperation{} + mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GeneratedOperation) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GeneratedOperation) ProtoMessage() {} + +func (x *GeneratedOperation) ProtoReflect() protoreflect.Message { + mi := &file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GeneratedOperation.ProtoReflect.Descriptor instead. +func (*GeneratedOperation) Descriptor() ([]byte, []int) { + return file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescGZIP(), []int{4} +} + +func (x *GeneratedOperation) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *GeneratedOperation) GetBody() string { + if x != nil { + return x.Body + } + return "" +} + +func (x *GeneratedOperation) GetKind() OperationKind { + if x != nil { + return x.Kind + } + return OperationKind_OPERATION_KIND_UNSPECIFIED +} + +func (x *GeneratedOperation) GetDescription() string { + if x != nil { + return x.Description + } + return "" +} + +var File_wg_cosmo_code_mode_yoko_v1_yoko_proto protoreflect.FileDescriptor + +const file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDesc = "" + + "\n" + + "%wg/cosmo/code_mode/yoko/v1/yoko.proto\x12#wundergraph.cosmo.code_mode.yoko.v1\"-\n" + + "\fIndexRequest\x12\x1d\n" + + "\n" + + "schema_sdl\x18\x01 \x01(\tR\tschemaSdl\",\n" + + "\rIndexResponse\x12\x1b\n" + + "\tschema_id\x18\x01 \x01(\tR\bschemaId\"e\n" + + "\rSearchRequest\x12\x18\n" + + "\aprompts\x18\x01 \x03(\tR\aprompts\x12\x1b\n" + + "\tschema_id\x18\x02 \x01(\tR\bschemaId\x12\x1d\n" + + "\n" + + "session_id\x18\x03 \x01(\tR\tsessionId\"i\n" + + "\x0eSearchResponse\x12W\n" + + "\n" + + "operations\x18\x01 \x03(\v27.wundergraph.cosmo.code_mode.yoko.v1.GeneratedOperationR\n" + + "operations\"\xa6\x01\n" + + "\x12GeneratedOperation\x12\x12\n" + + "\x04name\x18\x01 \x01(\tR\x04name\x12\x12\n" + + "\x04body\x18\x02 \x01(\tR\x04body\x12F\n" + + "\x04kind\x18\x03 \x01(\x0e22.wundergraph.cosmo.code_mode.yoko.v1.OperationKindR\x04kind\x12 \n" + + "\vdescription\x18\x04 \x01(\tR\vdescription*f\n" + + "\rOperationKind\x12\x1e\n" + + "\x1aOPERATION_KIND_UNSPECIFIED\x10\x00\x12\x18\n" + + "\x14OPERATION_KIND_QUERY\x10\x01\x12\x1b\n" + + "\x17OPERATION_KIND_MUTATION\x10\x022\xf0\x01\n" + + "\vYokoService\x12n\n" + + "\x05Index\x121.wundergraph.cosmo.code_mode.yoko.v1.IndexRequest\x1a2.wundergraph.cosmo.code_mode.yoko.v1.IndexResponse\x12q\n" + + "\x06Search\x122.wundergraph.cosmo.code_mode.yoko.v1.SearchRequest\x1a3.wundergraph.cosmo.code_mode.yoko.v1.SearchResponseB\xb2\x02\n" + + "'com.wundergraph.cosmo.code_mode.yoko.v1B\tYokoProtoP\x01ZOgithub.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1;yokov1\xa2\x02\x04WCCY\xaa\x02\"Wundergraph.Cosmo.CodeMode.Yoko.V1\xca\x02\"Wundergraph\\Cosmo\\CodeMode\\Yoko\\V1\xe2\x02.Wundergraph\\Cosmo\\CodeMode\\Yoko\\V1\\GPBMetadata\xea\x02&Wundergraph::Cosmo::CodeMode::Yoko::V1b\x06proto3" + +var ( + file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescOnce sync.Once + file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescData []byte +) + +func file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescGZIP() []byte { + file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescOnce.Do(func() { + file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDesc), len(file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDesc))) + }) + return file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDescData +} + +var file_wg_cosmo_code_mode_yoko_v1_yoko_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes = make([]protoimpl.MessageInfo, 5) +var file_wg_cosmo_code_mode_yoko_v1_yoko_proto_goTypes = []any{ + (OperationKind)(0), // 0: wundergraph.cosmo.code_mode.yoko.v1.OperationKind + (*IndexRequest)(nil), // 1: wundergraph.cosmo.code_mode.yoko.v1.IndexRequest + (*IndexResponse)(nil), // 2: wundergraph.cosmo.code_mode.yoko.v1.IndexResponse + (*SearchRequest)(nil), // 3: wundergraph.cosmo.code_mode.yoko.v1.SearchRequest + (*SearchResponse)(nil), // 4: wundergraph.cosmo.code_mode.yoko.v1.SearchResponse + (*GeneratedOperation)(nil), // 5: wundergraph.cosmo.code_mode.yoko.v1.GeneratedOperation +} +var file_wg_cosmo_code_mode_yoko_v1_yoko_proto_depIdxs = []int32{ + 5, // 0: wundergraph.cosmo.code_mode.yoko.v1.SearchResponse.operations:type_name -> wundergraph.cosmo.code_mode.yoko.v1.GeneratedOperation + 0, // 1: wundergraph.cosmo.code_mode.yoko.v1.GeneratedOperation.kind:type_name -> wundergraph.cosmo.code_mode.yoko.v1.OperationKind + 1, // 2: wundergraph.cosmo.code_mode.yoko.v1.YokoService.Index:input_type -> wundergraph.cosmo.code_mode.yoko.v1.IndexRequest + 3, // 3: wundergraph.cosmo.code_mode.yoko.v1.YokoService.Search:input_type -> wundergraph.cosmo.code_mode.yoko.v1.SearchRequest + 2, // 4: wundergraph.cosmo.code_mode.yoko.v1.YokoService.Index:output_type -> wundergraph.cosmo.code_mode.yoko.v1.IndexResponse + 4, // 5: wundergraph.cosmo.code_mode.yoko.v1.YokoService.Search:output_type -> wundergraph.cosmo.code_mode.yoko.v1.SearchResponse + 4, // [4:6] is the sub-list for method output_type + 2, // [2:4] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name +} + +func init() { file_wg_cosmo_code_mode_yoko_v1_yoko_proto_init() } +func file_wg_cosmo_code_mode_yoko_v1_yoko_proto_init() { + if File_wg_cosmo_code_mode_yoko_v1_yoko_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDesc), len(file_wg_cosmo_code_mode_yoko_v1_yoko_proto_rawDesc)), + NumEnums: 1, + NumMessages: 5, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_wg_cosmo_code_mode_yoko_v1_yoko_proto_goTypes, + DependencyIndexes: file_wg_cosmo_code_mode_yoko_v1_yoko_proto_depIdxs, + EnumInfos: file_wg_cosmo_code_mode_yoko_v1_yoko_proto_enumTypes, + MessageInfos: file_wg_cosmo_code_mode_yoko_v1_yoko_proto_msgTypes, + }.Build() + File_wg_cosmo_code_mode_yoko_v1_yoko_proto = out.File + file_wg_cosmo_code_mode_yoko_v1_yoko_proto_goTypes = nil + file_wg_cosmo_code_mode_yoko_v1_yoko_proto_depIdxs = nil +} diff --git a/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yokov1connect/yoko.connect.go b/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yokov1connect/yoko.connect.go new file mode 100644 index 0000000000..1e157644aa --- /dev/null +++ b/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yokov1connect/yoko.connect.go @@ -0,0 +1,142 @@ +// Code generated by protoc-gen-connect-go. DO NOT EDIT. +// +// Source: wg/cosmo/code_mode/yoko/v1/yoko.proto + +package yokov1connect + +import ( + connect "connectrpc.com/connect" + context "context" + errors "errors" + v1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1" + http "net/http" + strings "strings" +) + +// This is a compile-time assertion to ensure that this generated file and the connect package are +// compatible. If you get a compiler error that this constant is not defined, this code was +// generated with a version of connect newer than the one compiled into your binary. You can fix the +// problem by either regenerating this code with an older version of connect or updating the connect +// version compiled into your binary. +const _ = connect.IsAtLeastVersion1_13_0 + +const ( + // YokoServiceName is the fully-qualified name of the YokoService service. + YokoServiceName = "wundergraph.cosmo.code_mode.yoko.v1.YokoService" +) + +// These constants are the fully-qualified names of the RPCs defined in this package. They're +// exposed at runtime as Spec.Procedure and as the final two segments of the HTTP route. +// +// Note that these are different from the fully-qualified method names used by +// google.golang.org/protobuf/reflect/protoreflect. To convert from these constants to +// reflection-formatted method names, remove the leading slash and convert the remaining slash to a +// period. +const ( + // YokoServiceIndexProcedure is the fully-qualified name of the YokoService's Index RPC. + YokoServiceIndexProcedure = "/wundergraph.cosmo.code_mode.yoko.v1.YokoService/Index" + // YokoServiceSearchProcedure is the fully-qualified name of the YokoService's Search RPC. + YokoServiceSearchProcedure = "/wundergraph.cosmo.code_mode.yoko.v1.YokoService/Search" +) + +// These variables are the protoreflect.Descriptor objects for the RPCs defined in this package. +var ( + yokoServiceServiceDescriptor = v1.File_wg_cosmo_code_mode_yoko_v1_yoko_proto.Services().ByName("YokoService") + yokoServiceIndexMethodDescriptor = yokoServiceServiceDescriptor.Methods().ByName("Index") + yokoServiceSearchMethodDescriptor = yokoServiceServiceDescriptor.Methods().ByName("Search") +) + +// YokoServiceClient is a client for the wundergraph.cosmo.code_mode.yoko.v1.YokoService service. +type YokoServiceClient interface { + Index(context.Context, *connect.Request[v1.IndexRequest]) (*connect.Response[v1.IndexResponse], error) + Search(context.Context, *connect.Request[v1.SearchRequest]) (*connect.Response[v1.SearchResponse], error) +} + +// NewYokoServiceClient constructs a client for the wundergraph.cosmo.code_mode.yoko.v1.YokoService +// service. By default, it uses the Connect protocol with the binary Protobuf Codec, asks for +// gzipped responses, and sends uncompressed requests. To use the gRPC or gRPC-Web protocols, supply +// the connect.WithGRPC() or connect.WithGRPCWeb() options. +// +// The URL supplied here should be the base URL for the Connect or gRPC server (for example, +// http://api.acme.com or https://acme.com/grpc). +func NewYokoServiceClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) YokoServiceClient { + baseURL = strings.TrimRight(baseURL, "/") + return &yokoServiceClient{ + index: connect.NewClient[v1.IndexRequest, v1.IndexResponse]( + httpClient, + baseURL+YokoServiceIndexProcedure, + connect.WithSchema(yokoServiceIndexMethodDescriptor), + connect.WithClientOptions(opts...), + ), + search: connect.NewClient[v1.SearchRequest, v1.SearchResponse]( + httpClient, + baseURL+YokoServiceSearchProcedure, + connect.WithSchema(yokoServiceSearchMethodDescriptor), + connect.WithClientOptions(opts...), + ), + } +} + +// yokoServiceClient implements YokoServiceClient. +type yokoServiceClient struct { + index *connect.Client[v1.IndexRequest, v1.IndexResponse] + search *connect.Client[v1.SearchRequest, v1.SearchResponse] +} + +// Index calls wundergraph.cosmo.code_mode.yoko.v1.YokoService.Index. +func (c *yokoServiceClient) Index(ctx context.Context, req *connect.Request[v1.IndexRequest]) (*connect.Response[v1.IndexResponse], error) { + return c.index.CallUnary(ctx, req) +} + +// Search calls wundergraph.cosmo.code_mode.yoko.v1.YokoService.Search. +func (c *yokoServiceClient) Search(ctx context.Context, req *connect.Request[v1.SearchRequest]) (*connect.Response[v1.SearchResponse], error) { + return c.search.CallUnary(ctx, req) +} + +// YokoServiceHandler is an implementation of the wundergraph.cosmo.code_mode.yoko.v1.YokoService +// service. +type YokoServiceHandler interface { + Index(context.Context, *connect.Request[v1.IndexRequest]) (*connect.Response[v1.IndexResponse], error) + Search(context.Context, *connect.Request[v1.SearchRequest]) (*connect.Response[v1.SearchResponse], error) +} + +// NewYokoServiceHandler builds an HTTP handler from the service implementation. It returns the path +// on which to mount the handler and the handler itself. +// +// By default, handlers support the Connect, gRPC, and gRPC-Web protocols with the binary Protobuf +// and JSON codecs. They also support gzip compression. +func NewYokoServiceHandler(svc YokoServiceHandler, opts ...connect.HandlerOption) (string, http.Handler) { + yokoServiceIndexHandler := connect.NewUnaryHandler( + YokoServiceIndexProcedure, + svc.Index, + connect.WithSchema(yokoServiceIndexMethodDescriptor), + connect.WithHandlerOptions(opts...), + ) + yokoServiceSearchHandler := connect.NewUnaryHandler( + YokoServiceSearchProcedure, + svc.Search, + connect.WithSchema(yokoServiceSearchMethodDescriptor), + connect.WithHandlerOptions(opts...), + ) + return "/wundergraph.cosmo.code_mode.yoko.v1.YokoService/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case YokoServiceIndexProcedure: + yokoServiceIndexHandler.ServeHTTP(w, r) + case YokoServiceSearchProcedure: + yokoServiceSearchHandler.ServeHTTP(w, r) + default: + http.NotFound(w, r) + } + }) +} + +// UnimplementedYokoServiceHandler returns CodeUnimplemented from all methods. +type UnimplementedYokoServiceHandler struct{} + +func (UnimplementedYokoServiceHandler) Index(context.Context, *connect.Request[v1.IndexRequest]) (*connect.Response[v1.IndexResponse], error) { + return nil, connect.NewError(connect.CodeUnimplemented, errors.New("wundergraph.cosmo.code_mode.yoko.v1.YokoService.Index is not implemented")) +} + +func (UnimplementedYokoServiceHandler) Search(context.Context, *connect.Request[v1.SearchRequest]) (*connect.Response[v1.SearchResponse], error) { + return nil, connect.NewError(connect.CodeUnimplemented, errors.New("wundergraph.cosmo.code_mode.yoko.v1.YokoService.Search is not implemented")) +} diff --git a/router/go.mod b/router/go.mod index c2604da4a6..499f3b8aad 100644 --- a/router/go.mod +++ b/router/go.mod @@ -80,6 +80,7 @@ require ( github.com/posthog/posthog-go v1.5.5 github.com/pquerna/cachecontrol v0.2.0 github.com/santhosh-tekuri/jsonschema/v6 v6.0.1 + github.com/tdewolff/parse/v2 v2.8.12 github.com/tonglil/opentelemetry-go-datadog-propagator v0.1.3 github.com/wundergraph/astjson v1.1.0 github.com/wundergraph/go-arena v1.1.0 @@ -91,6 +92,8 @@ require ( golang.org/x/time v0.9.0 ) +require github.com/tetratelabs/wazero v1.9.0 // indirect + require ( github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect github.com/benbjohnson/clock v1.3.0 // indirect @@ -107,6 +110,8 @@ require ( github.com/docker/distribution v2.8.3+incompatible // indirect github.com/docker/docker-credential-helpers v0.9.3 // indirect github.com/docker/go-units v0.5.0 // indirect + github.com/evanw/esbuild v0.27.3 + github.com/fastschema/qjs v0.0.6 github.com/fatih/color v1.18.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/frankban/quicktest v1.14.6 // indirect diff --git a/router/go.sum b/router/go.sum index 561cbf94cd..7e82879a7d 100644 --- a/router/go.sum +++ b/router/go.sum @@ -73,8 +73,12 @@ github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4 github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/evanw/esbuild v0.27.3 h1:dH/to9tBKybig6hl25hg4SKIWP7U8COdJKbGEwnUkmU= +github.com/evanw/esbuild v0.27.3/go.mod h1:D2vIQZqV/vIf/VRHtViaUtViZmG7o+kKmlBfVQuRi48= github.com/expr-lang/expr v1.17.7 h1:Q0xY/e/2aCIp8g9s/LGvMDCC5PxYlvHgDZRQ4y16JX8= github.com/expr-lang/expr v1.17.7/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= +github.com/fastschema/qjs v0.0.6 h1:C45KMmQMd21UwsUAmQHxUxiWOfzwTg1GJW0DA0AbFEE= +github.com/fastschema/qjs v0.0.6/go.mod h1:bbg36wxXnx8g0FdKIe5+nCubrQvHa7XEVWqUptjHt/A= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= @@ -299,6 +303,12 @@ github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/tdewolff/parse/v2 v2.8.12 h1:5BBjfaCv482v3nltlS0u6wH1xJaxjR6ofDrWttNvROg= +github.com/tdewolff/parse/v2 v2.8.12/go.mod h1:Hwlni2tiVNKyzR1o6nUs4FOF07URA+JLBLd6dlIXYqo= +github.com/tdewolff/test v1.0.11 h1:FdLbwQVHxqG16SlkGveC0JVyrJN62COWTRyUFzfbtBE= +github.com/tdewolff/test v1.0.11/go.mod h1:XPuWBzvdUzhCuxWO1ojpXsyzsA5bFoS3tO/Q3kFuTG8= +github.com/tetratelabs/wazero v1.9.0 h1:IcZ56OuxrtaEz8UYNRHBrUa9bYeX9oVY93KspZZBf/I= +github.com/tetratelabs/wazero v1.9.0/go.mod h1:TSbcXCfFP0L2FGkRPxHphadXPjo1T6W+CseNNY7EkjM= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= diff --git a/router/internal/codemode/calltrace/calltrace.go b/router/internal/codemode/calltrace/calltrace.go new file mode 100644 index 0000000000..a6d3286e3b --- /dev/null +++ b/router/internal/codemode/calltrace/calltrace.go @@ -0,0 +1,92 @@ +package calltrace + +import ( + "encoding/json" + "os" + "sync" + "time" +) + +type Recorder interface { + RecordRequest(toolName string, body []byte) + RecordResponse(toolName string, body []byte) +} + +type Record struct { + ToolName string `json:"tool_name"` + Timestamp time.Time `json:"timestamp"` + Body json.RawMessage `json:"body"` +} + +type NopRecorder struct{} + +func (NopRecorder) RecordRequest(string, []byte) {} +func (NopRecorder) RecordResponse(string, []byte) {} + +type FileRecorder struct { + path string + now func() time.Time + mu sync.Mutex +} + +type Option func(*FileRecorder) + +func WithNow(now func() time.Time) Option { + return func(r *FileRecorder) { + if now != nil { + r.now = now + } + } +} + +func NewFileRecorder(path string, opts ...Option) *FileRecorder { + recorder := &FileRecorder{ + path: path, + now: time.Now, + } + for _, opt := range opts { + opt(recorder) + } + return recorder +} + +func (r *FileRecorder) RecordRequest(toolName string, body []byte) { + r.record(toolName, body) +} + +func (r *FileRecorder) RecordResponse(toolName string, body []byte) { + r.record(toolName, body) +} + +func (r *FileRecorder) record(toolName string, body []byte) { + if r == nil || r.path == "" { + return + } + line, err := json.Marshal(Record{ + ToolName: toolName, + Timestamp: r.now(), + Body: json.RawMessage(body), + }) + if err != nil { + return + } + line = append(line, '\n') + + r.mu.Lock() + defer r.mu.Unlock() + file, err := os.OpenFile(r.path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600) + if err != nil { + return + } + defer file.Close() + _, _ = file.Write(line) +} + +func Enabled(recorder Recorder) bool { + switch recorder.(type) { + case nil, NopRecorder, *NopRecorder: + return false + default: + return true + } +} diff --git a/router/internal/codemode/calltrace/calltrace_test.go b/router/internal/codemode/calltrace/calltrace_test.go new file mode 100644 index 0000000000..b9af2315ac --- /dev/null +++ b/router/internal/codemode/calltrace/calltrace_test.go @@ -0,0 +1,51 @@ +package calltrace + +import ( + "bufio" + "encoding/json" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFileRecorderWritesRequestAndResponseJSONL(t *testing.T) { + path := filepath.Join(t.TempDir(), "call-trace.jsonl") + now := time.Date(2026, 5, 4, 10, 30, 0, 0, time.UTC) + recorder := NewFileRecorder(path, WithNow(func() time.Time { return now })) + + recorder.RecordRequest("code_mode_run_js", []byte(`{"source":"async () => 1"}`)) + recorder.RecordResponse("code_mode_run_js", []byte(`{"content":[{"type":"text","text":"1"}]}`)) + + file, err := os.Open(path) + require.NoError(t, err) + defer file.Close() + + var got []Record + scanner := bufio.NewScanner(file) + for scanner.Scan() { + var record Record + require.NoError(t, json.Unmarshal(scanner.Bytes(), &record)) + got = append(got, record) + } + require.NoError(t, scanner.Err()) + assert.Equal(t, []Record{ + { + ToolName: "code_mode_run_js", + Timestamp: now, + Body: json.RawMessage(`{"source":"async () =\u003e 1"}`), + }, + { + ToolName: "code_mode_run_js", + Timestamp: now, + Body: json.RawMessage(`{"content":[{"type":"text","text":"1"}]}`), + }, + }, got) +} + +func TestNopRecorderIsDisabled(t *testing.T) { + assert.Equal(t, false, Enabled(NopRecorder{})) +} diff --git a/router/internal/codemode/deps.go b/router/internal/codemode/deps.go new file mode 100644 index 0000000000..ed36aea04c --- /dev/null +++ b/router/internal/codemode/deps.go @@ -0,0 +1,8 @@ +//go:build tools + +package codemode + +import ( + _ "github.com/evanw/esbuild/pkg/api" + _ "github.com/fastschema/qjs" +) diff --git a/router/internal/codemode/harness/envelope.go b/router/internal/codemode/harness/envelope.go new file mode 100644 index 0000000000..7b09fec3ba --- /dev/null +++ b/router/internal/codemode/harness/envelope.go @@ -0,0 +1,207 @@ +package harness + +import ( + "bytes" + "encoding/json" + "fmt" + "strings" + "unicode/utf8" + + "github.com/wundergraph/cosmo/router/internal/codemode/sandbox" +) + +const defaultMaxResultBytes = 32 << 10 +const previewBytes = 1 << 10 + +type ErrorEnvelope = sandbox.ErrorEnvelope +type SerializationWarning = sandbox.SerializationWarning + +// ResultEnvelope is the MCP-facing tool-result body for code_mode_run_js. +// +// Wire shape: +// - result is always present (null if the agent threw). +// - warnings is omitted on the wire when empty. +// - truncated is omitted on the wire when false (only signals a non-default state). +// - error is omitted on the wire when nil (only present on the throw path). +type ResultEnvelope struct { + Result json.RawMessage `json:"result"` + Warnings []SerializationWarning `json:"warnings,omitempty"` + Truncated bool `json:"truncated,omitempty"` + Error *ErrorEnvelope `json:"error,omitempty"` +} + +func BuildEnvelope(sandboxResult sandbox.ExecuteResult, maxResultBytes int) (ResultEnvelope, error) { + if maxResultBytes <= 0 { + maxResultBytes = defaultMaxResultBytes + } + if !sandboxResult.OK { + return ResultEnvelope{ + Result: json.RawMessage("null"), + Warnings: sandboxResult.Warnings, + Truncated: false, + Error: cloneErrorEnvelope(sandboxResult.Error), + }, nil + } + if len(sandboxResult.Result) <= maxResultBytes { + return ResultEnvelope{Result: sandboxResult.Result, Warnings: sandboxResult.Warnings, Truncated: false, Error: nil}, nil + } + + truncated, ok, err := structurallyTruncate(sandboxResult.Result, maxResultBytes) + if err != nil { + return ResultEnvelope{}, err + } + if ok { + return ResultEnvelope{Result: truncated, Warnings: sandboxResult.Warnings, Truncated: true, Error: nil}, nil + } + fallback, err := previewEnvelope(sandboxResult.Result) + if err != nil { + return ResultEnvelope{}, err + } + return ResultEnvelope{Result: fallback, Warnings: sandboxResult.Warnings, Truncated: true, Error: nil}, nil +} + +func cloneErrorEnvelope(err *ErrorEnvelope) *ErrorEnvelope { + if err == nil { + return nil + } + return &ErrorEnvelope{ + Name: err.Name, + Message: err.Message, + Stack: err.Stack, + Cause: cloneErrorEnvelope(err.Cause), + } +} + +func structurallyTruncate(raw json.RawMessage, maxBytes int) (json.RawMessage, bool, error) { + trimmed := bytes.TrimSpace(raw) + if len(trimmed) == 0 { + return nil, false, fmt.Errorf("empty JSON result") + } + switch trimmed[0] { + case '[': + items, err := splitJSONArray(trimmed) + if err != nil { + return nil, false, err + } + for keep := len(items); keep >= 0; keep-- { + body := joinJSON('[', ']', items[:keep]) + if len(body) <= maxBytes { + return body, true, nil + } + } + case '{': + fields, err := splitJSONObject(trimmed) + if err != nil { + return nil, false, err + } + for keep := len(fields); keep >= 0; keep-- { + body := joinJSON('{', '}', fields[:keep]) + if len(body) <= maxBytes { + return body, true, nil + } + } + } + return nil, false, nil +} + +func splitJSONArray(raw []byte) ([]json.RawMessage, error) { + if !json.Valid(raw) { + return nil, fmt.Errorf("invalid JSON result") + } + inner := bytes.TrimSpace(raw[1 : len(raw)-1]) + if len(inner) == 0 { + return nil, nil + } + return splitTopLevel(inner), nil +} + +func splitJSONObject(raw []byte) ([]json.RawMessage, error) { + if !json.Valid(raw) { + return nil, fmt.Errorf("invalid JSON result") + } + inner := bytes.TrimSpace(raw[1 : len(raw)-1]) + if len(inner) == 0 { + return nil, nil + } + return splitTopLevel(inner), nil +} + +func splitTopLevel(raw []byte) []json.RawMessage { + parts := make([]json.RawMessage, 0) + start := 0 + depth := 0 + inString := false + escaped := false + for i, b := range raw { + if inString { + if escaped { + escaped = false + } else if b == '\\' { + escaped = true + } else if b == '"' { + inString = false + } + continue + } + switch b { + case '"': + inString = true + case '[', '{': + depth++ + case ']', '}': + depth-- + case ',': + if depth == 0 { + parts = append(parts, bytes.TrimSpace(raw[start:i])) + start = i + 1 + } + } + } + parts = append(parts, bytes.TrimSpace(raw[start:])) + return parts +} + +func joinJSON(open byte, close byte, parts []json.RawMessage) json.RawMessage { + var b strings.Builder + b.WriteByte(open) + for i, part := range parts { + if i > 0 { + b.WriteByte(',') + } + b.Write(bytes.TrimSpace(part)) + } + b.WriteByte(close) + return json.RawMessage(b.String()) +} + +func previewEnvelope(raw json.RawMessage) (json.RawMessage, error) { + preview := string(raw) + var value string + if err := json.Unmarshal(raw, &value); err == nil { + preview = value + } + body, err := json.Marshal(struct { + Truncated bool `json:"__truncated"` + OriginalSize int `json:"originalSize"` + Preview string `json:"preview"` + }{ + Truncated: true, + OriginalSize: len(raw), + Preview: firstUTF8Bytes(preview, previewBytes), + }) + if err != nil { + return nil, err + } + return body, nil +} + +func firstUTF8Bytes(s string, limit int) string { + if len(s) <= limit { + return s + } + cut := limit + for cut > 0 && !utf8.ValidString(s[:cut]) { + cut-- + } + return s[:cut] +} diff --git a/router/internal/codemode/harness/envelope_test.go b/router/internal/codemode/harness/envelope_test.go new file mode 100644 index 0000000000..9c6ce7f3a5 --- /dev/null +++ b/router/internal/codemode/harness/envelope_test.go @@ -0,0 +1,61 @@ +package harness + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/internal/codemode/sandbox" +) + +func TestBuildEnvelopePassesThroughSmallResult(t *testing.T) { + got, err := BuildEnvelope(sandbox.ExecuteResult{OK: true, Result: raw(`{"ok":true}`)}, 32<<10) + require.NoError(t, err) + + assert.Equal(t, ResultEnvelope{Result: raw(`{"ok":true}`), Truncated: false, Error: nil}, got) +} + +func TestBuildEnvelopeTruncatesTopLevelArray(t *testing.T) { + got, err := BuildEnvelope(sandbox.ExecuteResult{OK: true, Result: raw(`[{"id":1},{"id":2},{"id":3}]`)}, len(`[{"id":1},{"id":2}]`)) + require.NoError(t, err) + + assert.Equal(t, ResultEnvelope{Result: raw(`[{"id":1},{"id":2}]`), Truncated: true, Error: nil}, got) +} + +func TestBuildEnvelopeTruncatesTopLevelObject(t *testing.T) { + got, err := BuildEnvelope(sandbox.ExecuteResult{OK: true, Result: raw(`{"a":1,"b":2,"c":3}`)}, len(`{"a":1,"b":2}`)) + require.NoError(t, err) + + assert.Equal(t, ResultEnvelope{Result: raw(`{"a":1,"b":2}`), Truncated: true, Error: nil}, got) +} + +func TestBuildEnvelopeFallsBackToPreviewForHugeScalar(t *testing.T) { + value := strings.Repeat("a", 2048) + body, err := json.Marshal(value) + require.NoError(t, err) + + got, err := BuildEnvelope(sandbox.ExecuteResult{OK: true, Result: body}, 128) + require.NoError(t, err) + + var preview struct { + Truncated bool `json:"__truncated"` + OriginalSize int `json:"originalSize"` + Preview string `json:"preview"` + } + require.NoError(t, json.Unmarshal(got.Result, &preview)) + assert.Equal(t, true, got.Truncated) + assert.Equal(t, true, preview.Truncated) + assert.Equal(t, len(body), preview.OriginalSize) + assert.Equal(t, strings.Repeat("a", 1024), preview.Preview) +} + +func TestBuildEnvelopeCopiesSandboxError(t *testing.T) { + sandboxErr := &sandbox.ErrorEnvelope{Name: "Error", Message: "boom", Stack: "stack"} + + got, err := BuildEnvelope(sandbox.ExecuteResult{OK: false, Error: sandboxErr}, 32<<10) + require.NoError(t, err) + + assert.Equal(t, ResultEnvelope{Result: raw(`null`), Truncated: false, Error: sandboxErr}, got) +} diff --git a/router/internal/codemode/harness/pipeline.go b/router/internal/codemode/harness/pipeline.go new file mode 100644 index 0000000000..a11a607a72 --- /dev/null +++ b/router/internal/codemode/harness/pipeline.go @@ -0,0 +1,127 @@ +package harness + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/wundergraph/cosmo/router/internal/codemode/sandbox" +) + +const defaultMaxInputBytes = 64 << 10 + +type sandboxExecutor interface { + Execute(ctx context.Context, req sandbox.ExecuteRequest) (sandbox.ExecuteResult, error) +} + +type Executor interface { + Execute(ctx context.Context, req PipelineRequest) (PipelineResponse, error) +} + +type Pipeline struct { + Sandbox *sandbox.Sandbox + MaxInputBytes int + MaxResultBytes int + + executor sandboxExecutor +} + +type PipelineRequest struct { + SessionID string + ToolNames []string + Source string + RequestHeaders http.Header + ApprovalGate sandbox.ApprovalGate +} + +type PipelineResponse struct { + Envelope ResultEnvelope + Encoded []byte + Diagnostics []Diagnostic +} + +func (p *Pipeline) Execute(ctx context.Context, req PipelineRequest) (PipelineResponse, error) { + maxInputBytes := p.MaxInputBytes + if maxInputBytes <= 0 { + maxInputBytes = defaultMaxInputBytes + } + + // Raw-source guard rejects oversized input before esbuild parses it. The + // same limit applies post-transpile below because generated JS can differ + // slightly from source size. + if len(req.Source) > maxInputBytes { + return p.errorResponse(&ErrorEnvelope{ + Name: "InputTooLarge", + Message: fmt.Sprintf("input size %d bytes exceeds limit %d bytes", len(req.Source), maxInputBytes), + Stack: "", + }, nil) + } + + transpiled, err := Transpile(req.Source) + if err != nil { + return p.errorResponse(&ErrorEnvelope{Name: "TranspileError", Message: err.Error(), Stack: ""}, transpiled.Diagnostics) + } + + if len(transpiled.JS) > maxInputBytes { + return p.errorResponse(&ErrorEnvelope{ + Name: "InputTooLarge", + Message: fmt.Sprintf("input size %d bytes exceeds limit %d bytes", len(transpiled.JS), maxInputBytes), + Stack: "", + }, nil) + } + + if err := ShapeCheck(transpiled.JS); err != nil { + return p.errorResponse(&ErrorEnvelope{Name: "ShapeCheck", Message: err.Error(), Stack: ""}, nil) + } + + executor, err := p.sandboxExecutor() + if err != nil { + return PipelineResponse{}, err + } + sandboxResult, err := executor.Execute(ctx, sandbox.ExecuteRequest{ + SessionID: req.SessionID, + ToolNames: req.ToolNames, + WrappedJS: transpiled.JS, + SourceMap: transpiled.SourceMap, + RequestHeaders: req.RequestHeaders, + ApprovalGate: req.ApprovalGate, + }) + if err != nil { + return PipelineResponse{}, err + } + + envelope, err := BuildEnvelope(sandboxResult, p.MaxResultBytes) + if err != nil { + return PipelineResponse{}, err + } + encoded, err := json.Marshal(envelope) + if err != nil { + return PipelineResponse{}, err + } + return PipelineResponse{Envelope: envelope, Encoded: encoded}, nil +} + +func (p *Pipeline) sandboxExecutor() (sandboxExecutor, error) { + if p.executor != nil { + return p.executor, nil + } + if p.Sandbox == nil { + return nil, errors.New("code mode: pipeline sandbox is nil") + } + return p.Sandbox, nil +} + +func (p *Pipeline) errorResponse(errEnv *ErrorEnvelope, diagnostics []Diagnostic) (PipelineResponse, error) { + envelope := ResultEnvelope{ + Result: json.RawMessage("null"), + Truncated: false, + Error: errEnv, + } + encoded, err := json.Marshal(envelope) + if err != nil { + return PipelineResponse{}, err + } + return PipelineResponse{Envelope: envelope, Encoded: encoded, Diagnostics: diagnostics}, nil +} diff --git a/router/internal/codemode/harness/pipeline_test.go b/router/internal/codemode/harness/pipeline_test.go new file mode 100644 index 0000000000..981d943c3c --- /dev/null +++ b/router/internal/codemode/harness/pipeline_test.go @@ -0,0 +1,144 @@ +package harness + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/internal/codemode/sandbox" +) + +type fakeExecutor struct { + calls int + result sandbox.ExecuteResult + err error +} + +func (f *fakeExecutor) Execute(ctx context.Context, req sandbox.ExecuteRequest) (sandbox.ExecuteResult, error) { + f.calls++ + return f.result, f.err +} + +func TestPipelineShapeCheckFailureShortCircuits(t *testing.T) { + fake := &fakeExecutor{} + pipeline := Pipeline{executor: fake} + + got, err := pipeline.Execute(context.Background(), PipelineRequest{Source: `() => 1`}) + require.NoError(t, err) + + assert.Equal(t, 0, fake.calls) + require.NotNil(t, got.Envelope.Error) + assert.Equal(t, "ShapeCheck", got.Envelope.Error.Name) + assert.Equal(t, `code mode: source must be a single async-arrow root (got: missing async modifier)`, got.Envelope.Error.Message) + assert.Empty(t, got.Diagnostics) + assert.NotEmpty(t, got.Encoded) +} + +func TestPipelineTopLevelAwaitFailsAtTranspile(t *testing.T) { + fake := &fakeExecutor{} + pipeline := Pipeline{executor: fake} + + got, err := pipeline.Execute(context.Background(), PipelineRequest{Source: `await tools.getUser({})`}) + require.NoError(t, err) + + assert.Equal(t, 0, fake.calls) + require.NotNil(t, got.Envelope.Error) + assert.Equal(t, "TranspileError", got.Envelope.Error.Name) + // esbuild's exact message is target-version dependent. We only assert the + // transpile-error envelope name; the full message lives in Diagnostics. + assert.NotEmpty(t, got.Diagnostics) +} + +func TestPipelineAcceptsTypeScriptSource(t *testing.T) { + fake := &fakeExecutor{result: sandbox.ExecuteResult{OK: true, Result: raw(`{"id":"1"}`)}} + pipeline := Pipeline{executor: fake} + + // TypeScript source: type annotations, optional params, type parameters. + // All three are valid TS-only syntax. Pipeline must transpile then accept. + tsInputs := []string{ + `async (x: string) => ({ id: x })`, + `async (x: string, y?: number) => ({ id: x })`, + `async (x: T) => ({ id: String(x) })`, + } + for _, in := range tsInputs { + t.Run(in, func(t *testing.T) { + fake.calls = 0 + got, err := pipeline.Execute(context.Background(), PipelineRequest{Source: in}) + require.NoError(t, err) + assert.Equal(t, 1, fake.calls, "sandbox should be invoked") + assert.Nil(t, got.Envelope.Error, "no shape or transpile error expected") + }) + } +} + +func TestPipelineTranspileFailureReturnsDiagnostics(t *testing.T) { + fake := &fakeExecutor{} + pipeline := Pipeline{executor: fake} + + got, err := pipeline.Execute(context.Background(), PipelineRequest{Source: `async () => { let x = ; }`}) + require.NoError(t, err) + + assert.Equal(t, 0, fake.calls) + require.NotNil(t, got.Envelope.Error) + assert.Equal(t, "TranspileError", got.Envelope.Error.Name) + assert.NotEmpty(t, got.Diagnostics) + assert.NotEmpty(t, got.Encoded) +} + +func TestPipelineSandboxErrorIsFoldedIntoEnvelope(t *testing.T) { + fake := &fakeExecutor{result: sandbox.ExecuteResult{ + OK: false, + Error: &sandbox.ErrorEnvelope{Name: "RuntimeError", Message: "boom", Stack: "stack"}, + }} + pipeline := Pipeline{executor: fake} + + got, err := pipeline.Execute(context.Background(), PipelineRequest{Source: `async () => 1`}) + require.NoError(t, err) + + assert.Equal(t, 1, fake.calls) + require.NotNil(t, got.Envelope.Error) + assert.Equal(t, "RuntimeError", got.Envelope.Error.Name) + assert.Equal(t, false, got.Envelope.Truncated) +} + +func TestPipelineSandboxSuccessEncodesEnvelope(t *testing.T) { + fake := &fakeExecutor{result: sandbox.ExecuteResult{OK: true, Result: raw(`{"ok":true}`)}} + pipeline := Pipeline{executor: fake} + + got, err := pipeline.Execute(context.Background(), PipelineRequest{ + SessionID: "session-1", + ToolNames: []string{"getUser"}, + Source: `async () => ({ ok: true })`, + RequestHeaders: http.Header{"Authorization": []string{"Bearer token"}}, + ApprovalGate: nil, + }) + require.NoError(t, err) + + assert.Equal(t, 1, fake.calls) + assert.Equal(t, ResultEnvelope{Result: raw(`{"ok":true}`), Truncated: false, Error: nil}, got.Envelope) + + var decoded map[string]any + require.NoError(t, json.Unmarshal(got.Encoded, &decoded)) + assert.Equal(t, map[string]any{"result": map[string]any{"ok": true}}, decoded) +} + +func TestPipelineTruncationTriggers(t *testing.T) { + result, err := json.Marshal([]any{map[string]any{"id": 1}, map[string]any{"id": 2}, map[string]any{"id": 3}}) + require.NoError(t, err) + + fake := &fakeExecutor{result: sandbox.ExecuteResult{OK: true, Result: result}} + pipeline := Pipeline{executor: fake, MaxResultBytes: len(`[{"id":1},{"id":2}]`)} + + got, err := pipeline.Execute(context.Background(), PipelineRequest{Source: `async () => []`}) + require.NoError(t, err) + + assert.Equal(t, true, got.Envelope.Truncated) + assert.Equal(t, raw(`[{"id":1},{"id":2}]`), got.Envelope.Result) +} + +func raw(s string) json.RawMessage { + return json.RawMessage(s) +} diff --git a/router/internal/codemode/harness/shape.go b/router/internal/codemode/harness/shape.go new file mode 100644 index 0000000000..56b1a9cd7a --- /dev/null +++ b/router/internal/codemode/harness/shape.go @@ -0,0 +1,97 @@ +package harness + +import ( + "errors" + "strings" + + "github.com/tdewolff/parse/v2" + "github.com/tdewolff/parse/v2/js" +) + +const shapeErrorPrefix = "code mode: source must be a single async-arrow root (got: " + +// ShapeCheck verifies that the given JavaScript source is exactly one +// top-level expression statement whose expression is an async arrow function. +// +// Input contract: ShapeCheck expects the *post-esbuild* JavaScript. TypeScript +// syntax is stripped earlier in the pipeline by Transpile (esbuild loaderTS). +// Callers must run Transpile first. +// +// Note: parse error messages from tdewolff include line/col positions for the +// post-esbuild source, NOT the original TS source the user wrote. That's +// acceptable because (a) ShapeCheck failures are structural, not character-level, +// and (b) Transpile already surfaces TS-source diagnostics for syntactic errors. +func ShapeCheck(source string) error { + if strings.TrimSpace(source) == "" { + return shapeError("empty source") + } + + ast, err := js.Parse(parse.NewInputBytes([]byte(source)), js.Options{}) + if err != nil { + return shapeError("parse failed: " + err.Error()) + } + + stmts := ast.BlockStmt.List + if len(stmts) == 0 { + return shapeError("empty source") + } + + // Detect import/export *before* the multi-statement check. Otherwise an + // input like `import x from "x"; async () => x` would report + // "multiple statements" instead of the more useful "leading import/export". + switch stmts[0].(type) { + case *js.ImportStmt, *js.ExportStmt: + return shapeError("leading import/export") + } + + if len(stmts) > 1 { + return shapeError("multiple statements") + } + + switch stmt := stmts[0].(type) { + case *js.ExprStmt: + return checkExpression(stmt.Value) + default: + return shapeError("non-arrow root") + } +} + +// checkExpression verifies the expression is an async arrow function, +// transparently unwrapping any number of redundant parentheses. +func checkExpression(expr js.IExpr) error { + for { + group, ok := expr.(*js.GroupExpr) + if !ok { + break + } + expr = group.X + } + + if isTopLevelAwait(expr) { + return shapeError("top-level await") + } + + arrow, ok := expr.(*js.ArrowFunc) + if !ok { + return shapeError("non-arrow root") + } + if !arrow.Async { + return shapeError("missing async modifier") + } + return nil +} + +// isTopLevelAwait detects `await x` used as a top-level expression. tdewolff +// parses await as a UnaryExpr with the Await operator. We surface this as a +// distinct error because it's a common model mistake worth flagging clearly. +func isTopLevelAwait(expr js.IExpr) bool { + unary, ok := expr.(*js.UnaryExpr) + if !ok { + return false + } + return unary.Op == js.AwaitToken +} + +func shapeError(reason string) error { + return errors.New(shapeErrorPrefix + reason + ")") +} diff --git a/router/internal/codemode/harness/shape_test.go b/router/internal/codemode/harness/shape_test.go new file mode 100644 index 0000000000..c6632e4389 --- /dev/null +++ b/router/internal/codemode/harness/shape_test.go @@ -0,0 +1,73 @@ +package harness + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// ShapeCheck runs on post-esbuild JavaScript. Inputs in this file are written +// as the JS that Transpile would produce — never raw TypeScript. End-to-end +// TS handling is covered by pipeline_test.go and transpile_test.go. + +func TestShapeCheckAcceptsAsyncArrowRoots(t *testing.T) { + tests := []string{ + `async () => 1`, + `async()=>1`, + `async () => { return 1; }`, + `async (x) => x`, + `async (x, y) => x + y`, + `async (x) => ({ x })`, + `(async () => 1)`, + `((async () => 1))`, + " \n\tasync () => true", + "// leading\nasync () => true", + "/* leading */ async () => true", + `async ({ id }) => id`, + `async () => await tools.getUser({ id: "1" })`, + `async () => { const rows = await Promise.all([]); return rows; }`, + } + for _, source := range tests { + t.Run(source, func(t *testing.T) { + assert.NoError(t, ShapeCheck(source)) + }) + } +} + +func TestShapeCheckRejectsNonAsyncArrowRoots(t *testing.T) { + tests := []struct { + name string + source string + want string + }{ + // Top-level await: ShapeCheck handles this defensively for the case where the + // pipeline's esbuild target is later raised to ES2022. Under today's ES2020 target, + // `await x` is rejected at Transpile and never reaches ShapeCheck — but the AST + // path still works as a unit, so we keep the test. + {name: "top-level await", source: `await tools.getUser({})`, want: `code mode: source must be a single async-arrow root (got: top-level await)`}, + // Import/export must be detected before the multi-statement check, otherwise + // `import x from "x"; async () => x` reports "multiple statements" instead. + {name: "import then arrow", source: `import x from "x"; async () => x`, want: `code mode: source must be a single async-arrow root (got: leading import/export)`}, + {name: "import alone", source: `import x from "x"`, want: `code mode: source must be a single async-arrow root (got: leading import/export)`}, + {name: "export", source: `export default async () => 1`, want: `code mode: source must be a single async-arrow root (got: leading import/export)`}, + {name: "block", source: `{ async () => 1 }`, want: `code mode: source must be a single async-arrow root (got: non-arrow root)`}, + {name: "function declaration", source: `async function main() {}`, want: `code mode: source must be a single async-arrow root (got: non-arrow root)`}, + {name: "non async arrow", source: `() => 1`, want: `code mode: source must be a single async-arrow root (got: missing async modifier)`}, + {name: "paren non async arrow", source: `(() => 1)`, want: `code mode: source must be a single async-arrow root (got: missing async modifier)`}, + {name: "identifier", source: `foo`, want: `code mode: source must be a single async-arrow root (got: non-arrow root)`}, + {name: "empty", source: ` `, want: `code mode: source must be a single async-arrow root (got: empty source)`}, + {name: "comment-only", source: `// only trivia`, want: `code mode: source must be a single async-arrow root (got: empty source)`}, + {name: "async call", source: `async()`, want: `code mode: source must be a single async-arrow root (got: non-arrow root)`}, + {name: "multiple arrows", source: `async () => 1; async () => 2`, want: `code mode: source must be a single async-arrow root (got: multiple statements)`}, + {name: "var then arrow", source: `const x = 1; async () => x`, want: `code mode: source must be a single async-arrow root (got: multiple statements)`}, + {name: "class", source: `class X {}`, want: `code mode: source must be a single async-arrow root (got: non-arrow root)`}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ShapeCheck(tt.source) + if assert.Error(t, err) { + assert.Equal(t, tt.want, err.Error()) + } + }) + } +} diff --git a/router/internal/codemode/harness/transpile.go b/router/internal/codemode/harness/transpile.go new file mode 100644 index 0000000000..129cf7bd3c --- /dev/null +++ b/router/internal/codemode/harness/transpile.go @@ -0,0 +1,73 @@ +package harness + +import ( + "errors" + "strings" + + "github.com/evanw/esbuild/pkg/api" +) + +type TranspileResult struct { + JS string + SourceMap []byte + Diagnostics []Diagnostic +} + +type Diagnostic struct { + Text string + Line int + Column int + File string +} + +func Transpile(source string) (TranspileResult, error) { + result := api.Transform(source, api.TransformOptions{ + Loader: api.LoaderTS, + Target: api.ES2020, + Platform: api.PlatformNeutral, + Format: api.FormatDefault, + Sourcemap: api.SourceMapExternal, + Sourcefile: "agent.ts", + LogLevel: api.LogLevelSilent, + LegalComments: api.LegalCommentsNone, + Drop: api.DropDebugger, + Charset: api.CharsetASCII, + }) + + out := TranspileResult{ + JS: trimTranspiledExpression(string(result.Code)), + SourceMap: append([]byte(nil), result.Map...), + Diagnostics: diagnosticsFromMessages(result.Errors), + } + if len(result.Errors) > 0 { + return out, errors.New("transpile failed: " + strings.Join(diagnosticTexts(out.Diagnostics), "; ")) + } + return out, nil +} + +func trimTranspiledExpression(js string) string { + trimmed := strings.TrimSpace(js) + return strings.TrimSuffix(trimmed, ";") +} + +func diagnosticsFromMessages(messages []api.Message) []Diagnostic { + diagnostics := make([]Diagnostic, 0, len(messages)) + for _, message := range messages { + diagnostic := Diagnostic{Text: message.Text} + if message.Location != nil { + diagnostic.Line = message.Location.Line + diagnostic.Column = message.Location.Column + 1 + diagnostic.File = message.Location.File + } + diagnostics = append(diagnostics, diagnostic) + } + return diagnostics +} + +func diagnosticTexts(diagnostics []Diagnostic) []string { + texts := make([]string, 0, len(diagnostics)) + for _, diagnostic := range diagnostics { + texts = append(texts, diagnostic.Text) + } + return texts +} diff --git a/router/internal/codemode/harness/transpile_test.go b/router/internal/codemode/harness/transpile_test.go new file mode 100644 index 0000000000..1d9cc21597 --- /dev/null +++ b/router/internal/codemode/harness/transpile_test.go @@ -0,0 +1,61 @@ +package harness + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTranspileStripsTypeScriptAnnotations(t *testing.T) { + got, err := Transpile(`async () => { const x: string = "hi"; return x; }`) + require.NoError(t, err) + + assert.NotContains(t, got.JS, `: string`) + assert.Contains(t, got.JS, `"hi"`) + assert.False(t, strings.HasSuffix(strings.TrimSpace(got.JS), ";")) + assert.NotEmpty(t, got.SourceMap) + assert.Empty(t, got.Diagnostics) + + var sourceMap map[string]any + require.NoError(t, json.Unmarshal(got.SourceMap, &sourceMap)) + assert.Equal(t, float64(3), sourceMap["version"]) +} + +func TestTranspileTreatsTypesAsNotation(t *testing.T) { + got, err := Transpile(`async (value: { id: string }): Promise => value.id`) + require.NoError(t, err) + + assert.NotContains(t, got.JS, `Promise`) + assert.NotContains(t, got.JS, `id: string`) + assert.Contains(t, got.JS, `value.id`) +} + +func TestTranspileReportsDiagnosticsForSyntaxErrors(t *testing.T) { + got, err := Transpile(`async () => { let x = ; }`) + require.Error(t, err) + + require.NotEmpty(t, got.Diagnostics) + assert.NotEmpty(t, got.Diagnostics[0].Text) + assert.NotEqual(t, 0, got.Diagnostics[0].Line) + assert.NotEqual(t, 0, got.Diagnostics[0].Column) + assert.True(t, strings.Contains(err.Error(), got.Diagnostics[0].Text)) +} + +func TestTranspileDropsDebuggerStatement(t *testing.T) { + got, err := Transpile(`async () => { debugger; return 1; }`) + require.NoError(t, err) + + assert.NotContains(t, got.JS, "debugger", "Drop:DropDebugger should remove debugger statements") +} + +func TestTranspileEscapesNonASCII(t *testing.T) { + got, err := Transpile(`async () => "héllo"`) + require.NoError(t, err) + + // CharsetASCII tells esbuild to escape non-ASCII codepoints in string + // literals. The raw `é` byte sequence must not appear in the output. + assert.NotContains(t, got.JS, "é", "Charset:ASCII should escape non-ASCII codepoints") +} diff --git a/router/internal/codemode/observability/logging.go b/router/internal/codemode/observability/logging.go new file mode 100644 index 0000000000..20b3e24810 --- /dev/null +++ b/router/internal/codemode/observability/logging.go @@ -0,0 +1,48 @@ +package observability + +import ( + "go.uber.org/zap" +) + +func LogSessionLifecycle(logger *zap.Logger, event string, sessionID string, fields ...zap.Field) { + if logger == nil { + return + } + allFields := append([]zap.Field{ + zap.String("event", event), + zap.String("session_id", sessionID), + }, fields...) + logger.Info("code mode session lifecycle", allFields...) +} + +func LogTranspileFailure(logger *zap.Logger, sessionID string, diagnostic string) { + if logger == nil { + return + } + logger.Info("code mode transpile failure", + zap.String("session_id", sessionID), + zap.String("diagnostic", diagnostic), + ) +} + +func LogElicitationOutcome(logger *zap.Logger, sessionID string, approved bool, reason string) { + if logger == nil { + return + } + logger.Info("code mode elicitation outcome", + zap.String("session_id", sessionID), + zap.Bool("approved", approved), + zap.String("reason", reason), + ) +} + +func LogToolInvocationFailure(logger *zap.Logger, sessionID string, opName string, err error) { + if logger == nil { + return + } + logger.Info("code mode tool invocation failure", + zap.String("session_id", sessionID), + zap.String("op_name", opName), + zap.Error(err), + ) +} diff --git a/router/internal/codemode/observability/logging_test.go b/router/internal/codemode/observability/logging_test.go new file mode 100644 index 0000000000..0a883f4ab1 --- /dev/null +++ b/router/internal/codemode/observability/logging_test.go @@ -0,0 +1,43 @@ +package observability + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest/observer" +) + +func TestLoggingHelpersEmitStructuredInfoEntries(t *testing.T) { + core, observed := observer.New(zapcore.InfoLevel) + logger := zap.New(core) + + LogSessionLifecycle(logger, "created", "session-1", zap.String("storage", "memory")) + LogTranspileFailure(logger, "session-1", "Unexpected \";\"") + LogElicitationOutcome(logger, "session-1", false, "operator declined") + LogToolInvocationFailure(logger, "session-1", "getOrders", errors.New("upstream timeout")) + + entries := observed.AllUntimed() + require.Len(t, entries, 4) + assert.Equal(t, []observer.LoggedEntry{ + { + Entry: zapcore.Entry{Level: zapcore.InfoLevel, Message: "code mode session lifecycle"}, + Context: []zapcore.Field{zap.String("event", "created"), zap.String("session_id", "session-1"), zap.String("storage", "memory")}, + }, + { + Entry: zapcore.Entry{Level: zapcore.InfoLevel, Message: "code mode transpile failure"}, + Context: []zapcore.Field{zap.String("session_id", "session-1"), zap.String("diagnostic", "Unexpected \";\"")}, + }, + { + Entry: zapcore.Entry{Level: zapcore.InfoLevel, Message: "code mode elicitation outcome"}, + Context: []zapcore.Field{zap.String("session_id", "session-1"), zap.Bool("approved", false), zap.String("reason", "operator declined")}, + }, + { + Entry: zapcore.Entry{Level: zapcore.InfoLevel, Message: "code mode tool invocation failure"}, + Context: []zapcore.Field{zap.String("session_id", "session-1"), zap.String("op_name", "getOrders"), zap.Error(errors.New("upstream timeout"))}, + }, + }, entries) +} diff --git a/router/internal/codemode/observability/metrics.go b/router/internal/codemode/observability/metrics.go new file mode 100644 index 0000000000..a9aaf18428 --- /dev/null +++ b/router/internal/codemode/observability/metrics.go @@ -0,0 +1,56 @@ +package observability + +import ( + "context" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" +) + +const meterName = "wundergraph.cosmo.router.mcp.code_mode" + +type Meter struct { + executionsCounter metric.Int64Counter + durationHistogram metric.Float64Histogram +} + +func NewMeter(meterProvider metric.MeterProvider) (*Meter, error) { + if meterProvider == nil { + meterProvider = otel.GetMeterProvider() + } + meter := meterProvider.Meter(meterName) + + executionsCounter, err := meter.Int64Counter( + "mcp.code_mode.sandbox.executions", + metric.WithDescription("Code Mode sandbox executions."), + ) + if err != nil { + return nil, err + } + durationHistogram, err := meter.Float64Histogram( + "mcp.code_mode.sandbox.duration", + metric.WithDescription("Code Mode sandbox execution duration."), + metric.WithUnit("ms"), + ) + if err != nil { + return nil, err + } + + return &Meter{ + executionsCounter: executionsCounter, + durationHistogram: durationHistogram, + }, nil +} + +func (m *Meter) Record(ctx context.Context, toolName, status string, durationMs float64) { + if m == nil { + return + } + attrs := metric.WithAttributes( + attribute.String("mcp.tool", toolName), + attribute.String("mcp.status", status), + ) + m.executionsCounter.Add(ctx, 1, attrs) + m.durationHistogram.Record(ctx, durationMs, attrs) +} diff --git a/router/internal/codemode/observability/metrics_test.go b/router/internal/codemode/observability/metrics_test.go new file mode 100644 index 0000000000..e39a1c6021 --- /dev/null +++ b/router/internal/codemode/observability/metrics_test.go @@ -0,0 +1,76 @@ +package observability + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/attribute" + sdkmetric "go.opentelemetry.io/otel/sdk/metric" + "go.opentelemetry.io/otel/sdk/metric/metricdata" +) + +func TestMeterRecordEmitsCounterAndDurationHistogram(t *testing.T) { + reader := sdkmetric.NewManualReader() + provider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(reader)) + meter, err := NewMeter(provider) + require.NoError(t, err) + + meter.Record(context.Background(), "code_mode_run_js", "success", 12.5) + + var got metricdata.ResourceMetrics + require.NoError(t, reader.Collect(context.Background(), &got)) + counter, histogram := codeModeMetrics(t, got) + + counterData, ok := counter.Data.(metricdata.Sum[int64]) + require.True(t, ok) + require.Len(t, counterData.DataPoints, 1) + counterPoint := counterData.DataPoints[0] + counterPoint.StartTime = time.Time{} + counterPoint.Time = time.Time{} + assert.Equal(t, metricdata.DataPoint[int64]{ + Attributes: attribute.NewSet( + attribute.String("mcp.tool", "code_mode_run_js"), + attribute.String("mcp.status", "success"), + ), + Value: 1, + }, counterPoint) + + histogramData, ok := histogram.Data.(metricdata.Histogram[float64]) + require.True(t, ok) + require.Len(t, histogramData.DataPoints, 1) + histogramPoint := histogramData.DataPoints[0] + histogramPoint.StartTime = time.Time{} + histogramPoint.Time = time.Time{} + assert.Equal(t, metricdata.HistogramDataPoint[float64]{ + Attributes: attribute.NewSet( + attribute.String("mcp.tool", "code_mode_run_js"), + attribute.String("mcp.status", "success"), + ), + Count: 1, + Bounds: histogramPoint.Bounds, + BucketCounts: histogramPoint.BucketCounts, + Min: histogramPoint.Min, + Max: histogramPoint.Max, + Sum: 12.5, + }, histogramPoint) +} + +func codeModeMetrics(t *testing.T, metrics metricdata.ResourceMetrics) (metricdata.Metrics, metricdata.Metrics) { + t.Helper() + require.Len(t, metrics.ScopeMetrics, 1) + assert.Equal(t, "wundergraph.cosmo.router.mcp.code_mode", metrics.ScopeMetrics[0].Scope.Name) + + byName := make(map[string]metricdata.Metrics, len(metrics.ScopeMetrics[0].Metrics)) + for _, metric := range metrics.ScopeMetrics[0].Metrics { + byName[metric.Name] = metric + } + + counter, ok := byName["mcp.code_mode.sandbox.executions"] + require.True(t, ok) + histogram, ok := byName["mcp.code_mode.sandbox.duration"] + require.True(t, ok) + return counter, histogram +} diff --git a/router/internal/codemode/observability/tracing.go b/router/internal/codemode/observability/tracing.go new file mode 100644 index 0000000000..70456d8cf9 --- /dev/null +++ b/router/internal/codemode/observability/tracing.go @@ -0,0 +1,36 @@ +package observability + +import ( + "context" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" +) + +const tracerName = "wundergraph.cosmo.router.mcp.code_mode" + +func StartToolSpan(ctx context.Context, toolName string) (context.Context, trace.Span) { + return StartToolSpanWithProvider(ctx, otel.GetTracerProvider(), toolName) +} + +func StartToolSpanWithProvider(ctx context.Context, tracerProvider trace.TracerProvider, toolName string) (context.Context, trace.Span) { + if tracerProvider == nil { + tracerProvider = otel.GetTracerProvider() + } + return tracerProvider.Tracer(tracerName).Start(ctx, toolSpanName(toolName), + trace.WithSpanKind(trace.SpanKindServer), + trace.WithAttributes(attribute.String("mcp.tool", toolName)), + ) +} + +func toolSpanName(toolName string) string { + switch toolName { + case "code_mode_search_tools": + return "MCP Code Mode - Search" + case "code_mode_run_js": + return "MCP Code Mode - Execute" + default: + return "MCP Code Mode - " + toolName + } +} diff --git a/router/internal/codemode/observability/tracing_test.go b/router/internal/codemode/observability/tracing_test.go new file mode 100644 index 0000000000..6ce73bf0a9 --- /dev/null +++ b/router/internal/codemode/observability/tracing_test.go @@ -0,0 +1,67 @@ +package observability + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" + "go.opentelemetry.io/otel/trace" +) + +func TestStartToolSpanRecordsSearchServerSpan(t *testing.T) { + recorder := tracetest.NewSpanRecorder() + provider := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(recorder)) + previous := otel.GetTracerProvider() + otel.SetTracerProvider(provider) + t.Cleanup(func() { otel.SetTracerProvider(previous) }) + + _, span := StartToolSpan(context.Background(), "code_mode_search_tools") + span.End() + + ended := recorder.Ended() + require.Len(t, ended, 1) + stub := tracetest.SpanStubFromReadOnlySpan(ended[0]) + stub.SpanContext = trace.SpanContext{} + stub.StartTime = time.Time{} + stub.EndTime = time.Time{} + stub.Resource = nil + assert.Equal(t, tracetest.SpanStub{ + Name: "MCP Code Mode - Search", + SpanKind: trace.SpanKindServer, + Attributes: []attribute.KeyValue{ + attribute.String("mcp.tool", "code_mode_search_tools"), + }, + InstrumentationLibrary: stub.InstrumentationLibrary, + }, stub) +} + +func TestStartToolSpanRecordsExecuteServerSpan(t *testing.T) { + recorder := tracetest.NewSpanRecorder() + provider := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(recorder)) + + ctx, span := StartToolSpanWithProvider(context.Background(), provider, "code_mode_run_js") + require.True(t, trace.SpanFromContext(ctx).SpanContext().IsValid()) + span.End() + + ended := recorder.Ended() + require.Len(t, ended, 1) + stub := tracetest.SpanStubFromReadOnlySpan(ended[0]) + stub.SpanContext = trace.SpanContext{} + stub.StartTime = time.Time{} + stub.EndTime = time.Time{} + stub.Resource = nil + assert.Equal(t, tracetest.SpanStub{ + Name: "MCP Code Mode - Execute", + SpanKind: trace.SpanKindServer, + Attributes: []attribute.KeyValue{ + attribute.String("mcp.tool", "code_mode_run_js"), + }, + InstrumentationLibrary: stub.InstrumentationLibrary, + }, stub) +} diff --git a/router/internal/codemode/sandbox/errors.go b/router/internal/codemode/sandbox/errors.go new file mode 100644 index 0000000000..c22e75e7be --- /dev/null +++ b/router/internal/codemode/sandbox/errors.go @@ -0,0 +1,201 @@ +package sandbox + +import ( + "encoding/json" + "regexp" + "strconv" + "strings" + + "github.com/fastschema/qjs" +) + +func normalizeError(ctx *qjs.Context, errValue *qjs.Value, sourceMap []byte, program string) (*ErrorEnvelope, error) { + global := ctx.Global() + normalizer := global.GetPropertyStr("__codemodeNormalizeErrorJSON") + encoded, err := ctx.Invoke(normalizer, global, errValue) + if err != nil { + return nil, err + } + + var envelope ErrorEnvelope + if err := json.Unmarshal([]byte(encoded.String()), &envelope); err != nil { + return nil, err + } + envelope.Stack = rewriteStack(envelope.Stack, sourceMap, userCodeStartLine(program)) + rewriteCauseStacks(envelope.Cause, sourceMap, program) + return &envelope, nil +} + +var toolsCallRE = regexp.MustCompile(`tools\.([A-Za-z_$][A-Za-z0-9_$]*)\s*\(`) + +func missingToolName(source string, known []string) string { + knownSet := map[string]struct{}{} + for _, name := range known { + knownSet[name] = struct{}{} + } + for _, match := range toolsCallRE.FindAllStringSubmatch(source, -1) { + if len(match) != 2 { + continue + } + if _, ok := knownSet[match[1]]; !ok { + return match[1] + } + } + return "" +} + +func rewriteCauseStacks(err *ErrorEnvelope, sourceMap []byte, program string) { + for err != nil { + err.Stack = rewriteStack(err.Stack, sourceMap, userCodeStartLine(program)) + err = err.Cause + } +} + +var stackLocationRE = regexp.MustCompile(`(?:\w+\.js:)?(\d+):(\d+)`) + +func rewriteStack(stack string, sourceMap []byte, userStartLine int) string { + if len(sourceMap) == 0 || stack == "" { + return stack + } + sm, err := parseSourceMap(sourceMap) + if err != nil { + return stack + } + return stackLocationRE.ReplaceAllStringFunc(stack, func(match string) string { + parts := stackLocationRE.FindStringSubmatch(match) + if len(parts) != 3 { + return match + } + line, err := strconv.Atoi(parts[1]) + if err != nil { + return match + } + col, err := strconv.Atoi(parts[2]) + if err != nil { + return match + } + generatedLine := line - userStartLine + 1 + if generatedLine < 1 { + return match + } + mapped, ok := sm.lookup(generatedLine, col) + if !ok { + return match + } + prefix := strings.TrimSuffix(match, parts[1]+":"+parts[2]) + return prefix + mapped.source + ":" + strconv.Itoa(mapped.line) + ":" + strconv.Itoa(mapped.column) + }) +} + +type sourceMap struct { + lines [][]mapping +} + +type mapping struct { + generatedColumn int + source string + line int + column int +} + +func parseSourceMap(data []byte) (*sourceMap, error) { + var raw struct { + Sources []string `json:"sources"` + Mappings string `json:"mappings"` + } + if err := json.Unmarshal(data, &raw); err != nil { + return nil, err + } + + sm := &sourceMap{lines: make([][]mapping, 0)} + var sourceIndex, originalLine, originalColumn int + for _, lineMappings := range strings.Split(raw.Mappings, ";") { + var generatedColumn int + line := make([]mapping, 0) + for _, segment := range strings.Split(lineMappings, ",") { + if segment == "" { + continue + } + values, err := decodeVLQSegment(segment) + if err != nil { + return nil, err + } + if len(values) < 4 { + continue + } + generatedColumn += values[0] + sourceIndex += values[1] + originalLine += values[2] + originalColumn += values[3] + if sourceIndex >= 0 && sourceIndex < len(raw.Sources) { + line = append(line, mapping{ + generatedColumn: generatedColumn, + source: raw.Sources[sourceIndex], + line: originalLine + 1, + column: originalColumn + 1, + }) + } + } + sm.lines = append(sm.lines, line) + } + return sm, nil +} + +func (sm *sourceMap) lookup(generatedLine, generatedColumn int) (mapping, bool) { + if generatedLine < 1 || generatedLine > len(sm.lines) { + return mapping{}, false + } + line := sm.lines[generatedLine-1] + if len(line) == 0 { + return mapping{}, false + } + column0 := generatedColumn - 1 + best := line[0] + for _, candidate := range line { + if candidate.generatedColumn > column0 { + break + } + best = candidate + } + return best, true +} + +const vlqBaseShift = 5 +const vlqBase = 1 << vlqBaseShift +const vlqBaseMask = vlqBase - 1 +const vlqContinuationBit = vlqBase + +var base64VLQ = map[rune]int{ + 'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4, 'F': 5, 'G': 6, 'H': 7, + 'I': 8, 'J': 9, 'K': 10, 'L': 11, 'M': 12, 'N': 13, 'O': 14, 'P': 15, + 'Q': 16, 'R': 17, 'S': 18, 'T': 19, 'U': 20, 'V': 21, 'W': 22, 'X': 23, + 'Y': 24, 'Z': 25, 'a': 26, 'b': 27, 'c': 28, 'd': 29, 'e': 30, 'f': 31, + 'g': 32, 'h': 33, 'i': 34, 'j': 35, 'k': 36, 'l': 37, 'm': 38, 'n': 39, + 'o': 40, 'p': 41, 'q': 42, 'r': 43, 's': 44, 't': 45, 'u': 46, 'v': 47, + 'w': 48, 'x': 49, 'y': 50, 'z': 51, '0': 52, '1': 53, '2': 54, '3': 55, + '4': 56, '5': 57, '6': 58, '7': 59, '8': 60, '9': 61, '+': 62, '/': 63, +} + +func decodeVLQSegment(segment string) ([]int, error) { + values := make([]int, 0, 4) + var value, shift int + for _, r := range segment { + digit := base64VLQ[r] + continuation := digit&vlqContinuationBit != 0 + digit &= vlqBaseMask + value += digit << shift + if continuation { + shift += vlqBaseShift + continue + } + negative := value&1 == 1 + value >>= 1 + if negative { + value = -value + } + values = append(values, value) + value = 0 + shift = 0 + } + return values, nil +} diff --git a/router/internal/codemode/sandbox/execute.go b/router/internal/codemode/sandbox/execute.go new file mode 100644 index 0000000000..37747a83bc --- /dev/null +++ b/router/internal/codemode/sandbox/execute.go @@ -0,0 +1,211 @@ +package sandbox + +import ( + "context" + "encoding/json" + "fmt" + "io" + "strings" + "time" + + "github.com/fastschema/qjs" +) + +func (s *Sandbox) Execute(ctx context.Context, req ExecuteRequest) (execResult ExecuteResult, retErr error) { + if err := s.acquire(ctx); err != nil { + return ExecuteResult{}, err + } + defer s.release() + + // qjs v0.0.6 panics from inside its Eval/Free/Close paths when the underlying + // wazero module is closed by context cancellation (e.g. host call exceeded + // the sandbox wall-clock). Recover here so a panicking call cannot crash the + // router goroutine; surface as a Timeout envelope instead. + defer func() { + if r := recover(); r != nil { + errEnv := &ErrorEnvelope{Name: "Timeout", Message: fmt.Sprintf("sandbox runtime panic: %v", r)} + if ctx.Err() != nil { + errEnv.Message = ctx.Err().Error() + } + execResult = ExecuteResult{OK: false, Error: errEnv, OutputSize: envelopeSize(nil, errEnv)} + retErr = nil + } + }() + + program := buildPreamble(req.WrappedJS) + if len(program) > s.cfg.MaxInputSizeBytes { + errEnv := &ErrorEnvelope{ + Name: "InputTooLarge", + Message: fmt.Sprintf("input size %d bytes exceeds limit %d bytes", len(program), s.cfg.MaxInputSizeBytes), + Stack: "", + } + return ExecuteResult{OK: false, Error: errEnv, OutputSize: envelopeSize(nil, errEnv)}, nil + } + + execCtx, cancel := context.WithTimeout(ctx, s.cfg.RequestTimeout) + defer cancel() + + rt, err := qjs.New(qjs.Option{ + Context: execCtx, + CloseOnContextDone: true, + DisableBuildCache: true, + MemoryLimit: s.cfg.MemoryLimitBytes, + MaxExecutionTime: int(s.cfg.RequestTimeout / time.Millisecond), + Stdout: io.Discard, + Stderr: io.Discard, + }) + if err != nil { + return runtimeErrorResult(err, execCtx, 0), nil + } + + qctx := rt.Context() + state := &executeState{req: req} + defer func() { + // qjs panics on Close when the runtime context has already been cancelled. + // Treat the runtime as best-effort cleanup; a leaked WASM instance is bounded + // by GC and the per-call freshness contract. + defer func() { _ = recover() }() + rt.Close() + }() + s.installHostInvoke(execCtx, qctx, state) + if err := installValidationHelpers(qctx); err != nil { + return runtimeErrorResult(err, execCtx, int(state.hostCalls.Load())), nil + } + + global := qctx.Global() + toolNames := req.ToolNames + if toolNames == nil { + toolNames = []string{} + } + namesJSON, err := json.Marshal(toolNames) + if err != nil { + return ExecuteResult{}, err + } + names := qctx.ParseJSON(string(namesJSON)) + global.SetPropertyStr("__HOST_TOOL_NAMES", names) + + value, err := qctx.Eval("codemode_agent.js", qjs.Code(program)) + if err != nil { + return runtimeErrorResult(err, execCtx, int(state.hostCalls.Load())), nil + } + + value, err = awaitWithContext(execCtx, rt, value) + if err != nil { + return runtimeErrorResult(err, execCtx, int(state.hostCalls.Load())), nil + } + okValue := value.GetPropertyStr("ok") + ok := okValue.Bool() + + if !ok { + errValue := value.GetPropertyStr("error") + errEnv, err := normalizeError(qctx, errValue, req.SourceMap, program) + if err != nil { + return runtimeErrorResult(err, execCtx, int(state.hostCalls.Load())), nil + } + if errEnv.Name == "InternalError" { + errEnv.Name = "MemoryLimit" + } + if errEnv.Name == "TypeError" && errEnv.Message == "not a function" { + if missing := missingToolName(req.WrappedJS, req.ToolNames); missing != "" { + errEnv.Message = "tools." + missing + " is not a function" + } + } + hostCalls := int(state.hostCalls.Load()) + if errEnv.Name == "HostCallLimitExceeded" { + hostCalls = s.cfg.MaxToolInvocationsPerCall + 1 + } + return ExecuteResult{ + OK: false, + Error: errEnv, + OutputSize: envelopeSize(nil, errEnv), + HostCalls: hostCalls, + }, nil + } + + resultValue := value.GetPropertyStr("result") + result, warnings, validationErr, err := validateResult(qctx, resultValue, s.cfg.MaxOutputSizeBytes) + if err != nil { + return runtimeErrorResult(err, execCtx, int(state.hostCalls.Load())), nil + } + if validationErr != nil { + return ExecuteResult{ + OK: false, + Error: validationErr, + Warnings: warnings, + OutputSize: envelopeSize(nil, validationErr), + HostCalls: int(state.hostCalls.Load()), + }, nil + } + return ExecuteResult{ + OK: true, + Result: result, + Warnings: warnings, + OutputSize: envelopeSize(result, nil), + HostCalls: int(state.hostCalls.Load()), + }, nil +} + +type awaitResult struct { + value *qjs.Value + err error +} + +func awaitWithContext(ctx context.Context, rt *qjs.Runtime, value *qjs.Value) (*qjs.Value, error) { + if !value.IsPromise() { + return value, nil + } + + done := make(chan awaitResult, 1) + go func() { + awaited, err := value.Await() + done <- awaitResult{value: awaited, err: err} + }() + + select { + case result := <-done: + return result.value, result.err + case <-ctx.Done(): + // Best-effort runtime close so the await goroutine unblocks; the deferred + // close in Execute owns the canonical cleanup (and recovers any qjs panic). + func() { + defer func() { _ = recover() }() + rt.Close() + }() + select { + case result := <-done: + _ = result + case <-time.After(100 * time.Millisecond): + } + return nil, ctx.Err() + } +} + +func runtimeErrorResult(err error, ctx context.Context, hostCalls int) ExecuteResult { + errEnv := classifyRuntimeError(err, ctx) + return ExecuteResult{ + OK: false, + Error: errEnv, + OutputSize: envelopeSize(nil, errEnv), + HostCalls: hostCalls, + } +} + +func classifyRuntimeError(err error, ctx context.Context) *ErrorEnvelope { + if ctx.Err() != nil { + return &ErrorEnvelope{Name: "Timeout", Message: ctx.Err().Error(), Stack: ""} + } + msg := err.Error() + lower := strings.ToLower(msg) + if strings.Contains(lower, "memory") || strings.Contains(lower, "out of memory") { + return &ErrorEnvelope{Name: "MemoryLimit", Message: msg, Stack: ""} + } + return &ErrorEnvelope{Name: "Error", Message: msg, Stack: ""} +} + +func envelopeSize(result json.RawMessage, errEnv *ErrorEnvelope) int { + if errEnv != nil { + body, _ := json.Marshal(errEnv) + return len(body) + } + return len(result) +} diff --git a/router/internal/codemode/sandbox/headers.go b/router/internal/codemode/sandbox/headers.go new file mode 100644 index 0000000000..100ae0be85 --- /dev/null +++ b/router/internal/codemode/sandbox/headers.go @@ -0,0 +1,44 @@ +package sandbox + +import ( + "net/http" + "strings" +) + +var hopByHopHeaders = map[string]struct{}{ + "connection": {}, + "keep-alive": {}, + "proxy-authenticate": {}, + "proxy-authorization": {}, + "te": {}, + "trailer": {}, + "transfer-encoding": {}, + "upgrade": {}, +} + +func headerAllowList(headers []string) map[string]struct{} { + allow := make(map[string]struct{}, len(headers)) + for _, h := range headers { + canonical := strings.ToLower(http.CanonicalHeaderKey(h)) + if _, hop := hopByHopHeaders[canonical]; hop { + continue + } + allow[canonical] = struct{}{} + } + return allow +} + +func copyAllowedHeaders(dst, src http.Header, allow map[string]struct{}) { + for name, values := range src { + canonical := strings.ToLower(http.CanonicalHeaderKey(name)) + if _, hop := hopByHopHeaders[canonical]; hop { + continue + } + if _, ok := allow[canonical]; !ok { + continue + } + for _, value := range values { + dst.Add(name, value) + } + } +} diff --git a/router/internal/codemode/sandbox/host.go b/router/internal/codemode/sandbox/host.go new file mode 100644 index 0000000000..e81509be36 --- /dev/null +++ b/router/internal/codemode/sandbox/host.go @@ -0,0 +1,241 @@ +package sandbox + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "sync" + "sync/atomic" + + "github.com/fastschema/qjs" + "github.com/wundergraph/cosmo/router/internal/codemode/observability" + "github.com/wundergraph/cosmo/router/internal/codemode/storage" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" +) + +// TODO(code-mode §9): The plan calls for channel-based async host calls so +// Promise.all can overlap HTTP work. qjs v0.0.6 SetAsyncFunc invokes the Go +// callback synchronously on the QuickJS/Wazero call path, and resolving from +// arbitrary goroutines is not supported by the wrapper without a JS-thread +// drain loop, so host calls remain serialized for the MVP. + +type executeState struct { + req ExecuteRequest + hostCalls atomic.Int32 + qjsMu sync.Mutex +} + +func (s *Sandbox) installHostInvoke(ctx context.Context, qctx *qjs.Context, state *executeState) { + qctx.SetAsyncFunc("__hostInvokeTool", func(this *qjs.This) { + args := this.Args() + name := "" + if len(args) > 0 && !args[0].IsUndefined() && !args[0].IsNull() { + name = args[0].String() + } + vars, err := varsJSON(args) + if err != nil { + resolveString(this, state, hostErrorPayload("TypeError", err.Error())) + return + } + + result, invokeErr := s.invokeTool(ctx, state, name, vars) + if invokeErr != nil { + resolveString(this, state, hostErrorPayload(invokeErr.name, invokeErr.message)) + return + } + resolveString(this, state, string(result)) + }) +} + +func resolveString(this *qjs.This, state *executeState, payload string) { + state.qjsMu.Lock() + defer state.qjsMu.Unlock() + this.Promise().Resolve(this.Context().NewString(payload)) +} + +func hostErrorPayload(name, message string) string { + body, _ := json.Marshal(map[string]any{ + "__codemodeHostError": map[string]string{ + "name": name, + "message": message, + }, + }) + return string(body) +} + +type hostError struct { + name string + message string +} + +func varsJSON(args []*qjs.Value) (json.RawMessage, error) { + if len(args) < 2 || args[1].IsUndefined() || args[1].IsNull() { + return json.RawMessage(`{}`), nil + } + jsonString, err := args[1].JSONStringify() + if err != nil { + return nil, err + } + if jsonString == "" || jsonString == "null" { + return json.RawMessage(`{}`), nil + } + return json.RawMessage(jsonString), nil +} + +func (s *Sandbox) invokeTool(ctx context.Context, state *executeState, name string, vars json.RawMessage) (json.RawMessage, *hostError) { + count := int(state.hostCalls.Add(1)) + if count > s.cfg.MaxToolInvocationsPerCall { + return nil, &hostError{ + name: "HostCallLimitExceeded", + message: fmt.Sprintf("tools.* invocation cap of %d exceeded; batch independent calls with Promise.all.", s.cfg.MaxToolInvocationsPerCall), + } + } + + op, ok, err := s.cfg.StorageLookup(ctx, state.req.SessionID, name) + if err != nil { + observability.LogToolInvocationFailure(s.cfg.Logger, state.req.SessionID, name, err) + return nil, &hostError{name: "Error", message: err.Error()} + } + if !ok { + err := fmt.Errorf("tools.%s is not a function", name) + observability.LogToolInvocationFailure(s.cfg.Logger, state.req.SessionID, name, err) + return nil, &hostError{name: "TypeError", message: err.Error()} + } + + span := trace.SpanFromContext(ctx) + span.SetAttributes( + attribute.String("codemode.op.name", op.Name), + attribute.String("codemode.op.kind", string(op.Kind)), + ) + + if op.Kind == storage.OperationKindMutation { + gate := state.req.ApprovalGate + if gate == nil { + gate = approveAllGate{} + } + decision, err := gate.Decide(ctx, ApprovalRequest{Name: name, Source: op.Body, Vars: vars}) + if err != nil { + observability.LogToolInvocationFailure(s.cfg.Logger, state.req.SessionID, name, err) + return nil, &hostError{name: "Error", message: err.Error()} + } + span.SetAttributes( + attribute.Bool("code_mode.mutation.approved", decision.Approved), + attribute.String("code_mode.mutation.reason", decision.Reason), + ) + if !decision.Approved { + body := mutationDeclinedResponse(decision.Reason) + span.SetAttributes(attribute.Bool("codemode.op.success", false)) + return body, nil + } + } + + body, err := json.Marshal(graphQLRequest{ + Query: op.Body, + OperationName: name, + Variables: vars, + }) + if err != nil { + return nil, &hostError{name: "Error", message: err.Error()} + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, s.cfg.RouterGraphQLEndpoint, bytes.NewReader(body)) + if err != nil { + return nil, &hostError{name: "Error", message: err.Error()} + } + copyAllowedHeaders(httpReq.Header, state.req.RequestHeaders, s.allowList) + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := s.http.Do(httpReq) + if err != nil { + span.SetAttributes(attribute.Bool("codemode.op.success", false)) + observability.LogToolInvocationFailure(s.cfg.Logger, state.req.SessionID, name, err) + return nil, &hostError{name: "Error", message: err.Error()} + } + defer resp.Body.Close() + + respBody, err := readCapped(resp.Body, s.cfg.MaxResponseBodyBytes) + if err != nil { + span.SetAttributes(attribute.Bool("codemode.op.success", false)) + observability.LogToolInvocationFailure(s.cfg.Logger, state.req.SessionID, name, err) + return nil, &hostError{name: "Error", message: err.Error()} + } + + result := normalizeGraphQLResponse(resp.StatusCode, respBody) + if errorsJSON := graphQLErrors(result); errorsJSON != "" { + span.SetAttributes(attribute.String("codemode.graphql.errors", errorsJSON)) + observability.LogToolInvocationFailure(s.cfg.Logger, state.req.SessionID, name, fmt.Errorf("graphql errors: %s", errorsJSON)) + } + span.SetAttributes(attribute.Bool("codemode.op.success", resp.StatusCode < 400)) + if resp.StatusCode >= 400 { + observability.LogToolInvocationFailure(s.cfg.Logger, state.req.SessionID, name, fmt.Errorf("graphql http status %d", resp.StatusCode)) + } + return result, nil +} + +type graphQLRequest struct { + Query string `json:"query"` + OperationName string `json:"operationName"` + Variables json.RawMessage `json:"variables"` +} + +func mutationDeclinedResponse(reason string) json.RawMessage { + if reason == "" { + reason = "Mutation declined by operator" + } + body, _ := json.Marshal(map[string]any{ + "data": nil, + "errors": []map[string]string{{ + "message": "Mutation declined by operator: " + reason, + }}, + "declined": map[string]string{"reason": reason}, + }) + return body +} + +func normalizeGraphQLResponse(status int, body []byte) json.RawMessage { + var payload map[string]json.RawMessage + if err := json.Unmarshal(body, &payload); err == nil { + if status >= 400 { + if _, ok := payload["errors"]; ok { + out, _ := json.Marshal(payload) + return out + } + } + out, _ := json.Marshal(payload) + return out + } + msg := strings.TrimSpace(string(body)) + if msg == "" { + msg = http.StatusText(status) + } + out, _ := json.Marshal(map[string]any{ + "errors": []map[string]string{{"message": msg}}, + }) + return out +} + +func graphQLErrors(body json.RawMessage) string { + var payload struct { + Errors json.RawMessage `json:"errors"` + } + if err := json.Unmarshal(body, &payload); err != nil || len(payload.Errors) == 0 { + return "" + } + return string(payload.Errors) +} + +func readCapped(r io.Reader, capBytes int) ([]byte, error) { + data, err := io.ReadAll(io.LimitReader(r, int64(capBytes)+1)) + if err != nil { + return nil, err + } + if len(data) > capBytes { + return nil, fmt.Errorf("tools.* HTTP response body exceeded %d bytes", capBytes) + } + return data, nil +} diff --git a/router/internal/codemode/sandbox/preamble.go b/router/internal/codemode/sandbox/preamble.go new file mode 100644 index 0000000000..58b0bf57ef --- /dev/null +++ b/router/internal/codemode/sandbox/preamble.go @@ -0,0 +1,28 @@ +package sandbox + +import ( + _ "embed" + "strings" +) + +//go:embed sandbox_preamble.js +var preambleTemplate string + +const ( + spliceComment = "// Splice point: Execute.WrappedJS is already harness-wrapped and transpiled." + agentMainSpliceID = "__AGENT_MAIN_SPLICE__" +) + +func buildPreamble(wrappedJS string) string { + return strings.Replace(preambleTemplate, agentMainSpliceID, wrappedJS, 1) +} + +func userCodeStartLine(program string) int { + lines := strings.Split(program, "\n") + for i, line := range lines { + if line == spliceComment { + return i + 2 + } + } + return 1 +} diff --git a/router/internal/codemode/sandbox/preamble_test.go b/router/internal/codemode/sandbox/preamble_test.go new file mode 100644 index 0000000000..e00fd9d562 --- /dev/null +++ b/router/internal/codemode/sandbox/preamble_test.go @@ -0,0 +1,94 @@ +package sandbox + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBuildPreambleGolden(t *testing.T) { + got := buildPreamble("async () => ({ ok: true })") + + want := `"use strict"; + +const tools = {}; +for (const name of __HOST_TOOL_NAMES) { + tools[name] = async (vars) => { + const __hostPayload = await __hostInvokeTool(name, vars); + const __hostResult = JSON.parse(__hostPayload); + if (__hostResult?.__codemodeHostError) { + const e = new Error(__hostResult.__codemodeHostError.message); + e.name = __hostResult.__codemodeHostError.name; + throw e; + } + return __hostResult; + }; +} +Object.freeze(tools); +globalThis.tools = tools; + +const __consoleErr = () => { + const e = new Error( + "console is not available in this sandbox. " + + "Include diagnostics in your return value, e.g. ` + "`return { result, debug: { ... } }`" + `." + ); + e.name = "ConsoleUnavailable"; + throw e; +}; +globalThis.console = new Proxy({}, { + get: () => __consoleErr, +}); + +Math.random = () => 0; +Date.now = () => 0; + +const __OrigDate = Date; +const __PinnedDate = function Date(...args) { + return args.length === 0 ? new __OrigDate(0) : new __OrigDate(...args); +}; +Object.setPrototypeOf(__PinnedDate, __OrigDate); +__PinnedDate.prototype = __OrigDate.prototype; +__PinnedDate.now = () => 0; +__PinnedDate.UTC = __OrigDate.UTC; +__PinnedDate.parse = __OrigDate.parse; +globalThis.Date = __PinnedDate; + +globalThis.notNull = (v, msg) => { + if (v == null) throw new Error(msg ?? "notNull: value was null/undefined"); + return v; +}; +globalThis.compact = (v) => { + if (Array.isArray(v)) return v.map(compact).filter((x) => x != null); + if (v && typeof v === "object") { + const out = {}; + for (const k in v) { + const c = compact(v[k]); + if (c != null) out[k] = c; + } + return out; + } + return v; +}; + +delete globalThis.eval; +delete globalThis.Function; +// Also remove indirect access via the Function constructor on the function prototype. +// (Function.prototype.constructor still exists per JS spec, but with eval/Function deleted +// it no longer resolves to a usable constructor.) + +// Splice point: Execute.WrappedJS is already harness-wrapped and transpiled. +const __agentMain = (async () => ({ ok: true })); +(async () => { + try { return { ok: true, result: await __agentMain() }; } + catch (err) { + return { ok: false, error: { name: err?.name ?? "Error", message: err?.message ?? String(err), stack: err?.stack ?? "", cause: err?.cause } }; + } +})() +` + assert.Equal(t, want, got) +} + +func TestBuildPreambleReportsUserCodeStartLine(t *testing.T) { + got := userCodeStartLine(buildPreamble("async () => 1")) + assert.Equal(t, 69, got) +} diff --git a/router/internal/codemode/sandbox/sandbox.go b/router/internal/codemode/sandbox/sandbox.go new file mode 100644 index 0000000000..7866eda79d --- /dev/null +++ b/router/internal/codemode/sandbox/sandbox.go @@ -0,0 +1,177 @@ +package sandbox + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "time" + + "github.com/hashicorp/go-retryablehttp" + "github.com/wundergraph/cosmo/router/internal/codemode/storage" + "go.uber.org/zap" +) + +const ( + defaultRequestTimeout = 5 * time.Second + defaultMemoryLimitBytes = 16 << 20 + defaultMaxInputSizeBytes = 64 << 10 + defaultMaxOutputSizeBytes = 1 << 20 + defaultMaxResultBytes = 32 << 10 + defaultMaxToolInvocationsPerCall = 256 + defaultMaxResponseBodyBytes = 10 << 20 + defaultRetryAttempts = 3 + defaultRetryCeiling = 60 * time.Second + defaultMaxConcurrent = 4 +) + +type Sandbox struct { + cfg Config + sem chan struct{} + http *http.Client + allowList map[string]struct{} +} + +type Config struct { + RouterGraphQLEndpoint string + RequestTimeout time.Duration + MemoryLimitBytes int + MaxInputSizeBytes int + MaxOutputSizeBytes int + MaxResultBytes int + MaxToolInvocationsPerCall int + MaxResponseBodyBytes int + RetryAttempts int + RetryCeiling time.Duration + MaxConcurrent int + HeaderAllowList []string + StorageLookup func(ctx context.Context, sessionID string, name string) (storage.SessionOp, bool, error) + Logger *zap.Logger + Now func() time.Time + HTTPClient *http.Client +} + +type ExecuteRequest struct { + SessionID string + ToolNames []string + WrappedJS string + SourceMap []byte + RequestHeaders http.Header + ApprovalGate ApprovalGate +} + +type ExecuteResult struct { + OK bool + Result json.RawMessage + Error *ErrorEnvelope + Warnings []SerializationWarning + Truncated bool + OutputSize int + HostCalls int +} + +type ErrorEnvelope struct { + Name string `json:"name"` + Message string `json:"message"` + Stack string `json:"stack"` + Cause *ErrorEnvelope `json:"cause,omitempty"` +} + +// SerializationWarning records a non-serializable value found in the script's +// return value. The bad value is replaced in the response with the sentinel +// string "<>" where KIND matches the reported Kind. +type SerializationWarning struct { + Path string `json:"path"` + Kind string `json:"kind"` +} + +type ApprovalGate interface { + Decide(ctx context.Context, req ApprovalRequest) (ApprovalDecision, error) +} + +type ApprovalRequest struct { + Name string + Source string + Vars json.RawMessage +} + +type ApprovalDecision struct { + Approved bool + Reason string +} + +type approveAllGate struct{} + +var AutoApprove ApprovalGate = approveAllGate{} + +func (approveAllGate) Decide(context.Context, ApprovalRequest) (ApprovalDecision, error) { + return ApprovalDecision{Approved: true}, nil +} + +func New(cfg Config) (*Sandbox, error) { + cfg = withDefaults(cfg) + if cfg.MaxConcurrent <= 0 { + return nil, errors.New("sandbox max concurrent must be positive") + } + if cfg.StorageLookup == nil { + cfg.StorageLookup = func(context.Context, string, string) (storage.SessionOp, bool, error) { + return storage.SessionOp{}, false, nil + } + } + if cfg.Logger == nil { + cfg.Logger = zap.NewNop() + } + if cfg.Now == nil { + cfg.Now = time.Now + } + + client := cfg.HTTPClient + if client == nil { + retryClient := retryablehttp.NewClient() + retryClient.RetryMax = cfg.RetryAttempts + retryClient.RetryWaitMax = cfg.RetryCeiling + retryClient.Logger = nil + client = retryClient.StandardClient() + } + + return &Sandbox{ + cfg: cfg, + sem: make(chan struct{}, cfg.MaxConcurrent), + http: client, + allowList: headerAllowList(cfg.HeaderAllowList), + }, nil +} + +func withDefaults(cfg Config) Config { + if cfg.RequestTimeout <= 0 { + cfg.RequestTimeout = defaultRequestTimeout + } + if cfg.MemoryLimitBytes <= 0 { + cfg.MemoryLimitBytes = defaultMemoryLimitBytes + } + if cfg.MaxInputSizeBytes <= 0 { + cfg.MaxInputSizeBytes = defaultMaxInputSizeBytes + } + if cfg.MaxOutputSizeBytes <= 0 { + cfg.MaxOutputSizeBytes = defaultMaxOutputSizeBytes + } + if cfg.MaxResultBytes <= 0 { + cfg.MaxResultBytes = defaultMaxResultBytes + } + if cfg.MaxToolInvocationsPerCall <= 0 { + cfg.MaxToolInvocationsPerCall = defaultMaxToolInvocationsPerCall + } + if cfg.MaxResponseBodyBytes <= 0 { + cfg.MaxResponseBodyBytes = defaultMaxResponseBodyBytes + } + if cfg.RetryAttempts <= 0 { + cfg.RetryAttempts = defaultRetryAttempts + } + if cfg.RetryCeiling <= 0 { + cfg.RetryCeiling = defaultRetryCeiling + } + if cfg.MaxConcurrent <= 0 { + cfg.MaxConcurrent = defaultMaxConcurrent + } + return cfg +} diff --git a/router/internal/codemode/sandbox/sandbox_preamble.js b/router/internal/codemode/sandbox/sandbox_preamble.js new file mode 100644 index 0000000000..32ee04e1a4 --- /dev/null +++ b/router/internal/codemode/sandbox/sandbox_preamble.js @@ -0,0 +1,75 @@ +"use strict"; + +const tools = {}; +for (const name of __HOST_TOOL_NAMES) { + tools[name] = async (vars) => { + const __hostPayload = await __hostInvokeTool(name, vars); + const __hostResult = JSON.parse(__hostPayload); + if (__hostResult?.__codemodeHostError) { + const e = new Error(__hostResult.__codemodeHostError.message); + e.name = __hostResult.__codemodeHostError.name; + throw e; + } + return __hostResult; + }; +} +Object.freeze(tools); +globalThis.tools = tools; + +const __consoleErr = () => { + const e = new Error( + "console is not available in this sandbox. " + + "Include diagnostics in your return value, e.g. `return { result, debug: { ... } }`." + ); + e.name = "ConsoleUnavailable"; + throw e; +}; +globalThis.console = new Proxy({}, { + get: () => __consoleErr, +}); + +Math.random = () => 0; +Date.now = () => 0; + +const __OrigDate = Date; +const __PinnedDate = function Date(...args) { + return args.length === 0 ? new __OrigDate(0) : new __OrigDate(...args); +}; +Object.setPrototypeOf(__PinnedDate, __OrigDate); +__PinnedDate.prototype = __OrigDate.prototype; +__PinnedDate.now = () => 0; +__PinnedDate.UTC = __OrigDate.UTC; +__PinnedDate.parse = __OrigDate.parse; +globalThis.Date = __PinnedDate; + +globalThis.notNull = (v, msg) => { + if (v == null) throw new Error(msg ?? "notNull: value was null/undefined"); + return v; +}; +globalThis.compact = (v) => { + if (Array.isArray(v)) return v.map(compact).filter((x) => x != null); + if (v && typeof v === "object") { + const out = {}; + for (const k in v) { + const c = compact(v[k]); + if (c != null) out[k] = c; + } + return out; + } + return v; +}; + +delete globalThis.eval; +delete globalThis.Function; +// Also remove indirect access via the Function constructor on the function prototype. +// (Function.prototype.constructor still exists per JS spec, but with eval/Function deleted +// it no longer resolves to a usable constructor.) + +// Splice point: Execute.WrappedJS is already harness-wrapped and transpiled. +const __agentMain = (__AGENT_MAIN_SPLICE__); +(async () => { + try { return { ok: true, result: await __agentMain() }; } + catch (err) { + return { ok: false, error: { name: err?.name ?? "Error", message: err?.message ?? String(err), stack: err?.stack ?? "", cause: err?.cause } }; + } +})() diff --git a/router/internal/codemode/sandbox/sandbox_test.go b/router/internal/codemode/sandbox/sandbox_test.go new file mode 100644 index 0000000000..48e8e8692a --- /dev/null +++ b/router/internal/codemode/sandbox/sandbox_test.go @@ -0,0 +1,691 @@ +package sandbox + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/evanw/esbuild/pkg/api" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/internal/codemode/storage" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" +) + +type DeclinedGate struct { + reason string +} + +func (g DeclinedGate) Decide(context.Context, ApprovalRequest) (ApprovalDecision, error) { + return ApprovalDecision{Approved: false, Reason: g.reason}, nil +} + +type nameDeclinedGate struct { + name string + reason string +} + +func (g nameDeclinedGate) Decide(_ context.Context, req ApprovalRequest) (ApprovalDecision, error) { + if req.Name == g.name { + return ApprovalDecision{Approved: false, Reason: g.reason}, nil + } + return ApprovalDecision{Approved: true}, nil +} + +type lookup map[string]storage.SessionOp + +func (l lookup) get(_ context.Context, _ string, name string) (storage.SessionOp, bool, error) { + op, ok := l[name] + return op, ok, nil +} + +func clientFunc(fn roundTripFunc) *http.Client { + return &http.Client{Transport: fn} +} + +func jsonResponse(status int, body string) *http.Response { + return &http.Response{ + StatusCode: status, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: ioNopCloser{bytes.NewBufferString(body)}, + } +} + +func newTestSandbox(t *testing.T, endpoint string, ops lookup, opts func(*Config)) *Sandbox { + t.Helper() + + cfg := Config{ + RouterGraphQLEndpoint: endpoint, + StorageLookup: ops.get, + RequestTimeout: 30 * time.Second, + RetryAttempts: 0, + } + if opts != nil { + opts(&cfg) + } + s, err := New(cfg) + require.NoError(t, err) + return s +} + +func execute(t *testing.T, s *Sandbox, req ExecuteRequest) ExecuteResult { + t.Helper() + + got, err := s.Execute(context.Background(), req) + require.NoError(t, err) + return got +} + +func raw(s string) json.RawMessage { + return json.RawMessage(s) +} + +func TestExecuteHappyPathToolCall(t *testing.T) { + var gotBody map[string]any + client := clientFunc(func(r *http.Request) (*http.Response, error) { + require.Equal(t, http.MethodPost, r.Method) + require.NoError(t, json.NewDecoder(r.Body).Decode(&gotBody)) + return jsonResponse(http.StatusOK, `{"data":{"order":{"id":"o1"}}}`), nil + }) + + s := newTestSandbox(t, "http://router/graphql", lookup{ + "getOrder": {Name: "getOrder", Body: "query GetOrder($id: ID!) { order(id: $id) { id } }", Kind: storage.OperationKindQuery}, + }, func(cfg *Config) { cfg.HTTPClient = client }) + + got := execute(t, s, ExecuteRequest{ + SessionID: "s1", + ToolNames: []string{"getOrder"}, + WrappedJS: `async () => { + return await tools.getOrder({ id: "o1" }); +}`, + }) + + assert.Equal(t, "getOrder", gotBody["operationName"]) + assert.Equal(t, ExecuteResult{ + OK: true, + Result: raw(`{"data":{"order":{"id":"o1"}}}`), + HostCalls: 1, + }, ExecuteResult{OK: got.OK, Result: got.Result, HostCalls: got.HostCalls}) +} + +func TestExecuteGraphQLErrorsResolveVerbatimAndRecordSpan(t *testing.T) { + client := clientFunc(func(r *http.Request) (*http.Response, error) { + return jsonResponse(http.StatusOK, `{"data":null,"errors":[{"message":"x"}]}`), nil + }) + + exporter := tracetest.NewInMemoryExporter() + tp := trace.NewTracerProvider(trace.WithSyncer(exporter)) + old := otel.GetTracerProvider() + otel.SetTracerProvider(tp) + defer otel.SetTracerProvider(old) + + s := newTestSandbox(t, "http://router/graphql", lookup{ + "getBroken": {Name: "getBroken", Body: "query Broken { broken }", Kind: storage.OperationKindQuery}, + }, func(cfg *Config) { cfg.HTTPClient = client }) + + ctx, span := otel.Tracer("sandbox-test").Start(context.Background(), "parent") + got, err := s.Execute(ctx, ExecuteRequest{ + SessionID: "s1", + ToolNames: []string{"getBroken"}, + WrappedJS: `async () => await tools.getBroken()`, + }) + span.End() + require.NoError(t, err) + + assert.Equal(t, ExecuteResult{ + OK: true, + Result: raw(`{"data":null,"errors":[{"message":"x"}]}`), + HostCalls: 1, + }, ExecuteResult{OK: got.OK, Result: got.Result, HostCalls: got.HostCalls}) + spans := exporter.GetSpans() + require.NotEmpty(t, spans) + var found bool + for _, sp := range spans { + for _, attr := range sp.Attributes { + if string(attr.Key) == "codemode.graphql.errors" && strings.Contains(attr.Value.AsString(), `"message":"x"`) { + found = true + } + } + } + assert.Equal(t, true, found) +} + +func TestExecuteHTTP500CanBeReturnedOrThrownByAgent(t *testing.T) { + client := clientFunc(func(r *http.Request) (*http.Response, error) { + return jsonResponse(http.StatusInternalServerError, `{"errors":[{"message":"upstream failed"}]}`), nil + }) + + s := newTestSandbox(t, "http://router/graphql", lookup{ + "getBroken": {Name: "getBroken", Body: "query Broken { broken }", Kind: storage.OperationKindQuery}, + }, func(cfg *Config) { cfg.HTTPClient = client }) + + returned := execute(t, s, ExecuteRequest{ + SessionID: "s1", + ToolNames: []string{"getBroken"}, + WrappedJS: `async () => await tools.getBroken()`, + }) + assert.Equal(t, ExecuteResult{ + OK: true, + Result: raw(`{"errors":[{"message":"upstream failed"}]}`), + HostCalls: 1, + }, ExecuteResult{OK: returned.OK, Result: returned.Result, HostCalls: returned.HostCalls}) + + thrown := execute(t, s, ExecuteRequest{ + SessionID: "s1", + ToolNames: []string{"getBroken"}, + WrappedJS: `async () => { + const r = await tools.getBroken(); + if (r.errors?.length) throw new Error(r.errors[0].message); + return r; +}`, + }) + assert.Equal(t, false, thrown.OK) + require.NotNil(t, thrown.Error) + assert.Equal(t, "Error", thrown.Error.Name) + assert.Equal(t, "upstream failed", thrown.Error.Message) + assert.Equal(t, 1, thrown.HostCalls) +} + +func TestExecuteConsoleUnavailable(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, nil) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => { console.log("x"); }`}) + + assert.Equal(t, false, got.OK) + require.NotNil(t, got.Error) + assert.Equal(t, ErrorEnvelope{ + Name: "ConsoleUnavailable", + Message: "console is not available in this sandbox. Include diagnostics in your return value, e.g. `return { result, debug: { ... } }`.", + Stack: got.Error.Stack, + }, *got.Error) +} + +func TestExecuteEvalAndFunctionRemoved(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, nil) + + tests := []struct { + name string + wrappedJS string + want json.RawMessage + }{ + { + name: "typeof eval", + wrappedJS: `async () => { return typeof eval; }`, + want: raw(`"undefined"`), + }, + { + name: "typeof Function", + wrappedJS: `async () => { return typeof Function; }`, + want: raw(`"undefined"`), + }, + { + name: "indirect eval", + wrappedJS: `async () => { try { (0, eval)("1+1"); return "ok"; } catch (e) { return e.name + ":" + e.message; } }`, + want: raw(`"ReferenceError:eval is not defined"`), + }, + { + name: "new Function", + wrappedJS: `async () => { try { new Function("return 1"); return "ok"; } catch (e) { return e.name + ":" + e.message; } }`, + want: raw(`"ReferenceError:Function is not defined"`), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := execute(t, s, ExecuteRequest{WrappedJS: tt.wrappedJS}) + + assert.Equal(t, ExecuteResult{OK: true, Result: tt.want}, ExecuteResult{OK: got.OK, Result: got.Result}) + }) + } +} + +func TestExecuteDeterministicDateAndRandom(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, nil) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => ({ + random: Math.random(), + now: Date.now(), + epoch: new Date().getTime(), + parsed: new Date(123).getTime() +})`}) + + assert.Equal(t, ExecuteResult{ + OK: true, + Result: raw(`{"random":0,"now":0,"epoch":0,"parsed":123}`), + }, ExecuteResult{OK: got.OK, Result: got.Result}) +} + +func TestExecuteAllowsConfiguredHostCallCapAndThrowsOnNextCall(t *testing.T) { + var calls atomic.Int32 + client := clientFunc(func(r *http.Request) (*http.Response, error) { + calls.Add(1) + return jsonResponse(http.StatusOK, `{"data":{"ok":true}}`), nil + }) + s := newTestSandbox(t, "http://router/graphql", lookup{ + "foo": {Name: "foo", Body: "query Foo { foo }", Kind: storage.OperationKindQuery}, + }, func(cfg *Config) { cfg.HTTPClient = client }) + + withinCap := execute(t, s, ExecuteRequest{ + SessionID: "s1", + ToolNames: []string{"foo"}, + WrappedJS: `async () => { + for (let i = 0; i < 256; i++) await tools.foo({}); + return "ok"; +}`, + }) + + assert.Equal(t, ExecuteResult{ + OK: true, + Result: raw(`"ok"`), + HostCalls: 256, + }, ExecuteResult{OK: withinCap.OK, Result: withinCap.Result, HostCalls: withinCap.HostCalls}) + assert.Equal(t, int32(256), calls.Load()) + + got := execute(t, s, ExecuteRequest{ + SessionID: "s1", + ToolNames: []string{"foo"}, + WrappedJS: `async () => { + for (let i = 0; i < 257; i++) await tools.foo({}); + return null; +}`, + }) + + assert.Equal(t, false, got.OK) + require.NotNil(t, got.Error) + assert.Equal(t, "HostCallLimitExceeded", got.Error.Name) + assert.Equal(t, "tools.* invocation cap of 256 exceeded; batch independent calls with Promise.all.", got.Error.Message) + assert.Equal(t, 257, got.HostCalls) + assert.Equal(t, int32(512), calls.Load()) +} + +func TestExecutePromiseAllToolCallsRunInParallel(t *testing.T) { + var calls atomic.Int32 + client := clientFunc(func(r *http.Request) (*http.Response, error) { + calls.Add(1) + return jsonResponse(http.StatusOK, `{"data":{"ok":true}}`), nil + }) + s := newTestSandbox(t, "http://router/graphql", lookup{ + "ping": {Name: "ping", Body: "query Ping { ping }", Kind: storage.OperationKindQuery}, + }, func(cfg *Config) { cfg.HTTPClient = client }) + + got := execute(t, s, ExecuteRequest{ + SessionID: "s1", + ToolNames: []string{"ping"}, + WrappedJS: `async () => Promise.all([tools.ping(), tools.ping(), tools.ping(), tools.ping()])`, + }) + + assert.Equal(t, true, got.OK) + assert.Equal(t, 4, got.HostCalls) + assert.Equal(t, int32(4), calls.Load()) +} + +func TestExecuteAcceptsTopLevelAwaitStringAsHarnessDeviation(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, nil) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => await Promise.resolve(1)`}) + + assert.Equal(t, ExecuteResult{OK: true, Result: raw(`1`)}, ExecuteResult{OK: got.OK, Result: got.Result}) +} + +func TestExecuteWallClockTimeout(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, func(cfg *Config) { + cfg.RequestTimeout = 25 * time.Millisecond + }) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => await new Promise(() => {})`}) + + assert.Equal(t, false, got.OK) + require.NotNil(t, got.Error) + assert.Equal(t, "Timeout", got.Error.Name) +} + +func TestExecuteMemoryLimit(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, func(cfg *Config) { + cfg.MemoryLimitBytes = 2 << 20 + }) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => { + const xs = []; + for (let i = 0; i < 1000000; i++) xs.push("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"); + return xs.length; +}`}) + + assert.Equal(t, false, got.OK) + require.NotNil(t, got.Error) + assert.Equal(t, "MemoryLimit", got.Error.Name) +} + +func TestExecuteSanitizesNonSerializableField(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, nil) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => ({ x: () => 1 })`}) + + assert.Equal(t, true, got.OK) + assert.Nil(t, got.Error) + assert.Equal(t, json.RawMessage(`{"x":"<>"}`), got.Result) + assert.Equal(t, []SerializationWarning{{Path: "$.x", Kind: "function"}}, got.Warnings) +} + +func TestExecuteSanitizesMixedNonSerializableValues(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, nil) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => { return { x: () => 1, y: 5n, cycle: (() => { const o = {}; o.self = o; return o; })() }; }`}) + + assert.Equal(t, true, got.OK) + assert.Nil(t, got.Error) + assert.Equal(t, json.RawMessage(`{"x":"<>","y":"<>","cycle":{"self":"<>"}}`), got.Result) + assert.Equal(t, []SerializationWarning{ + {Path: "$.x", Kind: "function"}, + {Path: "$.y", Kind: "bigint"}, + {Path: "$.cycle.self", Kind: "cycle"}, + }, got.Warnings) +} + +func TestExecuteSanitizesRootBigInt(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, nil) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => 5n`}) + + assert.Equal(t, true, got.OK) + assert.Nil(t, got.Error) + assert.Equal(t, json.RawMessage(`"<>"`), got.Result) + assert.Equal(t, []SerializationWarning{{Path: "$", Kind: "bigint"}}, got.Warnings) +} + +func TestExecuteSanitizesRootUndefined(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, nil) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => undefined`}) + + assert.Equal(t, true, got.OK) + assert.Nil(t, got.Error) + assert.Equal(t, json.RawMessage(`"<>"`), got.Result) + assert.Equal(t, []SerializationWarning{{Path: "$", Kind: "undefined"}}, got.Warnings) +} + +func TestExecuteSanitizesNonSerializableInArray(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, nil) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => [1, undefined, () => 2]`}) + + assert.Equal(t, true, got.OK) + assert.Nil(t, got.Error) + assert.Equal(t, json.RawMessage(`[1,"<>","<>"]`), got.Result) + assert.Equal(t, []SerializationWarning{ + {Path: "$[1]", Kind: "undefined"}, + {Path: "$[2]", Kind: "function"}, + }, got.Warnings) +} + +func TestExecuteCleanResultProducesNoWarnings(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, nil) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => ({ ok: true, n: 1, items: [1, 2, 3] })`}) + + assert.Equal(t, true, got.OK) + assert.Nil(t, got.Error) + assert.Equal(t, json.RawMessage(`{"ok":true,"n":1,"items":[1,2,3]}`), got.Result) + assert.Equal(t, []SerializationWarning(nil), got.Warnings) +} + +func TestExecuteOutputTooLarge(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, func(cfg *Config) { + cfg.MaxOutputSizeBytes = 10 + }) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => "this is too large"`}) + + assert.Equal(t, false, got.OK) + require.NotNil(t, got.Error) + assert.Equal(t, "OutputTooLarge", got.Error.Name) + assert.Contains(t, got.Error.Message, "encoded result size") +} + +func TestExecuteErrorCauseChain(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, nil) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => { + throw new Error("a", { cause: new Error("b", { cause: new Error("c") }) }); +}`}) + + assert.Equal(t, false, got.OK) + require.NotNil(t, got.Error) + assert.Equal(t, "a", got.Error.Message) + require.NotNil(t, got.Error.Cause) + assert.Equal(t, "b", got.Error.Cause.Message) + require.NotNil(t, got.Error.Cause.Cause) + assert.Equal(t, "c", got.Error.Cause.Cause.Message) + assert.Nil(t, got.Error.Cause.Cause.Cause) +} + +func TestExecuteErrorCauseChainTruncatesAfterDepthFive(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, nil) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => { + let err = new Error("7"); + for (let i = 6; i >= 1; i--) err = new Error(String(i), { cause: err }); + throw err; +}`}) + + assert.Equal(t, false, got.OK) + require.NotNil(t, got.Error) + cause := got.Error + for range 5 { + require.NotNil(t, cause.Cause) + cause = cause.Cause + } + assert.Equal(t, "TruncatedCause", cause.Name) + assert.Equal(t, "cause chain exceeded depth 5", cause.Message) +} + +func TestExecuteSourceMapRewrite(t *testing.T) { + ts := "async () => {\n const x: number = 1;\n throw new Error(\"boom\");\n}" + transformed := api.Transform(ts, api.TransformOptions{ + Loader: api.LoaderTS, + Sourcemap: api.SourceMapExternal, + Sourcefile: "agent.ts", + }) + require.Empty(t, transformed.Errors) + js := strings.TrimSpace(string(transformed.Code)) + js = strings.TrimSuffix(js, ";") + s := newTestSandbox(t, "", lookup{}, nil) + + got := execute(t, s, ExecuteRequest{WrappedJS: js, SourceMap: []byte(transformed.Map)}) + + assert.Equal(t, false, got.OK) + require.NotNil(t, got.Error) + assert.Contains(t, got.Error.Stack, "agent.ts:3:") +} + +func TestExecuteMutationApprovalDeclined(t *testing.T) { + var calls atomic.Int32 + client := clientFunc(func(r *http.Request) (*http.Response, error) { + calls.Add(1) + return jsonResponse(http.StatusOK, `{"data":{"ok":true}}`), nil + }) + s := newTestSandbox(t, "http://router/graphql", lookup{ + "deleteOrder": {Name: "deleteOrder", Body: "mutation DeleteOrder { deleteOrder }", Kind: storage.OperationKindMutation}, + }, func(cfg *Config) { cfg.HTTPClient = client }) + + got := execute(t, s, ExecuteRequest{ + SessionID: "s1", + ToolNames: []string{"deleteOrder"}, + ApprovalGate: DeclinedGate{reason: "no thanks"}, + WrappedJS: `async () => await tools.deleteOrder({ id: "o1" })`, + RequestHeaders: http.Header{}, + }) + + assert.Equal(t, int32(0), calls.Load()) + assert.Equal(t, ExecuteResult{ + OK: true, + Result: raw(`{"data":null,"declined":{"reason":"no thanks"},"errors":[{"message":"Mutation declined by operator: no thanks"}]}`), + HostCalls: 1, + }, ExecuteResult{OK: got.OK, Result: got.Result, HostCalls: got.HostCalls}) +} + +func TestExecuteSpecificMutationApprovalDeclinedReturnsStructuredValue(t *testing.T) { + var calls atomic.Int32 + client := clientFunc(func(r *http.Request) (*http.Response, error) { + calls.Add(1) + return jsonResponse(http.StatusOK, `{"data":{"ok":true}}`), nil + }) + s := newTestSandbox(t, "http://router/graphql", lookup{ + "deleteOrders": {Name: "deleteOrders", Body: "mutation DeleteOrders($id: ID!) { deleteOrders(id: $id) }", Kind: storage.OperationKindMutation}, + }, func(cfg *Config) { cfg.HTTPClient = client }) + + got := execute(t, s, ExecuteRequest{ + SessionID: "s1", + ToolNames: []string{"deleteOrders"}, + ApprovalGate: nameDeclinedGate{name: "deleteOrders", reason: "policy forbids"}, + WrappedJS: `async () => { const r = await tools.deleteOrders({id:"x"}); return r; }`, + RequestHeaders: http.Header{}, + }) + + assert.Equal(t, int32(0), calls.Load()) + assert.Equal(t, ExecuteResult{ + OK: true, + Result: raw(`{"data":null,"declined":{"reason":"policy forbids"},"errors":[{"message":"Mutation declined by operator: policy forbids"}]}`), + HostCalls: 1, + }, ExecuteResult{OK: got.OK, Result: got.Result, HostCalls: got.HostCalls}) +} + +func TestExecuteHeaderAllowList(t *testing.T) { + seen := make(chan http.Header, 1) + client := clientFunc(func(r *http.Request) (*http.Response, error) { + seen <- r.Header.Clone() + return jsonResponse(http.StatusOK, `{"data":{"ok":true}}`), nil + }) + s := newTestSandbox(t, "http://router/graphql", lookup{ + "ping": {Name: "ping", Body: "query Ping { ping }", Kind: storage.OperationKindQuery}, + }, func(cfg *Config) { + cfg.HeaderAllowList = []string{"Authorization", "X-Trace"} + cfg.HTTPClient = client + }) + + got := execute(t, s, ExecuteRequest{ + SessionID: "s1", + ToolNames: []string{"ping"}, + WrappedJS: `async () => await tools.ping()`, + RequestHeaders: http.Header{ + "Authorization": []string{"Bearer token"}, + "X-Trace": []string{"trace-1"}, + "X-Skip": []string{"skip"}, + "Connection": []string{"keep-alive"}, + }, + }) + + headers := <-seen + assert.Equal(t, true, got.OK) + assert.Equal(t, "Bearer token", headers.Get("Authorization")) + assert.Equal(t, "trace-1", headers.Get("X-Trace")) + assert.Equal(t, "", headers.Get("X-Skip")) + assert.Equal(t, "", headers.Get("Connection")) + assert.Equal(t, "application/json", headers.Get("Content-Type")) +} + +func TestExecuteSemaphoreBoundsConcurrency(t *testing.T) { + var active atomic.Int32 + var maxActive atomic.Int32 + started := make(chan struct{}, 5) + release := make(chan struct{}) + client := &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + now := active.Add(1) + for { + max := maxActive.Load() + if now <= max || maxActive.CompareAndSwap(max, now) { + break + } + } + started <- struct{}{} + <-release + active.Add(-1) + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: ioNopCloser{bytes.NewBufferString(`{"data":{"ok":true}}`)}, + }, nil + })} + s := newTestSandbox(t, "http://router/graphql", lookup{ + "ping": {Name: "ping", Body: "query Ping { ping }", Kind: storage.OperationKindQuery}, + }, func(cfg *Config) { + cfg.MaxConcurrent = 4 + cfg.HTTPClient = client + }) + + var wg sync.WaitGroup + for range 5 { + wg.Add(1) + go func() { + defer wg.Done() + _, err := s.Execute(context.Background(), ExecuteRequest{ + SessionID: "s1", + ToolNames: []string{"ping"}, + WrappedJS: `async () => await tools.ping()`, + }) + assert.NoError(t, err) + }() + } + + for range 4 { + <-started + } + assert.Equal(t, int32(4), maxActive.Load()) + assert.Equal(t, int32(4), active.Load()) + select { + case <-started: + t.Fatal("fifth Execute entered before a semaphore slot was released") + default: + } + close(release) + wg.Wait() + assert.Equal(t, int32(4), maxActive.Load()) +} + +func TestExecuteFrozenToolsAssignmentThrowsInStrictMode(t *testing.T) { + s := newTestSandbox(t, "", lookup{ + "foo": {Name: "foo", Body: "query Foo { foo }", Kind: storage.OperationKindQuery}, + }, nil) + + got := execute(t, s, ExecuteRequest{ToolNames: []string{"foo"}, WrappedJS: `async () => { + tools.foo = () => null; + return tools.foo === null; +}`}) + + assert.Equal(t, false, got.OK) + require.NotNil(t, got.Error) + assert.Equal(t, "TypeError", got.Error.Name) +} + +func TestExecuteUnknownToolName(t *testing.T) { + s := newTestSandbox(t, "", lookup{}, nil) + + got := execute(t, s, ExecuteRequest{WrappedJS: `async () => await tools.nope()`}) + + assert.Equal(t, false, got.OK) + require.NotNil(t, got.Error) + assert.Equal(t, "TypeError", got.Error.Name) + // qjs reports native missing-method calls in this form for plain objects. + assert.Equal(t, "tools.nope is not a function", got.Error.Message) +} + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return f(r) +} + +type ioNopCloser struct { + *bytes.Buffer +} + +func (c ioNopCloser) Close() error { + return nil +} diff --git a/router/internal/codemode/sandbox/semaphore.go b/router/internal/codemode/sandbox/semaphore.go new file mode 100644 index 0000000000..3677255c7b --- /dev/null +++ b/router/internal/codemode/sandbox/semaphore.go @@ -0,0 +1,16 @@ +package sandbox + +import "context" + +func (s *Sandbox) acquire(ctx context.Context) error { + select { + case s.sem <- struct{}{}: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (s *Sandbox) release() { + <-s.sem +} diff --git a/router/internal/codemode/sandbox/validation.go b/router/internal/codemode/sandbox/validation.go new file mode 100644 index 0000000000..dae3ca74f3 --- /dev/null +++ b/router/internal/codemode/sandbox/validation.go @@ -0,0 +1,110 @@ +package sandbox + +import ( + "encoding/json" + "fmt" + + "github.com/fastschema/qjs" +) + +const validationHelpers = ` +globalThis.__codemodeNormalizeError = (err, depth = 0) => { + if (!err) return null; + if (depth >= 5) return { name: "TruncatedCause", message: "cause chain exceeded depth 5", stack: "" }; + return { + name: err?.name ?? "Error", + message: err?.message ?? String(err), + stack: err?.stack ?? "", + cause: err?.cause ? __codemodeNormalizeError(err.cause, depth + 1) : null, + }; +}; +globalThis.__codemodeNormalizeErrorJSON = (err) => JSON.stringify(__codemodeNormalizeError(err)); + +globalThis.__codemodeValidateResult = (value) => { + const warnings = []; + const seen = new WeakSet(); + const keyPath = (base, key) => { + if (typeof key === "number") return base + "[" + key + "]"; + return /^[A-Za-z_$][A-Za-z0-9_$]*$/.test(key) ? base + "." + key : base + "[" + JSON.stringify(key) + "]"; + }; + const sentinel = (kind) => "<>"; + const walk = (v, path, parent, key) => { + const t = typeof v; + if (t === "bigint" || t === "function" || t === "symbol" || t === "undefined") { + parent[key] = sentinel(t); + warnings.push({ path, kind: t }); + return; + } + if (v && t === "object") { + if (seen.has(v)) { + parent[key] = sentinel("cycle"); + warnings.push({ path, kind: "cycle" }); + return; + } + seen.add(v); + if (Array.isArray(v)) { + for (let i = 0; i < v.length; i++) walk(v[i], keyPath(path, i), v, i); + return; + } + for (const k of Object.keys(v)) walk(v[k], keyPath(path, k), v, k); + } + }; + const root = { value }; + walk(root.value, "$", root, "value"); + try { + const json = JSON.stringify(root.value); + if (json === undefined) { + return JSON.stringify({ ok: false, warnings, error: "value serialized to undefined" }); + } + return JSON.stringify({ ok: true, json, warnings }); + } catch (err) { + const msg = err && err.message ? String(err.message) : String(err); + return JSON.stringify({ ok: false, warnings, error: msg }); + } +}; +` + +type validationOutcome struct { + OK bool `json:"ok"` + JSON string `json:"json"` + Warnings []SerializationWarning `json:"warnings"` + Error string `json:"error"` +} + +func installValidationHelpers(ctx *qjs.Context) error { + val, err := ctx.Eval("codemode_validation.js", qjs.Code(validationHelpers)) + _ = val + return err +} + +func validateResult(ctx *qjs.Context, result *qjs.Value, maxOutputBytes int) (json.RawMessage, []SerializationWarning, *ErrorEnvelope, error) { + global := ctx.Global() + validator := global.GetPropertyStr("__codemodeValidateResult") + encoded, err := ctx.Invoke(validator, global, result) + if err != nil { + return nil, nil, nil, err + } + + var outcome validationOutcome + if err := json.Unmarshal([]byte(encoded.String()), &outcome); err != nil { + return nil, nil, nil, err + } + if len(outcome.Warnings) == 0 { + outcome.Warnings = nil + } + if !outcome.OK { + message := "JSON serialization failed after sanitization" + if outcome.Error != "" { + message = message + ": " + outcome.Error + } + return nil, outcome.Warnings, &ErrorEnvelope{Name: "NotSerializable", Message: message, Stack: ""}, nil + } + if len(outcome.JSON) > maxOutputBytes { + return nil, outcome.Warnings, &ErrorEnvelope{ + Name: "OutputTooLarge", + Message: fmt.Sprintf("encoded result size %d bytes exceeds limit %d bytes", len(outcome.JSON), maxOutputBytes), + Stack: "", + }, nil + } + return json.RawMessage(outcome.JSON), outcome.Warnings, nil, nil +} diff --git a/router/internal/codemode/server/approval.go b/router/internal/codemode/server/approval.go new file mode 100644 index 0000000000..81d937c38a --- /dev/null +++ b/router/internal/codemode/server/approval.go @@ -0,0 +1,195 @@ +package server + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "unicode/utf8" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/wundergraph/cosmo/router/internal/codemode/observability" + "github.com/wundergraph/cosmo/router/internal/codemode/sandbox" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astprinter" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" +) + +const defaultMutationDeclinedReason = "Mutation declined by operator" + +// Elicitor is the testable subset of the MCP elicitation API used by mutation approval. +type Elicitor interface { + Elicit(ctx context.Context, params ElicitParams) (ElicitResponse, error) +} + +type ElicitParams struct { + Message string + RequestedSchema any +} + +type ElicitResponse struct { + Action string + FormData map[string]any +} + +type ElicitationGate struct { + elicitor Elicitor + logger *zap.Logger +} + +func NewElicitationGate(elicitor Elicitor, logger *zap.Logger) *ElicitationGate { + if logger == nil { + logger = zap.NewNop() + } + return &ElicitationGate{elicitor: elicitor, logger: logger} +} + +func (g *ElicitationGate) Decide(ctx context.Context, req sandbox.ApprovalRequest) (sandbox.ApprovalDecision, error) { + if g == nil || g.elicitor == nil { + decision := unsupportedElicitationDecision(errors.New("elicitor is not configured")) + recordMutationApproval(ctx, decision) + observability.LogElicitationOutcome(g.logger, SessionIDFromContext(ctx), decision.Approved, decision.Reason) + return decision, nil + } + + resp, err := g.elicitor.Elicit(ctx, ElicitParams{ + Message: mutationApprovalMessage(req), + RequestedSchema: mutationApprovalSchema(), + }) + if err != nil { + decision := unsupportedElicitationDecision(err) + recordMutationApproval(ctx, decision) + observability.LogElicitationOutcome(g.logger, SessionIDFromContext(ctx), decision.Approved, decision.Reason) + return decision, nil + } + + decision := decisionFromElicitation(resp) + recordMutationApproval(ctx, decision) + observability.LogElicitationOutcome(g.logger, SessionIDFromContext(ctx), decision.Approved, decision.Reason) + return decision, nil +} + +type MCPElicitor struct { + session *mcp.ServerSession +} + +func NewMCPElicitor(session *mcp.ServerSession) *MCPElicitor { + return &MCPElicitor{session: session} +} + +func (e *MCPElicitor) Elicit(ctx context.Context, params ElicitParams) (ElicitResponse, error) { + if e == nil || e.session == nil { + return ElicitResponse{}, errors.New("MCP server session is not available") + } + resp, err := e.session.Elicit(ctx, &mcp.ElicitParams{ + Message: params.Message, + RequestedSchema: params.RequestedSchema, + }) + if err != nil { + return ElicitResponse{}, err + } + if resp == nil { + return ElicitResponse{}, nil + } + return ElicitResponse{Action: resp.Action, FormData: resp.Content}, nil +} + +func decisionFromElicitation(resp ElicitResponse) sandbox.ApprovalDecision { + if resp.Action != "accept" || resp.FormData == nil { + return sandbox.ApprovalDecision{Approved: false, Reason: defaultMutationDeclinedReason} + } + if approved, ok := resp.FormData["approved"].(bool); ok && approved { + return sandbox.ApprovalDecision{Approved: true} + } + reason, _ := resp.FormData["reason"].(string) + return sandbox.ApprovalDecision{Approved: false, Reason: sanitizeMutationApprovalReason(reason)} +} + +func unsupportedElicitationDecision(err error) sandbox.ApprovalDecision { + return sandbox.ApprovalDecision{ + Approved: false, + Reason: fmt.Sprintf("mutation approval is required but the MCP client does not support elicitation: %s", err), + } +} + +func mutationApprovalSchema() map[string]any { + return map[string]any{ + "type": "object", + "required": []string{"approved"}, + "properties": map[string]any{ + "approved": map[string]any{"type": "boolean"}, + "reason": map[string]any{"type": "string", "maxLength": 500}, + }, + } +} + +func mutationApprovalMessage(req sandbox.ApprovalRequest) string { + return fmt.Sprintf( + "Approve GraphQL mutation %q?\n\nGraphQL mutation:\n\n%s\n\nVariables:\n\n%s", + req.Name, + prettyMutationSource(req.Source), + prettyMutationVariables(req.Vars), + ) +} + +// prettyMutationSource reformats a GraphQL operation body with two-space indentation. +// On any parse failure the original source is returned verbatim — operator-visible +// readability is best-effort, and we never want to swallow what they actually approve. +func prettyMutationSource(source string) string { + doc, report := astparser.ParseGraphqlDocumentString(source) + if report.HasErrors() { + return source + } + pretty, err := astprinter.PrintStringIndent(&doc, " ") + if err != nil { + return source + } + return pretty +} + +func prettyMutationVariables(vars json.RawMessage) string { + if len(vars) == 0 { + return "{}" + } + var decoded any + if err := json.Unmarshal(vars, &decoded); err != nil { + return string(vars) + } + pretty, err := json.MarshalIndent(decoded, "", " ") + if err != nil { + return string(vars) + } + return string(pretty) +} + +func sanitizeMutationApprovalReason(reason string) string { + var b strings.Builder + for len(reason) > 0 { + r, size := utf8.DecodeRuneInString(reason) + if r == utf8.RuneError && size == 1 { + reason = reason[size:] + continue + } + if r < 0x20 { + reason = reason[size:] + continue + } + if b.Len()+size > 500 { + break + } + b.WriteString(reason[:size]) + reason = reason[size:] + } + return b.String() +} + +func recordMutationApproval(ctx context.Context, decision sandbox.ApprovalDecision) { + span := trace.SpanFromContext(ctx) + span.SetAttributes( + attribute.Bool("code_mode.mutation.approved", decision.Approved), + attribute.String("code_mode.mutation.reason", decision.Reason), + ) +} diff --git a/router/internal/codemode/server/approval_test.go b/router/internal/codemode/server/approval_test.go new file mode 100644 index 0000000000..e15bb3a76e --- /dev/null +++ b/router/internal/codemode/server/approval_test.go @@ -0,0 +1,150 @@ +package server + +import ( + "context" + "encoding/json" + "errors" + "strings" + "testing" + "unicode/utf8" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/internal/codemode/sandbox" + "go.uber.org/zap" +) + +type fakeElicitor struct { + response ElicitResponse + err error + params ElicitParams +} + +func (f *fakeElicitor) Elicit(ctx context.Context, params ElicitParams) (ElicitResponse, error) { + f.params = params + if f.err != nil { + return ElicitResponse{}, f.err + } + return f.response, nil +} + +func TestElicitationGateAcceptApprovedTrue(t *testing.T) { + elicitor := &fakeElicitor{ + response: ElicitResponse{Action: "accept", FormData: map[string]any{"approved": true}}, + } + gate := NewElicitationGate(elicitor, zap.NewNop()) + + got, err := gate.Decide(context.Background(), sandbox.ApprovalRequest{ + Name: "deleteOrders", + Source: "mutation DeleteOrders { deleteOrders }", + Vars: json.RawMessage(`{"id":"x"}`), + }) + + require.NoError(t, err) + assert.Equal(t, sandbox.ApprovalDecision{Approved: true, Reason: ""}, got) + assert.Equal(t, map[string]any{ + "type": "object", + "required": []string{"approved"}, + "properties": map[string]any{ + "approved": map[string]any{"type": "boolean"}, + "reason": map[string]any{"type": "string", "maxLength": 500}, + }, + }, elicitor.params.RequestedSchema) + assert.Equal(t, "Approve GraphQL mutation \"deleteOrders\"?\n\nGraphQL mutation:\n\nmutation DeleteOrders {\n deleteOrders\n}\n\nVariables:\n\n{\n \"id\": \"x\"\n}", elicitor.params.Message) +} + +func TestElicitationGateAcceptApprovedFalseUsesReason(t *testing.T) { + elicitor := &fakeElicitor{ + response: ElicitResponse{Action: "accept", FormData: map[string]any{"approved": false, "reason": "no thanks"}}, + } + gate := NewElicitationGate(elicitor, zap.NewNop()) + + got, err := gate.Decide(context.Background(), sandbox.ApprovalRequest{Name: "deleteOrders"}) + + require.NoError(t, err) + assert.Equal(t, sandbox.ApprovalDecision{Approved: false, Reason: "no thanks"}, got) +} + +func TestElicitationGateAcceptApprovedFalseStripsControlCharacters(t *testing.T) { + elicitor := &fakeElicitor{ + response: ElicitResponse{Action: "accept", FormData: map[string]any{"approved": false, "reason": "no\x00 \x01thanks\x1f"}}, + } + gate := NewElicitationGate(elicitor, zap.NewNop()) + + got, err := gate.Decide(context.Background(), sandbox.ApprovalRequest{Name: "deleteOrders"}) + + require.NoError(t, err) + assert.Equal(t, sandbox.ApprovalDecision{Approved: false, Reason: "no thanks"}, got) +} + +func TestElicitationGateAcceptApprovedFalseTruncatesReasonUTF8Safely(t *testing.T) { + elicitor := &fakeElicitor{ + response: ElicitResponse{Action: "accept", FormData: map[string]any{"approved": false, "reason": strings.Repeat("é", 300)}}, + } + gate := NewElicitationGate(elicitor, zap.NewNop()) + + got, err := gate.Decide(context.Background(), sandbox.ApprovalRequest{Name: "deleteOrders"}) + + require.NoError(t, err) + assert.Equal(t, sandbox.ApprovalDecision{Approved: false, Reason: strings.Repeat("é", 250)}, got) + assert.Equal(t, 500, len(got.Reason)) + assert.Equal(t, true, utf8.ValidString(got.Reason)) +} + +func TestElicitationGateDeclineAction(t *testing.T) { + elicitor := &fakeElicitor{response: ElicitResponse{Action: "decline"}} + gate := NewElicitationGate(elicitor, zap.NewNop()) + + got, err := gate.Decide(context.Background(), sandbox.ApprovalRequest{Name: "deleteOrders"}) + + require.NoError(t, err) + assert.Equal(t, sandbox.ApprovalDecision{Approved: false, Reason: "Mutation declined by operator"}, got) +} + +func TestElicitationGateCancelAction(t *testing.T) { + elicitor := &fakeElicitor{response: ElicitResponse{Action: "cancel"}} + gate := NewElicitationGate(elicitor, zap.NewNop()) + + got, err := gate.Decide(context.Background(), sandbox.ApprovalRequest{Name: "deleteOrders"}) + + require.NoError(t, err) + assert.Equal(t, sandbox.ApprovalDecision{Approved: false, Reason: "Mutation declined by operator"}, got) +} + +func TestElicitationGateUnsupportedElicitationErrorDeclines(t *testing.T) { + elicitor := &fakeElicitor{err: errors.New("elicitation not supported")} + gate := NewElicitationGate(elicitor, zap.NewNop()) + + got, err := gate.Decide(context.Background(), sandbox.ApprovalRequest{Name: "deleteOrders"}) + + require.NoError(t, err) + assert.Equal(t, sandbox.ApprovalDecision{ + Approved: false, + Reason: "mutation approval is required but the MCP client does not support elicitation: elicitation not supported", + }, got) +} + +func TestElicitationGateContextCanceledErrorDeclines(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + elicitor := &fakeElicitor{err: ctx.Err()} + gate := NewElicitationGate(elicitor, zap.NewNop()) + + got, err := gate.Decide(ctx, sandbox.ApprovalRequest{Name: "deleteOrders"}) + + require.NoError(t, err) + assert.Equal(t, sandbox.ApprovalDecision{ + Approved: false, + Reason: "mutation approval is required but the MCP client does not support elicitation: context canceled", + }, got) +} + +func TestElicitationGateAcceptWithoutFormDataDeclines(t *testing.T) { + elicitor := &fakeElicitor{response: ElicitResponse{Action: "accept"}} + gate := NewElicitationGate(elicitor, zap.NewNop()) + + got, err := gate.Decide(context.Background(), sandbox.ApprovalRequest{Name: "deleteOrders"}) + + require.NoError(t, err) + assert.Equal(t, sandbox.ApprovalDecision{Approved: false, Reason: "Mutation declined by operator"}, got) +} diff --git a/router/internal/codemode/server/descriptions/descriptions.go b/router/internal/codemode/server/descriptions/descriptions.go new file mode 100644 index 0000000000..3336fe7c77 --- /dev/null +++ b/router/internal/codemode/server/descriptions/descriptions.go @@ -0,0 +1,37 @@ +// Package descriptions holds the markdown text used as MCP server, tool, and +// resource descriptions for the Code Mode server. Each description lives in its +// own .md file and is embedded at compile time so prose can be edited without +// touching Go source. go:embed only supports vars (not consts), so each export +// is a package-level string treated as immutable. +package descriptions + +import ( + _ "embed" + "strings" +) + +//go:embed search_tool.md +var rawSearchTool string + +//go:embed execute_tool.md +var rawExecuteTool string + +//go:embed execute_source.md +var rawExecuteSource string + +//go:embed persisted_ops_resource.md +var rawPersistedOpsResource string + +// SearchTool is the description of the code_mode_search_tools MCP tool. +var SearchTool = strings.TrimRight(rawSearchTool, "\n") + +// ExecuteTool is the description of the code_mode_run_js MCP tool. +var ExecuteTool = strings.TrimRight(rawExecuteTool, "\n") + +// ExecuteSource is the description of the `source` input parameter of the +// code_mode_run_js MCP tool. +var ExecuteSource = strings.TrimRight(rawExecuteSource, "\n") + +// PersistedOpsResource is the description of the yoko://persisted-ops.d.ts MCP +// resource. +var PersistedOpsResource = strings.TrimRight(rawPersistedOpsResource, "\n") diff --git a/router/internal/codemode/server/descriptions/execute_source.md b/router/internal/codemode/server/descriptions/execute_source.md new file mode 100644 index 0000000000..178814932a --- /dev/null +++ b/router/internal/codemode/server/descriptions/execute_source.md @@ -0,0 +1,3 @@ +JavaScript source containing a single async arrow function. +The host wraps it as `()()` and awaits the resulting Promise; +the resolved JSON-serializable value is the tool result. \ No newline at end of file diff --git a/router/internal/codemode/server/descriptions/execute_tool.md b/router/internal/codemode/server/descriptions/execute_tool.md new file mode 100644 index 0000000000..28646da1e1 --- /dev/null +++ b/router/internal/codemode/server/descriptions/execute_tool.md @@ -0,0 +1,32 @@ +Run JavaScript source as a single async arrow function in the Code Mode sandbox. +Use `await tools.(vars)` for operations registered by code_mode_search_tools; +the cumulative tools namespace is available at `yoko://persisted-ops.d.ts`. + +Style: write compact source — single line if it fits, no // comments, no blank lines, short variable names. +The JSON wrapping that encodes your source charges you for every newline and indent space. + +Batch everything into ONE code_mode_run_js call. +≥3 `tools.*` invocations per call is normal; +over-fetch and decide in JS, don't round-trip. +A failing inner call degrades the result, not the whole script — wrap with try/catch and surface the error in the return value. + +The return value of your async arrow is the only output channel — `console` is not available. +To surface intermediate state, include it in the returned object (e.g. `return { result, debug: { ... } }`). +For resilient fan-out use `Promise.allSettled` — `Promise.all` rejects on first failure and discards partial results. +Up to 256 `tools.*` invocations per call. +Non-serializable leaves in the return value (`BigInt`, functions, symbols, `undefined`, circular refs) are replaced with the sentinel string `<>` and listed in the response's `warnings: [{path, kind}]` field; +the rest of the value still comes through. + +Example: `async()=>{const o=await tools.getOrders({customerId:"c_1"});if(o.errors?.length)throw new Error(o.errors[0].message);return o.data.orders;}` + +Type declarations for reference (consumed via `yoko://persisted-ops.d.ts`): + +```ts +type GraphQLError = { message: string; path?: (string | number)[]; extensions?: Record }; +type R = Promise<{ data: T | null; errors?: GraphQLError[] }>; + +declare const tools: {}; + +declare function notNull(value: T | null | undefined, message?: string): T; +declare function compact(value: T): T; +``` \ No newline at end of file diff --git a/router/internal/codemode/server/descriptions/persisted_ops_resource.md b/router/internal/codemode/server/descriptions/persisted_ops_resource.md new file mode 100644 index 0000000000..2e119392d6 --- /dev/null +++ b/router/internal/codemode/server/descriptions/persisted_ops_resource.md @@ -0,0 +1 @@ +Cumulative TypeScript definitions for the current Code Mode MCP session's named operations. \ No newline at end of file diff --git a/router/internal/codemode/server/descriptions/search_tool.md b/router/internal/codemode/server/descriptions/search_tool.md new file mode 100644 index 0000000000..64581a38e1 --- /dev/null +++ b/router/internal/codemode/server/descriptions/search_tool.md @@ -0,0 +1,38 @@ +Plan ALL data shapes you need up front, +then call ONCE with every prompt in a single batch. +Each extra search is a round-trip you pay for. + +DEFAULT TO ONE PROMPT. +If the entities are related in any way — same domain, joinable, fetched together to answer one question, +traversed via the same parent, or the user mentioned them in the same breath — combine them into a SINGLE prompt that describes the complete joined shape. +Multiple prompts should be the exception, not the default. + +Write each prompt as the COMPLETE final shape of data you want, including joins and correlation IDs. + +Write prompts in a graph-like shape with relationships and nesting, not as separate flat queries. + +BE PRECISE about what you need. +Vague prompts produce vague operations and force re-searches. +Always state: +- The exact fields you need on each entity ("id, forename, surname" — not "name info"). +- Any required filters/arguments but never specific values ("employee by id - not "employee 123", "employee filtered by department name" - not "employee in department 'Engineering'"). +- Concrete entity and relationship names from the domain when you know them; otherwise describe the relationship explicitly ("the team an employee belongs to"). + +When to use multiple prompts (rare): genuinely unrelated operations on disjoint domains, different argument shapes that can't share a parent, or queries vs mutations. +Never slice one joinable shape into fragments. +When in doubt, combine. + +Do NOT issue prompts for derived/computed values: averages, medians, counts, filters, exclusions ("without X"), sorting, top-N. +Fetch the raw rows once and compute in code_mode_run_js. +Yoko exposes data; arithmetic and reshaping happen in your JS. + +Anti-pattern: search → inspect result → notice a field or ID is missing → search again. +One well-formed prompt beats three round-trips. + +The response appends newly registered TypeScript declarations for use as `await tools.(vars)` inside code_mode_run_js; +the cumulative bundle is available at `yoko://persisted-ops.d.ts`. + +Good example: "employee filtered by id with fields id, forename, surname, role, startDate; their team with fields id, name and the team's department with fields id, name; the projects the employee is assigned to with fields id, title, status, dueDate and each project's owner (employee) with fields id, forename, surname" + +Bad examples: ["list of employees with name info", "team for employee 123", "projects in department 'Engineering'", "top 5 employees by project count", "average project duration per team"] +— five prompts instead of one joined shape, vague fields ("name info"), hardcoded filter values ("123", "'Engineering'"), and derived/computed results (top-N, average) that belong in code_mode_run_js, not in a search prompt. \ No newline at end of file diff --git a/router/internal/codemode/server/execute_handler.go b/router/internal/codemode/server/execute_handler.go new file mode 100644 index 0000000000..70234b3525 --- /dev/null +++ b/router/internal/codemode/server/execute_handler.go @@ -0,0 +1,101 @@ +package server + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/wundergraph/cosmo/router/internal/codemode/harness" + "github.com/wundergraph/cosmo/router/internal/codemode/observability" + "github.com/wundergraph/cosmo/router/internal/codemode/sandbox" +) + +type executeAPIInput struct { + Source string `json:"source"` +} + +func (s *Server) handleExecuteAPI(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ctx = contextWithSessionFromExtra(ctx, req.GetExtra()) + + source, err := decodeExecuteSource(req) + if err != nil { + return toolErrorResult(err.Error()), nil + } + + if !s.namedOpsEnabled || s.sessionStateless { + return toolErrorResult(namedOpsDisabledMessage), nil + } + + sessionID := SessionIDFromContext(ctx) + if sessionID == "" { + return toolErrorResult(namedOpsDisabledMessage), nil + } + if s.storage == nil { + return toolErrorResult("code_mode_run_js: storage is not configured"), nil + } + if s.pipeline == nil { + return toolErrorResult("code_mode_run_js: pipeline failed: code mode execute pipeline is not configured"), nil + } + + names, err := s.storage.ListNames(ctx, sessionID) + if err != nil { + return toolErrorResult(fmt.Sprintf("code_mode_run_js: failed to list tools: %v", err)), nil + } + + executeTimeout := s.executeTimeout + if executeTimeout <= 0 { + executeTimeout = defaultExecuteTimeout + } + execCtx, cancel := context.WithTimeout(ctx, executeTimeout) + defer cancel() + + response, err := s.pipeline.Execute(execCtx, harness.PipelineRequest{ + SessionID: sessionID, + ToolNames: names, + Source: source, + RequestHeaders: requestHeaders(req), + ApprovalGate: s.approvalGateForRequest(req), + }) + if err != nil { + return toolErrorResult(fmt.Sprintf("code_mode_run_js: pipeline failed: %v", err)), nil + } + if response.Envelope.Error != nil && response.Envelope.Error.Name == "TranspileError" { + observability.LogTranspileFailure(s.logger, sessionID, response.Envelope.Error.Message) + } + return textResult(string(response.Encoded)), nil +} + +func decodeExecuteSource(req *mcp.CallToolRequest) (string, error) { + var input executeAPIInput + if req != nil && req.Params != nil && len(req.Params.Arguments) > 0 { + if err := json.Unmarshal(req.Params.Arguments, &input); err != nil { + return "", errors.New("code_mode_run_js: source must be a non-empty string") + } + } + if strings.TrimSpace(input.Source) == "" { + return "", errors.New("code_mode_run_js: source must be a non-empty string") + } + return input.Source, nil +} + +func (s *Server) approvalGateForRequest(req *mcp.CallToolRequest) sandbox.ApprovalGate { + if s.approvalGate != nil { + return s.approvalGate + } + var session *mcp.ServerSession + if req != nil { + session = req.Session + } + return NewElicitationGate(NewMCPElicitor(session), s.logger) +} + +func requestHeaders(req *mcp.CallToolRequest) http.Header { + if req == nil || req.GetExtra() == nil { + return nil + } + return req.GetExtra().Header.Clone() +} diff --git a/router/internal/codemode/server/execute_handler_test.go b/router/internal/codemode/server/execute_handler_test.go new file mode 100644 index 0000000000..57cef16d6e --- /dev/null +++ b/router/internal/codemode/server/execute_handler_test.go @@ -0,0 +1,431 @@ +package server + +import ( + "context" + "encoding/json" + "net/http" + "sync" + "testing" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/internal/codemode/harness" + "github.com/wundergraph/cosmo/router/internal/codemode/sandbox" + "github.com/wundergraph/cosmo/router/internal/codemode/storage" + "github.com/wundergraph/cosmo/router/internal/codemode/tsgen" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" +) + +func TestHandleExecuteValidatesSource(t *testing.T) { + srv := newExecuteTestServer(t, Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: false, + Pipeline: &recordingPipeline{}, + }, newExecuteTestStorage()) + + got, err := srv.handleExecute(context.Background(), executeToolRequest(t, "session-1", map[string]any{ + "source": "", + })) + + require.NoError(t, err) + assert.Equal(t, toolError("code_mode_run_js: source must be a non-empty string"), got) +} + +func TestHandleExecuteNamedOpsDisabled(t *testing.T) { + srv := newExecuteTestServer(t, Config{ + CodeModeEnabled: true, + NamedOpsEnabled: false, + SessionStateless: false, + Pipeline: &recordingPipeline{}, + }, newExecuteTestStorage()) + + got, err := srv.handleExecute(context.Background(), executeToolRequest(t, "session-1", map[string]any{ + "source": "async () => null", + })) + + require.NoError(t, err) + assert.Equal(t, toolError("named operations are disabled"), got) +} + +func TestHandleExecuteStatelessDisablesNamedOps(t *testing.T) { + srv := newExecuteTestServer(t, Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: true, + Pipeline: &recordingPipeline{}, + }, newExecuteTestStorage()) + + got, err := srv.handleExecute(context.Background(), executeToolRequest(t, "session-1", map[string]any{ + "source": "async () => null", + })) + + require.NoError(t, err) + assert.Equal(t, toolError("named operations are disabled"), got) +} + +func TestHandleExecuteStatefulHappyPathReturnsEncodedEnvelope(t *testing.T) { + store := newExecuteTestStorage() + store.ops["session-1"] = []storage.SessionOp{{ + Name: "someName", + Body: "query SomeName { orders { id total } }", + Kind: storage.OperationKindQuery, + }} + pipeline := &recordingPipeline{ + response: pipelineResponse(t, harness.ResultEnvelope{ + Result: json.RawMessage(`{"orders":[{"id":"o1","total":12.5}]}`), + Truncated: false, + Error: nil, + }), + } + srv := newExecuteTestServer(t, Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: false, + Pipeline: pipeline, + ApprovalGate: sandbox.AutoApprove, + }, store) + + got, err := srv.handleExecute(context.Background(), executeToolRequest(t, "session-1", map[string]any{ + "source": "async () => { const r = await tools.someName({}); return r.data; }", + })) + + require.NoError(t, err) + assert.Equal(t, textToolResult(string(pipeline.response.Encoded)), got) + assert.Equal(t, harness.PipelineRequest{ + SessionID: "session-1", + ToolNames: []string{ + "someName", + }, + Source: "async () => { const r = await tools.someName({}); return r.data; }", + RequestHeaders: http.Header{ + mcpSessionIDHeader: []string{"session-1"}, + "X-Test": []string{"yes"}, + }, + ApprovalGate: sandbox.AutoApprove, + }, pipeline.lastRequest()) + + var decoded map[string]any + require.NoError(t, json.Unmarshal(pipeline.response.Encoded, &decoded)) + assert.Equal(t, map[string]any{ + "result": map[string]any{ + "orders": []any{ + map[string]any{"id": "o1", "total": 12.5}, + }, + }, + }, decoded) +} + +func TestHandleExecuteSandboxErrorEnvelopeReturnsAsText(t *testing.T) { + store := newExecuteTestStorage() + store.ops["session-1"] = []storage.SessionOp{{Name: "someName", Body: "query SomeName { orders { id } }", Kind: storage.OperationKindQuery}} + pipeline := &recordingPipeline{ + response: pipelineResponse(t, harness.ResultEnvelope{ + Result: json.RawMessage("null"), + Truncated: false, + Error: &harness.ErrorEnvelope{Name: "RuntimeError", Message: "boom", Stack: "stack"}, + }), + } + srv := newExecuteTestServer(t, Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: false, + Pipeline: pipeline, + }, store) + + got, err := srv.handleExecute(context.Background(), executeToolRequest(t, "session-1", map[string]any{ + "source": "async () => { throw new Error('boom'); }", + })) + + require.NoError(t, err) + assert.Equal(t, textToolResult(string(pipeline.response.Encoded)), got) + + var decoded map[string]any + require.NoError(t, json.Unmarshal(pipeline.response.Encoded, &decoded)) + assert.Equal(t, map[string]any{ + "result": nil, + "error": map[string]any{ + "name": "RuntimeError", + "message": "boom", + "stack": "stack", + }, + }, decoded) +} + +func TestHandleExecutePerCallTimeoutRoutesEnvelope(t *testing.T) { + store := newExecuteTestStorage() + store.ops["session-1"] = []storage.SessionOp{{Name: "someName", Body: "query SomeName { orders { id } }", Kind: storage.OperationKindQuery}} + pipeline := &recordingPipeline{sleep: 100 * time.Millisecond} + pipeline.onCancel = pipelineResponse(t, harness.ResultEnvelope{ + Result: json.RawMessage("null"), + Truncated: false, + Error: &harness.ErrorEnvelope{Name: "Timeout", Message: "context deadline exceeded", Stack: ""}, + }) + srv := newExecuteTestServer(t, Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: false, + ExecuteTimeout: 10 * time.Millisecond, + Pipeline: pipeline, + }, store) + + got, err := srv.handleExecute(context.Background(), executeToolRequest(t, "session-1", map[string]any{ + "source": "async () => tools.someName({})", + })) + + require.NoError(t, err) + assert.Equal(t, textToolResult(string(pipeline.onCancel.Encoded)), got) + + var decoded map[string]any + require.NoError(t, json.Unmarshal(pipeline.onCancel.Encoded, &decoded)) + assert.Equal(t, map[string]any{ + "result": nil, + "error": map[string]any{ + "name": "Timeout", + "message": "context deadline exceeded", + "stack": "", + }, + }, decoded) +} + +func TestHandleExecuteTranspileErrorEnvelopeReturnsAsText(t *testing.T) { + store := newExecuteTestStorage() + store.ops["session-1"] = []storage.SessionOp{{Name: "someName", Body: "query SomeName { orders { id } }", Kind: storage.OperationKindQuery}} + srv := newExecuteTestServer(t, Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: false, + Pipeline: &harness.Pipeline{}, + }, store) + + got, err := srv.handleExecute(context.Background(), executeToolRequest(t, "session-1", map[string]any{ + "source": "async () => { let x = ; }", + })) + + require.NoError(t, err) + require.Len(t, got.Content, 1) + text, ok := got.Content[0].(*mcp.TextContent) + require.True(t, ok) + + var decoded map[string]any + require.NoError(t, json.Unmarshal([]byte(text.Text), &decoded)) + assert.Equal(t, map[string]any{ + "result": nil, + "error": map[string]any{ + "name": "TranspileError", + "message": "transpile failed: Unexpected \";\"", + "stack": "", + }, + }, decoded) +} + +func TestPersistedOpsResourceReturnsCumulativeBundle(t *testing.T) { + schema := searchHandlerTestSchema(t) + store := storage.NewMemoryBackend(storage.MemoryConfig{Renderer: tsgen.Adapter(schema, 0)}) + store.SetSchema(schema) + _, err := store.Append(context.Background(), "session-1", []storage.SessionOp{{ + Name: "getOrders", + Body: "query GetOrders($limit: Int) { orders(limit: $limit) { id total } }", + Kind: storage.OperationKindQuery, + Description: "Fetch orders.", + }}) + require.NoError(t, err) + srv := newExecuteTestServer(t, Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: false, + Storage: store, + Pipeline: &recordingPipeline{}, + }, nil) + + got, err := srv.handlePersistedOpsResource(context.Background(), resourceRequest("session-1")) + + require.NoError(t, err) + wantBundle, err := tsgen.RenderBundle([]storage.SessionOp{{ + Name: "getOrders", + Body: "query GetOrders($limit: Int) { orders(limit: $limit) { id total } }", + Kind: storage.OperationKindQuery, + Description: "Fetch orders.", + }}, schema, 0) + require.NoError(t, err) + assert.Equal(t, &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{{ + URI: persistedOpsURI, + MIMEType: "text/plain", + Text: wantBundle, + }}, + }, got) +} + +func TestPersistedOpsResourceWithoutSessionReturnsEmptyBundle(t *testing.T) { + schema := searchHandlerTestSchema(t) + store := storage.NewMemoryBackend(storage.MemoryConfig{Renderer: tsgen.Adapter(schema, 0)}) + store.SetSchema(schema) + srv := newExecuteTestServer(t, Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: false, + Storage: store, + Pipeline: &recordingPipeline{}, + }, nil) + + got, err := srv.handlePersistedOpsResource(context.Background(), resourceRequest("")) + + require.NoError(t, err) + wantBundle, err := tsgen.RenderBundle(nil, schema, 0) + require.NoError(t, err) + assert.Equal(t, &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{{ + URI: persistedOpsURI, + MIMEType: "text/plain", + Text: wantBundle, + }}, + }, got) +} + +type recordingPipeline struct { + mu sync.Mutex + requests []harness.PipelineRequest + response harness.PipelineResponse + onCancel harness.PipelineResponse + sleep time.Duration + err error + lastSpan trace.SpanContext +} + +func (p *recordingPipeline) Execute(ctx context.Context, req harness.PipelineRequest) (harness.PipelineResponse, error) { + p.mu.Lock() + p.requests = append(p.requests, req) + p.lastSpan = trace.SpanFromContext(ctx).SpanContext() + p.mu.Unlock() + + if p.sleep > 0 { + select { + case <-ctx.Done(): + return p.onCancel, nil + case <-time.After(p.sleep): + } + } + if p.err != nil { + return harness.PipelineResponse{}, p.err + } + return p.response, nil +} + +func (p *recordingPipeline) lastRequest() harness.PipelineRequest { + p.mu.Lock() + defer p.mu.Unlock() + if len(p.requests) == 0 { + return harness.PipelineRequest{} + } + return p.requests[len(p.requests)-1] +} + +func (p *recordingPipeline) lastSpanContext() trace.SpanContext { + p.mu.Lock() + defer p.mu.Unlock() + return p.lastSpan +} + +type executeTestStorage struct { + mu sync.Mutex + ops map[string][]storage.SessionOp +} + +func newExecuteTestStorage() *executeTestStorage { + return &executeTestStorage{ops: make(map[string][]storage.SessionOp)} +} + +func (s *executeTestStorage) Append(_ context.Context, sessionID string, ops []storage.SessionOp) ([]storage.SessionOp, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.ops[sessionID] = append(s.ops[sessionID], ops...) + return ops, nil +} + +func (s *executeTestStorage) GetOp(_ context.Context, sessionID string, name string) (storage.SessionOp, bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + for _, op := range s.ops[sessionID] { + if op.Name == name { + return op, true, nil + } + } + return storage.SessionOp{}, false, nil +} + +func (s *executeTestStorage) ListNames(_ context.Context, sessionID string) ([]string, error) { + s.mu.Lock() + defer s.mu.Unlock() + names := make([]string, 0, len(s.ops[sessionID])) + for _, op := range s.ops[sessionID] { + names = append(names, op.Name) + } + return names, nil +} + +func (s *executeTestStorage) Bundle(context.Context, string) (string, error) { + return "", nil +} + +func (s *executeTestStorage) Reset(_ context.Context, sessionID string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.ops, sessionID) + return nil +} + +func (s *executeTestStorage) SetSchema(*ast.Document) {} + +func (s *executeTestStorage) Schema() *ast.Document { return nil } + +func (s *executeTestStorage) Start(context.Context) error { return nil } + +func (s *executeTestStorage) Stop() error { return nil } + +func pipelineResponse(t *testing.T, envelope harness.ResultEnvelope) harness.PipelineResponse { + t.Helper() + encoded, err := json.Marshal(envelope) + require.NoError(t, err) + return harness.PipelineResponse{Envelope: envelope, Encoded: encoded} +} + +func executeToolRequest(t *testing.T, sessionID string, arguments map[string]any) *mcp.CallToolRequest { + t.Helper() + body, err := json.Marshal(arguments) + require.NoError(t, err) + return &mcp.CallToolRequest{ + Params: &mcp.CallToolParamsRaw{ + Name: "code_mode_run_js", + Arguments: body, + }, + Extra: &mcp.RequestExtra{Header: http.Header{ + mcpSessionIDHeader: []string{sessionID}, + "X-Test": []string{"yes"}, + }}, + } +} + +func resourceRequest(sessionID string) *mcp.ReadResourceRequest { + return &mcp.ReadResourceRequest{ + Params: &mcp.ReadResourceParams{URI: persistedOpsURI}, + Extra: &mcp.RequestExtra{Header: http.Header{mcpSessionIDHeader: []string{sessionID}}}, + } +} + +func newExecuteTestServer(t *testing.T, cfg Config, store storage.SessionStorage) *Server { + t.Helper() + if store != nil { + cfg.Storage = store + } + if cfg.Logger == nil { + cfg.Logger = zap.NewNop() + } + srv, err := New(cfg) + require.NoError(t, err) + return srv +} diff --git a/router/internal/codemode/server/lifecycle.go b/router/internal/codemode/server/lifecycle.go new file mode 100644 index 0000000000..42d3be7600 --- /dev/null +++ b/router/internal/codemode/server/lifecycle.go @@ -0,0 +1,182 @@ +package server + +import ( + "context" + "fmt" + "net/http" + + "github.com/wundergraph/cosmo/router/internal/codemode/harness" + "github.com/wundergraph/cosmo/router/internal/codemode/sandbox" + "github.com/wundergraph/cosmo/router/internal/codemode/storage" + "github.com/wundergraph/cosmo/router/internal/codemode/tsgen" + "github.com/wundergraph/cosmo/router/internal/codemode/yoko" + "github.com/wundergraph/cosmo/router/internal/rediscloser" + "github.com/wundergraph/cosmo/router/pkg/config" + "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" +) + +type BuildOptions struct { + Config config.MCPCodeModeConfiguration + SessionStateless bool + RouterGraphQLURL string + Logger *zap.Logger + TracerProvider trace.TracerProvider + MeterProvider metric.MeterProvider + + // RedisProvider is the resolved storage_providers.redis entry referenced by + // cfg.NamedOps.Storage.ProviderID. When nil, the in-memory backend is used. + // Provider lookup (and the "unknown id" error) is performed by the router. + RedisProvider *config.RedisStorageProvider + // RedisFactory is an optional override used by tests. When nil, the default + // rediscloser.NewRedisCloser is used. + RedisFactory func(opts *rediscloser.RedisCloserOptions) (rediscloser.RDCloser, error) +} + +func BuildFromConfig(opts BuildOptions) (*Server, error) { + logger := opts.Logger + if logger == nil { + logger = zap.NewNop() + } + + cfg := opts.Config + if !cfg.Enabled { + return New(Config{ + ListenAddr: cfg.Server.ListenAddr, + CodeModeEnabled: cfg.Enabled, + NamedOpsEnabled: cfg.NamedOps.Enabled, + SessionStateless: opts.SessionStateless, + ExecuteTimeout: cfg.ExecuteTimeout, + MaxResultBytes: cfg.MaxResultBytes, + Logger: logger, + TracerProvider: opts.TracerProvider, + MeterProvider: opts.MeterProvider, + ApprovalGate: sandbox.AutoApprove, + CallTraceRecorder: nil, + }) + } + + renderer := tsgen.Adapter(nil, cfg.NamedOps.MaxBundleBytes) + store, err := buildStorage(cfg, renderer, opts, logger) + if err != nil { + return nil, err + } + + sbx, err := sandbox.New(sandbox.Config{ + RouterGraphQLEndpoint: opts.RouterGraphQLURL, + RequestTimeout: cfg.Sandbox.Timeout, + MemoryLimitBytes: cfg.Sandbox.MaxMemoryMB * 1024 * 1024, + MaxInputSizeBytes: cfg.Sandbox.MaxInputSizeBytes, + MaxOutputSizeBytes: cfg.Sandbox.MaxOutputSizeBytes, + MaxResultBytes: cfg.MaxResultBytes, + StorageLookup: func(ctx context.Context, sessionID string, name string) (storage.SessionOp, bool, error) { + if store == nil { + return storage.SessionOp{}, false, nil + } + return store.GetOp(ctx, sessionID, name) + }, + Logger: logger, + }) + if err != nil { + return nil, fmt.Errorf("create code mode sandbox: %w", err) + } + + return New(Config{ + ListenAddr: cfg.Server.ListenAddr, + CodeModeEnabled: cfg.Enabled, + NamedOpsEnabled: cfg.NamedOps.Enabled, + SessionStateless: opts.SessionStateless, + Storage: store, + Pipeline: &harness.Pipeline{Sandbox: sbx, MaxInputBytes: cfg.Sandbox.MaxInputSizeBytes, MaxResultBytes: cfg.MaxResultBytes}, + YokoClient: buildYokoClient(cfg.QueryGeneration, logger), + BundleRenderer: renderer, + ExecuteTimeout: cfg.ExecuteTimeout, + MaxResultBytes: cfg.MaxResultBytes, + ApprovalGate: buildApprovalGate(cfg, logger), + Logger: logger, + MeterProvider: opts.MeterProvider, + TracerProvider: opts.TracerProvider, + CallTraceRecorder: nil, + }) +} + +func buildStorage(cfg config.MCPCodeModeConfiguration, renderer storage.Renderer, opts BuildOptions, logger *zap.Logger) (storage.SessionStorage, error) { + if !cfg.NamedOps.Enabled { + return nil, nil + } + + if opts.RedisProvider == nil { + return storage.NewMemoryBackend(storage.MemoryConfig{ + SessionTTL: cfg.NamedOps.SessionTTL, + MaxSessions: cfg.NamedOps.MaxSessions, + MaxBundleBytes: cfg.NamedOps.MaxBundleBytes, + Renderer: renderer, + }), nil + } + + factory := opts.RedisFactory + if factory == nil { + factory = rediscloser.NewRedisCloser + } + client, err := factory(&rediscloser.RedisCloserOptions{ + Logger: logger, + URLs: opts.RedisProvider.URLs, + ClusterEnabled: opts.RedisProvider.ClusterEnabled, + }) + if err != nil { + return nil, fmt.Errorf("create code mode redis storage client: %w", err) + } + backend, err := storage.NewRedisBackend(storage.RedisConfig{ + Client: client, + KeyPrefix: cfg.NamedOps.Storage.KeyPrefix, + SessionTTL: cfg.NamedOps.SessionTTL, + Renderer: renderer, + Logger: logger, + }) + if err != nil { + return nil, fmt.Errorf("create code mode redis storage backend: %w", err) + } + return backend, nil +} + +func buildYokoClient(cfg config.MCPCodeModeQueryGenConfig, logger *zap.Logger) *yoko.Client { + if !cfg.Enabled { + return nil + } + client := &http.Client{Timeout: cfg.Timeout} + if token := cfg.Auth.StaticToken; cfg.Auth.Type == "" || cfg.Auth.Type == "static" { + if token != "" { + client.Transport = staticBearerRoundTripper{ + token: token, + next: http.DefaultTransport, + } + } + } else if cfg.Auth.Type == "jwt" { + logger.Warn("code mode query generation jwt auth is not implemented; proceeding without auth") + } + return yoko.New(client, cfg.Endpoint, logger) +} + +func buildApprovalGate(cfg config.MCPCodeModeConfiguration, _ *zap.Logger) sandbox.ApprovalGate { + if cfg.RequireMutationApproval { + return nil + } + return sandbox.AutoApprove +} + +type staticBearerRoundTripper struct { + token string + next http.RoundTripper +} + +func (t staticBearerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + next := t.next + if next == nil { + next = http.DefaultTransport + } + cloned := req.Clone(req.Context()) + cloned.Header = req.Header.Clone() + cloned.Header.Set("Authorization", "Bearer "+t.token) + return next.RoundTrip(cloned) +} diff --git a/router/internal/codemode/server/lifecycle_test.go b/router/internal/codemode/server/lifecycle_test.go new file mode 100644 index 0000000000..349e2f1419 --- /dev/null +++ b/router/internal/codemode/server/lifecycle_test.go @@ -0,0 +1,206 @@ +package server + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/internal/codemode/storage" + "github.com/wundergraph/cosmo/router/internal/codemode/yoko" + "github.com/wundergraph/cosmo/router/internal/rediscloser" + "github.com/wundergraph/cosmo/router/pkg/config" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" + "github.com/wundergraph/graphql-go-tools/v2/pkg/asttransform" + "go.uber.org/zap" +) + +func TestBuildFromConfigDisabledIsNoOp(t *testing.T) { + srv, err := BuildFromConfig(BuildOptions{ + Config: config.MCPCodeModeConfiguration{Enabled: false}, + SessionStateless: false, + Logger: zap.NewNop(), + }) + require.NoError(t, err) + + require.NoError(t, srv.Start(context.Background())) + assert.Equal(t, "", srv.addr()) + require.NoError(t, srv.Reload(&ast.Document{}, "schema { query: Query }")) + require.NoError(t, srv.Stop(context.Background())) +} + +func TestBuildFromConfigMemoryBackendReloadsSchemaAndSDL(t *testing.T) { + cfg := fullLifecycleConfig() + srv, err := BuildFromConfig(BuildOptions{ + Config: cfg, + SessionStateless: false, + RouterGraphQLURL: "http://router.local/graphql", + Logger: zap.NewNop(), + }) + require.NoError(t, err) + + backend, ok := srv.storage.(*storage.MemoryBackend) + require.True(t, ok) + + schema := lifecycleTestSchema(t) + require.NoError(t, srv.Reload(schema, "type Query { orders: [Order!]! }")) + + assert.Equal(t, schema, backend.Schema()) + client, ok := srv.yokoClient.(*yoko.Client) + require.True(t, ok) + assert.Equal(t, "type Query { orders: [Order!]! }", client.Schema()) +} + +func TestBuildFromConfigRedisFactoryError(t *testing.T) { + cfg := fullLifecycleConfig() + cfg.NamedOps.Storage.ProviderID = "my_redis" + + srv, err := BuildFromConfig(BuildOptions{ + Config: cfg, + SessionStateless: false, + RouterGraphQLURL: "http://router.local/graphql", + Logger: zap.NewNop(), + RedisProvider: &config.RedisStorageProvider{ + ID: "my_redis", + URLs: []string{"redis://127.0.0.1:6379"}, + }, + RedisFactory: func(*rediscloser.RedisCloserOptions) (rediscloser.RDCloser, error) { + return nil, errors.New("redis unavailable") + }, + }) + + require.Nil(t, srv) + require.ErrorContains(t, err, "create code mode redis storage client: redis unavailable") +} + +func TestBuildFromConfigRedisBackendWithMiniredis(t *testing.T) { + mr, err := miniredis.Run() + if err != nil { + if isBindPermissionError(err) { + t.Skipf("local miniredis bind is not permitted in this environment: %v", err) + } + require.NoError(t, err) + } + t.Cleanup(mr.Close) + var gotOpts rediscloser.RedisCloserOptions + var client *redis.Client + t.Cleanup(func() { + if client != nil { + require.NoError(t, client.Close()) + } + }) + + cfg := fullLifecycleConfig() + cfg.NamedOps.Storage.ProviderID = "my_redis" + cfg.NamedOps.Storage.KeyPrefix = "test_code_mode" + + srv, err := BuildFromConfig(BuildOptions{ + Config: cfg, + SessionStateless: false, + RouterGraphQLURL: "http://router.local/graphql", + Logger: zap.NewNop(), + RedisProvider: &config.RedisStorageProvider{ + ID: "my_redis", + URLs: []string{"redis://" + mr.Addr()}, + ClusterEnabled: true, + }, + RedisFactory: func(opts *rediscloser.RedisCloserOptions) (rediscloser.RDCloser, error) { + gotOpts = *opts + client = redis.NewClient(&redis.Options{Addr: mr.Addr()}) + return client, nil + }, + }) + require.NoError(t, err) + + _, ok := srv.storage.(*storage.RedisBackend) + require.True(t, ok) + assert.NotNil(t, gotOpts.Logger) + assert.Equal(t, []string{"redis://" + mr.Addr()}, gotOpts.URLs) + assert.Equal(t, true, gotOpts.ClusterEnabled) +} + +func TestBuildFromConfigReloadEvictsMemorySessions(t *testing.T) { + srv, err := BuildFromConfig(BuildOptions{ + Config: fullLifecycleConfig(), + SessionStateless: false, + RouterGraphQLURL: "http://router.local/graphql", + Logger: zap.NewNop(), + }) + require.NoError(t, err) + + _, err = srv.storage.Append(context.Background(), "session-1", []storage.SessionOp{{ + Name: "getOrders", + Body: "query GetOrders { orders { id } }", + Kind: storage.OperationKindQuery, + Description: "Fetch orders.", + }}) + require.NoError(t, err) + + _, ok, err := srv.storage.GetOp(context.Background(), "session-1", "getOrders") + require.NoError(t, err) + assert.Equal(t, true, ok) + + require.NoError(t, srv.Reload(lifecycleTestSchema(t), "type Query { customer: Customer }")) + + got, ok, err := srv.storage.GetOp(context.Background(), "session-1", "getOrders") + require.NoError(t, err) + assert.Equal(t, false, ok) + assert.Equal(t, storage.SessionOp{}, got) +} + +func TestBuildFromConfigDisabledReloadIsNoOp(t *testing.T) { + srv, err := BuildFromConfig(BuildOptions{ + Config: config.MCPCodeModeConfiguration{Enabled: false}, + SessionStateless: false, + Logger: zap.NewNop(), + }) + require.NoError(t, err) + + require.NoError(t, srv.Reload(lifecycleTestSchema(t), "type Query { orders: [Order!]! }")) + assert.Nil(t, srv.storage) + assert.Nil(t, srv.yokoClient) +} + +func fullLifecycleConfig() config.MCPCodeModeConfiguration { + return config.MCPCodeModeConfiguration{ + Enabled: true, + Server: config.MCPCodeModeServerConfig{ListenAddr: "127.0.0.1:0"}, + RequireMutationApproval: true, + ExecuteTimeout: 120 * time.Second, + MaxResultBytes: 32 << 10, + Sandbox: config.MCPCodeModeSandboxConfig{ + Timeout: 5 * time.Second, + MaxMemoryMB: 16, + MaxInputSizeBytes: 64 << 10, + MaxOutputSizeBytes: 1 << 20, + }, + QueryGeneration: config.MCPCodeModeQueryGenConfig{ + Enabled: true, + Endpoint: "http://yoko.local", + Timeout: 10 * time.Second, + Auth: config.MCPCodeModeQueryGenAuthConfig{Type: "static", StaticToken: "token"}, + }, + NamedOps: config.MCPCodeModeNamedOpsConfig{ + Enabled: true, + SessionTTL: 30 * time.Minute, + MaxSessions: 1000, + MaxBundleBytes: 256 << 10, + Storage: config.MCPCodeModeNamedOpsStorageConfig{ + KeyPrefix: "cosmo_code_mode", + }, + }, + } +} + +func lifecycleTestSchema(t *testing.T) *ast.Document { + t.Helper() + doc, report := astparser.ParseGraphqlDocumentString(searchHandlerTestSchemaSDL) + require.False(t, report.HasErrors(), report.Error()) + require.NoError(t, asttransform.MergeDefinitionWithBaseSchema(&doc)) + return &doc +} diff --git a/router/internal/codemode/server/observability_handler_test.go b/router/internal/codemode/server/observability_handler_test.go new file mode 100644 index 0000000000..8a4621326e --- /dev/null +++ b/router/internal/codemode/server/observability_handler_test.go @@ -0,0 +1,180 @@ +package server + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + yokov1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1" + "github.com/wundergraph/cosmo/router/internal/codemode/harness" + "github.com/wundergraph/cosmo/router/internal/codemode/sandbox" + "github.com/wundergraph/cosmo/router/internal/codemode/storage" + "go.opentelemetry.io/otel/attribute" + sdkmetric "go.opentelemetry.io/otel/sdk/metric" + "go.opentelemetry.io/otel/sdk/metric/metricdata" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" + "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" +) + +func TestHandleSearchRecordsObservability(t *testing.T) { + traces, meterProvider, reader := newHandlerTelemetry() + searcher := newFakeYoko() + searcher.responses <- &yokov1.SearchResponse{Operations: []*yokov1.GeneratedOperation{{ + Name: "getOrders", + Body: "query GetOrders { orders { id } }", + Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, + }}} + store := newSearchTestStorage(t) + srv, err := New(Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + Storage: store, + YokoClient: searcher, + Logger: zap.NewNop(), + TracerProvider: sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(traces)), + MeterProvider: meterProvider, + }) + require.NoError(t, err) + + got, err := srv.handleSearch(context.Background(), searchToolRequest(t, "session-1", map[string]any{ + "prompts": []string{"orders"}, + })) + + require.NoError(t, err) + require.False(t, got.IsError) + assert.Equal(t, []tracetest.SpanStub{{ + Name: "MCP Code Mode - Search", + SpanKind: trace.SpanKindServer, + Attributes: []attribute.KeyValue{ + attribute.String("mcp.tool", "code_mode_search_tools"), + attribute.String("mcp.status", "success"), + }, + InstrumentationLibrary: normalizedSpanStubs(traces.Ended())[0].InstrumentationLibrary, + }}, normalizedSpanStubs(traces.Ended())) + assertCodeModeMetric(t, reader, "code_mode_search_tools", "success") +} + +func TestHandleExecuteRecordsObservability(t *testing.T) { + traces, meterProvider, reader := newHandlerTelemetry() + store := newExecuteTestStorage() + store.ops["session-1"] = []storage.SessionOp{{ + Name: "someName", + Body: "query SomeName { orders { id total } }", + Kind: storage.OperationKindQuery, + }} + pipeline := &recordingPipeline{ + response: pipelineResponse(t, harness.ResultEnvelope{ + Result: json.RawMessage(`{"orders":[{"id":"o1"}]}`), + Truncated: false, + Error: nil, + }), + } + srv := newExecuteTestServer(t, Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: false, + Pipeline: pipeline, + ApprovalGate: sandbox.AutoApprove, + Logger: zap.NewNop(), + TracerProvider: sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(traces)), + MeterProvider: meterProvider, + }, store) + + got, err := srv.handleExecute(context.Background(), executeToolRequest(t, "session-1", map[string]any{ + "source": "async () => tools.someName({})", + })) + + require.NoError(t, err) + require.False(t, got.IsError) + assert.Equal(t, []tracetest.SpanStub{{ + Name: "MCP Code Mode - Execute", + SpanKind: trace.SpanKindServer, + Attributes: []attribute.KeyValue{ + attribute.String("mcp.tool", "code_mode_run_js"), + attribute.String("mcp.status", "success"), + }, + InstrumentationLibrary: normalizedSpanStubs(traces.Ended())[0].InstrumentationLibrary, + }}, normalizedSpanStubs(traces.Ended())) + assertCodeModeMetric(t, reader, "code_mode_run_js", "success") + require.True(t, pipeline.lastSpanContext().IsValid()) +} + +func newHandlerTelemetry() (*tracetest.SpanRecorder, *sdkmetric.MeterProvider, *sdkmetric.ManualReader) { + reader := sdkmetric.NewManualReader() + return tracetest.NewSpanRecorder(), sdkmetric.NewMeterProvider(sdkmetric.WithReader(reader)), reader +} + +func normalizedSpanStubs(spans []sdktrace.ReadOnlySpan) []tracetest.SpanStub { + stubs := make([]tracetest.SpanStub, 0, len(spans)) + for _, span := range spans { + stub := tracetest.SpanStubFromReadOnlySpan(span) + stub.SpanContext = trace.SpanContext{} + stub.StartTime = time.Time{} + stub.EndTime = time.Time{} + stub.Resource = nil + stubs = append(stubs, stub) + } + return stubs +} + +func assertCodeModeMetric(t *testing.T, reader *sdkmetric.ManualReader, toolName string, status string) { + t.Helper() + var got metricdata.ResourceMetrics + require.NoError(t, reader.Collect(context.Background(), &got)) + + counter, histogram := handlerCodeModeMetrics(t, got) + counterData, ok := counter.Data.(metricdata.Sum[int64]) + require.True(t, ok) + require.Len(t, counterData.DataPoints, 1) + counterPoint := counterData.DataPoints[0] + counterPoint.StartTime = time.Time{} + counterPoint.Time = time.Time{} + assert.Equal(t, metricdata.DataPoint[int64]{ + Attributes: attribute.NewSet( + attribute.String("mcp.tool", toolName), + attribute.String("mcp.status", status), + ), + Value: 1, + }, counterPoint) + + histogramData, ok := histogram.Data.(metricdata.Histogram[float64]) + require.True(t, ok) + require.Len(t, histogramData.DataPoints, 1) + histogramPoint := histogramData.DataPoints[0] + require.Greater(t, histogramPoint.Sum, 0.0) + histogramPoint.StartTime = time.Time{} + histogramPoint.Time = time.Time{} + assert.Equal(t, metricdata.HistogramDataPoint[float64]{ + Attributes: attribute.NewSet( + attribute.String("mcp.tool", toolName), + attribute.String("mcp.status", status), + ), + Count: 1, + Bounds: histogramPoint.Bounds, + BucketCounts: histogramPoint.BucketCounts, + Min: histogramPoint.Min, + Max: histogramPoint.Max, + Sum: histogramPoint.Sum, + }, histogramPoint) +} + +func handlerCodeModeMetrics(t *testing.T, metrics metricdata.ResourceMetrics) (metricdata.Metrics, metricdata.Metrics) { + t.Helper() + require.Len(t, metrics.ScopeMetrics, 1) + assert.Equal(t, "wundergraph.cosmo.router.mcp.code_mode", metrics.ScopeMetrics[0].Scope.Name) + + byName := make(map[string]metricdata.Metrics, len(metrics.ScopeMetrics[0].Metrics)) + for _, metric := range metrics.ScopeMetrics[0].Metrics { + byName[metric.Name] = metric + } + counter, ok := byName["mcp.code_mode.sandbox.executions"] + require.True(t, ok) + histogram, ok := byName["mcp.code_mode.sandbox.duration"] + require.True(t, ok) + return counter, histogram +} diff --git a/router/internal/codemode/server/search_handler.go b/router/internal/codemode/server/search_handler.go new file mode 100644 index 0000000000..8860cc3eca --- /dev/null +++ b/router/internal/codemode/server/search_handler.go @@ -0,0 +1,264 @@ +package server + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sort" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" + yokov1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1" + "github.com/wundergraph/cosmo/router/internal/codemode/storage" + "go.uber.org/zap" +) + +const ( + maxSearchPrompts = 20 + emptySearchAPIResponseMessage = "// 0 new ops; previous code_mode_search_tools calls already cover these prompts." + + // The generated proto currently has query and mutation constants. Yoko may + // still send the planned subscription enum value; host behavior is to drop it. + yokoOperationKindSubscription yokov1.OperationKind = 3 +) + +type searchAPIInput struct { + Prompts []string `json:"prompts"` +} + +type legacyCatalogueOperation struct { + Name string `json:"name"` + Body string `json:"body"` + Kind string `json:"kind"` + Description string `json:"description"` + Variables *string `json:"variables"` +} + +func (s *Server) handleSearchAPI(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ctx = contextWithSessionFromExtra(ctx, req.GetExtra()) + + prompts, validationErr := decodeSearchPrompts(req) + if validationErr != nil { + return toolErrorResult(validationErr.Error()), nil + } + + if s.sessionStateless { + return s.handleSearchStateless(ctx, prompts), nil + } + + sessionID := SessionIDFromContext(ctx) + if sessionID == "" { + s.warnMissingSessionIDOnce() + return s.handleSearchStateless(ctx, prompts), nil + } + + key := searchSingleFlightKey(sessionID, prompts) + value, _, _ := s.searchGroup.Do(key, func() (any, error) { + return s.handleSearchStateful(ctx, sessionID, prompts), nil + }) + return value.(*mcp.CallToolResult), nil +} + +func decodeSearchPrompts(req *mcp.CallToolRequest) ([]string, error) { + var input searchAPIInput + if req != nil && req.Params != nil && len(req.Params.Arguments) > 0 { + if err := json.Unmarshal(req.Params.Arguments, &input); err != nil { + return nil, errors.New("code_mode_search_tools: prompts must be a non-empty array of strings") + } + } + + if len(input.Prompts) == 0 { + return nil, errors.New("code_mode_search_tools: prompts must be a non-empty array of strings") + } + if len(input.Prompts) > maxSearchPrompts { + return nil, fmt.Errorf("too many prompts: %d (max 20) — pass all prompts in one call", len(input.Prompts)) + } + for i, prompt := range input.Prompts { + if strings.TrimSpace(prompt) == "" { + return nil, fmt.Errorf("code_mode_search_tools: prompt at index %d is empty", i) + } + } + return input.Prompts, nil +} + +func (s *Server) handleSearchStateless(ctx context.Context, prompts []string) *mcp.CallToolResult { + response, err := s.searchYoko(ctx, "", prompts) + if err != nil { + return toolErrorResult(fmt.Sprintf("code_mode_search_tools: yoko search failed: %v", err)) + } + + catalogue := make([]legacyCatalogueOperation, 0, len(response.GetOperations())) + droppedSubscription := false + for _, op := range response.GetOperations() { + kind, ok, subscription := yokoOperationKindLabel(op.GetKind()) + if subscription { + droppedSubscription = true + continue + } + if !ok { + s.logger.Warn("code_mode_search_tools dropped unsupported operation kind", + zap.String("name", op.GetName()), + zap.String("kind", op.GetKind().String()), + ) + continue + } + catalogue = append(catalogue, legacyCatalogueOperation{ + Name: op.GetName(), + Body: op.GetBody(), + Kind: kind, + Description: op.GetDescription(), + Variables: extractGraphQLVariablesBlock(op.GetBody()), + }) + } + if droppedSubscription { + s.logger.Warn("code_mode_search_tools dropped subscription operations returned by yoko") + } + + encoded, err := json.Marshal(catalogue) + if err != nil { + return toolErrorResult(fmt.Sprintf("code_mode_search_tools: failed to encode legacy catalogue: %v", err)) + } + return textResult(string(encoded)) +} + +func (s *Server) handleSearchStateful(ctx context.Context, sessionID string, prompts []string) *mcp.CallToolResult { + response, err := s.searchYoko(ctx, sessionID, prompts) + if err != nil { + return toolErrorResult(fmt.Sprintf("code_mode_search_tools: yoko search failed: %v", err)) + } + + rawOps := make([]storage.SessionOp, 0, len(response.GetOperations())) + droppedSubscription := false + for _, op := range response.GetOperations() { + kind, ok, subscription := storageOperationKind(op.GetKind()) + if subscription { + droppedSubscription = true + continue + } + if !ok { + s.logger.Warn("code_mode_search_tools dropped unsupported operation kind", + zap.String("name", op.GetName()), + zap.String("kind", op.GetKind().String()), + ) + continue + } + rawOps = append(rawOps, storage.SessionOp{ + Name: storage.NormalizeName(op.GetName()), + Body: op.GetBody(), + Kind: kind, + Description: op.GetDescription(), + }) + } + if droppedSubscription { + s.logger.Warn("code_mode_search_tools dropped subscription operations returned by yoko") + } + + if len(rawOps) == 0 { + return textResult(emptySearchAPIResponseMessage) + } + if s.storage == nil { + return toolErrorResult("code_mode_search_tools: failed to register ops: code mode storage is not configured") + } + + // Collision handling approach: Append-applies-suffix. The storage backend is + // the serialization point for a session and returns the final stored names. + appendedOps, err := s.storage.Append(ctx, sessionID, rawOps) + if err != nil { + return toolErrorResult(fmt.Sprintf("code_mode_search_tools: failed to register ops: %v", err)) + } + if len(appendedOps) == 0 { + return textResult(emptySearchAPIResponseMessage) + } + + rendered, err := s.newOpsFragment(appendedOps, s.storage.Schema()) + if err != nil { + return toolErrorResult(fmt.Sprintf("code_mode_search_tools: failed to render new ops: %v", err)) + } + return textResult(rendered) +} + +func (s *Server) searchYoko(ctx context.Context, sessionID string, prompts []string) (*yokov1.SearchResponse, error) { + if s.yokoClient == nil { + return nil, errors.New("yoko client is not configured") + } + return s.yokoClient.Search(ctx, sessionID, prompts) +} + +func storageOperationKind(kind yokov1.OperationKind) (storage.OperationKind, bool, bool) { + switch kind { + case yokov1.OperationKind_OPERATION_KIND_QUERY: + return storage.OperationKindQuery, true, false + case yokov1.OperationKind_OPERATION_KIND_MUTATION: + return storage.OperationKindMutation, true, false + case yokoOperationKindSubscription: + return "", false, true + default: + return "", false, false + } +} + +func yokoOperationKindLabel(kind yokov1.OperationKind) (string, bool, bool) { + switch kind { + case yokov1.OperationKind_OPERATION_KIND_QUERY: + return "Query", true, false + case yokov1.OperationKind_OPERATION_KIND_MUTATION: + return "Mutation", true, false + case yokoOperationKindSubscription: + return "", false, true + default: + return "", false, false + } +} + +func searchSingleFlightKey(sessionID string, prompts []string) string { + sortedPrompts := append([]string(nil), prompts...) + sort.Strings(sortedPrompts) + keyParts := []string{sessionID} + for _, p := range sortedPrompts { + keyParts = append(keyParts, fmt.Sprintf("%d:%s", len(p), p)) + } + return strings.Join(keyParts, "|") +} + +func extractGraphQLVariablesBlock(body string) *string { + open := strings.IndexByte(body, '(') + if open < 0 { + return nil + } + selection := strings.IndexByte(body, '{') + if selection >= 0 && selection < open { + return nil + } + + depth := 0 + for i := open; i < len(body); i++ { + switch body[i] { + case '(': + depth++ + case ')': + depth-- + if depth == 0 { + value := strings.TrimSpace(body[open : i+1]) + return &value + } + } + } + return nil +} + +func textResult(text string) *mcp.CallToolResult { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: text}}, + } +} + +func (s *Server) warnMissingSessionIDOnce() { + s.mu.Lock() + defer s.mu.Unlock() + if s.warnedMissingSessionID { + return + } + s.warnedMissingSessionID = true + s.logger.Warn("code mode code_mode_search_tools missing MCP session id; falling back to legacy stateless catalogue") +} diff --git a/router/internal/codemode/server/search_handler_test.go b/router/internal/codemode/server/search_handler_test.go new file mode 100644 index 0000000000..8a6c31b357 --- /dev/null +++ b/router/internal/codemode/server/search_handler_test.go @@ -0,0 +1,678 @@ +package server + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "sync" + "testing" + "time" + + "connectrpc.com/connect" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + yokov1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1" + "github.com/wundergraph/cosmo/router/internal/codemode/storage" + "github.com/wundergraph/cosmo/router/internal/codemode/tsgen" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" + "github.com/wundergraph/graphql-go-tools/v2/pkg/asttransform" + "go.uber.org/zap" +) + +const searchHandlerTestSchemaSDL = ` +schema { + query: Query + mutation: Mutation +} + +type Query { + orders(limit: Int): [Order!]! + customer(id: ID!): Customer +} + +type Mutation { + cancelOrder(id: ID!): Order! +} + +type Order { + id: ID! + total: Float! +} + +type Customer { + id: ID! + name: String! +} +` + +const emptySearchMessage = "// 0 new ops; previous code_mode_search_tools calls already cover these prompts." + +func TestHandleSearchValidatesPrompts(t *testing.T) { + tests := []struct { + name string + arguments map[string]any + want string + }{ + { + name: "missing prompts", + arguments: map[string]any{}, + want: "code_mode_search_tools: prompts must be a non-empty array of strings", + }, + { + name: "empty prompts", + arguments: map[string]any{"prompts": []string{}}, + want: "code_mode_search_tools: prompts must be a non-empty array of strings", + }, + { + name: "too many prompts", + arguments: map[string]any{"prompts": func() []string { + prompts := make([]string, 21) + for i := range prompts { + prompts[i] = fmt.Sprintf("prompt %d", i) + } + return prompts + }()}, + want: "too many prompts: 21 (max 20) — pass all prompts in one call", + }, + { + name: "empty prompt", + arguments: map[string]any{"prompts": []string{"orders", " \t\n"}}, + want: "code_mode_search_tools: prompt at index 1 is empty", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + srv := newSearchTestServer(t, false, newFakeYoko(), newSearchTestStorage(t)) + + got, err := srv.handleSearch(context.Background(), searchToolRequest(t, "session-1", tt.arguments)) + + require.NoError(t, err) + assert.Equal(t, toolError(tt.want), got) + }) + } +} + +func TestHandleSearchStatelessReturnsLegacyJSONCatalogue(t *testing.T) { + searcher := newFakeYoko() + searcher.responses <- &yokov1.SearchResponse{Operations: []*yokov1.GeneratedOperation{ + { + Name: "getOrders", + Body: "query GetOrders($limit: Int) { orders(limit: $limit) { id } }", + Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, + Description: "Fetch orders.", + }, + { + Name: "watchOrders", + Body: "subscription WatchOrders { orders { id } }", + Kind: yokoOperationKindSubscription, + Description: "Watch orders.", + }, + }} + store := newSearchTestStorage(t) + srv := newSearchTestServer(t, true, searcher, store) + + got, err := srv.handleSearch(context.Background(), searchToolRequest(t, "", map[string]any{ + "prompts": []string{"orders"}, + })) + + require.NoError(t, err) + expectedJSON := mustJSON(t, []legacyCatalogueEntry{ + { + Name: "getOrders", + Body: "query GetOrders($limit: Int) { orders(limit: $limit) { id } }", + Kind: "Query", + Description: "Fetch orders.", + Variables: ptrString("($limit: Int)"), + }, + }) + assert.Equal(t, textToolResult(expectedJSON), got) + assert.Equal(t, []searchCall{{sessionID: "", prompts: []string{"orders"}}}, searcher.callsSnapshot()) + assert.Equal(t, []storage.SessionOp(nil), store.opsSnapshot("session-1")) +} + +func TestHandleSearchStatefulAppendsAndReturnsNewOpsFragment(t *testing.T) { + searcher := newFakeYoko() + searcher.responses <- &yokov1.SearchResponse{Operations: []*yokov1.GeneratedOperation{ + { + Name: "getOrders", + Body: "query GetOrders($limit: Int) { orders(limit: $limit) { id total } }", + Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, + Description: "Fetch orders.", + }, + { + Name: "cancelOrder", + Body: "mutation CancelOrder($id: ID!) { cancelOrder(id: $id) { id } }", + Kind: yokov1.OperationKind_OPERATION_KIND_MUTATION, + Description: "Cancel an order.", + }, + { + Name: "watchOrders", + Body: "subscription WatchOrders { orders { id } }", + Kind: yokoOperationKindSubscription, + Description: "Watch orders.", + }, + }} + store := newSearchTestStorage(t) + srv := newSearchTestServer(t, false, searcher, store) + + got, err := srv.handleSearch(context.Background(), searchToolRequest(t, "session-1", map[string]any{ + "prompts": []string{"orders", "cancel order"}, + })) + + require.NoError(t, err) + wantOps := []storage.SessionOp{ + { + Name: "getOrders", + Body: "query GetOrders($limit: Int) { orders(limit: $limit) { id total } }", + Kind: storage.OperationKindQuery, + Description: "Fetch orders.", + }, + { + Name: "cancelOrder", + Body: "mutation CancelOrder($id: ID!) { cancelOrder(id: $id) { id } }", + Kind: storage.OperationKindMutation, + Description: "Cancel an order.", + }, + } + wantFragment, err := tsgen.NewOpsFragment(wantOps, searchHandlerTestSchema(t)) + require.NoError(t, err) + assert.Equal(t, textToolResult(wantFragment), got) + assert.Equal(t, wantOps, store.opsSnapshot("session-1")) + assert.Equal(t, []searchCall{{sessionID: "session-1", prompts: []string{"orders", "cancel order"}}}, searcher.callsSnapshot()) +} + +func TestHandleSearchFallsBackToStatelessWhenSessionIDMissing(t *testing.T) { + searcher := newFakeYoko() + searcher.responses <- &yokov1.SearchResponse{Operations: []*yokov1.GeneratedOperation{{ + Name: "getOrders", + Body: "query GetOrders { orders { id } }", + Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, + Description: "Fetch orders.", + }}} + store := newSearchTestStorage(t) + srv := newSearchTestServer(t, false, searcher, store) + + got, err := srv.handleSearch(context.Background(), searchToolRequest(t, "", map[string]any{ + "prompts": []string{"orders"}, + })) + + require.NoError(t, err) + expectedJSON := mustJSON(t, []legacyCatalogueEntry{{ + Name: "getOrders", + Body: "query GetOrders { orders { id } }", + Kind: "Query", + Description: "Fetch orders.", + Variables: nil, + }}) + assert.Equal(t, textToolResult(expectedJSON), got) + assert.Equal(t, []storage.SessionOp(nil), store.opsSnapshot("session-1")) +} + +func TestHandleSearchNamingCollisionUsesFinalStoredName(t *testing.T) { + searcher := newFakeYoko() + searcher.responses <- &yokov1.SearchResponse{Operations: []*yokov1.GeneratedOperation{{ + Name: "getOrders", + Body: "query GetOrdersAgain { orders { total } }", + Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, + Description: "Fetch order totals.", + }}} + store := newSearchTestStorage(t) + _, err := store.Append(context.Background(), "session-1", []storage.SessionOp{{ + Name: "getOrders", + Body: "query GetOrders { orders { id } }", + Kind: storage.OperationKindQuery, + }}) + require.NoError(t, err) + srv := newSearchTestServer(t, false, searcher, store) + + got, err := srv.handleSearch(context.Background(), searchToolRequest(t, "session-1", map[string]any{ + "prompts": []string{"orders again"}, + })) + + require.NoError(t, err) + wantOps := []storage.SessionOp{ + {Name: "getOrders", Body: "query GetOrders { orders { id } }", Kind: storage.OperationKindQuery}, + {Name: "getOrders_2", Body: "query GetOrdersAgain { orders { total } }", Kind: storage.OperationKindQuery, Description: "Fetch order totals."}, + } + wantFragment, err := tsgen.NewOpsFragment(wantOps[1:], searchHandlerTestSchema(t)) + require.NoError(t, err) + assert.Equal(t, textToolResult(wantFragment), got) + assert.Equal(t, wantOps, store.opsSnapshot("session-1")) +} + +func TestHandleSearchEmptyYokoResponseIsSuccess(t *testing.T) { + searcher := newFakeYoko() + searcher.responses <- &yokov1.SearchResponse{} + srv := newSearchTestServer(t, false, searcher, newSearchTestStorage(t)) + + got, err := srv.handleSearch(context.Background(), searchToolRequest(t, "session-1", map[string]any{ + "prompts": []string{"orders"}, + })) + + require.NoError(t, err) + assert.Equal(t, textToolResult(emptySearchMessage), got) +} + +func TestHandleSearchDoesNotRetryNotFoundFromSearcher(t *testing.T) { + searcher := newFakeYoko() + searcher.errs <- connect.NewError(connect.CodeNotFound, errors.New("missing index")) + srv := newSearchTestServer(t, false, searcher, newSearchTestStorage(t)) + + got, err := srv.handleSearch(context.Background(), searchToolRequest(t, "session-1", map[string]any{ + "prompts": []string{"orders"}, + })) + + require.NoError(t, err) + assert.Equal(t, toolError("code_mode_search_tools: yoko search failed: not_found: missing index"), got) + assert.Equal(t, 1, searcher.callCount()) +} + +func TestHandleSearchYokoErrorIsToolError(t *testing.T) { + searcher := newFakeYoko() + searcher.errs <- errors.New("dial tcp: connection refused") + srv := newSearchTestServer(t, false, searcher, newSearchTestStorage(t)) + + got, err := srv.handleSearch(context.Background(), searchToolRequest(t, "session-1", map[string]any{ + "prompts": []string{"orders"}, + })) + + require.NoError(t, err) + assert.Equal(t, toolError("code_mode_search_tools: yoko search failed: dial tcp: connection refused"), got) +} + +func TestHandleSearchSingleFlight(t *testing.T) { + t.Run("identical calls share leader result", func(t *testing.T) { + searcher := newFakeYoko() + searcher.block = make(chan struct{}) + searcher.responses <- &yokov1.SearchResponse{Operations: []*yokov1.GeneratedOperation{{ + Name: "getOrders", + Body: "query GetOrders { orders { id } }", + Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, + }}} + srv := newSearchTestServer(t, false, searcher, newSearchTestStorage(t)) + + ctx := context.Background() + var wg sync.WaitGroup + results := make([]*mcp.CallToolResult, 2) + wg.Add(1) + go func() { + defer wg.Done() + result, err := srv.handleSearch(ctx, searchToolRequest(t, "session-1", map[string]any{ + "prompts": []string{"orders", "customers"}, + })) + require.NoError(t, err) + results[0] = result + }() + require.Eventually(t, func() bool { return searcher.callCount() == 1 }, time.Second, time.Millisecond) + wg.Add(1) + go func() { + defer wg.Done() + result, err := srv.handleSearch(ctx, searchToolRequest(t, "session-1", map[string]any{ + "prompts": []string{"orders", "customers"}, + })) + require.NoError(t, err) + results[1] = result + }() + time.Sleep(10 * time.Millisecond) + close(searcher.block) + wg.Wait() + + assert.Equal(t, 1, searcher.callCount()) + assert.Equal(t, results[0], results[1]) + }) + + t.Run("different calls do not share result", func(t *testing.T) { + searcher := newFakeYoko() + searcher.responses <- &yokov1.SearchResponse{} + searcher.responses <- &yokov1.SearchResponse{} + srv := newSearchTestServer(t, false, searcher, newSearchTestStorage(t)) + + var wg sync.WaitGroup + for _, prompt := range []string{"orders", "customers"} { + wg.Add(1) + go func(prompt string) { + defer wg.Done() + _, err := srv.handleSearch(context.Background(), searchToolRequest(t, "session-1", map[string]any{ + "prompts": []string{prompt}, + })) + require.NoError(t, err) + }(prompt) + } + wg.Wait() + + assert.Equal(t, 2, searcher.callCount()) + }) + + t.Run("ambiguous spacing prompt sets do not share result", func(t *testing.T) { + searcher := newFakeYoko() + searcher.block = make(chan struct{}) + searcher.responses <- &yokov1.SearchResponse{} + searcher.responses <- &yokov1.SearchResponse{} + srv := newSearchTestServer(t, false, searcher, newSearchTestStorage(t)) + + var wg sync.WaitGroup + for _, prompts := range [][]string{ + {"a b", "c"}, + {"a", "b c"}, + } { + prompts := prompts + wg.Add(1) + go func() { + defer wg.Done() + _, err := srv.handleSearch(context.Background(), searchToolRequest(t, "session-1", map[string]any{ + "prompts": prompts, + })) + require.NoError(t, err) + }() + } + + require.Eventually(t, func() bool { return searcher.callCount() == 2 }, time.Second, time.Millisecond) + close(searcher.block) + wg.Wait() + + assert.Equal(t, 2, searcher.callCount()) + }) +} + +func TestHandleSearchRenderErrorIsToolError(t *testing.T) { + searcher := newFakeYoko() + searcher.responses <- &yokov1.SearchResponse{Operations: []*yokov1.GeneratedOperation{{ + Name: "getOrders", + Body: "query GetOrders { orders { id } }", + Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, + }}} + srv := newSearchTestServer(t, false, searcher, newSearchTestStorage(t)) + srv.newOpsFragment = func([]storage.SessionOp, *ast.Document) (string, error) { + return "", errors.New("render exploded") + } + + got, err := srv.handleSearch(context.Background(), searchToolRequest(t, "session-1", map[string]any{ + "prompts": []string{"orders"}, + })) + + require.NoError(t, err) + assert.Equal(t, toolError("code_mode_search_tools: failed to render new ops: render exploded"), got) +} + +func TestHandleSearchCancelMaySurfaceLeaderCancellationToFollower(t *testing.T) { + searcher := newFakeYoko() + searcher.block = make(chan struct{}) + srv := newSearchTestServer(t, false, searcher, newSearchTestStorage(t)) + + leaderCtx, cancelLeader := context.WithCancel(context.Background()) + defer cancelLeader() + + var wg sync.WaitGroup + results := make([]*mcp.CallToolResult, 2) + wg.Add(1) + go func() { + defer wg.Done() + result, err := srv.handleSearch(leaderCtx, searchToolRequest(t, "session-1", map[string]any{ + "prompts": []string{"orders"}, + })) + require.NoError(t, err) + results[0] = result + }() + require.Eventually(t, func() bool { return searcher.callCount() == 1 }, time.Second, time.Millisecond) + + wg.Add(1) + go func() { + defer wg.Done() + result, err := srv.handleSearch(context.Background(), searchToolRequest(t, "session-1", map[string]any{ + "prompts": []string{"orders"}, + })) + require.NoError(t, err) + results[1] = result + }() + time.Sleep(10 * time.Millisecond) + cancelLeader() + close(searcher.block) + wg.Wait() + + assert.Equal(t, 1, searcher.callCount()) + assert.Equal(t, toolError("code_mode_search_tools: yoko search failed: context canceled"), results[0]) + assert.Equal(t, toolError("code_mode_search_tools: yoko search failed: context canceled"), results[1]) +} + +type searchCall struct { + sessionID string + prompts []string +} + +type fakeYoko struct { + mu sync.Mutex + calls []searchCall + responses chan *yokov1.SearchResponse + errs chan error + block chan struct{} + schema string + ensureIndexed int + ensureIndexedErr error +} + +func newFakeYoko() *fakeYoko { + return &fakeYoko{ + responses: make(chan *yokov1.SearchResponse, 16), + errs: make(chan error, 16), + } +} + +func (f *fakeYoko) Search(ctx context.Context, sessionID string, prompts []string) (*yokov1.SearchResponse, error) { + f.mu.Lock() + f.calls = append(f.calls, searchCall{sessionID: sessionID, prompts: append([]string(nil), prompts...)}) + f.mu.Unlock() + + if f.block != nil { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-f.block: + } + } + + select { + case err := <-f.errs: + return nil, err + default: + } + select { + case response := <-f.responses: + return response, nil + default: + return &yokov1.SearchResponse{}, nil + } +} + +func (f *fakeYoko) SetSchema(schema string) { + f.mu.Lock() + defer f.mu.Unlock() + f.schema = schema +} + +func (f *fakeYoko) Schema() string { + f.mu.Lock() + defer f.mu.Unlock() + return f.schema +} + +func (f *fakeYoko) EnsureIndexed(context.Context) error { + f.mu.Lock() + defer f.mu.Unlock() + f.ensureIndexed++ + return f.ensureIndexedErr +} + +func (f *fakeYoko) ensureIndexedCallCount() int { + f.mu.Lock() + defer f.mu.Unlock() + return f.ensureIndexed +} + +func (f *fakeYoko) callCount() int { + f.mu.Lock() + defer f.mu.Unlock() + return len(f.calls) +} + +func (f *fakeYoko) callsSnapshot() []searchCall { + f.mu.Lock() + defer f.mu.Unlock() + calls := make([]searchCall, 0, len(f.calls)) + for _, call := range f.calls { + calls = append(calls, searchCall{sessionID: call.sessionID, prompts: append([]string(nil), call.prompts...)}) + } + return calls +} + +type searchTestStorage struct { + mu sync.Mutex + schema *ast.Document + ops map[string][]storage.SessionOp +} + +func newSearchTestStorage(t *testing.T) *searchTestStorage { + t.Helper() + return &searchTestStorage{ + schema: searchHandlerTestSchema(t), + ops: make(map[string][]storage.SessionOp), + } +} + +func (s *searchTestStorage) Append(ctx context.Context, sessionID string, ops []storage.SessionOp) ([]storage.SessionOp, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + s.mu.Lock() + defer s.mu.Unlock() + + taken := make(map[string]struct{}, len(s.ops[sessionID])+len(ops)) + for _, op := range s.ops[sessionID] { + taken[op.Name] = struct{}{} + } + + appended := make([]storage.SessionOp, 0, len(ops)) + for _, op := range ops { + op.Name = storage.SuffixedName(storage.NormalizeName(op.Name), taken) + taken[op.Name] = struct{}{} + s.ops[sessionID] = append(s.ops[sessionID], op) + appended = append(appended, op) + } + return appended, nil +} + +func (s *searchTestStorage) GetOp(_ context.Context, sessionID string, name string) (storage.SessionOp, bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + for _, op := range s.ops[sessionID] { + if op.Name == name { + return op, true, nil + } + } + return storage.SessionOp{}, false, nil +} + +func (s *searchTestStorage) ListNames(_ context.Context, sessionID string) ([]string, error) { + s.mu.Lock() + defer s.mu.Unlock() + names := make([]string, 0, len(s.ops[sessionID])) + for _, op := range s.ops[sessionID] { + names = append(names, op.Name) + } + return names, nil +} + +func (s *searchTestStorage) Bundle(context.Context, string) (string, error) { + return "", nil +} + +func (s *searchTestStorage) Reset(_ context.Context, sessionID string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.ops, sessionID) + return nil +} + +func (s *searchTestStorage) SetSchema(schema *ast.Document) { + s.mu.Lock() + defer s.mu.Unlock() + s.schema = schema +} + +func (s *searchTestStorage) Schema() *ast.Document { + s.mu.Lock() + defer s.mu.Unlock() + return s.schema +} + +func (s *searchTestStorage) Start(context.Context) error { + return nil +} + +func (s *searchTestStorage) Stop() error { + return nil +} + +func (s *searchTestStorage) opsSnapshot(sessionID string) []storage.SessionOp { + s.mu.Lock() + defer s.mu.Unlock() + return append([]storage.SessionOp(nil), s.ops[sessionID]...) +} + +type legacyCatalogueEntry struct { + Name string `json:"name"` + Body string `json:"body"` + Kind string `json:"kind"` + Description string `json:"description"` + Variables *string `json:"variables"` +} + +func newSearchTestServer(t *testing.T, stateless bool, searcher *fakeYoko, store *searchTestStorage) *Server { + t.Helper() + srv, err := New(Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: stateless, + Storage: store, + YokoClient: searcher, + BundleRenderer: storage.RendererFunc(func([]storage.SessionOp) (string, error) { return "", nil }), + Logger: zap.NewNop(), + }) + require.NoError(t, err) + return srv +} + +func searchToolRequest(t *testing.T, sessionID string, arguments map[string]any) *mcp.CallToolRequest { + t.Helper() + body, err := json.Marshal(arguments) + require.NoError(t, err) + return &mcp.CallToolRequest{ + Params: &mcp.CallToolParamsRaw{ + Name: "code_mode_search_tools", + Arguments: body, + }, + Extra: &mcp.RequestExtra{Header: http.Header{mcpSessionIDHeader: []string{sessionID}}}, + } +} + +func textToolResult(text string) *mcp.CallToolResult { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: text}}, + } +} + +func ptrString(value string) *string { + return &value +} + +func searchHandlerTestSchema(t *testing.T) *ast.Document { + t.Helper() + doc, report := astparser.ParseGraphqlDocumentString(searchHandlerTestSchemaSDL) + require.False(t, report.HasErrors(), report.Error()) + require.NoError(t, asttransform.MergeDefinitionWithBaseSchema(&doc)) + return &doc +} diff --git a/router/internal/codemode/server/server.go b/router/internal/codemode/server/server.go new file mode 100644 index 0000000000..c185585ec7 --- /dev/null +++ b/router/internal/codemode/server/server.go @@ -0,0 +1,494 @@ +package server + +import ( + "context" + "encoding/json" + "errors" + "net" + "net/http" + "sync" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/wundergraph/cosmo/router/internal/codemode/calltrace" + "github.com/wundergraph/cosmo/router/internal/codemode/harness" + "github.com/wundergraph/cosmo/router/internal/codemode/observability" + "github.com/wundergraph/cosmo/router/internal/codemode/sandbox" + "github.com/wundergraph/cosmo/router/internal/codemode/server/descriptions" + "github.com/wundergraph/cosmo/router/internal/codemode/storage" + "github.com/wundergraph/cosmo/router/internal/codemode/tsgen" + "github.com/wundergraph/cosmo/router/internal/codemode/yoko" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + otelmetric "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" + "golang.org/x/sync/singleflight" +) + +const ( + defaultListenAddr = "localhost:5027" + defaultExecuteTimeout = 120 * time.Second + defaultMaxResultBytes = 32 << 10 + mcpPath = "/mcp" + persistedOpsURI = "yoko://persisted-ops.d.ts" + statelessNamedOpsWarnMessage = "code mode named operations are disabled because MCP session stateless mode is enabled" + namedOpsDisabledMessage = "named operations are disabled" +) + +// Config configures the Code Mode MCP server. +type Config struct { + ListenAddr string + CodeModeEnabled bool + NamedOpsEnabled bool + SessionStateless bool + Storage storage.SessionStorage + Pipeline harness.Executor + YokoClient yoko.Searcher + BundleRenderer storage.Renderer + ExecuteTimeout time.Duration + MaxResultBytes int + ApprovalGate sandbox.ApprovalGate + Logger *zap.Logger + MeterProvider otelmetric.MeterProvider + TracerProvider trace.TracerProvider + CallTraceRecorder calltrace.Recorder +} + +// Server owns the Code Mode MCP server and its separate HTTP listener. +type Server struct { + listenAddr string + codeModeEnabled bool + namedOpsEnabled bool + sessionStateless bool + storage storage.SessionStorage + pipeline harness.Executor + yokoClient yoko.Searcher + bundleRenderer storage.Renderer + executeTimeout time.Duration + maxResultBytes int + approvalGate sandbox.ApprovalGate + logger *zap.Logger + meter *observability.Meter + tracerProvider trace.TracerProvider + callTraceRecorder calltrace.Recorder + + mcpServer *mcp.Server + searchGroup singleflight.Group + newOpsFragment func([]storage.SessionOp, *ast.Document) (string, error) + + mu sync.Mutex + httpServer *http.Server + actualAddr string + warnedStatelessNamedOps bool + warnedMissingSessionID bool +} + +// New creates a Code Mode MCP server. +func New(cfg Config) (*Server, error) { + if cfg.ListenAddr == "" { + cfg.ListenAddr = defaultListenAddr + } + if cfg.Logger == nil { + cfg.Logger = zap.NewNop() + } + if cfg.MeterProvider == nil { + cfg.MeterProvider = otel.GetMeterProvider() + } + if cfg.TracerProvider == nil { + cfg.TracerProvider = otel.GetTracerProvider() + } + if cfg.CallTraceRecorder == nil { + cfg.CallTraceRecorder = calltrace.NopRecorder{} + } + if cfg.ExecuteTimeout <= 0 { + cfg.ExecuteTimeout = defaultExecuteTimeout + } + if cfg.MaxResultBytes <= 0 { + cfg.MaxResultBytes = defaultMaxResultBytes + } + if pipeline, ok := cfg.Pipeline.(*harness.Pipeline); ok { + pipeline.MaxResultBytes = cfg.MaxResultBytes + } + meter, err := observability.NewMeter(cfg.MeterProvider) + if err != nil { + return nil, err + } + + s := &Server{ + listenAddr: cfg.ListenAddr, + codeModeEnabled: cfg.CodeModeEnabled, + namedOpsEnabled: cfg.NamedOpsEnabled, + sessionStateless: cfg.SessionStateless, + storage: cfg.Storage, + pipeline: cfg.Pipeline, + yokoClient: cfg.YokoClient, + bundleRenderer: cfg.BundleRenderer, + executeTimeout: cfg.ExecuteTimeout, + maxResultBytes: cfg.MaxResultBytes, + approvalGate: cfg.ApprovalGate, + logger: cfg.Logger, + meter: meter, + tracerProvider: cfg.TracerProvider, + callTraceRecorder: cfg.CallTraceRecorder, + newOpsFragment: tsgen.NewOpsFragment, + } + + s.mcpServer = mcp.NewServer(&mcp.Implementation{ + Name: "yoko", + Title: "Yoko (Cosmo Code Mode)", + Version: "v0.1.0", + }, &mcp.ServerOptions{ + HasResources: true, + }) + + if cfg.CodeModeEnabled { + s.registerTools() + if cfg.NamedOpsEnabled && !cfg.SessionStateless { + s.registerPersistedOpsResource() + } + } + + return s, nil +} + +// Start binds the separate Code Mode MCP HTTP listener and serves until the +// server shuts down or ctx is canceled. When Code Mode is disabled it is a no-op. +func (s *Server) Start(ctx context.Context) error { + if !s.codeModeEnabled { + return nil + } + + if s.storage != nil { + if err := s.storage.Start(ctx); err != nil { + return err + } + } + + listener, err := net.Listen("tcp", s.listenAddr) + if err != nil { + if s.storage != nil { + _ = s.storage.Stop() + } + return err + } + + // WriteTimeout must exceed executeTimeout — net/http enforces it as a + // hard deadline on the whole response phase, which would cut off + // legitimately long code_mode_run_js calls. ReadHeaderTimeout bounds the + // header read separately so the listener still resists slow-loris clients. + httpServer := &http.Server{ + Addr: s.listenAddr, + Handler: s.handler(), + ReadHeaderTimeout: 30 * time.Second, + WriteTimeout: s.executeTimeout + 30*time.Second, + IdleTimeout: 60 * time.Second, + } + + s.mu.Lock() + s.httpServer = httpServer + s.actualAddr = listener.Addr().String() + s.mu.Unlock() + + done := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = s.Stop(shutdownCtx) + case <-done: + } + }() + + s.logger.Info("Code Mode MCP server started", zap.String("listen_addr", listener.Addr().String()), zap.String("path", mcpPath)) + err = httpServer.Serve(listener) + close(done) + if errors.Is(err, http.ErrServerClosed) { + return nil + } + return err +} + +// Stop gracefully shuts down the Code Mode MCP HTTP server. Disabled or unstarted +// servers are no-ops. +func (s *Server) Stop(ctx context.Context) error { + if !s.codeModeEnabled { + return nil + } + + s.mu.Lock() + httpServer := s.httpServer + s.mu.Unlock() + if httpServer == nil { + if s.storage != nil { + return s.storage.Stop() + } + return nil + } + err := httpServer.Shutdown(ctx) + if err == nil || errors.Is(err, http.ErrServerClosed) { + s.mu.Lock() + if s.httpServer == httpServer { + s.httpServer = nil + } + s.mu.Unlock() + if s.storage != nil { + return s.storage.Stop() + } + return nil + } + return err +} + +// Reload forwards schema state into Code Mode dependencies. Disabled servers +// ignore reloads. +func (s *Server) Reload(schema *ast.Document, sdl string) error { + if !s.codeModeEnabled { + return nil + } + if s.storage != nil { + s.storage.SetSchema(schema) + } + if s.yokoClient != nil { + s.yokoClient.SetSchema(sdl) + // Eagerly index the new SDL in the background so the first user-facing + // code_mode_search_tools call doesn't pay the IndexSchema round-trip + // latency. Failures are logged and ignored — the lazy path inside + // Search will retry on the next call. + // + // recover guard: an unrecovered panic here would bring the whole + // router down because the goroutine runs outside any caller frame. + // The warm-up is best-effort, so a panic must never escape. + if sdl != "" { + yokoClient := s.yokoClient + logger := s.logger + sdlBytes := len(sdl) + go func() { + start := time.Now() + defer func() { + if r := recover(); r != nil { + logger.Error("code mode eager schema index panicked", + zap.Any("panic", r), + zap.Duration("duration", time.Since(start)), + ) + } + }() + logger.Info("code mode eager schema index started", + zap.Int("sdl_bytes", sdlBytes), + ) + if err := yokoClient.EnsureIndexed(context.Background()); err != nil { + logger.Warn("code mode eager schema index failed", + zap.Error(err), + zap.Duration("duration", time.Since(start)), + ) + return + } + logger.Info("code mode eager schema index completed", + zap.Duration("duration", time.Since(start)), + ) + }() + } + } + if s.sessionStateless && s.namedOpsEnabled { + s.warnStatelessNamedOpsOnce() + } + observability.LogSessionLifecycle(s.logger, "schema_swap", "", zap.Int("sdl_bytes", len(sdl))) + return nil +} + +func (s *Server) registerTools() { + s.mcpServer.AddTool(&mcp.Tool{ + Name: "code_mode_search_tools", + Description: descriptions.SearchTool, + InputSchema: searchAPIInputSchema(), + }, s.handleSearch) + + s.mcpServer.AddTool(&mcp.Tool{ + Name: "code_mode_run_js", + Description: descriptions.ExecuteTool, + InputSchema: executeAPIInputSchema(), + }, s.handleExecute) +} + +func (s *Server) registerPersistedOpsResource() { + s.mcpServer.AddResource(&mcp.Resource{ + URI: persistedOpsURI, + Name: "persisted-ops.d.ts", + Title: "Persisted operations TypeScript definitions", + Description: descriptions.PersistedOpsResource, + MIMEType: "text/plain", + }, s.handlePersistedOpsResource) +} + +func (s *Server) handler() http.Handler { + cop := http.NewCrossOriginProtection() + cop.AddInsecureBypassPattern("/{path...}") + + streamableHTTPHandler := mcp.NewStreamableHTTPHandler( + func(*http.Request) *mcp.Server { + return s.mcpServer + }, + &mcp.StreamableHTTPOptions{ + Stateless: s.sessionStateless, + CrossOriginProtection: cop, + }, + ) + + mux := http.NewServeMux() + mux.Handle(mcpPath, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + req = req.WithContext(withSessionIDFromRequest(req.Context(), req)) + streamableHTTPHandler.ServeHTTP(w, req) + })) + return mux +} + +func (s *Server) handleSearch(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return s.handleTool(ctx, req, "code_mode_search_tools", s.handleSearchAPI) +} + +func (s *Server) handleExecute(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return s.handleTool(ctx, req, "code_mode_run_js", s.handleExecuteAPI) +} + +func (s *Server) handleTool(ctx context.Context, req *mcp.CallToolRequest, toolName string, next func(context.Context, *mcp.CallToolRequest) (*mcp.CallToolResult, error)) (result *mcp.CallToolResult, err error) { + start := time.Now() + ctx, span := observability.StartToolSpanWithProvider(ctx, s.tracerProvider, toolName) + sessionID := sessionIDFromToolRequest(req) + if calltrace.Enabled(s.callTraceRecorder) { + s.callTraceRecorder.RecordRequest(toolName, toolRequestBody(req)) + } + observability.LogSessionLifecycle(s.logger, toolName+".started", sessionID) + defer func() { + status := toolStatus(result, err) + durationMs := float64(time.Since(start)) / float64(time.Millisecond) + span.SetAttributes(attribute.String("mcp.status", status)) + s.meter.Record(ctx, toolName, status, durationMs) + observability.LogSessionLifecycle(s.logger, toolName+".completed", sessionID, + zap.String("status", status), + zap.Float64("duration_ms", durationMs), + ) + span.End() + }() + + result, err = next(ctx, req) + if calltrace.Enabled(s.callTraceRecorder) { + if body, marshalErr := json.Marshal(result); marshalErr == nil { + s.callTraceRecorder.RecordResponse(toolName, body) + } + } + return result, err +} + +func toolStatus(result *mcp.CallToolResult, err error) string { + if err != nil || (result != nil && result.IsError) { + return "error" + } + return "success" +} + +func sessionIDFromToolRequest(req *mcp.CallToolRequest) string { + if req == nil || req.GetExtra() == nil { + return "" + } + return req.GetExtra().Header.Get(mcpSessionIDHeader) +} + +func toolRequestBody(req *mcp.CallToolRequest) []byte { + if req == nil || req.Params == nil || len(req.Params.Arguments) == 0 { + return []byte(`null`) + } + return append([]byte(nil), req.Params.Arguments...) +} + +func (s *Server) handlePersistedOpsResource(ctx context.Context, req *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { + ctx = contextWithSessionFromExtra(ctx, req.GetExtra()) + if s.storage == nil { + return &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{{ + URI: persistedOpsURI, + MIMEType: "text/plain", + Text: "", + }}, + }, nil + } + bundle, err := s.storage.Bundle(ctx, SessionIDFromContext(ctx)) + if err != nil { + return nil, err + } + return &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{{ + URI: persistedOpsURI, + MIMEType: "text/plain", + Text: bundle, + }}, + }, nil +} + +func contextWithSessionFromExtra(ctx context.Context, extra *mcp.RequestExtra) context.Context { + if extra == nil { + return WithSessionID(ctx, "") + } + return WithSessionID(ctx, extra.Header.Get(mcpSessionIDHeader)) +} + +func toolErrorResult(message string) *mcp.CallToolResult { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: message}}, + IsError: true, + } +} + +func searchAPIInputSchema() map[string]any { + return map[string]any{ + "type": "object", + "required": []any{"prompts"}, + "properties": map[string]any{ + "prompts": map[string]any{ + "type": "array", + "minItems": 1, + "maxItems": 20, + "items": map[string]any{ + "type": "string", + "minLength": 1, + }, + }, + }, + } +} + +func executeAPIInputSchema() map[string]any { + return map[string]any{ + "type": "object", + "required": []any{"source"}, + "properties": map[string]any{ + "source": map[string]any{ + "type": "string", + "minLength": 1, + "description": descriptions.ExecuteSource, + }, + }, + } +} + +func (s *Server) warnStatelessNamedOpsOnce() { + s.mu.Lock() + defer s.mu.Unlock() + if s.warnedStatelessNamedOps { + return + } + s.warnedStatelessNamedOps = true + s.logger.Warn(statelessNamedOpsWarnMessage) +} + +// Addr returns the listener address once Start has bound it. +func (s *Server) Addr() string { + return s.addr() +} + +func (s *Server) addr() string { + s.mu.Lock() + defer s.mu.Unlock() + return s.actualAddr +} diff --git a/router/internal/codemode/server/server_test.go b/router/internal/codemode/server/server_test.go new file mode 100644 index 0000000000..ed7151c5c6 --- /dev/null +++ b/router/internal/codemode/server/server_test.go @@ -0,0 +1,556 @@ +package server + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "slices" + "sync" + "syscall" + "testing" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/internal/codemode/server/descriptions" + "github.com/wundergraph/cosmo/router/internal/codemode/storage" + "github.com/wundergraph/cosmo/router/internal/codemode/yoko" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "go.uber.org/zap" + "go.uber.org/zap/zaptest/observer" +) + +func TestStartDisabledReturnsWithoutBinding(t *testing.T) { + srv, err := New(Config{ + ListenAddr: "127.0.0.1:0", + CodeModeEnabled: false, + Storage: newRecordingStorage(), + YokoClient: yoko.New(nil, "http://127.0.0.1", zap.NewNop()), + BundleRenderer: storage.RendererFunc(func([]storage.SessionOp) (string, error) { return "", nil }), + Logger: zap.NewNop(), + NamedOpsEnabled: true, + SessionStateless: false, + }) + require.NoError(t, err) + + err = srv.Start(context.Background()) + + require.NoError(t, err) + assert.Equal(t, "", srv.addr()) + require.NoError(t, srv.Stop(context.Background())) +} + +func TestListToolsReturnsCodeModeTools(t *testing.T) { + srv := newTestServer(t, Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: false, + }) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + startServer(t, ctx, srv) + defer stopServer(t, srv) + + session := connectHTTPClient(t, ctx, "http://"+srv.addr()+"/mcp") + defer session.Close() + + got, err := session.ListTools(ctx, &mcp.ListToolsParams{}) + require.NoError(t, err) + require.Len(t, got.Tools, 2) + slices.SortFunc(got.Tools, func(a, b *mcp.Tool) int { + if a.Name < b.Name { + return -1 + } + if a.Name > b.Name { + return 1 + } + return 0 + }) + + assert.Equal(t, mustJSON(t, []*mcp.Tool{ + { + Name: "code_mode_run_js", + Description: descriptions.ExecuteTool, + InputSchema: map[string]any{ + "type": "object", + "required": []any{"source"}, + "properties": map[string]any{ + "source": map[string]any{ + "type": "string", + "minLength": float64(1), + "description": descriptions.ExecuteSource, + }, + }, + }, + }, + { + Name: "code_mode_search_tools", + Description: descriptions.SearchTool, + InputSchema: map[string]any{ + "type": "object", + "required": []any{"prompts"}, + "properties": map[string]any{ + "prompts": map[string]any{ + "type": "array", + "minItems": float64(1), + "maxItems": float64(20), + "items": map[string]any{ + "type": "string", + "minLength": float64(1), + }, + }, + }, + }, + }, + }), mustJSON(t, got.Tools)) +} + +func TestListResourcesGating(t *testing.T) { + tests := []struct { + name string + cfg Config + wantPresent bool + }{ + { + name: "code mode disabled", + cfg: Config{ + CodeModeEnabled: false, + NamedOpsEnabled: true, + SessionStateless: false, + }, + }, + { + name: "named ops disabled", + cfg: Config{ + CodeModeEnabled: true, + NamedOpsEnabled: false, + SessionStateless: false, + }, + }, + { + name: "stateless disables named ops", + cfg: Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: true, + }, + }, + { + name: "stateful named ops", + cfg: Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: false, + }, + wantPresent: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + srv := newTestServer(t, tt.cfg) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + session := connectInMemoryClient(t, ctx, srv) + defer session.Close() + + got, err := session.ListResources(ctx, &mcp.ListResourcesParams{}) + require.NoError(t, err) + assert.Equal(t, tt.wantPresent, hasResource(got.Resources, persistedOpsURI)) + }) + } +} + +func TestStatelessNamedOpsReloadWarnsOnce(t *testing.T) { + core, recorded := observer.New(zap.WarnLevel) + store := newRecordingStorage() + client := yoko.New(nil, "http://127.0.0.1", zap.NewNop()) + srv, err := New(Config{ + ListenAddr: "127.0.0.1:0", + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: true, + Storage: store, + YokoClient: client, + BundleRenderer: storage.RendererFunc(func([]storage.SessionOp) (string, error) { return "", nil }), + Logger: zap.New(core), + }) + require.NoError(t, err) + + require.NoError(t, srv.Reload(&ast.Document{}, "schema { query: Query }")) + require.NoError(t, srv.Reload(&ast.Document{}, "schema { query: Query }")) + + assert.Equal(t, 1, recorded.FilterMessage(statelessNamedOpsWarnMessage).Len()) +} + +func TestExecuteToolStubReturnsDeterministicToolError(t *testing.T) { + srv := newTestServer(t, Config{ + CodeModeEnabled: true, + NamedOpsEnabled: false, + SessionStateless: false, + }) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + startServer(t, ctx, srv) + defer stopServer(t, srv) + + session := connectHTTPClient(t, ctx, "http://"+srv.addr()+"/mcp") + defer session.Close() + + executeResult, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "code_mode_run_js", + Arguments: map[string]any{"source": "async () => null"}, + }) + require.NoError(t, err) + assert.Equal(t, mustJSON(t, toolError("named operations are disabled")), mustJSON(t, executeResult)) +} + +func TestSessionIDExtraction(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, "http://example.com/mcp", nil) + require.NoError(t, err) + req.Header.Set("Mcp-Session-Id", "session-123") + + ctx := withSessionIDFromRequest(context.Background(), req) + + assert.Equal(t, "session-123", SessionIDFromContext(ctx)) + assert.Equal(t, "", SessionIDFromContext(context.Background())) + assert.Equal(t, "manual", SessionIDFromContext(WithSessionID(context.Background(), "manual"))) +} + +func TestResourceHandlerUsesCurrentSessionID(t *testing.T) { + store := newRecordingStorage() + store.bundle = "declare const tools: { getUser(): R<{ id: string }> };" + srv := newTestServer(t, Config{ + CodeModeEnabled: true, + NamedOpsEnabled: true, + SessionStateless: false, + Storage: store, + }) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + startServer(t, ctx, srv) + defer stopServer(t, srv) + + session := connectHTTPClient(t, ctx, "http://"+srv.addr()+"/mcp") + defer session.Close() + + got, err := session.ReadResource(ctx, &mcp.ReadResourceParams{URI: persistedOpsURI}) + require.NoError(t, err) + + require.NotEmpty(t, session.ID()) + assert.Equal(t, session.ID(), store.lastBundleSessionID()) + assert.Equal(t, mustJSON(t, &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{{ + URI: persistedOpsURI, + MIMEType: "text/plain", + Text: store.bundle, + }}, + }), mustJSON(t, got)) +} + +func TestReloadForwardsSchemaAndSDL(t *testing.T) { + store := newRecordingStorage() + client := yoko.New(nil, "http://127.0.0.1", zap.NewNop()) + srv, err := New(Config{ + ListenAddr: "127.0.0.1:0", + CodeModeEnabled: true, + NamedOpsEnabled: false, + SessionStateless: false, + Storage: store, + YokoClient: client, + BundleRenderer: storage.RendererFunc(func([]storage.SessionOp) (string, error) { return "", nil }), + Logger: zap.NewNop(), + }) + require.NoError(t, err) + + schema := &ast.Document{} + require.NoError(t, srv.Reload(schema, "schema { query: Query }")) + + assert.Equal(t, schema, store.schema) + assert.Equal(t, 1, store.setSchemaCalls) + assert.Equal(t, "schema { query: Query }", client.Schema()) +} + +func TestReloadEagerlyIndexesViaBackgroundGoroutine(t *testing.T) { + core, recorded := observer.New(zap.InfoLevel) + searcher := newFakeYoko() + srv, err := New(Config{ + ListenAddr: "127.0.0.1:0", + CodeModeEnabled: true, + NamedOpsEnabled: false, + SessionStateless: false, + Storage: newRecordingStorage(), + YokoClient: searcher, + BundleRenderer: storage.RendererFunc(func([]storage.SessionOp) (string, error) { return "", nil }), + Logger: zap.New(core), + }) + require.NoError(t, err) + + require.NoError(t, srv.Reload(&ast.Document{}, "schema { query: Query }")) + + require.Eventually(t, func() bool { + return searcher.ensureIndexedCallCount() == 1 + }, 2*time.Second, 5*time.Millisecond, "eager index should fire once after Reload") + + require.Eventually(t, func() bool { + return recorded.FilterMessage("code mode eager schema index started").Len() == 1 && + recorded.FilterMessage("code mode eager schema index completed").Len() == 1 + }, 2*time.Second, 5*time.Millisecond, "expected start+completed info logs") +} + +func TestReloadEagerIndexLogsWarnOnFailure(t *testing.T) { + core, recorded := observer.New(zap.InfoLevel) + searcher := newFakeYoko() + searcher.ensureIndexedErr = errors.New("yoko unreachable") + srv, err := New(Config{ + ListenAddr: "127.0.0.1:0", + CodeModeEnabled: true, + NamedOpsEnabled: false, + SessionStateless: false, + Storage: newRecordingStorage(), + YokoClient: searcher, + BundleRenderer: storage.RendererFunc(func([]storage.SessionOp) (string, error) { return "", nil }), + Logger: zap.New(core), + }) + require.NoError(t, err) + + require.NoError(t, srv.Reload(&ast.Document{}, "schema { query: Query }")) + + require.Eventually(t, func() bool { + return recorded.FilterMessage("code mode eager schema index started").Len() == 1 && + recorded.FilterMessage("code mode eager schema index failed").Len() == 1 && + recorded.FilterMessage("code mode eager schema index completed").Len() == 0 + }, 2*time.Second, 5*time.Millisecond, "expected start+failed logs without completed log") +} + +func TestReloadEagerIndexSkippedWhenSDLEmpty(t *testing.T) { + searcher := newFakeYoko() + srv, err := New(Config{ + ListenAddr: "127.0.0.1:0", + CodeModeEnabled: true, + NamedOpsEnabled: false, + SessionStateless: false, + Storage: newRecordingStorage(), + YokoClient: searcher, + BundleRenderer: storage.RendererFunc(func([]storage.SessionOp) (string, error) { return "", nil }), + Logger: zap.NewNop(), + }) + require.NoError(t, err) + + require.NoError(t, srv.Reload(&ast.Document{}, "")) + + // Give the goroutine that EnsureIndexed *would* have launched a chance to + // run; assert it never did. + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 0, searcher.ensureIndexedCallCount()) +} + +func TestReloadDisabledIsNoOp(t *testing.T) { + store := newRecordingStorage() + client := yoko.New(nil, "http://127.0.0.1", zap.NewNop()) + srv, err := New(Config{ + ListenAddr: "127.0.0.1:0", + CodeModeEnabled: false, + NamedOpsEnabled: true, + SessionStateless: false, + Storage: store, + YokoClient: client, + BundleRenderer: storage.RendererFunc(func([]storage.SessionOp) (string, error) { return "", nil }), + Logger: zap.NewNop(), + }) + require.NoError(t, err) + + require.NoError(t, srv.Reload(&ast.Document{}, "schema { query: Query }")) + + assert.Equal(t, 0, store.setSchemaCalls) + assert.Equal(t, "", client.Schema()) +} + +func newTestServer(t *testing.T, cfg Config) *Server { + t.Helper() + if cfg.ListenAddr == "" { + cfg.ListenAddr = "127.0.0.1:0" + } + if cfg.Storage == nil { + cfg.Storage = newRecordingStorage() + } + if cfg.YokoClient == nil { + cfg.YokoClient = yoko.New(nil, "http://127.0.0.1", zap.NewNop()) + } + if cfg.BundleRenderer == nil { + cfg.BundleRenderer = storage.RendererFunc(func([]storage.SessionOp) (string, error) { return "", nil }) + } + if cfg.Logger == nil { + cfg.Logger = zap.NewNop() + } + srv, err := New(cfg) + require.NoError(t, err) + return srv +} + +func startServer(t *testing.T, ctx context.Context, srv *Server) { + t.Helper() + errs := make(chan error, 1) + go func() { + errs <- srv.Start(ctx) + }() + deadline := time.After(5 * time.Second) + tick := time.NewTicker(10 * time.Millisecond) + defer tick.Stop() + bound := false + for { + select { + case err := <-errs: + if isBindPermissionError(err) { + t.Skipf("local listener bind is not permitted in this environment: %v", err) + } + require.NoError(t, err) + case <-deadline: + require.FailNow(t, "server listener was not bound") + case <-tick.C: + if srv.addr() != "" { + bound = true + } + } + if bound { + break + } + } + t.Cleanup(func() { + select { + case err := <-errs: + require.NoError(t, err) + case <-time.After(5 * time.Second): + require.FailNow(t, "server did not stop") + } + }) +} + +func isBindPermissionError(err error) bool { + return errors.Is(err, syscall.EACCES) || errors.Is(err, syscall.EPERM) +} + +func stopServer(t *testing.T, srv *Server) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + require.NoError(t, srv.Stop(ctx)) +} + +func connectHTTPClient(t *testing.T, ctx context.Context, endpoint string) *mcp.ClientSession { + t.Helper() + client := mcp.NewClient(&mcp.Implementation{Name: "code-mode-test-client", Version: "test"}, nil) + session, err := client.Connect(ctx, &mcp.StreamableClientTransport{ + Endpoint: endpoint, + DisableStandaloneSSE: true, + }, nil) + require.NoError(t, err) + return session +} + +func connectInMemoryClient(t *testing.T, ctx context.Context, srv *Server) *mcp.ClientSession { + t.Helper() + clientTransport, serverTransport := mcp.NewInMemoryTransports() + errs := make(chan error, 1) + go func() { + errs <- srv.mcpServer.Run(ctx, serverTransport) + }() + t.Cleanup(func() { + select { + case err := <-errs: + if err != nil && !errors.Is(err, context.Canceled) { + require.NoError(t, err) + } + default: + } + }) + + client := mcp.NewClient(&mcp.Implementation{Name: "code-mode-test-client", Version: "test"}, nil) + session, err := client.Connect(ctx, clientTransport, nil) + require.NoError(t, err) + return session +} + +func hasResource(resources []*mcp.Resource, uri string) bool { + return slices.ContainsFunc(resources, func(resource *mcp.Resource) bool { + return resource.URI == uri + }) +} + +func mustJSON(t *testing.T, value any) string { + t.Helper() + data, err := json.Marshal(value) + require.NoError(t, err) + return string(data) +} + +func toolError(message string) *mcp.CallToolResult { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: message}}, + IsError: true, + } +} + +type recordingStorage struct { + mu sync.Mutex + schema *ast.Document + setSchemaCalls int + bundle string + bundleSessionID string +} + +func newRecordingStorage() *recordingStorage { + return &recordingStorage{bundle: "declare const tools: {};"} +} + +func (s *recordingStorage) Append(_ context.Context, _ string, ops []storage.SessionOp) ([]storage.SessionOp, error) { + return ops, nil +} + +func (s *recordingStorage) GetOp(context.Context, string, string) (storage.SessionOp, bool, error) { + return storage.SessionOp{}, false, nil +} + +func (s *recordingStorage) ListNames(context.Context, string) ([]string, error) { + return nil, nil +} + +func (s *recordingStorage) Bundle(_ context.Context, sessionID string) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.bundleSessionID = sessionID + return s.bundle, nil +} + +func (s *recordingStorage) Reset(context.Context, string) error { + return nil +} + +func (s *recordingStorage) SetSchema(schema *ast.Document) { + s.mu.Lock() + defer s.mu.Unlock() + s.schema = schema + s.setSchemaCalls++ +} + +func (s *recordingStorage) Schema() *ast.Document { + s.mu.Lock() + defer s.mu.Unlock() + return s.schema +} + +func (s *recordingStorage) Start(context.Context) error { + return nil +} + +func (s *recordingStorage) Stop() error { + return nil +} + +func (s *recordingStorage) lastBundleSessionID() string { + s.mu.Lock() + defer s.mu.Unlock() + return s.bundleSessionID +} diff --git a/router/internal/codemode/server/session.go b/router/internal/codemode/server/session.go new file mode 100644 index 0000000000..ed2dcaeba5 --- /dev/null +++ b/router/internal/codemode/server/session.go @@ -0,0 +1,34 @@ +package server + +import ( + "context" + "net/http" +) + +const mcpSessionIDHeader = "Mcp-Session-Id" + +type sessionIDContextKey struct{} + +// SessionIDFromContext returns the MCP Streamable-HTTP session ID stored on ctx. +// An empty value is meaningful: it indicates stateless mode or a request without +// Mcp-Session-Id, and callers must not synthesize a replacement. +func SessionIDFromContext(ctx context.Context) string { + id, _ := ctx.Value(sessionIDContextKey{}).(string) + return id +} + +// WithSessionID stores id on ctx for Code Mode handlers. +func WithSessionID(ctx context.Context, id string) context.Context { + return context.WithValue(ctx, sessionIDContextKey{}, id) +} + +// withSessionIDFromRequest reads Mcp-Session-Id directly from the HTTP request. +// The modelcontextprotocol/go-sdk exposes transport headers to MCP handlers as +// req.Extra.Header; handlers call WithSessionID(ctx, req.Extra.Header.Get(...)). +// This helper is used for HTTP middleware/tests where the raw request is known. +func withSessionIDFromRequest(ctx context.Context, req *http.Request) context.Context { + if req == nil { + return WithSessionID(ctx, "") + } + return WithSessionID(ctx, req.Header.Get(mcpSessionIDHeader)) +} diff --git a/router/internal/codemode/storage/memory_backend.go b/router/internal/codemode/storage/memory_backend.go new file mode 100644 index 0000000000..7467d17635 --- /dev/null +++ b/router/internal/codemode/storage/memory_backend.go @@ -0,0 +1,354 @@ +package storage + +import ( + "context" + "errors" + "sort" + "sync" + "sync/atomic" + "time" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" +) + +const ( + defaultSessionTTL = 30 * time.Minute + defaultMaxSessions = 1_000 + defaultMaxBundleBytes = 1 << 20 +) + +type MemoryConfig struct { + SessionTTL time.Duration + MaxSessions int + MaxBundleBytes int + Renderer Renderer + Now func() time.Time +} + +type MemoryBackend struct { + sessionTTL time.Duration + maxSessions int + maxBundleBytes int + renderer Renderer + now func() time.Time + + sessions sync.Map + + schemaMu sync.RWMutex + schema *ast.Document + + schemaVer atomic.Uint64 + + lifecycleMu sync.Mutex + cancel context.CancelFunc + done chan struct{} +} + +type memSession struct { + mu sync.Mutex + ops []SessionOp + lastUsed time.Time + bundle string + bundleValid bool +} + +type sessionSnapshot struct { + id string + lastUsed time.Time +} + +func NewMemoryBackend(config MemoryConfig) *MemoryBackend { + if config.SessionTTL <= 0 { + config.SessionTTL = defaultSessionTTL + } + if config.MaxSessions <= 0 { + config.MaxSessions = defaultMaxSessions + } + if config.MaxBundleBytes < 0 { + config.MaxBundleBytes = 0 + } + if config.MaxBundleBytes == 0 { + config.MaxBundleBytes = defaultMaxBundleBytes + } + if config.Now == nil { + config.Now = time.Now + } + + return &MemoryBackend{ + sessionTTL: config.SessionTTL, + maxSessions: config.MaxSessions, + maxBundleBytes: config.MaxBundleBytes, + renderer: config.Renderer, + now: config.Now, + } +} + +func (b *MemoryBackend) Append(ctx context.Context, sessionID string, ops []SessionOp) ([]SessionOp, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + if len(ops) == 0 { + return nil, nil + } + + session := b.loadOrCreateSession(sessionID) + session.mu.Lock() + appended := make([]SessionOp, 0, len(ops)) + taken := make(map[string]struct{}, len(session.ops)+len(ops)) + for _, op := range session.ops { + taken[op.Name] = struct{}{} + } + for _, op := range ops { + op.Name = SuffixedName(NormalizeName(op.Name), taken) + taken[op.Name] = struct{}{} + session.ops = append(session.ops, op) + appended = append(appended, op) + } + session.lastUsed = b.now() + session.bundle = "" + session.bundleValid = false + session.mu.Unlock() + + b.enforceMaxSessions() + return appended, nil +} + +func (b *MemoryBackend) GetOp(ctx context.Context, sessionID string, name string) (SessionOp, bool, error) { + if err := ctx.Err(); err != nil { + return SessionOp{}, false, err + } + + value, ok := b.sessions.Load(sessionID) + if !ok { + return SessionOp{}, false, nil + } + session := value.(*memSession) + session.mu.Lock() + defer session.mu.Unlock() + + session.lastUsed = b.now() + for _, op := range session.ops { + if op.Name == name { + return op, true, nil + } + } + return SessionOp{}, false, nil +} + +func (b *MemoryBackend) ListNames(ctx context.Context, sessionID string) ([]string, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + + value, ok := b.sessions.Load(sessionID) + if !ok { + return nil, nil + } + session := value.(*memSession) + session.mu.Lock() + defer session.mu.Unlock() + + session.lastUsed = b.now() + names := make([]string, 0, len(session.ops)) + for _, op := range session.ops { + names = append(names, op.Name) + } + return names, nil +} + +func (b *MemoryBackend) Bundle(ctx context.Context, sessionID string) (string, error) { + if err := ctx.Err(); err != nil { + return "", err + } + + value, ok := b.sessions.Load(sessionID) + if !ok { + return b.renderCapped(ctx, nil) + } + session := value.(*memSession) + + session.mu.Lock() + defer session.mu.Unlock() + + session.lastUsed = b.now() + if session.bundleValid { + return session.bundle, nil + } + + if b.renderer == nil { + return "", errors.New("code mode storage renderer is not configured") + } + + ops := append([]SessionOp(nil), session.ops...) + bundle, err := b.renderCapped(ctx, ops) + if err != nil { + return "", err + } + + session.bundle = bundle + session.bundleValid = true + return bundle, nil +} + +func (b *MemoryBackend) Reset(ctx context.Context, sessionID string) error { + if err := ctx.Err(); err != nil { + return err + } + b.sessions.Delete(sessionID) + return nil +} + +func (b *MemoryBackend) SetSchema(schema *ast.Document) { + b.schemaMu.Lock() + b.schema = schema + b.schemaMu.Unlock() + + b.schemaVer.Add(1) + b.clearSessions() +} + +func (b *MemoryBackend) Schema() *ast.Document { + b.schemaMu.RLock() + defer b.schemaMu.RUnlock() + return b.schema +} + +func (b *MemoryBackend) SchemaVersion() uint64 { + return b.schemaVer.Load() +} + +func (b *MemoryBackend) Start(ctx context.Context) error { + b.lifecycleMu.Lock() + defer b.lifecycleMu.Unlock() + + if b.cancel != nil { + return nil + } + + runCtx, cancel := context.WithCancel(ctx) + b.cancel = cancel + b.done = make(chan struct{}) + go b.runSweeper(runCtx, b.done) + return nil +} + +func (b *MemoryBackend) Stop() error { + b.lifecycleMu.Lock() + cancel := b.cancel + done := b.done + b.cancel = nil + b.done = nil + b.lifecycleMu.Unlock() + + if cancel == nil { + return nil + } + cancel() + <-done + return nil +} + +func (b *MemoryBackend) loadOrCreateSession(sessionID string) *memSession { + now := b.now() + session := &memSession{lastUsed: now} + value, _ := b.sessions.LoadOrStore(sessionID, session) + return value.(*memSession) +} + +func (b *MemoryBackend) renderCapped(ctx context.Context, ops []SessionOp) (string, error) { + bundle, err := b.renderer.Render(ctx, ops, b.Schema()) + if err != nil { + return "", err + } + if b.maxBundleBytes <= 0 || len(bundle) <= b.maxBundleBytes { + return bundle, nil + } + + for keep := len(ops) - 1; keep >= 0; keep-- { + if err := ctx.Err(); err != nil { + return "", err + } + truncated, err := b.renderer.Render(ctx, ops[:keep], b.Schema()) + if err != nil { + return "", err + } + if len(truncated) <= b.maxBundleBytes { + return truncated, nil + } + } + return "", nil +} + +func (b *MemoryBackend) runSweeper(ctx context.Context, done chan<- struct{}) { + defer close(done) + + interval := b.sessionTTL / 4 + if interval <= 0 { + interval = time.Second + } + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + b.sweepIdle() + b.enforceMaxSessions() + } + } +} + +func (b *MemoryBackend) sweepIdle() { + if b.sessionTTL <= 0 { + return + } + + cutoff := b.now().Add(-b.sessionTTL) + b.sessions.Range(func(key, value any) bool { + session := value.(*memSession) + session.mu.Lock() + expired := !session.lastUsed.After(cutoff) + session.mu.Unlock() + if expired { + b.sessions.Delete(key) + } + return true + }) +} + +func (b *MemoryBackend) enforceMaxSessions() { + if b.maxSessions <= 0 { + return + } + + snapshots := make([]sessionSnapshot, 0) + b.sessions.Range(func(key, value any) bool { + session := value.(*memSession) + session.mu.Lock() + snapshots = append(snapshots, sessionSnapshot{id: key.(string), lastUsed: session.lastUsed}) + session.mu.Unlock() + return true + }) + if len(snapshots) <= b.maxSessions { + return + } + + sort.Slice(snapshots, func(i, j int) bool { + if snapshots[i].lastUsed.Equal(snapshots[j].lastUsed) { + return snapshots[i].id < snapshots[j].id + } + return snapshots[i].lastUsed.Before(snapshots[j].lastUsed) + }) + for _, snapshot := range snapshots[:len(snapshots)-b.maxSessions] { + b.sessions.Delete(snapshot.id) + } +} + +func (b *MemoryBackend) clearSessions() { + b.sessions.Range(func(key, _ any) bool { + b.sessions.Delete(key) + return true + }) +} diff --git a/router/internal/codemode/storage/memory_backend_test.go b/router/internal/codemode/storage/memory_backend_test.go new file mode 100644 index 0000000000..662816ce64 --- /dev/null +++ b/router/internal/codemode/storage/memory_backend_test.go @@ -0,0 +1,332 @@ +package storage + +import ( + "context" + "fmt" + "sort" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" +) + +type testClock struct { + mu sync.Mutex + now time.Time +} + +func newTestClock() *testClock { + return &testClock{now: time.Unix(1_700_000_000, 0).UTC()} +} + +func (c *testClock) Now() time.Time { + c.mu.Lock() + defer c.mu.Unlock() + return c.now +} + +func (c *testClock) Advance(d time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + c.now = c.now.Add(d) +} + +func newTestBackend(t *testing.T, clock *testClock, renderer Renderer) *MemoryBackend { + t.Helper() + + if renderer == nil { + renderer = RendererFunc(func(ops []SessionOp) (string, error) { + names := make([]string, 0, len(ops)) + for _, op := range ops { + names = append(names, op.Name) + } + return strings.Join(names, "\n"), nil + }) + } + + backend := NewMemoryBackend(MemoryConfig{ + SessionTTL: time.Hour, + MaxSessions: 100, + MaxBundleBytes: 1 << 20, + Renderer: renderer, + Now: clock.Now, + }) + require.NoError(t, backend.Start(context.Background())) + t.Cleanup(func() { + require.NoError(t, backend.Stop()) + }) + + return backend +} + +func TestMemoryBackendAppendGetOpBundleResetRoundTrip(t *testing.T) { + ctx := context.Background() + clock := newTestClock() + backend := newTestBackend(t, clock, nil) + + ops := []SessionOp{ + {Name: "get-user", Body: "query GetUser { user { id } }", Kind: OperationKindQuery, Description: "Fetch a user"}, + {Name: "delete", Body: "mutation DeleteUser { deleteUser(id: 1) }", Kind: OperationKindMutation, Description: "Delete a user"}, + {Name: "get-user", Body: "query GetUserAgain { user { name } }", Kind: OperationKindQuery, Description: "Fetch user name"}, + } + + appended, err := backend.Append(ctx, "session-1", ops) + require.NoError(t, err) + assert.Equal(t, []SessionOp{ + {Name: "getUser", Body: "query GetUser { user { id } }", Kind: OperationKindQuery, Description: "Fetch a user"}, + {Name: "op_delete", Body: "mutation DeleteUser { deleteUser(id: 1) }", Kind: OperationKindMutation, Description: "Delete a user"}, + {Name: "getUser_2", Body: "query GetUserAgain { user { name } }", Kind: OperationKindQuery, Description: "Fetch user name"}, + }, appended) + + gotQuery, ok, err := backend.GetOp(ctx, "session-1", "getUser") + require.NoError(t, err) + assert.Equal(t, true, ok) + assert.Equal(t, SessionOp{Name: "getUser", Body: "query GetUser { user { id } }", Kind: OperationKindQuery, Description: "Fetch a user"}, gotQuery) + + gotMutation, ok, err := backend.GetOp(ctx, "session-1", "op_delete") + require.NoError(t, err) + assert.Equal(t, true, ok) + assert.Equal(t, SessionOp{Name: "op_delete", Body: "mutation DeleteUser { deleteUser(id: 1) }", Kind: OperationKindMutation, Description: "Delete a user"}, gotMutation) + + gotCollision, ok, err := backend.GetOp(ctx, "session-1", "getUser_2") + require.NoError(t, err) + assert.Equal(t, true, ok) + assert.Equal(t, SessionOp{Name: "getUser_2", Body: "query GetUserAgain { user { name } }", Kind: OperationKindQuery, Description: "Fetch user name"}, gotCollision) + + bundle, err := backend.Bundle(ctx, "session-1") + require.NoError(t, err) + assert.Equal(t, "getUser\nop_delete\ngetUser_2", bundle) + + require.NoError(t, backend.Reset(ctx, "session-1")) + gotAfterReset, ok, err := backend.GetOp(ctx, "session-1", "getUser") + require.NoError(t, err) + assert.Equal(t, false, ok) + assert.Equal(t, SessionOp{}, gotAfterReset) + + bundleAfterReset, err := backend.Bundle(ctx, "session-1") + require.NoError(t, err) + assert.Equal(t, "", bundleAfterReset) +} + +func TestMemoryBackendSetSchemaClearsSessionsAndIncrementsSchemaVersion(t *testing.T) { + ctx := context.Background() + clock := newTestClock() + backend := newTestBackend(t, clock, nil) + initialVersion := backend.SchemaVersion() + + _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: "get-user", Body: "query { user { id } }", Kind: OperationKindQuery}}) + require.NoError(t, err) + schema := &ast.Document{} + + backend.SetSchema(schema) + + assert.Equal(t, initialVersion+1, backend.SchemaVersion()) + assert.Equal(t, schema, backend.Schema()) + + got, ok, err := backend.GetOp(ctx, "session-1", "getUser") + require.NoError(t, err) + assert.Equal(t, false, ok) + assert.Equal(t, SessionOp{}, got) + + backend.SetSchema(&ast.Document{}) + assert.Equal(t, initialVersion+2, backend.SchemaVersion()) +} + +func TestMemoryBackendTTLEvictionUsesInjectedClock(t *testing.T) { + ctx := context.Background() + clock := newTestClock() + backend := NewMemoryBackend(MemoryConfig{ + SessionTTL: time.Minute, + MaxSessions: 100, + MaxBundleBytes: 1 << 20, + Renderer: RendererFunc(func(ops []SessionOp) (string, error) { return "", nil }), + Now: clock.Now, + }) + + _, err := backend.Append(ctx, "idle", []SessionOp{{Name: "idle-op", Body: "query { idle }", Kind: OperationKindQuery}}) + require.NoError(t, err) + _, err = backend.Append(ctx, "fresh", []SessionOp{{Name: "fresh-op", Body: "query { fresh }", Kind: OperationKindQuery}}) + require.NoError(t, err) + clock.Advance(30 * time.Second) + _, ok, err := backend.GetOp(ctx, "fresh", "freshOp") + require.NoError(t, err) + assert.Equal(t, true, ok) + + clock.Advance(31 * time.Second) + backend.sweepIdle() + + _, idleOK, err := backend.GetOp(ctx, "idle", "idleOp") + require.NoError(t, err) + assert.Equal(t, false, idleOK) + + _, freshOK, err := backend.GetOp(ctx, "fresh", "freshOp") + require.NoError(t, err) + assert.Equal(t, true, freshOK) +} + +func TestMemoryBackendLRUEvictionAtMaxSessions(t *testing.T) { + ctx := context.Background() + clock := newTestClock() + backend := NewMemoryBackend(MemoryConfig{ + SessionTTL: time.Hour, + MaxSessions: 2, + MaxBundleBytes: 1 << 20, + Renderer: RendererFunc(func(ops []SessionOp) (string, error) { return "", nil }), + Now: clock.Now, + }) + + _, err := backend.Append(ctx, "session-a", []SessionOp{{Name: "a-op", Body: "query { a }", Kind: OperationKindQuery}}) + require.NoError(t, err) + clock.Advance(time.Second) + _, err = backend.Append(ctx, "session-b", []SessionOp{{Name: "b-op", Body: "query { b }", Kind: OperationKindQuery}}) + require.NoError(t, err) + clock.Advance(time.Second) + _, ok, err := backend.GetOp(ctx, "session-a", "aOp") + require.NoError(t, err) + assert.Equal(t, true, ok) + clock.Advance(time.Second) + + _, err = backend.Append(ctx, "session-c", []SessionOp{{Name: "c-op", Body: "query { c }", Kind: OperationKindQuery}}) + require.NoError(t, err) + + _, aOK, err := backend.GetOp(ctx, "session-a", "aOp") + require.NoError(t, err) + assert.Equal(t, true, aOK) + + _, bOK, err := backend.GetOp(ctx, "session-b", "bOp") + require.NoError(t, err) + assert.Equal(t, false, bOK) + + _, cOK, err := backend.GetOp(ctx, "session-c", "cOp") + require.NoError(t, err) + assert.Equal(t, true, cOK) +} + +func TestMemoryBackendConcurrentAppendIsRaceFreeAndSuffixesNames(t *testing.T) { + ctx := context.Background() + clock := newTestClock() + backend := newTestBackend(t, clock, nil) + + const goroutines = 32 + var wg sync.WaitGroup + errs := make(chan error, goroutines) + + for i := range goroutines { + wg.Add(1) + go func(i int) { + defer wg.Done() + _, err := backend.Append(ctx, "shared", []SessionOp{{ + Name: "shared-op", + Body: fmt.Sprintf("query Shared%d { shared%d }", i, i), + Kind: OperationKindQuery, + Description: fmt.Sprintf("Shared %d", i), + }}) + errs <- err + }(i) + } + + wg.Wait() + close(errs) + + for err := range errs { + require.NoError(t, err) + } + + names := make([]string, 0, goroutines) + for i := range goroutines { + name := "sharedOp" + if i > 0 { + name = fmt.Sprintf("sharedOp_%d", i+1) + } + op, ok, err := backend.GetOp(ctx, "shared", name) + require.NoError(t, err) + assert.Equal(t, true, ok) + names = append(names, op.Name) + } + + sort.Strings(names) + want := make([]string, 0, goroutines) + for i := range goroutines { + name := "sharedOp" + if i > 0 { + name = fmt.Sprintf("sharedOp_%d", i+1) + } + want = append(want, name) + } + sort.Strings(want) + assert.Equal(t, want, names) +} + +func TestMemoryBackendBundleCacheInvalidatesOnAppend(t *testing.T) { + ctx := context.Background() + clock := newTestClock() + var mu sync.Mutex + rendered := make([]string, 0, 3) + renderer := RendererFunc(func(ops []SessionOp) (string, error) { + names := make([]string, 0, len(ops)) + for _, op := range ops { + names = append(names, op.Name) + } + bundle := strings.Join(names, ",") + mu.Lock() + rendered = append(rendered, bundle) + mu.Unlock() + return bundle, nil + }) + backend := newTestBackend(t, clock, renderer) + + _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: "one", Body: "query { one }", Kind: OperationKindQuery}}) + require.NoError(t, err) + first, err := backend.Bundle(ctx, "session-1") + require.NoError(t, err) + assert.Equal(t, "one", first) + + second, err := backend.Bundle(ctx, "session-1") + require.NoError(t, err) + assert.Equal(t, "one", second) + + _, err = backend.Append(ctx, "session-1", []SessionOp{{Name: "two", Body: "query { two }", Kind: OperationKindQuery}}) + require.NoError(t, err) + third, err := backend.Bundle(ctx, "session-1") + require.NoError(t, err) + assert.Equal(t, "one,two", third) + + mu.Lock() + gotRendered := append([]string(nil), rendered...) + mu.Unlock() + assert.Equal(t, []string{"one", "one,two"}, gotRendered) +} + +func TestMemoryBackendBundleDropsWholeOpsAtMaxBundleBytes(t *testing.T) { + ctx := context.Background() + clock := newTestClock() + renderer := RendererFunc(func(ops []SessionOp) (string, error) { + names := make([]string, 0, len(ops)) + for _, op := range ops { + names = append(names, op.Name) + } + return strings.Join(names, "|"), nil + }) + backend := NewMemoryBackend(MemoryConfig{ + SessionTTL: time.Hour, + MaxSessions: 100, + MaxBundleBytes: len("one|two"), + Renderer: renderer, + Now: clock.Now, + }) + + _, err := backend.Append(ctx, "session-1", []SessionOp{ + {Name: "one", Body: "query { one }", Kind: OperationKindQuery}, + {Name: "two", Body: "query { two }", Kind: OperationKindQuery}, + {Name: "three", Body: "query { three }", Kind: OperationKindQuery}, + }) + require.NoError(t, err) + + bundle, err := backend.Bundle(ctx, "session-1") + require.NoError(t, err) + assert.Equal(t, "one|two", bundle) +} diff --git a/router/internal/codemode/storage/naming.go b/router/internal/codemode/storage/naming.go new file mode 100644 index 0000000000..a3b91eadc9 --- /dev/null +++ b/router/internal/codemode/storage/naming.go @@ -0,0 +1,191 @@ +package storage + +import ( + "slices" + "strconv" + "strings" + "unicode" +) + +var reservedWords = map[string]struct{}{ + "abstract": {}, + "any": {}, + "as": {}, + "async": {}, + "await": {}, + "boolean": {}, + "break": {}, + "case": {}, + "catch": {}, + "class": {}, + "const": {}, + "constructor": {}, + "continue": {}, + "debugger": {}, + "declare": {}, + "default": {}, + "delete": {}, + "do": {}, + "else": {}, + "enum": {}, + "export": {}, + "extends": {}, + "false": {}, + "finally": {}, + "for": {}, + "from": {}, + "function": {}, + "get": {}, + "if": {}, + "implements": {}, + "import": {}, + "in": {}, + "infer": {}, + "instanceof": {}, + "interface": {}, + "is": {}, + "keyof": {}, + "let": {}, + "module": {}, + "namespace": {}, + "never": {}, + "new": {}, + "null": {}, + "number": {}, + "object": {}, + "of": {}, + "package": {}, + "private": {}, + "protected": {}, + "public": {}, + "readonly": {}, + "require": {}, + "return": {}, + "satisfies": {}, + "set": {}, + "static": {}, + "string": {}, + "super": {}, + "switch": {}, + "symbol": {}, + "this": {}, + "throw": {}, + "true": {}, + "try": {}, + "type": {}, + "typeof": {}, + "undefined": {}, + "unique": {}, + "unknown": {}, + "var": {}, + "void": {}, + "while": {}, + "with": {}, + "yield": {}, +} + +func NormalizeName(raw string) string { + // Idempotency: names produced by an earlier NormalizeName call (carrying our reserved-word + // or leading-digit prefixes) round-trip without re-splitting. + if rest, ok := strings.CutPrefix(raw, "op_"); ok { + if _, reserved := reservedWords[rest]; reserved && isLowerCamel(rest) { + return raw + } + } + if rest, ok := strings.CutPrefix(raw, "_"); ok { + if len(rest) > 0 && unicode.IsDigit(rune(rest[0])) && isIdentTail(rest) { + return raw + } + } + words := strings.FieldsFunc(raw, func(r rune) bool { + return !unicode.IsLetter(r) && !unicode.IsDigit(r) + }) + words = slices.DeleteFunc(words, func(word string) bool { + return word == "" + }) + if len(words) == 0 { + return "operation" + } + + var builder strings.Builder + for i, word := range words { + if i == 0 { + builder.WriteString(lowerFirst(word)) + continue + } + builder.WriteString(upperFirst(word)) + } + + name := builder.String() + if name == "" { + name = "operation" + } + if first, _ := firstRune(name); unicode.IsDigit(first) { + name = "_" + name + } + if _, ok := reservedWords[name]; ok { + name = "op_" + name + } + return name +} + +func SuffixedName(base string, taken map[string]struct{}) string { + if _, ok := taken[base]; !ok { + return base + } + for i := 2; ; i++ { + name := base + "_" + strconv.Itoa(i) + if _, ok := taken[name]; !ok { + return name + } + } +} + +func lowerFirst(value string) string { + if value == "" { + return value + } + runes := []rune(value) + runes[0] = unicode.ToLower(runes[0]) + return string(runes) +} + +func upperFirst(value string) string { + if value == "" { + return value + } + runes := []rune(strings.ToLower(value)) + runes[0] = unicode.ToUpper(runes[0]) + return string(runes) +} + +func isLowerCamel(value string) bool { + if value == "" { + return false + } + for i, r := range value { + if i == 0 && !unicode.IsLower(r) { + return false + } + if !unicode.IsLetter(r) && !unicode.IsDigit(r) { + return false + } + } + return true +} + +func isIdentTail(value string) bool { + for _, r := range value { + if !unicode.IsLetter(r) && !unicode.IsDigit(r) { + return false + } + } + return true +} + +func firstRune(value string) (rune, bool) { + for _, r := range value { + return r, true + } + return 0, false +} diff --git a/router/internal/codemode/storage/naming_test.go b/router/internal/codemode/storage/naming_test.go new file mode 100644 index 0000000000..10215f9730 --- /dev/null +++ b/router/internal/codemode/storage/naming_test.go @@ -0,0 +1,84 @@ +package storage + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNormalizeName(t *testing.T) { + tests := []struct { + name string + raw string + want string + }{ + {name: "kebab case", raw: "get-user-by-id", want: "getUserById"}, + {name: "snake case", raw: "get_user_by_id", want: "getUserById"}, + {name: "space separated", raw: "Get User By ID", want: "getUserById"}, + {name: "mixed separators", raw: "get__user--by id", want: "getUserById"}, + {name: "already camel", raw: "getUserById", want: "getUserById"}, + {name: "leading digit", raw: "123foo", want: "_123foo"}, + {name: "leading digit with separators", raw: "123-foo-bar", want: "_123FooBar"}, + {name: "reserved word", raw: "delete", want: "op_delete"}, + {name: "reserved word after normalization", raw: "class", want: "op_class"}, + {name: "invalid punctuation", raw: "get$user#by%id", want: "getUserById"}, + {name: "empty input", raw: "", want: "operation"}, + {name: "only invalid input", raw: "$$$", want: "operation"}, + {name: "underscore output for reserved word is not rechecked", raw: "op-delete", want: "opDelete"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, NormalizeName(tt.raw)) + }) + } +} + +func TestSuffixedName(t *testing.T) { + tests := []struct { + name string + base string + taken map[string]struct{} + want string + }{ + { + name: "first use keeps base", + base: "getUser", + taken: map[string]struct{}{}, + want: "getUser", + }, + { + name: "first collision uses suffix two", + base: "getUser", + taken: map[string]struct{}{ + "getUser": {}, + }, + want: "getUser_2", + }, + { + name: "skips occupied suffixes", + base: "getUser", + taken: map[string]struct{}{ + "getUser": {}, + "getUser_2": {}, + "getUser_3": {}, + }, + want: "getUser_4", + }, + { + name: "gap is reused", + base: "getUser", + taken: map[string]struct{}{ + "getUser": {}, + "getUser_3": {}, + }, + want: "getUser_2", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, SuffixedName(tt.base, tt.taken)) + }) + } +} diff --git a/router/internal/codemode/storage/redis_backend.go b/router/internal/codemode/storage/redis_backend.go new file mode 100644 index 0000000000..90e6e66883 --- /dev/null +++ b/router/internal/codemode/storage/redis_backend.go @@ -0,0 +1,362 @@ +package storage + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/redis/go-redis/v9" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "go.uber.org/zap" +) + +const defaultRedisKeyPrefix = "cosmo_code_mode" + +var _ SessionStorage = (*RedisBackend)(nil) + +type RedisConfig struct { + Client redis.UniversalClient + KeyPrefix string + SessionTTL time.Duration + Renderer Renderer + Logger *zap.Logger + Now func() time.Time +} + +type RedisBackend struct { + client redis.UniversalClient + keyPrefix string + sessionTTL time.Duration + renderer Renderer + logger *zap.Logger + now func() time.Time + + schemaMu sync.RWMutex + schema *ast.Document + schemaVer atomic.Uint64 +} + +type redisOpEntry struct { + SessionOp + LastUsed time.Time `json:"last_used"` +} + +type redisBundleEntry struct { + Bundle string `json:"bundle"` + SchemaVer uint64 `json:"schema_ver"` + RenderedAt time.Time `json:"rendered_at"` +} + +func NewRedisBackend(cfg RedisConfig) (*RedisBackend, error) { + if cfg.Client == nil { + return nil, errors.New("code mode redis storage client is not configured") + } + if cfg.KeyPrefix == "" { + cfg.KeyPrefix = defaultRedisKeyPrefix + } + if cfg.SessionTTL <= 0 { + cfg.SessionTTL = defaultSessionTTL + } + if cfg.Logger == nil { + cfg.Logger = zap.NewNop() + } + if cfg.Now == nil { + cfg.Now = time.Now + } + + return &RedisBackend{ + client: cfg.Client, + keyPrefix: cfg.KeyPrefix, + sessionTTL: cfg.SessionTTL, + renderer: cfg.Renderer, + logger: cfg.Logger, + now: cfg.Now, + }, nil +} + +func (b *RedisBackend) Append(ctx context.Context, sessionID string, ops []SessionOp) ([]SessionOp, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + if len(ops) == 0 { + return nil, nil + } + + backoff := 5 * time.Millisecond + var appended []SessionOp + for { + if err := ctx.Err(); err != nil { + return nil, err + } + + opsKey := b.opsKey(sessionID) + bundleKey := b.bundleKey(sessionID) + now := b.now() + err := b.client.Watch(ctx, func(tx *redis.Tx) error { + entries, err := b.readOps(ctx, tx, opsKey) + if err != nil { + return err + } + + taken := make(map[string]struct{}, len(entries)+len(ops)) + for _, entry := range entries { + taken[entry.Name] = struct{}{} + } + appended = make([]SessionOp, 0, len(ops)) + for _, op := range ops { + op.Name = SuffixedName(NormalizeName(op.Name), taken) + taken[op.Name] = struct{}{} + entries = append(entries, redisOpEntry{ + SessionOp: op, + LastUsed: now, + }) + appended = append(appended, op) + } + payload, err := json.Marshal(entries) + if err != nil { + return err + } + + _, err = tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error { + pipe.Set(ctx, opsKey, payload, 0) + pipe.Expire(ctx, opsKey, b.sessionTTL) + pipe.Del(ctx, bundleKey) + return nil + }) + return err + }, opsKey) + if err == nil { + return appended, nil + } + + b.logger.Debug("retrying code mode redis append", + zap.String("session_id", sessionID), + zap.Error(err), + ) + if err := sleepWithContext(ctx, backoff); err != nil { + return nil, err + } + backoff *= 2 + if backoff > 100*time.Millisecond { + backoff = 100 * time.Millisecond + } + } +} + +func (b *RedisBackend) GetOp(ctx context.Context, sessionID string, name string) (SessionOp, bool, error) { + if err := ctx.Err(); err != nil { + return SessionOp{}, false, err + } + + opsKey := b.opsKey(sessionID) + entries, err := b.readOps(ctx, b.client, opsKey) + if err != nil { + return SessionOp{}, false, err + } + + for i, entry := range entries { + if entry.Name != name { + continue + } + entries[i].LastUsed = b.now() + b.touchOpBestEffort(ctx, opsKey, name) + return entry.SessionOp, true, nil + } + return SessionOp{}, false, nil +} + +func (b *RedisBackend) ListNames(ctx context.Context, sessionID string) ([]string, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + + entries, err := b.readOps(ctx, b.client, b.opsKey(sessionID)) + if err != nil { + return nil, err + } + + names := make([]string, 0, len(entries)) + for _, entry := range entries { + names = append(names, entry.Name) + } + return names, nil +} + +func (b *RedisBackend) Bundle(ctx context.Context, sessionID string) (string, error) { + if err := ctx.Err(); err != nil { + return "", err + } + + bundleKey := b.bundleKey(sessionID) + cached, err := b.client.Get(ctx, bundleKey).Bytes() + if err == nil { + var entry redisBundleEntry + if err := json.Unmarshal(cached, &entry); err != nil { + return "", fmt.Errorf("decode code mode redis bundle: %w", err) + } + if entry.SchemaVer == b.SchemaVersion() { + return entry.Bundle, nil + } + } else if !errors.Is(err, redis.Nil) { + return "", err + } + + opsKey := b.opsKey(sessionID) + entries, err := b.readOps(ctx, b.client, opsKey) + if err != nil { + return "", err + } + if len(entries) == 0 { + if b.renderer == nil { + return "", errors.New("code mode storage renderer is not configured") + } + return b.renderer.Render(ctx, nil, b.Schema()) + } + if b.renderer == nil { + return "", errors.New("code mode storage renderer is not configured") + } + + ops := make([]SessionOp, 0, len(entries)) + for _, entry := range entries { + ops = append(ops, entry.SessionOp) + } + bundle, err := b.renderer.Render(ctx, ops, b.Schema()) + if err != nil { + return "", err + } + + payload, err := json.Marshal(redisBundleEntry{ + Bundle: bundle, + SchemaVer: b.SchemaVersion(), + RenderedAt: b.now(), + }) + if err != nil { + return "", err + } + if err := b.setWithTTL(ctx, bundleKey, payload); err != nil { + b.logger.Warn("failed to cache code mode redis bundle", + zap.String("session_id", sessionID), + zap.Error(err), + ) + } + return bundle, nil +} + +func (b *RedisBackend) Reset(ctx context.Context, sessionID string) error { + if err := ctx.Err(); err != nil { + return err + } + return b.client.Del(ctx, b.opsKey(sessionID), b.bundleKey(sessionID)).Err() +} + +func (b *RedisBackend) SetSchema(schema *ast.Document) { + b.schemaMu.Lock() + b.schema = schema + b.schemaMu.Unlock() + + b.schemaVer.Add(1) +} + +func (b *RedisBackend) Schema() *ast.Document { + b.schemaMu.RLock() + defer b.schemaMu.RUnlock() + return b.schema +} + +func (b *RedisBackend) SchemaVersion() uint64 { + return b.schemaVer.Load() +} + +func (b *RedisBackend) Start(context.Context) error { + return nil +} + +func (b *RedisBackend) Stop() error { + return nil +} + +func (b *RedisBackend) opsKey(sessionID string) string { + return fmt.Sprintf("%s:s:%d:%s:ops", b.keyPrefix, b.SchemaVersion(), sessionID) +} + +func (b *RedisBackend) bundleKey(sessionID string) string { + return fmt.Sprintf("%s:s:%d:%s:bundle", b.keyPrefix, b.SchemaVersion(), sessionID) +} + +type redisStringGetter interface { + Get(context.Context, string) *redis.StringCmd +} + +func (b *RedisBackend) readOps(ctx context.Context, getter redisStringGetter, key string) ([]redisOpEntry, error) { + raw, err := getter.Get(ctx, key).Bytes() + if errors.Is(err, redis.Nil) { + return nil, nil + } + if err != nil { + return nil, err + } + + var entries []redisOpEntry + if err := json.Unmarshal(raw, &entries); err != nil { + return nil, fmt.Errorf("decode code mode redis ops: %w", err) + } + return entries, nil +} + +func (b *RedisBackend) touchOpBestEffort(ctx context.Context, key string, name string) { + err := b.client.Watch(ctx, func(tx *redis.Tx) error { + entries, err := b.readOps(ctx, tx, key) + if err != nil { + return err + } + + found := false + for i := range entries { + if entries[i].Name == name { + entries[i].LastUsed = b.now() + found = true + break + } + } + if !found { + return nil + } + + payload, err := json.Marshal(entries) + if err != nil { + return err + } + _, err = tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error { + pipe.Set(ctx, key, payload, 0) + pipe.Expire(ctx, key, b.sessionTTL) + return nil + }) + return err + }, key) + if err != nil && !errors.Is(err, redis.TxFailedErr) { + b.logger.Warn("failed to update code mode redis op last_used", zap.Error(err)) + } +} + +func (b *RedisBackend) setWithTTL(ctx context.Context, key string, value []byte) error { + if err := b.client.Set(ctx, key, value, 0).Err(); err != nil { + return err + } + return b.client.Expire(ctx, key, b.sessionTTL).Err() +} + +func sleepWithContext(ctx context.Context, duration time.Duration) error { + timer := time.NewTimer(duration) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} diff --git a/router/internal/codemode/storage/redis_backend_test.go b/router/internal/codemode/storage/redis_backend_test.go new file mode 100644 index 0000000000..3bb736353c --- /dev/null +++ b/router/internal/codemode/storage/redis_backend_test.go @@ -0,0 +1,264 @@ +package storage + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + miniredisserver "github.com/alicebob/miniredis/v2/server" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" +) + +type testRedisRenderer func(context.Context, []SessionOp, *ast.Document) (string, error) + +func (f testRedisRenderer) Render(ctx context.Context, ops []SessionOp, schema *ast.Document) (string, error) { + return f(ctx, ops, schema) +} + +func newTestRedisBackend(t *testing.T, renderer Renderer, ttl time.Duration) (*RedisBackend, *miniredis.Miniredis, *redis.Client) { + t.Helper() + + if renderer == nil { + renderer = testRedisRenderer(func(_ context.Context, ops []SessionOp, _ *ast.Document) (string, error) { + names := make([]string, 0, len(ops)) + for _, op := range ops { + names = append(names, op.Name) + } + return strings.Join(names, "\n"), nil + }) + } + if ttl == 0 { + ttl = time.Hour + } + + mr := miniredis.RunT(t) + client := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + t.Cleanup(func() { + require.NoError(t, client.Close()) + }) + + backend, err := NewRedisBackend(RedisConfig{ + Client: client, + KeyPrefix: "test_code_mode", + SessionTTL: ttl, + Renderer: renderer, + Now: func() time.Time { return time.Unix(1_700_000_000, 0).UTC() }, + }) + require.NoError(t, err) + require.NoError(t, backend.Start(context.Background())) + t.Cleanup(func() { + require.NoError(t, backend.Stop()) + }) + + return backend, mr, client +} + +func TestRedisBackendAppendGetOpRoundTrip(t *testing.T) { + ctx := context.Background() + backend, _, _ := newTestRedisBackend(t, nil, time.Hour) + + ops := []SessionOp{ + {Name: "get-user", Body: "query GetUser { user { id } }", Kind: OperationKindQuery, Description: "Fetch a user"}, + {Name: "op_delete", Body: "mutation DeleteUser { deleteUser(id: 1) }", Kind: OperationKindMutation, Description: "Delete a user"}, + {Name: "get-user", Body: "query GetUserAgain { user { name } }", Kind: OperationKindQuery, Description: "Fetch user name"}, + } + appended, err := backend.Append(ctx, "session-1", ops) + require.NoError(t, err) + assert.Equal(t, []SessionOp{ + {Name: "getUser", Body: "query GetUser { user { id } }", Kind: OperationKindQuery, Description: "Fetch a user"}, + {Name: "op_delete", Body: "mutation DeleteUser { deleteUser(id: 1) }", Kind: OperationKindMutation, Description: "Delete a user"}, + {Name: "getUser_2", Body: "query GetUserAgain { user { name } }", Kind: OperationKindQuery, Description: "Fetch user name"}, + }, appended) + + gotQuery, ok, err := backend.GetOp(ctx, "session-1", "getUser") + require.NoError(t, err) + assert.Equal(t, true, ok) + assert.Equal(t, SessionOp{Name: "getUser", Body: "query GetUser { user { id } }", Kind: OperationKindQuery, Description: "Fetch a user"}, gotQuery) + + gotCollision, ok, err := backend.GetOp(ctx, "session-1", "getUser_2") + require.NoError(t, err) + assert.Equal(t, true, ok) + assert.Equal(t, SessionOp{Name: "getUser_2", Body: "query GetUserAgain { user { name } }", Kind: OperationKindQuery, Description: "Fetch user name"}, gotCollision) + + gotMissing, ok, err := backend.GetOp(ctx, "session-1", "missing") + require.NoError(t, err) + assert.Equal(t, false, ok) + assert.Equal(t, SessionOp{}, gotMissing) +} + +func TestRedisBackendBundleRendersAndReadsFromCache(t *testing.T) { + ctx := context.Background() + var renders atomic.Int64 + backend, mr, _ := newTestRedisBackend(t, testRedisRenderer(func(_ context.Context, ops []SessionOp, _ *ast.Document) (string, error) { + renders.Add(1) + return fmt.Sprintf("render-%d:%s", renders.Load(), ops[0].Name), nil + }), time.Hour) + + _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: "getUser", Body: "query { user { id } }", Kind: OperationKindQuery}}) + require.NoError(t, err) + + first, err := backend.Bundle(ctx, "session-1") + require.NoError(t, err) + assert.Equal(t, "render-1:getUser", first) + assert.Equal(t, true, mr.Exists(backend.bundleKey("session-1"))) + + second, err := backend.Bundle(ctx, "session-1") + require.NoError(t, err) + assert.Equal(t, "render-1:getUser", second) + assert.Equal(t, int64(1), renders.Load()) +} + +func TestRedisBackendResetClearsOpsAndBundleKeys(t *testing.T) { + ctx := context.Background() + backend, mr, _ := newTestRedisBackend(t, nil, time.Hour) + + _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: "getUser", Body: "query { user { id } }", Kind: OperationKindQuery}}) + require.NoError(t, err) + _, err = backend.Bundle(ctx, "session-1") + require.NoError(t, err) + opsKey := backend.opsKey("session-1") + bundleKey := backend.bundleKey("session-1") + assert.Equal(t, true, mr.Exists(opsKey)) + assert.Equal(t, true, mr.Exists(bundleKey)) + + require.NoError(t, backend.Reset(ctx, "session-1")) + + assert.Equal(t, false, mr.Exists(opsKey)) + assert.Equal(t, false, mr.Exists(bundleKey)) +} + +func TestRedisBackendSetSchemaRotatesKeysAndKeepsOldKeysUntilTTL(t *testing.T) { + ctx := context.Background() + schemaA := &ast.Document{RootNodes: []ast.Node{{Kind: ast.NodeKindSchemaDefinition}}} + schemaB := &ast.Document{RootNodes: []ast.Node{{Kind: ast.NodeKindObjectTypeDefinition}}} + backend, mr, _ := newTestRedisBackend(t, testRedisRenderer(func(_ context.Context, _ []SessionOp, schema *ast.Document) (string, error) { + return fmt.Sprintf("schema-kind-%d", schema.RootNodes[0].Kind), nil + }), time.Hour) + backend.SetSchema(schemaA) + + _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: "getUser", Body: "query { user { id } }", Kind: OperationKindQuery}}) + require.NoError(t, err) + oldOpsKey := backend.opsKey("session-1") + oldBundleKey := backend.bundleKey("session-1") + first, err := backend.Bundle(ctx, "session-1") + require.NoError(t, err) + assert.Equal(t, fmt.Sprintf("schema-kind-%d", schemaA.RootNodes[0].Kind), first) + assert.Equal(t, true, mr.Exists(oldOpsKey)) + assert.Equal(t, true, mr.Exists(oldBundleKey)) + + oldVersion := backend.SchemaVersion() + backend.SetSchema(schemaB) + + assert.Equal(t, oldVersion+1, backend.SchemaVersion()) + assert.Equal(t, schemaB, backend.Schema()) + assert.Equal(t, true, mr.Exists(oldOpsKey)) + assert.Equal(t, true, mr.Exists(oldBundleKey)) + + _, err = backend.Append(ctx, "session-1", []SessionOp{{Name: "getUser", Body: "query { user { id } }", Kind: OperationKindQuery}}) + require.NoError(t, err) + second, err := backend.Bundle(ctx, "session-1") + require.NoError(t, err) + assert.Equal(t, fmt.Sprintf("schema-kind-%d", schemaB.RootNodes[0].Kind), second) + assert.Equal(t, true, mr.Exists(backend.opsKey("session-1"))) + assert.Equal(t, true, mr.Exists(backend.bundleKey("session-1"))) +} + +func TestRedisBackendConcurrentAppendRetriesWatchConflicts(t *testing.T) { + ctx := context.Background() + backend, mr, _ := newTestRedisBackend(t, nil, time.Hour) + const goroutines = 12 + const opsPerGoroutine = 8 + + var wg sync.WaitGroup + errs := make(chan error, goroutines) + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(worker int) { + defer wg.Done() + ops := make([]SessionOp, 0, opsPerGoroutine) + for j := 0; j < opsPerGoroutine; j++ { + ops = append(ops, SessionOp{Name: fmt.Sprintf("op_%02d_%02d", worker, j), Body: "query { ok }", Kind: OperationKindQuery}) + } + _, err := backend.Append(ctx, "session-1", ops) + errs <- err + }(i) + } + wg.Wait() + close(errs) + for err := range errs { + require.NoError(t, err) + } + + raw, err := mr.Get(backend.opsKey("session-1")) + require.NoError(t, err) + var entries []redisOpEntry + require.NoError(t, json.Unmarshal([]byte(raw), &entries)) + assert.Equal(t, goroutines*opsPerGoroutine, len(entries)) +} + +func TestRedisBackendAppendAbandonsOnContextDone(t *testing.T) { + backend, mr, _ := newTestRedisBackend(t, nil, time.Hour) + mr.SetError("LOADING Redis is loading the dataset in memory") + ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancel() + + _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: "getUser", Body: "query { user { id } }", Kind: OperationKindQuery}}) + + require.Error(t, err) + assert.Equal(t, true, errors.Is(err, context.DeadlineExceeded)) +} + +func TestRedisBackendExpiresKeysOnWrites(t *testing.T) { + ctx := context.Background() + backend, mr, _ := newTestRedisBackend(t, nil, 10*time.Second) + + _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: "getUser", Body: "query { user { id } }", Kind: OperationKindQuery}}) + require.NoError(t, err) + opsKey := backend.opsKey("session-1") + assert.Equal(t, 10*time.Second, mr.TTL(opsKey)) + + _, err = backend.Bundle(ctx, "session-1") + require.NoError(t, err) + bundleKey := backend.bundleKey("session-1") + assert.Equal(t, 10*time.Second, mr.TTL(bundleKey)) + + mr.FastForward(11 * time.Second) + assert.Equal(t, false, mr.Exists(opsKey)) + assert.Equal(t, false, mr.Exists(bundleKey)) +} + +func TestRedisBackendBundleWriteBackIsBestEffort(t *testing.T) { + ctx := context.Background() + backend, mr, _ := newTestRedisBackend(t, testRedisRenderer(func(_ context.Context, ops []SessionOp, _ *ast.Document) (string, error) { + return "rendered:" + ops[0].Name, nil + }), time.Hour) + _, err := backend.Append(ctx, "session-1", []SessionOp{{Name: "getUser", Body: "query { user { id } }", Kind: OperationKindQuery}}) + require.NoError(t, err) + + mr.Server().SetPreHook(func(c *miniredisserver.Peer, cmd string, _ ...string) bool { + if strings.EqualFold(cmd, "set") { + c.WriteError("ERR forced set failure") + return true + } + return false + }) + t.Cleanup(func() { + mr.Server().SetPreHook(nil) + }) + + bundle, err := backend.Bundle(ctx, "session-1") + + require.NoError(t, err) + assert.Equal(t, "rendered:getUser", bundle) + assert.Equal(t, false, mr.Exists(backend.bundleKey("session-1"))) +} diff --git a/router/internal/codemode/storage/storage.go b/router/internal/codemode/storage/storage.go new file mode 100644 index 0000000000..fc847a7acb --- /dev/null +++ b/router/internal/codemode/storage/storage.go @@ -0,0 +1,29 @@ +package storage + +import ( + "context" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" +) + +type SessionStorage interface { + Append(ctx context.Context, sessionID string, ops []SessionOp) ([]SessionOp, error) + GetOp(ctx context.Context, sessionID string, name string) (SessionOp, bool, error) + ListNames(ctx context.Context, sessionID string) ([]string, error) + Bundle(ctx context.Context, sessionID string) (string, error) + Reset(ctx context.Context, sessionID string) error + SetSchema(*ast.Document) + Schema() *ast.Document + Start(ctx context.Context) error + Stop() error +} + +type Renderer interface { + Render(ctx context.Context, ops []SessionOp, schema *ast.Document) (string, error) +} + +type RendererFunc func([]SessionOp) (string, error) + +func (f RendererFunc) Render(_ context.Context, ops []SessionOp, _ *ast.Document) (string, error) { + return f(ops) +} diff --git a/router/internal/codemode/storage/types.go b/router/internal/codemode/storage/types.go new file mode 100644 index 0000000000..ba3f1c7df2 --- /dev/null +++ b/router/internal/codemode/storage/types.go @@ -0,0 +1,15 @@ +package storage + +type OperationKind string + +const ( + OperationKindQuery OperationKind = "Query" + OperationKindMutation OperationKind = "Mutation" +) + +type SessionOp struct { + Name string + Body string + Kind OperationKind + Description string +} diff --git a/router/internal/codemode/tsgen/bundle_test.go b/router/internal/codemode/tsgen/bundle_test.go new file mode 100644 index 0000000000..462fd3f0d0 --- /dev/null +++ b/router/internal/codemode/tsgen/bundle_test.go @@ -0,0 +1,138 @@ +package tsgen + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/internal/codemode/storage" +) + +func TestRenderBundleEmptyOps(t *testing.T) { + got, err := RenderBundle(nil, testSchema(t), 0) + require.NoError(t, err) + + want := "type GraphQLError = { message: string; path?: (string | number)[]; extensions?: Record };\n" + + "type R = Promise<{ data: T | null; errors?: GraphQLError[] }>;\n" + + "\n" + + "declare const tools: {};\n" + + "\n" + + "declare function notNull(value: T | null | undefined, message?: string): T;\n" + + "declare function compact(value: T): T;" + + assert.Equal(t, want, got) +} + +func TestRenderBundleThreeOpsNoTruncation(t *testing.T) { + ops := []storage.SessionOp{ + {Name: "health", Body: `query Health { health }`, Kind: storage.OperationKindQuery, Description: "Checks router health."}, + {Name: "viewer", Body: `query Viewer { viewer { id name } }`, Kind: storage.OperationKindQuery, Description: "Fetches viewer."}, + {Name: "renameUser", Body: `mutation RenameUser($id: ID!, $name: String!) { renameUser(id: $id, name: $name) { id } }`, Kind: storage.OperationKindMutation, Description: "Renames a user."}, + } + + got, err := RenderBundle(ops, testSchema(t), 0) + require.NoError(t, err) + + want := "type GraphQLError = { message: string; path?: (string | number)[]; extensions?: Record };\n" + + "type R = Promise<{ data: T | null; errors?: GraphQLError[] }>;\n" + + "\n" + + "declare const tools: {\n" + + " /** Checks router health. */\n" + + " health(): R<{ health: string }>;\n" + + "\n" + + " /** Fetches viewer. */\n" + + " viewer(): R<{ viewer: { id: string; name: string } | null }>;\n" + + "\n" + + " /** Renames a user. */\n" + + " renameUser(vars: { id: string; name: string }): R<{ renameUser: { id: string } }>;\n" + + "};\n" + + "\n" + + "declare function notNull(value: T | null | undefined, message?: string): T;\n" + + "declare function compact(value: T): T;" + + assert.Equal(t, want, got) +} + +func TestRenderBundleTruncatesWholeOpsFromEnd(t *testing.T) { + ops := []storage.SessionOp{ + {Name: "health", Body: `query Health { health }`, Kind: storage.OperationKindQuery, Description: "Checks router health."}, + {Name: "viewer", Body: `query Viewer { viewer { id name } }`, Kind: storage.OperationKindQuery, Description: "Fetches viewer."}, + {Name: "renameUser", Body: `mutation RenameUser($id: ID!, $name: String!) { renameUser(id: $id, name: $name) { id } }`, Kind: storage.OperationKindMutation, Description: "Renames a user."}, + } + fullWithTwo := "type GraphQLError = { message: string; path?: (string | number)[]; extensions?: Record };\n" + + "type R = Promise<{ data: T | null; errors?: GraphQLError[] }>;\n" + + "\n" + + "declare const tools: {\n" + + " /** Checks router health. */\n" + + " health(): R<{ health: string }>;\n" + + "\n" + + " /** Fetches viewer. */\n" + + " viewer(): R<{ viewer: { id: string; name: string } | null }>;\n" + + "};\n" + + "\n" + + "declare function notNull(value: T | null | undefined, message?: string): T;\n" + + "declare function compact(value: T): T;\n" + + "// truncated: 1 ops omitted" + + got, err := RenderBundle(ops, testSchema(t), len(fullWithTwo)) + require.NoError(t, err) + + assert.Equal(t, fullWithTwo, got) +} + +func TestRenderBundleErrorsWhenPreludeCannotFit(t *testing.T) { + _, err := RenderBundle(nil, testSchema(t), 12) + require.Error(t, err) +} + +func TestRenderBundleRoundTripsAbstractField(t *testing.T) { + ops := []storage.SessionOp{ + { + Name: "petsList", + Body: `query PetsList { pets { __typename ... on Cat { name } ... on Dog { bark } } }`, + Kind: storage.OperationKindQuery, + Description: "Lists pets.", + }, + } + + got, err := RenderBundle(ops, testSchema(t), 0) + require.NoError(t, err) + + want := "type GraphQLError = { message: string; path?: (string | number)[]; extensions?: Record };\n" + + "type R = Promise<{ data: T | null; errors?: GraphQLError[] }>;\n" + + "\n" + + "declare const tools: {\n" + + " /** Lists pets. */\n" + + " petsList(): R<{ pets: ({ __typename: \"Cat\"; name: string } | { __typename: \"Dog\"; bark: string } | { __typename: \"Mouse\" })[] }>;\n" + + "};\n" + + "\n" + + "declare function notNull(value: T | null | undefined, message?: string): T;\n" + + "declare function compact(value: T): T;" + + assert.Equal(t, want, got) +} + +func TestNewOpsFragmentReturnsOnlySignatures(t *testing.T) { + ops := []storage.SessionOp{ + {Name: "health", Body: `query Health { health }`, Kind: storage.OperationKindQuery, Description: "Checks router health."}, + {Name: "viewer", Body: `query Viewer { viewer { id } }`, Kind: storage.OperationKindQuery, Description: "Fetches viewer."}, + {Name: "animal", Body: `query Animal { animal { id } }`, Kind: storage.OperationKindQuery, Description: "Fetches animal."}, + } + + got, err := NewOpsFragment(ops, testSchema(t)) + require.NoError(t, err) + + want := "/** Checks router health. */\n" + + "health(): R<{ health: string }>;\n" + + "\n" + + "/** Fetches viewer. */\n" + + "viewer(): R<{ viewer: { id: string } | null }>;\n" + + "\n" + + "/** Fetches animal. */\n" + + "animal(): R<{ animal: { id: string } | null }>;" + + assert.Equal(t, want, got) + assert.False(t, strings.Contains(got, "declare const tools")) + assert.False(t, strings.Contains(got, "type R")) +} diff --git a/router/internal/codemode/tsgen/graphql.go b/router/internal/codemode/tsgen/graphql.go new file mode 100644 index 0000000000..352f0d21a8 --- /dev/null +++ b/router/internal/codemode/tsgen/graphql.go @@ -0,0 +1,674 @@ +package tsgen + +import ( + "fmt" + "strconv" + "strings" + + "github.com/wundergraph/cosmo/router/internal/codemode/storage" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" +) + +type operationRenderer struct { + schema *ast.Document +} + +func (r operationRenderer) renderOperation(op storage.SessionOp) (string, error) { + if r.schema == nil { + return "", fmt.Errorf("render op %q: schema is nil", op.Name) + } + + opDoc, report := astparser.ParseGraphqlDocumentString(op.Body) + if report.HasErrors() { + return "", fmt.Errorf("render op %q: parse GraphQL operation: %s", op.Name, report.Error()) + } + + opRef, err := singleOperationRef(&opDoc) + if err != nil { + return "", fmt.Errorf("render op %q: %w", op.Name, err) + } + + varsType, varsOptional, err := r.variablesType(&opDoc, opRef) + if err != nil { + return "", fmt.Errorf("render op %q: %w", op.Name, err) + } + + outputType, err := r.outputType(&opDoc, opRef) + if err != nil { + return "", fmt.Errorf("render op %q: %w", op.Name, err) + } + + return writeFieldSignature(op.Description, op.Name, varsType, outputType, varsOptional), nil +} + +func singleOperationRef(doc *ast.Document) (int, error) { + var refs []int + for _, node := range doc.RootNodes { + if node.Kind == ast.NodeKindOperationDefinition { + refs = append(refs, node.Ref) + } + } + if len(refs) == 0 { + return 0, fmt.Errorf("operation document contains no operation definition") + } + if len(refs) > 1 { + return 0, fmt.Errorf("operation document contains %d operation definitions", len(refs)) + } + return refs[0], nil +} + +func (r operationRenderer) variablesType(opDoc *ast.Document, opRef int) (string, bool, error) { + op := opDoc.OperationDefinitions[opRef] + if !op.HasVariableDefinitions || len(op.VariableDefinitions.Refs) == 0 { + return "{}", true, nil + } + + fields := make([]tsProperty, 0, len(op.VariableDefinitions.Refs)) + varsOptional := true + for _, varRef := range op.VariableDefinitions.Refs { + name := opDoc.VariableDefinitionNameString(varRef) + typeRef := opDoc.VariableDefinitionType(varRef) + required := opDoc.Types[typeRef].TypeKind == ast.TypeKindNonNull + + typ, nullable, err := r.inputType(opDoc, typeRef) + if err != nil { + return "", false, err + } + if nullable { + typ = writeNullable(typ) + } else { + varsOptional = false + } + + fields = append(fields, tsProperty{name: name, typ: typ, optional: !required}) + } + + return writeInlineObject(fields), varsOptional, nil +} + +func (r operationRenderer) inputType(doc *ast.Document, typeRef int) (string, bool, error) { + gqlType := doc.Types[typeRef] + switch gqlType.TypeKind { + case ast.TypeKindNonNull: + typ, _, err := r.inputType(doc, gqlType.OfType) + return typ, false, err + case ast.TypeKindList: + item, itemNullable, err := r.inputType(doc, gqlType.OfType) + if err != nil { + return "", false, err + } + if itemNullable { + item = writeNullable(item) + } + return writeArray(item), true, nil + case ast.TypeKindNamed: + typ, err := r.inputNamedType(doc.TypeNameString(typeRef)) + return typ, true, err + default: + return "", false, fmt.Errorf("unsupported GraphQL input type kind %s", gqlType.TypeKind.String()) + } +} + +func (r operationRenderer) inputNamedType(typeName string) (string, error) { + switch typeName { + case "ID", "String": + return "string", nil + case "Int", "Float": + return "number", nil + case "Boolean": + return "boolean", nil + } + + node, exists := r.schema.Index.FirstNonExtensionNodeByNameBytes([]byte(typeName)) + if !exists { + return "", fmt.Errorf("missing schema type %q", typeName) + } + + switch node.Kind { + case ast.NodeKindEnumTypeDefinition: + values := r.enumValues(node.Ref) + return writeStringLiteralUnion(values), nil + case ast.NodeKindInputObjectTypeDefinition: + return r.inputObjectType(node.Ref) + case ast.NodeKindScalarTypeDefinition: + return "unknown", nil + default: + return "unknown", nil + } +} + +func (r operationRenderer) enumValues(enumRef int) []string { + def := r.schema.EnumTypeDefinitions[enumRef] + values := make([]string, 0, len(def.EnumValuesDefinition.Refs)) + for _, valueRef := range def.EnumValuesDefinition.Refs { + values = append(values, r.schema.EnumValueDefinitionNameString(valueRef)) + } + return values +} + +func (r operationRenderer) inputObjectType(inputObjectRef int) (string, error) { + def := r.schema.InputObjectTypeDefinitions[inputObjectRef] + fields := make([]tsProperty, 0, len(def.InputFieldsDefinition.Refs)) + for _, fieldRef := range def.InputFieldsDefinition.Refs { + name := r.schema.InputValueDefinitionNameString(fieldRef) + typeRef := r.schema.InputValueDefinitionType(fieldRef) + required := r.schema.Types[typeRef].TypeKind == ast.TypeKindNonNull + + typ, nullable, err := r.inputType(r.schema, typeRef) + if err != nil { + return "", err + } + if nullable { + typ = writeNullable(typ) + } + + fields = append(fields, tsProperty{name: name, typ: typ, optional: !required}) + } + + return writeInlineObject(fields), nil +} + +func (r operationRenderer) outputType(opDoc *ast.Document, opRef int) (string, error) { + op := opDoc.OperationDefinitions[opRef] + rootNode, err := r.rootOperationNode(op.OperationType) + if err != nil { + return "", err + } + + return r.selectionSetType(opDoc, op.SelectionSet, rootNode) +} + +func (r operationRenderer) rootOperationNode(operationType ast.OperationType) (ast.Node, error) { + var typeName []byte + switch operationType { + case ast.OperationTypeQuery: + typeName = r.schema.Index.QueryTypeName + if len(typeName) == 0 { + typeName = []byte("Query") + } + case ast.OperationTypeMutation: + typeName = r.schema.Index.MutationTypeName + if len(typeName) == 0 { + typeName = []byte("Mutation") + } + case ast.OperationTypeSubscription: + typeName = r.schema.Index.SubscriptionTypeName + if len(typeName) == 0 { + typeName = []byte("Subscription") + } + default: + return ast.Node{}, fmt.Errorf("unsupported operation type %s", operationType.Name()) + } + + node, exists := r.schema.Index.FirstNonExtensionNodeByNameBytes(typeName) + if !exists { + return ast.Node{}, fmt.Errorf("missing schema root type %q", string(typeName)) + } + return node, nil +} + +func (r operationRenderer) selectionSetType(opDoc *ast.Document, selectionSetRef int, parent ast.Node) (string, error) { + selections := opDoc.SelectionSets[selectionSetRef] + fields := make([]tsProperty, 0, len(selections.SelectionRefs)) + + for _, selectionRef := range selections.SelectionRefs { + selection := opDoc.Selections[selectionRef] + switch selection.Kind { + case ast.SelectionKindField: + field, err := r.fieldProperty(opDoc, selection.Ref, parent) + if err != nil { + return "", err + } + fields = append(fields, field) + case ast.SelectionKindInlineFragment: + inlineFields, err := r.inlineFragmentProperties(opDoc, selection.Ref, parent) + if err != nil { + return "", err + } + fields = append(fields, inlineFields...) + case ast.SelectionKindFragmentSpread: + fragmentFields, err := r.fragmentSpreadProperties(opDoc, selection.Ref, parent) + if err != nil { + return "", err + } + fields = append(fields, fragmentFields...) + default: + return "", fmt.Errorf("unsupported selection kind %s", selection.Kind.String()) + } + } + + return writeInlineObject(fields), nil +} + +func (r operationRenderer) fieldProperty(opDoc *ast.Document, fieldRef int, parent ast.Node) (tsProperty, error) { + name := opDoc.FieldNameString(fieldRef) + propName := opDoc.FieldAliasOrNameString(fieldRef) + + if name == "__typename" { + return tsProperty{name: propName, typ: "string"}, nil + } + + fieldDefRef, exists := r.schema.NodeFieldDefinitionByName(parent, []byte(name)) + if !exists { + return tsProperty{}, fmt.Errorf("missing field %q on schema type %q", name, parent.NameString(r.schema)) + } + + selectionSetRef := -1 + if opDoc.Fields[fieldRef].HasSelections { + selectionSetRef = opDoc.Fields[fieldRef].SelectionSet + } + + typeRef := r.schema.FieldDefinitionType(fieldDefRef) + typ, nullable, err := r.outputGraphQLType(opDoc, typeRef, selectionSetRef) + if err != nil { + return tsProperty{}, err + } + if nullable { + typ = writeNullable(typ) + } + + return tsProperty{name: propName, typ: typ}, nil +} + +func (r operationRenderer) outputGraphQLType(opDoc *ast.Document, typeRef int, selectionSetRef int) (string, bool, error) { + gqlType := r.schema.Types[typeRef] + switch gqlType.TypeKind { + case ast.TypeKindNonNull: + typ, _, err := r.outputGraphQLType(opDoc, gqlType.OfType, selectionSetRef) + return typ, false, err + case ast.TypeKindList: + item, itemNullable, err := r.outputGraphQLType(opDoc, gqlType.OfType, selectionSetRef) + if err != nil { + return "", false, err + } + if itemNullable { + item = writeNullable(item) + } + return writeArray(item), true, nil + case ast.TypeKindNamed: + typ, err := r.outputNamedType(opDoc, r.schema.TypeNameString(typeRef), selectionSetRef) + return typ, true, err + default: + return "", false, fmt.Errorf("unsupported GraphQL output type kind %s", gqlType.TypeKind.String()) + } +} + +func (r operationRenderer) outputNamedType(opDoc *ast.Document, typeName string, selectionSetRef int) (string, error) { + switch typeName { + case "ID", "String": + return "string", nil + case "Int", "Float": + return "number", nil + case "Boolean": + return "boolean", nil + } + + node, exists := r.schema.Index.FirstNonExtensionNodeByNameBytes([]byte(typeName)) + if !exists { + return "", fmt.Errorf("missing schema type %q", typeName) + } + + switch node.Kind { + case ast.NodeKindEnumTypeDefinition: + return writeStringLiteralUnion(r.enumValues(node.Ref)), nil + case ast.NodeKindObjectTypeDefinition: + if selectionSetRef < 0 { + return "", fmt.Errorf("object type %q requires a selection set", typeName) + } + return r.selectionSetType(opDoc, selectionSetRef, node) + case ast.NodeKindInterfaceTypeDefinition, ast.NodeKindUnionTypeDefinition: + if selectionSetRef < 0 { + return "", fmt.Errorf("abstract type %q requires a selection set", typeName) + } + return r.abstractFieldType(opDoc, selectionSetRef, node) + case ast.NodeKindScalarTypeDefinition: + return "unknown", nil + default: + return "unknown", nil + } +} + +func (r operationRenderer) inlineFragmentProperties(opDoc *ast.Document, inlineRef int, parent ast.Node) ([]tsProperty, error) { + fragment := opDoc.InlineFragments[inlineRef] + fragmentParent := parent + if opDoc.InlineFragmentHasTypeCondition(inlineRef) { + typeName := opDoc.InlineFragmentTypeConditionNameString(inlineRef) + node, exists := r.schema.Index.FirstNonExtensionNodeByNameBytes([]byte(typeName)) + if !exists { + return nil, fmt.Errorf("missing schema type %q", typeName) + } + fragmentParent = node + } + + typ, err := r.selectionSetType(opDoc, fragment.SelectionSet, fragmentParent) + if err != nil { + return nil, err + } + + return propertiesFromInlineObject(typ), nil +} + +func (r operationRenderer) fragmentSpreadProperties(opDoc *ast.Document, spreadRef int, parent ast.Node) ([]tsProperty, error) { + fragmentName := opDoc.FragmentSpreadNameBytes(spreadRef) + fragmentRef, exists := opDoc.FragmentDefinitionRef(fragmentName) + if !exists { + return nil, fmt.Errorf("missing fragment %q", string(fragmentName)) + } + + fragment := opDoc.FragmentDefinitions[fragmentRef] + fragmentParent := parent + typeName := opDoc.ResolveTypeNameString(fragment.TypeCondition.Type) + if typeName != "" { + node, nodeExists := r.schema.Index.FirstNonExtensionNodeByNameBytes([]byte(typeName)) + if !nodeExists { + return nil, fmt.Errorf("missing schema type %q", typeName) + } + fragmentParent = node + } + + typ, err := r.selectionSetType(opDoc, fragment.SelectionSet, fragmentParent) + if err != nil { + return nil, err + } + + return propertiesFromInlineObject(typ), nil +} + +func propertiesFromInlineObject(typ string) []tsProperty { + if typ == "{}" { + return nil + } + + inner := typ[2 : len(typ)-2] + parts := splitInlineObjectFields(inner) + props := make([]tsProperty, 0, len(parts)) + for _, part := range parts { + nameAndType := splitProperty(part) + if nameAndType.name == "" { + continue + } + props = append(props, nameAndType) + } + + return props +} + +func splitInlineObjectFields(inner string) []string { + var parts []string + start := 0 + depth := 0 + for i := 0; i < len(inner); i++ { + switch inner[i] { + case '{': + depth++ + case '}': + depth-- + case ';': + if depth == 0 && i+1 < len(inner) && inner[i+1] == ' ' { + parts = append(parts, inner[start:i]) + start = i + 2 + } + } + } + parts = append(parts, inner[start:]) + return parts +} + +func splitProperty(part string) tsProperty { + for i := 0; i < len(part); i++ { + if part[i] != ':' { + continue + } + optional := i > 0 && part[i-1] == '?' + nameEnd := i + if optional { + nameEnd-- + } + return tsProperty{name: part[:nameEnd], typ: part[i+2:], optional: optional} + } + return tsProperty{} +} + +// abstractSelectionSet describes a fragment to be applied to the matching +// branches when lowering an abstract-typed field. `condition` is the schema +// node referenced by the fragment's type condition (or the parent abstract +// node itself for inline fragments without a type condition). +type abstractSelectionSet struct { + condition ast.Node + selectionSetRef int +} + +// abstractFieldType lowers a selection set on an interface- or union-typed +// field into a flat discriminated union of branches, one per concrete +// implementor. +func (r operationRenderer) abstractFieldType(opDoc *ast.Document, selectionSetRef int, parent ast.Node) (string, error) { + parentName := parent.NameString(r.schema) + possibleNames := r.possibleTypeNames(parent) + if len(possibleNames) == 0 { + return "", fmt.Errorf("abstract type %q has no possible types", parentName) + } + possibleSet := make(map[string]struct{}, len(possibleNames)) + for _, name := range possibleNames { + possibleSet[name] = struct{}{} + } + + selections := opDoc.SelectionSets[selectionSetRef] + if len(selections.SelectionRefs) == 0 { + return "", fmt.Errorf("abstract type %q requires at least one selection", parentName) + } + + // Bucket the selections. + var bareFieldRefs []int // Field selections defined on the abstract parent itself + var typenameSelected bool // unaliased __typename selected directly + var fragments []abstractSelectionSet + + for _, selRef := range selections.SelectionRefs { + sel := opDoc.Selections[selRef] + switch sel.Kind { + case ast.SelectionKindField: + fieldRef := sel.Ref + fieldName := opDoc.FieldNameString(fieldRef) + if fieldName == "__typename" { + if opDoc.FieldAliasOrNameString(fieldRef) == "__typename" { + typenameSelected = true + } else { + // aliased __typename: render through normal field path on each branch + bareFieldRefs = append(bareFieldRefs, fieldRef) + } + continue + } + // Non-typename bare field is only valid on interface parents and must + // be defined on the parent interface. + if parent.Kind != ast.NodeKindInterfaceTypeDefinition { + return "", fmt.Errorf("field %q is not valid on union type %q", fieldName, parentName) + } + if _, exists := r.schema.NodeFieldDefinitionByName(parent, []byte(fieldName)); !exists { + return "", fmt.Errorf("missing field %q on interface %q", fieldName, parentName) + } + bareFieldRefs = append(bareFieldRefs, fieldRef) + case ast.SelectionKindInlineFragment: + inlineRef := sel.Ref + inline := opDoc.InlineFragments[inlineRef] + condition := parent + if opDoc.InlineFragmentHasTypeCondition(inlineRef) { + typeName := opDoc.InlineFragmentTypeConditionNameString(inlineRef) + node, exists := r.schema.Index.FirstNonExtensionNodeByNameBytes([]byte(typeName)) + if !exists { + return "", fmt.Errorf("missing schema type %q", typeName) + } + condition = node + } + if err := r.checkAbstractFragmentCondition(condition, possibleSet, parentName); err != nil { + return "", err + } + fragments = append(fragments, abstractSelectionSet{ + condition: condition, + selectionSetRef: inline.SelectionSet, + }) + case ast.SelectionKindFragmentSpread: + spreadRef := sel.Ref + fragmentName := opDoc.FragmentSpreadNameBytes(spreadRef) + fragRef, exists := opDoc.FragmentDefinitionRef(fragmentName) + if !exists { + return "", fmt.Errorf("missing fragment %q", string(fragmentName)) + } + fragment := opDoc.FragmentDefinitions[fragRef] + typeName := opDoc.ResolveTypeNameString(fragment.TypeCondition.Type) + if typeName == "" { + return "", fmt.Errorf("fragment %q has no type condition", string(fragmentName)) + } + node, nodeExists := r.schema.Index.FirstNonExtensionNodeByNameBytes([]byte(typeName)) + if !nodeExists { + return "", fmt.Errorf("missing schema type %q", typeName) + } + if err := r.checkAbstractFragmentCondition(node, possibleSet, parentName); err != nil { + return "", err + } + fragments = append(fragments, abstractSelectionSet{ + condition: node, + selectionSetRef: fragment.SelectionSet, + }) + default: + return "", fmt.Errorf("unsupported selection kind %s", sel.Kind.String()) + } + } + + // Build a branch per concrete implementor. + branches := make([]string, 0, len(possibleNames)) + for _, typeName := range possibleNames { + concreteNode, exists := r.schema.Index.FirstNonExtensionNodeByNameBytes([]byte(typeName)) + if !exists || concreteNode.Kind != ast.NodeKindObjectTypeDefinition { + continue + } + + fields := make([]tsProperty, 0) + + // Bare fields rendered against the concrete type. (For unions there + // will only be aliased __typename here, since other bare fields are + // rejected above.) + for _, fieldRef := range bareFieldRefs { + prop, err := r.fieldProperty(opDoc, fieldRef, concreteNode) + if err != nil { + return "", err + } + fields = append(fields, prop) + } + + // Fragments whose target includes this concrete type. + for _, frag := range fragments { + if !abstractFragmentApplies(frag.condition, typeName, possibleSet, r.schema) { + continue + } + fragTyp, err := r.selectionSetType(opDoc, frag.selectionSetRef, concreteNode) + if err != nil { + return "", err + } + fields = append(fields, propertiesFromInlineObject(fragTyp)...) + } + + // __typename literal: prepend if explicitly selected. + if typenameSelected { + literal := tsProperty{name: "__typename", typ: strconv.Quote(typeName)} + fields = append([]tsProperty{literal}, fields...) + } + + // Drop empty branches. + if len(fields) == 0 { + continue + } + + branches = append(branches, writeInlineObject(fields)) + } + + if len(branches) == 0 { + // Every implementor has zero observable fields. Fall back to a single + // empty object so the type checker still sees a valid shape. + return "{}", nil + } + + if len(branches) == 1 { + return branches[0], nil + } + + // Single-shape collapse: every branch identical → one shape. + allEqual := true + for i := 1; i < len(branches); i++ { + if branches[i] != branches[0] { + allEqual = false + break + } + } + if allEqual { + return branches[0], nil + } + + return strings.Join(branches, " | "), nil +} + +// possibleTypeNames returns the concrete object type names that satisfy the +// given abstract parent, in schema declaration order. +func (r operationRenderer) possibleTypeNames(parent ast.Node) []string { + switch parent.Kind { + case ast.NodeKindInterfaceTypeDefinition: + names, _ := r.schema.InterfaceTypeDefinitionImplementedByObjectWithNames(parent.Ref) + return names + case ast.NodeKindUnionTypeDefinition: + names, _ := r.schema.UnionTypeDefinitionMemberTypeNames(parent.Ref) + return names + case ast.NodeKindObjectTypeDefinition: + return []string{r.schema.ObjectTypeDefinitionNameString(parent.Ref)} + } + return nil +} + +// abstractFragmentApplies decides whether a fragment with the given condition +// applies to the concrete branch named typeName under the parent abstract +// (whose possible types are in parentSet). +func abstractFragmentApplies(condition ast.Node, typeName string, parentSet map[string]struct{}, schema *ast.Document) bool { + switch condition.Kind { + case ast.NodeKindObjectTypeDefinition: + return schema.ObjectTypeDefinitionNameString(condition.Ref) == typeName + case ast.NodeKindInterfaceTypeDefinition: + // applies to any T that implements this interface AND is in parentSet. + impls, _ := schema.InterfaceTypeDefinitionImplementedByObjectWithNames(condition.Ref) + for _, name := range impls { + if name == typeName { + if _, ok := parentSet[name]; ok { + return true + } + } + } + return false + case ast.NodeKindUnionTypeDefinition: + members, _ := schema.UnionTypeDefinitionMemberTypeNames(condition.Ref) + for _, name := range members { + if name == typeName { + if _, ok := parentSet[name]; ok { + return true + } + } + } + return false + } + return false +} + +// checkAbstractFragmentCondition rejects fragments whose type condition can +// never apply under the given parent abstract. +func (r operationRenderer) checkAbstractFragmentCondition(condition ast.Node, parentSet map[string]struct{}, parentName string) error { + switch condition.Kind { + case ast.NodeKindObjectTypeDefinition: + name := r.schema.ObjectTypeDefinitionNameString(condition.Ref) + if _, ok := parentSet[name]; !ok { + return fmt.Errorf("type %q is not a possible type of %q", name, parentName) + } + case ast.NodeKindInterfaceTypeDefinition, ast.NodeKindUnionTypeDefinition: + // abstract conditions are always allowed; their target is the + // intersection with the parent's possible types (which may be empty + // — that just means the fragment contributes nothing). + default: + return fmt.Errorf("unsupported fragment type condition %s", condition.Kind.String()) + } + return nil +} diff --git a/router/internal/codemode/tsgen/tsgen.go b/router/internal/codemode/tsgen/tsgen.go new file mode 100644 index 0000000000..ad18f5ab71 --- /dev/null +++ b/router/internal/codemode/tsgen/tsgen.go @@ -0,0 +1,117 @@ +package tsgen + +import ( + "context" + "fmt" + "strings" + + "github.com/wundergraph/cosmo/router/internal/codemode/storage" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" +) + +const ( + defaultMaxBundleBytes = 64 * 1024 + graphQLErrorAlias = "type GraphQLError = { message: string; path?: (string | number)[]; extensions?: Record };" + responseAlias = "type R = Promise<{ data: T | null; errors?: GraphQLError[] }>;" + notNullHelper = "declare function notNull(value: T | null | undefined, message?: string): T;" + compactHelper = "declare function compact(value: T): T;" +) + +type Renderer struct { + Schema *ast.Document + MaxBytes int +} + +func Adapter(schema *ast.Document, maxBytes ...int) storage.Renderer { + limit := defaultMaxBundleBytes + if len(maxBytes) > 0 { + limit = maxBytes[0] + } + + return Renderer{Schema: schema, MaxBytes: limit} +} + +func (r Renderer) Render(_ context.Context, ops []storage.SessionOp, schema *ast.Document) (string, error) { + if schema == nil { + schema = r.Schema + } + return RenderBundle(ops, schema, r.MaxBytes) +} + +func NewOpsFragment(ops []storage.SessionOp, schema *ast.Document) (string, error) { + renderer := operationRenderer{schema: schema} + + blocks := make([]string, 0, len(ops)) + for _, op := range ops { + block, err := renderer.renderOperation(op) + if err != nil { + return "", err + } + blocks = append(blocks, block) + } + + return strings.Join(blocks, "\n\n"), nil +} + +func RenderBundle(ops []storage.SessionOp, schema *ast.Document, maxBytes int) (string, error) { + renderer := operationRenderer{schema: schema} + + blocks := make([]string, 0, len(ops)) + for _, op := range ops { + block, err := renderer.renderOperation(op) + if err != nil { + return "", err + } + blocks = append(blocks, block) + } + + if maxBytes <= 0 { + return renderBundleBlocks(blocks, 0), nil + } + + full := renderBundleBlocks(blocks, 0) + if len([]byte(full)) <= maxBytes { + return full, nil + } + + for omitted := 1; omitted <= len(blocks); omitted++ { + candidate := renderBundleBlocks(blocks[:len(blocks)-omitted], omitted) + if len([]byte(candidate)) <= maxBytes { + return candidate, nil + } + } + + return "", fmt.Errorf("render TypeScript bundle: maxBytes %d is too small for bundle prelude", maxBytes) +} + +func renderBundleBlocks(blocks []string, omitted int) string { + var b strings.Builder + b.WriteString(graphQLErrorAlias) + b.WriteByte('\n') + b.WriteString(responseAlias) + b.WriteString("\n\n") + + if len(blocks) == 0 { + b.WriteString("declare const tools: {};") + } else { + b.WriteString("declare const tools: {\n") + for i, block := range blocks { + if i > 0 { + b.WriteString("\n\n") + } + b.WriteString(indentBlock(block, " ")) + } + b.WriteString("\n};") + } + + b.WriteString("\n\n") + b.WriteString(notNullHelper) + b.WriteByte('\n') + b.WriteString(compactHelper) + + if omitted > 0 { + fmt.Fprintf(&b, "\n// truncated: %d ops omitted", omitted) + } + + return b.String() +} diff --git a/router/internal/codemode/tsgen/tsgen_test.go b/router/internal/codemode/tsgen/tsgen_test.go new file mode 100644 index 0000000000..9e2dc2e626 --- /dev/null +++ b/router/internal/codemode/tsgen/tsgen_test.go @@ -0,0 +1,411 @@ +package tsgen + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/internal/codemode/storage" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" + "github.com/wundergraph/graphql-go-tools/v2/pkg/asttransform" +) + +const testSchemaSDL = ` +schema { + query: Query + mutation: Mutation +} + +type Query { + health: String! + node(id: ID!): User + search(cursor: String): SearchConnection! + tagged(tags: [String!]!): [User!]! + byStatus(status: Status!): [User!]! + filterUsers(filter: UserFilter): [User!]! + viewer: User + animal: Animal + pet: Pet + pets: [Pet!]! + maybePet: Pet + maybePets: [Pet] + requiredPets: [Pet!]! + searchResult: SearchResult + outsider: Outsider +} + +type Mutation { + renameUser(id: ID!, name: String!): User! +} + +type User { + id: ID! + name: String! + friend: User + tags: [String!]! +} + +type SearchConnection { + nodes: [User]! + nextCursor: String +} + +interface Animal { + id: ID! +} + +type Cat implements Animal & Pet & Friendly { + id: ID! + name: String! + friendliness: Int! + companion: Animal +} + +type Dog implements Pet & Friendly { + id: ID! + bark: String! + friendliness: Int! +} + +type Mouse implements Pet { + id: ID! + squeak: Boolean! +} + +interface Pet { + id: ID! +} + +interface Friendly { + friendliness: Int! +} + +interface Unrelated { + unrelated: String! +} + +type Outsider implements Unrelated { + id: ID! + unrelated: String! +} + +union SearchResult = User | Cat + +enum Status { + OPEN + CLOSED +} + +input UserFilter { + status: Status + tags: [String!] + limit: Int! +} +` + +func testSchema(t *testing.T) *ast.Document { + t.Helper() + + doc, report := astparser.ParseGraphqlDocumentString(testSchemaSDL) + require.False(t, report.HasErrors(), report.Error()) + require.NoError(t, asttransform.MergeDefinitionWithBaseSchema(&doc)) + + return &doc +} + +func TestNewOpsFragmentSignatures(t *testing.T) { + schema := testSchema(t) + + tests := []struct { + name string + op storage.SessionOp + want string + }{ + { + name: "var-less query", + op: storage.SessionOp{ + Name: "health", + Body: `query Health { health }`, + Kind: storage.OperationKindQuery, + Description: "Checks router health.", + }, + want: "/** Checks router health. */\nhealth(): R<{ health: string }>;", + }, + { + name: "required scalar var", + op: storage.SessionOp{ + Name: "getNode", + Body: `query GetNode($id: ID!) { node(id: $id) { id } }`, + Kind: storage.OperationKindQuery, + Description: "Fetches a node.", + }, + want: "/** Fetches a node. */\ngetNode(vars: { id: string }): R<{ node: { id: string } | null }>;", + }, + { + name: "optional nullable var", + op: storage.SessionOp{ + Name: "search", + Body: `query Search($cursor: String) { search(cursor: $cursor) { nextCursor } }`, + Kind: storage.OperationKindQuery, + Description: "Searches users.", + }, + want: "/** Searches users. */\nsearch(vars?: { cursor?: string | null }): R<{ search: { nextCursor: string | null } }>;", + }, + { + name: "list non-null var", + op: storage.SessionOp{ + Name: "tagged", + Body: `query Tagged($tags: [String!]!) { tagged(tags: $tags) { id } }`, + Kind: storage.OperationKindQuery, + Description: "Fetches users by tag.", + }, + want: "/** Fetches users by tag. */\ntagged(vars: { tags: string[] }): R<{ tagged: { id: string }[] }>;", + }, + { + name: "enum var", + op: storage.SessionOp{ + Name: "byStatus", + Body: `query ByStatus($status: Status!) { byStatus(status: $status) { id } }`, + Kind: storage.OperationKindQuery, + Description: "Fetches users by status.", + }, + want: "/** Fetches users by status. */\nbyStatus(vars: { status: \"OPEN\" | \"CLOSED\" }): R<{ byStatus: { id: string }[] }>;", + }, + { + name: "input object var", + op: storage.SessionOp{ + Name: "filterUsers", + Body: `query FilterUsers($filter: UserFilter) { filterUsers(filter: $filter) { id } }`, + Kind: storage.OperationKindQuery, + Description: "Filters users.", + }, + want: "/** Filters users. */\nfilterUsers(vars?: { filter?: { status?: \"OPEN\" | \"CLOSED\" | null; tags?: string[] | null; limit: number } | null }): R<{ filterUsers: { id: string }[] }>;", + }, + { + name: "nested object", + op: storage.SessionOp{ + Name: "viewer", + Body: `query Viewer { viewer { id friend { name } } }`, + Kind: storage.OperationKindQuery, + Description: "Fetches viewer.", + }, + want: "/** Fetches viewer. */\nviewer(): R<{ viewer: { id: string; friend: { name: string } | null } | null }>;", + }, + { + name: "aliased field", + op: storage.SessionOp{ + Name: "viewerAlias", + Body: `query ViewerAlias { me: viewer { id } }`, + Kind: storage.OperationKindQuery, + Description: "Fetches viewer with alias.", + }, + want: "/** Fetches viewer with alias. */\nviewerAlias(): R<{ me: { id: string } | null }>;", + }, + { + name: "inline fragment", + op: storage.SessionOp{ + Name: "viewerFragment", + Body: `query ViewerFragment { viewer { id ... on User { name } } }`, + Kind: storage.OperationKindQuery, + Description: "Fetches viewer fields.", + }, + want: "/** Fetches viewer fields. */\nviewerFragment(): R<{ viewer: { id: string; name: string } | null }>;", + }, + { + name: "union or interface output", + op: storage.SessionOp{ + Name: "animal", + Body: `query Animal { animal { id } }`, + Kind: storage.OperationKindQuery, + Description: "Fetches animal.", + }, + want: "/** Fetches animal. */\nanimal(): R<{ animal: { id: string } | null }>;", + }, + { + name: "mutation kind", + op: storage.SessionOp{ + Name: "renameUser", + Body: `mutation RenameUser($id: ID!, $name: String!) { renameUser(id: $id, name: $name) { id name } }`, + Kind: storage.OperationKindMutation, + Description: "Renames a user.", + }, + want: "/** Renames a user. */\nrenameUser(vars: { id: string; name: string }): R<{ renameUser: { id: string; name: string } }>;", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewOpsFragment([]storage.SessionOp{tt.op}, schema) + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestNewOpsFragmentAbstractSelections(t *testing.T) { + schema := testSchema(t) + + tests := []struct { + name string + op storage.SessionOp + want string + wantErr string + }{ + { + name: "interface, only __typename", + op: storage.SessionOp{ + Name: "petKind", + Body: `query PetKind { pet { __typename } }`, + Kind: storage.OperationKindQuery, + Description: "Pet kind.", + }, + want: "/** Pet kind. */\npetKind(): R<{ pet: { __typename: \"Cat\" } | { __typename: \"Dog\" } | { __typename: \"Mouse\" } | null }>;", + }, + { + name: "interface, bare field + one concrete fragment", + op: storage.SessionOp{ + Name: "petWithCatName", + Body: `query PetWithCatName { pet { id ... on Cat { name } } }`, + Kind: storage.OperationKindQuery, + Description: "Pet with cat name.", + }, + want: "/** Pet with cat name. */\npetWithCatName(): R<{ pet: { id: string; name: string } | { id: string } | { id: string } | null }>;", + }, + { + name: "interface, fragment on the same interface", + op: storage.SessionOp{ + Name: "petSameInterface", + Body: `query PetSameInterface { pet { ... on Pet { id } } }`, + Kind: storage.OperationKindQuery, + Description: "Pet same interface.", + }, + want: "/** Pet same interface. */\npetSameInterface(): R<{ pet: { id: string } | null }>;", + }, + { + name: "interface, fragment on an unrelated abstract", + op: storage.SessionOp{ + Name: "petUnrelated", + Body: `query PetUnrelated { pet { id ... on Unrelated { unrelated } } }`, + Kind: storage.OperationKindQuery, + Description: "Pet unrelated.", + }, + want: "/** Pet unrelated. */\npetUnrelated(): R<{ pet: { id: string } | null }>;", + }, + { + name: "interface, fragment on a related abstract", + op: storage.SessionOp{ + Name: "petFriendly", + Body: `query PetFriendly { pet { id ... on Friendly { friendliness } } }`, + Kind: storage.OperationKindQuery, + Description: "Pet friendly.", + }, + want: "/** Pet friendly. */\npetFriendly(): R<{ pet: { id: string; friendliness: number } | { id: string; friendliness: number } | { id: string } | null }>;", + }, + { + name: "concrete fragment on a non-implementor type", + op: storage.SessionOp{ + Name: "petBadFragment", + Body: `query PetBadFragment { pet { ... on User { name } } }`, + Kind: storage.OperationKindQuery, + Description: "Pet with non-implementor fragment.", + }, + wantErr: `render op "petBadFragment": type "User" is not a possible type of "Pet"`, + }, + { + name: "union, __typename-only selection", + op: storage.SessionOp{ + Name: "searchKind", + Body: `query SearchKind { searchResult { __typename } }`, + Kind: storage.OperationKindQuery, + Description: "Search kind.", + }, + want: "/** Search kind. */\nsearchKind(): R<{ searchResult: { __typename: \"User\" } | { __typename: \"Cat\" } | null }>;", + }, + { + name: "union with ... on Member for a subset", + op: storage.SessionOp{ + Name: "searchSubset", + Body: `query SearchSubset { searchResult { __typename ... on Cat { name } } }`, + Kind: storage.OperationKindQuery, + Description: "Search subset.", + }, + want: "/** Search subset. */\nsearchSubset(): R<{ searchResult: { __typename: \"User\" } | { __typename: \"Cat\"; name: string } | null }>;", + }, + { + name: "named fragment spread on abstract field", + op: storage.SessionOp{ + Name: "petSpread", + Body: `query PetSpread { pet { ...Bits } } fragment Bits on Pet { id }`, + Kind: storage.OperationKindQuery, + Description: "Pet spread.", + }, + want: "/** Pet spread. */\npetSpread(): R<{ pet: { id: string } | null }>;", + }, + { + name: "aliased __typename", + op: storage.SessionOp{ + Name: "petAliasedKind", + Body: `query PetAliasedKind { pet { kind: __typename } }`, + Kind: storage.OperationKindQuery, + Description: "Pet aliased kind.", + }, + want: "/** Pet aliased kind. */\npetAliasedKind(): R<{ pet: { kind: string } | null }>;", + }, + { + name: "duplicate response keys, identical", + op: storage.SessionOp{ + Name: "petDupIdentical", + Body: `query PetDupIdentical { pet { id id } }`, + Kind: storage.OperationKindQuery, + Description: "Pet dup identical.", + }, + // merging is out of scope for this PR; pin duplicates as duplicates + want: "/** Pet dup identical. */\npetDupIdentical(): R<{ pet: { id: string; id: string } | null }>;", + }, + { + name: "duplicate response keys, conflicting", + op: storage.SessionOp{ + Name: "petDupConflict", + Body: `query PetDupConflict { pet { id ... on Cat { id: name } } }`, + Kind: storage.OperationKindQuery, + Description: "Pet dup conflict.", + }, + // merging is out of scope; conflicting duplicates are emitted as-is + // instead of erroring (mirrors current object-selection behavior). + want: "/** Pet dup conflict. */\npetDupConflict(): R<{ pet: { id: string; id: string } | { id: string } | { id: string } | null }>;", + }, + { + name: "nested abstract inside an inline fragment", + op: storage.SessionOp{ + Name: "petCompanion", + Body: `query PetCompanion { pet { ... on Cat { companion { __typename } } } }`, + Kind: storage.OperationKindQuery, + Description: "Pet companion.", + }, + want: "/** Pet companion. */\npetCompanion(): R<{ pet: { companion: { __typename: \"Cat\" } | null } | null }>;", + }, + { + name: "list / nullable / non-nullable wrapping", + op: storage.SessionOp{ + Name: "petsWrappers", + Body: `query PetsWrappers { pets { id } maybePet { id } maybePets { id } requiredPets { id } }`, + Kind: storage.OperationKindQuery, + Description: "Pets wrappers.", + }, + want: "/** Pets wrappers. */\npetsWrappers(): R<{ pets: { id: string }[]; maybePet: { id: string } | null; maybePets: ({ id: string } | null)[] | null; requiredPets: { id: string }[] }>;", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewOpsFragment([]storage.SessionOp{tt.op}, schema) + if tt.wantErr != "" { + require.EqualError(t, err, tt.wantErr) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/router/internal/codemode/tsgen/typescript.go b/router/internal/codemode/tsgen/typescript.go new file mode 100644 index 0000000000..3807f6c7a7 --- /dev/null +++ b/router/internal/codemode/tsgen/typescript.go @@ -0,0 +1,102 @@ +package tsgen + +import ( + "strconv" + "strings" +) + +type tsProperty struct { + name string + typ string + optional bool +} + +func writeJSDoc(description string) string { + clean := strings.Join(strings.Fields(description), " ") + clean = strings.ReplaceAll(clean, "*/", "* /") + if clean == "" { + clean = "Registered GraphQL operation." + } + return "/** " + clean + " */" +} + +func writeFieldSignature(description, name, varsType, outputType string, varsOptional bool) string { + var b strings.Builder + b.WriteString(writeJSDoc(description)) + b.WriteByte('\n') + b.WriteString(name) + if varsType == "{}" { + b.WriteString("()") + } else { + b.WriteString("(vars") + if varsOptional { + b.WriteByte('?') + } + b.WriteString(": ") + b.WriteString(varsType) + b.WriteByte(')') + } + b.WriteString(": R<") + b.WriteString(outputType) + b.WriteString(">;") + return b.String() +} + +func writeInlineObject(fields []tsProperty) string { + if len(fields) == 0 { + return "{}" + } + + parts := make([]string, 0, len(fields)) + for _, field := range fields { + suffix := ": " + if field.optional { + suffix = "?: " + } + parts = append(parts, field.name+suffix+field.typ) + } + + return "{ " + strings.Join(parts, "; ") + " }" +} + +func writeArray(item string) string { + if strings.Contains(item, " | ") { + item = "(" + item + ")" + } + return item + "[]" +} + +func writeNullable(typ string) string { + if strings.HasSuffix(typ, " | null") { + return typ + } + return typ + " | null" +} + +func writeStringLiteralUnion(values []string) string { + if len(values) == 0 { + return "unknown" + } + + quoted := make([]string, 0, len(values)) + for _, value := range values { + quoted = append(quoted, strconv.Quote(value)) + } + + return strings.Join(quoted, " | ") +} + +func indentBlock(block, indent string) string { + if block == "" { + return "" + } + + lines := strings.Split(block, "\n") + for i := range lines { + if lines[i] != "" { + lines[i] = indent + lines[i] + } + } + + return strings.Join(lines, "\n") +} diff --git a/router/internal/codemode/yoko/client.go b/router/internal/codemode/yoko/client.go new file mode 100644 index 0000000000..4768e0afa3 --- /dev/null +++ b/router/internal/codemode/yoko/client.go @@ -0,0 +1,163 @@ +package yoko + +import ( + "context" + "net/http" + "sync" + + "connectrpc.com/connect" + yokov1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1" + yokoconnect "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1/yokov1connect" + "go.uber.org/zap" + "golang.org/x/sync/singleflight" +) + +type Option func(*Client) + +func WithServiceClient(serviceClient yokoconnect.YokoServiceClient) Option { + return func(c *Client) { + if serviceClient != nil { + c.serviceClient = serviceClient + } + } +} + +type Client struct { + serviceClient yokoconnect.YokoServiceClient + logger *zap.Logger + + schemaMu sync.RWMutex + schemaSDL string + schemaID string + + indexGroup singleflight.Group +} + +func New(httpClient *http.Client, baseURL string, logger *zap.Logger, opts ...Option) *Client { + if httpClient == nil { + httpClient = http.DefaultClient + } + if logger == nil { + logger = zap.NewNop() + } + + client := &Client{ + serviceClient: yokoconnect.NewYokoServiceClient(httpClient, baseURL), + logger: logger, + } + for _, opt := range opts { + opt(client) + } + return client +} + +func (c *Client) Search(ctx context.Context, sessionID string, prompts []string) (*yokov1.SearchResponse, error) { + schemaID, err := c.ensureSchemaID(ctx) + if err != nil { + return nil, err + } + + resp, err := c.search(ctx, schemaID, sessionID, prompts) + if err == nil { + return resp, nil + } + if connect.CodeOf(err) != connect.CodeNotFound { + return nil, err + } + + c.invalidateSchemaID(schemaID) + + schemaID, err = c.ensureSchemaID(ctx) + if err != nil { + return nil, err + } + + resp, err = c.search(ctx, schemaID, sessionID, prompts) + if err != nil { + c.invalidateSchemaID(schemaID) + return nil, err + } + return resp, nil +} + +func (c *Client) SetSchema(sdl string) { + c.schemaMu.Lock() + defer c.schemaMu.Unlock() + c.schemaSDL = sdl + c.schemaID = "" +} + +func (c *Client) Schema() string { + c.schemaMu.RLock() + defer c.schemaMu.RUnlock() + return c.schemaSDL +} + +func (c *Client) EnsureIndexed(ctx context.Context) error { + _, err := c.ensureSchemaID(ctx) + return err +} + +func (c *Client) ensureSchemaID(ctx context.Context) (string, error) { + sdl, schemaID := c.schemaState() + if schemaID != "" { + return schemaID, nil + } + + // Key by raw SDL because Yoko, not the router, owns schema identity. + value, err, _ := c.indexGroup.Do(sdl, func() (any, error) { + currentSDL, currentSchemaID := c.schemaState() + if currentSDL == sdl && currentSchemaID != "" { + return currentSchemaID, nil + } + + resp, err := c.serviceClient.Index(ctx, connect.NewRequest(&yokov1.IndexRequest{ + SchemaSdl: sdl, + })) + if err != nil { + return "", err + } + + indexedSchemaID := resp.Msg.GetSchemaId() + c.cacheSchemaID(currentSDL, indexedSchemaID) + return indexedSchemaID, nil + }) + if err != nil { + return "", err + } + return value.(string), nil +} + +func (c *Client) search(ctx context.Context, schemaID string, sessionID string, prompts []string) (*yokov1.SearchResponse, error) { + resp, err := c.serviceClient.Search(ctx, connect.NewRequest(&yokov1.SearchRequest{ + Prompts: prompts, + SchemaId: schemaID, + SessionId: sessionID, + })) + if err != nil { + return nil, err + } + return resp.Msg, nil +} + +func (c *Client) schemaState() (string, string) { + c.schemaMu.RLock() + defer c.schemaMu.RUnlock() + return c.schemaSDL, c.schemaID +} + +func (c *Client) cacheSchemaID(sdl string, schemaID string) { + c.schemaMu.Lock() + defer c.schemaMu.Unlock() + if c.schemaSDL == sdl { + c.schemaID = schemaID + } +} + +func (c *Client) invalidateSchemaID(schemaID string) { + c.schemaMu.Lock() + defer c.schemaMu.Unlock() + if c.schemaID == schemaID { + c.schemaID = "" + } +} diff --git a/router/internal/codemode/yoko/client_test.go b/router/internal/codemode/yoko/client_test.go new file mode 100644 index 0000000000..136e5193c8 --- /dev/null +++ b/router/internal/codemode/yoko/client_test.go @@ -0,0 +1,434 @@ +package yoko + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "connectrpc.com/connect" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + yokov1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1" +) + +type fakeYokoServiceClient struct { + mu sync.Mutex + + indexRequests []*yokov1.IndexRequest + searchRequests []*yokov1.SearchRequest + + indexFunc func(context.Context, *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) + searchFunc func(context.Context, *connect.Request[yokov1.SearchRequest]) (*connect.Response[yokov1.SearchResponse], error) +} + +func (f *fakeYokoServiceClient) Index(ctx context.Context, req *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { + f.mu.Lock() + f.indexRequests = append(f.indexRequests, req.Msg) + indexFunc := f.indexFunc + f.mu.Unlock() + + if indexFunc != nil { + return indexFunc(ctx, req) + } + return connect.NewResponse(&yokov1.IndexResponse{SchemaId: "schema-1"}), nil +} + +func (f *fakeYokoServiceClient) Search(ctx context.Context, req *connect.Request[yokov1.SearchRequest]) (*connect.Response[yokov1.SearchResponse], error) { + f.mu.Lock() + f.searchRequests = append(f.searchRequests, req.Msg) + searchFunc := f.searchFunc + f.mu.Unlock() + + if searchFunc != nil { + return searchFunc(ctx, req) + } + return connect.NewResponse(searchResponse("op")), nil +} + +func (f *fakeYokoServiceClient) indexRequestMessages() []*yokov1.IndexRequest { + f.mu.Lock() + defer f.mu.Unlock() + return append([]*yokov1.IndexRequest(nil), f.indexRequests...) +} + +func (f *fakeYokoServiceClient) searchRequestMessages() []*yokov1.SearchRequest { + f.mu.Lock() + defer f.mu.Unlock() + return append([]*yokov1.SearchRequest(nil), f.searchRequests...) +} + +func newTestClient(fake *fakeYokoServiceClient) *Client { + client := New(nil, "http://yoko.example", nil, WithServiceClient(fake)) + client.SetSchema("type Query { product: Product }") + return client +} + +func searchResponse(name string) *yokov1.SearchResponse { + return &yokov1.SearchResponse{ + Operations: []*yokov1.GeneratedOperation{ + { + Name: name, + Body: "query " + name + " { product { id } }", + Kind: yokov1.OperationKind_OPERATION_KIND_QUERY, + Description: "Fetch product", + }, + }, + } +} + +func connectError(code connect.Code, message string) error { + return connect.NewError(code, errors.New(message)) +} + +func TestSearchFirstCallIndexesSchemaThenSearchesWithReturnedID(t *testing.T) { + fake := &fakeYokoServiceClient{ + indexFunc: func(context.Context, *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { + return connect.NewResponse(&yokov1.IndexResponse{SchemaId: "schema-from-yoko"}), nil + }, + searchFunc: func(context.Context, *connect.Request[yokov1.SearchRequest]) (*connect.Response[yokov1.SearchResponse], error) { + return connect.NewResponse(searchResponse("fromSearch")), nil + }, + } + client := newTestClient(fake) + + actual, err := client.Search(context.Background(), "session-1", []string{"find products"}) + + require.NoError(t, err) + require.Equal(t, searchResponse("fromSearch"), actual) + require.Equal(t, []*yokov1.IndexRequest{ + {SchemaSdl: "type Query { product: Product }"}, + }, fake.indexRequestMessages()) + require.Equal(t, []*yokov1.SearchRequest{ + { + Prompts: []string{"find products"}, + SchemaId: "schema-from-yoko", + SessionId: "session-1", + }, + }, fake.searchRequestMessages()) +} + +func TestSearchSubsequentCallUsesCachedSchemaID(t *testing.T) { + fake := &fakeYokoServiceClient{} + client := newTestClient(fake) + + first, firstErr := client.Search(context.Background(), "session-1", []string{"first"}) + second, secondErr := client.Search(context.Background(), "session-2", []string{"second"}) + + require.NoError(t, firstErr) + require.NoError(t, secondErr) + require.Equal(t, searchResponse("op"), first) + require.Equal(t, searchResponse("op"), second) + require.Equal(t, []*yokov1.IndexRequest{ + {SchemaSdl: "type Query { product: Product }"}, + }, fake.indexRequestMessages()) + require.Equal(t, []*yokov1.SearchRequest{ + { + Prompts: []string{"first"}, + SchemaId: "schema-1", + SessionId: "session-1", + }, + { + Prompts: []string{"second"}, + SchemaId: "schema-1", + SessionId: "session-2", + }, + }, fake.searchRequestMessages()) +} + +func TestSearchReindexesAndRetriesOnceAfterNotFound(t *testing.T) { + var searchCount int + fake := &fakeYokoServiceClient{} + indexIDs := []string{"schema-initial", "schema-reindexed"} + fake.indexFunc = func(context.Context, *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { + id := indexIDs[len(fake.indexRequestMessages())-1] + return connect.NewResponse(&yokov1.IndexResponse{SchemaId: id}), nil + } + fake.searchFunc = func(context.Context, *connect.Request[yokov1.SearchRequest]) (*connect.Response[yokov1.SearchResponse], error) { + searchCount++ + if searchCount == 1 { + return nil, connectError(connect.CodeNotFound, "schema evicted") + } + return connect.NewResponse(searchResponse("retried")), nil + } + client := newTestClient(fake) + + actual, err := client.Search(context.Background(), "session-1", []string{"find products"}) + + require.NoError(t, err) + require.Equal(t, searchResponse("retried"), actual) + require.Equal(t, []*yokov1.IndexRequest{ + {SchemaSdl: "type Query { product: Product }"}, + {SchemaSdl: "type Query { product: Product }"}, + }, fake.indexRequestMessages()) + require.Equal(t, []*yokov1.SearchRequest{ + { + Prompts: []string{"find products"}, + SchemaId: "schema-initial", + SessionId: "session-1", + }, + { + Prompts: []string{"find products"}, + SchemaId: "schema-reindexed", + SessionId: "session-1", + }, + }, fake.searchRequestMessages()) +} + +func TestSearchRetryFailureSurfacesErrorAndLeavesCacheEmpty(t *testing.T) { + retryErr := connectError(connect.CodeUnavailable, "retry transport down") + indexIDs := []string{"schema-initial", "schema-reindexed", "schema-after-failure"} + fake := &fakeYokoServiceClient{} + fake.indexFunc = func(context.Context, *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { + id := indexIDs[len(fake.indexRequestMessages())-1] + return connect.NewResponse(&yokov1.IndexResponse{SchemaId: id}), nil + } + searchErrors := []error{ + connectError(connect.CodeNotFound, "schema evicted"), + retryErr, + nil, + } + fake.searchFunc = func(context.Context, *connect.Request[yokov1.SearchRequest]) (*connect.Response[yokov1.SearchResponse], error) { + err := searchErrors[len(fake.searchRequestMessages())-1] + if err != nil { + return nil, err + } + return connect.NewResponse(searchResponse("afterFailure")), nil + } + client := newTestClient(fake) + + actual, err := client.Search(context.Background(), "session-1", []string{"find products"}) + + require.Nil(t, actual) + require.ErrorIs(t, err, retryErr) + + actualAfterFailure, errAfterFailure := client.Search(context.Background(), "session-2", []string{"find products again"}) + + require.NoError(t, errAfterFailure) + require.Equal(t, searchResponse("afterFailure"), actualAfterFailure) + require.Equal(t, []*yokov1.IndexRequest{ + {SchemaSdl: "type Query { product: Product }"}, + {SchemaSdl: "type Query { product: Product }"}, + {SchemaSdl: "type Query { product: Product }"}, + }, fake.indexRequestMessages()) + require.Equal(t, []*yokov1.SearchRequest{ + { + Prompts: []string{"find products"}, + SchemaId: "schema-initial", + SessionId: "session-1", + }, + { + Prompts: []string{"find products"}, + SchemaId: "schema-reindexed", + SessionId: "session-1", + }, + { + Prompts: []string{"find products again"}, + SchemaId: "schema-after-failure", + SessionId: "session-2", + }, + }, fake.searchRequestMessages()) +} + +func TestSearchRetryNotFoundSurfacesErrorAndLeavesCacheEmpty(t *testing.T) { + retryErr := connectError(connect.CodeNotFound, "schema evicted again") + indexIDs := []string{"schema-initial", "schema-reindexed", "schema-after-failure"} + fake := &fakeYokoServiceClient{} + fake.indexFunc = func(context.Context, *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { + id := indexIDs[len(fake.indexRequestMessages())-1] + return connect.NewResponse(&yokov1.IndexResponse{SchemaId: id}), nil + } + searchErrors := []error{ + connectError(connect.CodeNotFound, "schema evicted"), + retryErr, + nil, + } + fake.searchFunc = func(context.Context, *connect.Request[yokov1.SearchRequest]) (*connect.Response[yokov1.SearchResponse], error) { + err := searchErrors[len(fake.searchRequestMessages())-1] + if err != nil { + return nil, err + } + return connect.NewResponse(searchResponse("afterFailure")), nil + } + client := newTestClient(fake) + + actual, err := client.Search(context.Background(), "session-1", []string{"find products"}) + + require.Nil(t, actual) + require.ErrorIs(t, err, retryErr) + + actualAfterFailure, errAfterFailure := client.Search(context.Background(), "session-2", []string{"find products again"}) + + require.NoError(t, errAfterFailure) + require.Equal(t, searchResponse("afterFailure"), actualAfterFailure) + require.Equal(t, []*yokov1.IndexRequest{ + {SchemaSdl: "type Query { product: Product }"}, + {SchemaSdl: "type Query { product: Product }"}, + {SchemaSdl: "type Query { product: Product }"}, + }, fake.indexRequestMessages()) +} + +func TestSetSchemaInvalidatesCachedIDAndNextSearchReindexes(t *testing.T) { + indexIDs := []string{"schema-v1", "schema-v2"} + fake := &fakeYokoServiceClient{} + fake.indexFunc = func(context.Context, *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { + id := indexIDs[len(fake.indexRequestMessages())-1] + return connect.NewResponse(&yokov1.IndexResponse{SchemaId: id}), nil + } + client := newTestClient(fake) + + _, firstErr := client.Search(context.Background(), "session-1", []string{"first"}) + client.SetSchema("type Query { review: Review }") + _, secondErr := client.Search(context.Background(), "session-2", []string{"second"}) + + require.NoError(t, firstErr) + require.NoError(t, secondErr) + require.Equal(t, "type Query { review: Review }", client.Schema()) + require.Equal(t, []*yokov1.IndexRequest{ + {SchemaSdl: "type Query { product: Product }"}, + {SchemaSdl: "type Query { review: Review }"}, + }, fake.indexRequestMessages()) + require.Equal(t, []*yokov1.SearchRequest{ + { + Prompts: []string{"first"}, + SchemaId: "schema-v1", + SessionId: "session-1", + }, + { + Prompts: []string{"second"}, + SchemaId: "schema-v2", + SessionId: "session-2", + }, + }, fake.searchRequestMessages()) +} + +func TestConcurrentFirstSearchIndexesOnce(t *testing.T) { + indexStarted := make(chan struct{}) + releaseIndex := make(chan struct{}) + var indexStartedOnce sync.Once + fake := &fakeYokoServiceClient{ + indexFunc: func(context.Context, *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { + indexStartedOnce.Do(func() { + close(indexStarted) + }) + <-releaseIndex + return connect.NewResponse(&yokov1.IndexResponse{SchemaId: "schema-shared"}), nil + }, + } + client := newTestClient(fake) + + var wg sync.WaitGroup + wg.Add(2) + results := make([]*yokov1.SearchResponse, 2) + errs := make([]error, 2) + go func() { + defer wg.Done() + results[0], errs[0] = client.Search(context.Background(), "session-1", []string{"first"}) + }() + <-indexStarted + go func() { + defer wg.Done() + results[1], errs[1] = client.Search(context.Background(), "session-2", []string{"second"}) + }() + time.Sleep(25 * time.Millisecond) + close(releaseIndex) + wg.Wait() + + require.NoError(t, errs[0]) + require.NoError(t, errs[1]) + require.Equal(t, searchResponse("op"), results[0]) + require.Equal(t, searchResponse("op"), results[1]) + require.Equal(t, []*yokov1.IndexRequest{ + {SchemaSdl: "type Query { product: Product }"}, + }, fake.indexRequestMessages()) + assert.Equal(t, 2, len(fake.searchRequestMessages())) +} + +func TestConcurrentFirstSearchIndexFailureReturnsErrorToBothAndLeavesCacheEmpty(t *testing.T) { + indexErr := connectError(connect.CodeUnavailable, "index unavailable") + indexStarted := make(chan struct{}) + releaseIndex := make(chan struct{}) + var indexStartedOnce sync.Once + fake := &fakeYokoServiceClient{ + indexFunc: func(context.Context, *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { + indexStartedOnce.Do(func() { + close(indexStarted) + }) + <-releaseIndex + return nil, indexErr + }, + } + client := newTestClient(fake) + + var wg sync.WaitGroup + wg.Add(2) + results := make([]*yokov1.SearchResponse, 2) + errs := make([]error, 2) + go func() { + defer wg.Done() + results[0], errs[0] = client.Search(context.Background(), "session-1", []string{"first"}) + }() + <-indexStarted + go func() { + defer wg.Done() + results[1], errs[1] = client.Search(context.Background(), "session-2", []string{"second"}) + }() + time.Sleep(25 * time.Millisecond) + close(releaseIndex) + wg.Wait() + + require.Nil(t, results[0]) + require.Nil(t, results[1]) + require.ErrorIs(t, errs[0], indexErr) + require.ErrorIs(t, errs[1], indexErr) + require.Equal(t, []*yokov1.IndexRequest{ + {SchemaSdl: "type Query { product: Product }"}, + }, fake.indexRequestMessages()) + require.Equal(t, []*yokov1.SearchRequest(nil), fake.searchRequestMessages()) + + fake.indexFunc = func(context.Context, *connect.Request[yokov1.IndexRequest]) (*connect.Response[yokov1.IndexResponse], error) { + return connect.NewResponse(&yokov1.IndexResponse{SchemaId: "schema-after-error"}), nil + } + actual, err := client.Search(context.Background(), "session-3", []string{"third"}) + + require.NoError(t, err) + require.Equal(t, searchResponse("op"), actual) + require.Equal(t, []*yokov1.IndexRequest{ + {SchemaSdl: "type Query { product: Product }"}, + {SchemaSdl: "type Query { product: Product }"}, + }, fake.indexRequestMessages()) +} + +func TestSearchBubblesUpArbitraryConnectErrors(t *testing.T) { + searchErr := connectError(connect.CodeUnavailable, "search unavailable") + fake := &fakeYokoServiceClient{ + searchFunc: func(context.Context, *connect.Request[yokov1.SearchRequest]) (*connect.Response[yokov1.SearchResponse], error) { + return nil, searchErr + }, + } + client := newTestClient(fake) + + actual, err := client.Search(context.Background(), "session-1", []string{"find products"}) + + require.Nil(t, actual) + require.ErrorIs(t, err, searchErr) + require.Equal(t, []*yokov1.IndexRequest{ + {SchemaSdl: "type Query { product: Product }"}, + }, fake.indexRequestMessages()) + require.Equal(t, []*yokov1.SearchRequest{ + { + Prompts: []string{"find products"}, + SchemaId: "schema-1", + SessionId: "session-1", + }, + }, fake.searchRequestMessages()) +} + +func TestSchemaGetterReturnsCurrentSchema(t *testing.T) { + client := New(nil, "http://yoko.example", nil, WithServiceClient(&fakeYokoServiceClient{})) + + require.Equal(t, "", client.Schema()) + client.SetSchema("type Query { store: Store }") + require.Equal(t, "type Query { store: Store }", client.Schema()) +} diff --git a/router/internal/codemode/yoko/searcher.go b/router/internal/codemode/yoko/searcher.go new file mode 100644 index 0000000000..8f40ed8a72 --- /dev/null +++ b/router/internal/codemode/yoko/searcher.go @@ -0,0 +1,16 @@ +package yoko + +import ( + "context" + + yokov1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/code_mode/yoko/v1" +) + +type Searcher interface { + Search(ctx context.Context, sessionID string, prompts []string) (*yokov1.SearchResponse, error) + SetSchema(string) + Schema() string + EnsureIndexed(ctx context.Context) error +} + +var _ Searcher = (*Client)(nil) diff --git a/router/mcp.config.yaml b/router/mcp.config.yaml new file mode 100644 index 0000000000..a4726b577b --- /dev/null +++ b/router/mcp.config.yaml @@ -0,0 +1,82 @@ +# MCP (Model Context Protocol) multi-graph demo +# yaml-language-server: $schema=./pkg/config/config.schema.json +# +# Quick start: +# go run ./cmd/router -config mcp.config.yaml + +execution_config: + file: + path: "../router-tests/testenv/testdata/config.json" + watch: true + watch_interval: 20s + +mcp: + enabled: true + graph_name: "my-graph" + omit_tool_name_prefix: true + + servers: + - name: "anilist" + path: /anilist + expose_schema: true + enable_arbitrary_operations: true + omit_tool_name_prefix: true + storage: + provider_id: anilist + watch: true + watch_interval: 1s + upstream: + url: "https://graphql.anilist.co" + - name: "countries" + path: /countries + expose_schema: true + enable_arbitrary_operations: true + omit_tool_name_prefix: true + storage: + provider_id: countries + watch: true + watch_interval: 1s + upstream: + url: "https://countries.trevorblades.com" + - name: "rickandmorty" + path: /rickandmorty + expose_schema: true + enable_arbitrary_operations: true + omit_tool_name_prefix: true + storage: + provider_id: rickandmorty + watch: true + watch_interval: 1s + upstream: + url: "https://rickandmortyapi.com/graphql" + - name: "swapi" + path: /swapi + expose_schema: true + enable_arbitrary_operations: true + omit_tool_name_prefix: true + storage: + provider_id: swapi + watch: true + watch_interval: 1s + upstream: + url: "https://swapi-graphql.netlify.app/graphql" + + session: + stateless: true + + exclude_mutations: false + enable_arbitrary_operations: true + expose_schema: true + + router_url: "http://localhost:3002/graphql" + +storage_providers: + file_system: + - id: rickandmorty + path: /Users/asoorm/go/src/github.com/wundergraph/conference-mcp-demo/.operations/rickandmorty + - id: countries + path: /Users/asoorm/go/src/github.com/wundergraph/conference-mcp-demo/.operations/countries + - id: anilist + path: /Users/asoorm/go/src/github.com/wundergraph/conference-mcp-demo/.operations/anilist + - id: swapi + path: /Users/asoorm/go/src/github.com/wundergraph/conference-mcp-demo/.operations/swapi diff --git a/router/pkg/config/code_mode_config_test.go b/router/pkg/config/code_mode_config_test.go new file mode 100644 index 0000000000..839c0ab403 --- /dev/null +++ b/router/pkg/config/code_mode_config_test.go @@ -0,0 +1,278 @@ +package config + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMCPCodeModeConfigurationDefaults(t *testing.T) { + f := createTempFileFromFixture(t, ` +version: "1" +`) + + cfg, err := LoadConfig([]string{f}) + require.NoError(t, err) + + assert.Equal(t, MCPCodeModeConfiguration{ + Enabled: false, + Server: MCPCodeModeServerConfig{ListenAddr: "localhost:5027"}, + RequireMutationApproval: true, + ExecuteTimeout: 120 * time.Second, + MaxResultBytes: 32768, + Sandbox: MCPCodeModeSandboxConfig{ + Timeout: 5 * time.Second, + MaxMemoryMB: 16, + MaxInputSizeBytes: 65536, + MaxOutputSizeBytes: 1048576, + }, + QueryGeneration: MCPCodeModeQueryGenConfig{ + Enabled: false, + Endpoint: "", + Timeout: 10 * time.Second, + Auth: MCPCodeModeQueryGenAuthConfig{ + Type: "static", + StaticToken: "", + TokenEndpoint: "", + ClientID: "", + ClientSecret: "", + }, + }, + NamedOps: MCPCodeModeNamedOpsConfig{ + Enabled: false, + SessionTTL: 30 * time.Minute, + MaxSessions: 1000, + MaxBundleBytes: 262144, + Storage: MCPCodeModeNamedOpsStorageConfig{ + ProviderID: "", + KeyPrefix: "cosmo_code_mode", + }, + }, + }, cfg.Config.MCP.CodeMode) +} + +func TestMCPCodeModeConfigurationFullYAMLOverride(t *testing.T) { + f := createTempFileFromFixture(t, ` +version: "1" + +mcp: + session: + stateless: false + code_mode: + enabled: true + server: + listen_addr: "0.0.0.0:6027" + require_mutation_approval: false + execute_timeout: "45s" + max_result_bytes: 64000 + sandbox: + timeout: "7s" + max_memory_mb: 32 + max_input_size_bytes: 131072 + max_output_size_bytes: 2097152 + query_generation: + enabled: true + endpoint: "https://yoko.example.com" + timeout: "15s" + auth: + type: "jwt" + static_token: "unused-static" + token_endpoint: "https://auth.example.com/token" + client_id: "router-client" + client_secret: "router-secret" + named_ops: + enabled: true + session_ttl: "45m" + max_sessions: 2000 + max_bundle_bytes: 524288 + storage: + provider_id: "my_redis" + key_prefix: "custom_code_mode" +`) + + cfg, err := LoadConfig([]string{f}) + require.NoError(t, err) + + assert.Equal(t, MCPCodeModeConfiguration{ + Enabled: true, + Server: MCPCodeModeServerConfig{ListenAddr: "0.0.0.0:6027"}, + RequireMutationApproval: false, + ExecuteTimeout: 45 * time.Second, + MaxResultBytes: 64000, + Sandbox: MCPCodeModeSandboxConfig{ + Timeout: 7 * time.Second, + MaxMemoryMB: 32, + MaxInputSizeBytes: 131072, + MaxOutputSizeBytes: 2097152, + }, + QueryGeneration: MCPCodeModeQueryGenConfig{ + Enabled: true, + Endpoint: "https://yoko.example.com", + Timeout: 15 * time.Second, + Auth: MCPCodeModeQueryGenAuthConfig{ + Type: "jwt", + StaticToken: "unused-static", + TokenEndpoint: "https://auth.example.com/token", + ClientID: "router-client", + ClientSecret: "router-secret", + }, + }, + NamedOps: MCPCodeModeNamedOpsConfig{ + Enabled: true, + SessionTTL: 45 * time.Minute, + MaxSessions: 2000, + MaxBundleBytes: 524288, + Storage: MCPCodeModeNamedOpsStorageConfig{ + ProviderID: "my_redis", + KeyPrefix: "custom_code_mode", + }, + }, + }, cfg.Config.MCP.CodeMode) +} + +func TestMCPCodeModeConfigurationEnvOverride(t *testing.T) { + t.Setenv("MCP_CODE_MODE_ENABLED", "true") + t.Setenv("MCP_CODE_MODE_LISTEN_ADDR", "127.0.0.1:6027") + t.Setenv("MCP_CODE_MODE_REQUIRE_MUTATION_APPROVAL", "false") + t.Setenv("MCP_CODE_MODE_EXECUTE_TIMEOUT", "30s") + t.Setenv("MCP_CODE_MODE_MAX_RESULT_BYTES", "49152") + t.Setenv("MCP_CODE_MODE_SANDBOX_TIMEOUT", "8s") + t.Setenv("MCP_CODE_MODE_SANDBOX_MAX_MEMORY_MB", "64") + t.Setenv("MCP_CODE_MODE_SANDBOX_MAX_INPUT_SIZE_BYTES", "262144") + t.Setenv("MCP_CODE_MODE_SANDBOX_MAX_OUTPUT_SIZE_BYTES", "3145728") + t.Setenv("MCP_CODE_MODE_QUERY_GENERATION_ENABLED", "true") + t.Setenv("MCP_CODE_MODE_QUERY_GENERATION_ENDPOINT", "https://env-yoko.example.com") + t.Setenv("MCP_CODE_MODE_QUERY_GENERATION_TIMEOUT", "20s") + t.Setenv("MCP_CODE_MODE_QUERY_GENERATION_AUTH_TYPE", "jwt") + t.Setenv("MCP_CODE_MODE_QUERY_GENERATION_AUTH_STATIC_TOKEN", "env-static-token") + t.Setenv("MCP_CODE_MODE_QUERY_GENERATION_AUTH_TOKEN_ENDPOINT", "https://env-auth.example.com/token") + t.Setenv("MCP_CODE_MODE_QUERY_GENERATION_AUTH_CLIENT_ID", "env-client") + t.Setenv("MCP_CODE_MODE_QUERY_GENERATION_AUTH_CLIENT_SECRET", "env-secret") + t.Setenv("MCP_CODE_MODE_NAMED_OPS_ENABLED", "true") + t.Setenv("MCP_CODE_MODE_NAMED_OPS_SESSION_TTL", "1h") + t.Setenv("MCP_CODE_MODE_NAMED_OPS_MAX_SESSIONS", "3000") + t.Setenv("MCP_CODE_MODE_NAMED_OPS_MAX_BUNDLE_BYTES", "1048576") + t.Setenv("MCP_CODE_MODE_NAMED_OPS_STORAGE_PROVIDER_ID", "env_redis") + t.Setenv("MCP_CODE_MODE_NAMED_OPS_STORAGE_KEY_PREFIX", "env_code_mode") + + f := createTempFileFromFixture(t, ` +version: "1" + +mcp: + session: + stateless: false +`) + + cfg, err := LoadConfig([]string{f}) + require.NoError(t, err) + + assert.Equal(t, MCPCodeModeConfiguration{ + Enabled: true, + Server: MCPCodeModeServerConfig{ListenAddr: "127.0.0.1:6027"}, + RequireMutationApproval: false, + ExecuteTimeout: 30 * time.Second, + MaxResultBytes: 49152, + Sandbox: MCPCodeModeSandboxConfig{ + Timeout: 8 * time.Second, + MaxMemoryMB: 64, + MaxInputSizeBytes: 262144, + MaxOutputSizeBytes: 3145728, + }, + QueryGeneration: MCPCodeModeQueryGenConfig{ + Enabled: true, + Endpoint: "https://env-yoko.example.com", + Timeout: 20 * time.Second, + Auth: MCPCodeModeQueryGenAuthConfig{ + Type: "jwt", + StaticToken: "env-static-token", + TokenEndpoint: "https://env-auth.example.com/token", + ClientID: "env-client", + ClientSecret: "env-secret", + }, + }, + NamedOps: MCPCodeModeNamedOpsConfig{ + Enabled: true, + SessionTTL: time.Hour, + MaxSessions: 3000, + MaxBundleBytes: 1048576, + Storage: MCPCodeModeNamedOpsStorageConfig{ + ProviderID: "env_redis", + KeyPrefix: "env_code_mode", + }, + }, + }, cfg.Config.MCP.CodeMode) +} + +func TestValidateMCPCodeMode(t *testing.T) { + tests := []struct { + name string + cfg MCPCodeModeConfiguration + sessionStateless bool + wantErr string + }{ + { + name: "code mode disabled skips validation", + cfg: MCPCodeModeConfiguration{ + Enabled: false, + NamedOps: MCPCodeModeNamedOpsConfig{ + Enabled: true, + }, + }, + }, + { + name: "named ops disabled skips validation", + cfg: MCPCodeModeConfiguration{ + Enabled: true, + NamedOps: MCPCodeModeNamedOpsConfig{ + Enabled: false, + }, + }, + }, + { + name: "memory backend (no provider_id) is valid", + cfg: MCPCodeModeConfiguration{ + Enabled: true, + NamedOps: MCPCodeModeNamedOpsConfig{ + Enabled: true, + Storage: MCPCodeModeNamedOpsStorageConfig{KeyPrefix: "cosmo_code_mode"}, + }, + }, + }, + { + name: "redis-backed (provider_id set) is valid", + cfg: MCPCodeModeConfiguration{ + Enabled: true, + NamedOps: MCPCodeModeNamedOpsConfig{ + Enabled: true, + Storage: MCPCodeModeNamedOpsStorageConfig{ + ProviderID: "my_redis", + KeyPrefix: "cosmo_code_mode", + }, + }, + }, + }, + { + name: "stateless named ops does not fail boot validation", + cfg: MCPCodeModeConfiguration{ + Enabled: true, + NamedOps: MCPCodeModeNamedOpsConfig{ + Enabled: true, + }, + }, + sessionStateless: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateMCPCodeMode(&tt.cfg, tt.sessionStateless) + if tt.wantErr == "" { + require.NoError(t, err) + return + } + require.EqualError(t, err, tt.wantErr) + }) + } +} diff --git a/router/pkg/config/code_mode_validation.go b/router/pkg/config/code_mode_validation.go new file mode 100644 index 0000000000..5039ec9a4e --- /dev/null +++ b/router/pkg/config/code_mode_validation.go @@ -0,0 +1,23 @@ +package config + +func ValidateMCPCodeMode(cfg *MCPCodeModeConfiguration, sessionStateless bool) error { + if !cfg.Enabled { + return nil + } + + if !cfg.NamedOps.Enabled { + return nil + } + + // Storage backend selection: when ProviderID is set, the router resolves it + // against the central storage_providers registry (Redis backend). Otherwise + // the in-memory backend is used. The provider lookup error (unknown id) is + // emitted by the router at startup, not here. + + // Named ops require stateful MCP sessions to work, but this intentionally + // does not fail boot. The Code Mode runtime emits the warn log on first + // reload so deployments can enable Code Mode before flipping session mode. + _ = sessionStateless + + return nil +} diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index 22f58cf72a..82035691e8 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -1142,15 +1142,16 @@ type CacheWarmupConfiguration struct { } type MCPConfiguration struct { - Enabled bool `yaml:"enabled" envDefault:"false" env:"MCP_ENABLED"` - Server MCPServer `yaml:"server,omitempty"` - Storage MCPStorageConfig `yaml:"storage,omitempty"` - Session MCPSessionConfig `yaml:"session,omitempty"` - GraphName string `yaml:"graph_name" envDefault:"mygraph" env:"MCP_GRAPH_NAME"` - ExcludeMutations bool `yaml:"exclude_mutations" envDefault:"false" env:"MCP_EXCLUDE_MUTATIONS"` - EnableArbitraryOperations bool `yaml:"enable_arbitrary_operations" envDefault:"false" env:"MCP_ENABLE_ARBITRARY_OPERATIONS"` - ExposeSchema bool `yaml:"expose_schema" envDefault:"false" env:"MCP_EXPOSE_SCHEMA"` - RouterURL string `yaml:"router_url,omitempty" env:"MCP_ROUTER_URL"` + Enabled bool `yaml:"enabled" envDefault:"false" env:"MCP_ENABLED"` + Server MCPServer `yaml:"server,omitempty"` + Storage MCPStorageConfig `yaml:"storage,omitempty"` + Session MCPSessionConfig `yaml:"session,omitempty"` + CodeMode MCPCodeModeConfiguration `yaml:"code_mode,omitempty" envPrefix:"MCP_CODE_MODE_"` + GraphName string `yaml:"graph_name" envDefault:"mygraph" env:"MCP_GRAPH_NAME"` + ExcludeMutations bool `yaml:"exclude_mutations" envDefault:"false" env:"MCP_EXCLUDE_MUTATIONS"` + EnableArbitraryOperations bool `yaml:"enable_arbitrary_operations" envDefault:"false" env:"MCP_ENABLE_ARBITRARY_OPERATIONS"` + ExposeSchema bool `yaml:"expose_schema" envDefault:"false" env:"MCP_EXPOSE_SCHEMA"` + RouterURL string `yaml:"router_url,omitempty" env:"MCP_ROUTER_URL"` // OmitToolNamePrefix removes the "execute_operation_" prefix from MCP tool names. // When enabled, GetUser becomes get_user. When disabled (default), GetUser becomes execute_operation_get_user. OmitToolNamePrefix bool `yaml:"omit_tool_name_prefix" envDefault:"false" env:"MCP_OMIT_TOOL_NAME_PREFIX"` @@ -1158,6 +1159,105 @@ type MCPConfiguration struct { // ResourceDocumentation is a URL to a human-readable page describing this MCP resource, // its access policies, and how to get started. Included in RFC 9728 Protected Resource Metadata if set. ResourceDocumentation string `yaml:"resource_documentation,omitempty" env:"MCP_RESOURCE_DOCUMENTATION"` + + // Servers is the list of MCP servers ("collections") to expose on the shared HTTP listener. + // When empty, the legacy top-level fields above define a single implicit server at path "/mcp". + // When non-empty, each entry is fully self-described — no inheritance from top-level fields. + Servers []MCPServerEntry `yaml:"servers,omitempty"` +} + +// MCPServerEntry describes a single MCP server (collection) mounted at a path on the shared listener. +type MCPServerEntry struct { + // Name is a unique identifier for this server, used in metrics, logs, and tool descriptions. + Name string `yaml:"name"` + // Path is the URL path this server is mounted at (e.g. "/mcp", "/internal"). Must start with "/". + Path string `yaml:"path"` + // Storage references a file_system storage provider that holds operation files for this server. + Storage MCPStorageConfig `yaml:"storage,omitempty"` + // Upstream optionally overrides the local Cosmo supergraph as the GraphQL backend. + // When nil, the server routes to the local supergraph (default behavior). + Upstream *MCPUpstreamConfig `yaml:"upstream,omitempty"` + // OAuth configures authentication for this server; absent or enabled=false means no auth. + OAuth MCPOAuthConfiguration `yaml:"oauth,omitempty"` + // Per-server feature toggles (see MCPConfiguration for descriptions). + ExcludeMutations bool `yaml:"exclude_mutations"` + EnableArbitraryOperations bool `yaml:"enable_arbitrary_operations"` + ExposeSchema bool `yaml:"expose_schema"` + OmitToolNamePrefix bool `yaml:"omit_tool_name_prefix"` + Session MCPSessionConfig `yaml:"session,omitempty"` + ResourceDocumentation string `yaml:"resource_documentation,omitempty"` +} + +// MCPUpstreamConfig points an MCP server at a non-Cosmo GraphQL endpoint. +type MCPUpstreamConfig struct { + // URL is the GraphQL endpoint to which operations are forwarded. + URL string `yaml:"url"` + // Schema describes how to obtain the GraphQL schema for this upstream + // (used to compile operations into MCP tools). + Schema MCPUpstreamSchemaConfig `yaml:"schema,omitempty"` + // Headers are forwarded to the upstream on every request (in addition to per-request headers). + Headers map[string]string `yaml:"headers,omitempty"` +} + +// MCPUpstreamSchemaConfig describes the schema source for an upstream-bound collection. +type MCPUpstreamSchemaConfig struct { + // File is the path to an SDL file. If the file does not exist, the upstream is introspected + // at startup and the result is written to this path (introspection fallback is a v2 feature + // — for v1, the file must exist). + File string `yaml:"file,omitempty"` +} + +// NormalizeServers returns the canonical list of MCPServerEntry the router will mount. +// +// When Servers is non-empty, entries are returned as-is after uniqueness/format validation. +// When Servers is empty, a single implicit entry is synthesized from the legacy top-level +// fields and mounted at "/mcp" — preserving backwards compatibility with existing configs. +func (c *MCPConfiguration) NormalizeServers() ([]MCPServerEntry, error) { + if len(c.Servers) > 0 { + seenPaths := make(map[string]struct{}, len(c.Servers)) + seenNames := make(map[string]struct{}, len(c.Servers)) + for i := range c.Servers { + srv := &c.Servers[i] + if srv.Name == "" { + return nil, fmt.Errorf("mcp.servers[%d]: name is required", i) + } + if srv.Path == "" { + return nil, fmt.Errorf("mcp.servers[%d] (%q): path is required", i, srv.Name) + } + if !strings.HasPrefix(srv.Path, "/") { + return nil, fmt.Errorf("mcp.servers[%d] (%q): path must start with '/'", i, srv.Name) + } + if _, dup := seenPaths[srv.Path]; dup { + return nil, fmt.Errorf("mcp.servers: duplicate path %q", srv.Path) + } + seenPaths[srv.Path] = struct{}{} + if _, dup := seenNames[srv.Name]; dup { + return nil, fmt.Errorf("mcp.servers: duplicate name %q", srv.Name) + } + seenNames[srv.Name] = struct{}{} + if srv.Upstream != nil && srv.Upstream.URL == "" { + return nil, fmt.Errorf("mcp.servers[%d] (%q): upstream.url is required when upstream is set", i, srv.Name) + } + } + return c.Servers, nil + } + + name := c.GraphName + if name == "" { + name = "mygraph" + } + return []MCPServerEntry{{ + Name: name, + Path: "/mcp", + Storage: c.Storage, + OAuth: c.OAuth, + ExcludeMutations: c.ExcludeMutations, + EnableArbitraryOperations: c.EnableArbitraryOperations, + ExposeSchema: c.ExposeSchema, + OmitToolNamePrefix: c.OmitToolNamePrefix, + Session: c.Session, + ResourceDocumentation: c.ResourceDocumentation, + }}, nil } type MCPOAuthConfiguration struct { @@ -1203,8 +1303,65 @@ type MCPSessionConfig struct { Stateless bool `yaml:"stateless" envDefault:"true" env:"MCP_SESSION_STATELESS"` } +type MCPCodeModeConfiguration struct { + Enabled bool `yaml:"enabled" envDefault:"false" env:"ENABLED"` + Server MCPCodeModeServerConfig `yaml:"server,omitempty" envPrefix:""` + RequireMutationApproval bool `yaml:"require_mutation_approval" envDefault:"true" env:"REQUIRE_MUTATION_APPROVAL"` + ExecuteTimeout time.Duration `yaml:"execute_timeout" envDefault:"120s" env:"EXECUTE_TIMEOUT"` + MaxResultBytes int `yaml:"max_result_bytes" envDefault:"32768" env:"MAX_RESULT_BYTES"` + Sandbox MCPCodeModeSandboxConfig `yaml:"sandbox,omitempty" envPrefix:"SANDBOX_"` + QueryGeneration MCPCodeModeQueryGenConfig `yaml:"query_generation,omitempty" envPrefix:"QUERY_GENERATION_"` + NamedOps MCPCodeModeNamedOpsConfig `yaml:"named_ops,omitempty" envPrefix:"NAMED_OPS_"` +} + +type MCPCodeModeServerConfig struct { + ListenAddr string `yaml:"listen_addr" envDefault:"localhost:5027" env:"LISTEN_ADDR"` +} + +type MCPCodeModeSandboxConfig struct { + Timeout time.Duration `yaml:"timeout" envDefault:"5s" env:"TIMEOUT"` + MaxMemoryMB int `yaml:"max_memory_mb" envDefault:"16" env:"MAX_MEMORY_MB"` + MaxInputSizeBytes int `yaml:"max_input_size_bytes" envDefault:"65536" env:"MAX_INPUT_SIZE_BYTES"` + MaxOutputSizeBytes int `yaml:"max_output_size_bytes" envDefault:"1048576" env:"MAX_OUTPUT_SIZE_BYTES"` +} + +type MCPCodeModeQueryGenConfig struct { + Enabled bool `yaml:"enabled" envDefault:"false" env:"ENABLED"` + Endpoint string `yaml:"endpoint,omitempty" env:"ENDPOINT"` + Timeout time.Duration `yaml:"timeout" envDefault:"10s" env:"TIMEOUT"` + Auth MCPCodeModeQueryGenAuthConfig `yaml:"auth,omitempty" envPrefix:"AUTH_"` +} + +type MCPCodeModeQueryGenAuthConfig struct { + Type string `yaml:"type" envDefault:"static" env:"TYPE"` + StaticToken string `yaml:"static_token,omitempty" env:"STATIC_TOKEN"` + TokenEndpoint string `yaml:"token_endpoint,omitempty" env:"TOKEN_ENDPOINT"` + ClientID string `yaml:"client_id,omitempty" env:"CLIENT_ID"` + ClientSecret string `yaml:"client_secret,omitempty" env:"CLIENT_SECRET"` +} + +type MCPCodeModeNamedOpsConfig struct { + Enabled bool `yaml:"enabled" envDefault:"false" env:"ENABLED"` + SessionTTL time.Duration `yaml:"session_ttl" envDefault:"30m" env:"SESSION_TTL"` + MaxSessions int `yaml:"max_sessions" envDefault:"1000" env:"MAX_SESSIONS"` + MaxBundleBytes int `yaml:"max_bundle_bytes" envDefault:"262144" env:"MAX_BUNDLE_BYTES"` + Storage MCPCodeModeNamedOpsStorageConfig `yaml:"storage,omitempty" envPrefix:"STORAGE_"` +} + +type MCPCodeModeNamedOpsStorageConfig struct { + ProviderID string `yaml:"provider_id,omitempty" env:"PROVIDER_ID"` + KeyPrefix string `yaml:"key_prefix" envDefault:"cosmo_code_mode" env:"KEY_PREFIX"` +} + type MCPStorageConfig struct { ProviderID string `yaml:"provider_id,omitempty" env:"MCP_STORAGE_PROVIDER_ID"` + // Watch enables periodic scanning of the storage provider directory for added, + // modified, or removed operation files. When a change is detected, the MCP + // collection's tools are reloaded without restarting the router. + // Currently only supported when the referenced provider is a file_system provider. + Watch bool `yaml:"watch,omitempty" envDefault:"false" env:"MCP_STORAGE_WATCH"` + // WatchInterval is the polling interval used when Watch is true. Defaults to 1s. + WatchInterval time.Duration `yaml:"watch_interval,omitempty" envDefault:"1s" env:"MCP_STORAGE_WATCH_INTERVAL"` } type MCPServer struct { @@ -1462,5 +1619,9 @@ func LoadConfig(configFilePaths []string) (*LoadResult, error) { cfg.Config.SubgraphErrorPropagation.AllowedExtensionFields = unique.SliceElements(append(cfg.Config.SubgraphErrorPropagation.AllowedExtensionFields, "code", "stacktrace")) } + if err := ValidateMCPCodeMode(&cfg.Config.MCP.CodeMode, cfg.Config.MCP.Session.Stateless); err != nil { + return nil, err + } + return cfg, nil } diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index e90ae50407..b89da14be0 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -2396,6 +2396,150 @@ } } }, + "code_mode": { + "type": "object", + "description": "Configuration for the Code Mode MCP server surface.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "default": false + }, + "server": { + "type": "object", + "additionalProperties": false, + "properties": { + "listen_addr": { + "type": "string", + "default": "localhost:5027", + "format": "hostname-port" + } + } + }, + "require_mutation_approval": { + "type": "boolean", + "default": true + }, + "execute_timeout": { + "type": "string", + "default": "120s", + "duration": { + "minimum": "0s" + } + }, + "max_result_bytes": { + "type": "integer", + "default": 32768 + }, + "sandbox": { + "type": "object", + "additionalProperties": false, + "properties": { + "timeout": { + "type": "string", + "default": "5s", + "duration": { + "minimum": "0s" + } + }, + "max_memory_mb": { + "type": "integer", + "default": 16 + }, + "max_input_size_bytes": { + "type": "integer", + "default": 65536 + }, + "max_output_size_bytes": { + "type": "integer", + "default": 1048576 + } + } + }, + "query_generation": { + "type": "object", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "default": false + }, + "endpoint": { + "type": "string" + }, + "timeout": { + "type": "string", + "default": "10s", + "duration": { + "minimum": "0s" + } + }, + "auth": { + "type": "object", + "additionalProperties": false, + "properties": { + "type": { + "type": "string", + "default": "static" + }, + "static_token": { + "type": "string" + }, + "token_endpoint": { + "type": "string" + }, + "client_id": { + "type": "string" + }, + "client_secret": { + "type": "string" + } + } + } + } + }, + "named_ops": { + "type": "object", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "default": false + }, + "session_ttl": { + "type": "string", + "default": "30m", + "duration": { + "minimum": "0s" + } + }, + "max_sessions": { + "type": "integer", + "default": 1000 + }, + "max_bundle_bytes": { + "type": "integer", + "default": 262144 + }, + "storage": { + "type": "object", + "additionalProperties": false, + "properties": { + "provider_id": { + "type": "string", + "description": "ID of an entry in storage_providers.redis used to back named ops. When unset, the in-memory backend is used." + }, + "key_prefix": { + "type": "string", + "default": "cosmo_code_mode", + "description": "Key prefix applied to all named-ops keys written to the Redis storage provider." + } + } + } + } + } + } + }, "graph_name": { "type": "string", "default": "mygraph", @@ -2521,11 +2665,7 @@ "default": ["sig"], "items": { "type": "string", - "enum": [ - "sig", - "enc", - "" - ] + "enum": ["sig", "enc", ""] } }, "algorithms": { @@ -2633,6 +2773,87 @@ } } } + }, + "servers": { + "type": "array", + "description": "List of MCP servers (collections) mounted on the shared HTTP listener. When empty, the legacy top-level fields define a single implicit server at path '/mcp'. When non-empty, each entry is fully self-described — no inheritance from top-level fields.", + "items": { + "type": "object", + "additionalProperties": false, + "required": ["name", "path"], + "properties": { + "name": { + "type": "string", + "description": "Unique identifier for this server, used in metrics, logs, and tool descriptions." + }, + "path": { + "type": "string", + "description": "URL path this server is mounted at (e.g. '/mcp', '/internal'). Must start with '/' and be unique across servers." + }, + "storage": { + "type": "object", + "additionalProperties": false, + "description": "Storage provider for this server's operation files.", + "properties": { + "provider_id": { + "type": "string", + "description": "ID of a configured file_system storage provider." + }, + "watch": { + "type": "boolean", + "default": false, + "description": "Periodically scan the storage provider directory for changes and hot-reload the collection's tools when files are added, modified, or removed. Currently only honored for file_system providers." + }, + "watch_interval": { + "type": "string", + "default": "1s", + "description": "Polling interval used when watch is true (e.g. '500ms', '1s', '10s')." + } + }, + "required": ["provider_id"] + }, + "upstream": { + "type": "object", + "additionalProperties": false, + "description": "Override the local Cosmo supergraph as the GraphQL backend for this server. When omitted, this server routes to the local supergraph.", + "required": ["url"], + "properties": { + "url": { + "type": "string", + "description": "GraphQL endpoint URL to which operations are forwarded.", + "format": "http-url" + }, + "schema": { + "type": "object", + "additionalProperties": false, + "description": "Schema source for this upstream (used to compile operations into MCP tools).", + "properties": { + "file": { + "type": "string", + "description": "Path to an SDL file containing the upstream's GraphQL schema." + } + } + }, + "headers": { + "type": "object", + "additionalProperties": { "type": "string" }, + "description": "Headers forwarded to the upstream on every request." + } + } + }, + "oauth": { "$ref": "#/properties/mcp/properties/oauth" }, + "session": { "$ref": "#/properties/mcp/properties/session" }, + "exclude_mutations": { "type": "boolean", "default": false }, + "enable_arbitrary_operations": { "type": "boolean", "default": false }, + "expose_schema": { "type": "boolean", "default": false }, + "omit_tool_name_prefix": { "type": "boolean", "default": false }, + "resource_documentation": { + "type": "string", + "description": "URL to a human-readable page describing this MCP resource. Included in RFC 9728 Protected Resource Metadata if set.", + "format": "http-url" + } + } + } } }, "if": { diff --git a/router/pkg/config/fixtures/full.yaml b/router/pkg/config/fixtures/full.yaml index 42c6986d38..391d0838a0 100644 --- a/router/pkg/config/fixtures/full.yaml +++ b/router/pkg/config/fixtures/full.yaml @@ -64,11 +64,43 @@ mcp: omit_tool_name_prefix: false graph_name: cosmo router_url: https://cosmo-router.wundergraph.com + session: + stateless: false server: listen_addr: localhost:5025 base_url: 'http://localhost:5025' storage: provider_id: mcp + code_mode: + enabled: true + server: + listen_addr: localhost:6027 + require_mutation_approval: false + execute_timeout: 45s + max_result_bytes: 64000 + sandbox: + timeout: 7s + max_memory_mb: 32 + max_input_size_bytes: 131072 + max_output_size_bytes: 2097152 + query_generation: + enabled: true + endpoint: https://yoko.example.com + timeout: 15s + auth: + type: jwt + static_token: static-token + token_endpoint: https://auth.example.com/token + client_id: router-client + client_secret: router-secret + named_ops: + enabled: true + session_ttl: 45m + max_sessions: 2000 + max_bundle_bytes: 524288 + storage: + provider_id: my_redis + key_prefix: custom_code_mode watch_config: enabled: true diff --git a/router/pkg/config/testdata/config_defaults.json b/router/pkg/config/testdata/config_defaults.json index 8dc81bf6ed..a14830dd64 100644 --- a/router/pkg/config/testdata/config_defaults.json +++ b/router/pkg/config/testdata/config_defaults.json @@ -154,11 +154,50 @@ "BaseURL": "" }, "Storage": { - "ProviderID": "" + "ProviderID": "", + "Watch": false, + "WatchInterval": 1000000000 }, "Session": { "Stateless": true }, + "CodeMode": { + "Enabled": false, + "Server": { + "ListenAddr": "localhost:5027" + }, + "RequireMutationApproval": true, + "ExecuteTimeout": 120000000000, + "MaxResultBytes": 32768, + "Sandbox": { + "Timeout": 5000000000, + "MaxMemoryMB": 16, + "MaxInputSizeBytes": 65536, + "MaxOutputSizeBytes": 1048576 + }, + "QueryGeneration": { + "Enabled": false, + "Endpoint": "", + "Timeout": 10000000000, + "Auth": { + "Type": "static", + "StaticToken": "", + "TokenEndpoint": "", + "ClientID": "", + "ClientSecret": "" + } + }, + "NamedOps": { + "Enabled": false, + "SessionTTL": 1800000000000, + "MaxSessions": 1000, + "MaxBundleBytes": 262144, + "Storage": { + "ProviderID": "", + "KeyPrefix": "cosmo_code_mode" + } + } + }, "GraphName": "mygraph", "ExcludeMutations": false, "EnableArbitraryOperations": false, @@ -180,7 +219,8 @@ "ScopeChallengeIncludeTokenScopes": false, "MaxScopeCombinations": 2048 }, - "ResourceDocumentation": "" + "ResourceDocumentation": "", + "Servers": null }, "ConnectRPC": { "Enabled": false, diff --git a/router/pkg/config/testdata/config_full.json b/router/pkg/config/testdata/config_full.json index 33cc0c92e6..6d0e304557 100644 --- a/router/pkg/config/testdata/config_full.json +++ b/router/pkg/config/testdata/config_full.json @@ -199,10 +199,49 @@ "BaseURL": "http://localhost:5025" }, "Storage": { - "ProviderID": "mcp" + "ProviderID": "mcp", + "Watch": false, + "WatchInterval": 1000000000 }, "Session": { - "Stateless": true + "Stateless": false + }, + "CodeMode": { + "Enabled": true, + "Server": { + "ListenAddr": "localhost:6027" + }, + "RequireMutationApproval": false, + "ExecuteTimeout": 45000000000, + "MaxResultBytes": 64000, + "Sandbox": { + "Timeout": 7000000000, + "MaxMemoryMB": 32, + "MaxInputSizeBytes": 131072, + "MaxOutputSizeBytes": 2097152 + }, + "QueryGeneration": { + "Enabled": true, + "Endpoint": "https://yoko.example.com", + "Timeout": 15000000000, + "Auth": { + "Type": "jwt", + "StaticToken": "static-token", + "TokenEndpoint": "https://auth.example.com/token", + "ClientID": "router-client", + "ClientSecret": "router-secret" + } + }, + "NamedOps": { + "Enabled": true, + "SessionTTL": 2700000000000, + "MaxSessions": 2000, + "MaxBundleBytes": 524288, + "Storage": { + "ProviderID": "my_redis", + "KeyPrefix": "custom_code_mode" + } + } }, "GraphName": "cosmo", "ExcludeMutations": false, @@ -225,7 +264,8 @@ "ScopeChallengeIncludeTokenScopes": false, "MaxScopeCombinations": 2048 }, - "ResourceDocumentation": "" + "ResourceDocumentation": "", + "Servers": null }, "ConnectRPC": { "Enabled": false, diff --git a/router/pkg/mcpserver/code_mode.go b/router/pkg/mcpserver/code_mode.go new file mode 100644 index 0000000000..d97f9a5322 --- /dev/null +++ b/router/pkg/mcpserver/code_mode.go @@ -0,0 +1,279 @@ +package mcpserver + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "go.uber.org/zap" + + "github.com/wundergraph/cosmo/router/internal/codemode/harness" + "github.com/wundergraph/cosmo/router/internal/codemode/sandbox" + "github.com/wundergraph/cosmo/router/internal/codemode/storage" +) + +const codeModeToolName = "code_mode_run_js" + +// codeModeToolDescription returns the LLM-facing description of code_mode_run_js. +// opNames are the operations bound as tools.(vars). Variable shapes and +// return types are NOT duplicated here — they live on the per-op MCP tools the +// model already sees alongside this one. We only list bound names + the rules +// the model can't infer (sandbox shape, restrictions, return-value discipline). +func codeModeToolDescription(opNames []string) string { + body := `Run an async arrow function in a V8 sandbox where every operation tool on this MCP server is pre-bound as ` + "`tools.(vars)`" + `. Compose, batch, or aggregate multiple ops in ONE call so the calling model only ever sees the final answer — not the raw payloads. + +# Shape (strict) + +` + "`source`" + ` MUST be exactly one expression: an async arrow function. The harness invokes it for you. + +CORRECT: +` + "```" + `js +async () => { + const r = await tools.SomeOp({ id: 1 }); + return r.data; +} +` + "```" + ` + +WRONG — these all fail with ShapeCheck or TranspileError: +- top-level await: ` + "`const r = await tools.X(); return r;`" + ` +- IIFE: ` + "`(async () => {...})()`" + ` +- non-arrow root: ` + "`function main(){...}`" + `, ` + "`tools.X().then(...)`" + ` +- multiple statements: ` + "`const x = 1; async () => x;`" + ` +- import/export at top: ` + "`import x from 'y'; async () => {}`" + ` + +# Tool bindings — refer to the per-op MCP tools + +Every operation tool you see on this server is also bound inside the sandbox: + + tools.(vars) → Promise<{ data, errors? }> + +` + "`vars`" + ` matches the per-op MCP tool's ` + "`inputSchema`" + ` exactly. ` + "`data`" + ` is the GraphQL response data shape — call ` + "`get_operation_info({operationName: ''})`" + ` directly (outside code_mode) if you need to see the full query body. + +DO NOT GUESS NAMES — only the ones below are bound. + +` + buildBoundList(opNames) + ` + +# Helpers in scope + +- ` + "`notNull(value, msg?)`" + ` — throws if null/undefined. +- ` + "`compact(value)`" + ` — recursively strips null/undefined from objects/arrays. +- ` + "`Promise.all([...])`" + ` for parallel calls. +- Standard JS array methods, destructuring, optional chaining. + +# Sandbox restrictions + +- No ` + "`console`" + ` (throws ConsoleUnavailable). Return diagnostics in the result instead. +- ` + "`Date.now()` returns 0, `Math.random()` returns 0" + ` — pinned for determinism. +- No ` + "`eval` / `Function` / `import` / `require`" + `; no arbitrary HTTP. +- ~256 ` + "`tools.*`" + ` calls per execution; ~64 KB result cap. + +# Output envelope + +Success: ` + "`{ \"result\": , \"truncated\": false, \"warnings\": [] }`" + ` +Failure: ` + "`{ \"result\": null, \"error\": { \"name\", \"message\", \"stack\" } }`" + ` + +Common error names: ShapeCheck, TranspileError, InputTooLarge, HostCallLimitExceeded, ConsoleUnavailable. + +# Return-value discipline (the whole point of this tool) + +Every byte you return is context the calling model pays tokens for on the next turn. **Collapse data INSIDE the sandbox.** + +1. Return the answer, not the data. "How many" → a number, not the list. +2. Don't enrich helpfully. No names, IDs, samples, or metadata the user didn't ask for. +3. Aggregate before returning. Count/sum/group/filter in JS, not in the model. +4. If a list IS the answer, project to only the fields needed: ` + "`{ id, name }`" + ` not the full object. +5. One op, no transform → use the per-op tool directly. code_mode is overhead unless you're transforming. + +**Worked example — "how many countries speak Spanish":** + +GOOD: +` + "```" + `js +async () => { + const r = await tools.GetLanguage({ code: 'es' }); + return r.data?.language?.countries?.length ?? 0; +} +// → 31 +` + "```" + ` + +BAD (every country name leaks back into the model's context): +` + "```" + `js +async () => { + const r = await tools.GetLanguage({ code: 'es' }); + return { + languageName: r.data?.language?.name, // not asked for + countryCount: r.data?.language?.countries?.length, + countries: r.data?.language?.countries.map(c => c.name), // pollution + }; +} +// → { languageName: "Spanish", countryCount: 31, countries: [...31 strings...] } +` + "```" + ` + +If the user asks a follow-up like "name them", that's a separate code_mode call. Don't pre-fetch.` + + return body +} + +// buildBoundList renders the list of bound operation names. We surface only +// names — variable/return shapes are already in the per-op MCP tools' own +// descriptions and inputSchemas, which the model sees alongside this one. +func buildBoundList(names []string) string { + if len(names) == 0 { + return "Bound names: (none — no operations are loaded for this server)" + } + var b strings.Builder + b.WriteString("Bound names:\n") + for _, n := range names { + b.WriteString(" - tools.") + b.WriteString(n) + b.WriteString("\n") + } + return b.String() +} + +// codeModeRunJSInput is the JSON input schema for the code_mode_run_js tool. +type codeModeRunJSInput struct { + Source string `json:"source"` +} + +// ensureCodeModeSandbox lazily creates the V8 sandbox bound to this server's +// upstream GraphQL endpoint. The sandbox is reused across reloads — only the +// op catalog (looked up via StorageLookup) needs to refresh. +func (s *GraphQLSchemaServer) ensureCodeModeSandbox() (*sandbox.Sandbox, error) { + if s.codeModeSandbox != nil { + return s.codeModeSandbox, nil + } + sb, err := sandbox.New(sandbox.Config{ + RouterGraphQLEndpoint: s.routerGraphQLEndpoint, + StorageLookup: s.codeModeStorageLookup, + Logger: s.logger.With(zap.String("component", "code_mode_sandbox")), + }) + if err != nil { + return nil, fmt.Errorf("create code mode sandbox: %w", err) + } + s.codeModeSandbox = sb + return sb, nil +} + +// codeModeStorageLookup adapts the file-loaded operation catalog to the +// storage.SessionOp shape expected by the sandbox host. SessionID is ignored — +// each server has a single, file-driven catalog shared across all calls. +func (s *GraphQLSchemaServer) codeModeStorageLookup(_ context.Context, _ string, name string) (storage.SessionOp, bool, error) { + if s.operationsManager == nil { + return storage.SessionOp{}, false, nil + } + op := s.operationsManager.GetOperation(name) + if op == nil { + return storage.SessionOp{}, false, nil + } + return storage.SessionOp{ + Name: op.Name, + Body: op.OperationString, + Kind: operationKindFromType(op.OperationType), + Description: op.Description, + }, true, nil +} + +func operationKindFromType(opType string) storage.OperationKind { + if opType == "mutation" { + return storage.OperationKindMutation + } + return storage.OperationKindQuery +} + +// codeModeToolDescriptor builds the code_mode_run_js MCP tool. Variable shapes +// and return types are NOT inlined here — the model sees those on the per-op +// MCP tools that share this server. Only the bound names + sandbox/usage rules +// are described to keep this tool's description lean. +func (s *GraphQLSchemaServer) codeModeToolDescriptor() *mcp.Tool { + ops := s.operationsManager.GetFilteredOperations() + names := make([]string, 0, len(ops)) + for _, op := range ops { + names = append(names, op.Name) + } + + desc := codeModeToolDescription(names) + + inputSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "source": map[string]any{ + "type": "string", + "description": "A single async arrow function expression. MUST be of the form `async () => { ... }`. Do NOT invoke it (no trailing `()`), do NOT use IIFE wrappers like `(async () => {...})()`, do NOT use top-level await, do NOT use multiple statements. The harness invokes the arrow.", + }, + }, + "required": []string{"source"}, + "additionalProperties": false, + } + + return &mcp.Tool{ + Name: codeModeToolName, + Description: desc, + InputSchema: inputSchema, + Annotations: &mcp.ToolAnnotations{ + Title: "Code Mode (compose ops in JS)", + }, + } +} + +// handleCodeModeRunJS executes user-supplied JS/TS in a V8 sandbox where this +// server's loaded GraphQL operations are bound as tools.(vars). +func (s *GraphQLSchemaServer) handleCodeModeRunJS() func(ctx context.Context, request *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + var input codeModeRunJSInput + if req != nil && req.Params != nil && len(req.Params.Arguments) > 0 { + if err := json.Unmarshal(req.Params.Arguments, &input); err != nil { + return codeModeErrorResult("source must be a string: " + err.Error()), nil + } + } + if input.Source == "" { + return codeModeErrorResult("source is required"), nil + } + + sb, err := s.ensureCodeModeSandbox() + if err != nil { + return codeModeErrorResult(err.Error()), nil + } + + ops := s.operationsManager.GetFilteredOperations() + names := make([]string, 0, len(ops)) + for _, op := range ops { + names = append(names, op.Name) + } + + pipeline := &harness.Pipeline{ + Sandbox: sb, + MaxInputBytes: 64 * 1024, + MaxResultBytes: 64 * 1024, + } + + var headers http.Header + if h, hErr := headersFromContext(ctx); hErr == nil { + headers = h + } + + resp, err := pipeline.Execute(ctx, harness.PipelineRequest{ + SessionID: s.graphName, + ToolNames: names, + Source: input.Source, + RequestHeaders: headers, + ApprovalGate: sandbox.AutoApprove, + }) + if err != nil { + return codeModeErrorResult("execute failed: " + err.Error()), nil + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: string(resp.Encoded)}}, + }, nil + } +} + +func codeModeErrorResult(msg string) *mcp.CallToolResult { + return &mcp.CallToolResult{ + IsError: true, + Content: []mcp.Content{&mcp.TextContent{Text: "code_mode_run_js: " + msg}}, + } +} diff --git a/router/pkg/mcpserver/introspection.go b/router/pkg/mcpserver/introspection.go new file mode 100644 index 0000000000..78250f1505 --- /dev/null +++ b/router/pkg/mcpserver/introspection.go @@ -0,0 +1,153 @@ +package mcpserver + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/astprinter" + "github.com/wundergraph/graphql-go-tools/v2/pkg/introspection" +) + +// introspectionQuery is the standard GraphQL introspection query as defined in the +// graphql-spec. Servers that support introspection respond with the full schema. +const introspectionQuery = `query IntrospectionQuery { + __schema { + queryType { name } + mutationType { name } + subscriptionType { name } + types { ...FullType } + directives { + name description locations + args { ...InputValue } + } + } +} +fragment FullType on __Type { + kind name description + fields(includeDeprecated: true) { + name description + args { ...InputValue } + type { ...TypeRef } + isDeprecated deprecationReason + } + inputFields { ...InputValue } + interfaces { ...TypeRef } + enumValues(includeDeprecated: true) { + name description isDeprecated deprecationReason + } + possibleTypes { ...TypeRef } +} +fragment InputValue on __InputValue { + name description + type { ...TypeRef } + defaultValue +} +fragment TypeRef on __Type { + kind name + ofType { + kind name + ofType { + kind name + ofType { + kind name + ofType { + kind name + ofType { + kind name + ofType { + kind name + ofType { kind name } + } + } + } + } + } + } +}` + +// IntrospectUpstreamSDL runs the standard GraphQL introspection query against the given +// upstream URL and returns the result as SDL text. Extra headers are sent on the request. +// +// Used by upstream-bound MCP collections when no SDL file is provided — the schema is +// fetched from the live upstream and (optionally) cached to disk for subsequent runs. +func IntrospectUpstreamSDL(ctx context.Context, url string, extraHeaders map[string]string) (string, error) { + body, err := json.Marshal(struct { + Query string `json:"query"` + OperationName string `json:"operationName"` + }{ + Query: introspectionQuery, + OperationName: "IntrospectionQuery", + }) + if err != nil { + return "", fmt.Errorf("encode introspection query: %w", err) + } + + reqCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return "", fmt.Errorf("build introspection request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + for k, v := range extraHeaders { + req.Header.Set(k, v) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", fmt.Errorf("introspect %s: %w", url, err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode/100 != 2 { + raw, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("introspect %s: status %d: %s", url, resp.StatusCode, truncate(string(raw), 256)) + } + + var envelope struct { + Data json.RawMessage `json:"data"` + Errors []struct { + Message string `json:"message"` + } `json:"errors"` + } + if err := json.NewDecoder(resp.Body).Decode(&envelope); err != nil { + return "", fmt.Errorf("decode introspection response: %w", err) + } + if len(envelope.Errors) > 0 { + msgs := make([]string, 0, len(envelope.Errors)) + for _, e := range envelope.Errors { + msgs = append(msgs, e.Message) + } + return "", fmt.Errorf("upstream returned introspection errors: %s", strings.Join(msgs, "; ")) + } + if len(envelope.Data) == 0 { + return "", fmt.Errorf("upstream returned no introspection data (introspection may be disabled)") + } + + conv := &introspection.JsonConverter{} + doc, err := conv.GraphQLDocument(bytes.NewReader(envelope.Data)) + if err != nil { + return "", fmt.Errorf("convert introspection JSON to schema: %w", err) + } + + sdl, err := astprinter.PrintString(doc) + if err != nil { + return "", fmt.Errorf("print SDL: %w", err) + } + return sdl, nil +} + +func truncate(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] + "..." +} \ No newline at end of file diff --git a/router/pkg/mcpserver/multi_server.go b/router/pkg/mcpserver/multi_server.go new file mode 100644 index 0000000000..7a3dc365af --- /dev/null +++ b/router/pkg/mcpserver/multi_server.go @@ -0,0 +1,179 @@ +package mcpserver + +import ( + "context" + "errors" + "fmt" + "net/http" + "time" + + "go.uber.org/zap" + + nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" +) + +// MultiServer hosts multiple MCP collections on a single shared HTTP listener. +// Each collection has its own URL path, operations directory, optional OAuth policy, +// and optional upstream override. Use NewMultiServer to construct, Start to bind the +// listener, Reload to fan out supergraph reloads, and Stop to shut down gracefully. +type MultiServer struct { + listenAddr string + logger *zap.Logger + handlers []*GraphQLSchemaServer + httpServer *http.Server +} + +// NewMultiServer constructs a MultiServer that will mount the given handlers on +// listenAddr when Start is called. The handlers are not yet started. +func NewMultiServer(listenAddr string, logger *zap.Logger, handlers ...*GraphQLSchemaServer) (*MultiServer, error) { + if listenAddr == "" { + return nil, fmt.Errorf("listen_addr is required") + } + if len(handlers) == 0 { + return nil, fmt.Errorf("at least one MCP server handler is required") + } + if logger == nil { + logger = zap.NewNop() + } + seenPaths := make(map[string]struct{}, len(handlers)) + for _, h := range handlers { + if _, dup := seenPaths[h.path]; dup { + return nil, fmt.Errorf("duplicate MCP server path %q", h.path) + } + seenPaths[h.path] = struct{}{} + } + return &MultiServer{ + listenAddr: listenAddr, + logger: logger, + handlers: handlers, + }, nil +} + +// Start mounts every handler on a shared mux, primes upstream-bound handlers with +// their SDL-derived schema (so their tools are available immediately), and binds +// the HTTP listener. The listener runs in a background goroutine. +// +// Supergraph-bound handlers remain "empty" until Reload is called by the router +// with the federated schema — same lifecycle as the legacy single-server flow. +func (m *MultiServer) Start() error { + mux := http.NewServeMux() + for _, h := range m.handlers { + h.RegisterRoutes(mux) + + // Upstream-bound handlers carry their own SDL — load it now so their + // tools are ready before any client connects. Supergraph-bound handlers + // wait for the router's Reload(supergraphSchema, ...) call. + if h.HasUpstreamSchema() { + doc, err := parseSDL(h.upstreamSchemaSDL) + if err != nil { + return fmt.Errorf("mcp server %q: parse upstream SDL: %w", h.graphName, err) + } + if err := h.Reload(doc, nil); err != nil { + return fmt.Errorf("mcp server %q: initial reload: %w", h.graphName, err) + } + } + + // Per-collection operations directory watcher: hot-reloads tools when + // .graphql / .gql files are added, modified, or removed. + // Supergraph-bound handlers without an initial Reload yet still benefit — + // the watcher is no-op until the first Reload populates a schema, after + // which it picks up file changes on the next tick. + watchEnabled, interval := h.WatchSettings() + if watchEnabled && h.OperationsDir() != "" { + handler := h // capture loop variable for the callback + err := WatchOperationsDir(handler.Context(), handler.OperationsDir(), interval, func() { + if err := handler.ReloadOperations(); err != nil { + m.logger.Warn("hot-reload of MCP operations failed", + zap.String("name", handler.graphName), + zap.String("path", handler.path), + zap.Error(err)) + return + } + m.logger.Info("MCP operations hot-reloaded", + zap.String("name", handler.graphName), + zap.String("path", handler.path)) + }, m.logger.With(zap.String("mcp_server_name", handler.graphName))) + if err != nil { + return fmt.Errorf("mcp server %q: start operations watcher: %w", h.graphName, err) + } + } + } + + m.httpServer = &http.Server{ + Addr: m.listenAddr, + Handler: mux, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 60 * time.Second, + } + + paths := make([]string, 0, len(m.handlers)) + for _, h := range m.handlers { + paths = append(paths, h.path) + } + m.logger.Info("MCP multi-server starting", + zap.String("listen_addr", m.listenAddr), + zap.Strings("paths", paths), + ) + + go func() { + defer m.logger.Info("MCP multi-server stopped") + if err := m.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + m.logger.Error("MCP multi-server failed", zap.Error(err)) + } + }() + + return nil +} + +// Reload fans out a supergraph schema update to every handler that tracks the +// supergraph (i.e. has no upstream override). Upstream-bound handlers are skipped. +func (m *MultiServer) Reload(schema *ast.Document, fieldConfigs []*nodev1.FieldConfiguration) error { + var firstErr error + for _, h := range m.handlers { + if h.HasUpstreamSchema() { + continue + } + if err := h.Reload(schema, fieldConfigs); err != nil { + m.logger.Error("MCP server reload failed", + zap.String("name", h.graphName), + zap.String("path", h.path), + zap.Error(err)) + if firstErr == nil { + firstErr = err + } + } + } + return firstErr +} + +// Stop gracefully shuts down the HTTP listener and cancels every handler's +// background context (JWKS pollers, etc.). +func (m *MultiServer) Stop(ctx context.Context) error { + for _, h := range m.handlers { + if h.cancel != nil { + h.cancel() + } + } + if m.httpServer == nil { + return nil + } + shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + if err := m.httpServer.Shutdown(shutdownCtx); err != nil { + return fmt.Errorf("MCP multi-server shutdown: %w", err) + } + return nil +} + +// parseSDL parses an SDL string into an *ast.Document. +func parseSDL(sdl string) (*ast.Document, error) { + doc, report := astparser.ParseGraphqlDocumentString(sdl) + if report.HasErrors() { + return nil, fmt.Errorf("%s", report.Error()) + } + return &doc, nil +} \ No newline at end of file diff --git a/router/pkg/mcpserver/operations_watcher.go b/router/pkg/mcpserver/operations_watcher.go new file mode 100644 index 0000000000..a704605ef2 --- /dev/null +++ b/router/pkg/mcpserver/operations_watcher.go @@ -0,0 +1,126 @@ +package mcpserver + +import ( + "context" + "errors" + "io/fs" + "path/filepath" + "strings" + "time" + + "go.uber.org/zap" +) + +// WatchOperationsDir starts a ticker that scans dir on every interval, detects +// when the set of .graphql / .gql files (or their modification times) has changed, +// and invokes onChange after a settling tick. The settling tick avoids firing the +// callback in the middle of a multi-file save: the watcher waits until a tick +// passes with no further changes before treating the directory as "settled". +// +// The watcher runs until ctx is cancelled. It is non-blocking — start it in a goroutine. +// +// Errors from individual scans are logged at debug level; the watcher does not exit +// on transient I/O errors so that flaky filesystems (network mounts, container +// volumes) don't take down hot-reload. +func WatchOperationsDir(ctx context.Context, dir string, interval time.Duration, onChange func(), logger *zap.Logger) error { + if dir == "" { + return errors.New("dir is required") + } + if interval <= 0 { + return errors.New("interval must be greater than zero") + } + if onChange == nil { + return errors.New("onChange callback is required") + } + if logger == nil { + logger = zap.NewNop() + } + + logger = logger.With(zap.String("component", "mcp_operations_watcher"), zap.String("dir", dir)) + + prev, err := snapshotOperationFiles(dir) + if err != nil { + // Don't fail startup on transient errors — start with an empty snapshot + // and the next successful scan will pick everything up. + logger.Debug("initial directory snapshot failed; starting with empty baseline", zap.Error(err)) + prev = map[string]fileFingerprint{} + } + + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + pendingReload := false + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + curr, scanErr := snapshotOperationFiles(dir) + if scanErr != nil { + logger.Debug("scan failed", zap.Error(scanErr)) + continue + } + if !fingerprintsEqual(prev, curr) { + prev = curr + pendingReload = true + continue + } + if pendingReload { + pendingReload = false + logger.Info("operations directory changed; reloading tools and notifying connected clients") + onChange() + } + } + } + }() + + return nil +} + +// fileFingerprint identifies a file's relevant state (modification time + size). +type fileFingerprint struct { + modTime time.Time + size int64 +} + +// snapshotOperationFiles returns a map of path → fingerprint for every .graphql / .gql +// file under dir. Used by the watcher to detect added, removed, or modified operations. +func snapshotOperationFiles(dir string) (map[string]fileFingerprint, error) { + out := map[string]fileFingerprint{} + walkErr := filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + return nil + } + ext := strings.ToLower(filepath.Ext(path)) + if ext != ".graphql" && ext != ".gql" { + return nil + } + info, err := d.Info() + if err != nil { + return nil // skip unreadable entries; treat as if absent + } + out[path] = fileFingerprint{modTime: info.ModTime(), size: info.Size()} + return nil + }) + if walkErr != nil { + return nil, walkErr + } + return out, nil +} + +func fingerprintsEqual(a, b map[string]fileFingerprint) bool { + if len(a) != len(b) { + return false + } + for k, va := range a { + vb, ok := b[k] + if !ok || !va.modTime.Equal(vb.modTime) || va.size != vb.size { + return false + } + } + return true +} diff --git a/router/pkg/mcpserver/server.go b/router/pkg/mcpserver/server.go index a3c5c36858..0811a43ead 100644 --- a/router/pkg/mcpserver/server.go +++ b/router/pkg/mcpserver/server.go @@ -3,12 +3,15 @@ package mcpserver import ( "bytes" "context" + "crypto/sha256" + "encoding/hex" "encoding/json" "errors" "fmt" "io" "net/http" "slices" + "sort" "strings" "time" @@ -19,6 +22,7 @@ import ( "go.uber.org/zap" nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" + "github.com/wundergraph/cosmo/router/internal/codemode/sandbox" "github.com/wundergraph/cosmo/router/internal/headers" "github.com/wundergraph/cosmo/router/pkg/authentication" "github.com/wundergraph/cosmo/router/pkg/config" @@ -67,8 +71,13 @@ type Options struct { GraphName string // OperationsDir is the directory where GraphQL operations are stored OperationsDir string - // ListenAddr is the address where the server should listen to + // ListenAddr is the address where the server should listen to. + // Only used when this server is started standalone via Serve(); ignored when + // mounted on a shared listener via RegisterRoutes(mux). ListenAddr string + // Path is the URL path this MCP server is mounted at (e.g. "/mcp", "/internal"). + // Defaults to "/mcp" when empty. + Path string // Enabled determines whether the MCP server should be started Enabled bool // Logger is the logger to be used @@ -93,6 +102,19 @@ type Options struct { ServerBaseURL string // ResourceDocumentation is a URL to a human-readable page describing this resource ResourceDocumentation string + // UpstreamSchemaSDL is the GraphQL schema (as SDL text) for upstream-bound collections + // that don't share the local supergraph's schema. When set, Reload uses this schema + // instead of the supergraph schema passed in to Reload(). + UpstreamSchemaSDL string + // UpstreamHeaders are forwarded to the upstream GraphQL endpoint on every request + // (in addition to per-request headers). + UpstreamHeaders map[string]string + // WatchOperations enables periodic scanning of OperationsDir for added, + // modified, or removed .graphql / .gql files. When a change is detected, + // the collection's tools are reloaded without restarting the router. + WatchOperations bool + // OperationsWatchInterval is the polling interval used when WatchOperations is true. + OperationsWatchInterval time.Duration } // GraphQLSchemaServer represents an MCP server that works with GraphQL schemas and operations @@ -101,6 +123,7 @@ type GraphQLSchemaServer struct { graphName string operationsDir string listenAddr string + path string logger *zap.Logger httpClient *http.Client requestTimeout time.Duration @@ -120,6 +143,34 @@ type GraphQLSchemaServer struct { serverBaseURL string resourceDocumentation string authMiddleware *MCPAuthMiddleware + upstreamSchemaSDL string + upstreamHeaders map[string]string + watchOperations bool + operationsWatchInterval time.Duration + // lastSchema and lastFieldConfigs are remembered so ReloadOperations() can + // re-run the operations directory load without a fresh schema input. + lastSchema *ast.Document + lastFieldConfigs []*nodev1.FieldConfiguration + // lastToolFingerprints is the fingerprint of every tool currently registered + // with s.server. Used by Reload to compute a diff and only emit Remove/Add + // calls for tools that actually changed — avoiding spurious tools/list_changed + // notifications when an mtime touch produced no semantic change. + lastToolFingerprints map[string]string + // ctx is the per-server context (cancelled on Stop) — used for operation watchers. + ctx context.Context + // codeModeSandbox is the V8 isolate used by the code_mode_run_js tool. + // Lazily initialized on first use; reused across reloads since it only + // depends on the upstream endpoint, not the op catalog. + codeModeSandbox *sandbox.Sandbox +} + +// desiredTool bundles a Tool spec, its handler, and a content fingerprint so +// the diff-aware reload path can compare against the prior set without +// re-invoking the SDK's AddTool when nothing has changed. +type desiredTool struct { + tool *mcp.Tool + handler mcp.ToolHandler + fingerprint string } type graphqlRequest struct { @@ -206,6 +257,7 @@ func NewGraphQLSchemaServer(ctx context.Context, routerGraphQLEndpoint string, o GraphName: "graph", OperationsDir: "operations", ListenAddr: "0.0.0.0:5025", + Path: "/mcp", Enabled: false, Logger: zap.NewNop(), RequestTimeout: 30 * time.Second, @@ -218,6 +270,13 @@ func NewGraphQLSchemaServer(ctx context.Context, routerGraphQLEndpoint string, o opt(options) } + if options.Path == "" { + options.Path = "/mcp" + } + if !strings.HasPrefix(options.Path, "/") { + return nil, fmt.Errorf("MCP server path must start with '/': got %q", options.Path) + } + ctx, cancel := context.WithCancel(ctx) var authMiddleware *MCPAuthMiddleware @@ -261,10 +320,12 @@ func NewGraphQLSchemaServer(ctx context.Context, routerGraphQLEndpoint string, o return nil, fmt.Errorf("failed to create token decoder: %w", err) } - // Build resource metadata URL for WWW-Authenticate header + // Build resource metadata URL for WWW-Authenticate header. + // Per RFC 9728, each protected resource gets its own metadata endpoint + // at /.well-known/oauth-protected-resource{path}. resourceMetadataURL := "" if options.ServerBaseURL != "" { - resourceMetadataURL = fmt.Sprintf("%s/.well-known/oauth-protected-resource/mcp", options.ServerBaseURL) + resourceMetadataURL = fmt.Sprintf("%s/.well-known/oauth-protected-resource%s", options.ServerBaseURL, options.Path) } authMiddleware, err = NewMCPAuthMiddleware(tokenDecoder, resourceMetadataURL, options.OAuthConfig.Scopes, options.OAuthConfig.ScopeChallengeIncludeTokenScopes) @@ -304,6 +365,7 @@ func NewGraphQLSchemaServer(ctx context.Context, routerGraphQLEndpoint string, o graphName: options.GraphName, operationsDir: options.OperationsDir, listenAddr: options.ListenAddr, + path: options.Path, logger: options.Logger, httpClient: httpClient, requestTimeout: options.RequestTimeout, @@ -319,6 +381,11 @@ func NewGraphQLSchemaServer(ctx context.Context, routerGraphQLEndpoint string, o serverBaseURL: options.ServerBaseURL, resourceDocumentation: options.ResourceDocumentation, authMiddleware: authMiddleware, + upstreamSchemaSDL: options.UpstreamSchemaSDL, + upstreamHeaders: options.UpstreamHeaders, + watchOperations: options.WatchOperations, + operationsWatchInterval: options.OperationsWatchInterval, + ctx: ctx, } return gs, nil @@ -350,6 +417,38 @@ func WithListenAddr(listenAddr string) func(*Options) { } } +// WithPath sets the URL path this MCP server is mounted at (e.g. "/mcp", "/internal"). +func WithPath(path string) func(*Options) { + return func(o *Options) { + o.Path = path + } +} + +// WithUpstreamSchemaSDL sets the SDL text used as this server's GraphQL schema. +// When set, Reload uses this schema instead of the supergraph schema passed in. +func WithUpstreamSchemaSDL(sdl string) func(*Options) { + return func(o *Options) { + o.UpstreamSchemaSDL = sdl + } +} + +// WithUpstreamHeaders sets headers forwarded to the upstream GraphQL endpoint +// on every request (in addition to per-request headers). +func WithUpstreamHeaders(headers map[string]string) func(*Options) { + return func(o *Options) { + o.UpstreamHeaders = headers + } +} + +// WithWatchOperations enables periodic scanning of the operations directory and +// hot-reload of MCP tools when files are added, modified, or removed. +func WithWatchOperations(enabled bool, interval time.Duration) func(*Options) { + return func(o *Options) { + o.WatchOperations = enabled + o.OperationsWatchInterval = interval + } +} + func WithLogger(logger *zap.Logger) func(*Options) { return func(o *Options) { o.Logger = logger @@ -426,18 +525,13 @@ func WithResourceDocumentation(url string) func(*Options) { } } -// Serve starts the server with the configured options and returns the HTTP server. -func (s *GraphQLSchemaServer) Serve() (*http.Server, error) { - // Create custom HTTP server - httpServer := &http.Server{ - Addr: s.listenAddr, - ReadTimeout: 30 * time.Second, - WriteTimeout: 30 * time.Second, - IdleTimeout: 60 * time.Second, - } - - // Create MCP streamable HTTP handler - // The getServer function returns our MCP server instance for each request +// RegisterRoutes registers this server's HTTP handlers (the MCP endpoint and, +// if OAuth is enabled, the RFC 9728 Protected Resource Metadata endpoint) on the +// given mux. The CORS middleware configured on this server is applied to its handlers. +// +// Use RegisterRoutes when mounting multiple MCP servers on a shared HTTP listener +// (see MultiServer). Use Serve when running a single server with its own listener. +func (s *GraphQLSchemaServer) RegisterRoutes(mux *http.ServeMux) { // Disable the SDK's built-in cross-origin protection (Sec-Fetch-Site check) // because the router already applies its own CORS middleware around the handler. cop := http.NewCrossOriginProtection() @@ -455,11 +549,10 @@ func (s *GraphQLSchemaServer) Serve() (*http.Server, error) { middleware := cors.New(s.corsConfig) - mux := http.NewServeMux() - - // OAuth 2.0 Protected Resource Metadata (RFC 9728) — public discovery endpoint + // OAuth 2.0 Protected Resource Metadata (RFC 9728) — per-resource discovery endpoint. + // Each MCP server gets its own /.well-known/oauth-protected-resource{path} entry. if s.oauthConfig != nil && s.oauthConfig.Enabled && s.oauthConfig.AuthorizationServerURL != "" { - mux.Handle("/.well-known/oauth-protected-resource/mcp", middleware(http.HandlerFunc(s.handleProtectedResourceMetadata))) + mux.Handle("/.well-known/oauth-protected-resource"+s.path, middleware(http.HandlerFunc(s.handleProtectedResourceMetadata))) } // Inject request headers into context so tool handlers can forward them @@ -469,16 +562,29 @@ func (s *GraphQLSchemaServer) Serve() (*http.Server, error) { streamableHTTPHandler.ServeHTTP(w, r) }) if s.authMiddleware != nil { - mux.Handle("/mcp", middleware(s.authMiddleware.HTTPMiddleware(mcpHandler))) + mux.Handle(s.path, middleware(s.authMiddleware.HTTPMiddleware(mcpHandler))) } else { - mux.Handle("/mcp", middleware(mcpHandler)) + mux.Handle(s.path, middleware(mcpHandler)) } +} +// Serve starts the server with the configured options and returns the HTTP server. +func (s *GraphQLSchemaServer) Serve() (*http.Server, error) { + // Create custom HTTP server + httpServer := &http.Server{ + Addr: s.listenAddr, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 60 * time.Second, + } + + mux := http.NewServeMux() + s.RegisterRoutes(mux) httpServer.Handler = mux logger := []zap.Field{ zap.String("listen_addr", s.listenAddr), - zap.String("path", "/mcp"), + zap.String("path", s.path), zap.String("operations_dir", s.operationsDir), zap.String("graph_name", s.graphName), zap.Bool("exclude_mutations", s.excludeMutations), @@ -519,6 +625,8 @@ func (s *GraphQLSchemaServer) Reload(schema *ast.Document, fieldConfigs []*nodev return fmt.Errorf("server is not started") } + s.lastSchema = schema + s.lastFieldConfigs = fieldConfigs s.schemaCompiler = NewSchemaCompiler(s.logger) s.operationsManager = NewOperationsManager(schema, s.logger, s.excludeMutations) @@ -539,16 +647,88 @@ func (s *GraphQLSchemaServer) Reload(schema *ast.Document, fieldConfigs []*nodev s.authMiddleware.SetScopeExtractor(NewScopeExtractor(fieldConfigs, schema, maxScopeCombinations)) } - s.server.RemoveTools(s.registeredTools...) - s.registeredTools = nil - - if err := s.registerTools(); err != nil { - return fmt.Errorf("failed to register tools: %w", err) + desired, err := s.buildDesiredTools() + if err != nil { + return fmt.Errorf("failed to build tool set: %w", err) } + s.applyToolDiff(desired) + return nil } +// applyToolDiff applies the difference between the currently-registered tools +// (s.lastToolFingerprints) and the desired set. Tools whose fingerprint has not +// changed are left untouched — the SDK fires a tools/list_changed notification +// on every AddTool/RemoveTools call, so skipping unchanged tools keeps client +// chatter to a minimum and means an mtime-only file touch produces zero +// notifications. +func (s *GraphQLSchemaServer) applyToolDiff(desired map[string]desiredTool) { + addNames := make([]string, 0, len(desired)) + for name := range desired { + addNames = append(addNames, name) + } + sort.Strings(addNames) + + var added, changed []string + for _, name := range addNames { + d := desired[name] + prev, existed := s.lastToolFingerprints[name] + switch { + case !existed: + added = append(added, name) + case prev != d.fingerprint: + changed = append(changed, name) + } + } + + var removed []string + for name := range s.lastToolFingerprints { + if _, keep := desired[name]; !keep { + removed = append(removed, name) + } + } + sort.Strings(removed) + + // Apply the diff. RemoveTools batches into one notification; AddTool sends + // one per call but we only invoke it for actually-changed tools. + if len(removed) > 0 { + s.server.RemoveTools(removed...) + } + for _, name := range added { + d := desired[name] + s.server.AddTool(d.tool, d.handler) + } + for _, name := range changed { + d := desired[name] + s.server.AddTool(d.tool, d.handler) + } + + if len(added)+len(changed)+len(removed) == 0 { + s.logger.Debug("MCP tool refresh: no changes detected, no notification sent to clients") + } else { + s.logger.Info("MCP tool refresh broadcast to connected clients (tools/list_changed)", + zap.Strings("added", added), + zap.Strings("changed", changed), + zap.Strings("removed", removed), + zap.Int("total_tools", len(desired)), + ) + } + + // Remember the current state for the next reload's diff. + s.lastToolFingerprints = make(map[string]string, len(desired)) + for name, d := range desired { + s.lastToolFingerprints[name] = d.fingerprint + } + + // Maintain s.registeredTools as a sorted slice for any code that still + // reads it (collision detection in buildDesiredTools, etc.). + s.registeredTools = s.registeredTools[:0] + for _, name := range addNames { + s.registeredTools = append(s.registeredTools, name) + } +} + // Stop gracefully shuts down the MCP server func (s *GraphQLSchemaServer) Stop(ctx context.Context) error { if s.httpServer == nil { @@ -573,34 +753,38 @@ func (s *GraphQLSchemaServer) Stop(ctx context.Context) error { return nil } -// registerTools registers all tools for the MCP server -func (s *GraphQLSchemaServer) registerTools() error { - // Only register the schema tool if exposeSchema is enabled +// buildDesiredTools computes the full set of tools that should be registered +// with the MCP server given the current operations and config flags. It does +// NOT register them — the caller (Reload via applyToolDiff) compares against the +// previous set and only emits SDK Add/Remove calls for actual differences. +func (s *GraphQLSchemaServer) buildDesiredTools() (map[string]desiredTool, error) { + desired := make(map[string]desiredTool) + + // get_schema — only when exposeSchema is enabled. if s.exposeSchema { - // Create a schema with empty properties since get_schema takes no input - getSchemaInputSchema := map[string]any{ + schemaInput := map[string]any{ "type": "object", "properties": map[string]any{}, } - tool := &mcp.Tool{ Name: "get_schema", Description: "Provides the full GraphQL schema of the API.", - InputSchema: getSchemaInputSchema, + InputSchema: schemaInput, Annotations: &mcp.ToolAnnotations{ Title: "Get GraphQL Schema", ReadOnlyHint: true, }, } - - s.server.AddTool(tool, s.handleGetGraphQLSchema()) - s.registeredTools = append(s.registeredTools, "get_schema") + desired[tool.Name] = desiredTool{ + tool: tool, + handler: s.handleGetGraphQLSchema(), + fingerprint: fingerprintTool(tool, ""), + } } - // Only register the execute_graphql tool if enableArbitraryOperations is enabled + // execute_graphql — only when arbitrary operations are enabled. if s.enableArbitraryOperations { - // Add a tool to execute arbitrary GraphQL queries - executeGraphQLSchema := map[string]any{ + execInput := map[string]any{ "type": "object", "description": "The query and variables to execute.", "properties": map[string]any{ @@ -617,31 +801,28 @@ func (s *GraphQLSchemaServer) registerTools() error { "additionalProperties": false, "required": []string{"query"}, } - - destructiveHint := true - openWorldHint := true + destructive := true + openWorld := true tool := &mcp.Tool{ Name: "execute_graphql", Description: "Executes a GraphQL query or mutation.", - InputSchema: executeGraphQLSchema, + InputSchema: execInput, Annotations: &mcp.ToolAnnotations{ Title: "Execute GraphQL Query", - DestructiveHint: &destructiveHint, + DestructiveHint: &destructive, IdempotentHint: false, - OpenWorldHint: &openWorldHint, + OpenWorldHint: &openWorld, }, } - - s.server.AddTool(tool, s.handleExecuteGraphQL()) - s.registeredTools = append(s.registeredTools, "execute_graphql") + desired[tool.Name] = desiredTool{ + tool: tool, + handler: s.handleExecuteGraphQL(), + fingerprint: fingerprintTool(tool, ""), + } } - // Get operations filtered by the excludeMutations setting operations := s.operationsManager.GetFilteredOperations() - graphqlOperationNames := make([]string, 0, len(operations)) - - // Build per-tool scope map for the auth middleware toolScopes := make(map[string][][]string) for _, op := range operations { @@ -651,35 +832,33 @@ func (s *GraphQLSchemaServer) registerTools() error { graphqlOperationNames = append(graphqlOperationNames, op.Name) if len(op.JSONSchema) > 0 { - // Validate the JSON schema before compiling it if err := s.schemaCompiler.ValidateJSONSchema(op.JSONSchema); err != nil { s.logger.Error("invalid schema for operation", - zap.String("operation", op.Name), - zap.Error(err)) + zap.String("operation", op.Name), zap.Error(err)) continue } - - // Now compile the validated schema schemaName := fmt.Sprintf("schema-%s.json", op.Name) compiledSchema, err = s.schemaCompiler.CompileJSONSchema(op.JSONSchema, schemaName) if err != nil { s.logger.Error("failed to compile schema for operation", - zap.String("operation", op.Name), - zap.Error(err)) + zap.String("operation", op.Name), zap.Error(err)) continue } } - // Create handler with pre-compiled schema - handler := &operationHandler{ - operation: op, - compiledSchema: compiledSchema, - } + handler := &operationHandler{operation: op, compiledSchema: compiledSchema} - // Convert the operation name to snake_case for consistent tool naming operationToolName := strcase.ToSnake(op.Name) + toolName := operationToolName + if !s.omitToolNamePrefix { + toolName = fmt.Sprintf("execute_operation_%s", operationToolName) + } else if _, dup := desired[operationToolName]; dup || slices.Contains(reservedToolNames, operationToolName) { + s.logger.Error("Skipping operation due to tool name collision", + zap.String("operation", op.Name), + zap.String("conflicting_tool", operationToolName)) + continue + } - // Use the operation description directly if provided, otherwise generate a default description var toolDescription string if op.Description != "" { toolDescription = op.Description @@ -687,23 +866,11 @@ func (s *GraphQLSchemaServer) registerTools() error { toolDescription = fmt.Sprintf("Executes the GraphQL operation '%s' of type %s.", op.Name, op.OperationType) } - toolName := operationToolName - if !s.omitToolNamePrefix { - toolName = fmt.Sprintf("execute_operation_%s", operationToolName) - } else if slices.Contains(s.registeredTools, operationToolName) || slices.Contains(reservedToolNames, operationToolName) { - s.logger.Error("Skipping operation due to tool name collision", - zap.String("operation", op.Name), - zap.String("conflicting_tool", operationToolName), - ) - continue - } - // Parse JSON schema into map for the official SDK var inputSchema any if len(op.JSONSchema) > 0 { if err := json.Unmarshal(op.JSONSchema, &inputSchema); err != nil { s.logger.Error("failed to parse JSON schema for operation", - zap.String("operation", op.Name), - zap.Error(err)) + zap.String("operation", op.Name), zap.Error(err)) continue } } else { @@ -723,22 +890,30 @@ func (s *GraphQLSchemaServer) registerTools() error { }, } - s.server.AddTool(tool, s.handleOperation(handler)) - - s.registeredTools = append(s.registeredTools, toolName) + // Per-operation tools incorporate the query body and required scopes + // into the fingerprint, so editing the operation triggers a re-add and + // editing whitespace alone does not (the parser normalizes the body). + extra := op.OperationString + scopesFingerprint(op.RequiredScopes) + desired[toolName] = desiredTool{ + tool: tool, + handler: s.handleOperation(handler), + fingerprint: fingerprintTool(tool, extra), + } - // Record per-tool scope requirements for auth middleware enforcement if len(op.RequiredScopes) > 0 { toolScopes[toolName] = op.RequiredScopes } } - // Update auth middleware with per-tool scopes (thread-safe) if s.authMiddleware != nil { s.authMiddleware.SetToolScopes(toolScopes) } - getOperationInfoTool := &mcp.Tool{ + // get_operation_info — always present, but its description includes the list + // of operation names, so its fingerprint changes when operations are added or + // removed (correctly triggering a notification only when the enum actually shifts). + sort.Strings(graphqlOperationNames) + getOpInfo := &mcp.Tool{ Name: "get_operation_info", Description: "Provides instructions on how to execute the GraphQL operation via HTTP and how to integrate it into your application.", InputSchema: map[string]any{ @@ -757,12 +932,76 @@ func (s *GraphQLSchemaServer) registerTools() error { ReadOnlyHint: true, }, } + desired[getOpInfo.Name] = desiredTool{ + tool: getOpInfo, + handler: s.handleGraphQLOperationInfo(), + fingerprint: fingerprintTool(getOpInfo, ""), + } - s.server.AddTool(getOperationInfoTool, s.handleGraphQLOperationInfo()) + // code_mode_run_js — when at least one operation is loaded, expose a single + // V8-sandboxed tool where every operation is bound as `tools.(vars)`. + // Lets an LLM compose multiple ops in one round-trip instead of N MCP calls. + if len(operations) > 0 { + codeModeTool := s.codeModeToolDescriptor() + desired[codeModeTool.Name] = desiredTool{ + tool: codeModeTool, + handler: s.handleCodeModeRunJS(), + fingerprint: fingerprintTool(codeModeTool, fmt.Sprintf("ops=%d", len(operations))), + } + } - s.registeredTools = append(s.registeredTools, "get_operation_info") + return desired, nil +} - return nil +// fingerprintTool computes a stable hash of the tool's user-visible content: +// name, description, input schema, annotations, plus an operation-specific extra +// (query body + scopes for operation tools, empty for built-ins). +// +// Two tools with the same fingerprint produce identical tools/list and +// tools/call experiences for an MCP client — so we can skip re-registering +// (and skip the tools/list_changed notification). +func fingerprintTool(t *mcp.Tool, extra string) string { + h := sha256.New() + h.Write([]byte(t.Name)) + h.Write([]byte{0}) + h.Write([]byte(t.Description)) + h.Write([]byte{0}) + if t.InputSchema != nil { + if buf, err := json.Marshal(t.InputSchema); err == nil { + h.Write(buf) + } + } + h.Write([]byte{0}) + if t.Annotations != nil { + if buf, err := json.Marshal(t.Annotations); err == nil { + h.Write(buf) + } + } + h.Write([]byte{0}) + h.Write([]byte(extra)) + return hex.EncodeToString(h.Sum(nil)) +} + +// scopesFingerprint produces a stable string for an OR-of-AND scope list. +func scopesFingerprint(scopes [][]string) string { + if len(scopes) == 0 { + return "" + } + cp := make([][]string, len(scopes)) + for i, group := range scopes { + grp := make([]string, len(group)) + copy(grp, group) + sort.Strings(grp) + cp[i] = grp + } + sort.Slice(cp, func(i, j int) bool { + return strings.Join(cp[i], ",") < strings.Join(cp[j], ",") + }) + parts := make([]string, len(cp)) + for i, grp := range cp { + parts[i] = strings.Join(grp, "&") + } + return strings.Join(parts, "|") } // handleOperation handles a specific operation @@ -1098,7 +1337,7 @@ func (s *GraphQLSchemaServer) handleProtectedResourceMetadata(w http.ResponseWri scopes = []string{} // Ensure non-nil for JSON encoding } - mcpResourceURL := strings.TrimRight(resourceURL, "/") + "/mcp" + mcpResourceURL := strings.TrimRight(resourceURL, "/") + s.path metadata := ProtectedResourceMetadata{ Resource: mcpResourceURL, @@ -1124,7 +1363,43 @@ func (s *GraphQLSchemaServer) handleProtectedResourceMetadata(w http.ResponseWri // GetResourceMetadataURL returns the URL for the OAuth 2.0 Protected Resource Metadata endpoint func (s *GraphQLSchemaServer) GetResourceMetadataURL() string { if s.serverBaseURL != "" { - return fmt.Sprintf("%s/.well-known/oauth-protected-resource/mcp", s.serverBaseURL) + return fmt.Sprintf("%s/.well-known/oauth-protected-resource%s", s.serverBaseURL, s.path) } return "" } + +// Path returns the URL path this MCP server is mounted at. +func (s *GraphQLSchemaServer) Path() string { return s.path } + +// Name returns the configured graph name (used in metrics/logs). +func (s *GraphQLSchemaServer) Name() string { return s.graphName } + +// HasUpstreamSchema reports whether this server uses an SDL-provided upstream schema +// (i.e. it does not track the local supergraph schema). +func (s *GraphQLSchemaServer) HasUpstreamSchema() bool { return s.upstreamSchemaSDL != "" } + +// OperationsDir returns the configured operations directory for this server, +// or "" if no storage provider is wired up. +func (s *GraphQLSchemaServer) OperationsDir() string { return s.operationsDir } + +// WatchSettings returns the operations-directory watcher configuration. +// (enabled, interval). Used by MultiServer to start watchers after Reload. +func (s *GraphQLSchemaServer) WatchSettings() (bool, time.Duration) { + return s.watchOperations, s.operationsWatchInterval +} + +// Context returns the per-server context (cancelled on Stop). Used by +// background goroutines (operations watcher, etc.) to know when to exit. +func (s *GraphQLSchemaServer) Context() context.Context { return s.ctx } + +// ReloadOperations re-reads the operations directory using the most recently +// loaded schema and field configurations. Used by the per-collection storage +// directory watcher to hot-reload tools when files change. Safe to call +// before any initial Reload — in that case it returns an error rather than +// panicking on a nil schema. +func (s *GraphQLSchemaServer) ReloadOperations() error { + if s.lastSchema == nil { + return fmt.Errorf("ReloadOperations called before initial Reload") + } + return s.Reload(s.lastSchema, s.lastFieldConfigs) +} diff --git a/router/pkg/mcpserver/server_test.go b/router/pkg/mcpserver/server_test.go index 3a3c044609..72f21232a1 100644 --- a/router/pkg/mcpserver/server_test.go +++ b/router/pkg/mcpserver/server_test.go @@ -151,7 +151,7 @@ func TestReload_ReservedToolNameCollision(t *testing.T) { assert.Equal(t, "get_operation_info", entry.ContextMap()["conflicting_tool"]) } - assert.ElementsMatch(t, []string{"get_schema", "list_employees", "get_operation_info"}, srv.registeredTools) + assert.ElementsMatch(t, []string{"get_schema", "list_employees", "get_operation_info", "code_mode_run_js"}, srv.registeredTools) } func TestReload_PrefixModeAvoidsReservedNameCollision(t *testing.T) { @@ -193,5 +193,6 @@ func TestReload_PrefixModeAvoidsReservedNameCollision(t *testing.T) { "execute_operation_get_operation_info", "execute_operation_list_employees", "get_operation_info", + "code_mode_run_js", }, srv.registeredTools) } diff --git a/router/pkg/schemaloader/loader.go b/router/pkg/schemaloader/loader.go index e82f584357..5374a56913 100644 --- a/router/pkg/schemaloader/loader.go +++ b/router/pkg/schemaloader/loader.go @@ -12,6 +12,7 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astprinter" "github.com/wundergraph/graphql-go-tools/v2/pkg/astvalidation" "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" ) @@ -121,12 +122,27 @@ func (l *OperationLoader) LoadOperationsFromDirectory(dirPath string) ([]Operati // Extract description from operation definition opDescription := extractOperationDescription(&opDoc) + // Strip operation/fragment descriptions before storing OperationString. + // Descriptions on OperationDefinition / FragmentDefinition are a September 2025 + // GraphQL spec addition; most upstream parsers (including third-party APIs like + // rickandmorty, anilist, countries) still reject them as syntax errors. We keep + // the description on Operation.Description for MCP tool metadata. + clearOperationAndFragmentDescriptions(&opDoc) + cleanedOperationString, printErr := astprinter.PrintString(&opDoc) + if printErr != nil { + l.Logger.Error("failed to re-print MCP operation after stripping descriptions", + zap.String("operation", opName), + zap.String("file", path), + zap.Error(printErr)) + return nil + } + // Add to our list of operations operations = append(operations, Operation{ Name: opName, FilePath: path, Document: opDoc, - OperationString: operationString, + OperationString: cleanedOperationString, OperationType: opType, Description: opDescription, }) @@ -188,6 +204,19 @@ func GetOperationNameAndType(doc *ast.Document) (string, string, error) { return "", "", fmt.Errorf("no operation found in document") } +// clearOperationAndFragmentDescriptions marks every OperationDefinition and +// FragmentDefinition description in the document as undefined so the printer +// emits a description-free GraphQL document. Used when forwarding operations +// to upstreams that don't yet support the September 2025 description spec. +func clearOperationAndFragmentDescriptions(doc *ast.Document) { + for i := range doc.OperationDefinitions { + doc.OperationDefinitions[i].Description.IsDefined = false + } + for i := range doc.FragmentDefinitions { + doc.FragmentDefinitions[i].Description.IsDefined = false + } +} + // extractOperationDescription extracts the description string from an operation definition func extractOperationDescription(doc *ast.Document) string { for _, ref := range doc.RootNodes {