diff --git a/core/service/src/client/tests/threshold/common.rs b/core/service/src/client/tests/threshold/common.rs index 587ec607cd..e2a978b265 100644 --- a/core/service/src/client/tests/threshold/common.rs +++ b/core/service/src/client/tests/threshold/common.rs @@ -1,8 +1,10 @@ +use core::future::Future; + use crate::client::client_wasm::Client; use crate::client::test_tools::ServerHandle; use crate::conf::{Keychain, SecretSharingKeychain}; use crate::consts::{ - BACKUP_STORAGE_PREFIX_THRESHOLD_ALL, DEFAULT_EPOCH_ID, DEFAULT_MPC_CONTEXT, + BACKUP_STORAGE_PREFIX_THRESHOLD_ALL, DEFAULT_EPOCH_ID, DEFAULT_MPC_CONTEXT, MAX_TRIES, PRIVATE_STORAGE_PREFIX_THRESHOLD_ALL, PUBLIC_STORAGE_PREFIX_THRESHOLD_ALL, SIGNING_KEY_ID, }; use crate::engine::base::derive_request_id; @@ -18,14 +20,20 @@ use crate::util::rate_limiter::RateLimiterConfig; use crate::vault::Vault; use crate::vault::storage::delete_at_request_id; use crate::vault::storage::{StorageType, file::FileStorage}; +use kms_grpc::RequestId; use kms_grpc::kms_service::v1::core_service_endpoint_client::CoreServiceEndpointClient; use kms_grpc::rpc_types::PrivDataType; use std::collections::HashMap; use std::path::Path; +use std::pin::Pin; use tfhe::core_crypto::commons::utils::ZipChecked; use threshold_execution::endpoints::decryption::DecryptionMode; use threshold_execution::tfhe_internals::parameters::DKGParams; use tonic::transport::Channel; +use tonic::{Request, Response, Status}; + +/// RequestIds as they are represented in the current version of the ProtoBuf API. +type ProtoRequestId = kms_grpc::kms::v1::RequestId; #[allow(clippy::too_many_arguments)] async fn threshold_handles_w_vaults( @@ -444,3 +452,36 @@ pub async fn threshold_key_gen_secure( Ok(responses) } + +/// Helper to retry a single poll call until it succeeds or we exhaust [`crate::consts::MAX_TRIES`]. +pub async fn poll_with_retries( + mut client: CoreServiceEndpointClient, + server_id: u32, + req_id: ProtoRequestId, + poll_fn: impl for<'a> Fn( + &'a mut CoreServiceEndpointClient, + Request, + ) + -> Pin, Status>> + Send + 'a>>, +) -> (u32, ProtoRequestId, R) { + for count in 0..MAX_TRIES { + // By default our gRPC calls do not time out. Here we're giving it 2sec per poll attempt to reply. + tokio::select! { + result = poll_fn(&mut client, Request::new(req_id.clone())) => { + match result { + Ok(resp) => return (server_id, req_id, resp.into_inner()), + Err(e) => { + let id_str = RequestId::try_from(req_id.clone()).unwrap().to_string(); + tracing::trace!("Attempt {count} for server {server_id}, req {id_str}: {e:?}"); + } + } + } + _ = tokio::time::sleep(tokio::time::Duration::from_secs(2)) => { + tracing::trace!("Attempt {count} for server {server_id} timed out"); + } + } + // Back-off a little bit before re-trying + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + } + panic!("no response for server {server_id} after {MAX_TRIES} tries"); +} diff --git a/core/service/src/client/tests/threshold/crs_gen_tests.rs b/core/service/src/client/tests/threshold/crs_gen_tests.rs index 01bb76e5b0..1c03dfe642 100644 --- a/core/service/src/client/tests/threshold/crs_gen_tests.rs +++ b/core/service/src/client/tests/threshold/crs_gen_tests.rs @@ -1,6 +1,16 @@ cfg_if::cfg_if! { if #[cfg(any(feature = "slow_tests", feature = "insecure"))] { + use std::collections::HashMap; + + use futures_util::future::join_all; + use itertools::Itertools; + + + use kms_grpc::rpc_types::protobuf_to_alloy_domain; + + use crate::client::client_wasm::Client; + use crate::client::tests::{threshold::common::{poll_with_retries}}; use crate::cryptography::internal_crypto_types::WrappedDKGParams; use crate::dummy_domain; use crate::engine::base::derive_request_id; @@ -11,7 +21,6 @@ cfg_if::cfg_if! { use kms_grpc::kms::v1::CrsInfo; use kms_grpc::kms_service::v1::core_service_endpoint_client::CoreServiceEndpointClient; use kms_grpc::RequestId; - use std::collections::HashMap; use std::path::Path; use threshold_execution::tfhe_internals::parameters::DKGParams; use tokio::task::JoinSet; @@ -265,12 +274,12 @@ pub async fn run_crs( .crs_gen_request(crs_req_id, None, None, max_bits, Some(parameter), &domain) .unwrap(); - let responses = launch_crs(&vec![crs_req.clone()], kms_clients, insecure).await; + let responses = launch_crs(&crs_req, kms_clients, insecure).await; for response in responses { response.unwrap(); } wait_for_crsgen_result( - &vec![crs_req], + &[crs_req], kms_clients, internal_client, &dkg_param, @@ -281,63 +290,71 @@ pub async fn run_crs( #[cfg(any(feature = "slow_tests", feature = "insecure"))] async fn launch_crs( - reqs: &Vec, + req: &CrsGenRequest, kms_clients: &HashMap>, insecure: bool, ) -> Vec, tonic::Status>> { let amount_parties = kms_clients.len(); let mut tasks_gen = JoinSet::new(); - for req in reqs { - for i in 1..=amount_parties as u32 { - let mut cur_client = kms_clients.get(&i).unwrap().clone(); - let req_clone = req.clone(); - tasks_gen.spawn(async move { - if insecure { - #[cfg(feature = "insecure")] - { - cur_client - .insecure_crs_gen(tonic::Request::new(req_clone)) - .await - } - #[cfg(not(feature = "insecure"))] - { - panic!("cannot perform insecure crs gen") - } - } else { - cur_client.crs_gen(tonic::Request::new(req_clone)).await + for i in 1..=amount_parties as u32 { + let mut cur_client = kms_clients.get(&i).unwrap().clone(); + let req_clone = req.clone(); + tasks_gen.spawn(async move { + if insecure { + #[cfg(feature = "insecure")] + { + cur_client + .insecure_crs_gen(tonic::Request::new(req_clone)) + .await } - }); - } + #[cfg(not(feature = "insecure"))] + { + panic!("Asked for insecure crs gen but feature 'insecure' is not active.") + } + } else { + cur_client.crs_gen(tonic::Request::new(req_clone)).await + } + }); } let mut responses_gen = Vec::new(); while let Some(inner) = tasks_gen.join_next().await { let resp = inner.unwrap(); responses_gen.push(resp); } - assert_eq!(responses_gen.len(), amount_parties * reqs.len()); + assert_eq!(responses_gen.len(), amount_parties); responses_gen } #[cfg(any(feature = "slow_tests", feature = "insecure"))] pub async fn wait_for_crsgen_result( - reqs: &Vec, + reqs: &[CrsGenRequest], kms_clients: &HashMap>, internal_client: &Client, param: &DKGParams, test_path: Option<&Path>, ) -> Vec { - let amount_parties = kms_clients.len(); - // wait a bit for the crs generation to finish - let joined_responses = - crate::par_poll_responses!(kms_clients, reqs, get_crs_gen_result, amount_parties); + let amount_parties: usize = kms_clients.len(); + + // Poll each (client, request) pair independently until all succeed. + let mut futs = Vec::new(); + for req in reqs { + let req_id = req.request_id.clone().unwrap(); + for (server_id, client) in kms_clients.iter() { + let client = client.clone(); + futs.push(poll_with_retries( + client, + *server_id, + req_id.clone(), + |c, req| Box::pin(c.get_crs_gen_result(req)), + )) + } + } + let joined_responses = join_all(futs).await; let mut results = Vec::new(); // first check the happy path // the public parameter is checked in ddec tests, so we don't specifically check _pp for req in reqs { - use itertools::Itertools; - use kms_grpc::rpc_types::protobuf_to_alloy_domain; - let req_id: RequestId = req.clone().request_id.unwrap().try_into().unwrap(); let joined_responses: Vec<_> = joined_responses .iter() @@ -355,7 +372,7 @@ pub async fn wait_for_crsgen_result( // we need to setup the storage devices in the right order // so that the client can read the CRS - tracing::info!( + tracing::debug!( "Got {} responses for CRS gen request id {}", joined_responses.len(), req_id @@ -538,72 +555,3 @@ fn set_digests( crs_gen_result.crs_digest = digest.to_vec(); } } - -// Poll the client method function `f_to_poll` until there is a result -// or error out until some timeout. -// The requests from the `reqs` argument need to implement `RequestIdGetter`. -#[macro_export] -macro_rules! par_poll_responses { - ($kms_clients:expr,$reqs:expr,$f_to_poll:ident,$amount_parties:expr) => {{ - use $crate::consts::MAX_TRIES; - let mut joined_responses = vec![]; - for count in 0..MAX_TRIES { - // Reset the list every time since we get all old results as well - joined_responses = vec![]; - tokio::time::sleep(tokio::time::Duration::from_secs(30 * $reqs.len() as u64)).await; - - let mut tasks_get = JoinSet::new(); - for req in $reqs { - for i in 1..=$amount_parties as u32 { - // Make sure we only consider clients for which - // we haven't killed the corresponding server - if let Some(cur_client) = $kms_clients.get(&i) { - let mut cur_client = cur_client.clone(); - let req_id_proto = req.request_id.clone().unwrap(); - tasks_get.spawn(async move { - ( - i, - req_id_proto.clone(), - cur_client - .$f_to_poll(tonic::Request::new(req_id_proto)) - .await, - ) - }); - } - } - } - - while let Some(res) = tasks_get.join_next().await { - match res { - Ok(inner) => { - // Validate if the result returned is ok, if not we ignore, since it likely means that the process is still running on the server - if let (j, req_id, Ok(resp)) = inner { - joined_responses.push((j, req_id, resp.into_inner())); - } else { - let (j, req_id, inner_resp) = inner; - // Explicitly convert to string to avoid any type conversion issues - let req_id_str = match kms_grpc::RequestId::try_from(req_id.clone()).unwrap() { - id => id.to_string(), - }; - tracing::info!("Response in iteration {count} for server {j} and req_id {req_id_str} is: {:?}", inner_resp); - } - } - _ => { - panic!("Something went wrong while polling for responses"); - } - } - } - - if joined_responses.len() >= $kms_clients.len() * $reqs.len() { - break; - } - - // fail if we can't find a response - if count >= MAX_TRIES - 1 { - panic!("could not get response after {} tries", count); - } - } - - joined_responses - }}; -}