Skip to content
Open
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion ooniauth-py/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ooniauth_py"
version = "0.2.0"
version = "0.2.1"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
4 changes: 2 additions & 2 deletions ooniauth-py/ooniauth_py.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ class ServerState:
request: str,
probe_cc: str,
probe_asn: str,
age_range: list,
measurement_count_range: list,
age_range: tuple[builtins.int, builtins.int],
measurement_count_range: tuple[builtins.int, builtins.int],
) -> str: ...
def handle_update_request(
self, req: str, old_public_params: str, old_secret_key: str
Expand Down
135 changes: 89 additions & 46 deletions ooniauth-py/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,25 @@ use ooniauth_core::{self as ooni, PublicParameters, SecretKey};

use pyo3::{
prelude::*,
types::{PyList, PyString},
types::PyString,
};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyfunction, gen_stub_pymethods};

use crate::utils::{from_pystring, to_pystring};
use crate::{exceptions::OoniResult, OoniErr};

fn py_string_arg<'py>(
py: Python<'py>,
value: &'py Py<PyString>,
name: &str,
) -> OoniResult<&'py str> {
value
.to_str(py)
.map_err(|e| OoniErr::DeserializationFailed {
reason: format!("invalid {name}: {e}"),
})
}

/// Returns the version of the `ooniauth-core`, the actual protocol implementation.
#[gen_stub_pyfunction(module = "ooniauth-py")]
#[pyfunction]
Expand Down Expand Up @@ -87,45 +99,34 @@ impl ServerState {
request: Py<PyString>,
probe_cc: Py<PyString>,
probe_asn: Py<PyString>,
age_range: Py<PyList>,
measurement_count_range: Py<PyList>,
age_range: (u32, u32),
measurement_count_range: (u32, u32),
) -> OoniResult<Py<PyString>> {
// Convert arguments from py types to rust types
let nym = nym.to_str(py).map_err(|e| OoniErr::DeserializationFailed {
reason: e.to_string(),
})?;

let nym = BASE64_STANDARD
.decode(nym)
let nym: [u8; 32] = BASE64_STANDARD
.decode(py_string_arg(py, &nym, "nym")?)
.map_err(|e| OoniErr::DeserializationFailed {
reason: e.to_string(),
})?
.try_into()
.map_err(|nym: Vec<u8>| OoniErr::DeserializationFailed {
reason: format!("nym must decode to 32 bytes, got {}", nym.len()),
})?;

let mut nym_32: [u8; 32] = [0; 32];
nym_32.copy_from_slice(nym.as_ref());

let request = from_pystring::<ooniauth_core::submit::SubmitRequest>(py, &request)?;

let probe_cc = probe_cc.to_str(py).expect("Could not get str");
let probe_asn = probe_asn.to_str(py).expect("Could not get str");

let age_range = age_range
.extract::<Vec<u32>>(py)
.expect("could not get list");
let measurement_count_range = measurement_count_range
.extract::<Vec<u32>>(py)
.expect("could not get list");
let probe_cc = py_string_arg(py, &probe_cc, "probe_cc")?;
Comment thread
mmaker marked this conversation as resolved.
let probe_asn = py_string_arg(py, &probe_asn, "probe_asn")?;

// Handle submission
let mut rng = rand::thread_rng();
let result = self.state.handle_submit(
&mut rng,
request,
&nym_32,
&nym,
probe_cc,
probe_asn,
age_range[0]..age_range[1],
measurement_count_range[0]..measurement_count_range[1],
age_range.0..age_range.1,
measurement_count_range.0..measurement_count_range.1,
)?;

Ok(to_pystring(py, &result))
Expand Down Expand Up @@ -314,12 +315,10 @@ pub struct SubmitRequest {

#[cfg(test)]
mod tests {
use crate::OoniErr;
use base64::{prelude::BASE64_STANDARD, Engine};
use ooniauth_core::{registration::open_registration::Request, ServerState, UserState};
use pyo3::{
types::{PyList, PyString},
Py, Python,
};
use pyo3::{types::PyString, Py, Python};
use rand::{rngs::ThreadRng, thread_rng};

#[test]
Expand Down Expand Up @@ -358,15 +357,15 @@ mod tests {
let cc = PyString::new(py, "VE");
let asn = PyString::new(py, "AS1234");
let today = ServerState::today();
let age_range = PyList::new(py, vec![today - 30, today + 1]).unwrap();
let msm_range = PyList::new(py, vec![0, 100]).unwrap();
let age_tuple = (today - 30, today + 1);
let msm_tuple = (0u32, 100u32);
let submit_req = client
.make_submit_request(
py,
cc.clone().into(),
asn.clone().into(),
(today - 30, today + 1),
(0, 100),
age_tuple,
msm_tuple,
)
.unwrap();

Expand All @@ -377,8 +376,8 @@ mod tests {
submit_req.request,
cc.into(),
asn.into(),
age_range.into(),
msm_range.into()
age_tuple,
msm_tuple,
)
.is_ok());
});
Expand Down Expand Up @@ -453,17 +452,16 @@ mod tests {
let probe_cc: Py<PyString> = PyString::new(py, "VE").into();
let probe_asn: Py<PyString> = PyString::new(py, "AS8048").into();
let today = ServerState::today();
let age_range: Py<PyList> =
PyList::new(py, vec![today - 30, today + 1]).unwrap().into();
let count_range: Py<PyList> = PyList::new(py, vec![0, 100]).unwrap().into();
let age_tuple = (today - 30, today + 1);
let count_tuple = (0u32, 100u32);

let submit = client
.make_submit_request(
py,
probe_cc.clone_ref(py),
probe_asn.clone_ref(py),
(today - 30, today + 1),
(0, 100),
age_tuple,
count_tuple,
)
.expect("Unable to make submit request");

Expand All @@ -474,8 +472,8 @@ mod tests {
submit.request,
probe_cc.clone_ref(py),
probe_asn.clone_ref(py),
age_range.clone_ref(py),
count_range.clone_ref(py),
age_tuple,
count_tuple,
)
.expect("Invalid submit request");

Expand Down Expand Up @@ -507,8 +505,8 @@ mod tests {
py,
probe_cc.clone_ref(py),
probe_asn.clone_ref(py),
(today - 30, today + 1),
(0, 100),
age_tuple,
count_tuple,
)
.expect("Unable to make submit request");

Expand All @@ -519,8 +517,8 @@ mod tests {
submit.request,
probe_cc,
probe_asn,
age_range,
count_range,
age_tuple,
count_tuple,
)
.expect("Invalid submit request");

Expand All @@ -530,6 +528,51 @@ mod tests {
});
}

fn submit_fixture(
Comment thread
mmaker marked this conversation as resolved.
py: Python<'_>,
) -> (
crate::ServerState,
crate::SubmitRequest,
Py<PyString>,
Py<PyString>,
(u32, u32),
(u32, u32),
) {
let server = crate::ServerState::new();
let mut client = crate::UserState::new(py, server.get_public_parameters(py)).unwrap();
let req = client.make_registration_request(py).unwrap();
let resp = server.handle_registration_request(py, req).unwrap();
client.handle_registration_response(py, resp).unwrap();
let cc: Py<PyString> = PyString::new(py, "VE").into();
let asn: Py<PyString> = PyString::new(py, "AS1234").into();
let today = crate::ServerState::today();
let age_tuple = (today - 30, today + 1);
let count_tuple = (0u32, 100u32);
let submit = client
.make_submit_request(
py,
cc.clone_ref(py),
asn.clone_ref(py),
age_tuple,
count_tuple,
)
.unwrap();
(server, submit, cc, asn, age_tuple, count_tuple)
}

#[test]
fn test_handle_submit_request_rejects_short_nym() {
pyo3::Python::initialize();
Python::attach(|py| {
let (server, submit, cc, asn, age_range, count_range) = submit_fixture(py);
let bad_nym = PyString::new(py, &BASE64_STANDARD.encode([7u8; 31])).into();
let err = server
.handle_submit_request(py, bad_nym, submit.request, cc, asn, age_range, count_range)
.unwrap_err();
assert!(matches!(err, OoniErr::DeserializationFailed { .. }));
});
}

fn setup() -> (ThreadRng, UserState, ServerState) {
let mut rng = thread_rng();
let server = ServerState::new(&mut rng);
Expand Down
Loading