Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .config/nextest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ binary(/cluster/)
| binary(descriptor_versioning_cross_node)
| binary(prepared_cache_invalidation)
| binary(sql_cluster_cross_node_dml)
| binary(pgwire_gateway_migration)
'''
test-group = 'cluster'
threads-required = 'num-test-threads'
Expand Down
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions nodedb-sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ description = "SQL parser, planner, and optimizer for NodeDB"

[dependencies]
nodedb-types = { workspace = true }
nodedb-query = { workspace = true }
sqlparser = "0.61"
thiserror = { workspace = true }
212 changes: 212 additions & 0 deletions nodedb-sql/src/planner/const_fold.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
//! Plan-time constant folding for `SqlExpr`.
//!
//! Evaluates literal expressions and registered zero-or-few-arg scalar
//! functions (e.g. `now()`, `current_timestamp`, `date_add(now(), '1h')`)
//! at plan time via the shared `nodedb_query::functions::eval_function`
//! evaluator.
//!
//! This keeps the bare-`SELECT` projection path, the `INSERT`/`UPSERT`
//! `VALUES` path, and any future default-expression paths from drifting
//! apart — they all reach the same evaluator that the Data Plane uses
//! for column-reference evaluation.
//!
//! Semantics: Postgres / SQL-standard compatible. `now()` and
//! `current_timestamp` snapshot once per statement — `CURRENT_TIMESTAMP`
//! is defined to return the same value for every row of a single
//! statement, and Postgres goes further (same value for the whole
//! transaction). Folding at plan time satisfies both contracts and is
//! cheaper than per-row runtime dispatch.

use std::sync::LazyLock;

use nodedb_types::Value;

use crate::functions::registry::{FunctionCategory, FunctionRegistry};
use crate::types::{BinaryOp, SqlExpr, SqlValue, UnaryOp};

/// Process-wide default registry. Used by call sites that don't already
/// thread a `FunctionRegistry` through (e.g. the DML `VALUES` path).
static DEFAULT_REGISTRY: LazyLock<FunctionRegistry> = LazyLock::new(FunctionRegistry::new);

/// Access the shared default registry.
pub fn default_registry() -> &'static FunctionRegistry {
&DEFAULT_REGISTRY
}

/// Convenience wrapper around [`fold_constant`] using the default registry.
pub fn fold_constant_default(expr: &SqlExpr) -> Option<SqlValue> {
fold_constant(expr, default_registry())
}

/// Fold a `SqlExpr` to a literal `SqlValue` at plan time, or return
/// `None` if the expression depends on row/runtime state (column refs,
/// subqueries, unknown functions, etc.).
pub fn fold_constant(expr: &SqlExpr, registry: &FunctionRegistry) -> Option<SqlValue> {
match expr {
SqlExpr::Literal(v) => Some(v.clone()),
SqlExpr::UnaryOp {
op: UnaryOp::Neg,
expr,
} => match fold_constant(expr, registry)? {
SqlValue::Int(i) => Some(SqlValue::Int(-i)),
SqlValue::Float(f) => Some(SqlValue::Float(-f)),
_ => None,
},
SqlExpr::BinaryOp { left, op, right } => {
let l = fold_constant(left, registry)?;
let r = fold_constant(right, registry)?;
fold_binary(l, *op, r)
}
SqlExpr::Function { name, args, .. } => fold_function_call(name, args, registry),
_ => None,
}
}

fn fold_binary(l: SqlValue, op: BinaryOp, r: SqlValue) -> Option<SqlValue> {
Some(match (l, op, r) {
(SqlValue::Int(a), BinaryOp::Add, SqlValue::Int(b)) => SqlValue::Int(a + b),
(SqlValue::Int(a), BinaryOp::Sub, SqlValue::Int(b)) => SqlValue::Int(a - b),
(SqlValue::Int(a), BinaryOp::Mul, SqlValue::Int(b)) => SqlValue::Int(a * b),
(SqlValue::Float(a), BinaryOp::Add, SqlValue::Float(b)) => SqlValue::Float(a + b),
(SqlValue::Float(a), BinaryOp::Sub, SqlValue::Float(b)) => SqlValue::Float(a - b),
(SqlValue::Float(a), BinaryOp::Mul, SqlValue::Float(b)) => SqlValue::Float(a * b),
(SqlValue::String(a), BinaryOp::Concat, SqlValue::String(b)) => {
SqlValue::String(format!("{a}{b}"))
}
_ => return None,
})
}

/// Fold a function call by recursively folding its arguments, dispatching
/// through the shared scalar evaluator, and converting the result back to
/// `SqlValue`. Only folds functions that are present in `registry`, so
/// callers can distinguish "unknown function" from "known function, all
/// args folded".
pub fn fold_function_call(
name: &str,
args: &[SqlExpr],
registry: &FunctionRegistry,
) -> Option<SqlValue> {
// Gate on registry so unknown-function paths keep their existing
// fallbacks instead of collapsing to SqlValue::Null. Aggregates and
// window functions aren't foldable — they need a row stream.
let meta = registry.lookup(name)?;
if matches!(
meta.category,
FunctionCategory::Aggregate | FunctionCategory::Window
) {
return None;
}

let folded_args: Vec<Value> = args
.iter()
.map(|a| fold_constant(a, registry).map(sql_to_ndb_value))
.collect::<Option<_>>()?;

let result = nodedb_query::functions::eval_function(name, &folded_args);
Some(ndb_to_sql_value(result))
}

fn sql_to_ndb_value(v: SqlValue) -> Value {
match v {
SqlValue::Null => Value::Null,
SqlValue::Bool(b) => Value::Bool(b),
SqlValue::Int(i) => Value::Integer(i),
SqlValue::Float(f) => Value::Float(f),
SqlValue::String(s) => Value::String(s),
SqlValue::Bytes(b) => Value::Bytes(b),
SqlValue::Array(a) => Value::Array(a.into_iter().map(sql_to_ndb_value).collect()),
}
}

fn ndb_to_sql_value(v: Value) -> SqlValue {
match v {
Value::Null => SqlValue::Null,
Value::Bool(b) => SqlValue::Bool(b),
Value::Integer(i) => SqlValue::Int(i),
Value::Float(f) => SqlValue::Float(f),
Value::String(s) => SqlValue::String(s),
Value::Bytes(b) => SqlValue::Bytes(b),
Value::Array(a) => SqlValue::Array(a.into_iter().map(ndb_to_sql_value).collect()),
Value::DateTime(dt) => SqlValue::String(dt.to_iso8601()),
Value::Uuid(s) | Value::Ulid(s) | Value::Regex(s) => SqlValue::String(s),
Value::Duration(d) => SqlValue::String(d.to_human()),
Value::Decimal(d) => SqlValue::String(d.to_string()),
// Structured and opaque types collapse to Null — callers that
// need these go through the runtime expression path, not folding.
Value::Object(_)
| Value::Geometry(_)
| Value::Set(_)
| Value::Range { .. }
| Value::Record { .. } => SqlValue::Null,
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn fold_now_produces_non_epoch_string() {
let registry = FunctionRegistry::new();
let expr = SqlExpr::Function {
name: "now".into(),
args: vec![],
distinct: false,
};
let val = fold_constant(&expr, &registry).expect("now() should fold");
match val {
SqlValue::String(s) => {
assert!(!s.starts_with("1970"), "got {s}");
assert!(s.contains('T'), "not ISO-8601: {s}");
}
other => panic!("expected string, got {other:?}"),
}
}

#[test]
fn fold_current_timestamp() {
let registry = FunctionRegistry::new();
let expr = SqlExpr::Function {
name: "current_timestamp".into(),
args: vec![],
distinct: false,
};
assert!(matches!(
fold_constant(&expr, &registry),
Some(SqlValue::String(_))
));
}

#[test]
fn fold_unknown_function_returns_none() {
let registry = FunctionRegistry::new();
let expr = SqlExpr::Function {
name: "definitely_not_a_real_function".into(),
args: vec![],
distinct: false,
};
assert!(fold_constant(&expr, &registry).is_none());
}

#[test]
fn fold_literal_arithmetic_still_works() {
let registry = FunctionRegistry::new();
let expr = SqlExpr::BinaryOp {
left: Box::new(SqlExpr::Literal(SqlValue::Int(2))),
op: BinaryOp::Add,
right: Box::new(SqlExpr::Literal(SqlValue::Int(3))),
};
assert_eq!(fold_constant(&expr, &registry), Some(SqlValue::Int(5)));
}

#[test]
fn fold_column_ref_returns_none() {
let registry = FunctionRegistry::new();
let expr = SqlExpr::Column {
table: None,
name: "name".into(),
};
assert!(fold_constant(&expr, &registry).is_none());
}
}
15 changes: 13 additions & 2 deletions nodedb-sql/src/planner/dml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,19 @@ fn expr_to_sql_value(expr: &ast::Expr) -> Result<SqlValue> {
}
}
_ => {
// Other functions like now() — store as string for runtime eval.
Ok(SqlValue::String(format!("{expr}")))
// Try folding via the shared scalar evaluator. Handles
// `now()`, `current_timestamp`, `date_add(now(),'1h')`,
// etc. — Postgres semantics: one snapshot per statement.
// Unknown or non-foldable functions fall back to the
// legacy string passthrough so existing behavior for
// other callers is preserved.
if let Ok(sql_expr) = crate::resolver::expr::convert_expr(expr)
&& let Some(v) = super::const_fold::fold_constant_default(&sql_expr)
{
Ok(v)
} else {
Ok(SqlValue::String(format!("{expr}")))
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions nodedb-sql/src/planner/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod aggregate;
pub mod const_fold;
pub mod cte;
pub mod dml;
pub mod join;
Expand Down
38 changes: 7 additions & 31 deletions nodedb-sql/src/planner/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ fn plan_select(
match proj {
Projection::Computed { expr, alias } => {
columns.push(alias.clone());
values.push(eval_constant_expr(expr));
values.push(eval_constant_expr(expr, functions));
}
Projection::Column(name) => {
columns.push(name.clone());
Expand Down Expand Up @@ -797,36 +797,12 @@ pub(crate) fn extract_func_args(func: &ast::Function) -> Result<Vec<ast::Expr>>
}
}

/// Evaluate a constant SqlExpr to a SqlValue.
fn eval_constant_expr(expr: &SqlExpr) -> SqlValue {
match expr {
SqlExpr::Literal(v) => v.clone(),
SqlExpr::UnaryOp {
op: UnaryOp::Neg,
expr,
} => match eval_constant_expr(expr) {
SqlValue::Int(i) => SqlValue::Int(-i),
SqlValue::Float(f) => SqlValue::Float(-f),
other => other,
},
SqlExpr::BinaryOp { left, op, right } => {
let l = eval_constant_expr(left);
let r = eval_constant_expr(right);
match (l, op, r) {
(SqlValue::Int(a), BinaryOp::Add, SqlValue::Int(b)) => SqlValue::Int(a + b),
(SqlValue::Int(a), BinaryOp::Sub, SqlValue::Int(b)) => SqlValue::Int(a - b),
(SqlValue::Int(a), BinaryOp::Mul, SqlValue::Int(b)) => SqlValue::Int(a * b),
(SqlValue::Float(a), BinaryOp::Add, SqlValue::Float(b)) => SqlValue::Float(a + b),
(SqlValue::Float(a), BinaryOp::Sub, SqlValue::Float(b)) => SqlValue::Float(a - b),
(SqlValue::Float(a), BinaryOp::Mul, SqlValue::Float(b)) => SqlValue::Float(a * b),
(SqlValue::String(a), BinaryOp::Concat, SqlValue::String(b)) => {
SqlValue::String(format!("{a}{b}"))
}
_ => SqlValue::Null,
}
}
_ => SqlValue::Null,
}
/// Evaluate a constant SqlExpr to a SqlValue. Delegates to the shared
/// `const_fold::fold_constant` helper so that zero-arg scalar functions
/// like `now()` and `current_timestamp` go through the same evaluator
/// as the runtime expression path.
fn eval_constant_expr(expr: &SqlExpr, functions: &FunctionRegistry) -> SqlValue {
super::const_fold::fold_constant(expr, functions).unwrap_or(SqlValue::Null)
}

/// Extract a geometry argument: handles ST_Point(lon, lat), ST_GeomFromGeoJSON('...'),
Expand Down
73 changes: 73 additions & 0 deletions nodedb/src/control/server/pgwire/ddl/sql_parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,82 @@ pub(super) fn parse_sql_value(val: &str) -> nodedb_types::Value {
if let Ok(f) = trimmed.parse::<f64>() {
return nodedb_types::Value::Float(f);
}
// Scalar function call like `now()` or `date_add(now(), '1h')`, or a
// bare identifier like `current_timestamp` that SQL treats as a
// zero-arg function. Route through the shared evaluator so the
// UPSERT fast-path stays aligned with the SQL planner's VALUES path.
// Unknown names fall through to the legacy string behavior.
if let Some(v) = try_eval_scalar_function(trimmed) {
return v;
}
nodedb_types::Value::String(trimmed.to_string())
}

/// Evaluate a scalar function expression like `now()` or a bare SQL
/// keyword like `current_timestamp` via the shared `nodedb_query`
/// evaluator. Returns `None` if the input isn't a recognizable call
/// form or the function is unknown.
fn try_eval_scalar_function(s: &str) -> Option<nodedb_types::Value> {
// Bare identifier: SQL treats `current_timestamp`, `current_date`,
// etc. as zero-arg function references without parentheses.
let is_bare_ident = s.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
&& !s.is_empty()
&& !s.chars().next().is_some_and(|c| c.is_ascii_digit());

if is_bare_ident {
let name = s.to_lowercase();
// Only fold if the registry knows this name. Gate via nodedb-sql's
// registry so we don't accidentally evaluate user identifiers.
let registry = nodedb_sql::planner::const_fold::default_registry();
if registry.lookup(&name).is_some() {
let val = nodedb_query::functions::eval_function(&name, &[]);
if !matches!(val, nodedb_types::Value::Null) {
return Some(val);
}
}
return None;
}

// Call form `name(args...)`. Parse via sqlparser + fold via const_fold.
if !s.ends_with(')') || !s.contains('(') {
return None;
}
let stmt_sql = format!("SELECT {s}");
let dialect = sqlparser::dialect::PostgreSqlDialect {};
let stmts = sqlparser::parser::Parser::parse_sql(&dialect, &stmt_sql).ok()?;
let stmt = stmts.into_iter().next()?;
let sqlparser::ast::Statement::Query(query) = stmt else {
return None;
};
let sqlparser::ast::SetExpr::Select(select) = *query.body else {
return None;
};
let item = select.projection.into_iter().next()?;
let ast_expr = match item {
sqlparser::ast::SelectItem::UnnamedExpr(e)
| sqlparser::ast::SelectItem::ExprWithAlias { expr: e, .. } => e,
_ => return None,
};
let sql_expr = nodedb_sql::resolver::expr::convert_expr(&ast_expr).ok()?;
let folded = nodedb_sql::planner::const_fold::fold_constant_default(&sql_expr)?;
Some(sql_value_to_ndb_value(folded))
}

fn sql_value_to_ndb_value(v: nodedb_sql::types::SqlValue) -> nodedb_types::Value {
use nodedb_sql::types::SqlValue;
match v {
SqlValue::Null => nodedb_types::Value::Null,
SqlValue::Bool(b) => nodedb_types::Value::Bool(b),
SqlValue::Int(i) => nodedb_types::Value::Integer(i),
SqlValue::Float(f) => nodedb_types::Value::Float(f),
SqlValue::String(s) => nodedb_types::Value::String(s),
SqlValue::Bytes(b) => nodedb_types::Value::Bytes(b),
SqlValue::Array(a) => {
nodedb_types::Value::Array(a.into_iter().map(sql_value_to_ndb_value).collect())
}
}
}

/// Extract a clause value delimited by known keywords.
///
/// Given `upper = "TYPE INT DEFAULT 0 ASSERT $value > 0"`, `original` (same
Expand Down
Loading
Loading