diff --git a/utils/yoke/derive/src/lib.rs b/utils/yoke/derive/src/lib.rs index 146f6cb5443..87e911b284f 100644 --- a/utils/yoke/derive/src/lib.rs +++ b/utils/yoke/derive/src/lib.rs @@ -8,8 +8,10 @@ use proc_macro::TokenStream; use proc_macro2::{Span, TokenStream as TokenStream2}; use quote::quote; use syn::spanned::Spanned; -use syn::{parse_macro_input, parse_quote, DeriveInput, Ident, Lifetime, Type, WherePredicate}; +use syn::punctuated::Punctuated; +use syn::{parse_macro_input, Token, MetaList, parse_quote, DeriveInput, Ident, Lifetime, Type, WherePredicate}; use synstructure::Structure; +use std::collections::{HashMap, HashSet}; mod visitor; @@ -33,6 +35,28 @@ pub fn yokeable_derive(input: TokenStream) -> TokenStream { TokenStream::from(yokeable_derive_impl(&input)) } + +// Collects all idents from #[yoke(may_borrow(A, B, C, D))] +// needed since #[yoke(may_borrow)] doesn't work yet +// (https://github.com/rust-lang/rust/issues/114393) +fn get_may_borrow_attr(attrs: &[syn::Attribute]) -> Result, Span> { + let mut params = HashSet::new(); + for attr in attrs { + if let Ok(list) = attr.parse_args::() { + if list.path.is_ident("may_borrow") { + if let Ok(list) = + list.parse_args_with(Punctuated::::parse_terminated) + { + params.extend(list.into_iter()) + } else { + return Err(attr.span()); + } + } + } + } + Ok(params) +} + fn yokeable_derive_impl(input: &DeriveInput) -> TokenStream2 { let tybounds = input .generics @@ -56,7 +80,33 @@ fn yokeable_derive_impl(input: &DeriveInput) -> TokenStream2 { .map(|ty| parse_quote!(#ty: 'static)) .collect(); let lts = input.generics.lifetimes().count(); - if lts == 0 { + + let may_borrow_attrs = match get_may_borrow_attr(&input.attrs) { + Ok(mb) => mb, + Err(span) => { + return syn::Error::new( + span, + "#[yoke(may_borrow)] on the struct takes in a comma separated list of type parameters, like so: `#[zerofrom(may_borrow(A, B, C, D)]`", + ).to_compile_error(); + } + }; + + let generics_env: HashMap = tybounds + .iter() + .map(|param| { + ( + param.ident.clone(), + // Can't check + may_borrow_attrs.contains(¶m.ident), + ) + }) + .collect(); + + // Do any of the generics potentially borrow? + let generics_may_borrow = generics_env.values().any(|x| *x); + + + if lts == 0 && !generics_may_borrow { let name = &input.ident; quote! { // This is safe because there are no lifetime parameters. @@ -83,7 +133,7 @@ fn yokeable_derive_impl(input: &DeriveInput) -> TokenStream2 { } } } else { - if lts != 1 { + if lts > 1 { return syn::Error::new( input.generics.span(), "derive(Yokeable) cannot have multiple lifetime parameters", @@ -99,9 +149,14 @@ fn yokeable_derive_impl(input: &DeriveInput) -> TokenStream2 { } false }); + // Due to the possibility of generics_may_borrow, we might reach here with no lifetimes on self + let (maybe_static_lifetime, maybe_a_lifetime) = if lts == 0 { + (quote!(), quote!()) + } else { + (quote!('static,), quote!('a,)) + }; if manual_covariance { let mut structure = Structure::new(input); - let generics_env = typarams.iter().cloned().collect(); let static_bounds: Vec = typarams .iter() .map(|ty| parse_quote!(#ty: 'static)) @@ -169,10 +224,10 @@ fn yokeable_derive_impl(input: &DeriveInput) -> TokenStream2 { } }); return quote! { - unsafe impl<'a, #(#tybounds),*> yoke::Yokeable<'a> for #name<'static, #(#typarams),*> + unsafe impl<'a, #(#tybounds),*> yoke::Yokeable<'a> for #name<#maybe_static_lifetime #(#typarams),*> where #(#static_bounds,)* #(#yoke_bounds,)* { - type Output = #name<'a, #(#typarams),*>; + type Output = #name<#maybe_a_lifetime #(#typarams),*>; #[inline] fn transform(&'a self) -> &'a Self::Output { // These are just type asserts, we don't need them for anything diff --git a/utils/yoke/derive/src/visitor.rs b/utils/yoke/derive/src/visitor.rs index daca1da13fd..18fff4b3e78 100644 --- a/utils/yoke/derive/src/visitor.rs +++ b/utils/yoke/derive/src/visitor.rs @@ -4,13 +4,13 @@ //! Visitor for determining whether a type has type and non-static lifetime parameters -use std::collections::HashSet; +use std::collections::HashMap; use syn::visit::{visit_lifetime, visit_type, visit_type_path, Visit}; use syn::{Ident, Lifetime, Type, TypePath}; struct TypeVisitor<'a> { /// The type parameters in scope - typarams: &'a HashSet, + typarams: &'a HashMap, /// Whether we found a type parameter found_typarams: bool, /// Whether we found a non-'static lifetime parameter @@ -29,8 +29,11 @@ impl<'a, 'ast> Visit<'ast> for TypeVisitor<'a> { // generics in ty.path because the visitor will eventually visit those // types on its own if let Some(ident) = ty.path.get_ident() { - if self.typarams.contains(ident) { + if let Some(maybe_lt) = self.typarams.get(ident) { self.found_typarams = true; + if *maybe_lt { + self.found_lifetimes = true; + } } } @@ -40,7 +43,7 @@ impl<'a, 'ast> Visit<'ast> for TypeVisitor<'a> { /// Checks if a type has type or lifetime parameters, given the local context of /// named type parameters. Returns (has_type_params, has_lifetime_params) -pub fn check_type_for_parameters(ty: &Type, typarams: &HashSet) -> (bool, bool) { +pub fn check_type_for_parameters(ty: &Type, typarams: &HashMap) -> (bool, bool) { let mut visit = TypeVisitor { typarams, found_typarams: false,