55//! all DDL dispatch, transaction handling, and permission checks.
66
77use std:: fmt:: Debug ;
8+ use std:: sync:: Arc ;
89
910use bytes:: Bytes ;
11+ use futures:: StreamExt ;
1012use futures:: sink:: Sink ;
1113use pgwire:: api:: portal:: Portal ;
12- use pgwire:: api:: results:: Response ;
14+ use pgwire:: api:: results:: { DataRowEncoder , FieldInfo , QueryResponse , Response } ;
1315use pgwire:: api:: { ClientInfo , ClientPortalStore , Type } ;
1416use pgwire:: error:: { ErrorInfo , PgWireError , PgWireResult } ;
1517use pgwire:: messages:: PgWireBackendMessage ;
18+ use sonic_rs;
1619
1720use super :: super :: core:: NodeDbPgHandler ;
1821use super :: statement:: ParsedStatement ;
@@ -39,15 +42,134 @@ impl NodeDbPgHandler {
3942 let stmt = & portal. statement . statement ;
4043 let tenant_id = identity. tenant_id ;
4144
45+ // DSL passthroughs (SEARCH, GRAPH, MATCH, UPSERT INTO, etc.) cannot be
46+ // handled by the planned-SQL path. Route them through the same full DSL
47+ // dispatcher used by the simple-query handler. DSL statements do not use
48+ // SQL parameter placeholders, so bound parameters are intentionally ignored.
49+ if stmt. is_dsl {
50+ let mut results = self . execute_sql ( & identity, & addr, & stmt. sql ) . await ?;
51+ return Ok ( results. pop ( ) . unwrap_or ( Response :: EmptyQuery ) ) ;
52+ }
53+
4254 // Convert pgwire binary parameters to typed ParamValues for AST binding.
4355 let params = convert_portal_params ( & portal. parameters , & stmt. param_types ) ?;
4456
4557 // Execute through the planned SQL path with AST-level parameter binding.
4658 let mut results = self
4759 . execute_planned_sql_with_params ( & identity, & stmt. sql , tenant_id, & addr, & params)
4860 . await ?;
49- Ok ( results. pop ( ) . unwrap_or ( Response :: EmptyQuery ) )
61+ let result = results. pop ( ) . unwrap_or ( Response :: EmptyQuery ) ;
62+
63+ // When the statement declared typed result columns via Describe, the
64+ // client expects DataRow messages with one field per declared column.
65+ //
66+ // The generic `payload_to_response` path produces a single-column
67+ // QueryResponse with the full JSON as one text field. In the extended-
68+ // query protocol the RowDescription was already sent by Describe, so
69+ // pgwire sends only the DataRow messages on Execute — the client maps
70+ // them against the previously-described schema. A 1-field row against
71+ // an N-column schema causes null values for columns 2..N.
72+ //
73+ // Fix: when result_fields is non-empty, consume the single-field stream,
74+ // parse each JSON object, and re-encode with one pgwire field per
75+ // declared column.
76+ if !stmt. result_fields . is_empty ( ) {
77+ reproject_response ( result, & stmt. result_fields ) . await
78+ } else {
79+ Ok ( result)
80+ }
81+ }
82+ }
83+
84+ /// Re-encode a query response to match the column schema declared by Describe.
85+ ///
86+ /// Each DataRow from `payload_to_response` contains a single text field holding
87+ /// a JSON object. We parse each object and extract fields in `result_fields`
88+ /// order, producing a new QueryResponse whose rows have one field per declared
89+ /// column. Missing fields are sent as SQL NULL.
90+ ///
91+ /// Non-query responses (execution tags) pass through unchanged.
92+ async fn reproject_response (
93+ response : Response ,
94+ result_fields : & [ FieldInfo ] ,
95+ ) -> PgWireResult < Response > {
96+ let qr = match response {
97+ Response :: Query ( qr) => qr,
98+ other => return Ok ( other) ,
99+ } ;
100+
101+ let schema = Arc :: new ( result_fields. to_vec ( ) ) ;
102+ let field_names: Vec < String > = result_fields. iter ( ) . map ( |f| f. name ( ) . to_string ( ) ) . collect ( ) ;
103+
104+ // Collect JSON objects from the single-column stream produced by
105+ // payload_to_response. Each DataRow has exactly one field: a JSON string.
106+ let json_rows = collect_json_rows ( qr) . await ?;
107+
108+ let mut pgwire_rows = Vec :: with_capacity ( json_rows. len ( ) ) ;
109+ for obj in & json_rows {
110+ let mut encoder = DataRowEncoder :: new ( schema. clone ( ) ) ;
111+ for name in & field_names {
112+ match obj. get ( name) {
113+ None | Some ( serde_json:: Value :: Null ) => {
114+ let _ = encoder. encode_field ( & Option :: < String > :: None ) ;
115+ }
116+ Some ( v) => {
117+ let text = match v {
118+ serde_json:: Value :: String ( s) => s. clone ( ) ,
119+ other => other. to_string ( ) ,
120+ } ;
121+ let _ = encoder. encode_field ( & text) ;
122+ }
123+ }
124+ }
125+ pgwire_rows. push ( Ok ( encoder. take_row ( ) ) ) ;
50126 }
127+
128+ Ok ( Response :: Query ( QueryResponse :: new (
129+ schema,
130+ futures:: stream:: iter ( pgwire_rows) ,
131+ ) ) )
132+ }
133+
134+ /// Consume a `QueryResponse` stream and decode the single text field of each
135+ /// `DataRow` as a JSON object.
136+ ///
137+ /// `payload_to_response` always produces rows where field[0] is a JSON string.
138+ /// The pgwire `DataRow.data` format is: for each field, 4-byte length (i32,
139+ /// big-endian) followed by the field bytes. `-1` (0xFFFFFFFF) means SQL NULL.
140+ async fn collect_json_rows ( mut qr : QueryResponse ) -> PgWireResult < Vec < serde_json:: Value > > {
141+ let mut rows = Vec :: new ( ) ;
142+ while let Some ( row_result) = qr. data_rows . next ( ) . await {
143+ let row = row_result?;
144+ // Decode field[0] from the raw DataRow wire format.
145+ let text = decode_first_field_text ( & row. data ) ;
146+ if let Some ( t) = text {
147+ let val: serde_json:: Value =
148+ sonic_rs:: from_str ( t) . unwrap_or_else ( |_| serde_json:: Value :: String ( t. to_string ( ) ) ) ;
149+ rows. push ( val) ;
150+ }
151+ }
152+ Ok ( rows)
153+ }
154+
155+ /// Decode the text bytes of the first field from a pgwire `DataRow` wire buffer.
156+ ///
157+ /// Wire format: for each field, 4-byte big-endian length followed by bytes.
158+ /// Returns `None` for NULL fields or invalid encodings.
159+ fn decode_first_field_text ( data : & bytes:: BytesMut ) -> Option < & str > {
160+ if data. len ( ) < 4 {
161+ return None ;
162+ }
163+ let len = i32:: from_be_bytes ( [ data[ 0 ] , data[ 1 ] , data[ 2 ] , data[ 3 ] ] ) ;
164+ if len < 0 {
165+ // NULL field.
166+ return None ;
167+ }
168+ let len = len as usize ;
169+ if data. len ( ) < 4 + len {
170+ return None ;
171+ }
172+ std:: str:: from_utf8 ( & data[ 4 ..4 + len] ) . ok ( )
51173}
52174
53175/// Convert pgwire portal parameters to typed `ParamValue` for AST-level binding.
@@ -156,4 +278,27 @@ mod tests {
156278 assert ! ( matches!( result[ 0 ] , nodedb_sql:: ParamValue :: Bool ( v) if v == expected) ) ;
157279 }
158280 }
281+
282+ #[ test]
283+ fn decode_first_field_text_normal ( ) {
284+ // Wire format: 4-byte length (big-endian) + UTF-8 bytes.
285+ let text = b"hello" ;
286+ let mut data = bytes:: BytesMut :: new ( ) ;
287+ data. extend_from_slice ( & ( text. len ( ) as i32 ) . to_be_bytes ( ) ) ;
288+ data. extend_from_slice ( text) ;
289+ assert_eq ! ( decode_first_field_text( & data) , Some ( "hello" ) ) ;
290+ }
291+
292+ #[ test]
293+ fn decode_first_field_text_null ( ) {
294+ // -1 length means SQL NULL.
295+ let mut data = bytes:: BytesMut :: new ( ) ;
296+ data. extend_from_slice ( & ( -1i32 ) . to_be_bytes ( ) ) ;
297+ assert_eq ! ( decode_first_field_text( & data) , None ) ;
298+ }
299+
300+ #[ test]
301+ fn decode_first_field_text_empty ( ) {
302+ assert_eq ! ( decode_first_field_text( & bytes:: BytesMut :: new( ) ) , None ) ;
303+ }
159304}
0 commit comments