chromium/third_party/rust/chromium_crates_io/vendor/rstest_macros-0.17.0/src/parse/future.rs

use quote::{format_ident, ToTokens};
use syn::{visit_mut::VisitMut, FnArg, Ident, ItemFn, PatType, Type};

use crate::{error::ErrorsVec, refident::MaybeType, utils::attr_is};

use super::{arguments::FutureArg, extract_argument_attrs};

pub(crate) fn extract_futures(
    item_fn: &mut ItemFn,
) -> Result<(Vec<(Ident, FutureArg)>, bool), ErrorsVec> {
    let mut extractor = FutureFunctionExtractor::default();
    extractor.visit_item_fn_mut(item_fn);
    extractor.take()
}

pub(crate) trait MaybeFutureImplType {
    fn as_future_impl_type(&self) -> Option<&Type>;

    fn as_mut_future_impl_type(&mut self) -> Option<&mut Type>;
}

impl MaybeFutureImplType for FnArg {
    fn as_future_impl_type(&self) -> Option<&Type> {
        match self {
            FnArg::Typed(PatType { ty, .. }) if can_impl_future(ty.as_ref()) => Some(ty.as_ref()),
            _ => None,
        }
    }

    fn as_mut_future_impl_type(&mut self) -> Option<&mut Type> {
        match self {
            FnArg::Typed(PatType { ty, .. }) if can_impl_future(ty.as_ref()) => Some(ty.as_mut()),
            _ => None,
        }
    }
}

fn can_impl_future(ty: &Type) -> bool {
    use Type::*;
    !matches!(
        ty,
        Group(_)
            | ImplTrait(_)
            | Infer(_)
            | Macro(_)
            | Never(_)
            | Slice(_)
            | TraitObject(_)
            | Verbatim(_)
    )
}

/// Simple struct used to visit function attributes and extract future args to
/// implement the boilerplate.
#[derive(Default)]
struct FutureFunctionExtractor {
    futures: Vec<(Ident, FutureArg)>,
    awt: bool,
    errors: Vec<syn::Error>,
}

impl FutureFunctionExtractor {
    pub(crate) fn take(self) -> Result<(Vec<(Ident, FutureArg)>, bool), ErrorsVec> {
        if self.errors.is_empty() {
            Ok((self.futures, self.awt))
        } else {
            Err(self.errors.into())
        }
    }
}

impl VisitMut for FutureFunctionExtractor {
    fn visit_item_fn_mut(&mut self, node: &mut ItemFn) {
        let attrs = std::mem::take(&mut node.attrs);
        let (awts, remain): (Vec<_>, Vec<_>) = attrs.into_iter().partition(|a| attr_is(a, "awt"));
        self.awt = match awts.len().cmp(&1) {
            std::cmp::Ordering::Equal => true,
            std::cmp::Ordering::Greater => {
                self.errors.extend(awts.into_iter().skip(1).map(|a| {
                    syn::Error::new_spanned(
                        a.into_token_stream(),
                        "Cannot use #[awt] more than once.".to_owned(),
                    )
                }));
                false
            }
            std::cmp::Ordering::Less => false,
        };
        node.attrs = remain;
        syn::visit_mut::visit_item_fn_mut(self, node);
    }

    fn visit_fn_arg_mut(&mut self, node: &mut FnArg) {
        if matches!(node, FnArg::Receiver(_)) {
            return;
        }
        match extract_argument_attrs(
            node,
            |a| attr_is(a, "future"),
            |arg, name| {
                let kind = if arg.tokens.is_empty() {
                    FutureArg::Define
                } else {
                    match arg.parse_args::<Option<Ident>>()? {
                        Some(awt) if awt == format_ident!("awt") => FutureArg::Await,
                        None => FutureArg::Define,
                        Some(invalid) => {
                            return Err(syn::Error::new_spanned(
                                arg.parse_args::<Option<Ident>>()?.into_token_stream(),
                                format!("Invalid '{invalid}' #[future(...)] arg."),
                            ));
                        }
                    }
                };
                Ok((arg, name.clone(), kind))
            },
        )
        .collect::<Result<Vec<_>, _>>()
        {
            Ok(futures) => match futures.len().cmp(&1) {
                std::cmp::Ordering::Equal => match node.as_future_impl_type() {
                    Some(_) => self.futures.push((futures[0].1.clone(), futures[0].2)),
                    None => self.errors.push(syn::Error::new_spanned(
                        node.maybe_type().unwrap().into_token_stream(),
                        "This type cannot used to generate impl Future.".to_owned(),
                    )),
                },
                std::cmp::Ordering::Greater => {
                    self.errors
                        .extend(futures.iter().skip(1).map(|(attr, _ident, _type)| {
                            syn::Error::new_spanned(
                                attr.into_token_stream(),
                                "Cannot use #[future] more than once.".to_owned(),
                            )
                        }));
                }
                std::cmp::Ordering::Less => {}
            },
            Err(e) => {
                self.errors.push(e);
            }
        };
    }
}

#[cfg(test)]
mod should {
    use super::*;
    use crate::test::{assert_eq, *};
    use rstest_test::assert_in;

    #[rstest]
    #[case("fn simple(a: u32) {}")]
    #[case("fn more(a: u32, b: &str) {}")]
    #[case("fn gen<S: AsRef<str>>(a: u32, b: S) {}")]
    #[case("fn attr(#[case] a: u32, #[values(1,2)] b: i32) {}")]
    fn not_change_anything_if_no_future_attribute_found(#[case] item_fn: &str) {
        let mut item_fn: ItemFn = item_fn.ast();
        let orig = item_fn.clone();

        let (futures, awt) = extract_futures(&mut item_fn).unwrap();

        assert_eq!(orig, item_fn);
        assert!(futures.is_empty());
        assert!(!awt);
    }

    #[rstest]
    #[case::simple("fn f(#[future] a: u32) {}", "fn f(a: u32) {}", &[("a", FutureArg::Define)], false)]
    #[case::global_awt("#[awt] fn f(a: u32) {}", "fn f(a: u32) {}", &[], true)]
    #[case::simple_awaited("fn f(#[future(awt)] a: u32) {}", "fn f(a: u32) {}", &[("a", FutureArg::Await)], false)]
    #[case::simple_awaited_and_global("#[awt] fn f(#[future(awt)] a: u32) {}", "fn f(a: u32) {}", &[("a", FutureArg::Await)], true)]
    #[case::more_than_one(
        "fn f(#[future] a: u32, #[future(awt)] b: String, #[future()] c: std::collection::HashMap<usize, String>) {}",
        r#"fn f(a: u32, 
                b: String, 
                c: std::collection::HashMap<usize, String>) {}"#,
        &[("a", FutureArg::Define), ("b", FutureArg::Await), ("c", FutureArg::Define)],
        false,
    )]
    #[case::just_one(
        "fn f(a: u32, #[future] b: String) {}",
        r#"fn f(a: u32, b: String) {}"#,
        &[("b", FutureArg::Define)],
        false,
    )]
    #[case::just_one_awaited(
        "fn f(a: u32, #[future(awt)] b: String) {}",
        r#"fn f(a: u32, b: String) {}"#,
        &[("b", FutureArg::Await)],
        false,
    )]
    fn extract(
        #[case] item_fn: &str,
        #[case] expected: &str,
        #[case] expected_futures: &[(&str, FutureArg)],
        #[case] expected_awt: bool,
    ) {
        let mut item_fn: ItemFn = item_fn.ast();
        let expected: ItemFn = expected.ast();

        let (futures, awt) = extract_futures(&mut item_fn).unwrap();

        assert_eq!(expected, item_fn);
        assert_eq!(
            futures,
            expected_futures
                .into_iter()
                .map(|(id, a)| (ident(id), *a))
                .collect::<Vec<_>>()
        );
        assert_eq!(expected_awt, awt);
    }

    #[rstest]
    #[case::base(r#"#[awt] fn f(a: u32) {}"#, r#"fn f(a: u32) {}"#)]
    #[case::two(
        r#"
        #[awt]
        #[awt] 
        fn f(a: u32) {}
        "#,
        r#"fn f(a: u32) {}"#
    )]
    #[case::inner(
        r#"
        #[one]
        #[awt] 
        #[two]
        fn f(a: u32) {}
        "#,
        r#"
        #[one]
        #[two]
        fn f(a: u32) {}
        "#
    )]
    fn remove_all_awt_attributes(#[case] item_fn: &str, #[case] expected: &str) {
        let mut item_fn: ItemFn = item_fn.ast();
        let expected: ItemFn = expected.ast();

        let _ = extract_futures(&mut item_fn);

        assert_eq!(item_fn, expected);
    }

    #[rstest]
    #[case::no_more_than_one("fn f(#[future] #[future] a: u32) {}", "more than once")]
    #[case::no_impl("fn f(#[future] a: impl AsRef<str>) {}", "generate impl Future")]
    #[case::no_slice("fn f(#[future] a: [i32]) {}", "generate impl Future")]
    #[case::invalid_arg("fn f(#[future(other)] a: [i32]) {}", "Invalid 'other'")]
    #[case::no_more_than_one_awt("#[awt] #[awt] fn f(a: u32) {}", "more than once")]
    fn raise_error(#[case] item_fn: &str, #[case] message: &str) {
        let mut item_fn: ItemFn = item_fn.ast();

        let err = extract_futures(&mut item_fn).unwrap_err();

        assert_in!(format!("{:?}", err), message);
    }
}