diff --git a/serde-reflection/src/de.rs b/serde-reflection/src/de.rs index 1eb8f25da..f95f1c63c 100644 --- a/serde-reflection/src/de.rs +++ b/serde-reflection/src/de.rs @@ -442,6 +442,14 @@ impl<'de, 'a> de::Deserializer<'de> for Deserializer<'de, 'a> { { continue; } + // Skip variants that were already fully traced during serialization. + if self + .tracer + .serialized_variants + .contains(&(enum_name.to_string(), variant_name.to_string())) + { + continue; + } // Insert into known_variants with a provisional index. let provisional_index = provisional_min + i as u32; let variant = known_variants diff --git a/serde-reflection/src/trace.rs b/serde-reflection/src/trace.rs index d3917d3e4..d57bbefb5 100644 --- a/serde-reflection/src/trace.rs +++ b/serde-reflection/src/trace.rs @@ -12,7 +12,7 @@ use erased_discriminant::Discriminant; use once_cell::sync::Lazy; use serde::{de::DeserializeSeed, Deserialize, Serialize}; use std::any::TypeId; -use std::collections::BTreeMap; +use std::collections::{BTreeMap, BTreeSet}; /// A map of container formats. pub type Registry = BTreeMap; @@ -34,6 +34,11 @@ pub struct Tracer { /// Discriminant associated with each variant of each enum. pub(crate) discriminants: BTreeMap<(TypeId, VariantId<'static>), Discriminant>, + + /// Enum variants whose serialized `VariantFormat` is already complete. + /// Keyed by (enum_name, variant_name). This allows `deserialize_enum` to skip + /// re-exploring variants that do not need further tracing. + pub(crate) serialized_variants: BTreeSet<(String, String)>, } /// Type of untraced enum variants @@ -191,6 +196,7 @@ impl Tracer { registry: BTreeMap::new(), incomplete_enums: BTreeMap::new(), discriminants: BTreeMap::new(), + serialized_variants: BTreeSet::new(), } } @@ -367,6 +373,8 @@ impl Tracer { variant: VariantFormat, variant_value: Value, ) -> Result<(Format, Value)> { + let mut normalized_variant = variant.clone(); + let is_complete = normalized_variant.normalize().is_ok(); let mut variants = BTreeMap::new(); variants.insert( variant_index, @@ -377,6 +385,10 @@ impl Tracer { ); let format = ContainerFormat::Enum(variants); let value = Value::Variant(variant_index, Box::new(variant_value)); + if is_complete { + self.serialized_variants + .insert((name.to_string(), variant_name.to_string())); + } self.record_container(samples, name, format, value, false) } diff --git a/serde-reflection/tests/serde.rs b/serde-reflection/tests/serde.rs index 2aeee78e4..6b235f2fa 100644 --- a/serde-reflection/tests/serde.rs +++ b/serde-reflection/tests/serde.rs @@ -406,6 +406,113 @@ fn test_mixed_tracing_for_multiple_enums() { assert_eq!(variants.len(), 2); } +#[test] +fn test_trace_type_revisits_partially_serialized_enum_variants() { + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)] + enum E { + A, + B(Option), + } + + let mut samples = Samples::new(); + let mut tracer = Tracer::new(TracerConfig::default()); + + tracer.trace_value(&mut samples, &E::B(None)).unwrap(); + tracer.trace_type::(&samples).unwrap(); + + let registry = tracer.registry().unwrap(); + let variants = match registry.get("E").unwrap() { + ContainerFormat::Enum(variants) => variants, + _ => panic!("should be an enum"), + }; + assert_eq!(variants.len(), 2); + assert_eq!(variants.get(&0).unwrap().name, "A"); + assert_eq!(variants.get(&1).unwrap().name, "B"); + assert_eq!( + variants.get(&1).unwrap().value, + VariantFormat::NewType(Box::new(Format::Option(Box::new(Format::U32)))) + ); +} + +#[test] +fn test_trace_type_skips_fully_serialized_bytes_variant() { + use serde::de::{self, Visitor}; + use std::fmt; + + #[derive(PartialEq, Eq, Debug, Clone)] + struct Exact16([u8; 16]); + + impl Serialize for Exact16 { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + serializer.serialize_bytes(&self.0) + } + } + + impl<'de> Deserialize<'de> for Exact16 { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct Exact16Visitor; + + impl<'de> Visitor<'de> for Exact16Visitor { + type Value = Exact16; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str("exactly 16 bytes") + } + + fn visit_bytes(self, value: &[u8]) -> std::result::Result + where + E: de::Error, + { + let bytes: [u8; 16] = value + .try_into() + .map_err(|_| E::invalid_length(value.len(), &self))?; + Ok(Exact16(bytes)) + } + + fn visit_byte_buf(self, value: Vec) -> std::result::Result + where + E: de::Error, + { + self.visit_bytes(&value) + } + } + + deserializer.deserialize_bytes(Exact16Visitor) + } + } + + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)] + enum Message { + Noop, + Send(Exact16), + } + + let mut samples = Samples::new(); + let mut tracer = Tracer::new(TracerConfig::default()); + + tracer + .trace_value(&mut samples, &Message::Send(Exact16([7; 16]))) + .unwrap(); + tracer.trace_type::(&samples).unwrap(); + + let registry = tracer.registry().unwrap(); + let variants = match registry.get("Message").unwrap() { + ContainerFormat::Enum(variants) => variants, + _ => panic!("should be an enum"), + }; + assert_eq!(variants.len(), 2); + assert_eq!( + variants.get(&1).unwrap().value, + VariantFormat::NewType(Box::new(Format::Bytes)) + ); +} + #[test] fn test_value_recording_for_structs() { #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)]