diff --git a/changelog.d/19837.misc b/changelog.d/19837.misc new file mode 100644 index 00000000000..64a531d4ed2 --- /dev/null +++ b/changelog.d/19837.misc @@ -0,0 +1 @@ +Port the synchronous core of client event serialization to Rust. diff --git a/rust/src/duration.rs b/rust/src/duration.rs index 6c2e2653d11..6863e3349cd 100644 --- a/rust/src/duration.rs +++ b/rust/src/duration.rs @@ -40,6 +40,11 @@ impl SynapseDuration { Self { milliseconds } } + /// Returns the duration as a number of milliseconds. + pub const fn as_millis(&self) -> u64 { + self.milliseconds + } + /// Creates a `SynapseDuration` from a number of hours. pub const fn from_hours(hours: u32) -> Self { // We take a u32 here so that we know the multiplication won't overflow. diff --git a/rust/src/events/constants.rs b/rust/src/events/constants.rs index 811794d48c3..b965cd0db05 100644 --- a/rust/src/events/constants.rs +++ b/rust/src/events/constants.rs @@ -74,6 +74,36 @@ pub mod unsigned_field { pub const AGE_TS: &str = "age_ts"; /// Unsigned field: redacted_because pub const REDACTED_BECAUSE: &str = "redacted_because"; + /// Unsigned field: redacted_by + pub const REDACTED_BY: &str = "redacted_by"; + /// Unsigned field: transaction_id + pub const TRANSACTION_ID: &str = "transaction_id"; + /// Unsigned field: org.matrix.msc4140.delay_id + pub const DELAY_ID: &str = "org.matrix.msc4140.delay_id"; + /// Unsigned field: membership (MSC4115) + pub const MEMBERSHIP: &str = "membership"; + /// Unsigned field: msc4354_sticky_duration_ttl_ms (MSC4354) + pub const STICKY_TTL: &str = "msc4354_sticky_duration_ttl_ms"; + /// Unsigned field: io.element.synapse.soft_failed (admin metadata) + pub const SOFT_FAILED: &str = "io.element.synapse.soft_failed"; + /// Unsigned field: io.element.synapse.policy_server_spammy (admin metadata) + pub const POLICY_SERVER_SPAMMY: &str = "io.element.synapse.policy_server_spammy"; + /// Unsigned field: invite_room_state + pub const INVITE_ROOM_STATE: &str = "invite_room_state"; + /// Unsigned field: knock_room_state + pub const KNOCK_ROOM_STATE: &str = "knock_room_state"; + /// Unsigned field: m.relations + pub const M_RELATIONS: &str = "m.relations"; +} + +/// Relation types (the `rel_type` of an `m.relates_to`). +pub mod relation_type { + /// Relation type: m.reference + pub const REFERENCE: &str = "m.reference"; + /// Relation type: m.replace + pub const REPLACE: &str = "m.replace"; + /// Relation type: m.thread + pub const THREAD: &str = "m.thread"; } /// Membership Event Fields diff --git a/rust/src/events/formats/mod.rs b/rust/src/events/formats/mod.rs index 86023add17f..8f24ffd7574 100644 --- a/rust/src/events/formats/mod.rs +++ b/rust/src/events/formats/mod.rs @@ -95,9 +95,13 @@ pub use vmsc4242::EventFormatVMSC4242; /// pyclass. /// /// The `signatures` and `unsigned` fields are kept separate from the other -/// fields as they are mutable (and must be deep-copied if the event is cloned). -/// `common_fields` and `specific_fields` are both `#[serde(flatten)]`ed so that -/// the serialised JSON is a single flat object matching the Matrix spec. +/// fields as they are mutable. Note the derived [`Clone`] is *shallow*: it +/// shares the mutable `signatures`/`unsigned`/internal state behind their +/// `Arc`s (cheap, and fine for read-only uses such as bundled aggregations). +/// Use [`FormattedEvent::deep_copy`] when an independently-mutable copy is +/// required. `common_fields` and `specific_fields` are both +/// `#[serde(flatten)]`ed so that the serialised JSON is a single flat object +/// matching the Matrix spec. /// /// Note, deserialization of this struct must not be done from /// [`serde_json::Value`] nor [`pythonize::depythonize`], due to a bug with @@ -105,7 +109,7 @@ pub use vmsc4242::EventFormatVMSC4242; /// Instead, deserialize directly from a JSON string with /// `serde_json::from_str`. See https://github.com/serde-rs/serde/issues/2230 /// for details. -#[derive(Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize)] pub struct FormattedEvent> { /// The event's signatures. /// diff --git a/rust/src/events/internal_metadata.rs b/rust/src/events/internal_metadata.rs index d61bf0d48c8..0778fbfeaa2 100644 --- a/rust/src/events/internal_metadata.rs +++ b/rust/src/events/internal_metadata.rs @@ -498,6 +498,44 @@ impl EventInternalMetadata { .write() .map_err(|_| PyRuntimeError::new_err("EventInternalMetadata lock poisoned")) } + + /// The event ID of the redaction event, if this event has been redacted. + pub fn redacted_by(&self) -> PyResult> { + Ok(self.read_inner()?.redacted_by.clone()) + } + + /// The transaction ID, if set when the event was created. + /// + /// The transaction ID comes from the `txn_id` path parameter of the + /// client-server API request used to send the event. + pub fn txn_id(&self) -> PyResult> { + Ok(self.read_inner()?.get_txn_id().map(|s| s.to_owned())) + } + + /// The device ID of the sender, if set. + pub fn device_id(&self) -> PyResult> { + Ok(self.read_inner()?.get_device_id().map(|s| s.to_owned())) + } + + /// The access token ID of the sender, if set. + pub fn token_id(&self) -> PyResult> { + Ok(self.read_inner()?.get_token_id()) + } + + /// The delay ID, set only if the event was a delayed event. + pub fn delay_id(&self) -> PyResult> { + Ok(self.read_inner()?.get_delay_id().map(|s| s.to_owned())) + } + + /// Whether the event has been soft failed. + pub fn soft_failed(&self) -> PyResult { + Ok(self.read_inner()?.is_soft_failed()) + } + + /// Whether the policy server marked this event as spammy. + pub fn policy_server_spammy(&self) -> PyResult { + Ok(self.read_inner()?.get_policy_server_spammy()) + } } /// Helper to convert `None` to an `AttributeError` for a property getter. diff --git a/rust/src/events/json_object.rs b/rust/src/events/json_object.rs index bb4877d482f..ef6f3ef4424 100644 --- a/rust/src/events/json_object.rs +++ b/rust/src/events/json_object.rs @@ -17,14 +17,16 @@ use std::{collections::BTreeMap, sync::Arc}; use pyo3::{ exceptions::{PyKeyError, PyTypeError}, + prelude::Borrowed, pyclass, pymethods, types::{ PyAnyMethods, PyIterator, PyList, PyListMethods, PyMapping, PySet, PySetMethods, PyTuple, }, - Bound, IntoPyObject, IntoPyObjectExt, Py, PyAny, PyResult, Python, + Bound, FromPyObject, IntoPyObject, IntoPyObjectExt, Py, PyAny, PyErr, PyResult, Python, }; use pythonize::{depythonize, pythonize}; use serde::{Deserialize, Serialize}; +use serde_json::Value; /// A generic class for representing immutable JSON objects. /// @@ -40,34 +42,46 @@ pub struct JsonObject { object: Arc, serde_json::Value>>, } -#[pymethods] -impl JsonObject { - #[new] - #[pyo3(signature = (content = None))] - fn new<'a, 'py>(content: Option<&'a Bound<'py, PyAny>>) -> PyResult { - let Some(content) = content else { - // If no content is provided, default to an empty object. - return Ok(Self::default()); - }; +// We implement `FromPyObject` to allow `JsonObject` to be used as function +// arguments. +impl<'py> FromPyObject<'_, 'py> for JsonObject { + type Error = PyErr; - if let Ok(content) = content.cast::() { - // If the content is already a JsonObject, we can just clone the - // underlying map (this is safe as the object is immutable). + fn extract(ob: Borrowed<'_, 'py, PyAny>) -> Result { + // Fast path: already a JsonObject, so just share the underlying map + // (cheap, as it's immutable and behind an `Arc`). + if let Ok(obj) = ob.cast::() { return Ok(JsonObject { - object: content.get().object.clone(), + object: obj.get().object.clone(), }); } - let Ok(content) = content.cast::() else { - return Err(PyTypeError::new_err("'content' must be a mapping")); - }; - - // Use pythonize to try and convert from a mapping. - let content = depythonize(content)?; - Ok(Self { - object: Arc::new(content), + // Otherwise accept any mapping and convert it via pythonize. Unlike the + // `#[new]` constructor we don't accept `None` here: an absent value is + // represented as `Option` at the field/argument level. + let mapping = ob + .cast::() + .map_err(|_| PyTypeError::new_err("expected a mapping"))?; + let object: BTreeMap, Value> = depythonize(&mapping)?; + Ok(JsonObject { + object: Arc::new(object), }) } +} + +#[pymethods] +impl JsonObject { + #[new] + #[pyo3(signature = (content = None))] + fn new(content: Option<&Bound<'_, PyAny>>) -> PyResult { + match content { + // If no content is provided, default to an empty object. + None => Ok(Self::default()), + // Otherwise reuse the `FromPyObject` path, which accepts an + // existing `JsonObject` or any Python mapping. + Some(content) => JsonObject::extract(content.as_borrowed()), + } + } fn __len__(&self) -> usize { self.object.len() @@ -197,6 +211,29 @@ impl JsonObject { pub fn get_field(&self, key: &str) -> Option<&serde_json::Value> { self.object.get(key) } + + /// Returns a reference to the underlying map of this object's entries. + pub fn as_map(&self) -> &BTreeMap, Value> { + &self.object + } + + /// Whether the object has no entries. + pub fn is_empty(&self) -> bool { + self.object.is_empty() + } + + pub fn iter(&self) -> impl Iterator, &Value)> { + self.object.iter() + } +} + +impl<'a> IntoIterator for &'a JsonObject { + type Item = (&'a Box, &'a serde_json::Value); + type IntoIter = std::collections::btree_map::Iter<'a, Box, serde_json::Value>; + + fn into_iter(self) -> Self::IntoIter { + self.object.as_ref().iter() + } } /// Helper class returned by `JsonObject.keys()` to act as a view into the keys diff --git a/rust/src/events/mod.rs b/rust/src/events/mod.rs index 83900e14baa..21d56e8e7a3 100644 --- a/rust/src/events/mod.rs +++ b/rust/src/events/mod.rs @@ -58,6 +58,7 @@ use pyo3::{ wrap_pyfunction, Bound, IntoPyObject, PyAny, PyResult, Python, }; use pythonize::{depythonize, pythonize}; +use serde_json::Value; use crate::events::{ constants::event_type::M_ROOM_MEMBER, @@ -87,6 +88,8 @@ pub mod filter; pub mod formats; pub mod internal_metadata; pub mod json_object; +pub mod relations; +pub mod serialize; pub mod signatures; pub mod unsigned; pub mod utils; @@ -107,9 +110,14 @@ pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> child_module.add_class::()?; child_module.add_class::()?; child_module.add_class::()?; + child_module.add_class::()?; + child_module.add_class::()?; + child_module.add_class::()?; + child_module.add_class::()?; child_module.add_function(wrap_pyfunction!(filter::event_visible_to_server_py, m)?)?; child_module.add_function(wrap_pyfunction!(redact_event_py, m)?)?; child_module.add_function(wrap_pyfunction!(redact_event_dict, m)?)?; + child_module.add_function(wrap_pyfunction!(serialize::serialize_events, m)?)?; m.add_submodule(&child_module)?; @@ -129,7 +137,11 @@ pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> /// metadata, rejection reason, and a reference to the room version that /// produced this event). See the module-level docs for the high-level /// design. -#[pyclass(frozen, weakref)] +/// +/// `Clone` is shallow (see [`FormattedEvent`]) and lets an `Event` be held by +/// value, e.g. inside [`BundledAggregations`](crate::events::relations::BundledAggregations). +#[pyclass(frozen, weakref, skip_from_py_object)] +#[derive(Clone)] pub struct Event { /// The parsed event JSON. parsed_event: FormattedEvent, @@ -593,8 +605,19 @@ impl Event { } } - #[getter] - fn redacts<'py>(&self, py: Python<'py>) -> PyResult>> { + /// Returns the `redacts` field of this event, if it has one. + #[getter(redacts)] + fn redacts_py<'py>(&self, py: Python<'py>) -> PyResult>> { + let value = self.redacts(); + value + .map(|v| pythonize(py, v).map_err(Into::into)) + .transpose() + } +} + +impl Event { + /// Returns the `redacts` field of this event, if it has one. + pub fn redacts(&self) -> Option<&Value> { let common = &self.parsed_event.common_fields; let value = if self.room_version.updated_redaction_rules { common.content.get_field(REDACTS) @@ -602,8 +625,6 @@ impl Event { common.other_fields.get(REDACTS) }; value - .map(|v| pythonize(py, v).map_err(Into::into)) - .transpose() } } diff --git a/rust/src/events/relations.rs b/rust/src/events/relations.rs new file mode 100644 index 00000000000..472d3e4f624 --- /dev/null +++ b/rust/src/events/relations.rs @@ -0,0 +1,138 @@ +/* + * This file is licensed under the Affero General Public License (AGPL) version 3. + * + * Copyright (C) 2026 Element Creations Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * See the GNU Affero General Public License for more details: + * . + */ + +//! The bundled aggregations attached to an event for client serialization. +//! +//! These mirror the Matrix "server-side aggregation" data (references, edits +//! and thread summaries) that is folded into an event's `unsigned.m.relations` +//! section when serializing for clients. They are built by the Python +//! `RelationsHandler` and consumed by [`serialize_events`](crate::events::serialize::serialize_events). +//! +//! The events they reference ([`Event`]) are stored by value rather than as +//! Python handles; cloning an `Event` is cheap (it shares the underlying data +//! behind `Arc`s) and the events are only ever read here. + +use pyo3::{pyclass, pymethods, Py, PyTraverseError, PyVisit}; + +use crate::events::{json_object::JsonObject, Event}; + +/// A thread's bundled summary: its latest event, the number of events in the +/// thread, and whether the requesting user has participated. +#[pyclass(frozen, skip_from_py_object, get_all)] +pub struct ThreadAggregation { + /// The latest event in the thread. + pub latest_event: Py, + /// The total number of events in the thread. + pub count: i64, + /// Whether the requesting user has sent an event to the thread. + pub current_user_participated: bool, +} + +#[pymethods] +impl ThreadAggregation { + #[new] + fn new(latest_event: Py, count: i64, current_user_participated: bool) -> Self { + Self { + latest_event, + count, + current_user_participated, + } + } + + #[getter] + fn latest_event(&self) -> &Py { + &self.latest_event + } + + #[getter] + fn count(&self) -> i64 { + self.count + } + + #[getter] + fn current_user_participated(&self) -> bool { + self.current_user_participated + } + + /// The Python GC needs to know that this object references the latest + /// event. + fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + visit.call(&self.latest_event)?; + Ok(()) + } +} + +/// The bundled aggregations for a single event. +/// +/// Some values require additional processing during serialization (the edit +/// and the thread's latest event are themselves serialized). +#[pyclass(frozen, skip_from_py_object, get_all)] +pub struct BundledAggregations { + /// The `m.reference` aggregation (e.g. `{"chunk": [{"event_id": ...}]}`). + pub references: Option, + /// The edit (`m.replace`) event that applies to this event. + pub replace: Option>, + /// The thread (`m.thread`) summary for this event. + pub thread: Option>, +} + +#[pymethods] +impl BundledAggregations { + #[new] + #[pyo3(signature = (references = None, replace = None, thread = None))] + fn new( + references: Option, + replace: Option>, + thread: Option>, + ) -> Self { + Self { + references, + replace, + thread, + } + } + + #[getter] + fn references(&self) -> Option { + self.references.clone() + } + + #[getter] + fn replace(&self) -> Option<&Py> { + self.replace.as_ref() + } + + #[getter] + fn thread(&self) -> Option<&Py> { + self.thread.as_ref() + } + + /// Whether there are any aggregations to bundle. + /// + /// Matches the Python `bool(self.references or self.replace or self.thread)`: + /// an empty `references` mapping counts as falsey. + fn __bool__(&self) -> bool { + self.references.as_ref().is_some_and(|r| !r.is_empty()) + || self.replace.is_some() + || self.thread.is_some() + } + + /// The Python GC needs to know that this object references the latest + /// event. + fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + visit.call(&self.replace)?; + visit.call(&self.thread)?; + Ok(()) + } +} diff --git a/rust/src/events/serialize.rs b/rust/src/events/serialize.rs new file mode 100644 index 00000000000..1f1e5083d77 --- /dev/null +++ b/rust/src/events/serialize.rs @@ -0,0 +1,794 @@ +/* + * This file is licensed under the Affero General Public License (AGPL) version 3. + * + * Copyright (C) 2026 Element Creations Ltd + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * See the GNU Affero General Public License for more details: + * . + */ + +//! The synchronous core of client event serialization. +//! +//! This module turns events from their internal/federation shape into the JSON +//! shape sent to clients: applying the requested [`EventFormat`], folding in +//! redactions, module-callback unsigned additions, bundled aggregations +//! (references, edits and thread summaries) and field filtering. +//! +//! It operates purely on already-fetched data — all DB/IO (fetching redactions, +//! running module callbacks, resolving the admin/MSC4354 config) is performed +//! up front by the Python caller and passed in. +//! +//! The entry point is [`serialize_events`], which reads the Python inputs (the +//! redaction map, bundled aggregations and module-callback additions) once per +//! batch and then serializes each event — recursing entirely in Rust into +//! redactions and bundled aggregations via [`serialize_event`]. + +use std::collections::HashMap; + +use pyo3::{ + exceptions::{PyTypeError, PyValueError}, + pyclass, pyfunction, pymethods, Bound, IntoPyObject, PyAny, PyResult, Python, +}; +use pythonize::pythonize; +use serde_json::{Map, Number, Value}; + +use crate::{ + events::{ + constants::{ + event_field::{CONTENT, EVENT_ID, ROOM_ID, SENDER, UNSIGNED}, + event_type::{M_ROOM_CREATE, M_ROOM_REDACTION}, + redaction_field::REDACTS, + relation_type, unsigned_field, + }, + json_object::JsonObject, + relations::BundledAggregations, + Event, + }, + types::Requester, +}; + +/// The user_id field copied from `sender` by the v1 client format. +const USER_ID: &str = "user_id"; + +/// Keys dropped by the v2 client event format. +const V2_DROP_KEYS: [&str; 7] = [ + "auth_events", + "prev_events", + "hashes", + "signatures", + "depth", + "origin", + "prev_state", +]; + +/// Keys copied from `unsigned` to the top level by the v1 client event format. +const V1_COPY_KEYS: [&str; 6] = [ + "age", + "redacted_because", + "replaces_state", + "prev_content", + "invite_room_state", + "knock_room_state", +]; + +/// The format used to convert an event from its federation shape to the shape +/// sent to clients. +#[pyclass(eq, eq_int, frozen, from_py_object)] +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum EventFormat { + /// Return the event dict unchanged (federation format). + Raw, + /// The legacy `/events`-style client format. + ClientV1, + /// The `/sync`-style client format. + ClientV2, + /// Like `ClientV2`, but also strips `room_id`. + ClientV2WithoutRoomId, +} + +/// Configuration for serializing an event for clients. +/// +/// The output shape is chosen by [`EventFormat`]. The field `requester`, when +/// set, controls whether sender-only fields (such as the transaction ID) are +/// included. +#[pyclass(frozen, skip_from_py_object)] +#[derive(Clone)] +pub struct SerializeEventConfig { + /// Whether to apply the client event format transform (v1/v2/raw). When + /// `false`, the federation-format PDU event is returned as-is. + /// + /// FIXME: Can we remove this and rely on [`Self::event_format`]? + as_client_event: bool, + /// Which client event format variant to apply (only used when + /// `as_client_event` is `true`). + event_format: EventFormat, + /// The entity requesting the event. Used to gate sender-only fields such as + /// `transaction_id` and `delay_id`. + requester: Option, + /// If set, only include these field paths in the output. An empty list + /// returns an empty event; `None` returns all fields. + /// + /// The fields can be "dotted" fields, e.g. `content.body`. + event_field_allowlist: Option>, + /// Whether to include `invite_room_state` / `knock_room_state` in + /// `unsigned`. These are stripped by default and only included for specific + /// endpoints (e.g. `/sync` invite/knock handling). + include_stripped_room_state: bool, + /// When `true`, add server-admin-only metadata to `unsigned` + /// (`io.element.synapse.soft_failed`, + /// `io.element.synapse.policy_server_spammy`). + include_admin_metadata: bool, + /// Whether MSC4354 (sticky events) is enabled. When `true`, the remaining + /// stickiness TTL is computed and added to `unsigned`. + msc4354_enabled: bool, +} + +#[pymethods] +impl SerializeEventConfig { + #[new] + #[allow(clippy::too_many_arguments)] + fn new( + as_client_event: bool, + event_format: EventFormat, + requester: Option>, + event_field_allowlist: Option>, + include_stripped_room_state: bool, + include_admin_metadata: bool, + msc4354_enabled: bool, + ) -> PyResult { + let requester = requester.map(|r| r.get().clone()); + + Ok(Self { + as_client_event, + event_format, + requester, + event_field_allowlist, + include_stripped_room_state, + include_admin_metadata, + msc4354_enabled, + }) + } + + #[getter] + fn as_client_event(&self) -> bool { + self.as_client_event + } + + #[getter] + fn event_format(&self) -> EventFormat { + self.event_format + } + + #[getter] + fn requester<'py>(&self, py: Python<'py>) -> PyResult>> { + self.requester + .as_ref() + .map(|r| r.clone().into_pyobject(py)) + .transpose() + } + + #[getter] + fn event_field_allowlist(&self) -> Option> { + self.event_field_allowlist.clone() + } + + #[getter] + fn include_stripped_room_state(&self) -> bool { + self.include_stripped_room_state + } + + #[getter] + fn include_admin_metadata(&self) -> bool { + self.include_admin_metadata + } + + #[getter] + fn msc4354_enabled(&self) -> bool { + self.msc4354_enabled + } +} + +/// Synchronously serialize a batch of events for clients. +/// +/// `events` is a list of `(event, membership)` pairs, where `event` is a +/// `FilteredEvent.event` and `membership` the corresponding +/// `FilteredEvent.membership`. All DB/IO must already have been performed by the +/// Python caller: `redaction_map` maps redaction event IDs to events, +/// `unsigned_additions` maps event IDs to module-callback unsigned fields, and +/// `bundle_aggregations` maps event IDs to their bundled aggregations. +/// +/// These three maps are shared across the whole batch, so they are read out of +/// Python once and then reused for every event. +#[pyfunction] +#[pyo3(signature = ( + events, + time_now_ms, + config, + *, + bundle_aggregations = None, + redaction_map = None, + unsigned_additions = None, +))] +pub fn serialize_events<'py>( + py: Python<'py>, + events: Vec<(Bound<'py, Event>, Option)>, + time_now_ms: i64, + config: &SerializeEventConfig, + bundle_aggregations: Option>>, + redaction_map: Option>>, + unsigned_additions: Option>, +) -> PyResult>> { + let redaction_map = redaction_map.unwrap_or_default(); + let unsigned_additions = unsigned_additions.unwrap_or_default(); + + events + .iter() + .map(|(event, membership)| { + let serialized = serialize_event( + event.get(), + time_now_ms, + config, + membership.as_deref(), + bundle_aggregations.as_ref(), + &redaction_map, + &unsigned_additions, + )?; + Ok(pythonize(py, &Value::Object(serialized))?) + }) + .collect() +} + +/// The recursive core: serialize a single event, fold in its redaction, +/// module-callback additions and field filtering, then recurse into any +/// bundled aggregations. +#[allow(clippy::too_many_arguments)] +fn serialize_event( + event: &Event, + time_now_ms: i64, + config: &SerializeEventConfig, + membership: Option<&str>, + bundle_aggregations: Option<&HashMap>>, + redaction_map: &HashMap>, + unsigned_additions: &HashMap, +) -> PyResult> { + let mut serialized = serialize_event_value(event, time_now_ms, config, membership)?; + + // If the event was redacted, include the (pre-fetched) redaction event in + // the serialized event's unsigned section. + if let Some(redacted_by) = event.internal_metadata.redacted_by()? { + unsigned_mut(&mut serialized)?.insert( + unsigned_field::REDACTED_BY.to_owned(), + Value::String(redacted_by.clone()), + ); + + if let Some(redaction_event) = redaction_map.get(&redacted_by) { + let serialized_redaction = Value::Object(serialize_event_value( + redaction_event.get(), + time_now_ms, + config, + None, + )?); + unsigned_mut(&mut serialized)?.insert( + unsigned_field::REDACTED_BECAUSE.to_owned(), + serialized_redaction.clone(), + ); + // The v1 client format (apply_event_format) copies redacted_because + // up to the top level, but since we add it after that runs, do it + // here too. + if config.as_client_event && config.event_format == EventFormat::ClientV1 { + serialized.insert( + unsigned_field::REDACTED_BECAUSE.to_owned(), + serialized_redaction, + ); + } + } + } + + // Merge in the module-callback additions. Start from a copy of the additions + // and overlay the event's own unsigned on top, so modules can't clobber + // existing fields. + if let Some(adds) = unsigned_additions.get(event.event_id()) { + let unsigned = unsigned_mut(&mut serialized)?; + for (key, value) in adds { + // Don't let modules clobber existing unsigned fields. + if let serde_json::map::Entry::Vacant(entry) = unsigned.entry(&**key) { + entry.insert(value.clone()); + } + } + } + + // Only include fields that the client has requested. + if let Some(fields) = &config.event_field_allowlist { + if !fields.is_empty() { + serialized = only_fields(&serialized, fields)?; + } + } + + // Inject any bundled aggregations. Note this happens after field filtering; + // aggregations are always returned. + if let Some(bundles) = bundle_aggregations { + if let Some(aggregation) = bundles.get(event.event_id()) { + inject_bundled_aggregations( + time_now_ms, + config, + aggregation.get(), + &mut serialized, + bundle_aggregations, + redaction_map, + unsigned_additions, + )?; + } + } + + Ok(serialized) +} + +/// Inject an event's bundled aggregations (references, edit, thread summary) +/// into the `m.relations` section of its serialized `unsigned`. +#[allow(clippy::too_many_arguments)] +fn inject_bundled_aggregations( + time_now_ms: i64, + config: &SerializeEventConfig, + aggregation: &BundledAggregations, + serialized_event: &mut Map, + bundle_aggregations: Option<&HashMap>>, + redaction_map: &HashMap>, + unsigned_additions: &HashMap, +) -> PyResult<()> { + let mut serialized_aggregations = Map::new(); + + if let Some(references) = &aggregation.references { + if !references.is_empty() { + serialized_aggregations.insert( + relation_type::REFERENCE.to_owned(), + Value::Object( + references + .iter() + .map(|(k, v)| (k.clone().into_string(), v.clone())) + .collect(), + ), + ); + } + } + + if let Some(replace) = &aggregation.replace { + // Bundle the *whole* edit event (serialized without its own bundled + // aggregations). The spec (v1.5) only requires event_id/origin_server_ts/ + // sender, but per MSC3925 we include the full edit. + // https://spec.matrix.org/v1.5/client-server-api/#server-side-aggregation-of-mreplace-relationships + let serialized = serialize_event( + replace.get(), + time_now_ms, + config, + None, + None, + redaction_map, + unsigned_additions, + )?; + serialized_aggregations + .insert(relation_type::REPLACE.to_owned(), Value::Object(serialized)); + } + + if let Some(thread) = &aggregation.thread { + let thread = thread.get(); + // The thread's latest event is serialized with the same bundle map, so + // it may recurse further. + let serialized_latest = serialize_event( + thread.latest_event.get(), + time_now_ms, + config, + None, + bundle_aggregations, + redaction_map, + unsigned_additions, + )?; + + let mut thread_summary = Map::new(); + thread_summary.insert("latest_event".to_owned(), Value::Object(serialized_latest)); + thread_summary.insert( + "count".to_owned(), + Value::Number(Number::from(thread.count)), + ); + thread_summary.insert( + "current_user_participated".to_owned(), + Value::Bool(thread.current_user_participated), + ); + serialized_aggregations.insert( + relation_type::THREAD.to_owned(), + Value::Object(thread_summary), + ); + } + + if !serialized_aggregations.is_empty() { + let unsigned = unsigned_mut(serialized_event)?; + let relations = object_entry_mut(unsigned, unsigned_field::M_RELATIONS)?; + for (key, value) in serialized_aggregations { + relations.insert(key, value); + } + } + + Ok(()) +} + +/// Serialize a single event to its client JSON shape, without recursing into +/// redactions or bundled aggregations (those are handled by the caller). +fn serialize_event_value( + event: &Event, + time_now_ms: i64, + config: &SerializeEventConfig, + membership: Option<&str>, +) -> PyResult> { + let mut d: Map = match serde_json::to_value(&event.parsed_event) { + Ok(Value::Object(map)) => map, + Ok(_) => { + return Err(PyValueError::new_err( + "event did not serialize to a JSON object", + )) + } + Err(err) => { + return Err(PyValueError::new_err(format!( + "Failed to serialize event: {err}" + ))) + } + }; + + // Always include the `event_id` field in a client event. For room version + // v3+, these aren't in the PDU event JSON. + d.insert( + EVENT_ID.to_owned(), + Value::String(event.event_id().to_owned()), + ); + + // Replace `age_ts` with `age`, with `age` calculated as the difference + // between the current time and `age_ts`. This is an optional field in the + // spec. + // + // We might not have an `age_ts`, e.g. if a remote server did not include + // the `age` field in the event it sent us. Since `age_ts` is generated by + // us, it *should* be an integer, but it is possible for it to be out of i64 + // range (e.g. if the original `age` was close to the maximum i64 value). In + // that case, just omit `age` rather than erroring (otherwise a once valid + // event could start failing). + let unsigned = unsigned_mut(&mut d)?; + if let Some(age_ts) = event.unsigned().age_ts()?.and_then(|n| n.as_i64()) { + unsigned.insert( + unsigned_field::AGE.to_owned(), + Value::Number(Number::from(time_now_ms - age_ts)), + ); + unsigned.remove(unsigned_field::AGE_TS); + } + + // Include the transaction_id / delay_id in the unsigned section if the event + // was sent by the same session (or, where appropriate, the same sender) as + // the one requesting the event. + if let Some(requester) = &config.requester { + if requester.user_id == event.sender() { + if let Some(txn_id) = event.internal_metadata.txn_id()? { + if let Some(event_device_id) = event.internal_metadata.device_id()? { + if Some(event_device_id.as_str()) == requester.device_id.as_deref() { + unsigned_mut(&mut d)?.insert( + unsigned_field::TRANSACTION_ID.to_owned(), + Value::String(txn_id), + ); + } + } else { + // No device ID is stored for some events: old events, and + // those created by appservices, guests, or with admin-API + // tokens. For those, fall back to the access token: only + // include the transaction ID if the event was sent from the + // same token (or for guests/appservices, which we can't + // check, so assume the same session). + let event_token_id = event.internal_metadata.token_id()?; + let token_matches = event_token_id.is_some() + && requester.access_token_id.is_some() + && event_token_id == requester.access_token_id; + if token_matches || requester.is_guest || requester.app_service_id.is_some() { + unsigned_mut(&mut d)?.insert( + unsigned_field::TRANSACTION_ID.to_owned(), + Value::String(txn_id), + ); + } + } + } + + if let Some(delay_id) = event.internal_metadata.delay_id()? { + unsigned_mut(&mut d)? + .insert(unsigned_field::DELAY_ID.to_owned(), Value::String(delay_id)); + } + } + } + + // Strip invite/knock room state unless requested. + if !config.include_stripped_room_state { + let unsigned = unsigned_mut(&mut d)?; + unsigned.remove(unsigned_field::INVITE_ROOM_STATE); + unsigned.remove(unsigned_field::KNOCK_ROOM_STATE); + } + + if config.as_client_event { + apply_event_format(config.event_format, &mut d); + } + + // Ensure the room_id field is set for create events in MSC4291 rooms. + if event.r#type() == M_ROOM_CREATE && event.room_version.msc4291_room_ids_as_hashes { + d.insert( + ROOM_ID.to_owned(), + Value::String(event.room_id().to_owned()), + ); + } + + // A redaction stores the redacted event ID in different places depending + // on the room version (top-level `redacts` vs `content.redacts`). It's + // already in the version-correct place; copy it to the *other* one too, + // for forwards/backwards-compatibility with clients. + if event.r#type() == M_ROOM_REDACTION { + let redacts = event.redacts(); + // Skip a present-but-null value: the Python `e.redacts` property + // surfaced JSON null as `None`, and the old code guarded with + // `e.redacts is not None`. + if let Some(redacts) = redacts.filter(|v| !v.is_null()) { + let redacts = redacts.clone(); + if event.room_version.updated_redaction_rules { + d.insert(REDACTS.to_owned(), redacts); + } else { + object_entry_mut(&mut d, CONTENT)?.insert(REDACTS.to_owned(), redacts); + } + } + } + + let unsigned = unsigned_mut(&mut d)?; + if config.include_admin_metadata { + if event.internal_metadata.soft_failed()? { + unsigned.insert(unsigned_field::SOFT_FAILED.to_owned(), Value::Bool(true)); + } + if event.internal_metadata.policy_server_spammy()? { + unsigned.insert( + unsigned_field::POLICY_SERVER_SPAMMY.to_owned(), + Value::Bool(true), + ); + } + } + + if config.msc4354_enabled { + if let Some(sticky_duration) = event.sticky_duration() { + // min() ensures the origin server can't claim a time in the future + // to exceed the stickiness duration limit. + // + // The `as i64` cast is safe as sticky duration are capped to an + // hour, which is well within the i64 range. + let expires_at = std::cmp::min(event.origin_server_ts(), time_now_ms) + + sticky_duration.as_millis() as i64; + if expires_at > time_now_ms { + unsigned.insert( + unsigned_field::STICKY_TTL.to_owned(), + Value::Number(Number::from(expires_at - time_now_ms)), + ); + } + } + } + + if let Some(membership) = membership { + unsigned.insert( + unsigned_field::MEMBERSHIP.to_owned(), + Value::String(membership.to_owned()), + ); + } + + Ok(d) +} + +/// Apply the client event format transform in place. +fn apply_event_format(format: EventFormat, d: &mut Map) { + match format { + EventFormat::Raw => {} + EventFormat::ClientV2 => format_for_client_v2(d), + EventFormat::ClientV2WithoutRoomId => { + format_for_client_v2(d); + d.remove(ROOM_ID); + } + EventFormat::ClientV1 => { + format_for_client_v2(d); + + let sender = d.get(SENDER).filter(|v| !v.is_null()).cloned(); + if let Some(sender) = sender { + d.insert(USER_ID.to_owned(), sender); + } + + let mut to_copy = Vec::new(); + if let Some(Value::Object(unsigned)) = d.get(UNSIGNED) { + for key in V1_COPY_KEYS { + if let Some(value) = unsigned.get(key) { + to_copy.push((key.to_owned(), value.clone())); + } + } + } + for (key, value) in to_copy { + d.insert(key, value); + } + } + } +} + +fn format_for_client_v2(d: &mut Map) { + for key in V2_DROP_KEYS { + d.remove(key); + } +} + +/// Return a mutable reference to `map["unsigned"]`, creating it as an empty +/// object if it is missing or not an object. +fn unsigned_mut(map: &mut Map) -> PyResult<&mut Map> { + object_entry_mut(map, UNSIGNED) +} + +/// Return a mutable reference to `map[key]`, creating it as an empty object if +/// missing or not an object. +fn object_entry_mut<'a>( + map: &'a mut Map, + key: &str, +) -> PyResult<&'a mut Map> { + let entry = map + .entry(key.to_owned()) + .or_insert_with(|| Value::Object(Map::new())); + + let Some(obj) = entry.as_object_mut() else { + return Err(PyTypeError::new_err(format!( + "Expected an object for key '{key}'" + ))); + }; + + Ok(obj) +} + +/// Return a new map containing only the given (possibly dotted) field paths, +/// implementing the `event_field_allowlist` client filter. +fn only_fields(dictionary: &Map, fields: &[String]) -> PyResult> { + let mut output = Map::new(); + for field in fields { + copy_field(dictionary, &mut output, &split_field(field))?; + } + Ok(output) +} + +/// Copy a single (possibly nested) field path from `src` into `dst`, creating +/// intermediate objects in `dst` as needed. A missing path is a no-op. +fn copy_field( + src: &Map, + dst: &mut Map, + field: &[String], +) -> PyResult<()> { + if field.is_empty() { + return Ok(()); + } + if field.len() == 1 { + if let Some(value) = src.get(&field[0]) { + dst.insert(field[0].clone(), value.clone()); + } + return Ok(()); + } + + let (key_to_move, parents) = field.split_last().expect("field is non-empty"); + + // Drill down into `src`. + let mut sub = src; + for parent in parents { + match sub.get(parent) { + Some(Value::Object(obj)) => sub = obj, + _ => return Ok(()), + } + } + + let Some(value) = sub.get(key_to_move) else { + return Ok(()); + }; + let value = value.clone(); + + // Build the nested objects in `dst` as required. + let mut out = dst; + for parent in parents { + out = object_entry_mut(out, parent)?; + } + out.insert(key_to_move.clone(), value); + + Ok(()) +} + +/// Split a dotted field path into its components, splitting on unescaped dots +/// and removing the escaping. A literal `.` or `\` in a key is escaped with `\`. +fn split_field(field: &str) -> Vec { + let bytes = field.as_bytes(); + let mut result = Vec::new(); + let mut prev_start = 0; + + for (i, &b) in bytes.iter().enumerate() { + if b != b'.' { + continue; + } + // Count the run of backslashes immediately preceding the dot. The dot is + // escaped iff that count is odd. + let mut backslashes = 0; + let mut j = i; + while j > 0 && bytes[j - 1] == b'\\' { + backslashes += 1; + j -= 1; + } + if backslashes % 2 == 0 { + result.push(unescape(&field[prev_start..i])); + prev_start = i + 1; + } + } + + result.push(unescape(&field[prev_start..])); + result +} + +/// Remove field-path escaping: `\\` and `\.` collapse to the second character; +/// any other `\x` is left as-is. +fn unescape(s: &str) -> String { + let mut out = String::with_capacity(s.len()); + let mut chars = s.chars().peekable(); + while let Some(c) = chars.next() { + if c == '\\' { + match chars.peek() { + Some('\\') | Some('.') => { + out.push(*chars.peek().expect("just peeked")); + chars.next(); + } + _ => out.push('\\'), + } + } else { + out.push(c); + } + } + out +} + +#[cfg(test)] +mod tests { + use super::split_field; + + #[test] + fn test_split_field() { + // Ported from the Python `SplitFieldTestCase` that previously lived in + // tests/events/test_utils.py (removed alongside `_split_field`). + let cases: &[(&str, &[&str])] = &[ + // A field with no dots. + ("m", &["m"]), + // Simple dotted fields. + ("m.foo", &["m", "foo"]), + ("m.foo.bar", &["m", "foo", "bar"]), + // Backslash is used as an escape character. + (r"m\.foo", &["m.foo"]), + (r"m\\.foo", &["m\\", "foo"]), + (r"m\\\.foo", &[r"m\.foo"]), + (r"m\\\\.foo", &["m\\\\", "foo"]), + (r"m\foo", &[r"m\foo"]), + (r"m\\foo", &[r"m\foo"]), + (r"m\\\foo", &[r"m\\foo"]), + (r"m\\\\foo", &[r"m\\foo"]), + // Ensure that escapes at the end don't cause issues. + ("m.foo\\", &["m", "foo\\"]), + (r"m.foo\.", &["m", "foo."]), + (r"m.foo\\.", &["m", "foo\\", ""]), + (r"m.foo\\\.", &["m", r"foo\."]), + // Empty parts (corresponding to empty-string properties) are allowed. + (".m", &["", "m"]), + ("..m", &["", "", "m"]), + ("m.", &["m", ""]), + ("m..", &["m", "", ""]), + ("m..foo", &["m", "", "foo"]), + // Invalid escape sequences are left alone. + (r"\m", &[r"\m"]), + ]; + + for (input, expected) in cases { + let expected: Vec = expected.iter().map(|s| s.to_string()).collect(); + assert_eq!(split_field(input), expected, "split_field({input:?})"); + } + } +} diff --git a/rust/src/events/unsigned.rs b/rust/src/events/unsigned.rs index 931c412325b..8db0a338bd6 100644 --- a/rust/src/events/unsigned.rs +++ b/rust/src/events/unsigned.rs @@ -291,6 +291,14 @@ impl Unsigned { } } +impl Unsigned { + /// Get the `age_ts` field, which is used to generate the `age` field when + /// serializing an event. + pub fn age_ts(&self) -> PyResult> { + Ok(self.py_read()?.persisted_fields.age_ts.clone()) + } +} + fn room_state_to_py<'py>( py: Python<'py>, state: &[serde_json::Value], diff --git a/rust/src/types/mod.rs b/rust/src/types/mod.rs index ffb19a83a2e..f23b44da8e7 100644 --- a/rust/src/types/mod.rs +++ b/rust/src/types/mod.rs @@ -38,31 +38,31 @@ fn user_id_class(py: Python<'_>) -> PyResult<&Bound<'_, PyAny>> { /// Represents the user making a request. #[pyclass(frozen, skip_from_py_object, get_all, eq)] -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct Requester { /// The ID of the user making the request, in string form (see /// [`Self::user`] for accessing the parsed `UserID`). - user_id: String, + pub user_id: String, /// The ID of the access token used for this request, or None for /// appservices, guests, and tokens generated by the admin API - access_token_id: Option, + pub access_token_id: Option, /// True if the user making this request is a guest - is_guest: bool, + pub is_guest: bool, /// Any scopes associated with the access token used for this request, or an /// empty set if no token or a non-oauth token was used - scope: HashSet, + pub scope: HashSet, /// True if the user making this request is shadow banned - shadow_banned: bool, + pub shadow_banned: bool, /// The device_id which was set at authentication time, or None for /// appservices, guests, and tokens generated by the admin API - device_id: Option, + pub device_id: Option, /// The ID of the AS requesting on behalf of the user, or None. - app_service_id: Option, + pub app_service_id: Option, /// The entity that authenticated when making the request. /// /// This is different to the `user_id` when an admin user or the server is /// "puppeting" the user. - authenticated_entity: String, + pub authenticated_entity: String, } #[pymethods] diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 6303cde1826..5f6fb4971c6 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -40,7 +40,7 @@ TransactionUnusedFallbackKeys, ) from synapse.events import EventBase -from synapse.events.utils import FilteredEvent, SerializeEventConfig +from synapse.events.utils import FilteredEvent from synapse.http.client import SimpleHttpClient, is_unknown_endpoint from synapse.logging import opentracing from synapse.metrics import SERVER_NAME_LABEL @@ -560,7 +560,7 @@ async def _serialize( return await self._event_serializer.serialize_events( [FilteredEvent(event=e, membership=None) for e in events], time_now, - config=SerializeEventConfig( + config=await self._event_serializer.create_config( as_client_event=True, # If this is an invite or a knock membership event, then include # any stripped state alongside the event. We could narrow this diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 8ce795052a5..a396ad58b43 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -20,7 +20,6 @@ # # import collections.abc -import re from typing import ( TYPE_CHECKING, Any, @@ -28,7 +27,6 @@ Callable, Collection, Mapping, - Match, MutableMapping, ) @@ -40,27 +38,31 @@ CANONICALJSON_MIN_INT, MAX_PDU_SIZE, EventTypes, - EventUnsignedContentFields, - RelationTypes, ) from synapse.api.errors import Codes, SynapseError from synapse.logging.opentracing import SynapseTags, set_tag, trace -from synapse.synapse_rust.events import Unsigned, redact_event -from synapse.types import JsonDict, Requester +from synapse.synapse_rust.events import ( + EventFormat, + SerializeEventConfig, + Unsigned, + redact_event, + serialize_events, +) +from synapse.synapse_rust.types import Requester +from synapse.types import JsonDict from . import EventBase, StrippedStateEvent +# These are imported only to re-export them (callers import them from this +# module); listing them in __all__ stops the unused-import lint flagging them +# and re-exports them for `import *`. +__all__ = ["EventFormat", "SerializeEventConfig"] + if TYPE_CHECKING: from synapse.handlers.relations import BundledAggregations from synapse.server import HomeServer -# Split strings on "." but not "\." (or "\\\."). -SPLIT_FIELD_REGEX = re.compile(r"\\*\.") -# Find escaped characters, e.g. those with a \ in front of them. -ESCAPE_SEQUENCE_PATTERN = re.compile(r"\\(.)") - - # Module API callback that allows adding fields to the unsigned section of # events that are sent to clients. ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK = Callable[ @@ -90,177 +92,6 @@ def clone_event(event: EventBase) -> EventBase: return event.deep_copy() -def _copy_field(src: JsonDict, dst: JsonDict, field: list[str]) -> None: - """Copy the field in 'src' to 'dst'. - - For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"] - then dst={"foo":{"bar":5}}. - - Args: - src: The dict to read from. - dst: The dict to modify. - field: List of keys to drill down to in 'src'. - """ - if len(field) == 0: # this should be impossible - return - if len(field) == 1: # common case e.g. 'origin_server_ts' - if field[0] in src: - dst[field[0]] = src[field[0]] - return - - # Else is a nested field e.g. 'content.body' - # Pop the last field as that's the key to move across and we need the - # parent dict in order to access the data. Drill down to the right dict. - key_to_move = field.pop(-1) - sub_dict = src - for sub_field in field: # e.g. sub_field => "content" - if sub_field in sub_dict and isinstance( - sub_dict[sub_field], collections.abc.Mapping - ): - sub_dict = sub_dict[sub_field] - else: - return - - if key_to_move not in sub_dict: - return - - # Insert the key into the output dictionary, creating nested objects - # as required. We couldn't do this any earlier or else we'd need to delete - # the empty objects if the key didn't exist. - sub_out_dict = dst - for sub_field in field: - sub_out_dict = sub_out_dict.setdefault(sub_field, {}) - sub_out_dict[key_to_move] = sub_dict[key_to_move] - - -def _escape_slash(m: Match[str]) -> str: - """ - Replacement function; replace a backslash-backslash or backslash-dot with the - second character. Leaves any other string alone. - """ - if m.group(1) in ("\\", "."): - return m.group(1) - return m.group(0) - - -def _split_field(field: str) -> list[str]: - """ - Splits strings on unescaped dots and removes escaping. - - Args: - field: A string representing a path to a field. - - Returns: - A list of nested fields to traverse. - """ - - # Convert the field and remove escaping: - # - # 1. "content.body.thing\.with\.dots" - # 2. ["content", "body", "thing\.with\.dots"] - # 3. ["content", "body", "thing.with.dots"] - - # Find all dots (and their preceding backslashes). If the dot is unescaped - # then emit a new field part. - result = [] - prev_start = 0 - for match in SPLIT_FIELD_REGEX.finditer(field): - # If the match is an *even* number of characters than the dot was escaped. - if len(match.group()) % 2 == 0: - continue - - # Add a new part (up to the dot, exclusive) after escaping. - result.append( - ESCAPE_SEQUENCE_PATTERN.sub( - _escape_slash, field[prev_start : match.end() - 1] - ) - ) - prev_start = match.end() - - # Add any part of the field after the last unescaped dot. (Note that if the - # character is a dot this correctly adds a blank string.) - result.append(re.sub(r"\\(.)", _escape_slash, field[prev_start:])) - - return result - - -def only_fields(dictionary: JsonDict, fields: list[str]) -> JsonDict: - """Return a new dict with only the fields in 'dictionary' which are present - in 'fields'. - - If there are no event fields specified then all fields are included. - The entries may include '.' characters to indicate sub-fields. - So ['content.body'] will include the 'body' field of the 'content' object. - A literal '.' or '\' character in a field name may be escaped using a '\'. - - Args: - dictionary: The dictionary to read from. - fields: A list of fields to copy over. Only shallow refs are - taken. - Returns: - A new dictionary with only the given fields. If fields was empty, - the same dictionary is returned. - """ - if len(fields) == 0: - return dictionary - - # for each field, convert it: - # ["content.body.thing\.with\.dots"] => [["content", "body", "thing\.with\.dots"]] - split_fields = [_split_field(f) for f in fields] - - output: JsonDict = {} - for field_array in split_fields: - _copy_field(dictionary, output, field_array) - return output - - -def format_event_raw(d: JsonDict) -> JsonDict: - return d - - -def format_event_for_client_v1(d: JsonDict) -> JsonDict: - d = format_event_for_client_v2(d) - - sender = d.get("sender") - if sender is not None: - d["user_id"] = sender - - copy_keys = ( - "age", - "redacted_because", - "replaces_state", - "prev_content", - "invite_room_state", - "knock_room_state", - ) - for key in copy_keys: - if key in d["unsigned"]: - d[key] = d["unsigned"][key] - - return d - - -def format_event_for_client_v2(d: JsonDict) -> JsonDict: - drop_keys = ( - "auth_events", - "prev_events", - "hashes", - "signatures", - "depth", - "origin", - "prev_state", - ) - for key in drop_keys: - d.pop(key, None) - return d - - -def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict: - d = format_event_for_client_v2(d) - d.pop("room_id", None) - return d - - @attr.s(slots=True, frozen=True, auto_attribs=True) class FilteredEvent: """An event annotated with per-user data for client serialization. @@ -305,180 +136,6 @@ def admin_override(cls, event: "EventBase") -> "FilteredEvent": return cls(event=event, membership=None) -@attr.s(slots=True, frozen=True, auto_attribs=True) -class SerializeEventConfig: - as_client_event: bool = True - # Function to convert from federation format to client format - event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1 - # The entity that requested the event. This is used to determine whether to include - # the transaction_id and delay_id in the unsigned section of the event. - requester: Requester | None = None - # List of event fields to include. If empty, all fields will be returned. - only_event_fields: list[str] | None = attr.ib(default=None) - # Some events can have stripped room state stored in the `unsigned` field. - # This is required for invite and knock functionality. If this option is - # False, that state will be removed from the event before it is returned. - # Otherwise, it will be kept. - include_stripped_room_state: bool = False - # When True, sets unsigned fields to help clients identify events which - # only server admins can see through other configuration. For example, - # whether an event was soft failed by the server. - include_admin_metadata: bool = False - # Whether MSC4354 (sticky events) is enabled. When True, the sticky TTL - # will be computed and included in the unsigned section of sticky events. - msc4354_enabled: bool = False - - @only_event_fields.validator - def _validate_only_event_fields( - self, attribute: attr.Attribute, value: Any - ) -> None: - if value is None: - return - - if not isinstance(value, list) or not all(isinstance(f, str) for f in value): - raise TypeError("only_event_fields must be a list of strings") - - -_DEFAULT_SERIALIZE_EVENT_CONFIG = SerializeEventConfig() - - -def make_config_for_admin(existing: SerializeEventConfig) -> SerializeEventConfig: - # Set the options which are only available to server admins, - # and copy the rest. - return attr.evolve(existing, include_admin_metadata=True) - - -def _serialize_event( - e: JsonDict | EventBase, - time_now_ms: int, - *, - config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, - membership: str | None = None, -) -> JsonDict: - """Serialize event for clients - - Args: - e - time_now_ms - config: Event serialization config - membership: The requesting user's membership at the time of the event, - to be injected into unsigned.membership (MSC4115). - - Returns: - The serialized event dictionary. - """ - - # FIXME(erikj): To handle the case of presence events and the like - if not isinstance(e, EventBase): - return e - - time_now_ms = int(time_now_ms) - - # Should this strip out None's? - d = dict(e.get_dict().items()) - - d["event_id"] = e.event_id - - if "age_ts" in d["unsigned"]: - d["unsigned"]["age"] = time_now_ms - d["unsigned"]["age_ts"] - del d["unsigned"]["age_ts"] - - # If we have applicable fields saved in the internal_metadata, include them in the - # unsigned section of the event if the event was sent by the same session (or when - # appropriate, just the same sender) as the one requesting the event. - if config.requester is not None and config.requester.user.to_string() == e.sender: - txn_id: str | None = getattr(e.internal_metadata, "txn_id", None) - if txn_id is not None: - # Some events do not have the device ID stored in the internal metadata, - # this includes old events as well as those created by appservice, guests, - # or with tokens minted with the admin API. For those events, fallback - # to using the access token instead. - event_device_id: str | None = getattr( - e.internal_metadata, "device_id", None - ) - if event_device_id is not None: - if event_device_id == config.requester.device_id: - d["unsigned"]["transaction_id"] = txn_id - - else: - # Fallback behaviour: only include the transaction ID if the event - # was sent from the same access token. - # - # For regular users, the access token ID can be used to determine this. - # This includes access tokens minted with the admin API. - # - # For guests and appservice users, we can't check the access token ID - # so assume it is the same session. - event_token_id: int | None = getattr( - e.internal_metadata, "token_id", None - ) - if ( - ( - event_token_id is not None - and config.requester.access_token_id is not None - and event_token_id == config.requester.access_token_id - ) - or config.requester.is_guest - or config.requester.app_service_id - ): - d["unsigned"]["transaction_id"] = txn_id - - delay_id: str | None = getattr(e.internal_metadata, "delay_id", None) - if delay_id is not None: - d["unsigned"]["org.matrix.msc4140.delay_id"] = delay_id - - # invite_room_state and knock_room_state are a list of stripped room state events - # that are meant to provide metadata about a room to an invitee/knocker. They are - # intended to only be included in specific circumstances, such as down sync, and - # should not be included in any other case. - if not config.include_stripped_room_state: - d["unsigned"].pop("invite_room_state", None) - d["unsigned"].pop("knock_room_state", None) - - if config.as_client_event: - d = config.event_format(d) - - # Ensure the room_id field is set for create events in MSC4291 rooms - if e.type == EventTypes.Create and e.room_version.msc4291_room_ids_as_hashes: - d["room_id"] = e.room_id - - # If the event is a redaction, the field with the redacted event ID appears - # in a different location depending on the room version. e.redacts handles - # fetching from the proper location; copy it to the other location for forwards- - # and backwards-compatibility with clients. - if e.type == EventTypes.Redaction and e.redacts is not None: - if e.room_version.updated_redaction_rules: - d["redacts"] = e.redacts - else: - d["content"] = dict(d["content"]) - d["content"]["redacts"] = e.redacts - - if config.include_admin_metadata: - if e.internal_metadata.is_soft_failed(): - d["unsigned"]["io.element.synapse.soft_failed"] = True - if e.internal_metadata.policy_server_spammy: - d["unsigned"]["io.element.synapse.policy_server_spammy"] = True - - if config.msc4354_enabled: - sticky_duration = e.sticky_duration() - if sticky_duration: - expires_at = ( - # min() ensures that the origin server can't lie about the time and - # send the event 'in the future', as that would allow them to exceed - # the 1 hour limit on stickiness duration. - min(e.origin_server_ts, time_now_ms) + sticky_duration.as_millis() - ) - if expires_at > time_now_ms: - d["unsigned"][EventUnsignedContentFields.STICKY_TTL] = ( - expires_at - time_now_ms - ) - - if membership is not None: - d["unsigned"][EventUnsignedContentFields.MEMBERSHIP] = membership - - return d - - class EventClientSerializer: """Serializes events that are to be sent to clients. @@ -495,12 +152,65 @@ def __init__(self, hs: "HomeServer") -> None: ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK ] = [] + async def create_config( + self, + *, + as_client_event: bool = True, + event_format: EventFormat = EventFormat.ClientV1, + requester: Requester | None = None, + event_field_allowlist: list[str] | None = None, + include_stripped_room_state: bool = False, + include_admin_metadata: bool | None = None, + ) -> SerializeEventConfig: + """ + Create a new SerializeEventConfig for the given parameters. + + Helper method that sets the `include_admin_metadata` field based on + whether the requester is a server admin if it is not explicitly + provided. Also sets the `msc4354_enabled` field based on the homeserver + config. + + Args: + as_client_event: Whether to serialize the events as client events. + event_format: The format to serialize events in. requester: The user + requesting the events, if any. Used to determine + whether to include admin-only metadata in the serialized events. + event_field_allowlist: A list of event fields to include in the + serialized events. + include_stripped_room_state: Whether to include stripped room state + in the serialized events. + include_admin_metadata: Whether to include admin-only metadata in + the serialized events. If None, this will be determined based on + whether the requester is a server admin. + Returns: + A SerializeEventConfig instance. + """ + + # If include_admin_metadata is None, determine whether to include + # admin-only metadata based on the requester. + if include_admin_metadata is None: + # Check if the requester is a server admin. + if requester is not None and await self._auth.is_server_admin(requester): + include_admin_metadata = True + else: + include_admin_metadata = False + + return SerializeEventConfig( + as_client_event=as_client_event, + event_format=event_format, + requester=requester, + event_field_allowlist=event_field_allowlist, + include_stripped_room_state=include_stripped_room_state, + include_admin_metadata=include_admin_metadata, + msc4354_enabled=self._config.experimental.msc4354_enabled, + ) + async def serialize_event( self, event: JsonDict | FilteredEvent, time_now: int, *, - config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, + config: SerializeEventConfig | None = None, bundle_aggregations: dict[str, "BundledAggregations"] | None = None, redaction_map: Mapping[str, "EventBase"] | None = None, ) -> JsonDict: @@ -522,154 +232,105 @@ async def serialize_event( if not isinstance(event, FilteredEvent): return event - # Force-enable server admin metadata because the only time an event with - # relevant metadata will be when the admin requested it via their admin - # client config account data. Also, it's "just" some `unsigned` fields, so - # shouldn't cause much in terms of problems to downstream consumers. - if config.requester is not None and await self._auth.is_server_admin( - config.requester - ): - config = make_config_for_admin(config) - - if self._config.experimental.msc4354_enabled: - config = attr.evolve(config, msc4354_enabled=True) + if config is None: + # Generate default config if none was provided. + config = await self.create_config() - serialized_event = _serialize_event( - event.event, time_now, config=config, membership=event.membership + # Perform all the async DB/IO work up front, then run the synchronous + # serialization core. + redaction_map, unsigned_additions = await self._prepare_serialization( + [event], bundle_aggregations, redaction_map ) - # If the event was redacted, fetch the redaction event from the database - # and include it in the serialized event's unsigned section. - redacted_by: str | None = event.event.internal_metadata.redacted_by - if redacted_by is not None: - serialized_event.setdefault("unsigned", {})["redacted_by"] = redacted_by - if redaction_map is not None: - redaction_event: EventBase | None = redaction_map.get(redacted_by) - else: - redaction_event = await self._store.get_event( - redacted_by, - allow_none=True, - ) - if redaction_event is not None: - serialized_redaction = _serialize_event( - redaction_event, time_now, config=config - ) - serialized_event.setdefault("unsigned", {})["redacted_because"] = ( - serialized_redaction - ) - # format_event_for_client_v1 copies redacted_because to the - # top level, but since we add it after that runs, do it here. - if ( - config.as_client_event - and config.event_format is format_event_for_client_v1 - ): - serialized_event["redacted_because"] = serialized_redaction - - new_unsigned = {} - for callback in self._add_extra_fields_to_unsigned_client_event_callbacks: - u = await callback(event.event) - new_unsigned.update(u) - - if new_unsigned: - # We do the `update` this way round so that modules can't clobber - # existing fields. - new_unsigned.update(serialized_event["unsigned"]) - serialized_event["unsigned"] = new_unsigned - - # Only include fields that the client has requested. - # - # Note: we always return bundled aggregations, though it is unclear why. - only_event_fields = config.only_event_fields - if only_event_fields: - serialized_event = only_fields(serialized_event, only_event_fields) - - # Check if there are any bundled aggregations to include with the event. - if bundle_aggregations: - if event.event.event_id in bundle_aggregations: - await self._inject_bundled_aggregations( - event.event, - time_now, - config, - bundle_aggregations, - serialized_event, - ) - - return serialized_event - - async def _inject_bundled_aggregations( + return serialize_events( + [(event.event, event.membership)], + time_now, + config, + bundle_aggregations=bundle_aggregations, + redaction_map=redaction_map, + unsigned_additions=unsigned_additions, + )[0] + + async def _prepare_serialization( self, - event: EventBase, - time_now: int, - config: SerializeEventConfig, - bundled_aggregations: dict[str, "BundledAggregations"], - serialized_event: JsonDict, - ) -> None: - """Potentially injects bundled aggregations into the unsigned portion of the serialized event. + events: Collection[FilteredEvent], + bundle_aggregations: dict[str, "BundledAggregations"] | None, + redaction_map: Mapping[str, "EventBase"] | None = None, + ) -> tuple[dict[str, "EventBase"], dict[str, JsonDict]]: + """Perform all the async DB/IO work needed to serialize `events`. + + Does two things: + 1. Fetches any redaction events needed to serialize `events` (and any + bundled events) and returns a map from redaction event_id to event. + 2. Runs the module callbacks for each event to build up the additional + `unsigned` fields they contribute. Args: - event: The event being serialized. - time_now: The current time in milliseconds - config: Event serialization config - bundled_aggregations: Bundled aggregations to be injected. - A map from event_id to aggregation data. Must contain at least an - entry for `event`. + events: The events that will be serialized. + bundle_aggregations: A map from event_id to the aggregations to be + bundled into the event. Used to discover the sub-events (edits + and thread latest events) that will also be serialized. + redaction_map: An optional caller-supplied map from redaction + event_id to the redaction event. Any redactions already present + here are not re-fetched, and these entries take precedence over + anything we fetch ourselves. - While serializing the bundled aggregations this map may be searched - again for additional events in a recursive manner. - serialized_event: The serialized event which may be modified. + Returns: + A tuple of: + - a map from redaction event_id to the redaction event, + - a map from event_id to the extra `unsigned` fields contributed + by the registered module callbacks. """ - # We have already checked that aggregations exist for this event. - event_aggregations = bundled_aggregations[event.event_id] + # First we collect all events that get included in the serialization of + # `events`, including the events themselves and any bundled events (edits + # and thread latest events, which are themselves serialized). + collected = {e.event.event_id: e.event for e in events} + if bundle_aggregations is not None: + for aggregation in bundle_aggregations.values(): + if aggregation.replace: + collected[aggregation.replace.event_id] = aggregation.replace + if aggregation.thread: + latest_event = aggregation.thread.latest_event + collected[latest_event.event_id] = latest_event + + # Next, check the redaction status of all events, and fetch the + # redactions if needed. + redaction_map = redaction_map or {} + + redaction_ids_to_fetch = { + redacted_by + for collected_event in collected.values() + if (redacted_by := collected_event.internal_metadata.redacted_by) + is not None + and redacted_by not in redaction_map + } + + if redaction_ids_to_fetch: + fetched_redaction_map = await self._store.get_events(redaction_ids_to_fetch) + else: + fetched_redaction_map = {} - # The JSON dictionary to be added under the unsigned property of the event - # being serialized. - serialized_aggregations = {} + # Ensure the returned redaction map includes any caller-supplied + # redactions + fetched_redaction_map.update(redaction_map) - if event_aggregations.references: - serialized_aggregations[RelationTypes.REFERENCE] = ( - event_aggregations.references - ) + # Run the module callbacks for each event (once per event_id, since + # `collected` is already de-duplicated) to build up the additional + # `unsigned` fields they contribute. + unsigned_additions: dict[str, JsonDict] = {} + if self._add_extra_fields_to_unsigned_client_event_callbacks: + for collected_event in collected.values(): + new_unsigned: JsonDict = {} + for ( + callback + ) in self._add_extra_fields_to_unsigned_client_event_callbacks: + new_unsigned.update(await callback(collected_event)) - if event_aggregations.replace: - # Include information about it in the relations dict. - # - # Matrix spec v1.5 (https://spec.matrix.org/v1.5/client-server-api/#server-side-aggregation-of-mreplace-relationships) - # said that we should only include the `event_id`, `origin_server_ts` and - # `sender` of the edit; however MSC3925 proposes extending it to the whole - # of the edit, which is what we do here. - serialized_aggregations[RelationTypes.REPLACE] = await self.serialize_event( - FilteredEvent(event=event_aggregations.replace, membership=None), - time_now, - config=config, - ) - - # Include any threaded replies to this event. - if event_aggregations.thread: - thread = event_aggregations.thread - - serialized_latest_event = await self.serialize_event( - FilteredEvent(event=thread.latest_event, membership=None), - time_now, - config=config, - bundle_aggregations=bundled_aggregations, - ) + if new_unsigned: + unsigned_additions[collected_event.event_id] = new_unsigned - thread_summary = { - "latest_event": serialized_latest_event, - "count": thread.count, - "current_user_participated": thread.current_user_participated, - } - serialized_aggregations[RelationTypes.THREAD] = thread_summary - - # Include the bundled aggregations in the event. - if serialized_aggregations: - # There is likely already an "unsigned" field, but a filter might - # have stripped it off (via the event_fields option). The server is - # allowed to return additional fields, so add it back. - serialized_event.setdefault("unsigned", {}).setdefault( - "m.relations", {} - ).update(serialized_aggregations) + return fetched_redaction_map, unsigned_additions @trace async def serialize_events( @@ -677,7 +338,7 @@ async def serialize_events( events: Collection[JsonDict | FilteredEvent], time_now: int, *, - config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, + config: SerializeEventConfig | None = None, bundle_aggregations: dict[str, "BundledAggregations"] | None = None, ) -> list[JsonDict]: """Serializes multiple events. @@ -698,26 +359,33 @@ async def serialize_events( str(len(events)), ) - # Batch-fetch all redaction events in one go rather than one per event. - redaction_ids: set[str] = set() - for e in events: - base = e.event if isinstance(e, FilteredEvent) else e - if isinstance(base, EventBase): - redacted_by = base.internal_metadata.redacted_by - if redacted_by is not None: - redaction_ids.add(redacted_by) - redaction_map = ( - await self._store.get_events(redaction_ids) if redaction_ids else {} + if config is None: + # Generate default config if none was provided. + config = await self.create_config() + + filtered_events = [e for e in events if isinstance(e, FilteredEvent)] + + # Perform all the async DB/IO work up front, then run the synchronous + # serialization core for the whole batch in one go. + redaction_map, unsigned_additions = await self._prepare_serialization( + filtered_events, bundle_aggregations ) - return [ - await self.serialize_event( - event, + serialized = iter( + serialize_events( + [(e.event, e.membership) for e in filtered_events], time_now, - config=config, + config, bundle_aggregations=bundle_aggregations, redaction_map=redaction_map, + unsigned_additions=unsigned_additions, ) + ) + + # Stitch the serialized events back in, passing through anything that + # wasn't a FilteredEvent (e.g. presence events) unchanged. + return [ + event if not isinstance(event, FilteredEvent) else next(serialized) for event in events ] diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 2518716bc70..56ded1634ea 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -25,7 +25,7 @@ from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState from synapse.api.errors import AuthError, SynapseError -from synapse.events.utils import FilteredEvent, SerializeEventConfig +from synapse.events.utils import FilteredEvent from synapse.handlers.presence import format_user_presence_state from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.streams.config import PaginationConfig @@ -129,7 +129,7 @@ async def get_stream( chunks = await self._event_serializer.serialize_events( events, time_now, - config=SerializeEventConfig( + config=await self._event_serializer.create_config( as_client_event=as_client_event, requester=requester ), ) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 591a0aefd33..56f4d86d414 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -30,7 +30,7 @@ Membership, ) from synapse.api.errors import SynapseError -from synapse.events.utils import FilteredEvent, SerializeEventConfig +from synapse.events.utils import FilteredEvent from synapse.events.validator import EventValidator from synapse.handlers.presence import format_user_presence_state from synapse.handlers.receipts import ReceiptEventSource @@ -169,7 +169,9 @@ async def _snapshot_all_rooms( public_room_ids = await self.store.get_public_room_ids() - serializer_options = SerializeEventConfig(as_client_event=as_client_event) + serializer_options = await self._event_serializer.create_config( + as_client_event=as_client_event + ) async def handle_room(event: RoomsForUser) -> None: d: JsonDict = { @@ -395,7 +397,9 @@ async def _room_initial_sync_parted( end_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, stream_token) time_now = self.clock.time_msec() - serialize_options = SerializeEventConfig(requester=requester) + serialize_options = await self._event_serializer.create_config( + requester=requester + ) return { "membership": membership, @@ -436,7 +440,9 @@ async def _room_initial_sync_joined( # TODO: These concurrently time_now = self.clock.time_msec() - serialize_options = SerializeEventConfig(requester=requester) + serialize_options = await self._event_serializer.create_config( + requester=requester + ) # Don't bundle aggregations as this is a deprecated API. state = await self._event_serializer.serialize_events( [FilteredEvent.state(e) for e in current_state.values()], diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 6bdef7e2025..b34ee9d50f7 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -64,7 +64,6 @@ ) from synapse.events.utils import ( FilteredEvent, - SerializeEventConfig, maybe_upsert_event_field, ) from synapse.events.validator import EventValidator @@ -269,7 +268,7 @@ async def get_state_events( events = await self._event_serializer.serialize_events( [FilteredEvent.state(e) for e in room_state.values()], self.clock.time_msec(), - config=SerializeEventConfig(requester=requester), + config=await self._event_serializer.create_config(requester=requester), ) return events diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index ee4f8d672ee..a8db082febf 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -28,16 +28,22 @@ Sequence, ) -import attr - from synapse.api.constants import Direction, EventTypes, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event -from synapse.events.utils import FilteredEvent, SerializeEventConfig +from synapse.events.utils import FilteredEvent from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import trace from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent from synapse.streams.config import PaginationConfig + +# `BundledAggregations` and `ThreadAggregation` are implemented in Rust; they +# are re-exported here so existing call sites can keep importing them from +# `synapse.handlers.relations`. +from synapse.synapse_rust.events import ( # noqa: F401 + BundledAggregations, + ThreadAggregation, +) from synapse.types import JsonDict, Requester, UserID from synapse.util.async_helpers import gather_results from synapse.visibility import filter_and_transform_events_for_client @@ -56,32 +62,6 @@ class ThreadsListInclude(str, enum.Enum): participated = "participated" -@attr.s(slots=True, frozen=True, auto_attribs=True) -class _ThreadAggregation: - # The latest event in the thread. - latest_event: EventBase - # The total number of events in the thread. - count: int - # True if the current user has sent an event to the thread. - current_user_participated: bool - - -@attr.s(slots=True, auto_attribs=True) -class BundledAggregations: - """ - The bundled aggregations for an event. - - Some values require additional processing during serialization. - """ - - references: JsonDict | None = None - replace: EventBase | None = None - thread: _ThreadAggregation | None = None - - def __bool__(self) -> bool: - return bool(self.references or self.replace or self.thread) - - class RelationsHandler: def __init__(self, hs: "HomeServer"): self._main_store = hs.get_datastores().main @@ -170,7 +150,9 @@ async def get_relations( ) now = self._clock.time_msec() - serialize_options = SerializeEventConfig(requester=requester) + serialize_options = await self._event_serializer.create_config( + requester=requester + ) return_value: JsonDict = { "chunk": await self._event_serializer.serialize_events( filtered_events, @@ -310,7 +292,7 @@ async def _get_threads_for_events( relations_by_id: dict[str, str], user_id: str, ignored_users: frozenset[str], - ) -> dict[str, _ThreadAggregation]: + ) -> dict[str, ThreadAggregation]: """Get the bundled aggregations for threads for the requested events. Args: @@ -421,7 +403,7 @@ async def _get_threads_for_events( continue latest_thread_event = event.event - results[event_id] = _ThreadAggregation( + results[event_id] = ThreadAggregation( latest_event=latest_thread_event, count=thread_count, # If there's a thread summary it must also exist in the @@ -478,8 +460,12 @@ async def get_bundled_aggregations( # The event should get bundled aggregations. events_by_id[event.event_id] = event - # event ID -> bundled aggregation in non-serialized form. - results: dict[str, BundledAggregations] = {} + # `BundledAggregations` is immutable, so we collect each kind of + # aggregation into its own map keyed by event ID and assemble the + # results once everything has been fetched. + thread_by_id: dict[str, ThreadAggregation] = {} + references_by_id: dict[str, JsonDict] = {} + replace_by_id: dict[str, EventBase] = {} # Fetch any ignored users of the requesting user. ignored_users = await self._main_store.ignored_users(user_id) @@ -495,7 +481,7 @@ async def get_bundled_aggregations( ignored_users, ) for event_id, thread in threads.items(): - results.setdefault(event_id, BundledAggregations()).thread = thread + thread_by_id[event_id] = thread # If the latest event in a thread is not already being fetched, # add it. This ensures that the bundled aggregations for the @@ -516,7 +502,7 @@ async def _fetch_references() -> None: ) for event_id, references in references_by_event_id.items(): if references: - results.setdefault(event_id, BundledAggregations()).references = { + references_by_id[event_id] = { "chunk": [{"event_id": ev.event_id} for ev in references] } @@ -535,7 +521,13 @@ async def _fetch_edits() -> None: ] ) for event_id, edit in edits.items(): - results.setdefault(event_id, BundledAggregations()).replace = edit + # `get_applicable_edits` returns `None` for events with no + # applicable edit. Skip those rather than recording an entry: a + # `None` replace contributes nothing during serialization, so + # the old code's empty `BundledAggregations` for such events was + # inert anyway. + if edit is not None: + replace_by_id[event_id] = edit # Parallelize the calls for annotations, references, and edits since they # are unrelated. @@ -548,7 +540,17 @@ async def _fetch_edits() -> None: ) ) - return results + # Assemble one (immutable) bundled aggregation per event that has any. + return { + event_id: BundledAggregations( + references=references_by_id.get(event_id), + replace=replace_by_id.get(event_id), + thread=thread_by_id.get(event_id), + ) + for event_id in thread_by_id.keys() + | references_by_id.keys() + | replace_by_id.keys() + } async def get_threads( self, diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 30e072d011e..eb0492ff595 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -29,7 +29,7 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, SynapseError from synapse.api.filtering import Filter -from synapse.events.utils import FilteredEvent, SerializeEventConfig +from synapse.events.utils import FilteredEvent from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType, UserID from synapse.types.state import StateFilter from synapse.visibility import filter_and_transform_events_for_client @@ -377,7 +377,9 @@ async def _search( # blocking calls after this. Otherwise, the 'age' will be wrong. time_now = self.clock.time_msec() - serialize_options = SerializeEventConfig(requester=requester) + serialize_options = await self._event_serializer.create_config( + requester=requester + ) for context in contexts.values(): context["events_before"] = await self._event_serializer.serialize_events( diff --git a/synapse/rest/admin/events.py b/synapse/rest/admin/events.py index 1c311b04713..7dbd7f5d2be 100644 --- a/synapse/rest/admin/events.py +++ b/synapse/rest/admin/events.py @@ -3,9 +3,8 @@ from synapse.api.errors import NotFoundError from synapse.events.utils import ( + EventFormat, FilteredEvent, - SerializeEventConfig, - format_event_raw, ) from synapse.http.servlet import RestServlet from synapse.http.site import SynapseRequest @@ -57,11 +56,11 @@ async def on_GET( if event is None: raise NotFoundError("Event not found") - config = SerializeEventConfig( + config = await self._event_serializer.create_config( as_client_event=False, - event_format=format_event_raw, + event_format=EventFormat.Raw, requester=requester, - only_event_fields=None, + event_field_allowlist=None, include_stripped_room_state=True, include_admin_metadata=True, ) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index f6693e09236..e47b6e9efee 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -1028,7 +1028,7 @@ async def on_GET( ): as_client_event = False - serialize_options = SerializeEventConfig( + serialize_options = await self._event_serializer.create_config( as_client_event=as_client_event, requester=requester ) diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py index de73c96fd0d..f5b894038ed 100644 --- a/synapse/rest/client/events.py +++ b/synapse/rest/client/events.py @@ -25,7 +25,6 @@ from typing import TYPE_CHECKING from synapse.api.errors import SynapseError -from synapse.events.utils import SerializeEventConfig from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_string from synapse.http.site import SynapseRequest @@ -104,7 +103,7 @@ async def on_GET( result = await self._event_serializer.serialize_event( event, self.clock.time_msec(), - config=SerializeEventConfig(requester=requester), + config=await self._event_serializer.create_config(requester=requester), ) return 200, result else: diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index f80a43b2978..ae3893d2965 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -24,9 +24,8 @@ from synapse.api.constants import ReceiptTypes from synapse.events.utils import ( + EventFormat, FilteredEvent, - SerializeEventConfig, - format_event_for_client_v2_without_room_id, ) from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_integer, parse_string @@ -98,8 +97,8 @@ async def on_GET(self, request: SynapseRequest) -> tuple[int, JsonDict]: next_token = None - serialize_options = SerializeEventConfig( - event_format=format_event_for_client_v2_without_room_id, + serialize_options = await self._event_serializer.create_config( + event_format=EventFormat.ClientV2WithoutRoomId, requester=requester, ) now = self.clock.time_msec() diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 6e00197b6de..36f638e236f 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -53,9 +53,9 @@ from synapse.api.filtering import Filter from synapse.events.utils import ( EventClientSerializer, + EventFormat, FilteredEvent, SerializeEventConfig, - format_event_for_client_v2, ) from synapse.handlers.pagination import GetMessagesResult from synapse.http.server import HttpServer @@ -289,8 +289,8 @@ async def on_GET( event = await self._event_serializer.serialize_event( FilteredEvent.state(data), self.clock.time_msec(), - config=SerializeEventConfig( - event_format=format_event_for_client_v2, + config=await self._event_serializer.create_config( + event_format=EventFormat.ClientV2, requester=requester, ), ) @@ -925,7 +925,7 @@ async def on_GET( ): as_client_event = False - serialize_options = SerializeEventConfig( + serialize_options = await self.event_serializer.create_config( as_client_event=as_client_event, requester=requester ) @@ -1114,7 +1114,7 @@ async def on_GET( event, self.clock.time_msec(), bundle_aggregations=aggregations, - config=SerializeEventConfig(requester=requester), + config=await self._event_serializer.create_config(requester=requester), ) return 200, event_dict @@ -1154,7 +1154,9 @@ async def on_GET( raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() - serializer_options = SerializeEventConfig(requester=requester) + serializer_options = await self._event_serializer.create_config( + requester=requester + ) results = { "events_before": await self._event_serializer.serialize_events( event_context.events_before, diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 702ddcd6ca1..962317dedbe 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -30,10 +30,9 @@ from synapse.api.presence import UserPresenceState from synapse.api.ratelimiting import Ratelimiter from synapse.events.utils import ( + EventFormat, FilteredEvent, SerializeEventConfig, - format_event_for_client_v2_without_room_id, - format_event_raw, ) from synapse.handlers.presence import format_user_presence_state from synapse.handlers.sliding_sync import SlidingSyncConfig, SlidingSyncResult @@ -304,18 +303,18 @@ async def encode_response( ) -> JsonDict: logger.debug("Formatting events in sync response") if filter.event_format == "client": - event_formatter = format_event_for_client_v2_without_room_id + event_formatter = EventFormat.ClientV2WithoutRoomId elif filter.event_format == "federation": - event_formatter = format_event_raw + event_formatter = EventFormat.Raw else: raise Exception("Unknown event format %s" % (filter.event_format,)) - serialize_options = SerializeEventConfig( + serialize_options = await self._event_serializer.create_config( event_format=event_formatter, requester=requester, - only_event_fields=filter.event_fields, + event_field_allowlist=filter.event_fields, ) - stripped_serialize_options = SerializeEventConfig( + stripped_serialize_options = await self._event_serializer.create_config( event_format=event_formatter, requester=requester, include_stripped_room_state=True, @@ -931,8 +930,8 @@ async def encode_rooms( ) -> JsonDict: time_now = self.clock.time_msec() - serialize_options = SerializeEventConfig( - event_format=format_event_for_client_v2_without_room_id, + serialize_options = await self.event_serializer.create_config( + event_format=EventFormat.ClientV2WithoutRoomId, requester=requester, ) @@ -1158,8 +1157,8 @@ async def _serialise_sticky_events( time_now = self.clock.time_msec() # Same as SSS timelines. # - serialize_options = SerializeEventConfig( - event_format=format_event_for_client_v2_without_room_id, + serialize_options = await self.event_serializer.create_config( + event_format=EventFormat.ClientV2WithoutRoomId, requester=requester, ) diff --git a/synapse/synapse_rust/events.pyi b/synapse/synapse_rust/events.pyi index f84eeb55d65..fc9173acb24 100644 --- a/synapse/synapse_rust/events.pyi +++ b/synapse/synapse_rust/events.pyi @@ -13,7 +13,7 @@ from typing import Any, Iterator, Mapping from synapse.synapse_rust.room_versions import RoomVersion -from synapse.types import JsonDict, JsonMapping, StrSequence +from synapse.types import JsonDict, JsonMapping, Requester, StrSequence from synapse.util.duration import Duration class EventInternalMetadata: @@ -309,6 +309,139 @@ class Event: ``SynapseDuration`` representing the sticky duration. Otherwise returns ``None``.""" +class ThreadAggregation: + """The bundled thread summary for an event.""" + + def __init__( + self, + latest_event: Event, + count: int, + current_user_participated: bool, + ) -> None: ... + @property + def latest_event(self) -> Event: + """The latest event in the thread.""" + + @property + def count(self) -> int: + """The total number of events in the thread.""" + + @property + def current_user_participated(self) -> bool: + """Whether the requesting user has sent an event to the thread.""" + +class BundledAggregations: + """The bundled aggregations for an event. + + Some values require additional processing during serialization. + """ + + def __init__( + self, + references: JsonMapping | None = None, + replace: Event | None = None, + thread: ThreadAggregation | None = None, + ) -> None: ... + @property + def references(self) -> JsonMapping | None: ... + @property + def replace(self) -> Event | None: ... + @property + def thread(self) -> ThreadAggregation | None: ... + def __bool__(self) -> bool: ... + +class EventFormat: + """The format used to convert an event to the shape sent to clients.""" + + Raw: EventFormat + ClientV1: EventFormat + ClientV2: EventFormat + ClientV2WithoutRoomId: EventFormat + +class SerializeEventConfig: + """Configuration for serializing an event for clients.""" + + def __init__( + self, + *, + as_client_event: bool, + event_format: EventFormat, + requester: Requester | None, + event_field_allowlist: list[str] | None, + include_stripped_room_state: bool, + include_admin_metadata: bool, + msc4354_enabled: bool, + ) -> None: ... + @property + def as_client_event(self) -> bool: + """Whether to apply the client event format transform (v1/v2/raw). When + ``False``, the federation-format event is returned as-is.""" + + @property + def event_format(self) -> EventFormat: + """Which client event format variant to apply (only used when + ``as_client_event`` is ``True``).""" + + @property + def requester(self) -> Requester | None: + """The entity requesting the event. Used to gate sender-only fields such + as ``transaction_id`` and ``delay_id``.""" + + @property + def event_field_allowlist(self) -> list[str] | None: + """If set, only include these field paths in the output. An empty list + returns an empty event; ``None`` returns all fields. + + The fields can be "dotted" fields, e.g. ``content.body``.""" + + @property + def include_stripped_room_state(self) -> bool: + """Whether to include ``invite_room_state`` / ``knock_room_state`` in + ``unsigned``. These are stripped by default and only included for + specific endpoints (e.g. ``/sync`` invite/knock handling).""" + + @property + def include_admin_metadata(self) -> bool: + """When ``True``, add server-admin-only metadata to ``unsigned`` + (``io.element.synapse.soft_failed``, + ``io.element.synapse.policy_server_spammy``).""" + + @property + def msc4354_enabled(self) -> bool: + """Whether MSC4354 (sticky events) is enabled. When ``True``, the + remaining stickiness TTL is computed and added to ``unsigned``.""" + +def serialize_events( + events: list[tuple[Event, str | None]], + time_now_ms: int, + config: SerializeEventConfig, + *, + bundle_aggregations: Mapping[str, BundledAggregations] | None = None, + redaction_map: Mapping[str, Event] | None = None, + unsigned_additions: Mapping[str, JsonDict] | None = None, +) -> list[JsonDict]: + """Synchronously serialize a batch of events for clients using pre-fetched data. + + All DB/IO must already have been done by the caller; the keyword maps below + are all keyed by event ID and shared across the whole batch. + + Args: + events: The events to serialize, as `(event, membership)` pairs. + `membership` is the requesting user's membership at the time of the + event, injected into `unsigned.membership` (MSC4115). + time_now_ms: The current time in milliseconds. + config: The serialization config. + bundle_aggregations: Map from event_id to the `BundledAggregations` to + bundle into the event's `unsigned.m.relations`. + redaction_map: Map from redaction event_id to the redaction `Event`, + used to populate `unsigned.redacted_because` for redacted events. + unsigned_additions: Map from event_id to extra `unsigned` fields + contributed by module callbacks. + + Returns: + The serialized events, in the same order as `events`. + """ + def redact_event(event: Event) -> Event: """Returns a pruned version of the given event, which removes all keys we don't know about or think could potentially be dodgy. diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 8f78ae49448..b45f7ccad0c 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -22,24 +22,18 @@ import unittest as stdlib_unittest from typing import TYPE_CHECKING, Any, Mapping -from parameterized import parameterized - from synapse.api.constants import EventContentFields from synapse.api.room_versions import RoomVersions from synapse.events import EventBase from synapse.events.utils import ( FilteredEvent, PowerLevelsContent, - SerializeEventConfig, - _split_field, clone_event, copy_and_fixup_power_levels_contents, - format_event_raw, - make_config_for_admin, maybe_upsert_event_field, prune_event, ) -from synapse.types import JsonDict, create_requester +from synapse.types import JsonDict from synapse.util.frozenutils import freeze from tests.test_utils.event_builders import make_test_event @@ -665,9 +659,11 @@ def serialize( self._event_serializer.serialize_event( FilteredEvent(event=ev, membership=None), 1479807801915, - config=SerializeEventConfig( - only_event_fields=fields, - include_admin_metadata=include_admin_metadata, + config=self.get_success( + self._event_serializer.create_config( + event_field_allowlist=fields, + include_admin_metadata=include_admin_metadata, + ) ), redaction_map=redaction_map, ) @@ -788,13 +784,19 @@ def test_event_fields_all_fields_if_empty(self) -> None: def test_event_fields_fail_if_fields_not_str(self) -> None: with self.assertRaises(TypeError): - SerializeEventConfig( - only_event_fields=["room_id", 4], # type: ignore[list-item] + self.get_success_or_raise( + self._event_serializer.create_config( + event_field_allowlist=["room_id", 4], # type: ignore[list-item] + ) ) def test_default_serialize_config_excludes_admin_metadata(self) -> None: # We just really don't want this to be set to True accidentally - self.assertFalse(SerializeEventConfig().include_admin_metadata) + self.assertFalse( + self.get_success( + self._event_serializer.create_config() + ).include_admin_metadata + ) def test_event_flagged_for_admins(self) -> None: # Default behaviour should be *not* to include it @@ -875,34 +877,10 @@ def test_event_flagged_for_admins(self) -> None: }, ) - def test_make_serialize_config_for_admin_retains_other_fields(self) -> None: - non_default_config = SerializeEventConfig( - include_admin_metadata=False, # should be True in a moment - as_client_event=False, # default True - event_format=format_event_raw, # default format_event_for_client_v1 - requester=create_requester("@example:example.org"), # default None - only_event_fields=["foo"], # default None - include_stripped_room_state=True, # default False - ) - admin_config = make_config_for_admin(non_default_config) - self.assertEqual( - admin_config.as_client_event, non_default_config.as_client_event - ) - self.assertEqual(admin_config.event_format, non_default_config.event_format) - self.assertEqual(admin_config.requester, non_default_config.requester) - self.assertEqual( - admin_config.only_event_fields, non_default_config.only_event_fields - ) - self.assertEqual( - admin_config.include_stripped_room_state, - admin_config.include_stripped_room_state, - ) - self.assertTrue(admin_config.include_admin_metadata) - def test_redacted_because_is_filtered_out(self) -> None: """If an event's unsigned dict has a `redacted_by` field, then the `redacted_because` should be filtered out if not specified in - `only_event_fields`.""" + `event_field_allowlist`.""" redaction_id = "$redaction_event_id" @@ -1019,40 +997,3 @@ def test_invalid_types_raise_type_error(self) -> None: def test_invalid_nesting_raises_type_error(self) -> None: with self.assertRaises(TypeError): copy_and_fixup_power_levels_contents({"a": {"b": {"c": 1}}}) # type: ignore[dict-item] - - -class SplitFieldTestCase(stdlib_unittest.TestCase): - @parameterized.expand( - [ - # A field with no dots. - ["m", ["m"]], - # Simple dotted fields. - ["m.foo", ["m", "foo"]], - ["m.foo.bar", ["m", "foo", "bar"]], - # Backslash is used as an escape character. - [r"m\.foo", ["m.foo"]], - [r"m\\.foo", ["m\\", "foo"]], - [r"m\\\.foo", [r"m\.foo"]], - [r"m\\\\.foo", ["m\\\\", "foo"]], - [r"m\foo", [r"m\foo"]], - [r"m\\foo", [r"m\foo"]], - [r"m\\\foo", [r"m\\foo"]], - [r"m\\\\foo", [r"m\\foo"]], - # Ensure that escapes at the end don't cause issues. - ["m.foo\\", ["m", "foo\\"]], - ["m.foo\\", ["m", "foo\\"]], - [r"m.foo\.", ["m", "foo."]], - [r"m.foo\\.", ["m", "foo\\", ""]], - [r"m.foo\\\.", ["m", r"foo\."]], - # Empty parts (corresponding to properties which are an empty string) are allowed. - [".m", ["", "m"]], - ["..m", ["", "", "m"]], - ["m.", ["m", ""]], - ["m..", ["m", "", ""]], - ["m..foo", ["m", "", "foo"]], - # Invalid escape sequences. - [r"\m", [r"\m"]], - ] - ) - def test_split_field(self, input: str, expected: str) -> None: - self.assertEqual(_split_field(input), expected)