Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend the WorldQuery macro to tuple structs #8119

Merged
merged 12 commits into from
Apr 4, 2023
105 changes: 71 additions & 34 deletions crates/bevy_ecs/macros/src/fetch.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use bevy_macro_utils::ensure_no_collision;
use proc_macro::TokenStream;
use proc_macro2::{Ident, Span};
use quote::{quote, ToTokens};
use quote::{format_ident, quote, ToTokens};
use syn::{
parse::{Parse, ParseStream},
parse_macro_input, parse_quote,
punctuated::Punctuated,
Attribute, Data, DataStruct, DeriveInput, Field, Fields,
Attribute, Data, DataStruct, DeriveInput, Field, Index,
};

use crate::bevy_ecs_path;
Expand Down Expand Up @@ -116,34 +116,49 @@ pub fn derive_world_query_impl(input: TokenStream) -> TokenStream {
fetch_struct_name.clone()
};

let marker_name =
ensure_no_collision(format_ident!("_world_query_derive_marker"), tokens.clone());

// Generate a name for the state struct that doesn't conflict
// with the struct definition.
let state_struct_name = Ident::new(&format!("{struct_name}State"), Span::call_site());
let state_struct_name = ensure_no_collision(state_struct_name, tokens);

let fields = match &ast.data {
Data::Struct(DataStruct {
fields: Fields::Named(fields),
..
}) => &fields.named,
_ => panic!("Expected a struct with named fields"),
let Data::Struct(DataStruct { fields, .. }) = &ast.data else {
return syn::Error::new(
Span::call_site(),
"#[derive(WorldQuery)]` only supports structs",
)
.into_compile_error()
.into()
};

let mut field_attrs = Vec::new();
let mut field_visibilities = Vec::new();
let mut field_idents = Vec::new();
let mut named_field_idents = Vec::new();
let mut field_types = Vec::new();
let mut read_only_field_types = Vec::new();

for field in fields {
for (i, field) in fields.iter().enumerate() {
let attrs = match read_world_query_field_info(field) {
Ok(WorldQueryFieldInfo { attrs }) => attrs,
Err(e) => return e.into_compile_error().into(),
};

let named_field_ident = field
.ident
.as_ref()
.cloned()
.unwrap_or_else(|| format_ident!("f{i}"));
let i = Index::from(i);
let field_ident = field
.ident
.as_ref()
.map_or(quote! { #i }, |i| quote! { #i });
field_idents.push(field_ident);
named_field_idents.push(named_field_ident);
field_attrs.push(attrs);
field_visibilities.push(field.vis.clone());
field_idents.push(field.ident.as_ref().unwrap().clone());
let field_ty = field.ty.clone();
field_types.push(quote!(#field_ty));
read_only_field_types.push(quote!(<#field_ty as #path::query::WorldQuery>::ReadOnly));
Expand Down Expand Up @@ -176,15 +191,34 @@ pub fn derive_world_query_impl(input: TokenStream) -> TokenStream {
&field_types
};

let item_struct = quote! {
#derive_macro_call
#[doc = "Automatically generated [`WorldQuery`] item type for [`"]
#[doc = stringify!(#struct_name)]
#[doc = "`], returned when iterating over query results."]
#[automatically_derived]
#visibility struct #item_struct_name #user_impl_generics_with_world #user_where_clauses_with_world {
#(#(#field_attrs)* #field_visibilities #field_idents: <#field_types as #path::query::WorldQuery>::Item<'__w>,)*
}
let item_struct = match fields {
syn::Fields::Named(_) => quote! {
#derive_macro_call
#[doc = "Automatically generated [`WorldQuery`] item type for [`"]
#[doc = stringify!(#struct_name)]
#[doc = "`], returned when iterating over query results."]
#[automatically_derived]
#visibility struct #item_struct_name #user_impl_generics_with_world #user_where_clauses_with_world {
#(#(#field_attrs)* #field_visibilities #field_idents: <#field_types as #path::query::WorldQuery>::Item<'__w>,)*
}
},
syn::Fields::Unnamed(_) => quote! {
#derive_macro_call
#[doc = "Automatically generated [`WorldQuery`] item type for [`"]
#[doc = stringify!(#struct_name)]
#[doc = "`], returned when iterating over query results."]
#[automatically_derived]
#visibility struct #item_struct_name #user_impl_generics_with_world #user_where_clauses_with_world(
#( #field_visibilities <#field_types as #path::query::WorldQuery>::Item<'__w>, )*
);
},
syn::Fields::Unit => quote! {
#[doc = "Automatically generated [`WorldQuery`] item type for [`"]
#[doc = stringify!(#struct_name)]
#[doc = "`], returned when iterating over query results."]
#[automatically_derived]
#visibility type #item_struct_name #user_ty_generics_with_world = #struct_name #user_ty_generics;
},
};

let query_impl = quote! {
Expand All @@ -194,7 +228,8 @@ pub fn derive_world_query_impl(input: TokenStream) -> TokenStream {
#[doc = "`], used to define the world data accessed by this query."]
#[automatically_derived]
#visibility struct #fetch_struct_name #user_impl_generics_with_world #user_where_clauses_with_world {
#(#field_idents: <#field_types as #path::query::WorldQuery>::Fetch<'__w>,)*
#(#named_field_idents: <#field_types as #path::query::WorldQuery>::Fetch<'__w>,)*
#marker_name: &'__w (),
}

// SAFETY: `update_component_access` and `update_archetype_component_access` are called on every field
Expand Down Expand Up @@ -223,14 +258,15 @@ pub fn derive_world_query_impl(input: TokenStream) -> TokenStream {
_this_run: #path::component::Tick,
) -> <Self as #path::query::WorldQuery>::Fetch<'__w> {
#fetch_struct_name {
#(#field_idents:
#(#named_field_idents:
<#field_types>::init_fetch(
_world,
&state.#field_idents,
&state.#named_field_idents,
_last_run,
_this_run,
),
)*
#marker_name: &(),
}
}

Expand All @@ -239,8 +275,9 @@ pub fn derive_world_query_impl(input: TokenStream) -> TokenStream {
) -> <Self as #path::query::WorldQuery>::Fetch<'__w> {
#fetch_struct_name {
#(
#field_idents: <#field_types>::clone_fetch(& _fetch. #field_idents),
#named_field_idents: <#field_types>::clone_fetch(& _fetch. #named_field_idents),
)*
#marker_name: &(),
}
}

Expand All @@ -256,7 +293,7 @@ pub fn derive_world_query_impl(input: TokenStream) -> TokenStream {
_archetype: &'__w #path::archetype::Archetype,
_table: &'__w #path::storage::Table
) {
#(<#field_types>::set_archetype(&mut _fetch.#field_idents, &_state.#field_idents, _archetype, _table);)*
#(<#field_types>::set_archetype(&mut _fetch.#named_field_idents, &_state.#named_field_idents, _archetype, _table);)*
}

/// SAFETY: we call `set_table` for each member that implements `Fetch`
Expand All @@ -266,7 +303,7 @@ pub fn derive_world_query_impl(input: TokenStream) -> TokenStream {
_state: &Self::State,
_table: &'__w #path::storage::Table
) {
#(<#field_types>::set_table(&mut _fetch.#field_idents, &_state.#field_idents, _table);)*
#(<#field_types>::set_table(&mut _fetch.#named_field_idents, &_state.#named_field_idents, _table);)*
}

/// SAFETY: we call `fetch` for each member that implements `Fetch`.
Expand All @@ -277,7 +314,7 @@ pub fn derive_world_query_impl(input: TokenStream) -> TokenStream {
_table_row: #path::storage::TableRow,
) -> <Self as #path::query::WorldQuery>::Item<'__w> {
Self::Item {
#(#field_idents: <#field_types>::fetch(&mut _fetch.#field_idents, _entity, _table_row),)*
#(#field_idents: <#field_types>::fetch(&mut _fetch.#named_field_idents, _entity, _table_row),)*
}
}

Expand All @@ -288,11 +325,11 @@ pub fn derive_world_query_impl(input: TokenStream) -> TokenStream {
_entity: #path::entity::Entity,
_table_row: #path::storage::TableRow,
) -> bool {
true #(&& <#field_types>::filter_fetch(&mut _fetch.#field_idents, _entity, _table_row))*
true #(&& <#field_types>::filter_fetch(&mut _fetch.#named_field_idents, _entity, _table_row))*
}

fn update_component_access(state: &Self::State, _access: &mut #path::query::FilteredAccess<#path::component::ComponentId>) {
#( <#field_types>::update_component_access(&state.#field_idents, _access); )*
#( <#field_types>::update_component_access(&state.#named_field_idents, _access); )*
}

fn update_archetype_component_access(
Expand All @@ -301,18 +338,18 @@ pub fn derive_world_query_impl(input: TokenStream) -> TokenStream {
_access: &mut #path::query::Access<#path::archetype::ArchetypeComponentId>
) {
#(
<#field_types>::update_archetype_component_access(&state.#field_idents, _archetype, _access);
<#field_types>::update_archetype_component_access(&state.#named_field_idents, _archetype, _access);
)*
}

fn init_state(world: &mut #path::world::World) -> #state_struct_name #user_ty_generics {
#state_struct_name {
#(#field_idents: <#field_types>::init_state(world),)*
#(#named_field_idents: <#field_types>::init_state(world),)*
}
}

fn matches_component_set(state: &Self::State, _set_contains_id: &impl Fn(#path::component::ComponentId) -> bool) -> bool {
true #(&& <#field_types>::matches_component_set(&state.#field_idents, _set_contains_id))*
true #(&& <#field_types>::matches_component_set(&state.#named_field_idents, _set_contains_id))*
}
}
};
Expand All @@ -328,7 +365,7 @@ pub fn derive_world_query_impl(input: TokenStream) -> TokenStream {
#[doc = "`]."]
#[automatically_derived]
#visibility struct #read_only_struct_name #user_impl_generics #user_where_clauses {
#( #field_visibilities #field_idents: #read_only_field_types, )*
#( #field_visibilities #named_field_idents: #read_only_field_types, )*
}

#readonly_state
Expand Down Expand Up @@ -374,7 +411,7 @@ pub fn derive_world_query_impl(input: TokenStream) -> TokenStream {
#[doc = "`], used for caching."]
#[automatically_derived]
#visibility struct #state_struct_name #user_impl_generics #user_where_clauses {
#(#field_idents: <#field_types as #path::query::WorldQuery>::State,)*
#(#named_field_idents: <#field_types as #path::query::WorldQuery>::State,)*
}

#mutable_impl
Expand Down
31 changes: 28 additions & 3 deletions crates/bevy_ecs/src/query/fetch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ use std::{cell::UnsafeCell, marker::PhantomData};
/// - Methods can be implemented for the query items.
/// - There is no hardcoded limit on the number of elements.
///
/// This trait can only be derived if each field also implements `WorldQuery`.
/// The derive macro only supports regular structs (structs with named fields).
/// This trait can only be derived for structs, if each field also implements `WorldQuery`.
///
/// ```
/// # use bevy_ecs::prelude::*;
Expand Down Expand Up @@ -1468,11 +1467,37 @@ unsafe impl<T: ?Sized> ReadOnlyWorldQuery for PhantomData<T> {}
#[cfg(test)]
mod tests {
use super::*;
use crate::{self as bevy_ecs, system::Query};
use crate::{
self as bevy_ecs,
system::{assert_is_system, Query},
};

#[derive(Component)]
pub struct A;

#[derive(Component)]
pub struct B;

// Tests that each variant of struct can be used as a `WorldQuery`.
#[test]
fn world_query_struct_variants() {
#[derive(WorldQuery)]
pub struct NamedQuery {
id: Entity,
a: &'static A,
}

#[derive(WorldQuery)]
pub struct TupleQuery(&'static A, &'static B);

#[derive(WorldQuery)]
pub struct UnitQuery;

fn my_system(_: Query<(NamedQuery, TupleQuery, UnitQuery)>) {}

assert_is_system(my_system);
}

// Compile test for https://github.com/bevyengine/bevy/pull/8030.
#[test]
fn world_query_phantom_data() {
Expand Down