chromium/third_party/rust/chromium_crates_io/vendor/rstest_macros-0.17.0/src/utils.rs

/// Contains some unsorted functions used across others modules
///
use quote::format_ident;
use std::collections::{HashMap, HashSet};

use crate::refident::MaybeIdent;
use syn::{Attribute, Expr, FnArg, Generics, Ident, ItemFn, ReturnType, Type, WherePredicate};

/// Return an iterator over fn arguments items.
///
pub(crate) fn fn_args_idents(test: &ItemFn) -> impl Iterator<Item = &Ident> {
    fn_args(test).filter_map(MaybeIdent::maybe_ident)
}

/// Return if function declaration has an ident
///
pub(crate) fn fn_args_has_ident(fn_decl: &ItemFn, ident: &Ident) -> bool {
    fn_args_idents(fn_decl).any(|id| id == ident)
}

/// Return an iterator over fn arguments.
///
pub(crate) fn fn_args(item_fn: &ItemFn) -> impl Iterator<Item = &FnArg> {
    item_fn.sig.inputs.iter()
}

pub(crate) fn attr_ends_with(attr: &Attribute, segment: &syn::PathSegment) -> bool {
    attr.path.segments.iter().last() == Some(segment)
}

pub(crate) fn attr_starts_with(attr: &Attribute, segment: &syn::PathSegment) -> bool {
    attr.path.segments.iter().next() == Some(segment)
}

pub(crate) fn attr_is(attr: &Attribute, name: &str) -> bool {
    attr.path.is_ident(&format_ident!("{}", name))
}

pub(crate) fn attr_in(attr: &Attribute, names: &[&str]) -> bool {
    names
        .iter()
        .any(|name| attr.path.is_ident(&format_ident!("{}", name)))
}

pub(crate) trait IsLiteralExpression {
    fn is_literal(&self) -> bool;
}

impl<E: AsRef<Expr>> IsLiteralExpression for E {
    fn is_literal(&self) -> bool {
        matches!(
            self.as_ref(),
            Expr::Lit(syn::ExprLit {
                lit: syn::Lit::Str(_),
                ..
            })
        )
    }
}

// Recoursive search id by reference till find one in ends
fn _is_used(
    visited: &mut HashSet<Ident>,
    id: &Ident,
    references: &HashMap<Ident, HashSet<Ident>>,
    ends: &HashSet<Ident>,
) -> bool {
    if visited.contains(id) {
        return false;
    }
    visited.insert(id.clone());
    if ends.contains(id) {
        return true;
    }
    if references.contains_key(id) {
        for refered in references.get(id).unwrap() {
            if _is_used(visited, refered, references, ends) {
                return true;
            }
        }
    }
    false
}

// Recoursive search id by reference till find one in ends
fn is_used(id: &Ident, references: &HashMap<Ident, HashSet<Ident>>, ends: &HashSet<Ident>) -> bool {
    let mut visited = Default::default();
    _is_used(&mut visited, id, references, ends)
}

impl MaybeIdent for syn::WherePredicate {
    fn maybe_ident(&self) -> Option<&Ident> {
        match self {
            WherePredicate::Type(syn::PredicateType { bounded_ty: t, .. }) => {
                first_type_path_segment_ident(t)
            }
            WherePredicate::Lifetime(syn::PredicateLifetime { lifetime, .. }) => {
                Some(&lifetime.ident)
            }
            WherePredicate::Eq(_) => None,
        }
    }
}

#[derive(Default)]
struct SearchSimpleTypeName(HashSet<Ident>);

impl SearchSimpleTypeName {
    fn take(self) -> HashSet<Ident> {
        self.0
    }

    fn visit_inputs<'a>(&mut self, inputs: impl Iterator<Item = &'a FnArg>) {
        use syn::visit::Visit;
        inputs.for_each(|fn_arg| self.visit_fn_arg(fn_arg));
    }
    fn visit_output(&mut self, output: &ReturnType) {
        use syn::visit::Visit;
        self.visit_return_type(output);
    }

    fn collect_from_type_param(tp: &syn::TypeParam) -> Self {
        let mut s: Self = Default::default();
        use syn::visit::Visit;
        s.visit_type_param(tp);
        s
    }

    fn collect_from_where_predicate(wp: &syn::WherePredicate) -> Self {
        let mut s: Self = Default::default();
        use syn::visit::Visit;
        s.visit_where_predicate(wp);
        s
    }
}

impl<'ast> syn::visit::Visit<'ast> for SearchSimpleTypeName {
    fn visit_path(&mut self, p: &'ast syn::Path) {
        if let Some(id) = p.get_ident() {
            self.0.insert(id.clone());
        }
        syn::visit::visit_path(self, p)
    }

    fn visit_lifetime(&mut self, i: &'ast syn::Lifetime) {
        self.0.insert(i.ident.clone());
        syn::visit::visit_lifetime(self, i)
    }
}

// Take generics definitions and where clauses and return the
// a map from simple types (lifetime names or type with just names)
// to a set of all simple types that use it as some costrain.
fn extract_references_map(generics: &Generics) -> HashMap<Ident, HashSet<Ident>> {
    let mut references = HashMap::<Ident, HashSet<Ident>>::default();
    // Extracts references from types param
    generics.type_params().for_each(|tp| {
        SearchSimpleTypeName::collect_from_type_param(tp)
            .take()
            .into_iter()
            .for_each(|id| {
                references.entry(id).or_default().insert(tp.ident.clone());
            });
    });
    // Extracts references from where clauses
    generics
        .where_clause
        .iter()
        .flat_map(|wc| wc.predicates.iter())
        .filter_map(|wp| wp.maybe_ident().map(|id| (id, wp)))
        .for_each(|(ref_ident, wp)| {
            SearchSimpleTypeName::collect_from_where_predicate(wp)
                .take()
                .into_iter()
                .for_each(|id| {
                    references.entry(id).or_default().insert(ref_ident.clone());
                });
        });
    references
}

// Return a hash set that contains all types and lifetimes referenced
// in input/output expressed by a single ident.
fn references_ident_types<'a>(
    generics: &Generics,
    inputs: impl Iterator<Item = &'a FnArg>,
    output: &ReturnType,
) -> HashSet<Ident> {
    let mut used: SearchSimpleTypeName = Default::default();
    used.visit_output(output);
    used.visit_inputs(inputs);
    let references = extract_references_map(generics);
    let mut used = used.take();
    let input_output = used.clone();
    // Extend the input output collected ref with the transitive ones:
    used.extend(
        generics
            .params
            .iter()
            .filter_map(MaybeIdent::maybe_ident)
            .filter(|&id| is_used(id, &references, &input_output))
            .cloned(),
    );
    used
}

fn filtered_predicates(mut wc: syn::WhereClause, valids: &HashSet<Ident>) -> syn::WhereClause {
    wc.predicates = wc
        .predicates
        .clone()
        .into_iter()
        .filter(|wp| {
            wp.maybe_ident()
                .map(|t| valids.contains(t))
                .unwrap_or_default()
        })
        .collect();
    wc
}

fn filtered_generics<'a>(
    params: impl Iterator<Item = syn::GenericParam> + 'a,
    valids: &'a HashSet<Ident>,
) -> impl Iterator<Item = syn::GenericParam> + 'a {
    params.filter(move |p| match p.maybe_ident() {
        Some(id) => valids.contains(id),
        None => false,
    })
}

//noinspection RsTypeCheck
pub(crate) fn generics_clean_up<'a>(
    original: &Generics,
    inputs: impl Iterator<Item = &'a FnArg>,
    output: &ReturnType,
) -> syn::Generics {
    let used = references_ident_types(original, inputs, output);
    let mut result: Generics = original.clone();
    result.params = filtered_generics(result.params.into_iter(), &used).collect();
    result.where_clause = result.where_clause.map(|wc| filtered_predicates(wc, &used));
    result
}

// If type is not self and doesn't starts with :: return the first ident
// of its path segment: only if is a simple path.
// If type is a simple ident just return the this ident. That is useful to
// find the base type for associate type indication
fn first_type_path_segment_ident(t: &Type) -> Option<&Ident> {
    match t {
        Type::Path(tp) if tp.qself.is_none() && tp.path.leading_colon.is_none() => tp
            .path
            .segments
            .iter()
            .next()
            .and_then(|ps| match ps.arguments {
                syn::PathArguments::None => Some(&ps.ident),
                _ => None,
            }),
        _ => None,
    }
}

pub(crate) fn fn_arg_mutability(arg: &FnArg) -> Option<syn::token::Mut> {
    match arg {
        FnArg::Typed(syn::PatType { pat, .. }) => match pat.as_ref() {
            syn::Pat::Ident(syn::PatIdent { mutability, .. }) => *mutability,
            _ => None,
        },
        _ => None,
    }
}

#[cfg(test)]
mod test {
    use syn::parse_quote;

    use super::*;
    use crate::test::{assert_eq, *};

    #[test]
    fn fn_args_idents_should() {
        let item_fn = parse_quote! {
            fn the_functon(first: u32, second: u32) {}
        };

        let mut args = fn_args_idents(&item_fn);

        assert_eq!("first", args.next().unwrap().to_string());
        assert_eq!("second", args.next().unwrap().to_string());
    }

    #[test]
    fn fn_args_has_ident_should() {
        let item_fn = parse_quote! {
            fn the_functon(first: u32, second: u32) {}
        };

        assert!(fn_args_has_ident(&item_fn, &ident("first")));
        assert!(!fn_args_has_ident(&item_fn, &ident("third")));
    }

    #[rstest]
    #[case::base("fn foo<A, B, C>(a: A) -> B {}", &["A", "B"])]
    #[case::use_const_in_array("fn foo<A, const B: usize, C>(a: A) -> [u32; B] {}", &["A", "B", "u32"])]
    #[case::in_type_args("fn foo<A, const B: usize, C>(a: A) -> SomeType<B> {}", &["A", "B"])]
    #[case::in_type_args("fn foo<A, const B: usize, C>(a: SomeType<A>, b: SomeType<B>) {}", &["A", "B"])]
    #[case::pointers("fn foo<A, B, C>(a: *const A, b: &B) {}", &["A", "B"])]
    #[case::lifetime("fn foo<'a, A, B, C>(a: A, b: &'a B) {}", &["a", "A", "B"])]
    #[case::transitive_lifetime("fn foo<'a, A, B, C>(a: A, b: B) where B: Iterator<Item=A> + 'a {}", &["a", "A", "B"])]
    #[case::associated("fn foo<'a, A:Copy, C>(b: impl Iterator<Item=A> + 'a) {}", &["a", "A"])]
    #[case::transitive_in_defs("fn foo<A:Copy, B: Iterator<Item=A>>(b: B) {}", &["A", "B"])]
    #[case::transitive_in_where("fn foo<A:Copy, B>(b: B) where B: Iterator<Item=A> {}", &["A", "B"])]
    #[case::transitive_const("fn foo<const A: usize, B, C>(b: B) where B: Some<A> {}", &["A", "B"])]
    #[case::transitive_lifetime("fn foo<'a, A, B, C>(a: A, b: B) where B: Iterator<Item=A> + 'a {}", &["a", "A", "B"])]
    #[case::transitive_lifetime(r#"fn foo<'a, 'b, 'c, 'd, A, B, C>
        (a: A, b: B) 
        where B: Iterator<Item=A> + 'c, 
        'c: 'a + 'b {}"#, &["a", "b", "c", "A", "B"])]
    fn references_ident_types_should(#[case] f: &str, #[case] expected: &[&str]) {
        let f: ItemFn = f.ast();
        let used = references_ident_types(&f.sig.generics, f.sig.inputs.iter(), &f.sig.output);

        let expected = to_idents!(expected)
            .into_iter()
            .collect::<std::collections::HashSet<_>>();

        assert_eq!(expected, used);
    }

    #[rstest]
    #[case::remove_not_in_output(
        r#"fn test<R: AsRef<str>, B, F, H: Iterator<Item=u32>>() -> (H, B, String, &str)
                        where F: ToString,
                        B: Borrow<u32>
                        {}"#,
        r#"fn test<B, H: Iterator<Item=u32>>() -> (H, B, String, &str)
                        where B: Borrow<u32>
                {}"#
    )]
    #[case::not_remove_used_in_arguments(
        r#"fn test<R: AsRef<str>, B, F, H: Iterator<Item=u32>>
                    (h: H, it: impl Iterator<Item=R>, j: &[B])
                    where F: ToString,
                    B: Borrow<u32>
                {}"#,
        r#"fn test<R: AsRef<str>, B, H: Iterator<Item=u32>>
                    (h: H, it: impl Iterator<Item=R>, j: &[B])
                    where
                    B: Borrow<u32>
                {}"#
    )]
    #[case::dont_remove_transitive(
        r#"fn test<A, B, C, D, const F: usize, O>(a: A) where 
            B: AsRef<C>,
            A: Iterator<Item=[B; F]>,
            D: ArsRef<O> {}"#,
        r#"fn test<A, B, C, const F: usize>(a: A) where 
            B: AsRef<C>,
            A: Iterator<Item=[B; F]> {}"#
    )]
    #[case::remove_unused_lifetime(
        "fn test<'a, 'b, 'c, 'd, 'e, 'f, 'g, A>(a: &'a uint32, b: impl AsRef<A> + 'b) where 'b: 'c + 'd, A: Copy + 'e, 'f: 'g {}",
        "fn test<'a, 'b, 'c, 'd, 'e, A>(a: &'a uint32, b: impl AsRef<A> + 'b) where 'b: 'c + 'd, A: Copy + 'e {}"
    )]
    #[case::remove_unused_const(
        r#"fn test<const A: usize, const B: usize, const C: usize, const D: usize, T, O>
            (a: [u32; A], b: SomeType<B>, c: T) where 
            T: Iterator<Item=[i32; C]>,
            O: AsRef<D> 
            {}"#,
        r#"fn test<const A: usize, const B: usize, const C: usize, T>
            (a: [u32; A], b: SomeType<B>, c: T) where 
            T: Iterator<Item=[i32; C]>
            {}"#
    )]
    fn generics_cleaner(#[case] code: &str, #[case] expected: &str) {
        // Should remove all generics parameters that are not present in output
        let item_fn: ItemFn = code.ast();

        let expected: ItemFn = expected.ast();

        let cleaned = generics_clean_up(
            &item_fn.sig.generics,
            item_fn.sig.inputs.iter(),
            &item_fn.sig.output,
        );

        assert_eq!(expected.sig.generics, cleaned);
    }
}