Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store only the IDs needed for Query iteration #12476

Merged
merged 19 commits into from
Mar 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 21 additions & 26 deletions crates/bevy_ecs/src/query/iter.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use crate::{
archetype::{Archetype, ArchetypeEntity, ArchetypeId, Archetypes},
archetype::{Archetype, ArchetypeEntity, Archetypes},
component::Tick,
entity::{Entities, Entity},
query::{ArchetypeFilter, DebugCheckedUnwrap, QueryState},
storage::{Table, TableId, TableRow, Tables},
query::{ArchetypeFilter, DebugCheckedUnwrap, QueryState, StorageId},
storage::{Table, TableRow, Tables},
world::unsafe_world_cell::UnsafeWorldCell,
};
use std::{borrow::Borrow, iter::FusedIterator, mem::MaybeUninit, ops::Range};
Expand Down Expand Up @@ -239,22 +239,20 @@ impl<'w, 's, D: QueryData, F: QueryFilter> Iterator for QueryIter<'w, 's, D, F>
let Some(item) = self.next() else { break };
accum = func(accum, item);
}
if D::IS_DENSE && F::IS_DENSE {
for table_id in self.cursor.table_id_iter.clone() {
for id in self.cursor.storage_id_iter.clone() {
if D::IS_DENSE && F::IS_DENSE {
// SAFETY: Matched table IDs are guaranteed to still exist.
let table = unsafe { self.tables.get(*table_id).debug_checked_unwrap() };
let table = unsafe { self.tables.get(id.table_id).debug_checked_unwrap() };
accum =
// SAFETY:
// - The fetched table matches both D and F
// - The provided range is equivalent to [0, table.entity_count)
// - The if block ensures that D::IS_DENSE and F::IS_DENSE are both true
unsafe { self.fold_over_table_range(accum, &mut func, table, 0..table.entity_count()) };
}
} else {
for archetype_id in self.cursor.archetype_id_iter.clone() {
} else {
let archetype =
// SAFETY: Matched archetype IDs are guaranteed to still exist.
unsafe { self.archetypes.get(*archetype_id).debug_checked_unwrap() };
unsafe { self.archetypes.get(id.archetype_id).debug_checked_unwrap() };
accum =
// SAFETY:
// - The fetched archetype matches both D and F
Expand Down Expand Up @@ -650,8 +648,7 @@ impl<'w, 's, D: ReadOnlyQueryData, F: QueryFilter, const K: usize> FusedIterator
}

struct QueryIterationCursor<'w, 's, D: QueryData, F: QueryFilter> {
table_id_iter: std::slice::Iter<'s, TableId>,
archetype_id_iter: std::slice::Iter<'s, ArchetypeId>,
storage_id_iter: std::slice::Iter<'s, StorageId>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that mean that the table_entities and archetype_entities below could have a similar optiimization as storage_entities: &'w [StorageEntities]?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what #5085 does, however the benefit may be limited as this struct doesn't really exist at runtime: it's never formally materialized to the stack or the heap under normal use cases and any inlined iteration will decompose the fetches and updates.

Compare this with the Vec in QueryState, which requires both stack and heap space due to being a persisted heap allocated backing for Query. There's real memory savings by using the union.

Doable but I'm not sure if the savings are worth the introduction of more unsafe and readability impact.

table_entities: &'w [Entity],
archetype_entities: &'w [ArchetypeEntity],
fetch: D::Fetch<'w>,
Expand All @@ -665,8 +662,7 @@ struct QueryIterationCursor<'w, 's, D: QueryData, F: QueryFilter> {
impl<D: QueryData, F: QueryFilter> Clone for QueryIterationCursor<'_, '_, D, F> {
fn clone(&self) -> Self {
Self {
table_id_iter: self.table_id_iter.clone(),
archetype_id_iter: self.archetype_id_iter.clone(),
storage_id_iter: self.storage_id_iter.clone(),
table_entities: self.table_entities,
archetype_entities: self.archetype_entities,
fetch: self.fetch.clone(),
Expand All @@ -687,8 +683,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIterationCursor<'w, 's, D, F> {
this_run: Tick,
) -> Self {
QueryIterationCursor {
table_id_iter: [].iter(),
archetype_id_iter: [].iter(),
storage_id_iter: [].iter(),
..Self::init(world, query_state, last_run, this_run)
}
}
Expand All @@ -709,8 +704,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIterationCursor<'w, 's, D, F> {
filter,
table_entities: &[],
archetype_entities: &[],
table_id_iter: query_state.matched_table_ids.iter(),
archetype_id_iter: query_state.matched_archetype_ids.iter(),
storage_id_iter: query_state.matched_storage_ids.iter(),
current_len: 0,
current_row: 0,
}
Expand Down Expand Up @@ -746,12 +740,13 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIterationCursor<'w, 's, D, F> {
/// Note that if `D::IS_ARCHETYPAL && F::IS_ARCHETYPAL`, the return value
/// will be **the exact count of remaining values**.
fn max_remaining(&self, tables: &'w Tables, archetypes: &'w Archetypes) -> usize {
let ids = self.storage_id_iter.clone();
let remaining_matched: usize = if Self::IS_DENSE {
let ids = self.table_id_iter.clone();
ids.map(|id| tables[*id].entity_count()).sum()
// SAFETY: The if check ensures that storage_id_iter stores TableIds
unsafe { ids.map(|id| tables[id.table_id].entity_count()).sum() }
} else {
let ids = self.archetype_id_iter.clone();
ids.map(|id| archetypes[*id].len()).sum()
// SAFETY: The if check ensures that storage_id_iter stores ArchetypeIds
unsafe { ids.map(|id| archetypes[id.archetype_id].len()).sum() }
};
remaining_matched + self.current_len - self.current_row
}
Expand All @@ -773,8 +768,8 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIterationCursor<'w, 's, D, F> {
loop {
// we are on the beginning of the query, or finished processing a table, so skip to the next
if self.current_row == self.current_len {
let table_id = self.table_id_iter.next()?;
let table = tables.get(*table_id).debug_checked_unwrap();
let table_id = self.storage_id_iter.next()?.table_id;
let table = tables.get(table_id).debug_checked_unwrap();
// SAFETY: `table` is from the world that `fetch/filter` were created for,
// `fetch_state`/`filter_state` are the states that `fetch/filter` were initialized with
unsafe {
Expand Down Expand Up @@ -809,8 +804,8 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIterationCursor<'w, 's, D, F> {
} else {
loop {
if self.current_row == self.current_len {
let archetype_id = self.archetype_id_iter.next()?;
let archetype = archetypes.get(*archetype_id).debug_checked_unwrap();
let archetype_id = self.storage_id_iter.next()?.archetype_id;
let archetype = archetypes.get(archetype_id).debug_checked_unwrap();
let table = tables.get(archetype.table_id()).debug_checked_unwrap();
// SAFETY: `archetype` and `tables` are from the world that `fetch/filter` were created for,
// `fetch_state`/`filter_state` are the states that `fetch/filter` were initialized with
Expand Down
18 changes: 8 additions & 10 deletions crates/bevy_ecs/src/query/par_iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,24 +160,22 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryParIter<'w, 's, D, F> {
thread_count > 0,
"Attempted to run parallel iteration over a query with an empty TaskPool"
);
let id_iter = self.state.matched_storage_ids.iter();
let max_size = if D::IS_DENSE && F::IS_DENSE {
// SAFETY: We only access table metadata.
let tables = unsafe { &self.world.world_metadata().storages().tables };
self.state
.matched_table_ids
.iter()
.map(|id| tables[*id].entity_count())
id_iter
// SAFETY: The if check ensures that matched_storage_ids stores TableIds
.map(|id| unsafe { tables[id.table_id].entity_count() })
.max()
.unwrap_or(0)
} else {
let archetypes = &self.world.archetypes();
self.state
.matched_archetype_ids
.iter()
.map(|id| archetypes[*id].len())
id_iter
// SAFETY: The if check ensures that matched_storage_ids stores ArchetypeIds
.map(|id| unsafe { archetypes[id.archetype_id].len() })
.max()
.unwrap_or(0)
};
let max_size = max_size.unwrap_or(0);

let batches = thread_count * self.batching_strategy.batches_per_thread;
// Round up to the nearest batch size.
Expand Down
131 changes: 80 additions & 51 deletions crates/bevy_ecs/src/query/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,25 @@ use super::{
QuerySingleError, ROQueryItem,
};

/// An ID for either a table or an archetype. Used for Query iteration.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if it would be useful to specify that this is used for optimizing query iteration; in the case where all components are in tables, we can iterate through table_ids directly instead of archetypes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also maybe add a comment on why this being a union is ok here, as unions are rarer than enums. It's because we know which variant to use depending on QueryState::IS_DENSE so we can skip storing the variant type?

///
/// Query iteration is exclusively dense (over tables) or archetypal (over archetypes) based on whether
/// both `D::IS_DENSE` and `F::IS_DENSE` are true or not.
///
/// This is a union instead of an enum as the usage is determined at compile time, as all [`StorageId`]s for
/// a [`QueryState`] will be all [`TableId`]s or all [`ArchetypeId`]s, and not a mixture of both. This
/// removes the need for discriminator to minimize memory usage and branching during iteration, but requires
/// a safety invariant be verified when disambiguating them.
///
/// # Safety
/// Must be initialized and accessed as a [`TableId`], if both generic parameters to the query are dense.
/// Must be initialized and accessed as an [`ArchetypeId`] otherwise.
#[derive(Clone, Copy)]
pub(super) union StorageId {
pub(super) table_id: TableId,
pub(super) archetype_id: ArchetypeId,
}

/// Provides scoped access to a [`World`] state according to a given [`QueryData`] and [`QueryFilter`].
#[repr(C)]
// SAFETY NOTE:
Expand All @@ -32,10 +51,8 @@ pub struct QueryState<D: QueryData, F: QueryFilter = ()> {
pub(crate) matched_tables: FixedBitSet,
pub(crate) matched_archetypes: FixedBitSet,
pub(crate) component_access: FilteredAccess<ComponentId>,
// NOTE: we maintain both a TableId bitset and a vec because iterating the vec is faster
pub(crate) matched_table_ids: Vec<TableId>,
// NOTE: we maintain both a ArchetypeId bitset and a vec because iterating the vec is faster
pub(crate) matched_archetype_ids: Vec<ArchetypeId>,
// NOTE: we maintain both a bitset and a vec because iterating the vec is faster
pub(super) matched_storage_ids: Vec<StorageId>,
pub(crate) fetch_state: D::State,
pub(crate) filter_state: F::State,
#[cfg(feature = "trace")]
Expand All @@ -46,8 +63,11 @@ impl<D: QueryData, F: QueryFilter> fmt::Debug for QueryState<D, F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("QueryState")
.field("world_id", &self.world_id)
.field("matched_table_count", &self.matched_table_ids.len())
.field("matched_archetype_count", &self.matched_archetype_ids.len())
.field("matched_table_count", &self.matched_tables.count_ones(..))
.field(
"matched_archetype_count",
&self.matched_archetypes.count_ones(..),
)
.finish_non_exhaustive()
}
}
Expand Down Expand Up @@ -101,13 +121,13 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
}

/// Returns the tables matched by this query.
pub fn matched_tables(&self) -> &[TableId] {
&self.matched_table_ids
pub fn matched_tables(&self) -> impl Iterator<Item = TableId> + '_ {
self.matched_tables.ones().map(TableId::from_usize)
}

/// Returns the archetypes matched by this query.
pub fn matched_archetypes(&self) -> &[ArchetypeId] {
&self.matched_archetype_ids
pub fn matched_archetypes(&self) -> impl Iterator<Item = ArchetypeId> + '_ {
self.matched_archetypes.ones().map(ArchetypeId::new)
}
}

Expand Down Expand Up @@ -158,8 +178,7 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
Self {
world_id: world.id(),
archetype_generation: ArchetypeGeneration::initial(),
matched_table_ids: Vec::new(),
matched_archetype_ids: Vec::new(),
matched_storage_ids: Vec::new(),
fetch_state,
filter_state,
component_access,
Expand All @@ -183,8 +202,7 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
let mut state = Self {
world_id: builder.world().id(),
archetype_generation: ArchetypeGeneration::initial(),
matched_table_ids: Vec::new(),
matched_archetype_ids: Vec::new(),
matched_storage_ids: Vec::new(),
fetch_state,
filter_state,
component_access: builder.access().clone(),
Expand Down Expand Up @@ -338,12 +356,20 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
let archetype_index = archetype.id().index();
if !self.matched_archetypes.contains(archetype_index) {
self.matched_archetypes.grow_and_insert(archetype_index);
self.matched_archetype_ids.push(archetype.id());
if !D::IS_DENSE || !F::IS_DENSE {
self.matched_storage_ids.push(StorageId {
archetype_id: archetype.id(),
});
}
}
let table_index = archetype.table_id().as_usize();
if !self.matched_tables.contains(table_index) {
self.matched_tables.grow_and_insert(table_index);
self.matched_table_ids.push(archetype.table_id());
if D::IS_DENSE && F::IS_DENSE {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still also need separate self.matched_tables and self.matched_archetypes bitsets?
It looks like only one of the two will be used, depending on the value of D::IS_DENSE and F::IS_DENSE

Maybe it's still needed for things like join?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep exactly, join needs them, and get_unchecked_manual needs specifically matched_archetypes (right now).

self.matched_storage_ids.push(StorageId {
table_id: archetype.table_id(),
});
}
}
true
} else {
Expand Down Expand Up @@ -424,8 +450,7 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
QueryState {
world_id: self.world_id,
archetype_generation: self.archetype_generation,
matched_table_ids: self.matched_table_ids.clone(),
matched_archetype_ids: self.matched_archetype_ids.clone(),
matched_storage_ids: self.matched_storage_ids.clone(),
fetch_state,
filter_state,
component_access: self.component_access.clone(),
Expand Down Expand Up @@ -515,24 +540,30 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
}

// take the intersection of the matched ids
let matched_tables: FixedBitSet = self
.matched_tables
.intersection(&other.matched_tables)
.collect();
let matched_table_ids: Vec<TableId> =
matched_tables.ones().map(TableId::from_usize).collect();
let matched_archetypes: FixedBitSet = self
.matched_archetypes
.intersection(&other.matched_archetypes)
.collect();
let matched_archetype_ids: Vec<ArchetypeId> =
matched_archetypes.ones().map(ArchetypeId::new).collect();
let mut matched_tables = self.matched_tables.clone();
let mut matched_archetypes = self.matched_archetypes.clone();
matched_tables.intersect_with(&other.matched_tables);
matched_archetypes.intersect_with(&other.matched_archetypes);
let matched_storage_ids = if NewD::IS_DENSE && NewF::IS_DENSE {
matched_tables
.ones()
.map(|id| StorageId {
table_id: TableId::from_usize(id),
})
.collect()
} else {
matched_archetypes
.ones()
.map(|id| StorageId {
archetype_id: ArchetypeId::new(id),
})
.collect()
};

QueryState {
world_id: self.world_id,
archetype_generation: self.archetype_generation,
matched_table_ids,
matched_archetype_ids,
matched_storage_ids,
fetch_state: new_fetch_state,
filter_state: new_filter_state,
component_access: joined_component_access,
Expand Down Expand Up @@ -1306,12 +1337,15 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
) {
// NOTE: If you are changing query iteration code, remember to update the following places, where relevant:
// QueryIter, QueryIterationCursor, QueryManyIter, QueryCombinationIter, QueryState::for_each_unchecked_manual, QueryState::par_for_each_unchecked_manual

bevy_tasks::ComputeTaskPool::get().scope(|scope| {
if D::IS_DENSE && F::IS_DENSE {
// SAFETY: We only access table data that has been registered in `self.archetype_component_access`.
let tables = unsafe { &world.storages().tables };
for table_id in &self.matched_table_ids {
let table = &tables[*table_id];
// SAFETY: We only access table data that has been registered in `self.archetype_component_access`.
let tables = unsafe { &world.storages().tables };
let archetypes = world.archetypes();
for storage_id in &self.matched_storage_ids {
if D::IS_DENSE && F::IS_DENSE {
let table_id = storage_id.table_id;
let table = &tables[table_id];
if table.is_empty() {
continue;
}
Expand All @@ -1320,39 +1354,34 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
while offset < table.entity_count() {
let mut func = func.clone();
let len = batch_size.min(table.entity_count() - offset);
let batch = offset..offset + len;
scope.spawn(async move {
#[cfg(feature = "trace")]
let _span = self.par_iter_span.enter();
let table = &world
.storages()
.tables
.get(*table_id)
.debug_checked_unwrap();
let batch = offset..offset + len;
let table =
&world.storages().tables.get(table_id).debug_checked_unwrap();
self.iter_unchecked_manual(world, last_run, this_run)
.for_each_in_table_range(&mut func, table, batch);
});
offset += batch_size;
}
}
} else {
let archetypes = world.archetypes();
for archetype_id in &self.matched_archetype_ids {
let mut offset = 0;
let archetype = &archetypes[*archetype_id];
} else {
let archetype_id = storage_id.archetype_id;
let archetype = &archetypes[archetype_id];
if archetype.is_empty() {
continue;
}

let mut offset = 0;
while offset < archetype.len() {
let mut func = func.clone();
let len = batch_size.min(archetype.len() - offset);
let batch = offset..offset + len;
scope.spawn(async move {
#[cfg(feature = "trace")]
let _span = self.par_iter_span.enter();
let archetype =
world.archetypes().get(*archetype_id).debug_checked_unwrap();
let batch = offset..offset + len;
world.archetypes().get(archetype_id).debug_checked_unwrap();
self.iter_unchecked_manual(world, last_run, this_run)
.for_each_in_archetype_range(&mut func, archetype, batch);
});
Expand Down