Skip to content
Open
6 changes: 6 additions & 0 deletions ooniauth-py/src/exceptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ pub enum OoniErr {

#[error("Deserialization Error: {reason}")]
DeserializationFailed { reason: String },

#[error("Submit emission date {emission_date} cannot build the request window")]
SubmitDateOutOfRange { emission_date: u32 },
}

pub type OoniResult<T> = Result<T, OoniErr>;
Expand All @@ -62,6 +65,9 @@ impl From<OoniErr> for PyErr {
reason: errors::CredentialError::CMZError(e),
} => ProtocolError::new_err(format!("{e}")),
OoniErr::CredentialError { reason } => CredentialError::new_err(format!("{reason}")),
OoniErr::SubmitDateOutOfRange { emission_date } => ProtocolError::new_err(format!(
"submit emission date {emission_date} cannot build the request window"
)),
}
}
}
Expand Down
168 changes: 145 additions & 23 deletions ooniauth-py/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,40 @@ use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
use crate::utils::{from_pystring, to_pystring};
use crate::{exceptions::OoniResult, OoniErr};

fn py_string_arg<'py>(
py: Python<'py>,
value: &'py Py<PyString>,
name: &str,
) -> OoniResult<&'py str> {
value
.to_str(py)
.map_err(|e| OoniErr::DeserializationFailed {
reason: format!("invalid {name}: {e}"),
})
}

fn py_u32_range(
py: Python<'_>,
value: &Py<PyList>,
name: &str,
) -> OoniResult<std::ops::Range<u32>> {
let range = value
.extract::<Vec<u32>>(py)
.map_err(|e| OoniErr::DeserializationFailed {
reason: format!("invalid {name}: {e}"),
})?;
let [start, end]: [u32; 2] =
range
.try_into()
.map_err(|range: Vec<u32>| OoniErr::DeserializationFailed {
reason: format!(
"{name} must contain exactly 2 integers, got {}",
range.len()
),
})?;
Ok(start..end)
}

#[gen_stub_pyclass]
#[pyclass]
pub struct ServerState {
Expand Down Expand Up @@ -84,41 +118,33 @@ impl ServerState {
measurement_count_range: Py<PyList>,
) -> OoniResult<Py<PyString>> {
// Convert arguments from py types to rust types
let nym = nym.to_str(py).map_err(|e| OoniErr::DeserializationFailed {
reason: e.to_string(),
})?;

let nym = BASE64_STANDARD
.decode(nym)
let nym: [u8; 32] = BASE64_STANDARD
.decode(py_string_arg(py, &nym, "nym")?)
.map_err(|e| OoniErr::DeserializationFailed {
reason: e.to_string(),
})?
.try_into()
.map_err(|nym: Vec<u8>| OoniErr::DeserializationFailed {
reason: format!("nym must decode to 32 bytes, got {}", nym.len()),
})?;

let mut nym_32: [u8; 32] = [0; 32];
nym_32.copy_from_slice(nym.as_ref());

let request = from_pystring::<ooniauth_core::submit::SubmitRequest>(py, &request)?;

let probe_cc = probe_cc.to_str(py).expect("Could not get str");
let probe_asn = probe_asn.to_str(py).expect("Could not get str");

let age_range = age_range
.extract::<Vec<u32>>(py)
.expect("could not get list");
let measurement_count_range = measurement_count_range
.extract::<Vec<u32>>(py)
.expect("could not get list");
let probe_cc = py_string_arg(py, &probe_cc, "probe_cc")?;
Comment thread
mmaker marked this conversation as resolved.
let probe_asn = py_string_arg(py, &probe_asn, "probe_asn")?;
let age_range = py_u32_range(py, &age_range, "age_range")?;
let measurement_count_range =
py_u32_range(py, &measurement_count_range, "measurement_count_range")?;

// Handle submission
let mut rng = rand::thread_rng();
let result = self.state.handle_submit(
&mut rng,
request,
&nym_32,
&nym,
probe_cc,
probe_asn,
age_range[0]..age_range[1],
measurement_count_range[0]..measurement_count_range[1],
age_range,
measurement_count_range,
)?;

Ok(to_pystring(py, &result))
Expand Down Expand Up @@ -226,13 +252,19 @@ impl UserState {
) -> OoniResult<SubmitRequest> {
let probe_cc = probe_cc.to_str(py).expect("unable to get string");
let probe_asn = probe_asn.to_str(py).expect("unable to get string");
let age_start = emission_date
.checked_sub(30)
.ok_or(OoniErr::SubmitDateOutOfRange { emission_date })?;
let age_end = emission_date
.checked_add(1)
.ok_or(OoniErr::SubmitDateOutOfRange { emission_date })?;

let mut rng = rand::thread_rng();
let ((result, client_state), nym) = self.state.submit_request(
&mut rng,
probe_cc.into(),
probe_asn.into(),
(emission_date - 30)..(emission_date + 1),
age_start..age_end,
0..100,
)?;

Expand Down Expand Up @@ -306,6 +338,7 @@ pub struct SubmitRequest {

#[cfg(test)]
mod tests {
use crate::OoniErr;
use base64::{prelude::BASE64_STANDARD, Engine};
use ooniauth_core::{registration::open_registration::Request, ServerState, UserState};
use pyo3::{
Expand All @@ -314,6 +347,15 @@ mod tests {
};
use rand::{rngs::ThreadRng, thread_rng};

fn registered_client(py: Python<'_>) -> crate::UserState {
let server = crate::ServerState::new();
let mut client = crate::UserState::new(py, server.get_public_parameters(py)).unwrap();
let req = client.make_registration_request(py).unwrap();
let resp = server.handle_registration_request(py, req).unwrap();
client.handle_registration_response(py, resp).unwrap();
client
}

#[test]
fn test_encoding_verifies() {
// Check that the string encoding still let us verify
Expand Down Expand Up @@ -370,6 +412,24 @@ mod tests {
});
}

#[test]
fn test_make_submit_request_rejects_date_overflow() {
pyo3::Python::initialize();
Python::attach(|py| {
let mut client = registered_client(py);
let cc: Py<PyString> = PyString::new(py, "VE").into();
let asn: Py<PyString> = PyString::new(py, "AS1234").into();
assert!(matches!(
client.make_submit_request(py, cc.clone_ref(py), asn.clone_ref(py), 0),
Err(crate::OoniErr::SubmitDateOutOfRange { .. })
));
assert!(matches!(
client.make_submit_request(py, cc, asn, u32::MAX),
Err(crate::OoniErr::SubmitDateOutOfRange { .. })
));
});
}

#[test]
fn test_credential_update_simple() {
pyo3::Python::initialize();
Expand Down Expand Up @@ -514,6 +574,68 @@ mod tests {
});
}

fn submit_fixture(
Comment thread
mmaker marked this conversation as resolved.
py: Python<'_>,
) -> (
crate::ServerState,
crate::SubmitRequest,
Py<PyString>,
Py<PyString>,
Py<PyList>,
Py<PyList>,
) {
let server = crate::ServerState::new();
let mut client = crate::UserState::new(py, server.get_public_parameters(py)).unwrap();
let req = client.make_registration_request(py).unwrap();
let resp = server.handle_registration_request(py, req).unwrap();
client.handle_registration_response(py, resp).unwrap();
let cc: Py<PyString> = PyString::new(py, "VE").into();
let asn: Py<PyString> = PyString::new(py, "AS1234").into();
let today = crate::ServerState::today();
let submit = client
.make_submit_request(py, cc.clone_ref(py), asn.clone_ref(py), today)
.unwrap();
let age_range: Py<PyList> = PyList::new(py, vec![today - 30, today + 1]).unwrap().into();
let count_range: Py<PyList> = PyList::new(py, vec![0, 100]).unwrap().into();
(server, submit, cc, asn, age_range, count_range)
}

#[test]
fn test_handle_submit_request_rejects_short_nym() {
pyo3::Python::initialize();
Python::attach(|py| {
let (server, submit, cc, asn, age_range, count_range) = submit_fixture(py);
let bad_nym = PyString::new(py, &BASE64_STANDARD.encode([7u8; 31])).into();
let err = server
.handle_submit_request(py, bad_nym, submit.request, cc, asn, age_range, count_range)
.unwrap_err();
assert!(matches!(err, OoniErr::DeserializationFailed { .. }));
});
}

#[test]
fn test_handle_submit_request_rejects_short_age_range() {
pyo3::Python::initialize();
Python::attach(|py| {
let (server, submit, cc, asn, _age_range, count_range) = submit_fixture(py);
let short_range = PyList::new(py, vec![crate::ServerState::today()])
.unwrap()
.into();
let err = server
.handle_submit_request(
py,
submit.nym,
submit.request,
cc,
asn,
short_range,
count_range,
)
.unwrap_err();
assert!(matches!(err, OoniErr::DeserializationFailed { .. }));
});
}

fn setup() -> (ThreadRng, UserState, ServerState) {
let mut rng = thread_rng();
let server = ServerState::new(&mut rng);
Expand Down
Loading