Skip to content

Commit

Permalink
Expand Yoke to cover cases where attrs can borrow
Browse files Browse the repository at this point in the history
  • Loading branch information
Manishearth committed Oct 5, 2023
1 parent d2669e1 commit 9621292
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 10 deletions.
67 changes: 61 additions & 6 deletions utils/yoke/derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ use proc_macro::TokenStream;
use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::quote;
use syn::spanned::Spanned;
use syn::{parse_macro_input, parse_quote, DeriveInput, Ident, Lifetime, Type, WherePredicate};
use syn::punctuated::Punctuated;
use syn::{parse_macro_input, Token, MetaList, parse_quote, DeriveInput, Ident, Lifetime, Type, WherePredicate};
use synstructure::Structure;
use std::collections::{HashMap, HashSet};

mod visitor;

Expand All @@ -33,6 +35,28 @@ pub fn yokeable_derive(input: TokenStream) -> TokenStream {
TokenStream::from(yokeable_derive_impl(&input))
}


// Collects all idents from #[yoke(may_borrow(A, B, C, D))]
// needed since #[yoke(may_borrow)] doesn't work yet
// (https://github.com/rust-lang/rust/issues/114393)
fn get_may_borrow_attr(attrs: &[syn::Attribute]) -> Result<HashSet<Ident>, Span> {
let mut params = HashSet::new();
for attr in attrs {
if let Ok(list) = attr.parse_args::<MetaList>() {
if list.path.is_ident("may_borrow") {
if let Ok(list) =
list.parse_args_with(Punctuated::<Ident, Token![,]>::parse_terminated)
{
params.extend(list.into_iter())
} else {
return Err(attr.span());
}
}
}
}
Ok(params)
}

fn yokeable_derive_impl(input: &DeriveInput) -> TokenStream2 {
let tybounds = input
.generics
Expand All @@ -56,7 +80,33 @@ fn yokeable_derive_impl(input: &DeriveInput) -> TokenStream2 {
.map(|ty| parse_quote!(#ty: 'static))
.collect();
let lts = input.generics.lifetimes().count();
if lts == 0 {

let may_borrow_attrs = match get_may_borrow_attr(&input.attrs) {
Ok(mb) => mb,
Err(span) => {
return syn::Error::new(
span,
"#[yoke(may_borrow)] on the struct takes in a comma separated list of type parameters, like so: `#[zerofrom(may_borrow(A, B, C, D)]`",
).to_compile_error();
}
};

let generics_env: HashMap<Ident, bool> = tybounds
.iter()
.map(|param| {
(
param.ident.clone(),
// Can't check
may_borrow_attrs.contains(&param.ident),
)
})
.collect();

// Do any of the generics potentially borrow?
let generics_may_borrow = generics_env.values().any(|x| *x);


if lts == 0 && !generics_may_borrow {
let name = &input.ident;
quote! {
// This is safe because there are no lifetime parameters.
Expand All @@ -83,7 +133,7 @@ fn yokeable_derive_impl(input: &DeriveInput) -> TokenStream2 {
}
}
} else {
if lts != 1 {
if lts > 1 {
return syn::Error::new(
input.generics.span(),
"derive(Yokeable) cannot have multiple lifetime parameters",
Expand All @@ -99,9 +149,14 @@ fn yokeable_derive_impl(input: &DeriveInput) -> TokenStream2 {
}
false
});
// Due to the possibility of generics_may_borrow, we might reach here with no lifetimes on self
let (maybe_static_lifetime, maybe_a_lifetime) = if lts == 0 {
(quote!(), quote!())
} else {
(quote!('static,), quote!('a,))
};
if manual_covariance {
let mut structure = Structure::new(input);
let generics_env = typarams.iter().cloned().collect();
let static_bounds: Vec<WherePredicate> = typarams
.iter()
.map(|ty| parse_quote!(#ty: 'static))
Expand Down Expand Up @@ -169,10 +224,10 @@ fn yokeable_derive_impl(input: &DeriveInput) -> TokenStream2 {
}
});
return quote! {
unsafe impl<'a, #(#tybounds),*> yoke::Yokeable<'a> for #name<'static, #(#typarams),*>
unsafe impl<'a, #(#tybounds),*> yoke::Yokeable<'a> for #name<#maybe_static_lifetime #(#typarams),*>
where #(#static_bounds,)*
#(#yoke_bounds,)* {
type Output = #name<'a, #(#typarams),*>;
type Output = #name<#maybe_a_lifetime #(#typarams),*>;
#[inline]
fn transform(&'a self) -> &'a Self::Output {
// These are just type asserts, we don't need them for anything
Expand Down
11 changes: 7 additions & 4 deletions utils/yoke/derive/src/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

//! Visitor for determining whether a type has type and non-static lifetime parameters
use std::collections::HashSet;
use std::collections::HashMap;
use syn::visit::{visit_lifetime, visit_type, visit_type_path, Visit};
use syn::{Ident, Lifetime, Type, TypePath};

struct TypeVisitor<'a> {
/// The type parameters in scope
typarams: &'a HashSet<Ident>,
typarams: &'a HashMap<Ident, bool>,
/// Whether we found a type parameter
found_typarams: bool,
/// Whether we found a non-'static lifetime parameter
Expand All @@ -29,8 +29,11 @@ impl<'a, 'ast> Visit<'ast> for TypeVisitor<'a> {
// generics in ty.path because the visitor will eventually visit those
// types on its own
if let Some(ident) = ty.path.get_ident() {
if self.typarams.contains(ident) {
if let Some(maybe_lt) = self.typarams.get(ident) {
self.found_typarams = true;
if *maybe_lt {
self.found_lifetimes = true;
}
}
}

Expand All @@ -40,7 +43,7 @@ impl<'a, 'ast> Visit<'ast> for TypeVisitor<'a> {

/// Checks if a type has type or lifetime parameters, given the local context of
/// named type parameters. Returns (has_type_params, has_lifetime_params)
pub fn check_type_for_parameters(ty: &Type, typarams: &HashSet<Ident>) -> (bool, bool) {
pub fn check_type_for_parameters(ty: &Type, typarams: &HashMap<Ident, bool>) -> (bool, bool) {
let mut visit = TypeVisitor {
typarams,
found_typarams: false,
Expand Down

0 comments on commit 9621292

Please sign in to comment.