Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add full support for generics #11

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ serde = { version = "1" }
serde-byte-array = "0.1.2"
serde_bytes = { version = "0.11.12", default-features = false }
serde_cbor = { version = "0.11.0" }
serde_test = "1.0.176"
90 changes: 70 additions & 20 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,33 @@ extern crate proc_macro;
mod parse;

use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::parse_macro_input;
use proc_macro2::{Ident, Span};
use quote::{format_ident, quote, ToTokens};
use syn::{parse_macro_input, Lifetime, LifetimeParam, TypeParamBound};

use crate::parse::Input;

/// Wrapper around syn structs that don't implement `Copy` but we want to use at multiple places
#[derive(Clone, Copy)]
struct CopyWrapper<'a, T>(&'a T);

impl<'a, T: ToTokens> ToTokens for CopyWrapper<'a, T> {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
self.0.to_tokens(tokens)
}

fn to_token_stream(&self) -> proc_macro2::TokenStream {
self.0.to_token_stream()
}

fn into_token_stream(self) -> proc_macro2::TokenStream
where
Self: Sized,
{
self.0.to_token_stream()
}
}

fn serialize_fields(fields: &[parse::Field], offset: usize) -> Vec<proc_macro2::TokenStream> {
fields
.iter()
Expand Down Expand Up @@ -85,10 +107,17 @@ pub fn derive_serialize(input: TokenStream) -> TokenStream {
let ident = input.ident;
let num_fields = count_serialized_fields(&input.fields);
let serialize_fields = serialize_fields(&input.fields, input.attrs.offset);
let lifetimes = &input.lifetimes;
let (_, ty_generics, where_clause) = input.generics.split_for_impl();
let mut generics_cl = input.generics.clone();
generics_cl.type_params_mut().for_each(|t| {
t.bounds
.push_value(TypeParamBound::Verbatim(quote!(serde::Serialize)));
});
let (impl_generics, _, _) = generics_cl.split_for_impl();

TokenStream::from(quote! {
impl<#(#lifetimes),*> serde::Serialize for #ident<#(#lifetimes),*> {
#[automatically_derived]
impl #impl_generics serde::Serialize for #ident #ty_generics #where_clause {
fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
where
S: serde::Serializer
Expand Down Expand Up @@ -128,9 +157,8 @@ fn unwrap_expected_fields(fields: &[parse::Field]) -> Vec<proc_macro2::TokenStre
let #ident = #ident.ok_or_else(|| serde::de::Error::missing_field(#label))?;
}
} else {
// TODO: still confused here, but the tests pass ;)
quote! {
// let #ident = #ident.or(None);
let #ident = #ident.unwrap_or_default();
}
}
})
Expand Down Expand Up @@ -168,12 +196,6 @@ fn all_fields(fields: &[parse::Field]) -> Vec<proc_macro2::TokenStream> {
.collect()
}

fn de_lifetime(lifetimes: &[syn::Lifetime]) -> proc_macro2::TokenStream {
quote! {
'de: #(#lifetimes)+*
}
}

#[proc_macro_derive(DeserializeIndexed, attributes(serde, serde_indexed))]
pub fn derive_deserialize(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as Input);
Expand All @@ -182,8 +204,34 @@ pub fn derive_deserialize(input: TokenStream) -> TokenStream {
let unwrap_expected_fields = unwrap_expected_fields(&input.fields);
let match_fields = match_fields(&input.fields, input.attrs.offset);
let all_fields = all_fields(&input.fields);
let de_lifetime = de_lifetime(&input.lifetimes);
let lifetimes = input.lifetimes;

let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let ty_generics = CopyWrapper(&ty_generics);

let mut generics_cl = input.generics.clone();
generics_cl.params.insert(
0,
syn::GenericParam::Lifetime(LifetimeParam {
attrs: Vec::new(),
lifetime: Lifetime {
apostrophe: Span::call_site(),
ident: Ident::new("de", Span::call_site()),
},
colon_token: None,
bounds: input
.generics
.lifetimes()
.map(|l| l.lifetime.clone())
.collect(),
}),
);
generics_cl.type_params_mut().for_each(|t| {
t.bounds
.push_value(TypeParamBound::Verbatim(quote!(serde::Deserialize<'de>)));
});

let (impl_generics_with_de, _, _) = generics_cl.split_for_impl();
let impl_generics_with_de = CopyWrapper(&impl_generics_with_de);

let the_loop = if !input.fields.is_empty() {
// NB: In the previous "none_fields", we use the actual struct's
Expand All @@ -204,16 +252,17 @@ pub fn derive_deserialize(input: TokenStream) -> TokenStream {
quote! {}
};

TokenStream::from(quote! {
impl<#de_lifetime, #(#lifetimes),*> serde::Deserialize<'de> for #ident<#(#lifetimes),*> {
let res = quote! {
#[automatically_derived]
impl #impl_generics_with_de serde::Deserialize<'de> for #ident #ty_generics #where_clause {
fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct IndexedVisitor<#(#lifetimes),*>(core::marker::PhantomData<#(&#lifetimes)* ()>);
struct IndexedVisitor #impl_generics (core::marker::PhantomData<#ident #ty_generics>);

impl<#de_lifetime, #(#lifetimes),*> serde::de::Visitor<'de> for IndexedVisitor<#(#lifetimes),*> {
type Value = #ident<#(#lifetimes),*>;
impl #impl_generics_with_de serde::de::Visitor<'de> for IndexedVisitor #ty_generics {
type Value = #ident #ty_generics;

fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
formatter.write_str(stringify!(#ident))
Expand All @@ -236,5 +285,6 @@ pub fn derive_deserialize(input: TokenStream) -> TokenStream {
deserializer.deserialize_map(IndexedVisitor(Default::default()))
}
}
})
};
TokenStream::from(res)
}
12 changes: 3 additions & 9 deletions src/parse.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use proc_macro2::Span;
use syn::meta::ParseNestedMeta;
use syn::parse::{Error, Parse, ParseStream, Result};
use syn::{Data, DeriveInput, Fields, Ident, Lifetime, LitInt, LitStr, Token};
use syn::{Data, DeriveInput, Fields, Generics, Ident, LitInt, LitStr, Token};

pub struct Input {
pub ident: Ident,
pub attrs: StructAttrs,
pub fields: Vec<Field>,
pub lifetimes: Vec<Lifetime>,
pub generics: Generics,
}

#[derive(Default)]
Expand Down Expand Up @@ -57,10 +57,6 @@ fn parse_attrs(attrs: &Vec<syn::Attribute>) -> Result<StructAttrs> {
Ok(struct_attrs)
}

fn lifetimes(generics: &syn::Generics) -> Vec<Lifetime> {
generics.lifetimes().map(|l| l.lifetime.clone()).collect()
}

impl Parse for Input {
fn parse(input: ParseStream) -> Result<Self> {
let call_site = Span::call_site();
Expand All @@ -84,15 +80,13 @@ impl Parse for Input {

let fields = fields_from_ast(&syn_fields.named)?;

let lifetimes = lifetimes(&derive_input.generics);

//serde::internals::ast calls `fields_from_ast(cx, &fields.named, attrs, container_default)`

Ok(Input {
ident: derive_input.ident,
attrs,
fields,
lifetimes,
generics: derive_input.generics,
})
}
}
Expand Down
Loading