Skip to content

Commit ebe4787

Browse files
committed
fix(pgwire): fix extended-query protocol for DSL statements and typed result columns
Two issues in the prepared-statement extended-query path: 1. DSL statements (SEARCH, GRAPH, MATCH, UPSERT INTO, etc.) were not handled by the Execute phase. Route them through the same DSL dispatcher used by simple queries; bound parameters are intentionally ignored for DSL. 2. When a statement declares typed result columns via Describe, Execute was producing a single-column JSON response against the N-column schema described to the client, causing null values for columns 2..N. Add a reproject step that parses each JSON object and re-encodes it with one pgwire field per declared column, with missing fields sent as SQL NULL.
1 parent 861f7a5 commit ebe4787

File tree

4 files changed

+178
-3
lines changed

4 files changed

+178
-3
lines changed

nodedb/src/control/server/pgwire/ddl/backup.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ pub async fn restore_tenant(
244244
})?;
245245
let mut aad = [0u8; nodedb_wal::record::HEADER_SIZE];
246246
aad[..6].copy_from_slice(b"BACKUP");
247-
key.decrypt(0, &aad, &raw_bytes[4..])
247+
key.decrypt(key.epoch(), 0, &aad, &raw_bytes[4..])
248248
.map_err(|e| sqlstate_error("XX000", &format!("backup decryption failed: {e}")))?
249249
} else {
250250
raw_bytes

nodedb/src/control/server/pgwire/handler/prepared/execute.rs

Lines changed: 147 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,17 @@
55
//! all DDL dispatch, transaction handling, and permission checks.
66
77
use std::fmt::Debug;
8+
use std::sync::Arc;
89

910
use bytes::Bytes;
11+
use futures::StreamExt;
1012
use futures::sink::Sink;
1113
use pgwire::api::portal::Portal;
12-
use pgwire::api::results::Response;
14+
use pgwire::api::results::{DataRowEncoder, FieldInfo, QueryResponse, Response};
1315
use pgwire::api::{ClientInfo, ClientPortalStore, Type};
1416
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
1517
use pgwire::messages::PgWireBackendMessage;
18+
use sonic_rs;
1619

1720
use super::super::core::NodeDbPgHandler;
1821
use super::statement::ParsedStatement;
@@ -39,15 +42,134 @@ impl NodeDbPgHandler {
3942
let stmt = &portal.statement.statement;
4043
let tenant_id = identity.tenant_id;
4144

45+
// DSL passthroughs (SEARCH, GRAPH, MATCH, UPSERT INTO, etc.) cannot be
46+
// handled by the planned-SQL path. Route them through the same full DSL
47+
// dispatcher used by the simple-query handler. DSL statements do not use
48+
// SQL parameter placeholders, so bound parameters are intentionally ignored.
49+
if stmt.is_dsl {
50+
let mut results = self.execute_sql(&identity, &addr, &stmt.sql).await?;
51+
return Ok(results.pop().unwrap_or(Response::EmptyQuery));
52+
}
53+
4254
// Convert pgwire binary parameters to typed ParamValues for AST binding.
4355
let params = convert_portal_params(&portal.parameters, &stmt.param_types)?;
4456

4557
// Execute through the planned SQL path with AST-level parameter binding.
4658
let mut results = self
4759
.execute_planned_sql_with_params(&identity, &stmt.sql, tenant_id, &addr, &params)
4860
.await?;
49-
Ok(results.pop().unwrap_or(Response::EmptyQuery))
61+
let result = results.pop().unwrap_or(Response::EmptyQuery);
62+
63+
// When the statement declared typed result columns via Describe, the
64+
// client expects DataRow messages with one field per declared column.
65+
//
66+
// The generic `payload_to_response` path produces a single-column
67+
// QueryResponse with the full JSON as one text field. In the extended-
68+
// query protocol the RowDescription was already sent by Describe, so
69+
// pgwire sends only the DataRow messages on Execute — the client maps
70+
// them against the previously-described schema. A 1-field row against
71+
// an N-column schema causes null values for columns 2..N.
72+
//
73+
// Fix: when result_fields is non-empty, consume the single-field stream,
74+
// parse each JSON object, and re-encode with one pgwire field per
75+
// declared column.
76+
if !stmt.result_fields.is_empty() {
77+
reproject_response(result, &stmt.result_fields).await
78+
} else {
79+
Ok(result)
80+
}
81+
}
82+
}
83+
84+
/// Re-encode a query response to match the column schema declared by Describe.
85+
///
86+
/// Each DataRow from `payload_to_response` contains a single text field holding
87+
/// a JSON object. We parse each object and extract fields in `result_fields`
88+
/// order, producing a new QueryResponse whose rows have one field per declared
89+
/// column. Missing fields are sent as SQL NULL.
90+
///
91+
/// Non-query responses (execution tags) pass through unchanged.
92+
async fn reproject_response(
93+
response: Response,
94+
result_fields: &[FieldInfo],
95+
) -> PgWireResult<Response> {
96+
let qr = match response {
97+
Response::Query(qr) => qr,
98+
other => return Ok(other),
99+
};
100+
101+
let schema = Arc::new(result_fields.to_vec());
102+
let field_names: Vec<String> = result_fields.iter().map(|f| f.name().to_string()).collect();
103+
104+
// Collect JSON objects from the single-column stream produced by
105+
// payload_to_response. Each DataRow has exactly one field: a JSON string.
106+
let json_rows = collect_json_rows(qr).await?;
107+
108+
let mut pgwire_rows = Vec::with_capacity(json_rows.len());
109+
for obj in &json_rows {
110+
let mut encoder = DataRowEncoder::new(schema.clone());
111+
for name in &field_names {
112+
match obj.get(name) {
113+
None | Some(serde_json::Value::Null) => {
114+
let _ = encoder.encode_field(&Option::<String>::None);
115+
}
116+
Some(v) => {
117+
let text = match v {
118+
serde_json::Value::String(s) => s.clone(),
119+
other => other.to_string(),
120+
};
121+
let _ = encoder.encode_field(&text);
122+
}
123+
}
124+
}
125+
pgwire_rows.push(Ok(encoder.take_row()));
50126
}
127+
128+
Ok(Response::Query(QueryResponse::new(
129+
schema,
130+
futures::stream::iter(pgwire_rows),
131+
)))
132+
}
133+
134+
/// Consume a `QueryResponse` stream and decode the single text field of each
135+
/// `DataRow` as a JSON object.
136+
///
137+
/// `payload_to_response` always produces rows where field[0] is a JSON string.
138+
/// The pgwire `DataRow.data` format is: for each field, 4-byte length (i32,
139+
/// big-endian) followed by the field bytes. `-1` (0xFFFFFFFF) means SQL NULL.
140+
async fn collect_json_rows(mut qr: QueryResponse) -> PgWireResult<Vec<serde_json::Value>> {
141+
let mut rows = Vec::new();
142+
while let Some(row_result) = qr.data_rows.next().await {
143+
let row = row_result?;
144+
// Decode field[0] from the raw DataRow wire format.
145+
let text = decode_first_field_text(&row.data);
146+
if let Some(t) = text {
147+
let val: serde_json::Value =
148+
sonic_rs::from_str(t).unwrap_or_else(|_| serde_json::Value::String(t.to_string()));
149+
rows.push(val);
150+
}
151+
}
152+
Ok(rows)
153+
}
154+
155+
/// Decode the text bytes of the first field from a pgwire `DataRow` wire buffer.
156+
///
157+
/// Wire format: for each field, 4-byte big-endian length followed by bytes.
158+
/// Returns `None` for NULL fields or invalid encodings.
159+
fn decode_first_field_text(data: &bytes::BytesMut) -> Option<&str> {
160+
if data.len() < 4 {
161+
return None;
162+
}
163+
let len = i32::from_be_bytes([data[0], data[1], data[2], data[3]]);
164+
if len < 0 {
165+
// NULL field.
166+
return None;
167+
}
168+
let len = len as usize;
169+
if data.len() < 4 + len {
170+
return None;
171+
}
172+
std::str::from_utf8(&data[4..4 + len]).ok()
51173
}
52174

53175
/// Convert pgwire portal parameters to typed `ParamValue` for AST-level binding.
@@ -156,4 +278,27 @@ mod tests {
156278
assert!(matches!(result[0], nodedb_sql::ParamValue::Bool(v) if v == expected));
157279
}
158280
}
281+
282+
#[test]
283+
fn decode_first_field_text_normal() {
284+
// Wire format: 4-byte length (big-endian) + UTF-8 bytes.
285+
let text = b"hello";
286+
let mut data = bytes::BytesMut::new();
287+
data.extend_from_slice(&(text.len() as i32).to_be_bytes());
288+
data.extend_from_slice(text);
289+
assert_eq!(decode_first_field_text(&data), Some("hello"));
290+
}
291+
292+
#[test]
293+
fn decode_first_field_text_null() {
294+
// -1 length means SQL NULL.
295+
let mut data = bytes::BytesMut::new();
296+
data.extend_from_slice(&(-1i32).to_be_bytes());
297+
assert_eq!(decode_first_field_text(&data), None);
298+
}
299+
300+
#[test]
301+
fn decode_first_field_text_empty() {
302+
assert_eq!(decode_first_field_text(&bytes::BytesMut::new()), None);
303+
}
159304
}

nodedb/src/control/server/pgwire/handler/prepared/parser.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,17 @@ impl QueryParser for NodeDbQueryParser {
112112
.unwrap_or(1);
113113
let (param_types, result_fields) = self.try_infer_types(sql, types, tenant_id);
114114

115+
// If type inference produced no result fields and the SQL matches a
116+
// known DSL prefix, mark the statement as a DSL passthrough. The
117+
// Execute handler will route it through the full DSL dispatcher
118+
// (same as the simple-query path) instead of `execute_planned_sql_with_params`.
119+
let is_dsl = result_fields.is_empty() && is_dsl_statement(sql);
120+
115121
Ok(ParsedStatement {
116122
sql: sql.to_owned(),
117123
param_types,
118124
result_fields,
125+
is_dsl,
119126
})
120127
}
121128

@@ -136,6 +143,25 @@ impl QueryParser for NodeDbQueryParser {
136143
}
137144
}
138145

146+
/// Return true if `sql` starts with a DSL keyword that `plan_sql` cannot parse.
147+
///
148+
/// Mirrors the prefix checks in `ddl/router/dsl.rs` so the extended-query
149+
/// Parse handler can mark such statements as DSL passthroughs and route them
150+
/// through the DSL dispatcher at Execute time.
151+
fn is_dsl_statement(sql: &str) -> bool {
152+
let upper = sql.trim().to_uppercase();
153+
upper.starts_with("SEARCH ")
154+
|| upper.starts_with("GRAPH ")
155+
|| upper.starts_with("MATCH ")
156+
|| upper.starts_with("OPTIONAL MATCH ")
157+
|| upper.starts_with("CRDT MERGE ")
158+
|| upper.starts_with("UPSERT INTO ")
159+
|| upper.starts_with("CREATE VECTOR INDEX ")
160+
|| upper.starts_with("CREATE FULLTEXT INDEX ")
161+
|| upper.starts_with("CREATE SEARCH INDEX ")
162+
|| upper.starts_with("CREATE SPARSE INDEX ")
163+
}
164+
139165
/// Count $1, $2, ... placeholders in SQL text.
140166
fn count_placeholders(sql: &str) -> usize {
141167
let mut max_idx = 0usize;

nodedb/src/control/server/pgwire/handler/prepared/statement.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,8 @@ pub struct ParsedStatement {
2121
/// Result column schema inferred from the logical plan.
2222
/// Empty for DML statements (INSERT/UPDATE/DELETE).
2323
pub result_fields: Vec<FieldInfo>,
24+
/// True when the SQL is a DSL statement (SEARCH, GRAPH, MATCH, UPSERT INTO,
25+
/// etc.) that `plan_sql` cannot parse. The Execute handler routes these
26+
/// through the full DSL dispatcher instead of `execute_planned_sql_with_params`.
27+
pub is_dsl: bool,
2428
}

0 commit comments

Comments
 (0)