Skip to content

Commit 717a56e

Browse files
authored
enable env to set FulfilmentStrategy for network provers (#32)
* enable env to set FulfilmentStrategy for network provers * rm network zk prover timeout * actually submit encrypted blob lol * defensive take to ensure no unencrypted data remains in intercepted body
1 parent 3c499a2 commit 717a56e

File tree

7 files changed

+153
-59
lines changed

7 files changed

+153
-59
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

example.env

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ SP1_PROVER=cuda
7373
# TODO: This feature is NOT fully implemented! We need a TEE wrapper around proving remote for this.
7474
NETWORK_RPC_URL=https://rpc.production.succinct.xyz
7575
NETWORK_PRIVATE_KEY=
76+
# Strategy one of "UNSPECIFIED_FULFILLMENT_STRATEGY", "HOSTED", "RESERVED", "AUCTION"
77+
# You likely want HOSTED (the hardcoded default for network provers)
78+
# https://github.com/succinctlabs/sp1/blob/11ab6b783cfce295b6f1113af088cc5f0a8caa5b/crates/sdk/src/network/prover.rs#L138
79+
SP1_FULFILLMENT_STRATEGY=HOSTED
7680

7781
#### Dependent & Provider Settings
7882

@@ -90,7 +94,3 @@ CELESTIA_NODE_NAME=celestia-light-mocha-4
9094
CELESTIA_DATA_DIR=/home/you/.celestia-light-mocha-4
9195
CELESTIA_NODE_WRITE_TOKEN=
9296

93-
94-
95-
96-

justfile

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,23 @@ celestia-node-balance:
139139
--data '{"jsonrpc":"2.0","id":1,"method":"state.Balance","params":[]}' \
140140
http://127.0.0.1:26658 | jq
141141

142+
# https://mocha-4.celenium.io/tx/28fa01d026ac5a229e5d5472a204d290beda02ea229f6b3f42da520b00154e58?tab=messages
143+
142144
# Test blob.Get for PDA Proxy
143145
curl-blob-get:
144-
curl -H "Content-Type: application/json" -H "Authorization: Bearer $CELESTIA_NODE_WRITE_TOKEN" --data '{"id": 1,"jsonrpc": "2.0", "method": "blob.Get", "params": [ 42, "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAMJ/xGlNMdE=", "aHlbp+J9yub6hw/uhK6dP8hBLR2mFy78XNRRdLf2794=" ] }' \
146+
curl -H "Content-Type: application/json" -H "Authorization: Bearer $CELESTIA_NODE_WRITE_TOKEN" --data '{"id": 1,"jsonrpc": "2.0", "method": "blob.Get", "params": [ 6629478, "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAMJ/xGlNMdE=", "yS3XX33mc1uXkGinkTCvS9oqE0k9mtHMWTz0mwZccOc=" ] }' \
145147
https://127.0.0.1:26657 \
146148
--insecure | jq
147149

150+
# https://mocha-4.celenium.io/tx/28fa01d026ac5a229e5d5472a204d290beda02ea229f6b3f42da520b00154e58?tab=messages
151+
152+
# Test blob.Get for local light node
153+
curl-blob-get-passthrough:
154+
curl -H "Content-Type: application/json" -H "Authorization: Bearer $CELESTIA_NODE_WRITE_TOKEN" --data '{"id": 1,"jsonrpc": "2.0", "method": "blob.Get", "params": [ 6629478, "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAMJ/xGlNMdE=", "yS3XX33mc1uXkGinkTCvS9oqE0k9mtHMWTz0mwZccOc=" ] }' \
155+
http://127.0.0.1:26658 | jq
156+
157+
# https://mocha.celenium.io/tx/436f223bfa8c4adf1e1b79dde43a84918f3a50809583c57c33c1c079568b47cb?tab=messages
158+
148159
# Test blob.Submit for PDA proxy
149160
curl-blob-submit:
150161
curl -H "Content-Type: application/json" -H "Authorization: Bearer $CELESTIA_NODE_WRITE_TOKEN" \

service/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ version.workspace = true
44
edition.workspace = true
55

66
[dependencies]
7+
zkvm-common = { workspace = true, features = ["std"] }
8+
79
anyhow.workspace = true
810
serde.workspace = true
911
serde_json.workspace = true

service/src/internal/runner.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ use sha2::Digest;
66
use sled::{Transactional, Tree as SledTree};
77
use sp1_sdk::{
88
EnvProver as SP1EnvProver, NetworkProver as SP1NetworkProver, Prover, SP1ProofWithPublicValues,
9-
SP1Stdin, network::Error as SP1NetworkError,
9+
SP1Stdin,
10+
network::{Error as SP1NetworkError, FulfillmentStrategy},
1011
};
1112
use std::sync::Arc;
1213
use tokio::sync::{OnceCell, mpsc};
@@ -521,6 +522,12 @@ impl PdaRunner {
521522
.get_proof_setup_remote(program_id, zk_client_handle.clone())
522523
.await?;
523524

525+
let fs_string = std::env::var("SP1_FULFILLMENT_STRATEGY")
526+
.expect("Missing SP1_FULFILLMENT_STRATEGY env var");
527+
let strategy = FulfillmentStrategy::from_str_name(fs_string.as_str()).ok_or(
528+
PdaRunnerError::InternalError("SP1_FULFILLMENT_STRATEGY env var malformed".to_string()),
529+
)?;
530+
524531
let mut stdin = SP1Stdin::new();
525532
// Setup the inputs:
526533
// - key = 32 bytes
@@ -544,9 +551,10 @@ impl PdaRunner {
544551
debug!("0x{} - Starting proof", hex::encode(job_key));
545552
let request_id: util::SuccNetJobId = zk_client_handle
546553
.prove(&proof_setup.pk, &stdin)
554+
.strategy(strategy)
547555
.groth16()
548556
.skip_simulation(false)
549-
.timeout(std::time::Duration::from_secs(5)) // Don't hang too long on this. If it's gonna fail, fail fast.
557+
// .timeout(std::time::Duration::from_secs(60))
550558
.request_async()
551559
.await
552560
// TODO: how to handle errors without a concrete type? Anyhow is not the right thing for us...

service/src/main.rs

Lines changed: 121 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use log::{debug, error, info, warn};
1616
use rustls::ServerConfig;
1717
use serde_json::json;
1818
use sha2::{Digest, Sha256};
19+
use sp1_sdk::SP1ProofWithPublicValues;
1920
use std::{net::SocketAddr, sync::Arc};
2021
use tokio::{
2122
net::TcpListener,
@@ -29,6 +30,8 @@ use internal::job::*;
2930
use internal::runner::*;
3031
use internal::util::*;
3132

33+
use zkvm_common::{chacha, std_only::ZkvmOutput};
34+
3235
type GenericError = Box<dyn std::error::Error + Send + Sync>;
3336
type BoxBody = http_body_util::combinators::BoxBody<Bytes, GenericError>;
3437

@@ -198,40 +201,47 @@ async fn main() -> Result<()> {
198201
}
199202

200203
let mut request_method: String = Default::default();
201-
let maybe_wrapped_req =
202-
match inbound_handler(plaintext_req, &mut request_method, runner)
203-
.await
204-
{
205-
Ok(req) => req,
206-
Err(e) => return Ok(internal_error_response(e.to_string())),
207-
};
204+
let maybe_wrapped_req = match inbound_handler(
205+
plaintext_req,
206+
&mut request_method,
207+
runner.clone(),
208+
)
209+
.await
210+
{
211+
Ok(req) => req,
212+
Err(e) => return Ok(internal_error_response(e.to_string())),
213+
};
208214

209215
match maybe_wrapped_req {
210216
Some(wrapped_req) => {
211-
debug!(
212-
"Forwarding (maybe modified) `{}` --> DA: {:?}",
213-
request_method, wrapped_req
214-
);
217+
debug!("Forwarding (maybe modified) --> DA",);
218+
// debug!(
219+
// "Forwarding (maybe modified) `{}` --> DA: {:?}",
220+
// request_method, wrapped_req
221+
// );
215222
let returned = match celestia_client.request(wrapped_req).await
216223
{
217-
Ok(resp) => {
218-
outbound_handler(resp, request_method.to_owned())
219-
.await
220-
.unwrap_or_else(|e| {
221-
internal_error_response(format!(
222-
"Outbound Handler: {}",
223-
e
224-
))
225-
})
226-
}
224+
Ok(resp) => outbound_handler(
225+
resp,
226+
request_method.to_owned(),
227+
runner,
228+
)
229+
.await
230+
.unwrap_or_else(|e| {
231+
internal_error_response(format!(
232+
"Outbound Handler: {}",
233+
e
234+
))
235+
}),
227236
Err(e) => {
228237
internal_error_response(format!("DA Client: {}", e))
229238
}
230239
};
231-
debug!(
232-
"Responding (maybe modified) `{}` <-- DA: {:?}",
233-
request_method, returned
234-
);
240+
debug!("Responding (maybe modified) <-- DA",);
241+
// debug!(
242+
// "Responding (maybe modified) `{}` <-- DA: {:?}",
243+
// request_method, returned
244+
// );
235245
anyhow::Ok(returned)
236246
}
237247
None => Ok(pending_response()),
@@ -284,11 +294,10 @@ pub async fn inbound_handler(
284294
.get_mut(0)
285295
.ok_or_else(|| anyhow::anyhow!("Expected first 'params' entry"))?;
286296

287-
let blobs: Vec<Blob> = serde_json::from_value(blobs_value.clone())?;
297+
let blobs: Vec<Blob> = serde_json::from_value(std::mem::take(blobs_value))?;
288298

289299
let mut encrypted_blobs = Vec::with_capacity(blobs.len());
290300

291-
// TODO: consider only allowing one blob and one job on the queue
292301
for blob in blobs {
293302
let pda_runner = pda_runner.clone();
294303
let data = blob.data.clone();
@@ -304,10 +313,11 @@ pub async fn inbound_handler(
304313
if let Some(proof_with_values) =
305314
pda_runner.get_verifiable_encryption(job).await?
306315
{
307-
debug!(
308-
"Replacing blob.data with <Sp1ProofWithPublicValues>.proof = {:?}",
309-
proof_with_values.proof
310-
);
316+
debug!("Replacing blob.data with <Sp1ProofWithPublicValues>.proof");
317+
// debug!(
318+
// "Replacing blob.data with <Sp1ProofWithPublicValues>.proof = {:?}",
319+
// proof_with_values.proof
320+
// );
311321

312322
let encrypted_data = bincode::serialize(&proof_with_values)?;
313323
let encrypted_blob = Blob::new(
@@ -318,9 +328,12 @@ pub async fn inbound_handler(
318328

319329
encrypted_blobs.push(encrypted_blob);
320330
} else {
321-
return Ok(None);
331+
return Ok(None); // Bail out if any blob can't be encrypted
322332
}
323333
}
334+
335+
// Overwrite the original blob array in the params with encrypted blobs
336+
*blobs_value = serde_json::to_value(encrypted_blobs)?;
324337
}
325338
} else {
326339
debug!("Forwarding `blob.Submit` error: missing params");
@@ -355,48 +368,93 @@ pub async fn inbound_handler(
355368
async fn outbound_handler(
356369
resp: Response<IncomingBody>,
357370
request_method: String,
371+
pda_runner: Arc<PdaRunner>,
358372
) -> Result<Response<BoxBody>> {
359373
let (mut parts, body_stream) = resp.into_parts();
360374
let mut body_buf = body_stream.collect().await?.aggregate();
361375
let body_bytes = body_buf.copy_to_bytes(body_buf.remaining());
362376

363-
debug!("Raw upstream response: {:?}", body_bytes);
364-
365377
let status = parts.status;
366-
367378
if status == StatusCode::UNAUTHORIZED {
368379
return Ok(bad_auth_response());
369380
}
370381

371-
let body_json: serde_json::Value = serde_json::from_slice(&body_bytes)?;
372-
// (Optional) handle blob.Get/All for decryption etc.
373-
match request_method.as_str() {
374-
"blob.Get" => {
375-
if let Some(result_raw) = body_json.get("result") {
382+
let mut body_json: serde_json::Value = serde_json::from_slice(&body_bytes)?;
383+
384+
let try_mutate_response = async {
385+
let result_raw = body_json
386+
.get_mut("result")
387+
.ok_or_else(|| anyhow::anyhow!("Missing 'result' field"))?;
388+
389+
let key = <[u8; 32]>::from_hex(
390+
std::env::var("ENCRYPTION_KEY").expect("Missing ENCRYPTION_KEY env var"),
391+
)
392+
.expect("ENCRYPTION_KEY must be 32 bytes, hex encoded (ex: `1234...abcd`)");
393+
394+
match request_method.as_str() {
395+
"blob.Get" => {
376396
let blob: Blob = serde_json::from_value(result_raw.clone())?;
377-
debug!("{blob:?}");
397+
let plaintext_only_blob =
398+
verify_decrypt_blob(blob, key, pda_runner.clone()).await?;
399+
*result_raw = serde_json::to_value(plaintext_only_blob)?;
378400
}
379-
}
380-
"blob.GetAll" => {
381-
if let Some(result_raw) = body_json.get("result") {
382-
let blobs: Vec<Blob> = serde_json::from_value(result_raw.clone())?;
383-
for blob in blobs {
384-
debug!("{blob:?}");
401+
402+
"blob.GetAll" => {
403+
let original_array = result_raw
404+
.as_array()
405+
.ok_or_else(|| anyhow::anyhow!("Expected 'result' to be an array"))?
406+
.clone();
407+
408+
let mut plaintext_blobs = Vec::with_capacity(original_array.len());
409+
for blob_val in original_array {
410+
let blob: Blob = serde_json::from_value(blob_val)?;
411+
let plaintext_only_blob =
412+
verify_decrypt_blob(blob, key, pda_runner.clone()).await?;
413+
plaintext_blobs.push(serde_json::to_value(plaintext_only_blob)?);
385414
}
415+
416+
*result_raw = serde_json::Value::Array(plaintext_blobs);
386417
}
418+
419+
_ => {}
387420
}
388-
_ => {}
421+
422+
Ok::<_, anyhow::Error>(())
423+
}
424+
.await;
425+
426+
if let Err(err) = try_mutate_response {
427+
warn!("Failed to decrypt response: {:?}", err);
428+
let orig_body = Full::new(body_bytes)
429+
.map_err(|err: std::convert::Infallible| match err {})
430+
.boxed();
431+
return Ok(Response::from_parts(parts, orig_body));
389432
}
390433

391434
let json = serde_json::to_string(&body_json)?;
392435
parts.headers.remove("content-length");
436+
393437
let new_body = Full::new(Bytes::from(json))
394438
.map_err(|err: std::convert::Infallible| match err {})
395439
.boxed();
396440

397441
Ok(Response::from_parts(parts, new_body))
398442
}
399443

444+
async fn verify_decrypt_blob(
445+
blob: Blob,
446+
key: [u8; 32],
447+
pda_runner: Arc<PdaRunner>,
448+
) -> Result<Blob, anyhow::Error> {
449+
let proof: SP1ProofWithPublicValues = bincode::deserialize(&blob.data)?;
450+
let output = extract_verified_proof_output(&proof, pda_runner).await?;
451+
let mut buffer = output.ciphertext.to_owned();
452+
chacha(&key, &output.nonce, &mut buffer);
453+
let mut decrypted_plaintext_blob = blob.clone();
454+
decrypted_plaintext_blob.data = buffer.to_vec();
455+
Ok(decrypted_plaintext_blob)
456+
}
457+
400458
/// Job is in queue, we are waiting on it to finish
401459
fn pending_response() -> Response<BoxBody> {
402460
let raw_json = r#"{ "id": 1, "jsonrpc": "2.0", "status": "[pda-proxy] Verifiable encryption processing... Call back for result" }"#;
@@ -438,3 +496,18 @@ fn new_response_from(raw_json: &str, status: StatusCode) -> Response<BoxBody> {
438496
*response.status_mut() = status;
439497
response
440498
}
499+
500+
/// Verify a proof before returning it's attested output
501+
async fn extract_verified_proof_output<'a>(
502+
proof: &'a SP1ProofWithPublicValues,
503+
runner: Arc<PdaRunner>,
504+
) -> Result<ZkvmOutput<'a>> {
505+
let zk_client_local = runner.get_zk_client_local().await;
506+
let vk = &runner
507+
.get_proof_setup_local(&get_program_id().await, zk_client_local.clone())
508+
.await?
509+
.vk;
510+
zk_client_local.verify(proof, vk)?;
511+
512+
ZkvmOutput::from_bytes(proof.public_values.as_slice()).map_err(anyhow::Error::msg)
513+
}

zkVM/common/src/lib.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ pub const HASH_LEN: usize = 32;
1818
pub const NONCE_LEN: usize = 12;
1919
pub const HEADER_LEN: usize = HASH_LEN + NONCE_LEN + HASH_LEN;
2020

21-
/// Encrypt a buffer in-place using [ChaCha20](https://en.wikipedia.org/wiki/Salsa20#ChaCha_variant).
21+
/// Encrypt or Decrypt a buffer in-place using [ChaCha20](https://en.wikipedia.org/wiki/Salsa20#ChaCha_variant).
2222
///
2323
/// ## Important Notice
2424
///
@@ -32,7 +32,6 @@ pub fn chacha(key: &[u8; KEY_LEN], nonce: &[u8; NONCE_LEN], buffer: &mut [u8]) {
3232
cipher.apply_keystream(buffer);
3333
}
3434

35-
// Only compile this when the standard library is available
3635
#[cfg(feature = "std")]
3736
pub mod std_only {
3837
use super::*;
@@ -93,14 +92,14 @@ pub mod std_only {
9392
}
9493
}
9594

96-
// Helper to get a OsRng nonce of correct length
95+
/// Helper to get a OsRng nonce of correct length
9796
pub fn random_nonce() -> [u8; NONCE_LEN] {
9897
let mut nonce = [0u8; NONCE_LEN];
9998
OsRng.try_fill_bytes(&mut nonce).expect("Rng->buffer");
10099
nonce
101100
}
102101

103-
// Helper to format bytes as hex for pretty printing
102+
/// Helper to format bytes as hex for pretty printing
104103
pub fn bytes_to_hex(bytes: &[u8]) -> String {
105104
let digest_hex: String = bytes.iter().map(|b| format!("{:02x}", b)).collect();
106105
digest_hex

0 commit comments

Comments
 (0)