|
| 1 | +//! Plan-time constant folding for `SqlExpr`. |
| 2 | +//! |
| 3 | +//! Evaluates literal expressions and registered zero-or-few-arg scalar |
| 4 | +//! functions (e.g. `now()`, `current_timestamp`, `date_add(now(), '1h')`) |
| 5 | +//! at plan time via the shared `nodedb_query::functions::eval_function` |
| 6 | +//! evaluator. |
| 7 | +//! |
| 8 | +//! This keeps the bare-`SELECT` projection path, the `INSERT`/`UPSERT` |
| 9 | +//! `VALUES` path, and any future default-expression paths from drifting |
| 10 | +//! apart — they all reach the same evaluator that the Data Plane uses |
| 11 | +//! for column-reference evaluation. |
| 12 | +//! |
| 13 | +//! Semantics: Postgres / SQL-standard compatible. `now()` and |
| 14 | +//! `current_timestamp` snapshot once per statement — `CURRENT_TIMESTAMP` |
| 15 | +//! is defined to return the same value for every row of a single |
| 16 | +//! statement, and Postgres goes further (same value for the whole |
| 17 | +//! transaction). Folding at plan time satisfies both contracts and is |
| 18 | +//! cheaper than per-row runtime dispatch. |
| 19 | +
|
| 20 | +use std::sync::LazyLock; |
| 21 | + |
| 22 | +use nodedb_types::Value; |
| 23 | + |
| 24 | +use crate::functions::registry::{FunctionCategory, FunctionRegistry}; |
| 25 | +use crate::types::{BinaryOp, SqlExpr, SqlValue, UnaryOp}; |
| 26 | + |
| 27 | +/// Process-wide default registry. Used by call sites that don't already |
| 28 | +/// thread a `FunctionRegistry` through (e.g. the DML `VALUES` path). |
| 29 | +static DEFAULT_REGISTRY: LazyLock<FunctionRegistry> = LazyLock::new(FunctionRegistry::new); |
| 30 | + |
| 31 | +/// Access the shared default registry. |
| 32 | +pub fn default_registry() -> &'static FunctionRegistry { |
| 33 | + &DEFAULT_REGISTRY |
| 34 | +} |
| 35 | + |
| 36 | +/// Convenience wrapper around [`fold_constant`] using the default registry. |
| 37 | +pub fn fold_constant_default(expr: &SqlExpr) -> Option<SqlValue> { |
| 38 | + fold_constant(expr, default_registry()) |
| 39 | +} |
| 40 | + |
| 41 | +/// Fold a `SqlExpr` to a literal `SqlValue` at plan time, or return |
| 42 | +/// `None` if the expression depends on row/runtime state (column refs, |
| 43 | +/// subqueries, unknown functions, etc.). |
| 44 | +pub fn fold_constant(expr: &SqlExpr, registry: &FunctionRegistry) -> Option<SqlValue> { |
| 45 | + match expr { |
| 46 | + SqlExpr::Literal(v) => Some(v.clone()), |
| 47 | + SqlExpr::UnaryOp { |
| 48 | + op: UnaryOp::Neg, |
| 49 | + expr, |
| 50 | + } => match fold_constant(expr, registry)? { |
| 51 | + SqlValue::Int(i) => Some(SqlValue::Int(-i)), |
| 52 | + SqlValue::Float(f) => Some(SqlValue::Float(-f)), |
| 53 | + _ => None, |
| 54 | + }, |
| 55 | + SqlExpr::BinaryOp { left, op, right } => { |
| 56 | + let l = fold_constant(left, registry)?; |
| 57 | + let r = fold_constant(right, registry)?; |
| 58 | + fold_binary(l, *op, r) |
| 59 | + } |
| 60 | + SqlExpr::Function { name, args, .. } => fold_function_call(name, args, registry), |
| 61 | + _ => None, |
| 62 | + } |
| 63 | +} |
| 64 | + |
| 65 | +fn fold_binary(l: SqlValue, op: BinaryOp, r: SqlValue) -> Option<SqlValue> { |
| 66 | + Some(match (l, op, r) { |
| 67 | + (SqlValue::Int(a), BinaryOp::Add, SqlValue::Int(b)) => SqlValue::Int(a + b), |
| 68 | + (SqlValue::Int(a), BinaryOp::Sub, SqlValue::Int(b)) => SqlValue::Int(a - b), |
| 69 | + (SqlValue::Int(a), BinaryOp::Mul, SqlValue::Int(b)) => SqlValue::Int(a * b), |
| 70 | + (SqlValue::Float(a), BinaryOp::Add, SqlValue::Float(b)) => SqlValue::Float(a + b), |
| 71 | + (SqlValue::Float(a), BinaryOp::Sub, SqlValue::Float(b)) => SqlValue::Float(a - b), |
| 72 | + (SqlValue::Float(a), BinaryOp::Mul, SqlValue::Float(b)) => SqlValue::Float(a * b), |
| 73 | + (SqlValue::String(a), BinaryOp::Concat, SqlValue::String(b)) => { |
| 74 | + SqlValue::String(format!("{a}{b}")) |
| 75 | + } |
| 76 | + _ => return None, |
| 77 | + }) |
| 78 | +} |
| 79 | + |
| 80 | +/// Fold a function call by recursively folding its arguments, dispatching |
| 81 | +/// through the shared scalar evaluator, and converting the result back to |
| 82 | +/// `SqlValue`. Only folds functions that are present in `registry`, so |
| 83 | +/// callers can distinguish "unknown function" from "known function, all |
| 84 | +/// args folded". |
| 85 | +pub fn fold_function_call( |
| 86 | + name: &str, |
| 87 | + args: &[SqlExpr], |
| 88 | + registry: &FunctionRegistry, |
| 89 | +) -> Option<SqlValue> { |
| 90 | + // Gate on registry so unknown-function paths keep their existing |
| 91 | + // fallbacks instead of collapsing to SqlValue::Null. Aggregates and |
| 92 | + // window functions aren't foldable — they need a row stream. |
| 93 | + let meta = registry.lookup(name)?; |
| 94 | + if matches!( |
| 95 | + meta.category, |
| 96 | + FunctionCategory::Aggregate | FunctionCategory::Window |
| 97 | + ) { |
| 98 | + return None; |
| 99 | + } |
| 100 | + |
| 101 | + let folded_args: Vec<Value> = args |
| 102 | + .iter() |
| 103 | + .map(|a| fold_constant(a, registry).map(sql_to_ndb_value)) |
| 104 | + .collect::<Option<_>>()?; |
| 105 | + |
| 106 | + let result = nodedb_query::functions::eval_function(name, &folded_args); |
| 107 | + Some(ndb_to_sql_value(result)) |
| 108 | +} |
| 109 | + |
| 110 | +fn sql_to_ndb_value(v: SqlValue) -> Value { |
| 111 | + match v { |
| 112 | + SqlValue::Null => Value::Null, |
| 113 | + SqlValue::Bool(b) => Value::Bool(b), |
| 114 | + SqlValue::Int(i) => Value::Integer(i), |
| 115 | + SqlValue::Float(f) => Value::Float(f), |
| 116 | + SqlValue::String(s) => Value::String(s), |
| 117 | + SqlValue::Bytes(b) => Value::Bytes(b), |
| 118 | + SqlValue::Array(a) => Value::Array(a.into_iter().map(sql_to_ndb_value).collect()), |
| 119 | + } |
| 120 | +} |
| 121 | + |
| 122 | +fn ndb_to_sql_value(v: Value) -> SqlValue { |
| 123 | + match v { |
| 124 | + Value::Null => SqlValue::Null, |
| 125 | + Value::Bool(b) => SqlValue::Bool(b), |
| 126 | + Value::Integer(i) => SqlValue::Int(i), |
| 127 | + Value::Float(f) => SqlValue::Float(f), |
| 128 | + Value::String(s) => SqlValue::String(s), |
| 129 | + Value::Bytes(b) => SqlValue::Bytes(b), |
| 130 | + Value::Array(a) => SqlValue::Array(a.into_iter().map(ndb_to_sql_value).collect()), |
| 131 | + Value::DateTime(dt) => SqlValue::String(dt.to_iso8601()), |
| 132 | + Value::Uuid(s) | Value::Ulid(s) | Value::Regex(s) => SqlValue::String(s), |
| 133 | + Value::Duration(d) => SqlValue::String(d.to_human()), |
| 134 | + Value::Decimal(d) => SqlValue::String(d.to_string()), |
| 135 | + // Structured and opaque types collapse to Null — callers that |
| 136 | + // need these go through the runtime expression path, not folding. |
| 137 | + Value::Object(_) |
| 138 | + | Value::Geometry(_) |
| 139 | + | Value::Set(_) |
| 140 | + | Value::Range { .. } |
| 141 | + | Value::Record { .. } => SqlValue::Null, |
| 142 | + } |
| 143 | +} |
| 144 | + |
| 145 | +#[cfg(test)] |
| 146 | +mod tests { |
| 147 | + use super::*; |
| 148 | + |
| 149 | + #[test] |
| 150 | + fn fold_now_produces_non_epoch_string() { |
| 151 | + let registry = FunctionRegistry::new(); |
| 152 | + let expr = SqlExpr::Function { |
| 153 | + name: "now".into(), |
| 154 | + args: vec![], |
| 155 | + distinct: false, |
| 156 | + }; |
| 157 | + let val = fold_constant(&expr, ®istry).expect("now() should fold"); |
| 158 | + match val { |
| 159 | + SqlValue::String(s) => { |
| 160 | + assert!(!s.starts_with("1970"), "got {s}"); |
| 161 | + assert!(s.contains('T'), "not ISO-8601: {s}"); |
| 162 | + } |
| 163 | + other => panic!("expected string, got {other:?}"), |
| 164 | + } |
| 165 | + } |
| 166 | + |
| 167 | + #[test] |
| 168 | + fn fold_current_timestamp() { |
| 169 | + let registry = FunctionRegistry::new(); |
| 170 | + let expr = SqlExpr::Function { |
| 171 | + name: "current_timestamp".into(), |
| 172 | + args: vec![], |
| 173 | + distinct: false, |
| 174 | + }; |
| 175 | + assert!(matches!( |
| 176 | + fold_constant(&expr, ®istry), |
| 177 | + Some(SqlValue::String(_)) |
| 178 | + )); |
| 179 | + } |
| 180 | + |
| 181 | + #[test] |
| 182 | + fn fold_unknown_function_returns_none() { |
| 183 | + let registry = FunctionRegistry::new(); |
| 184 | + let expr = SqlExpr::Function { |
| 185 | + name: "definitely_not_a_real_function".into(), |
| 186 | + args: vec![], |
| 187 | + distinct: false, |
| 188 | + }; |
| 189 | + assert!(fold_constant(&expr, ®istry).is_none()); |
| 190 | + } |
| 191 | + |
| 192 | + #[test] |
| 193 | + fn fold_literal_arithmetic_still_works() { |
| 194 | + let registry = FunctionRegistry::new(); |
| 195 | + let expr = SqlExpr::BinaryOp { |
| 196 | + left: Box::new(SqlExpr::Literal(SqlValue::Int(2))), |
| 197 | + op: BinaryOp::Add, |
| 198 | + right: Box::new(SqlExpr::Literal(SqlValue::Int(3))), |
| 199 | + }; |
| 200 | + assert_eq!(fold_constant(&expr, ®istry), Some(SqlValue::Int(5))); |
| 201 | + } |
| 202 | + |
| 203 | + #[test] |
| 204 | + fn fold_column_ref_returns_none() { |
| 205 | + let registry = FunctionRegistry::new(); |
| 206 | + let expr = SqlExpr::Column { |
| 207 | + table: None, |
| 208 | + name: "name".into(), |
| 209 | + }; |
| 210 | + assert!(fold_constant(&expr, ®istry).is_none()); |
| 211 | + } |
| 212 | +} |
0 commit comments