Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions serde-reflection/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,14 @@ impl<'de, 'a> de::Deserializer<'de> for Deserializer<'de, 'a> {
{
continue;
}
// Skip variants that were already fully traced during serialization.
if self
.tracer
.serialized_variants
.contains(&(enum_name.to_string(), variant_name.to_string()))
{
continue;
}
// Insert into known_variants with a provisional index.
let provisional_index = provisional_min + i as u32;
let variant = known_variants
Expand Down
14 changes: 13 additions & 1 deletion serde-reflection/src/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use erased_discriminant::Discriminant;
use once_cell::sync::Lazy;
use serde::{de::DeserializeSeed, Deserialize, Serialize};
use std::any::TypeId;
use std::collections::BTreeMap;
use std::collections::{BTreeMap, BTreeSet};

/// A map of container formats.
pub type Registry = BTreeMap<String, ContainerFormat>;
Expand All @@ -34,6 +34,11 @@ pub struct Tracer {

/// Discriminant associated with each variant of each enum.
pub(crate) discriminants: BTreeMap<(TypeId, VariantId<'static>), Discriminant>,

/// Enum variants whose serialized `VariantFormat` is already complete.
/// Keyed by (enum_name, variant_name). This allows `deserialize_enum` to skip
/// re-exploring variants that do not need further tracing.
pub(crate) serialized_variants: BTreeSet<(String, String)>,
}

/// Type of untraced enum variants
Expand Down Expand Up @@ -191,6 +196,7 @@ impl Tracer {
registry: BTreeMap::new(),
incomplete_enums: BTreeMap::new(),
discriminants: BTreeMap::new(),
serialized_variants: BTreeSet::new(),
}
}

Expand Down Expand Up @@ -367,6 +373,8 @@ impl Tracer {
variant: VariantFormat,
variant_value: Value,
) -> Result<(Format, Value)> {
let mut normalized_variant = variant.clone();
let is_complete = normalized_variant.normalize().is_ok();
let mut variants = BTreeMap::new();
variants.insert(
variant_index,
Expand All @@ -377,6 +385,10 @@ impl Tracer {
);
let format = ContainerFormat::Enum(variants);
let value = Value::Variant(variant_index, Box::new(variant_value));
if is_complete {
self.serialized_variants
.insert((name.to_string(), variant_name.to_string()));
}
self.record_container(samples, name, format, value, false)
}

Expand Down
107 changes: 107 additions & 0 deletions serde-reflection/tests/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,113 @@ fn test_mixed_tracing_for_multiple_enums() {
assert_eq!(variants.len(), 2);
}

#[test]
fn test_trace_type_revisits_partially_serialized_enum_variants() {
#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)]
enum E {
A,
B(Option<u32>),
}

let mut samples = Samples::new();
let mut tracer = Tracer::new(TracerConfig::default());

tracer.trace_value(&mut samples, &E::B(None)).unwrap();
tracer.trace_type::<E>(&samples).unwrap();

let registry = tracer.registry().unwrap();
let variants = match registry.get("E").unwrap() {
ContainerFormat::Enum(variants) => variants,
_ => panic!("should be an enum"),
};
assert_eq!(variants.len(), 2);
assert_eq!(variants.get(&0).unwrap().name, "A");
assert_eq!(variants.get(&1).unwrap().name, "B");
assert_eq!(
variants.get(&1).unwrap().value,
VariantFormat::NewType(Box::new(Format::Option(Box::new(Format::U32))))
);
}

#[test]
fn test_trace_type_skips_fully_serialized_bytes_variant() {
use serde::de::{self, Visitor};
use std::fmt;

#[derive(PartialEq, Eq, Debug, Clone)]
struct Exact16([u8; 16]);

impl Serialize for Exact16 {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_bytes(&self.0)
}
}

impl<'de> Deserialize<'de> for Exact16 {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct Exact16Visitor;

impl<'de> Visitor<'de> for Exact16Visitor {
type Value = Exact16;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("exactly 16 bytes")
}

fn visit_bytes<E>(self, value: &[u8]) -> std::result::Result<Self::Value, E>
where
E: de::Error,
{
let bytes: [u8; 16] = value
.try_into()
.map_err(|_| E::invalid_length(value.len(), &self))?;
Ok(Exact16(bytes))
}

fn visit_byte_buf<E>(self, value: Vec<u8>) -> std::result::Result<Self::Value, E>
where
E: de::Error,
{
self.visit_bytes(&value)
}
}

deserializer.deserialize_bytes(Exact16Visitor)
}
}

#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)]
enum Message {
Noop,
Send(Exact16),
}

let mut samples = Samples::new();
let mut tracer = Tracer::new(TracerConfig::default());

tracer
.trace_value(&mut samples, &Message::Send(Exact16([7; 16])))
.unwrap();
tracer.trace_type::<Message>(&samples).unwrap();

let registry = tracer.registry().unwrap();
let variants = match registry.get("Message").unwrap() {
ContainerFormat::Enum(variants) => variants,
_ => panic!("should be an enum"),
};
assert_eq!(variants.len(), 2);
assert_eq!(
variants.get(&1).unwrap().value,
VariantFormat::NewType(Box::new(Format::Bytes))
);
}

#[test]
fn test_value_recording_for_structs() {
#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)]
Expand Down
Loading