#![allow(unused_imports)]
use std::{cmp, convert::TryFrom};
use proc_macro2::{Ident, Span, TokenStream, TokenTree};
use quote::{quote, quote_spanned, ToTokens};
use syn::{
parse::{Parse, ParseStream, Parser},
punctuated::Punctuated,
spanned::Spanned,
Result, *,
};
macro_rules! bail {
($msg:expr $(,)?) => {
return Err(Error::new(Span::call_site(), &$msg[..]))
};
( $msg:expr => $span_to_blame:expr $(,)? ) => {
return Err(Error::new_spanned(&$span_to_blame, $msg))
};
}
pub trait Derivable {
fn ident(input: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path>;
fn implies_trait(_crate_name: &TokenStream) -> Option<TokenStream> {
None
}
fn asserts(
_input: &DeriveInput, _crate_name: &TokenStream,
) -> Result<TokenStream> {
Ok(quote!())
}
fn check_attributes(_ty: &Data, _attributes: &[Attribute]) -> Result<()> {
Ok(())
}
fn trait_impl(
_input: &DeriveInput, _crate_name: &TokenStream,
) -> Result<(TokenStream, TokenStream)> {
Ok((quote!(), quote!()))
}
fn requires_where_clause() -> bool {
true
}
fn explicit_bounds_attribute_name() -> Option<&'static str> {
None
}
}
pub struct Pod;
impl Derivable for Pod {
fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
Ok(syn::parse_quote!(#crate_name::Pod))
}
fn asserts(
input: &DeriveInput, crate_name: &TokenStream,
) -> Result<TokenStream> {
let repr = get_repr(&input.attrs)?;
let completly_packed =
repr.packed == Some(1) || repr.repr == Repr::Transparent;
if !completly_packed && !input.generics.params.is_empty() {
bail!("\
Pod requires cannot be derived for non-packed types containing \
generic parameters because the padding requirements can't be verified \
for generic non-packed structs\
" => input.generics.params.first().unwrap());
}
match &input.data {
Data::Struct(_) => {
let assert_no_padding = if !completly_packed {
Some(generate_assert_no_padding(input)?)
} else {
None
};
let assert_fields_are_pod =
generate_fields_are_trait(input, Self::ident(input, crate_name)?)?;
Ok(quote!(
#assert_no_padding
#assert_fields_are_pod
))
}
Data::Enum(_) => bail!("Deriving Pod is not supported for enums"),
Data::Union(_) => bail!("Deriving Pod is not supported for unions"),
}
}
fn check_attributes(_ty: &Data, attributes: &[Attribute]) -> Result<()> {
let repr = get_repr(attributes)?;
match repr.repr {
Repr::C => Ok(()),
Repr::Transparent => Ok(()),
_ => {
bail!("Pod requires the type to be #[repr(C)] or #[repr(transparent)]")
}
}
}
}
pub struct AnyBitPattern;
impl Derivable for AnyBitPattern {
fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
Ok(syn::parse_quote!(#crate_name::AnyBitPattern))
}
fn implies_trait(crate_name: &TokenStream) -> Option<TokenStream> {
Some(quote!(#crate_name::Zeroable))
}
fn asserts(
input: &DeriveInput, crate_name: &TokenStream,
) -> Result<TokenStream> {
match &input.data {
Data::Union(_) => Ok(quote!()), // unions are always `AnyBitPattern`
Data::Struct(_) => {
generate_fields_are_trait(input, Self::ident(input, crate_name)?)
}
Data::Enum(_) => {
bail!("Deriving AnyBitPattern is not supported for enums")
}
}
}
}
pub struct Zeroable;
impl Derivable for Zeroable {
fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
Ok(syn::parse_quote!(#crate_name::Zeroable))
}
fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> {
let repr = get_repr(attributes)?;
match ty {
Data::Struct(_) => Ok(()),
Data::Enum(DataEnum { variants, .. }) => {
if !repr.repr.is_integer() {
bail!("Zeroable requires the enum to be an explicit #[repr(Int)]")
}
if variants.iter().any(|variant| !variant.fields.is_empty()) {
bail!("Only fieldless enums are supported for Zeroable")
}
let iter = VariantDiscriminantIterator::new(variants.iter());
let mut has_zero_variant = false;
for res in iter {
let discriminant = res?;
if discriminant == 0 {
has_zero_variant = true;
break;
}
}
if !has_zero_variant {
bail!("No variant's discriminant is 0")
}
Ok(())
}
Data::Union(_) => Ok(()),
}
}
fn asserts(
input: &DeriveInput, crate_name: &TokenStream,
) -> Result<TokenStream> {
match &input.data {
Data::Union(_) => Ok(quote!()), // unions are always `Zeroable`
Data::Struct(_) => {
generate_fields_are_trait(input, Self::ident(input, crate_name)?)
}
Data::Enum(_) => Ok(quote!()),
}
}
fn explicit_bounds_attribute_name() -> Option<&'static str> {
Some("zeroable")
}
}
pub struct NoUninit;
impl Derivable for NoUninit {
fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
Ok(syn::parse_quote!(#crate_name::NoUninit))
}
fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> {
let repr = get_repr(attributes)?;
match ty {
Data::Struct(_) => match repr.repr {
Repr::C | Repr::Transparent => Ok(()),
_ => bail!("NoUninit requires the struct to be #[repr(C)] or #[repr(transparent)]"),
},
Data::Enum(_) => if repr.repr.is_integer() {
Ok(())
} else {
bail!("NoUninit requires the enum to be an explicit #[repr(Int)]")
},
Data::Union(_) => bail!("NoUninit can only be derived on enums and structs")
}
}
fn asserts(
input: &DeriveInput, crate_name: &TokenStream,
) -> Result<TokenStream> {
if !input.generics.params.is_empty() {
bail!("NoUninit cannot be derived for structs containing generic parameters because the padding requirements can't be verified for generic structs");
}
match &input.data {
Data::Struct(DataStruct { .. }) => {
let assert_no_padding = generate_assert_no_padding(&input)?;
let assert_fields_are_no_padding =
generate_fields_are_trait(&input, Self::ident(input, crate_name)?)?;
Ok(quote!(
#assert_no_padding
#assert_fields_are_no_padding
))
}
Data::Enum(DataEnum { variants, .. }) => {
if variants.iter().any(|variant| !variant.fields.is_empty()) {
bail!("Only fieldless enums are supported for NoUninit")
} else {
Ok(quote!())
}
}
Data::Union(_) => bail!("NoUninit cannot be derived for unions"), /* shouldn't be possible since we already error in attribute check for this case */
}
}
fn trait_impl(
_input: &DeriveInput, _crate_name: &TokenStream,
) -> Result<(TokenStream, TokenStream)> {
Ok((quote!(), quote!()))
}
}
pub struct CheckedBitPattern;
impl Derivable for CheckedBitPattern {
fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
Ok(syn::parse_quote!(#crate_name::CheckedBitPattern))
}
fn check_attributes(ty: &Data, attributes: &[Attribute]) -> Result<()> {
let repr = get_repr(attributes)?;
match ty {
Data::Struct(_) => match repr.repr {
Repr::C | Repr::Transparent => Ok(()),
_ => bail!("CheckedBitPattern derive requires the struct to be #[repr(C)] or #[repr(transparent)]"),
},
Data::Enum(DataEnum { variants,.. }) => {
if !enum_has_fields(variants.iter()){
if repr.repr.is_integer() {
Ok(())
} else {
bail!("CheckedBitPattern requires the enum to be an explicit #[repr(Int)]")
}
} else if matches!(repr.repr, Repr::Rust) {
bail!("CheckedBitPattern requires an explicit repr annotation because `repr(Rust)` doesn't have a specified type layout")
} else {
Ok(())
}
}
Data::Union(_) => bail!("CheckedBitPattern can only be derived on enums and structs")
}
}
fn asserts(
input: &DeriveInput, crate_name: &TokenStream,
) -> Result<TokenStream> {
if !input.generics.params.is_empty() {
bail!("CheckedBitPattern cannot be derived for structs containing generic parameters");
}
match &input.data {
Data::Struct(DataStruct { .. }) => {
let assert_fields_are_maybe_pod =
generate_fields_are_trait(&input, Self::ident(input, crate_name)?)?;
Ok(assert_fields_are_maybe_pod)
}
Data::Enum(_) => Ok(quote!()), /* nothing needed, already guaranteed
* OK by NoUninit */
Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */
}
}
fn trait_impl(
input: &DeriveInput, crate_name: &TokenStream,
) -> Result<(TokenStream, TokenStream)> {
match &input.data {
Data::Struct(DataStruct { fields, .. }) => {
generate_checked_bit_pattern_struct(
&input.ident,
fields,
&input.attrs,
crate_name,
)
}
Data::Enum(DataEnum { variants, .. }) => {
generate_checked_bit_pattern_enum(input, variants, crate_name)
}
Data::Union(_) => bail!("Internal error in CheckedBitPattern derive"), /* shouldn't be possible since we already error in attribute check for this case */
}
}
}
pub struct TransparentWrapper;
impl TransparentWrapper {
fn get_wrapper_type(
attributes: &[Attribute], fields: &Fields,
) -> Option<TokenStream> {
let transparent_param = get_simple_attr(attributes, "transparent");
transparent_param.map(|ident| ident.to_token_stream()).or_else(|| {
let mut types = get_field_types(&fields);
let first_type = types.next();
if let Some(_) = types.next() {
// can't guess param type if there is more than one field
return None;
} else {
first_type.map(|ty| ty.to_token_stream())
}
})
}
}
impl Derivable for TransparentWrapper {
fn ident(input: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
let fields = get_struct_fields(input)?;
let ty = match Self::get_wrapper_type(&input.attrs, &fields) {
Some(ty) => ty,
None => bail!(
"\
when deriving TransparentWrapper for a struct with more than one field \
you need to specify the transparent field using #[transparent(T)]\
"
),
};
Ok(syn::parse_quote!(#crate_name::TransparentWrapper<#ty>))
}
fn asserts(
input: &DeriveInput, crate_name: &TokenStream,
) -> Result<TokenStream> {
let (impl_generics, _ty_generics, where_clause) =
input.generics.split_for_impl();
let fields = get_struct_fields(input)?;
let wrapped_type = match Self::get_wrapper_type(&input.attrs, &fields) {
Some(wrapped_type) => wrapped_type.to_string(),
None => unreachable!(), /* other code will already reject this derive */
};
let mut wrapped_field_ty = None;
let mut nonwrapped_field_tys = vec![];
for field in fields.iter() {
let field_ty = &field.ty;
if field_ty.to_token_stream().to_string() == wrapped_type {
if wrapped_field_ty.is_some() {
bail!(
"TransparentWrapper can only have one field of the wrapped type"
);
}
wrapped_field_ty = Some(field_ty);
} else {
nonwrapped_field_tys.push(field_ty);
}
}
if let Some(wrapped_field_ty) = wrapped_field_ty {
Ok(quote!(
const _: () = {
#[repr(transparent)]
#[allow(clippy::multiple_bound_locations)]
struct AssertWrappedIsWrapped #impl_generics((u8, ::core::marker::PhantomData<#wrapped_field_ty>), #(#nonwrapped_field_tys),*) #where_clause;
fn assert_zeroable<Z: #crate_name::Zeroable>() {}
#[allow(clippy::multiple_bound_locations)]
fn check #impl_generics () #where_clause {
#(
assert_zeroable::<#nonwrapped_field_tys>();
)*
}
};
))
} else {
bail!("TransparentWrapper must have one field of the wrapped type")
}
}
fn check_attributes(_ty: &Data, attributes: &[Attribute]) -> Result<()> {
let repr = get_repr(attributes)?;
match repr.repr {
Repr::Transparent => Ok(()),
_ => {
bail!(
"TransparentWrapper requires the struct to be #[repr(transparent)]"
)
}
}
}
fn requires_where_clause() -> bool {
false
}
}
pub struct Contiguous;
impl Derivable for Contiguous {
fn ident(_: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
Ok(syn::parse_quote!(#crate_name::Contiguous))
}
fn trait_impl(
input: &DeriveInput, _crate_name: &TokenStream,
) -> Result<(TokenStream, TokenStream)> {
let repr = get_repr(&input.attrs)?;
let integer_ty = if let Some(integer_ty) = repr.repr.as_integer() {
integer_ty
} else {
bail!("Contiguous requires the enum to be #[repr(Int)]");
};
let variants = get_enum_variants(input)?;
if enum_has_fields(variants.clone()) {
return Err(Error::new_spanned(
&input,
"Only fieldless enums are supported",
));
}
let mut variants_with_discriminator =
VariantDiscriminantIterator::new(variants);
let (min, max, count) = variants_with_discriminator.try_fold(
(i64::max_value(), i64::min_value(), 0),
|(min, max, count), res| {
let discriminator = res?;
Ok::<_, Error>((
i64::min(min, discriminator),
i64::max(max, discriminator),
count + 1,
))
},
)?;
if max - min != count - 1 {
bail! {
"Contiguous requires the enum discriminants to be contiguous",
}
}
let min_lit = LitInt::new(&format!("{}", min), input.span());
let max_lit = LitInt::new(&format!("{}", max), input.span());
// `from_integer` and `into_integer` are usually provided by the trait's
// default implementation. We override this implementation because it
// goes through `transmute_copy`, which can lead to inefficient assembly as seen in https://github.com/Lokathor/bytemuck/issues/175 .
Ok((
quote!(),
quote! {
type Int = #integer_ty;
#[allow(clippy::missing_docs_in_private_items)]
const MIN_VALUE: #integer_ty = #min_lit;
#[allow(clippy::missing_docs_in_private_items)]
const MAX_VALUE: #integer_ty = #max_lit;
#[inline]
fn from_integer(value: Self::Int) -> Option<Self> {
#[allow(clippy::manual_range_contains)]
if Self::MIN_VALUE <= value && value <= Self::MAX_VALUE {
Some(unsafe { ::core::mem::transmute(value) })
} else {
None
}
}
#[inline]
fn into_integer(self) -> Self::Int {
self as #integer_ty
}
},
))
}
}
fn get_struct_fields(input: &DeriveInput) -> Result<&Fields> {
if let Data::Struct(DataStruct { fields, .. }) = &input.data {
Ok(fields)
} else {
bail!("deriving this trait is only supported for structs")
}
}
fn get_fields(input: &DeriveInput) -> Result<Fields> {
match &input.data {
Data::Struct(DataStruct { fields, .. }) => Ok(fields.clone()),
Data::Union(DataUnion { fields, .. }) => Ok(Fields::Named(fields.clone())),
Data::Enum(_) => bail!("deriving this trait is not supported for enums"),
}
}
fn get_enum_variants<'a>(
input: &'a DeriveInput,
) -> Result<impl Iterator<Item = &'a Variant> + Clone + 'a> {
if let Data::Enum(DataEnum { variants, .. }) = &input.data {
Ok(variants.iter())
} else {
bail!("deriving this trait is only supported for enums")
}
}
fn get_field_types<'a>(
fields: &'a Fields,
) -> impl Iterator<Item = &'a Type> + 'a {
fields.iter().map(|field| &field.ty)
}
fn generate_checked_bit_pattern_struct(
input_ident: &Ident, fields: &Fields, attrs: &[Attribute],
crate_name: &TokenStream,
) -> Result<(TokenStream, TokenStream)> {
let bits_ty = Ident::new(&format!("{}Bits", input_ident), input_ident.span());
let repr = get_repr(attrs)?;
let field_names = fields
.iter()
.enumerate()
.map(|(i, field)| {
field.ident.clone().unwrap_or_else(|| {
Ident::new(&format!("field{}", i), input_ident.span())
})
})
.collect::<Vec<_>>();
let field_tys = fields.iter().map(|field| &field.ty).collect::<Vec<_>>();
let field_name = &field_names[..];
let field_ty = &field_tys[..];
let derive_dbg =
quote!(#[cfg_attr(not(target_arch = "spirv"), derive(Debug))]);
Ok((
quote! {
#[doc = #GENERATED_TYPE_DOCUMENTATION]
#repr
#[derive(Clone, Copy, #crate_name::AnyBitPattern)]
#derive_dbg
#[allow(missing_docs)]
pub struct #bits_ty {
#(#field_name: <#field_ty as #crate_name::CheckedBitPattern>::Bits,)*
}
},
quote! {
type Bits = #bits_ty;
#[inline]
#[allow(clippy::double_comparisons, unused)]
fn is_valid_bit_pattern(bits: &#bits_ty) -> bool {
#(<#field_ty as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(&{ bits.#field_name }) && )* true
}
},
))
}
fn generate_checked_bit_pattern_enum(
input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>,
crate_name: &TokenStream,
) -> Result<(TokenStream, TokenStream)> {
if enum_has_fields(variants.iter()) {
generate_checked_bit_pattern_enum_with_fields(input, variants, crate_name)
} else {
generate_checked_bit_pattern_enum_without_fields(input, variants)
}
}
fn generate_checked_bit_pattern_enum_without_fields(
input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>,
) -> Result<(TokenStream, TokenStream)> {
let span = input.span();
let mut variants_with_discriminant =
VariantDiscriminantIterator::new(variants.iter());
let (min, max, count) = variants_with_discriminant.try_fold(
(i64::max_value(), i64::min_value(), 0),
|(min, max, count), res| {
let discriminant = res?;
Ok::<_, Error>((
i64::min(min, discriminant),
i64::max(max, discriminant),
count + 1,
))
},
)?;
let check = if count == 0 {
quote_spanned!(span => false)
} else if max - min == count - 1 {
// contiguous range
let min_lit = LitInt::new(&format!("{}", min), span);
let max_lit = LitInt::new(&format!("{}", max), span);
quote!(*bits >= #min_lit && *bits <= #max_lit)
} else {
// not contiguous range, check for each
let variant_lits = VariantDiscriminantIterator::new(variants.iter())
.map(|res| {
let variant = res?;
Ok(LitInt::new(&format!("{}", variant), span))
})
.collect::<Result<Vec<_>>>()?;
// count is at least 1
let first = &variant_lits[0];
let rest = &variant_lits[1..];
quote!(matches!(*bits, #first #(| #rest )*))
};
let repr = get_repr(&input.attrs)?;
let integer = repr.repr.as_integer().unwrap(); // should be checked in attr check already
Ok((
quote!(),
quote! {
type Bits = #integer;
#[inline]
#[allow(clippy::double_comparisons)]
fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
#check
}
},
))
}
fn generate_checked_bit_pattern_enum_with_fields(
input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>,
crate_name: &TokenStream,
) -> Result<(TokenStream, TokenStream)> {
let representation = get_repr(&input.attrs)?;
let vis = &input.vis;
let derive_dbg =
quote!(#[cfg_attr(not(target_arch = "spirv"), derive(Debug))]);
match representation.repr {
Repr::Rust => unreachable!(),
repr @ (Repr::C | Repr::CWithDiscriminant(_)) => {
let integer = match repr {
Repr::C => quote!(::core::ffi::c_int),
Repr::CWithDiscriminant(integer) => quote!(#integer),
_ => unreachable!(),
};
let input_ident = &input.ident;
let bits_repr = Representation { repr: Repr::C, ..representation };
// the enum manually re-configured as the actual tagged union it
// represents, thus circumventing the requirements rust imposes on
// the tag even when using #[repr(C)] enum layout
// see: https://doc.rust-lang.org/reference/type-layout.html#reprc-enums-with-fields
let bits_ty_ident =
Ident::new(&format!("{input_ident}Bits"), input.span());
// the variants union part of the tagged union. These get put into a union
// which gets the AnyBitPattern derive applied to it, thus checking
// that the fields of the union obey the requriements of AnyBitPattern.
// The types that actually go in the union are one more level of
// indirection deep: we generate new structs for each variant
// (`variant_struct_definitions`) which themselves have the
// `CheckedBitPattern` derive applied, thus generating
// `{variant_struct_ident}Bits` structs, which are the ones that go
// into this union.
let variants_union_ident =
Ident::new(&format!("{}Variants", input.ident), input.span());
let variant_struct_idents = variants.iter().map(|v| {
Ident::new(&format!("{input_ident}Variant{}", v.ident), v.span())
});
let variant_struct_definitions =
variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| {
let fields = v.fields.iter().map(|v| &v.ty);
quote! {
#[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::CheckedBitPattern)]
#[repr(C)]
#vis struct #variant_struct_ident(#(#fields),*);
}
});
let union_fields = variant_struct_idents
.clone()
.zip(variants.iter())
.map(|(variant_struct_ident, v)| {
let variant_struct_bits_ident =
Ident::new(&format!("{variant_struct_ident}Bits"), input.span());
let field_ident = &v.ident;
quote! {
#field_ident: #variant_struct_bits_ident
}
});
let variant_checks = variant_struct_idents
.clone()
.zip(VariantDiscriminantIterator::new(variants.iter()))
.zip(variants.iter())
.map(|((variant_struct_ident, discriminant), v)| -> Result<_> {
let discriminant = discriminant?;
let discriminant = LitInt::new(&discriminant.to_string(), v.span());
let ident = &v.ident;
Ok(quote! {
#discriminant => {
let payload = unsafe { &bits.payload.#ident };
<#variant_struct_ident as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(payload)
}
})
})
.collect::<Result<Vec<_>>>()?;
Ok((
quote! {
#[doc = #GENERATED_TYPE_DOCUMENTATION]
#[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::AnyBitPattern)]
#derive_dbg
#bits_repr
#vis struct #bits_ty_ident {
tag: #integer,
payload: #variants_union_ident,
}
#[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::AnyBitPattern)]
#[repr(C)]
#[allow(non_snake_case)]
#vis union #variants_union_ident {
#(#union_fields,)*
}
#[cfg(not(target_arch = "spirv"))]
impl ::core::fmt::Debug for #variants_union_ident {
fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#variants_union_ident));
::core::fmt::DebugStruct::finish_non_exhaustive(&mut debug_struct)
}
}
#(#variant_struct_definitions)*
},
quote! {
type Bits = #bits_ty_ident;
#[inline]
#[allow(clippy::double_comparisons)]
fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
match bits.tag {
#(#variant_checks)*
_ => false,
}
}
},
))
}
Repr::Transparent => {
if variants.len() != 1 {
bail!("enums with more than one variant cannot be transparent")
}
let variant = &variants[0];
let bits_ty = Ident::new(&format!("{}Bits", input.ident), input.span());
let fields = variant.fields.iter().map(|v| &v.ty);
Ok((
quote! {
#[doc = #GENERATED_TYPE_DOCUMENTATION]
#[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::CheckedBitPattern)]
#[repr(C)]
#vis struct #bits_ty(#(#fields),*);
},
quote! {
type Bits = <#bits_ty as #crate_name::CheckedBitPattern>::Bits;
#[inline]
#[allow(clippy::double_comparisons)]
fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
<#bits_ty as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(bits)
}
},
))
}
Repr::Integer(integer) => {
let bits_repr = Representation { repr: Repr::C, ..representation };
let input_ident = &input.ident;
// the enum manually re-configured as the union it represents. such a
// union is the union of variants as a repr(c) struct with the
// discriminator type inserted at the beginning. in our case we
// union the `Bits` representation of each variant rather than the variant
// itself, which we generate via a nested `CheckedBitPattern` derive
// on the `variant_struct_definitions` generated below.
//
// see: https://doc.rust-lang.org/reference/type-layout.html#primitive-representation-of-enums-with-fields
let bits_ty_ident =
Ident::new(&format!("{input_ident}Bits"), input.span());
let variant_struct_idents = variants.iter().map(|v| {
Ident::new(&format!("{input_ident}Variant{}", v.ident), v.span())
});
let variant_struct_definitions =
variant_struct_idents.clone().zip(variants.iter()).map(|(variant_struct_ident, v)| {
let fields = v.fields.iter().map(|v| &v.ty);
// adding the discriminant repr integer as first field, as described above
quote! {
#[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::CheckedBitPattern)]
#[repr(C)]
#vis struct #variant_struct_ident(#integer, #(#fields),*);
}
});
let union_fields = variant_struct_idents
.clone()
.zip(variants.iter())
.map(|(variant_struct_ident, v)| {
let variant_struct_bits_ident =
Ident::new(&format!("{variant_struct_ident}Bits"), input.span());
let field_ident = &v.ident;
quote! {
#field_ident: #variant_struct_bits_ident
}
});
let variant_checks = variant_struct_idents
.clone()
.zip(VariantDiscriminantIterator::new(variants.iter()))
.zip(variants.iter())
.map(|((variant_struct_ident, discriminant), v)| -> Result<_> {
let discriminant = discriminant?;
let discriminant = LitInt::new(&discriminant.to_string(), v.span());
let ident = &v.ident;
Ok(quote! {
#discriminant => {
let payload = unsafe { &bits.#ident };
<#variant_struct_ident as #crate_name::CheckedBitPattern>::is_valid_bit_pattern(payload)
}
})
})
.collect::<Result<Vec<_>>>()?;
Ok((
quote! {
#[doc = #GENERATED_TYPE_DOCUMENTATION]
#[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::AnyBitPattern)]
#bits_repr
#[allow(non_snake_case)]
#vis union #bits_ty_ident {
__tag: #integer,
#(#union_fields,)*
}
#[cfg(not(target_arch = "spirv"))]
impl ::core::fmt::Debug for #bits_ty_ident {
fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
let mut debug_struct = ::core::fmt::Formatter::debug_struct(f, ::core::stringify!(#bits_ty_ident));
::core::fmt::DebugStruct::field(&mut debug_struct, "tag", unsafe { &self.__tag });
::core::fmt::DebugStruct::finish_non_exhaustive(&mut debug_struct)
}
}
#(#variant_struct_definitions)*
},
quote! {
type Bits = #bits_ty_ident;
#[inline]
#[allow(clippy::double_comparisons)]
fn is_valid_bit_pattern(bits: &Self::Bits) -> bool {
match unsafe { bits.__tag } {
#(#variant_checks)*
_ => false,
}
}
},
))
}
}
}
/// Check that a struct has no padding by asserting that the size of the struct
/// is equal to the sum of the size of it's fields
fn generate_assert_no_padding(input: &DeriveInput) -> Result<TokenStream> {
let struct_type = &input.ident;
let span = input.ident.span();
let fields = get_fields(input)?;
let mut field_types = get_field_types(&fields);
let size_sum = if let Some(first) = field_types.next() {
let size_first = quote_spanned!(span => ::core::mem::size_of::<#first>());
let size_rest =
quote_spanned!(span => #( + ::core::mem::size_of::<#field_types>() )*);
quote_spanned!(span => #size_first #size_rest)
} else {
quote_spanned!(span => 0)
};
Ok(quote_spanned! {span => const _: fn() = || {
#[doc(hidden)]
struct TypeWithoutPadding([u8; #size_sum]);
let _ = ::core::mem::transmute::<#struct_type, TypeWithoutPadding>;
};})
}
/// Check that all fields implement a given trait
fn generate_fields_are_trait(
input: &DeriveInput, trait_: syn::Path,
) -> Result<TokenStream> {
let (impl_generics, _ty_generics, where_clause) =
input.generics.split_for_impl();
let fields = get_fields(input)?;
let span = input.span();
let field_types = get_field_types(&fields);
Ok(quote_spanned! {span => #(const _: fn() = || {
#[allow(clippy::missing_const_for_fn)]
#[doc(hidden)]
fn check #impl_generics () #where_clause {
fn assert_impl<T: #trait_>() {}
assert_impl::<#field_types>();
}
};)*
})
}
fn get_ident_from_stream(tokens: TokenStream) -> Option<Ident> {
match tokens.into_iter().next() {
Some(TokenTree::Group(group)) => get_ident_from_stream(group.stream()),
Some(TokenTree::Ident(ident)) => Some(ident),
_ => None,
}
}
/// get a simple #[foo(bar)] attribute, returning "bar"
fn get_simple_attr(attributes: &[Attribute], attr_name: &str) -> Option<Ident> {
for attr in attributes {
if let (AttrStyle::Outer, Meta::List(list)) = (&attr.style, &attr.meta) {
if list.path.is_ident(attr_name) {
if let Some(ident) = get_ident_from_stream(list.tokens.clone()) {
return Some(ident);
}
}
}
}
None
}
fn get_repr(attributes: &[Attribute]) -> Result<Representation> {
attributes
.iter()
.filter_map(|attr| {
if attr.path().is_ident("repr") {
Some(attr.parse_args::<Representation>())
} else {
None
}
})
.try_fold(Representation::default(), |a, b| {
let b = b?;
Ok(Representation {
repr: match (a.repr, b.repr) {
(a, Repr::Rust) => a,
(Repr::Rust, b) => b,
_ => bail!("conflicting representation hints"),
},
packed: match (a.packed, b.packed) {
(a, None) => a,
(None, b) => b,
_ => bail!("conflicting representation hints"),
},
align: match (a.align, b.align) {
(Some(a), Some(b)) => Some(cmp::max(a, b)),
(a, None) => a,
(None, b) => b,
},
})
})
}
mk_repr! {
U8 => u8,
I8 => i8,
U16 => u16,
I16 => i16,
U32 => u32,
I32 => i32,
U64 => u64,
I64 => i64,
I128 => i128,
U128 => u128,
Usize => usize,
Isize => isize,
}
// where
macro_rules! mk_repr {(
$(
$Xn:ident => $xn:ident
),* $(,)?
) => (
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum IntegerRepr {
$($Xn),*
}
impl<'a> TryFrom<&'a str> for IntegerRepr {
type Error = &'a str;
fn try_from(value: &'a str) -> std::result::Result<Self, &'a str> {
match value {
$(
stringify!($xn) => Ok(Self::$Xn),
)*
_ => Err(value),
}
}
}
impl ToTokens for IntegerRepr {
fn to_tokens(&self, tokens: &mut TokenStream) {
match self {
$(
Self::$Xn => tokens.extend(quote!($xn)),
)*
}
}
}
)}
use mk_repr;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Repr {
Rust,
C,
Transparent,
Integer(IntegerRepr),
CWithDiscriminant(IntegerRepr),
}
impl Repr {
fn is_integer(&self) -> bool {
matches!(self, Self::Integer(..))
}
fn as_integer(&self) -> Option<IntegerRepr> {
if let Self::Integer(v) = self {
Some(*v)
} else {
None
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct Representation {
packed: Option<u32>,
align: Option<u32>,
repr: Repr,
}
impl Default for Representation {
fn default() -> Self {
Self { packed: None, align: None, repr: Repr::Rust }
}
}
impl Parse for Representation {
fn parse(input: ParseStream<'_>) -> Result<Representation> {
let mut ret = Representation::default();
while !input.is_empty() {
let keyword = input.parse::<Ident>()?;
// preƫmptively call `.to_string()` *once* (rather than on `is_ident()`)
let keyword_str = keyword.to_string();
let new_repr = match keyword_str.as_str() {
"C" => Repr::C,
"transparent" => Repr::Transparent,
"packed" => {
ret.packed = Some(if input.peek(token::Paren) {
let contents;
parenthesized!(contents in input);
LitInt::base10_parse::<u32>(&contents.parse()?)?
} else {
1
});
let _: Option<Token![,]> = input.parse()?;
continue;
}
"align" => {
let contents;
parenthesized!(contents in input);
let new_align = LitInt::base10_parse::<u32>(&contents.parse()?)?;
ret.align = Some(
ret
.align
.map_or(new_align, |old_align| cmp::max(old_align, new_align)),
);
let _: Option<Token![,]> = input.parse()?;
continue;
}
ident => {
let primitive = IntegerRepr::try_from(ident)
.map_err(|_| input.error("unrecognized representation hint"))?;
Repr::Integer(primitive)
}
};
ret.repr = match (ret.repr, new_repr) {
(Repr::Rust, new_repr) => {
// This is the first explicit repr.
new_repr
}
(Repr::C, Repr::Integer(integer))
| (Repr::Integer(integer), Repr::C) => {
// Both the C repr and an integer repr have been specified
// -> merge into a C wit discriminant.
Repr::CWithDiscriminant(integer)
}
(_, _) => {
return Err(input.error("duplicate representation hint"));
}
};
let _: Option<Token![,]> = input.parse()?;
}
Ok(ret)
}
}
impl ToTokens for Representation {
fn to_tokens(&self, tokens: &mut TokenStream) {
let mut meta = Punctuated::<_, Token![,]>::new();
match self.repr {
Repr::Rust => {}
Repr::C => meta.push(quote!(C)),
Repr::Transparent => meta.push(quote!(transparent)),
Repr::Integer(primitive) => meta.push(quote!(#primitive)),
Repr::CWithDiscriminant(primitive) => {
meta.push(quote!(C));
meta.push(quote!(#primitive));
}
}
if let Some(packed) = self.packed.as_ref() {
let lit = LitInt::new(&packed.to_string(), Span::call_site());
meta.push(quote!(packed(#lit)));
}
if let Some(align) = self.align.as_ref() {
let lit = LitInt::new(&align.to_string(), Span::call_site());
meta.push(quote!(align(#lit)));
}
tokens.extend(quote!(
#[repr(#meta)]
));
}
}
fn enum_has_fields<'a>(
mut variants: impl Iterator<Item = &'a Variant>,
) -> bool {
variants.any(|v| matches!(v.fields, Fields::Named(_) | Fields::Unnamed(_)))
}
struct VariantDiscriminantIterator<'a, I: Iterator<Item = &'a Variant> + 'a> {
inner: I,
last_value: i64,
}
impl<'a, I: Iterator<Item = &'a Variant> + 'a>
VariantDiscriminantIterator<'a, I>
{
fn new(inner: I) -> Self {
VariantDiscriminantIterator { inner, last_value: -1 }
}
}
impl<'a, I: Iterator<Item = &'a Variant> + 'a> Iterator
for VariantDiscriminantIterator<'a, I>
{
type Item = Result<i64>;
fn next(&mut self) -> Option<Self::Item> {
let variant = self.inner.next()?;
if let Some((_, discriminant)) = &variant.discriminant {
let discriminant_value = match parse_int_expr(discriminant) {
Ok(value) => value,
Err(e) => return Some(Err(e)),
};
self.last_value = discriminant_value;
} else {
self.last_value += 1;
}
Some(Ok(self.last_value))
}
}
fn parse_int_expr(expr: &Expr) -> Result<i64> {
match expr {
Expr::Unary(ExprUnary { op: UnOp::Neg(_), expr, .. }) => {
parse_int_expr(expr).map(|int| -int)
}
Expr::Lit(ExprLit { lit: Lit::Int(int), .. }) => int.base10_parse(),
Expr::Lit(ExprLit { lit: Lit::Byte(byte), .. }) => Ok(byte.value().into()),
_ => bail!("Not an integer expression"),
}
}
#[cfg(test)]
mod tests {
use syn::parse_quote;
use super::{get_repr, IntegerRepr, Repr, Representation};
#[test]
fn parse_basic_repr() {
let attr = parse_quote!(#[repr(C)]);
let repr = get_repr(&[attr]).unwrap();
assert_eq!(repr, Representation { repr: Repr::C, ..Default::default() });
let attr = parse_quote!(#[repr(transparent)]);
let repr = get_repr(&[attr]).unwrap();
assert_eq!(
repr,
Representation { repr: Repr::Transparent, ..Default::default() }
);
let attr = parse_quote!(#[repr(u8)]);
let repr = get_repr(&[attr]).unwrap();
assert_eq!(
repr,
Representation {
repr: Repr::Integer(IntegerRepr::U8),
..Default::default()
}
);
let attr = parse_quote!(#[repr(packed)]);
let repr = get_repr(&[attr]).unwrap();
assert_eq!(repr, Representation { packed: Some(1), ..Default::default() });
let attr = parse_quote!(#[repr(packed(1))]);
let repr = get_repr(&[attr]).unwrap();
assert_eq!(repr, Representation { packed: Some(1), ..Default::default() });
let attr = parse_quote!(#[repr(packed(2))]);
let repr = get_repr(&[attr]).unwrap();
assert_eq!(repr, Representation { packed: Some(2), ..Default::default() });
let attr = parse_quote!(#[repr(align(2))]);
let repr = get_repr(&[attr]).unwrap();
assert_eq!(repr, Representation { align: Some(2), ..Default::default() });
}
#[test]
fn parse_advanced_repr() {
let attr = parse_quote!(#[repr(align(4), align(2))]);
let repr = get_repr(&[attr]).unwrap();
assert_eq!(repr, Representation { align: Some(4), ..Default::default() });
let attr1 = parse_quote!(#[repr(align(1))]);
let attr2 = parse_quote!(#[repr(align(4))]);
let attr3 = parse_quote!(#[repr(align(2))]);
let repr = get_repr(&[attr1, attr2, attr3]).unwrap();
assert_eq!(repr, Representation { align: Some(4), ..Default::default() });
let attr = parse_quote!(#[repr(C, u8)]);
let repr = get_repr(&[attr]).unwrap();
assert_eq!(
repr,
Representation {
repr: Repr::CWithDiscriminant(IntegerRepr::U8),
..Default::default()
}
);
let attr = parse_quote!(#[repr(u8, C)]);
let repr = get_repr(&[attr]).unwrap();
assert_eq!(
repr,
Representation {
repr: Repr::CWithDiscriminant(IntegerRepr::U8),
..Default::default()
}
);
}
}
pub fn bytemuck_crate_name(input: &DeriveInput) -> TokenStream {
const ATTR_NAME: &'static str = "crate";
let mut crate_name = quote!(::bytemuck);
for attr in &input.attrs {
if !attr.path().is_ident("bytemuck") {
continue;
}
attr.parse_nested_meta(|meta| {
if meta.path.is_ident(ATTR_NAME) {
let expr: syn::Expr = meta.value()?.parse()?;
let mut value = &expr;
while let syn::Expr::Group(e) = value {
value = &e.expr;
}
if let syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(lit), ..
}) = value
{
let suffix = lit.suffix();
if !suffix.is_empty() {
bail!(format!("Unexpected suffix `{}` on string literal", suffix))
}
let path: syn::Path = match lit.parse() {
Ok(path) => path,
Err(_) => {
bail!(format!("Failed to parse path: {:?}", lit.value()))
}
};
crate_name = path.into_token_stream();
} else {
bail!(
"Expected bytemuck `crate` attribute to be a string: `crate = \"...\"`",
)
}
}
Ok(())
}).unwrap();
}
return crate_name;
}
const GENERATED_TYPE_DOCUMENTATION: &str =
" `bytemuck`-generated type for internal purposes only.";