diff --git a/Cargo.lock b/Cargo.lock index 394c907..d9e68f0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1320,7 +1320,7 @@ dependencies = [ [[package]] name = "ooniauth_py" -version = "0.2.0" +version = "0.2.1" dependencies = [ "base64", "bincode", diff --git a/ooniauth-py/Cargo.toml b/ooniauth-py/Cargo.toml index a392ff8..476c5d8 100644 --- a/ooniauth-py/Cargo.toml +++ b/ooniauth-py/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ooniauth_py" -version = "0.2.0" +version = "0.2.1" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/ooniauth-py/ooniauth_py.pyi b/ooniauth-py/ooniauth_py.pyi index 33950ac..2e56892 100644 --- a/ooniauth-py/ooniauth_py.pyi +++ b/ooniauth-py/ooniauth_py.pyi @@ -48,8 +48,8 @@ class ServerState: request: str, probe_cc: str, probe_asn: str, - age_range: list, - measurement_count_range: list, + age_range: tuple[builtins.int, builtins.int], + min_measurement_count: builtins.int, ) -> str: ... def handle_update_request( self, req: str, old_public_params: str, old_secret_key: str @@ -79,7 +79,7 @@ class UserState: probe_cc: str, probe_asn: str, age_range: tuple[builtins.int, builtins.int], - measurement_count_range: tuple[builtins.int, builtins.int], + min_measurement_count: builtins.int, ) -> SubmitRequest: ... def handle_submit_response(self, response: str) -> None: r""" diff --git a/ooniauth-py/src/protocol.rs b/ooniauth-py/src/protocol.rs index 29a50bf..c499f34 100644 --- a/ooniauth-py/src/protocol.rs +++ b/ooniauth-py/src/protocol.rs @@ -6,13 +6,25 @@ use ooniauth_core::{self as ooni, PublicParameters, SecretKey}; use pyo3::{ prelude::*, - types::{PyList, PyString}, + types::PyString, }; use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyfunction, gen_stub_pymethods}; use crate::utils::{from_pystring, to_pystring}; use crate::{exceptions::OoniResult, OoniErr}; +fn py_string_arg<'py>( + py: Python<'py>, + value: &'py Py, + name: &str, +) -> OoniResult<&'py str> { + value + .to_str(py) + .map_err(|e| OoniErr::DeserializationFailed { + reason: format!("invalid {name}: {e}"), + }) +} + /// Returns the version of the `ooniauth-core`, the actual protocol implementation. #[gen_stub_pyfunction(module = "ooniauth-py")] #[pyfunction] @@ -87,45 +99,34 @@ impl ServerState { request: Py, probe_cc: Py, probe_asn: Py, - age_range: Py, - measurement_count_range: Py, + age_range: (u32, u32), + min_measurement_count: u32, ) -> OoniResult> { // Convert arguments from py types to rust types - let nym = nym.to_str(py).map_err(|e| OoniErr::DeserializationFailed { - reason: e.to_string(), - })?; - - let nym = BASE64_STANDARD - .decode(nym) + let nym: [u8; 32] = BASE64_STANDARD + .decode(py_string_arg(py, &nym, "nym")?) .map_err(|e| OoniErr::DeserializationFailed { reason: e.to_string(), + })? + .try_into() + .map_err(|nym: Vec| OoniErr::DeserializationFailed { + reason: format!("nym must decode to 32 bytes, got {}", nym.len()), })?; - let mut nym_32: [u8; 32] = [0; 32]; - nym_32.copy_from_slice(nym.as_ref()); - let request = from_pystring::(py, &request)?; - - let probe_cc = probe_cc.to_str(py).expect("Could not get str"); - let probe_asn = probe_asn.to_str(py).expect("Could not get str"); - - let age_range = age_range - .extract::>(py) - .expect("could not get list"); - let measurement_count_range = measurement_count_range - .extract::>(py) - .expect("could not get list"); + let probe_cc = py_string_arg(py, &probe_cc, "probe_cc")?; + let probe_asn = py_string_arg(py, &probe_asn, "probe_asn")?; // Handle submission let mut rng = rand::thread_rng(); let result = self.state.handle_submit( &mut rng, request, - &nym_32, + &nym, probe_cc, probe_asn, - age_range[0]..age_range[1], - measurement_count_range[0]..measurement_count_range[1], + age_range.0..age_range.1, + min_measurement_count..u32::MAX, )?; Ok(to_pystring(py, &result)) @@ -230,7 +231,7 @@ impl UserState { probe_cc: Py, probe_asn: Py, age_range: (u32, u32), - measurement_count_range: (u32, u32), + min_measurement_count: u32, ) -> OoniResult { let probe_cc = probe_cc.to_str(py).expect("unable to get string"); let probe_asn = probe_asn.to_str(py).expect("unable to get string"); @@ -241,7 +242,7 @@ impl UserState { probe_cc.into(), probe_asn.into(), age_range.0..age_range.1, - measurement_count_range.0..measurement_count_range.1, + min_measurement_count..u32::MAX, )?; self.submit_client_state = Some(client_state); @@ -314,12 +315,10 @@ pub struct SubmitRequest { #[cfg(test)] mod tests { + use crate::OoniErr; use base64::{prelude::BASE64_STANDARD, Engine}; use ooniauth_core::{registration::open_registration::Request, ServerState, UserState}; - use pyo3::{ - types::{PyList, PyString}, - Py, Python, - }; + use pyo3::{types::PyString, Py, Python}; use rand::{rngs::ThreadRng, thread_rng}; #[test] @@ -358,15 +357,15 @@ mod tests { let cc = PyString::new(py, "VE"); let asn = PyString::new(py, "AS1234"); let today = ServerState::today(); - let age_range = PyList::new(py, vec![today - 30, today + 1]).unwrap(); - let msm_range = PyList::new(py, vec![0, 100]).unwrap(); + let age_tuple = (today - 30, today + 1); + let min_msm = 0u32; let submit_req = client .make_submit_request( py, cc.clone().into(), asn.clone().into(), - (today - 30, today + 1), - (0, 100), + age_tuple, + min_msm, ) .unwrap(); @@ -377,8 +376,8 @@ mod tests { submit_req.request, cc.into(), asn.into(), - age_range.into(), - msm_range.into() + age_tuple, + min_msm, ) .is_ok()); }); @@ -453,17 +452,16 @@ mod tests { let probe_cc: Py = PyString::new(py, "VE").into(); let probe_asn: Py = PyString::new(py, "AS8048").into(); let today = ServerState::today(); - let age_range: Py = - PyList::new(py, vec![today - 30, today + 1]).unwrap().into(); - let count_range: Py = PyList::new(py, vec![0, 100]).unwrap().into(); + let age_tuple = (today - 30, today + 1); + let min_msm = 0u32; let submit = client .make_submit_request( py, probe_cc.clone_ref(py), probe_asn.clone_ref(py), - (today - 30, today + 1), - (0, 100), + age_tuple, + min_msm, ) .expect("Unable to make submit request"); @@ -474,8 +472,8 @@ mod tests { submit.request, probe_cc.clone_ref(py), probe_asn.clone_ref(py), - age_range.clone_ref(py), - count_range.clone_ref(py), + age_tuple, + min_msm, ) .expect("Invalid submit request"); @@ -507,8 +505,8 @@ mod tests { py, probe_cc.clone_ref(py), probe_asn.clone_ref(py), - (today - 30, today + 1), - (0, 100), + age_tuple, + min_msm, ) .expect("Unable to make submit request"); @@ -519,8 +517,8 @@ mod tests { submit.request, probe_cc, probe_asn, - age_range, - count_range, + age_tuple, + min_msm, ) .expect("Invalid submit request"); @@ -530,6 +528,51 @@ mod tests { }); } + fn submit_fixture( + py: Python<'_>, + ) -> ( + crate::ServerState, + crate::SubmitRequest, + Py, + Py, + (u32, u32), + u32, + ) { + let server = crate::ServerState::new(); + let mut client = crate::UserState::new(py, server.get_public_parameters(py)).unwrap(); + let req = client.make_registration_request(py).unwrap(); + let resp = server.handle_registration_request(py, req).unwrap(); + client.handle_registration_response(py, resp).unwrap(); + let cc: Py = PyString::new(py, "VE").into(); + let asn: Py = PyString::new(py, "AS1234").into(); + let today = crate::ServerState::today(); + let age_tuple = (today - 30, today + 1); + let min_msm = 0u32; + let submit = client + .make_submit_request( + py, + cc.clone_ref(py), + asn.clone_ref(py), + age_tuple, + min_msm, + ) + .unwrap(); + (server, submit, cc, asn, age_tuple, min_msm) + } + + #[test] + fn test_handle_submit_request_rejects_short_nym() { + pyo3::Python::initialize(); + Python::attach(|py| { + let (server, submit, cc, asn, age_range, min_msm) = submit_fixture(py); + let bad_nym = PyString::new(py, &BASE64_STANDARD.encode([7u8; 31])).into(); + let err = server + .handle_submit_request(py, bad_nym, submit.request, cc, asn, age_range, min_msm) + .unwrap_err(); + assert!(matches!(err, OoniErr::DeserializationFailed { .. })); + }); + } + fn setup() -> (ThreadRng, UserState, ServerState) { let mut rng = thread_rng(); let server = ServerState::new(&mut rng);