diff --git a/readme.md b/readme.md index 24e8991..cd5cc7c 100644 --- a/readme.md +++ b/readme.md @@ -90,6 +90,7 @@ Optional parameters: - `--password`: Password for PostgreSQL authentication - `--port`: Port number (default: 5432) - `--ssl`: Enable SSL connection (true/false) +- `--ssl-reject-unauthorized`: Reject unauthorized SSL certificates (true/false, default: true). Set to `false` to accept self-signed certificates. - `--connection-timeout`: Connection timeout in milliseconds (default: 30000) ### MySQL Database diff --git a/src/db/postgresql-adapter.ts b/src/db/postgresql-adapter.ts index 6e95cf1..138ca46 100644 --- a/src/db/postgresql-adapter.ts +++ b/src/db/postgresql-adapter.ts @@ -1,14 +1,25 @@ import { DbAdapter } from "./adapter.js"; import pg from 'pg'; +// Default timeouts (in milliseconds) +const DEFAULT_STATEMENT_TIMEOUT_MS = 30_000; // 30 s — kills runaway queries +const DEFAULT_IDLE_IN_TRANSACTION_TIMEOUT_MS = 60_000; // 60 s — kills forgotten transactions +const DEFAULT_IDLE_TIMEOUT_MS = 10_000; // 10 s — release idle pool connections +const DEFAULT_CONNECTION_TIMEOUT_MS = 30_000; // 30 s — give up connecting + /** * PostgreSQL database adapter implementation + * + * Uses pg.Pool (max 1) instead of a bare pg.Client so that: + * - idle connections are reaped after idleTimeoutMillis + * - statement_timeout prevents queries from running forever + * - idle_in_transaction_session_timeout kills abandoned transactions */ export class PostgresqlAdapter implements DbAdapter { - private client: pg.Client | null = null; - private config: pg.ClientConfig; + private pool: pg.Pool | null = null; private host: string; private database: string; + private poolConfig: pg.PoolConfig; constructor(connectionInfo: { host: string; @@ -19,41 +30,63 @@ export class PostgresqlAdapter implements DbAdapter { ssl?: boolean | object; options?: any; connectionTimeout?: number; + statementTimeout?: number; + idleTimeout?: number; }) { this.host = connectionInfo.host; this.database = connectionInfo.database; - - // Create PostgreSQL connection config - this.config = { + + const statementTimeout = connectionInfo.statementTimeout || DEFAULT_STATEMENT_TIMEOUT_MS; + + this.poolConfig = { host: connectionInfo.host, database: connectionInfo.database, port: connectionInfo.port || 5432, user: connectionInfo.user, password: connectionInfo.password, ssl: connectionInfo.ssl, - // Add connection timeout if provided (in milliseconds) - connectionTimeoutMillis: connectionInfo.connectionTimeout || 30000, - }; + // Single connection — MCP server is single-threaded + max: 1, + connectionTimeoutMillis: connectionInfo.connectionTimeout || DEFAULT_CONNECTION_TIMEOUT_MS, + idleTimeoutMillis: connectionInfo.idleTimeout || DEFAULT_IDLE_TIMEOUT_MS, + // Server-side timeouts applied to every connection + statement_timeout: statementTimeout, + idle_in_transaction_session_timeout: DEFAULT_IDLE_IN_TRANSACTION_TIMEOUT_MS, + } as pg.PoolConfig; } /** - * Initialize PostgreSQL connection + * Initialize PostgreSQL connection pool */ async init(): Promise { try { console.error(`[INFO] Connecting to PostgreSQL: ${this.host}, Database: ${this.database}`); - console.error(`[DEBUG] Connection details:`, { - host: this.host, + console.error(`[DEBUG] Pool config:`, { + host: this.host, database: this.database, - port: this.config.port, - user: this.config.user, - connectionTimeoutMillis: this.config.connectionTimeoutMillis, - ssl: !!this.config.ssl + port: this.poolConfig.port, + user: this.poolConfig.user, + max: this.poolConfig.max, + connectionTimeoutMillis: this.poolConfig.connectionTimeoutMillis, + idleTimeoutMillis: this.poolConfig.idleTimeoutMillis, + ssl: !!this.poolConfig.ssl, + }); + + this.pool = new pg.Pool(this.poolConfig); + + // Log pool errors instead of crashing + this.pool.on('error', (err) => { + console.error('[ERROR] Unexpected pool client error:', err.message); }); - - this.client = new pg.Client(this.config); - await this.client.connect(); - console.error(`[INFO] PostgreSQL connection established successfully`); + + // Verify connectivity + const client = await this.pool.connect(); + try { + await client.query('SELECT 1'); + } finally { + client.release(); + } + console.error(`[INFO] PostgreSQL connection pool ready`); } catch (err) { console.error(`[ERROR] PostgreSQL connection error: ${(err as Error).message}`); throw new Error(`Failed to connect to PostgreSQL: ${(err as Error).message}`); @@ -62,20 +95,15 @@ export class PostgresqlAdapter implements DbAdapter { /** * Execute a SQL query and get all results - * @param query SQL query to execute - * @param params Query parameters - * @returns Promise with query results */ async all(query: string, params: any[] = []): Promise { - if (!this.client) { + if (!this.pool) { throw new Error("Database not initialized"); } try { - // PostgreSQL uses $1, $2, etc. for parameterized queries const preparedQuery = query.replace(/\?/g, (_, i) => `$${i + 1}`); - - const result = await this.client.query(preparedQuery, params); + const result = await this.pool.query(preparedQuery, params); return result.rows; } catch (err) { throw new Error(`PostgreSQL query error: ${(err as Error).message}`); @@ -84,37 +112,31 @@ export class PostgresqlAdapter implements DbAdapter { /** * Execute a SQL query that modifies data - * @param query SQL query to execute - * @param params Query parameters - * @returns Promise with result info */ async run(query: string, params: any[] = []): Promise<{ changes: number, lastID: number }> { - if (!this.client) { + if (!this.pool) { throw new Error("Database not initialized"); } try { - // Replace ? with numbered parameters const preparedQuery = query.replace(/\?/g, (_, i) => `$${i + 1}`); - + let lastID = 0; let changes = 0; - - // For INSERT queries, try to get the inserted ID + if (query.trim().toUpperCase().startsWith('INSERT')) { - // Add RETURNING clause to get the inserted ID if it doesn't already have one - const returningQuery = preparedQuery.includes('RETURNING') - ? preparedQuery + const returningQuery = preparedQuery.includes('RETURNING') + ? preparedQuery : `${preparedQuery} RETURNING id`; - - const result = await this.client.query(returningQuery, params); + + const result = await this.pool.query(returningQuery, params); changes = result.rowCount || 0; lastID = result.rows[0]?.id || 0; } else { - const result = await this.client.query(preparedQuery, params); + const result = await this.pool.query(preparedQuery, params); changes = result.rowCount || 0; } - + return { changes, lastID }; } catch (err) { throw new Error(`PostgreSQL query error: ${(err as Error).message}`); @@ -123,28 +145,26 @@ export class PostgresqlAdapter implements DbAdapter { /** * Execute multiple SQL statements - * @param query SQL statements to execute - * @returns Promise that resolves when execution completes */ async exec(query: string): Promise { - if (!this.client) { + if (!this.pool) { throw new Error("Database not initialized"); } try { - await this.client.query(query); + await this.pool.query(query); } catch (err) { throw new Error(`PostgreSQL batch error: ${(err as Error).message}`); } } /** - * Close the database connection + * Close the connection pool — releases all connections back to the server */ async close(): Promise { - if (this.client) { - await this.client.end(); - this.client = null; + if (this.pool) { + await this.pool.end(); + this.pool = null; } } @@ -169,7 +189,6 @@ export class PostgresqlAdapter implements DbAdapter { /** * Get database-specific query for describing a table - * @param tableName Table name */ getDescribeTableQuery(tableName: string): string { return ` @@ -194,4 +213,4 @@ export class PostgresqlAdapter implements DbAdapter { c.ordinal_position `; } -} \ No newline at end of file +} diff --git a/src/handlers/toolHandlers.ts b/src/handlers/toolHandlers.ts index 463bcc3..274865c 100644 --- a/src/handlers/toolHandlers.ts +++ b/src/handlers/toolHandlers.ts @@ -3,6 +3,7 @@ import { formatErrorResponse } from '../utils/formatUtils.js'; // Import all tool implementations import { readQuery, writeQuery, exportQuery } from '../tools/queryTools.js'; import { createTable, alterTable, dropTable, listTables, describeTable } from '../tools/schemaTools.js'; +import { createType, createIndex, createFunctionAndTrigger } from '../tools/ddlTools.js'; import { appendInsight, listInsights } from '../tools/insightTools.js'; /** @@ -68,6 +69,43 @@ export function handleListTools() { required: ["table_name", "confirm"], }, }, + { + name: "create_type", + description: + "Create a new PostgreSQL enum type. Only CREATE TYPE ... AS ENUM (...) statements are allowed. Safe-additive only; no CREATE OR REPLACE, no ALTER TYPE, no DROP TYPE.", + inputSchema: { + type: "object", + properties: { + query: { type: "string" }, + }, + required: ["query"], + }, + }, + { + name: "create_index", + description: + "Create a PostgreSQL index. Allows CREATE [UNIQUE] INDEX [IF NOT EXISTS] ... ON ... (...). Safe-additive only.", + inputSchema: { + type: "object", + properties: { + query: { type: "string" }, + }, + required: ["query"], + }, + }, + { + name: "create_function_and_trigger", + description: + "Create paired PostgreSQL function + trigger for audit-immutability and similar append-only enforcement. Accepts two separate statements. Function runs first, then trigger; CREATE OR REPLACE is allowed for the function (trigger helpers must be idempotent) but not for the trigger (pick a unique trigger name).", + inputSchema: { + type: "object", + properties: { + functionSql: { type: "string" }, + triggerSql: { type: "string" }, + }, + required: ["functionSql", "triggerSql"], + }, + }, { name: "export_query", description: "Export query results to various formats (CSV, JSON)", @@ -145,7 +183,16 @@ export async function handleToolCall(name: string, args: any) { case "drop_table": return await dropTable(args.table_name, args.confirm); - + + case "create_type": + return await createType(args.query); + + case "create_index": + return await createIndex(args.query); + + case "create_function_and_trigger": + return await createFunctionAndTrigger(args.functionSql, args.triggerSql); + case "export_query": return await exportQuery(args.query, args.format); diff --git a/src/index.ts b/src/index.ts index 9bf7fda..d955456 100644 --- a/src/index.ts +++ b/src/index.ts @@ -44,7 +44,7 @@ if (args.length === 0) { logger.error("Please provide database connection information"); logger.error("Usage for SQLite: node index.js "); logger.error("Usage for SQL Server: node index.js --sqlserver --server --database [--user --password ]"); - logger.error("Usage for PostgreSQL: node index.js --postgresql --host --database [--user --password --port ]"); + logger.error("Usage for PostgreSQL: node index.js --postgresql --host --database [--user --password --port --ssl true --ssl-reject-unauthorized false]"); logger.error("Usage for MySQL: node index.js --mysql --host --database [--user --password --port ]"); logger.error("Usage for MySQL with AWS IAM: node index.js --mysql --aws-iam-auth --host --database --user --aws-region "); process.exit(1); @@ -95,9 +95,10 @@ else if (args.includes('--postgresql') || args.includes('--postgres')) { password: undefined, port: undefined, ssl: undefined, + sslRejectUnauthorized: undefined, connectionTimeout: undefined }; - + // Parse PostgreSQL connection parameters for (let i = 0; i < args.length; i++) { if (args[i] === '--host' && i + 1 < args.length) { @@ -112,11 +113,19 @@ else if (args.includes('--postgresql') || args.includes('--postgres')) { connectionInfo.port = parseInt(args[i + 1], 10); } else if (args[i] === '--ssl' && i + 1 < args.length) { connectionInfo.ssl = args[i + 1] === 'true'; + } else if (args[i] === '--ssl-reject-unauthorized' && i + 1 < args.length) { + connectionInfo.sslRejectUnauthorized = args[i + 1] === 'true'; } else if (args[i] === '--connection-timeout' && i + 1 < args.length) { connectionInfo.connectionTimeout = parseInt(args[i + 1], 10); } } - + + // Build SSL configuration object if needed + if (connectionInfo.ssl && connectionInfo.sslRejectUnauthorized === false) { + connectionInfo.ssl = { rejectUnauthorized: false }; + logger.info("SSL enabled with self-signed certificate support (rejectUnauthorized: false)"); + } + // Validate PostgreSQL connection info if (!connectionInfo.host || !connectionInfo.database) { logger.error("Error: PostgreSQL requires --host and --database parameters"); diff --git a/src/tools/ddlTools.test.ts b/src/tools/ddlTools.test.ts new file mode 100644 index 0000000..fb6507d --- /dev/null +++ b/src/tools/ddlTools.test.ts @@ -0,0 +1,207 @@ +/** + * Pure-validator smoke tests for the safe-additive DDL channels. Exercises + * each validator with accepted + rejected inputs without touching the + * database — validators are pulled out as standalone exports so we can test + * them without a live connection or an ESM module-mock shim. + * + * Run with: node dist/src/tools/ddlTools.test.js + */ + +import { + validateCreateType, + validateCreateIndex, + validateFunctionAndTrigger, +} from './ddlTools.js'; + +type Case = { + label: string; + run: () => void; + expect: 'accept' | 'reject'; + errorContains?: string; +}; + +function mk(label: string, fn: () => void, expect: 'accept' | 'reject', errorContains?: string): Case { + return { label, run: fn, expect, errorContains }; +} + +const cases: Case[] = [ + // ---- create_type ---- + mk( + 'create_type: simple enum', + () => void validateCreateType("CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy')"), + 'accept' + ), + mk( + 'create_type: lowercase keywords', + () => void validateCreateType("create type status as enum ('a','b')"), + 'accept' + ), + mk( + 'create_type: with trailing semicolon + whitespace', + () => void validateCreateType("CREATE TYPE t1 AS ENUM ('x') ; "), + 'accept' + ), + mk( + 'create_type: rejects DROP keyword anywhere in body', + () => void validateCreateType("CREATE TYPE mood AS ENUM ('DROP', 'ok')"), + 'reject', + 'Disallowed keyword' + ), + mk( + 'create_type: rejects composite type', + () => void validateCreateType('CREATE TYPE person AS (name text, age int)'), + 'reject', + 'AS ENUM' + ), + mk( + 'create_type: rejects multi-statement', + () => void validateCreateType("CREATE TYPE a AS ENUM ('x'); CREATE TYPE b AS ENUM ('y')"), + 'reject', + 'Multi-statement' + ), + + // ---- create_index ---- + mk( + 'create_index: simple', + () => void validateCreateIndex('CREATE INDEX idx_users_email ON users (email)'), + 'accept' + ), + mk( + 'create_index: unique + if not exists', + () => void validateCreateIndex( + 'CREATE UNIQUE INDEX IF NOT EXISTS idx_uniq ON accounts (tenant_id, slug)' + ), + 'accept' + ), + mk( + 'create_index: partial with WHERE', + () => void validateCreateIndex( + 'CREATE INDEX idx_active ON users (id) WHERE deleted_at IS NULL' + ), + 'accept' + ), + mk( + 'create_index: identifier containing "drop" as substring is accepted', + // Word-boundary regex treats underscore as a word character, so index + // names like idx_drop_me do NOT trigger the DROP keyword check. This is + // correct — banning any identifier containing "drop" would be too + // aggressive. Real DROP keywords (standalone) are still rejected. + () => void validateCreateIndex('CREATE INDEX idx_drop_me ON users (id)'), + 'accept' + ), + mk( + 'create_index: rejects REINDEX command (wrong starting token)', + () => void validateCreateIndex('REINDEX INDEX idx_users_email'), + 'reject', + 'CREATE' + ), + mk( + 'create_index: rejects DROP keyword as standalone word', + () => void validateCreateIndex( + "CREATE INDEX idx_x ON users (id) WHERE kind = 'DROP'" + ), + 'reject', + 'Disallowed keyword' + ), + mk( + 'create_index: rejects missing ON', + () => void validateCreateIndex('CREATE INDEX idx_bad (email)'), + 'reject', + 'ON clause' + ), + + // ---- create_function_and_trigger ---- + mk( + 'fn+trigger: simple audit immutability', + () => void validateFunctionAndTrigger( + `CREATE OR REPLACE FUNCTION audit_ledger_immutable() RETURNS TRIGGER AS $$ + BEGIN + RAISE EXCEPTION 'ledger rows are append-only'; + END; + $$ LANGUAGE plpgsql`, + `CREATE TRIGGER trg_ledger_no_update + BEFORE UPDATE OR DELETE ON ledger_entries + FOR EACH ROW EXECUTE FUNCTION audit_ledger_immutable()` + ), + 'accept' + ), + mk( + 'fn+trigger: plain CREATE FUNCTION (no OR REPLACE)', + () => void validateFunctionAndTrigger( + `CREATE FUNCTION f1() RETURNS TRIGGER AS $$ BEGIN RETURN NEW; END; $$ LANGUAGE plpgsql`, + `CREATE TRIGGER trg_f1 AFTER INSERT ON t1 FOR EACH ROW EXECUTE FUNCTION f1()` + ), + 'accept' + ), + mk( + 'fn+trigger: tagged dollar-quote body', + () => void validateFunctionAndTrigger( + `CREATE FUNCTION f2() RETURNS TRIGGER AS $body$ BEGIN RETURN NEW; END; $body$ LANGUAGE plpgsql`, + `CREATE TRIGGER trg_f2 AFTER INSERT ON t2 FOR EACH ROW EXECUTE FUNCTION f2()` + ), + 'accept' + ), + mk( + 'fn+trigger: rejects CREATE OR REPLACE TRIGGER', + () => void validateFunctionAndTrigger( + `CREATE FUNCTION f3() RETURNS TRIGGER AS $$ BEGIN RETURN NEW; END; $$ LANGUAGE plpgsql`, + `CREATE OR REPLACE TRIGGER trg_f3 AFTER INSERT ON t3 FOR EACH ROW EXECUTE FUNCTION f3()` + ), + 'reject', + 'CREATE TRIGGER' + ), + mk( + 'fn+trigger: rejects DROP in trigger', + () => void validateFunctionAndTrigger( + `CREATE FUNCTION f4() RETURNS TRIGGER AS $$ BEGIN RETURN NEW; END; $$ LANGUAGE plpgsql`, + `DROP TRIGGER IF EXISTS trg_f4 ON t4` + ), + 'reject', + 'CREATE TRIGGER' + ), + mk( + 'fn+trigger: rejects bogus first statement', + () => void validateFunctionAndTrigger( + `SELECT 1`, + `CREATE TRIGGER trg_f5 AFTER INSERT ON t5 FOR EACH ROW EXECUTE FUNCTION f5()` + ), + 'reject', + 'functionSql' + ), +]; + +function main() { + let pass = 0; + let fail = 0; + for (const c of cases) { + try { + c.run(); + if (c.expect === 'accept') { + console.log(`PASS ${c.label}`); + pass++; + } else { + console.log(`FAIL ${c.label} (expected reject, got accept)`); + fail++; + } + } catch (err: any) { + if (c.expect === 'reject') { + if (c.errorContains && !err.message.includes(c.errorContains)) { + console.log( + `FAIL ${c.label} (rejected as expected but message did not contain "${c.errorContains}": ${err.message})` + ); + fail++; + } else { + console.log(`PASS ${c.label} (rejected: ${err.message})`); + pass++; + } + } else { + console.log(`FAIL ${c.label} (expected accept, got: ${err.message})`); + fail++; + } + } + } + console.log(`\n${pass} passed, ${fail} failed, ${pass + fail} total`); + if (fail > 0) process.exit(1); +} + +main(); diff --git a/src/tools/ddlTools.ts b/src/tools/ddlTools.ts new file mode 100644 index 0000000..cad44d7 --- /dev/null +++ b/src/tools/ddlTools.ts @@ -0,0 +1,254 @@ +import { dbExec } from '../db/index.js'; +import { formatSuccessResponse } from '../utils/formatUtils.js'; + +/** + * Shared guardrails applied to every safe-additive DDL channel: + * - Reject DROP / ALTER / TRUNCATE / REINDEX / REPLACE (where not explicitly + * permitted) anywhere in the body. + * - Reject multi-statement submissions. A single optional trailing semicolon + * (possibly followed by whitespace) is permitted; any other semicolon + * terminates a prior statement and is rejected. + * + * These checks are deliberately coarse: the MCP tool surface is intended for + * the accounting project's migration workflow, not as a general SQL gateway. + * Anything that smells like a destructive or destructive-adjacent operation + * must go through out-of-band psql with human review. + */ + +const BANNED_KEYWORDS_DEFAULT = ['DROP', 'ALTER', 'REPLACE', 'TRUNCATE', 'REINDEX']; + +function assertNoBannedKeywords(sql: string, banned: string[]): void { + const upper = sql.toUpperCase(); + for (const kw of banned) { + // Word-boundary match so we don't choke on substrings like "dropdown_value" + const re = new RegExp(`\\b${kw}\\b`); + if (re.test(upper)) { + throw new Error(`Disallowed keyword in DDL: ${kw}`); + } + } +} + +function assertSingleStatement(sql: string): void { + // Strip a single trailing semicolon + whitespace. + const stripped = sql.replace(/;\s*$/, ''); + if (stripped.includes(';')) { + throw new Error('Multi-statement submissions are not allowed; submit one statement per call.'); + } +} + +function normalizeTokens(sql: string): string[] { + // Collapse whitespace, uppercase, split. Good enough for first-N-tokens + // prefix checks; we never hand this tokenization to the database. + return sql.trim().replace(/\s+/g, ' ').toUpperCase().split(' '); +} + +/** + * Pure validator for create_type. Throws on invalid input, returns the + * trimmed query string on success. Exported for unit testing without a DB + * connection. + */ +export function validateCreateType(query: string): string { + if (typeof query !== 'string' || !query.trim()) { + throw new Error('query is required'); + } + + const trimmed = query.trim(); + const tokens = normalizeTokens(trimmed); + + if (tokens.length < 2 || tokens[0] !== 'CREATE' || tokens[1] !== 'TYPE') { + throw new Error('Only CREATE TYPE statements are allowed with create_type'); + } + + // Must be an enum definition specifically. Composite types, ranges, and + // base types are intentionally out of scope for the safe-additive channel. + if (!/\bAS\s+ENUM\b/i.test(trimmed)) { + throw new Error('create_type only supports CREATE TYPE ... AS ENUM (...) statements'); + } + + assertNoBannedKeywords(trimmed, BANNED_KEYWORDS_DEFAULT); + assertSingleStatement(trimmed); + + return trimmed; +} + +/** + * Create a PostgreSQL enum type. + * + * Accepts only: CREATE TYPE AS ENUM (...) + * Rejects: CREATE OR REPLACE TYPE, ALTER TYPE, DROP TYPE, multi-statement. + */ +export async function createType(query: string) { + try { + const trimmed = validateCreateType(query); + await dbExec(trimmed); + return formatSuccessResponse({ success: true, message: 'Type created successfully' }); + } catch (error: any) { + throw new Error(`SQL Error: ${error.message}`); + } +} + +/** + * Pure validator for create_index. Throws on invalid input, returns the + * trimmed query string on success. Exported for unit testing. + */ +export function validateCreateIndex(query: string): string { + if (typeof query !== 'string' || !query.trim()) { + throw new Error('query is required'); + } + + const trimmed = query.trim(); + const tokens = normalizeTokens(trimmed); + + if (tokens.length < 2 || tokens[0] !== 'CREATE') { + throw new Error('Only CREATE [UNIQUE] INDEX statements are allowed with create_index'); + } + + if (tokens[1] === 'UNIQUE') { + if (tokens.length < 3 || tokens[2] !== 'INDEX') { + throw new Error('Expected CREATE UNIQUE INDEX ...'); + } + } else if (tokens[1] !== 'INDEX') { + throw new Error('Expected CREATE INDEX or CREATE UNIQUE INDEX'); + } + + // The statement must target a table — i.e. contain an " ON " clause + // somewhere after the name. We only need a presence check because the + // database will reject anything syntactically malformed. + if (!/\bON\b/i.test(trimmed)) { + throw new Error('CREATE INDEX statement must contain an ON clause'); + } + + assertNoBannedKeywords(trimmed, BANNED_KEYWORDS_DEFAULT); + assertSingleStatement(trimmed); + + return trimmed; +} + +/** + * Create a PostgreSQL index. + * + * Accepts: CREATE [UNIQUE] INDEX [IF NOT EXISTS] ON (...) + * Rejects: DROP INDEX, ALTER INDEX, REINDEX, multi-statement. + */ +export async function createIndex(query: string) { + try { + const trimmed = validateCreateIndex(query); + await dbExec(trimmed); + return formatSuccessResponse({ success: true, message: 'Index created successfully' }); + } catch (error: any) { + throw new Error(`SQL Error: ${error.message}`); + } +} + +/** + * Create a paired PostgreSQL function + trigger. + * + * Intended for narrow audit-immutability style use cases where a trigger + * function must be defined before the trigger that references it. The + * function is executed first; if the trigger creation then fails, the + * function is left in place (harmless) and the trigger error is returned. + * + * CREATE OR REPLACE FUNCTION is allowed because trigger helpers need to be + * idempotent across migration reruns. CREATE TRIGGER does not allow OR + * REPLACE — replacing a trigger requires a prior DROP TRIGGER, which is not + * safe-additive. Callers must use unique trigger names. + */ +/** + * Pure validator for create_function_and_trigger. Throws on invalid input, + * returns `{ fn, trg }` trimmed on success. Exported for unit testing. + */ +export function validateFunctionAndTrigger( + functionSql: string, + triggerSql: string, +): { fn: string; trg: string } { + if (typeof functionSql !== 'string' || !functionSql.trim()) { + throw new Error('functionSql is required'); + } + if (typeof triggerSql !== 'string' || !triggerSql.trim()) { + throw new Error('triggerSql is required'); + } + + const fnTrim = functionSql.trim(); + const trgTrim = triggerSql.trim(); + + // Function: must begin with CREATE FUNCTION or CREATE OR REPLACE FUNCTION. + { + const fnTokens = normalizeTokens(fnTrim); + const startsWithCreateFn = + fnTokens[0] === 'CREATE' && fnTokens[1] === 'FUNCTION'; + const startsWithCreateOrReplaceFn = + fnTokens[0] === 'CREATE' && + fnTokens[1] === 'OR' && + fnTokens[2] === 'REPLACE' && + fnTokens[3] === 'FUNCTION'; + + if (!startsWithCreateFn && !startsWithCreateOrReplaceFn) { + throw new Error( + 'functionSql must start with CREATE FUNCTION or CREATE OR REPLACE FUNCTION' + ); + } + } + + // For the function body we whitelist REPLACE (because CREATE OR REPLACE + // is permitted) but still reject DROP / ALTER / TRUNCATE. + assertNoBannedKeywords(fnTrim, ['DROP', 'ALTER', 'TRUNCATE']); + + // Trigger: must begin with CREATE TRIGGER (no OR REPLACE). + { + const trgTokens = normalizeTokens(trgTrim); + if (trgTokens[0] !== 'CREATE' || trgTokens[1] !== 'TRIGGER') { + throw new Error( + 'triggerSql must start with CREATE TRIGGER (OR REPLACE is not supported — pick a unique trigger name)' + ); + } + } + assertNoBannedKeywords(trgTrim, ['DROP', 'ALTER', 'TRUNCATE', 'REPLACE']); + + // Both inputs must be single statements. Function bodies can legally + // contain semicolons inside $$ ... $$ delimited code blocks, so we + // strip those out before checking. + assertSingleStatementAllowingDollarQuotes(fnTrim); + assertSingleStatement(trgTrim); + + return { fn: fnTrim, trg: trgTrim }; +} + +export async function createFunctionAndTrigger(functionSql: string, triggerSql: string) { + try { + const { fn, trg } = validateFunctionAndTrigger(functionSql, triggerSql); + + // Execute function first. If it fails, surface that error directly. + await dbExec(fn); + + // Execute trigger. If it fails, leave the function in place and report. + try { + await dbExec(trg); + } catch (triggerErr: any) { + throw new Error( + `Function created, but trigger creation failed: ${triggerErr.message}. The function has been left in place; delete it manually if needed.` + ); + } + + return formatSuccessResponse({ + success: true, + message: 'Function and trigger created successfully', + }); + } catch (error: any) { + throw new Error(`SQL Error: ${error.message}`); + } +} + +/** + * Single-statement check that tolerates $$-quoted function bodies. PL/pgSQL + * function bodies are commonly wrapped in $$ ... $$ (or $tag$ ... $tag$) and + * will legitimately contain semicolons for inner statements. We strip the + * dollar-quoted regions before doing the outer semicolon check. + */ +function assertSingleStatementAllowingDollarQuotes(sql: string): void { + // Remove $$-delimited regions first, then $tag$-delimited regions. Doing + // them separately sidesteps JS regex edge cases around backreferences to + // optional empty capture groups. + let cleaned = sql.replace(/\$\$[\s\S]*?\$\$/g, ''); + cleaned = cleaned.replace(/\$([A-Za-z_][A-Za-z0-9_]*)\$[\s\S]*?\$\1\$/g, ''); + assertSingleStatement(cleaned); +}