diff --git a/bin-proto-derive/src/attr.rs b/bin-proto-derive/src/attr.rs index ea27648..b731f5c 100644 --- a/bin-proto-derive/src/attr.rs +++ b/bin-proto-derive/src/attr.rs @@ -1,19 +1,25 @@ -use proc_macro2::{Span, TokenStream}; -use syn::{parenthesized, punctuated::Punctuated, spanned::Spanned, token::Comma, Error, Result}; +use std::fmt; + +use proc_macro2::TokenStream; +use syn::{parenthesized, punctuated::Punctuated, spanned::Spanned, Error, Result, Token}; #[derive(Default)] pub struct Attrs { pub discriminant_type: Option, pub discriminant: Option, - pub ctx: Option, + pub ctx: Option, pub ctx_generics: Option>, - pub ctx_bounds: Option>, pub write_value: Option, pub bits: Option, pub flexible_array_member: bool, pub tag: Option, } +pub enum Ctx { + Concrete(syn::Type), + Bounds(Vec), +} + pub enum Tag { External(syn::Expr), Prepend { @@ -22,179 +28,119 @@ pub enum Tag { }, } -impl Attrs { - #[allow(clippy::too_many_lines)] - pub fn validate_enum(&self, span: Span) -> Result<()> { - if self.discriminant_type.is_none() { - return Err(Error::new( - span, - "expected discriminant_type attribute for enum", - )); - } - if self.discriminant.is_some() { - return Err(Error::new( - span, - "unexpected discriminant attribute for enum", - )); - } - if self.ctx.is_some() && self.ctx_bounds.is_some() { - return Err(Error::new( - span, - "cannot specify ctx and ctx_bounds simultaneously", - )); - } - if self.write_value.is_some() { - return Err(Error::new( - span, - "unexpected write_value attribute for enum", - )); - } - if self.flexible_array_member { - return Err(Error::new( - span, - "unexpected flexible_array_member attribute for enum", - )); - } - if self.tag.is_some() { - return Err(Error::new(span, "unexpected tag attribute for enum")); - } - Ok(()) - } +#[derive(Clone, Copy, PartialEq, Eq)] +pub enum AttrKind { + Enum, + Struct, + Variant, + Field, +} - pub fn validate_variant(&self, span: Span) -> Result<()> { - if self.discriminant_type.is_some() { - return Err(Error::new( - span, - "unexpected discriminant_type attribute for variant", - )); - } - if self.ctx.is_some() { - return Err(Error::new(span, "unexpected ctx attribute for variant")); - } - if self.ctx_bounds.is_some() { - return Err(Error::new( - span, - "unexpected ctx_bounds attribute for variant", - )); - } - if self.write_value.is_some() { - return Err(Error::new( - span, - "unexpected write_value attribute for variant", - )); - } - if self.bits.is_some() { - return Err(Error::new(span, "unexpected bits attribute for variant")); +impl fmt::Display for AttrKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AttrKind::Enum => write!(f, "enum"), + AttrKind::Struct => write!(f, "struct"), + AttrKind::Variant => write!(f, "variant"), + AttrKind::Field => write!(f, "field"), } - if self.flexible_array_member { - return Err(Error::new( - span, - "unexpected flexible_array_member attribute for variant", - )); - } - if self.tag.is_some() { - return Err(Error::new(span, "unexpected tag attribute for variant")); - } - Ok(()) } +} - pub fn validate_field(&self, span: Span) -> Result<()> { - if self.discriminant_type.is_some() { - return Err(Error::new( - span, - "unexpected discriminant_type attribute for field", - )); - } - if self.discriminant.is_some() { - return Err(Error::new( - span, - "unexpected discriminant attribute for field", - )); - } - if self.ctx.is_some() { - return Err(Error::new(span, "unexpected ctx attribute for variant")); - } - if self.ctx_bounds.is_some() { - return Err(Error::new( - span, - "unexpected ctx_bounds attribute for variant", - )); - } - if [ - self.bits.is_some(), - self.flexible_array_member, - self.tag.is_some(), - ] - .iter() - .filter(|b| **b) - .count() - > 1 - { - return Err(Error::new( - span, - "bits, flexible_array_member, and tag are mutually-exclusive attributes", - )); +macro_rules! validate_attr_kind { + ($pat:pat, $kind:expr, $meta:expr, $attr:expr) => { + if let Some(kind) = $kind { + if !matches!(kind, $pat) { + return Err($meta.error(format!( + "attribute '{}' cannot be applied to {}", + $attr, kind + ))); + } } - Ok(()) - } + }; +} +impl Attrs { pub fn ctx_ty(&self) -> TokenStream { - self.ctx - .as_ref() - .map(|ctx| quote!(#ctx)) - .unwrap_or(quote!(__Ctx)) + if let Some(Ctx::Concrete(ctx)) = &self.ctx { + quote!(#ctx) + } else { + quote!(__Ctx) + } } -} -impl TryFrom<&[syn::Attribute]> for Attrs { - type Error = syn::Error; - - fn try_from(attrs: &[syn::Attribute]) -> Result { + #[allow(clippy::too_many_lines)] + pub fn for_kind(attrs: &[syn::Attribute], kind: Option) -> Result { let mut attribs = Attrs::default(); + let mut tag = None; let mut tag_type = None; let mut tag_value = None; + let mut ctx = None; + let mut ctx_bounds = None; + for attr in attrs { if attr.path().is_ident("protocol") { attr.parse_nested_meta(|meta| { if meta.path.is_ident("flexible_array_member") { + validate_attr_kind!(AttrKind::Field, kind, meta, "flexible_array_member"); attribs.flexible_array_member = true; } else if meta.path.is_ident("discriminant_type") { + validate_attr_kind!(AttrKind::Enum, kind, meta, "discriminant_type"); attribs.discriminant_type = Some(meta.value()?.parse()?); } else if meta.path.is_ident("discriminant") { + validate_attr_kind!(AttrKind::Variant, kind, meta, "discriminant"); attribs.discriminant = Some(meta.value()?.parse()?); } else if meta.path.is_ident("ctx") { - attribs.ctx = Some(meta.value()?.parse()?); + validate_attr_kind!(AttrKind::Enum | AttrKind::Struct, kind, meta, "ctx"); + ctx = Some(meta.value()?.parse()?); } else if meta.path.is_ident("ctx_generics") { + validate_attr_kind!( + AttrKind::Enum | AttrKind::Struct, + kind, + meta, + "ctx_generics" + ); let content; parenthesized!(content in meta.input); attribs.ctx_generics = Some( - Punctuated::::parse_separated_nonempty( + Punctuated::::parse_separated_nonempty( &content, )? .into_iter() .collect(), ); } else if meta.path.is_ident("ctx_bounds") { + validate_attr_kind!( + AttrKind::Enum | AttrKind::Struct, + kind, + meta, + "ctx_bounds" + ); let content; parenthesized!(content in meta.input); - attribs.ctx_bounds = Some( - Punctuated::::parse_separated_nonempty( + ctx_bounds = Some( + Punctuated::::parse_separated_nonempty( &content, )? .into_iter() .collect(), ); } else if meta.path.is_ident("bits") { + validate_attr_kind!(AttrKind::Enum | AttrKind::Field, kind, meta, "bits"); attribs.bits = Some(meta.value()?.parse()?); } else if meta.path.is_ident("write_value") { + validate_attr_kind!(AttrKind::Field, kind, meta, "write_value"); attribs.write_value = Some(meta.value()?.parse()?); } else if meta.path.is_ident("tag") { + validate_attr_kind!(AttrKind::Field, kind, meta, "tag"); tag = Some(meta.value()?.parse()?); } else if meta.path.is_ident("tag_type") { + validate_attr_kind!(AttrKind::Field, kind, meta, "tag_type"); tag_type = Some(meta.value()?.parse()?); } else if meta.path.is_ident("tag_value") { + validate_attr_kind!(AttrKind::Field, kind, meta, "tag_value"); tag_value = Some(meta.value()?.parse()?); } else { return Err(meta.error("unrecognized protocol")); @@ -216,6 +162,29 @@ impl TryFrom<&[syn::Attribute]> for Attrs { _ => return Err(Error::new(attrs[0].span(), "TODO")), } + match (ctx, ctx_bounds) { + (Some(ctx), None) => attribs.ctx = Some(Ctx::Concrete(ctx)), + (None, Some(ctx_bounds)) => attribs.ctx = Some(Ctx::Bounds(ctx_bounds)), + (None, None) => {} + _ => return Err(Error::new(attrs[0].span(), "TODO")), + } + + if [ + attribs.bits.is_some(), + attribs.flexible_array_member, + attribs.tag.is_some(), + ] + .iter() + .filter(|b| **b) + .count() + > 1 + { + return Err(Error::new( + attrs[0].span(), + "bits, flexible_array_member, and tag are mutually-exclusive attributes", + )); + } + Ok(attribs) } } diff --git a/bin-proto-derive/src/codegen/mod.rs b/bin-proto-derive/src/codegen/mod.rs index c7a3f0b..ce0fe9b 100644 --- a/bin-proto-derive/src/codegen/mod.rs +++ b/bin-proto-derive/src/codegen/mod.rs @@ -1,7 +1,7 @@ pub mod enums; pub mod trait_impl; -use crate::attr::{Attrs, Tag}; +use crate::attr::{AttrKind, Attrs, Tag}; use proc_macro2::TokenStream; use syn::{spanned::Spanned, Error}; @@ -54,13 +54,10 @@ fn read_named_fields(fields_named: &syn::FieldsNamed) -> (TokenStream, TokenStre } fn read(field: &syn::Field) -> TokenStream { - let attribs = match Attrs::try_from(field.attrs.as_slice()) { + let attribs = match Attrs::for_kind(field.attrs.as_slice(), Some(AttrKind::Field)) { Ok(attribs) => attribs, Err(e) => return e.to_compile_error(), }; - if let Err(e) = attribs.validate_field(field.span()) { - return e.to_compile_error(); - }; if let Some(field_width) = attribs.bits { quote!(::bin_proto::BitFieldRead::read(__io_reader, __byte_order, __ctx, #field_width)) @@ -95,7 +92,7 @@ fn read(field: &syn::Field) -> TokenStream { } fn write(field: &syn::Field, field_name: &TokenStream) -> TokenStream { - let attribs = match Attrs::try_from(field.attrs.as_slice()) { + let attribs = match Attrs::for_kind(field.attrs.as_slice(), Some(AttrKind::Field)) { Ok(attribs) => attribs, Err(e) => return e.to_compile_error(), }; diff --git a/bin-proto-derive/src/codegen/trait_impl.rs b/bin-proto-derive/src/codegen/trait_impl.rs index b8255be..6e8babe 100644 --- a/bin-proto-derive/src/codegen/trait_impl.rs +++ b/bin-proto-derive/src/codegen/trait_impl.rs @@ -1,4 +1,4 @@ -use crate::attr::Attrs; +use crate::attr::{Attrs, Ctx}; use proc_macro2::{Span, TokenStream}; use syn::{parse_quote, punctuated::Punctuated, Token}; @@ -17,7 +17,7 @@ pub fn impl_trait_for( typ: &TraitImplType, ) -> TokenStream { let name = &ast.ident; - let attribs = match Attrs::try_from(ast.attrs.as_slice()) { + let attribs = match Attrs::for_kind(ast.attrs.as_slice(), None) { Ok(attribs) => attribs, Err(e) => return e.to_compile_error(), }; @@ -59,23 +59,26 @@ pub fn impl_trait_for( | TraitImplType::TaggedRead(_) | TraitImplType::UntaggedWrite ) { - trait_generics.push(if let Some(ctx) = attribs.ctx { - if let Some(ctx_generics) = attribs.ctx_generics { - generics.params.extend(ctx_generics); - } + if let Some(ctx_generics) = attribs.ctx_generics { + generics.params.extend(ctx_generics); + } + + trait_generics.push(if let Some(Ctx::Concrete(ctx)) = attribs.ctx { quote!(#ctx) } else { let ident = syn::Ident::new("__Ctx", Span::call_site()); + let bounds = if let Some(Ctx::Bounds(bounds)) = attribs.ctx { + bounds.into_iter().collect() + } else { + Punctuated::new() + }; generics .params .push(syn::GenericParam::Type(syn::TypeParam { attrs: Vec::new(), ident: ident.clone(), colon_token: None, - bounds: attribs - .ctx_bounds - .map(|ctx_bounds| ctx_bounds.into_iter().collect()) - .unwrap_or_default(), + bounds, eq_token: None, default: None, })); diff --git a/bin-proto-derive/src/lib.rs b/bin-proto-derive/src/lib.rs index fea6615..304083d 100644 --- a/bin-proto-derive/src/lib.rs +++ b/bin-proto-derive/src/lib.rs @@ -8,7 +8,7 @@ mod attr; mod codegen; mod plan; -use attr::Attrs; +use attr::{AttrKind, Attrs}; use codegen::trait_impl::{impl_trait_for, TraitImplType}; use proc_macro2::TokenStream; use syn::parse_macro_input; @@ -46,7 +46,7 @@ fn impl_for_struct( strukt: &syn::DataStruct, protocol_type: Operation, ) -> TokenStream { - let attribs = match Attrs::try_from(ast.attrs.as_slice()) { + let attribs = match Attrs::for_kind(ast.attrs.as_slice(), Some(AttrKind::Struct)) { Ok(attribs) => attribs, Err(e) => return e.to_compile_error(), }; @@ -100,7 +100,7 @@ fn impl_for_enum( Ok(plan) => plan, Err(e) => return e.to_compile_error(), }; - let attribs = match Attrs::try_from(ast.attrs.as_slice()) { + let attribs = match Attrs::for_kind(ast.attrs.as_slice(), Some(AttrKind::Enum)) { Ok(attribs) => attribs, Err(e) => return e.to_compile_error(), }; diff --git a/bin-proto-derive/src/plan.rs b/bin-proto-derive/src/plan.rs index 4a73b65..3517893 100644 --- a/bin-proto-derive/src/plan.rs +++ b/bin-proto-derive/src/plan.rs @@ -1,4 +1,4 @@ -use crate::attr::Attrs; +use crate::attr::{AttrKind, Attrs}; use syn::{spanned::Spanned, Error, Result}; pub struct Enum { @@ -14,8 +14,7 @@ pub struct EnumVariant { impl Enum { pub fn try_new(ast: &syn::DeriveInput, e: &syn::DataEnum) -> Result { - let attrs = Attrs::try_from(ast.attrs.as_slice())?; - attrs.validate_enum(ast.span())?; + let attrs = Attrs::for_kind(ast.attrs.as_slice(), Some(AttrKind::Enum))?; let plan = Self { discriminant_ty: attrs.discriminant_type.unwrap(), @@ -23,8 +22,7 @@ impl Enum { .variants .iter() .map(|variant| { - let attrs = Attrs::try_from(variant.attrs.as_slice())?; - attrs.validate_variant(variant.span())?; + let attrs = Attrs::for_kind(variant.attrs.as_slice(), Some(AttrKind::Variant))?; let discriminant_value = match variant.discriminant.as_ref().map(|a| &a.1) { Some(expr_lit) => expr_lit.clone(), diff --git a/bin-proto/Cargo.toml b/bin-proto/Cargo.toml index 9201485..d45b452 100644 --- a/bin-proto/Cargo.toml +++ b/bin-proto/Cargo.toml @@ -22,4 +22,4 @@ derive = ["bin-proto-derive"] [dependencies] bin-proto-derive = { version = "0.6.0", path = "../bin-proto-derive", optional = true } bitstream-io = "2.3.0" -thiserror = "2.0.2" +thiserror = "2.0.3" diff --git a/bin-proto/tests/ctx.rs b/bin-proto/tests/ctx.rs index 96a7f8c..06b825a 100644 --- a/bin-proto/tests/ctx.rs +++ b/bin-proto/tests/ctx.rs @@ -1,3 +1,5 @@ +use std::marker::PhantomData; + use bin_proto::{ByteOrder, ProtocolRead, ProtocolWrite}; trait Boolean { @@ -10,6 +12,12 @@ impl Boolean for bool { } } +trait TraitWithGeneric<'a, T> +where + T: Boolean, +{ +} + trait CtxTrait { fn call(&mut self); } @@ -73,6 +81,14 @@ struct CtxCheckStructWrapperWithGenericsConcreteBool(CtxCheck); #[protocol(ctx = CtxStructWithGenerics<'a, T>, ctx_generics('a, T: Boolean))] struct CtxCheckStructWrapperWithGenerics(CtxCheck); +#[derive(Debug, ProtocolRead, ProtocolWrite)] +#[protocol(ctx_bounds(TraitWithGeneric<'a, bool>, CtxTrait), ctx_generics('a))] +struct CtxCheckBoundsWithGenericsConcreteBool(CtxCheck); + +#[derive(Debug, ProtocolRead, ProtocolWrite)] +#[protocol(ctx_bounds(TraitWithGeneric<'a, T>, CtxTrait), ctx_generics('a))] +struct CtxCheckBoundsWithGenerics(CtxCheck, PhantomData); + #[derive(Debug, ProtocolRead, ProtocolWrite)] #[protocol(ctx_bounds(CtxTrait))] struct CtxCheckTraitWrapper(CtxCheck);