Skip to content
43 changes: 42 additions & 1 deletion core/service/src/client/tests/threshold/common.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use core::future::Future;

use crate::client::client_wasm::Client;
use crate::client::test_tools::ServerHandle;
use crate::conf::{Keychain, SecretSharingKeychain};
use crate::consts::{
BACKUP_STORAGE_PREFIX_THRESHOLD_ALL, DEFAULT_EPOCH_ID, DEFAULT_MPC_CONTEXT,
BACKUP_STORAGE_PREFIX_THRESHOLD_ALL, DEFAULT_EPOCH_ID, DEFAULT_MPC_CONTEXT, MAX_TRIES,
PRIVATE_STORAGE_PREFIX_THRESHOLD_ALL, PUBLIC_STORAGE_PREFIX_THRESHOLD_ALL, SIGNING_KEY_ID,
};
use crate::engine::base::derive_request_id;
Expand All @@ -18,14 +20,20 @@ use crate::util::rate_limiter::RateLimiterConfig;
use crate::vault::Vault;
use crate::vault::storage::delete_at_request_id;
use crate::vault::storage::{StorageType, file::FileStorage};
use kms_grpc::RequestId;
use kms_grpc::kms_service::v1::core_service_endpoint_client::CoreServiceEndpointClient;
use kms_grpc::rpc_types::PrivDataType;
use std::collections::HashMap;
use std::path::Path;
use std::pin::Pin;
use tfhe::core_crypto::commons::utils::ZipChecked;
use threshold_execution::endpoints::decryption::DecryptionMode;
use threshold_execution::tfhe_internals::parameters::DKGParams;
use tonic::transport::Channel;
use tonic::{Request, Response, Status};

/// RequestIds as they are represented in the current version of the ProtoBuf API.
type ProtoRequestId = kms_grpc::kms::v1::RequestId;

#[allow(clippy::too_many_arguments)]
async fn threshold_handles_w_vaults(
Expand Down Expand Up @@ -444,3 +452,36 @@ pub async fn threshold_key_gen_secure(

Ok(responses)
}

/// Helper to retry a single poll call until it succeeds or we exhaust [`crate::consts::MAX_TRIES`].
pub async fn poll_with_retries<R: Send>(
mut client: CoreServiceEndpointClient<Channel>,
server_id: u32,
req_id: ProtoRequestId,
poll_fn: impl for<'a> Fn(
&'a mut CoreServiceEndpointClient<Channel>,
Request<ProtoRequestId>,
)
-> Pin<Box<dyn Future<Output = Result<Response<R>, Status>> + Send + 'a>>,
) -> (u32, ProtoRequestId, R) {
Comment thread
dvdplm marked this conversation as resolved.
for count in 0..MAX_TRIES {
// By default our gRPC calls do not time out. Here we're giving it 2sec per poll attempt to reply.
tokio::select! {
result = poll_fn(&mut client, Request::new(req_id.clone())) => {
match result {
Ok(resp) => return (server_id, req_id, resp.into_inner()),
Err(e) => {
let id_str = RequestId::try_from(req_id.clone()).unwrap().to_string();
tracing::trace!("Attempt {count} for server {server_id}, req {id_str}: {e:?}");
}
}
}
_ = tokio::time::sleep(tokio::time::Duration::from_secs(2)) => {
Comment thread
dvdplm marked this conversation as resolved.
tracing::trace!("Attempt {count} for server {server_id} timed out");
}
}
// Back-off a little bit before re-trying
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
}
panic!("no response for server {server_id} after {MAX_TRIES} tries");
}
154 changes: 51 additions & 103 deletions core/service/src/client/tests/threshold/crs_gen_tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
cfg_if::cfg_if! {
if #[cfg(any(feature = "slow_tests", feature = "insecure"))] {
use std::collections::HashMap;

use futures_util::future::join_all;
use itertools::Itertools;


use kms_grpc::rpc_types::protobuf_to_alloy_domain;


use crate::client::client_wasm::Client;
use crate::client::tests::{threshold::common::{poll_with_retries}};
use crate::cryptography::internal_crypto_types::WrappedDKGParams;
use crate::dummy_domain;
use crate::engine::base::derive_request_id;
Expand All @@ -11,7 +21,6 @@ cfg_if::cfg_if! {
use kms_grpc::kms::v1::CrsInfo;
use kms_grpc::kms_service::v1::core_service_endpoint_client::CoreServiceEndpointClient;
use kms_grpc::RequestId;
use std::collections::HashMap;
use std::path::Path;
use threshold_execution::tfhe_internals::parameters::DKGParams;
use tokio::task::JoinSet;
Expand Down Expand Up @@ -265,12 +274,12 @@ pub async fn run_crs(
.crs_gen_request(crs_req_id, None, None, max_bits, Some(parameter), &domain)
.unwrap();

let responses = launch_crs(&vec![crs_req.clone()], kms_clients, insecure).await;
let responses = launch_crs(&crs_req, kms_clients, insecure).await;
for response in responses {
response.unwrap();
}
wait_for_crsgen_result(
&vec![crs_req],
&[crs_req],
kms_clients,
internal_client,
&dkg_param,
Expand All @@ -281,63 +290,71 @@ pub async fn run_crs(

#[cfg(any(feature = "slow_tests", feature = "insecure"))]
async fn launch_crs(
reqs: &Vec<CrsGenRequest>,
req: &CrsGenRequest,
kms_clients: &HashMap<u32, CoreServiceEndpointClient<Channel>>,
insecure: bool,
) -> Vec<Result<tonic::Response<Empty>, tonic::Status>> {
let amount_parties = kms_clients.len();
let mut tasks_gen = JoinSet::new();
for req in reqs {
for i in 1..=amount_parties as u32 {
let mut cur_client = kms_clients.get(&i).unwrap().clone();
let req_clone = req.clone();
tasks_gen.spawn(async move {
if insecure {
#[cfg(feature = "insecure")]
{
cur_client
.insecure_crs_gen(tonic::Request::new(req_clone))
.await
}
#[cfg(not(feature = "insecure"))]
{
panic!("cannot perform insecure crs gen")
}
} else {
cur_client.crs_gen(tonic::Request::new(req_clone)).await
for i in 1..=amount_parties as u32 {
let mut cur_client = kms_clients.get(&i).unwrap().clone();
let req_clone = req.clone();
tasks_gen.spawn(async move {
if insecure {
#[cfg(feature = "insecure")]
{
cur_client
.insecure_crs_gen(tonic::Request::new(req_clone))
.await
}
});
}
#[cfg(not(feature = "insecure"))]
{
panic!("Asked for insecure crs gen but feature 'insecure' is not active.")
}
} else {
cur_client.crs_gen(tonic::Request::new(req_clone)).await
}
});
}
let mut responses_gen = Vec::new();
while let Some(inner) = tasks_gen.join_next().await {
let resp = inner.unwrap();
responses_gen.push(resp);
}
assert_eq!(responses_gen.len(), amount_parties * reqs.len());
assert_eq!(responses_gen.len(), amount_parties);
responses_gen
}

#[cfg(any(feature = "slow_tests", feature = "insecure"))]
pub async fn wait_for_crsgen_result(
reqs: &Vec<CrsGenRequest>,
reqs: &[CrsGenRequest],
kms_clients: &HashMap<u32, CoreServiceEndpointClient<Channel>>,
internal_client: &Client,
param: &DKGParams,
test_path: Option<&Path>,
) -> Vec<CrsInfo> {
let amount_parties = kms_clients.len();
// wait a bit for the crs generation to finish
let joined_responses =
crate::par_poll_responses!(kms_clients, reqs, get_crs_gen_result, amount_parties);
let amount_parties: usize = kms_clients.len();

// Poll each (client, request) pair independently until all succeed.
let mut futs = Vec::new();
for req in reqs {
let req_id = req.request_id.clone().unwrap();
for (server_id, client) in kms_clients.iter() {
let client = client.clone();
futs.push(poll_with_retries(
client,
*server_id,
req_id.clone(),
|c, req| Box::pin(c.get_crs_gen_result(req)),
))
}
}
let joined_responses = join_all(futs).await;

let mut results = Vec::new();
// first check the happy path
// the public parameter is checked in ddec tests, so we don't specifically check _pp
for req in reqs {
use itertools::Itertools;
use kms_grpc::rpc_types::protobuf_to_alloy_domain;

let req_id: RequestId = req.clone().request_id.unwrap().try_into().unwrap();
let joined_responses: Vec<_> = joined_responses
.iter()
Expand All @@ -355,7 +372,7 @@ pub async fn wait_for_crsgen_result(

// we need to setup the storage devices in the right order
// so that the client can read the CRS
tracing::info!(
tracing::debug!(
"Got {} responses for CRS gen request id {}",
joined_responses.len(),
req_id
Expand Down Expand Up @@ -538,72 +555,3 @@ fn set_digests(
crs_gen_result.crs_digest = digest.to_vec();
}
}

// Poll the client method function `f_to_poll` until there is a result
// or error out until some timeout.
// The requests from the `reqs` argument need to implement `RequestIdGetter`.
#[macro_export]
macro_rules! par_poll_responses {
($kms_clients:expr,$reqs:expr,$f_to_poll:ident,$amount_parties:expr) => {{
use $crate::consts::MAX_TRIES;
let mut joined_responses = vec![];
for count in 0..MAX_TRIES {
// Reset the list every time since we get all old results as well
joined_responses = vec![];
tokio::time::sleep(tokio::time::Duration::from_secs(30 * $reqs.len() as u64)).await;

let mut tasks_get = JoinSet::new();
for req in $reqs {
for i in 1..=$amount_parties as u32 {
// Make sure we only consider clients for which
// we haven't killed the corresponding server
if let Some(cur_client) = $kms_clients.get(&i) {
let mut cur_client = cur_client.clone();
let req_id_proto = req.request_id.clone().unwrap();
tasks_get.spawn(async move {
(
i,
req_id_proto.clone(),
cur_client
.$f_to_poll(tonic::Request::new(req_id_proto))
.await,
)
});
}
}
}

while let Some(res) = tasks_get.join_next().await {
match res {
Ok(inner) => {
// Validate if the result returned is ok, if not we ignore, since it likely means that the process is still running on the server
if let (j, req_id, Ok(resp)) = inner {
joined_responses.push((j, req_id, resp.into_inner()));
} else {
let (j, req_id, inner_resp) = inner;
// Explicitly convert to string to avoid any type conversion issues
let req_id_str = match kms_grpc::RequestId::try_from(req_id.clone()).unwrap() {
id => id.to_string(),
};
tracing::info!("Response in iteration {count} for server {j} and req_id {req_id_str} is: {:?}", inner_resp);
}
}
_ => {
panic!("Something went wrong while polling for responses");
}
}
}

if joined_responses.len() >= $kms_clients.len() * $reqs.len() {
break;
}

// fail if we can't find a response
if count >= MAX_TRIES - 1 {
panic!("could not get response after {} tries", count);
}
}

joined_responses
}};
}
Loading