Skip to content

Commit

Permalink
Improve attribute validation.
Browse files Browse the repository at this point in the history
  • Loading branch information
wojciech-graj committed Nov 10, 2024
1 parent 30d9501 commit 51dc21f
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 152 deletions.
223 changes: 96 additions & 127 deletions bin-proto-derive/src/attr.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
use proc_macro2::{Span, TokenStream};
use syn::{parenthesized, punctuated::Punctuated, spanned::Spanned, token::Comma, Error, Result};
use std::fmt;

use proc_macro2::TokenStream;
use syn::{parenthesized, punctuated::Punctuated, spanned::Spanned, Error, Result, Token};

#[derive(Default)]
pub struct Attrs {
pub discriminant_type: Option<syn::Type>,
pub discriminant: Option<syn::Expr>,
pub ctx: Option<syn::Type>,
pub ctx: Option<Ctx>,
pub ctx_generics: Option<Vec<syn::GenericParam>>,
pub ctx_bounds: Option<Vec<syn::TypeParamBound>>,
pub write_value: Option<syn::Expr>,
pub bits: Option<syn::Expr>,
pub flexible_array_member: bool,
pub tag: Option<Tag>,
}

pub enum Ctx {
Concrete(syn::Type),
Bounds(Vec<syn::TypeParamBound>),
}

pub enum Tag {
External(syn::Expr),
Prepend {
Expand All @@ -22,179 +28,119 @@ pub enum Tag {
},
}

impl Attrs {
#[allow(clippy::too_many_lines)]
pub fn validate_enum(&self, span: Span) -> Result<()> {
if self.discriminant_type.is_none() {
return Err(Error::new(
span,
"expected discriminant_type attribute for enum",
));
}
if self.discriminant.is_some() {
return Err(Error::new(
span,
"unexpected discriminant attribute for enum",
));
}
if self.ctx.is_some() && self.ctx_bounds.is_some() {
return Err(Error::new(
span,
"cannot specify ctx and ctx_bounds simultaneously",
));
}
if self.write_value.is_some() {
return Err(Error::new(
span,
"unexpected write_value attribute for enum",
));
}
if self.flexible_array_member {
return Err(Error::new(
span,
"unexpected flexible_array_member attribute for enum",
));
}
if self.tag.is_some() {
return Err(Error::new(span, "unexpected tag attribute for enum"));
}
Ok(())
}
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum AttrKind {
Enum,
Struct,
Variant,
Field,
}

pub fn validate_variant(&self, span: Span) -> Result<()> {
if self.discriminant_type.is_some() {
return Err(Error::new(
span,
"unexpected discriminant_type attribute for variant",
));
}
if self.ctx.is_some() {
return Err(Error::new(span, "unexpected ctx attribute for variant"));
}
if self.ctx_bounds.is_some() {
return Err(Error::new(
span,
"unexpected ctx_bounds attribute for variant",
));
}
if self.write_value.is_some() {
return Err(Error::new(
span,
"unexpected write_value attribute for variant",
));
}
if self.bits.is_some() {
return Err(Error::new(span, "unexpected bits attribute for variant"));
impl fmt::Display for AttrKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
AttrKind::Enum => write!(f, "enum"),
AttrKind::Struct => write!(f, "struct"),
AttrKind::Variant => write!(f, "variant"),
AttrKind::Field => write!(f, "field"),
}
if self.flexible_array_member {
return Err(Error::new(
span,
"unexpected flexible_array_member attribute for variant",
));
}
if self.tag.is_some() {
return Err(Error::new(span, "unexpected tag attribute for variant"));
}
Ok(())
}
}

pub fn validate_field(&self, span: Span) -> Result<()> {
if self.discriminant_type.is_some() {
return Err(Error::new(
span,
"unexpected discriminant_type attribute for field",
));
}
if self.discriminant.is_some() {
return Err(Error::new(
span,
"unexpected discriminant attribute for field",
));
}
if self.ctx.is_some() {
return Err(Error::new(span, "unexpected ctx attribute for variant"));
}
if self.ctx_bounds.is_some() {
return Err(Error::new(
span,
"unexpected ctx_bounds attribute for variant",
));
}
if [
self.bits.is_some(),
self.flexible_array_member,
self.tag.is_some(),
]
.iter()
.filter(|b| **b)
.count()
> 1
{
return Err(Error::new(
span,
"bits, flexible_array_member, and tag are mutually-exclusive attributes",
));
macro_rules! validate_attr_kind {
($pat:pat, $kind:expr, $meta:expr, $attr:expr) => {
if let Some(kind) = $kind {
if !matches!(kind, $pat) {
return Err($meta.error(format!(
"attribute '{}' cannot be applied to {}",
$attr, kind
)));
}
}
Ok(())
}
};
}

impl Attrs {
pub fn ctx_ty(&self) -> TokenStream {
self.ctx
.as_ref()
.map(|ctx| quote!(#ctx))
.unwrap_or(quote!(__Ctx))
if let Some(Ctx::Concrete(ctx)) = &self.ctx {
quote!(#ctx)
} else {
quote!(__Ctx)
}
}
}

impl TryFrom<&[syn::Attribute]> for Attrs {
type Error = syn::Error;

fn try_from(attrs: &[syn::Attribute]) -> Result<Self> {
#[allow(clippy::too_many_lines)]
pub fn for_kind(attrs: &[syn::Attribute], kind: Option<AttrKind>) -> Result<Self> {
let mut attribs = Attrs::default();

let mut tag = None;
let mut tag_type = None;
let mut tag_value = None;

let mut ctx = None;
let mut ctx_bounds = None;

for attr in attrs {
if attr.path().is_ident("protocol") {
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("flexible_array_member") {
validate_attr_kind!(AttrKind::Field, kind, meta, "flexible_array_member");
attribs.flexible_array_member = true;
} else if meta.path.is_ident("discriminant_type") {
validate_attr_kind!(AttrKind::Enum, kind, meta, "discriminant_type");
attribs.discriminant_type = Some(meta.value()?.parse()?);
} else if meta.path.is_ident("discriminant") {
validate_attr_kind!(AttrKind::Variant, kind, meta, "discriminant");
attribs.discriminant = Some(meta.value()?.parse()?);
} else if meta.path.is_ident("ctx") {
attribs.ctx = Some(meta.value()?.parse()?);
validate_attr_kind!(AttrKind::Enum | AttrKind::Struct, kind, meta, "ctx");
ctx = Some(meta.value()?.parse()?);
} else if meta.path.is_ident("ctx_generics") {
validate_attr_kind!(
AttrKind::Enum | AttrKind::Struct,
kind,
meta,
"ctx_generics"
);
let content;
parenthesized!(content in meta.input);
attribs.ctx_generics = Some(
Punctuated::<syn::GenericParam, Comma>::parse_separated_nonempty(
Punctuated::<syn::GenericParam, Token![,]>::parse_separated_nonempty(
&content,
)?
.into_iter()
.collect(),
);
} else if meta.path.is_ident("ctx_bounds") {
validate_attr_kind!(
AttrKind::Enum | AttrKind::Struct,
kind,
meta,
"ctx_bounds"
);
let content;
parenthesized!(content in meta.input);
attribs.ctx_bounds = Some(
Punctuated::<syn::TypeParamBound, Comma>::parse_separated_nonempty(
ctx_bounds = Some(
Punctuated::<syn::TypeParamBound, Token![,]>::parse_separated_nonempty(
&content,
)?
.into_iter()
.collect(),
);
} else if meta.path.is_ident("bits") {
validate_attr_kind!(AttrKind::Enum | AttrKind::Field, kind, meta, "bits");
attribs.bits = Some(meta.value()?.parse()?);
} else if meta.path.is_ident("write_value") {
validate_attr_kind!(AttrKind::Field, kind, meta, "write_value");
attribs.write_value = Some(meta.value()?.parse()?);
} else if meta.path.is_ident("tag") {
validate_attr_kind!(AttrKind::Field, kind, meta, "tag");
tag = Some(meta.value()?.parse()?);
} else if meta.path.is_ident("tag_type") {
validate_attr_kind!(AttrKind::Field, kind, meta, "tag_type");
tag_type = Some(meta.value()?.parse()?);
} else if meta.path.is_ident("tag_value") {
validate_attr_kind!(AttrKind::Field, kind, meta, "tag_value");
tag_value = Some(meta.value()?.parse()?);
} else {
return Err(meta.error("unrecognized protocol"));
Expand All @@ -216,6 +162,29 @@ impl TryFrom<&[syn::Attribute]> for Attrs {
_ => return Err(Error::new(attrs[0].span(), "TODO")),
}

match (ctx, ctx_bounds) {
(Some(ctx), None) => attribs.ctx = Some(Ctx::Concrete(ctx)),
(None, Some(ctx_bounds)) => attribs.ctx = Some(Ctx::Bounds(ctx_bounds)),
(None, None) => {}
_ => return Err(Error::new(attrs[0].span(), "TODO")),
}

if [
attribs.bits.is_some(),
attribs.flexible_array_member,
attribs.tag.is_some(),
]
.iter()
.filter(|b| **b)
.count()
> 1
{
return Err(Error::new(
attrs[0].span(),
"bits, flexible_array_member, and tag are mutually-exclusive attributes",
));
}

Ok(attribs)
}
}
9 changes: 3 additions & 6 deletions bin-proto-derive/src/codegen/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
pub mod enums;
pub mod trait_impl;

use crate::attr::{Attrs, Tag};
use crate::attr::{AttrKind, Attrs, Tag};
use proc_macro2::TokenStream;
use syn::{spanned::Spanned, Error};

Expand Down Expand Up @@ -54,13 +54,10 @@ fn read_named_fields(fields_named: &syn::FieldsNamed) -> (TokenStream, TokenStre
}

fn read(field: &syn::Field) -> TokenStream {
let attribs = match Attrs::try_from(field.attrs.as_slice()) {
let attribs = match Attrs::for_kind(field.attrs.as_slice(), Some(AttrKind::Field)) {
Ok(attribs) => attribs,
Err(e) => return e.to_compile_error(),
};
if let Err(e) = attribs.validate_field(field.span()) {
return e.to_compile_error();
};

if let Some(field_width) = attribs.bits {
quote!(::bin_proto::BitFieldRead::read(__io_reader, __byte_order, __ctx, #field_width))
Expand Down Expand Up @@ -95,7 +92,7 @@ fn read(field: &syn::Field) -> TokenStream {
}

fn write(field: &syn::Field, field_name: &TokenStream) -> TokenStream {
let attribs = match Attrs::try_from(field.attrs.as_slice()) {
let attribs = match Attrs::for_kind(field.attrs.as_slice(), Some(AttrKind::Field)) {
Ok(attribs) => attribs,
Err(e) => return e.to_compile_error(),
};
Expand Down
23 changes: 13 additions & 10 deletions bin-proto-derive/src/codegen/trait_impl.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::attr::Attrs;
use crate::attr::{Attrs, Ctx};

use proc_macro2::{Span, TokenStream};
use syn::{parse_quote, punctuated::Punctuated, Token};
Expand All @@ -17,7 +17,7 @@ pub fn impl_trait_for(
typ: &TraitImplType,
) -> TokenStream {
let name = &ast.ident;
let attribs = match Attrs::try_from(ast.attrs.as_slice()) {
let attribs = match Attrs::for_kind(ast.attrs.as_slice(), None) {
Ok(attribs) => attribs,
Err(e) => return e.to_compile_error(),
};
Expand Down Expand Up @@ -59,23 +59,26 @@ pub fn impl_trait_for(
| TraitImplType::TaggedRead(_)
| TraitImplType::UntaggedWrite
) {
trait_generics.push(if let Some(ctx) = attribs.ctx {
if let Some(ctx_generics) = attribs.ctx_generics {
generics.params.extend(ctx_generics);
}
if let Some(ctx_generics) = attribs.ctx_generics {
generics.params.extend(ctx_generics);
}

trait_generics.push(if let Some(Ctx::Concrete(ctx)) = attribs.ctx {
quote!(#ctx)
} else {
let ident = syn::Ident::new("__Ctx", Span::call_site());
let bounds = if let Some(Ctx::Bounds(bounds)) = attribs.ctx {
bounds.into_iter().collect()
} else {
Punctuated::new()
};
generics
.params
.push(syn::GenericParam::Type(syn::TypeParam {
attrs: Vec::new(),
ident: ident.clone(),
colon_token: None,
bounds: attribs
.ctx_bounds
.map(|ctx_bounds| ctx_bounds.into_iter().collect())
.unwrap_or_default(),
bounds,
eq_token: None,
default: None,
}));
Expand Down
Loading

0 comments on commit 51dc21f

Please sign in to comment.