chromium/third_party/rust/serde_json_lenient/v0_2/wrapper/visitor.rs

// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

use crate::{ContextPointer, Functions};
use serde::de::{DeserializeSeed, Deserializer, Error, MapAccess, SeqAccess, Visitor};
use std::convert::TryFrom;
use std::fmt;
use std::pin::Pin;

/// Watches to ensure recursion does not go too deep during deserialization.
struct RecursionDepthCheck(usize);

impl RecursionDepthCheck {
    /// Recurse a level and return an error if we've recursed too far.
    fn recurse<E: Error>(&self) -> Result<RecursionDepthCheck, E> {
        match self.0.checked_sub(1) {
            Some(recursion_limit) => Ok(RecursionDepthCheck(recursion_limit)),
            None => Err(Error::custom("recursion limit exceeded")),
        }
    }
}

/// What type of aggregate JSON type is being deserialized.
pub enum DeserializationTarget<'c> {
    /// Deserialize by appending to a list.
    List { ctx: Pin<&'c mut ContextPointer> },
    /// Deserialize by setting a dictionary key.
    Dict { ctx: Pin<&'c mut ContextPointer>, key: String },
}

/// A deserializer and visitor type that is used to visit each value in the JSON
/// input when it is deserialized.
///
/// Normally serde deserialization instantiates a new object, but this visitor
/// is designed to call back into C++ for creating the deserialized objects. To
/// achieve this we use a feature of serde called "stateful deserialization" (https://docs.serde.rs/serde/de/trait.DeserializeSeed.html).
pub struct ValueVisitor<'c> {
    fns: &'static Functions,
    aggregate: DeserializationTarget<'c>,
    recursion_depth_check: RecursionDepthCheck,
}

impl<'c> ValueVisitor<'c> {
    pub fn new(
        fns: &'static Functions,
        target: DeserializationTarget<'c>,
        max_depth: usize,
    ) -> Self {
        Self {
            fns,
            aggregate: target,
            // The `max_depth` includes the top level of the JSON input, which is where parsing
            // starts. We subtract 1 to count the top level now.
            recursion_depth_check: RecursionDepthCheck(max_depth - 1),
        }
    }
}

impl<'de, 'c> Visitor<'de> for ValueVisitor<'c> {
    // We call out to C++ to construct the deserialized type, so no output from the
    // visitor.
    type Value = ();

    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
        formatter.write_str("any valid JSON")
    }

    fn visit_i32<E: serde::de::Error>(self, value: i32) -> Result<Self::Value, E> {
        match self.aggregate {
            DeserializationTarget::List { ctx } => self.fns.list_append_i32(ctx, value),
            DeserializationTarget::Dict { ctx, key } => self.fns.dict_set_i32(ctx, &key, value),
        };
        Ok(())
    }

    fn visit_i8<E: serde::de::Error>(self, value: i8) -> Result<Self::Value, E> {
        self.visit_i32(value as i32)
    }

    fn visit_bool<E: serde::de::Error>(self, value: bool) -> Result<Self::Value, E> {
        match self.aggregate {
            DeserializationTarget::List { ctx } => self.fns.list_append_bool(ctx, value),
            DeserializationTarget::Dict { ctx, key } => self.fns.dict_set_bool(ctx, &key, value),
        };
        Ok(())
    }

    fn visit_i64<E: serde::de::Error>(self, value: i64) -> Result<Self::Value, E> {
        // Integer values that are > 32 bits large are returned as doubles instead. See
        // JSONReaderTest.LargerIntIsLossy for a related test.
        match i32::try_from(value) {
            Ok(value) => self.visit_i32(value),
            Err(_) => self.visit_f64(value as f64),
        }
    }

    fn visit_u64<E: serde::de::Error>(self, value: u64) -> Result<Self::Value, E> {
        // See visit_i64 comment.
        match i32::try_from(value) {
            Ok(value) => self.visit_i32(value),
            Err(_) => self.visit_f64(value as f64),
        }
    }

    fn visit_f64<E: serde::de::Error>(self, value: f64) -> Result<Self::Value, E> {
        match self.aggregate {
            DeserializationTarget::List { ctx } => self.fns.list_append_f64(ctx, value),
            DeserializationTarget::Dict { ctx, key } => self.fns.dict_set_f64(ctx, &key, value),
        };
        Ok(())
    }

    fn visit_str<E: serde::de::Error>(self, value: &str) -> Result<Self::Value, E> {
        match self.aggregate {
            DeserializationTarget::List { ctx } => self.fns.list_append_str(ctx, value),
            DeserializationTarget::Dict { ctx, key } => self.fns.dict_set_str(ctx, &key, value),
        };
        Ok(())
    }

    fn visit_borrowed_str<E: serde::de::Error>(self, value: &'de str) -> Result<Self::Value, E> {
        match self.aggregate {
            DeserializationTarget::List { ctx } => self.fns.list_append_str(ctx, value),
            DeserializationTarget::Dict { ctx, key } => self.fns.dict_set_str(ctx, &key, value),
        };
        Ok(())
    }

    fn visit_string<E: serde::de::Error>(self, value: String) -> Result<Self::Value, E> {
        self.visit_str(&value)
    }

    fn visit_none<E: serde::de::Error>(self) -> Result<Self::Value, E> {
        match self.aggregate {
            DeserializationTarget::List { ctx } => self.fns.list_append_none(ctx),
            DeserializationTarget::Dict { ctx, key } => self.fns.dict_set_none(ctx, &key),
        };
        Ok(())
    }

    fn visit_unit<E: serde::de::Error>(self) -> Result<Self::Value, E> {
        self.visit_none()
    }

    fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
    where
        M: MapAccess<'de>,
    {
        // TODO(danakj): base::Value::Dict doesn't expose a way to reserve space, so we
        // don't bother using `access.size_hint()` here, unlike when creating a
        // list.
        let mut inner_ctx = match self.aggregate {
            DeserializationTarget::List { ctx } => self.fns.list_append_dict(ctx),
            DeserializationTarget::Dict { ctx, key } => self.fns.dict_set_dict(ctx, &key),
        };
        while let Some(key) = access.next_key::<String>()? {
            access.next_value_seed(ValueVisitor {
                fns: self.fns,
                aggregate: DeserializationTarget::Dict { ctx: inner_ctx.as_mut(), key },
                recursion_depth_check: self.recursion_depth_check.recurse()?,
            })?;
        }
        Ok(())
    }

    fn visit_seq<S>(self, mut access: S) -> Result<Self::Value, S::Error>
    where
        S: SeqAccess<'de>,
    {
        let mut inner_ctx = match self.aggregate {
            DeserializationTarget::List { ctx } => {
                self.fns.list_append_list(ctx, access.size_hint().unwrap_or(0))
            }
            DeserializationTarget::Dict { ctx, key } => {
                self.fns.dict_set_list(ctx, &key, access.size_hint().unwrap_or(0))
            }
        };
        while let Some(_) = access.next_element_seed(ValueVisitor {
            fns: self.fns,
            aggregate: DeserializationTarget::List { ctx: inner_ctx.as_mut() },
            recursion_depth_check: self.recursion_depth_check.recurse()?,
        })? {}
        Ok(())
    }
}

impl<'de, 'c> DeserializeSeed<'de> for ValueVisitor<'c> {
    // We call out to C++ to construct the deserialized type, so no output from
    // here.
    type Value = ();

    fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
    where
        D: Deserializer<'de>,
    {
        deserializer.deserialize_any(self)
    }
}