@@ -16,6 +16,7 @@ use log::{debug, error, info, warn};
1616use rustls:: ServerConfig ;
1717use serde_json:: json;
1818use sha2:: { Digest , Sha256 } ;
19+ use sp1_sdk:: SP1ProofWithPublicValues ;
1920use std:: { net:: SocketAddr , sync:: Arc } ;
2021use tokio:: {
2122 net:: TcpListener ,
@@ -29,6 +30,8 @@ use internal::job::*;
2930use internal:: runner:: * ;
3031use internal:: util:: * ;
3132
33+ use zkvm_common:: { chacha, std_only:: ZkvmOutput } ;
34+
3235type GenericError = Box < dyn std:: error:: Error + Send + Sync > ;
3336type BoxBody = http_body_util:: combinators:: BoxBody < Bytes , GenericError > ;
3437
@@ -198,40 +201,47 @@ async fn main() -> Result<()> {
198201 }
199202
200203 let mut request_method: String = Default :: default ( ) ;
201- let maybe_wrapped_req =
202- match inbound_handler ( plaintext_req, & mut request_method, runner)
203- . await
204- {
205- Ok ( req) => req,
206- Err ( e) => return Ok ( internal_error_response ( e. to_string ( ) ) ) ,
207- } ;
204+ let maybe_wrapped_req = match inbound_handler (
205+ plaintext_req,
206+ & mut request_method,
207+ runner. clone ( ) ,
208+ )
209+ . await
210+ {
211+ Ok ( req) => req,
212+ Err ( e) => return Ok ( internal_error_response ( e. to_string ( ) ) ) ,
213+ } ;
208214
209215 match maybe_wrapped_req {
210216 Some ( wrapped_req) => {
211- debug ! (
212- "Forwarding (maybe modified) `{}` --> DA: {:?}" ,
213- request_method, wrapped_req
214- ) ;
217+ debug ! ( "Forwarding (maybe modified) --> DA" , ) ;
218+ // debug!(
219+ // "Forwarding (maybe modified) `{}` --> DA: {:?}",
220+ // request_method, wrapped_req
221+ // );
215222 let returned = match celestia_client. request ( wrapped_req) . await
216223 {
217- Ok ( resp) => {
218- outbound_handler ( resp, request_method. to_owned ( ) )
219- . await
220- . unwrap_or_else ( |e| {
221- internal_error_response ( format ! (
222- "Outbound Handler: {}" ,
223- e
224- ) )
225- } )
226- }
224+ Ok ( resp) => outbound_handler (
225+ resp,
226+ request_method. to_owned ( ) ,
227+ runner,
228+ )
229+ . await
230+ . unwrap_or_else ( |e| {
231+ internal_error_response ( format ! (
232+ "Outbound Handler: {}" ,
233+ e
234+ ) )
235+ } ) ,
227236 Err ( e) => {
228237 internal_error_response ( format ! ( "DA Client: {}" , e) )
229238 }
230239 } ;
231- debug ! (
232- "Responding (maybe modified) `{}` <-- DA: {:?}" ,
233- request_method, returned
234- ) ;
240+ debug ! ( "Responding (maybe modified) <-- DA" , ) ;
241+ // debug!(
242+ // "Responding (maybe modified) `{}` <-- DA: {:?}",
243+ // request_method, returned
244+ // );
235245 anyhow:: Ok ( returned)
236246 }
237247 None => Ok ( pending_response ( ) ) ,
@@ -284,11 +294,10 @@ pub async fn inbound_handler(
284294 . get_mut ( 0 )
285295 . ok_or_else ( || anyhow:: anyhow!( "Expected first 'params' entry" ) ) ?;
286296
287- let blobs: Vec < Blob > = serde_json:: from_value ( blobs_value . clone ( ) ) ?;
297+ let blobs: Vec < Blob > = serde_json:: from_value ( std :: mem :: take ( blobs_value ) ) ?;
288298
289299 let mut encrypted_blobs = Vec :: with_capacity ( blobs. len ( ) ) ;
290300
291- // TODO: consider only allowing one blob and one job on the queue
292301 for blob in blobs {
293302 let pda_runner = pda_runner. clone ( ) ;
294303 let data = blob. data . clone ( ) ;
@@ -304,10 +313,11 @@ pub async fn inbound_handler(
304313 if let Some ( proof_with_values) =
305314 pda_runner. get_verifiable_encryption ( job) . await ?
306315 {
307- debug ! (
308- "Replacing blob.data with <Sp1ProofWithPublicValues>.proof = {:?}" ,
309- proof_with_values. proof
310- ) ;
316+ debug ! ( "Replacing blob.data with <Sp1ProofWithPublicValues>.proof" ) ;
317+ // debug!(
318+ // "Replacing blob.data with <Sp1ProofWithPublicValues>.proof = {:?}",
319+ // proof_with_values.proof
320+ // );
311321
312322 let encrypted_data = bincode:: serialize ( & proof_with_values) ?;
313323 let encrypted_blob = Blob :: new (
@@ -318,9 +328,12 @@ pub async fn inbound_handler(
318328
319329 encrypted_blobs. push ( encrypted_blob) ;
320330 } else {
321- return Ok ( None ) ;
331+ return Ok ( None ) ; // Bail out if any blob can't be encrypted
322332 }
323333 }
334+
335+ // Overwrite the original blob array in the params with encrypted blobs
336+ * blobs_value = serde_json:: to_value ( encrypted_blobs) ?;
324337 }
325338 } else {
326339 debug ! ( "Forwarding `blob.Submit` error: missing params" ) ;
@@ -355,48 +368,93 @@ pub async fn inbound_handler(
355368async fn outbound_handler (
356369 resp : Response < IncomingBody > ,
357370 request_method : String ,
371+ pda_runner : Arc < PdaRunner > ,
358372) -> Result < Response < BoxBody > > {
359373 let ( mut parts, body_stream) = resp. into_parts ( ) ;
360374 let mut body_buf = body_stream. collect ( ) . await ?. aggregate ( ) ;
361375 let body_bytes = body_buf. copy_to_bytes ( body_buf. remaining ( ) ) ;
362376
363- debug ! ( "Raw upstream response: {:?}" , body_bytes) ;
364-
365377 let status = parts. status ;
366-
367378 if status == StatusCode :: UNAUTHORIZED {
368379 return Ok ( bad_auth_response ( ) ) ;
369380 }
370381
371- let body_json: serde_json:: Value = serde_json:: from_slice ( & body_bytes) ?;
372- // (Optional) handle blob.Get/All for decryption etc.
373- match request_method. as_str ( ) {
374- "blob.Get" => {
375- if let Some ( result_raw) = body_json. get ( "result" ) {
382+ let mut body_json: serde_json:: Value = serde_json:: from_slice ( & body_bytes) ?;
383+
384+ let try_mutate_response = async {
385+ let result_raw = body_json
386+ . get_mut ( "result" )
387+ . ok_or_else ( || anyhow:: anyhow!( "Missing 'result' field" ) ) ?;
388+
389+ let key = <[ u8 ; 32 ] >:: from_hex (
390+ std:: env:: var ( "ENCRYPTION_KEY" ) . expect ( "Missing ENCRYPTION_KEY env var" ) ,
391+ )
392+ . expect ( "ENCRYPTION_KEY must be 32 bytes, hex encoded (ex: `1234...abcd`)" ) ;
393+
394+ match request_method. as_str ( ) {
395+ "blob.Get" => {
376396 let blob: Blob = serde_json:: from_value ( result_raw. clone ( ) ) ?;
377- debug ! ( "{blob:?}" ) ;
397+ let plaintext_only_blob =
398+ verify_decrypt_blob ( blob, key, pda_runner. clone ( ) ) . await ?;
399+ * result_raw = serde_json:: to_value ( plaintext_only_blob) ?;
378400 }
379- }
380- "blob.GetAll" => {
381- if let Some ( result_raw) = body_json. get ( "result" ) {
382- let blobs: Vec < Blob > = serde_json:: from_value ( result_raw. clone ( ) ) ?;
383- for blob in blobs {
384- debug ! ( "{blob:?}" ) ;
401+
402+ "blob.GetAll" => {
403+ let original_array = result_raw
404+ . as_array ( )
405+ . ok_or_else ( || anyhow:: anyhow!( "Expected 'result' to be an array" ) ) ?
406+ . clone ( ) ;
407+
408+ let mut plaintext_blobs = Vec :: with_capacity ( original_array. len ( ) ) ;
409+ for blob_val in original_array {
410+ let blob: Blob = serde_json:: from_value ( blob_val) ?;
411+ let plaintext_only_blob =
412+ verify_decrypt_blob ( blob, key, pda_runner. clone ( ) ) . await ?;
413+ plaintext_blobs. push ( serde_json:: to_value ( plaintext_only_blob) ?) ;
385414 }
415+
416+ * result_raw = serde_json:: Value :: Array ( plaintext_blobs) ;
386417 }
418+
419+ _ => { }
387420 }
388- _ => { }
421+
422+ Ok :: < _ , anyhow:: Error > ( ( ) )
423+ }
424+ . await ;
425+
426+ if let Err ( err) = try_mutate_response {
427+ warn ! ( "Failed to decrypt response: {:?}" , err) ;
428+ let orig_body = Full :: new ( body_bytes)
429+ . map_err ( |err : std:: convert:: Infallible | match err { } )
430+ . boxed ( ) ;
431+ return Ok ( Response :: from_parts ( parts, orig_body) ) ;
389432 }
390433
391434 let json = serde_json:: to_string ( & body_json) ?;
392435 parts. headers . remove ( "content-length" ) ;
436+
393437 let new_body = Full :: new ( Bytes :: from ( json) )
394438 . map_err ( |err : std:: convert:: Infallible | match err { } )
395439 . boxed ( ) ;
396440
397441 Ok ( Response :: from_parts ( parts, new_body) )
398442}
399443
444+ async fn verify_decrypt_blob (
445+ blob : Blob ,
446+ key : [ u8 ; 32 ] ,
447+ pda_runner : Arc < PdaRunner > ,
448+ ) -> Result < Blob , anyhow:: Error > {
449+ let proof: SP1ProofWithPublicValues = bincode:: deserialize ( & blob. data ) ?;
450+ let output = extract_verified_proof_output ( & proof, pda_runner) . await ?;
451+ let mut buffer = output. ciphertext . to_owned ( ) ;
452+ chacha ( & key, & output. nonce , & mut buffer) ;
453+ let mut decrypted_plaintext_blob = blob. clone ( ) ;
454+ decrypted_plaintext_blob. data = buffer. to_vec ( ) ;
455+ Ok ( decrypted_plaintext_blob)
456+ }
457+
400458/// Job is in queue, we are waiting on it to finish
401459fn pending_response ( ) -> Response < BoxBody > {
402460 let raw_json = r#"{ "id": 1, "jsonrpc": "2.0", "status": "[pda-proxy] Verifiable encryption processing... Call back for result" }"# ;
@@ -438,3 +496,18 @@ fn new_response_from(raw_json: &str, status: StatusCode) -> Response<BoxBody> {
438496 * response. status_mut ( ) = status;
439497 response
440498}
499+
500+ /// Verify a proof before returning it's attested output
501+ async fn extract_verified_proof_output < ' a > (
502+ proof : & ' a SP1ProofWithPublicValues ,
503+ runner : Arc < PdaRunner > ,
504+ ) -> Result < ZkvmOutput < ' a > > {
505+ let zk_client_local = runner. get_zk_client_local ( ) . await ;
506+ let vk = & runner
507+ . get_proof_setup_local ( & get_program_id ( ) . await , zk_client_local. clone ( ) )
508+ . await ?
509+ . vk ;
510+ zk_client_local. verify ( proof, vk) ?;
511+
512+ ZkvmOutput :: from_bytes ( proof. public_values . as_slice ( ) ) . map_err ( anyhow:: Error :: msg)
513+ }
0 commit comments