diff --git a/crates/bevy_ecs/src/query/iter.rs b/crates/bevy_ecs/src/query/iter.rs index 2cfcf37031c1d..ae6eab9162321 100644 --- a/crates/bevy_ecs/src/query/iter.rs +++ b/crates/bevy_ecs/src/query/iter.rs @@ -41,54 +41,6 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIter<'w, 's, D, F> { } } - /// Executes the equivalent of [`Iterator::for_each`] over a contiguous segment - /// from a table. - /// - /// # Safety - /// - all `rows` must be in `[0, table.entity_count)`. - /// - `table` must match D and F - /// - Both `D::IS_DENSE` and `F::IS_DENSE` must be true. - #[inline] - #[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))] - pub(super) unsafe fn for_each_in_table_range( - &mut self, - func: &mut Func, - table: &'w Table, - rows: Range, - ) where - Func: FnMut(D::Item<'w>), - { - // SAFETY: Caller assures that D::IS_DENSE and F::IS_DENSE are true, that table matches D and F - // and all indices in rows are in range. - unsafe { - self.fold_over_table_range((), &mut |_, item| func(item), table, rows); - } - } - - /// Executes the equivalent of [`Iterator::for_each`] over a contiguous segment - /// from an archetype. - /// - /// # Safety - /// - all `indices` must be in `[0, archetype.len())`. - /// - `archetype` must match D and F - /// - Either `D::IS_DENSE` or `F::IS_DENSE` must be false. - #[inline] - #[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))] - pub(super) unsafe fn for_each_in_archetype_range( - &mut self, - func: &mut Func, - archetype: &'w Archetype, - rows: Range, - ) where - Func: FnMut(D::Item<'w>), - { - // SAFETY: Caller assures that either D::IS_DENSE or F::IS_DENSE are false, that archetype matches D and F - // and all indices in rows are in range. - unsafe { - self.fold_over_archetype_range((), &mut |_, item| func(item), archetype, rows); - } - } - /// Executes the equivalent of [`Iterator::fold`] over a contiguous segment /// from an table. /// @@ -752,7 +704,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIterationCursor<'w, 's, D, F> { } // NOTE: If you are changing query iteration code, remember to update the following places, where relevant: - // QueryIter, QueryIterationCursor, QueryManyIter, QueryCombinationIter, QueryState::par_for_each_unchecked_manual + // QueryIter, QueryIterationCursor, QueryManyIter, QueryCombinationIter, QueryState::par_fold_init_unchecked_manual /// # Safety /// `tables` and `archetypes` must belong to the same world that the [`QueryIterationCursor`] /// was initialized for. diff --git a/crates/bevy_ecs/src/query/par_iter.rs b/crates/bevy_ecs/src/query/par_iter.rs index 164165cf9e01e..00e474e81b2ac 100644 --- a/crates/bevy_ecs/src/query/par_iter.rs +++ b/crates/bevy_ecs/src/query/par_iter.rs @@ -35,8 +35,52 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryParIter<'w, 's, D, F> { /// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool #[inline] pub fn for_each) + Send + Sync + Clone>(self, func: FN) { + self.for_each_init(|| {}, |_, item| func(item)); + } + + /// Runs `func` on each query result in parallel on a value returned by `init`. + /// + /// `init` may be called multiple times per thread, and the values returned may be discarded between tasks on any given thread. + /// Callers should avoid using this function as if it were a a parallel version + /// of [`Iterator::fold`]. + /// + /// # Example + /// + /// ``` + /// use bevy_utils::Parallel; + /// use crate::{bevy_ecs::prelude::Component, bevy_ecs::system::Query}; + /// #[derive(Component)] + /// struct T; + /// fn system(query: Query<&T>){ + /// let mut queue: Parallel = Parallel::default(); + /// // queue.borrow_local_mut() will get or create a thread_local queue for each task/thread; + /// query.par_iter().for_each_init(|| queue.borrow_local_mut(),|local_queue,item| { + /// **local_queue += 1; + /// }); + /// + /// // collect value from every thread + /// let entity_count: usize = queue.iter_mut().map(|v| *v).sum(); + /// } + /// ``` + /// + /// # Panics + /// If the [`ComputeTaskPool`] is not initialized. If using this from a query that is being + /// initialized and run from the ECS scheduler, this should never panic. + /// + /// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool + #[inline] + pub fn for_each_init(self, init: INIT, func: FN) + where + FN: Fn(&mut T, QueryItem<'w, D>) + Send + Sync + Clone, + INIT: Fn() -> T + Sync + Send + Clone, + { + let func = |mut init, item| { + func(&mut init, item); + init + }; #[cfg(any(target_arch = "wasm32", not(feature = "multi-threaded")))] { + let init = init(); // SAFETY: // This method can only be called once per instance of QueryParIter, // which ensures that mutable queries cannot be executed multiple times at once. @@ -46,25 +90,27 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryParIter<'w, 's, D, F> { unsafe { self.state .iter_unchecked_manual(self.world, self.last_run, self.this_run) - .for_each(func); + .fold(init, func); } } #[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))] { let thread_count = bevy_tasks::ComputeTaskPool::get().thread_num(); if thread_count <= 1 { + let init = init(); // SAFETY: See the safety comment above. unsafe { self.state .iter_unchecked_manual(self.world, self.last_run, self.this_run) - .for_each(func); + .fold(init, func); } } else { // Need a batch size of at least 1. let batch_size = self.get_batch_size(thread_count).max(1); // SAFETY: See the safety comment above. unsafe { - self.state.par_for_each_unchecked_manual( + self.state.par_fold_init_unchecked_manual( + init, self.world, batch_size, func, diff --git a/crates/bevy_ecs/src/query/state.rs b/crates/bevy_ecs/src/query/state.rs index 69a4e5778d1a1..5e52178837072 100644 --- a/crates/bevy_ecs/src/query/state.rs +++ b/crates/bevy_ecs/src/query/state.rs @@ -1394,19 +1394,20 @@ impl QueryState { /// /// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool #[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))] - pub(crate) unsafe fn par_for_each_unchecked_manual< - 'w, - FN: Fn(D::Item<'w>) + Send + Sync + Clone, - >( + pub(crate) unsafe fn par_fold_init_unchecked_manual<'w, T, FN, INIT>( &self, + init_accum: INIT, world: UnsafeWorldCell<'w>, batch_size: usize, func: FN, last_run: Tick, this_run: Tick, - ) { + ) where + FN: Fn(T, D::Item<'w>) -> T + Send + Sync + Clone, + INIT: Fn() -> T + Sync + Send + Clone, + { // NOTE: If you are changing query iteration code, remember to update the following places, where relevant: - // QueryIter, QueryIterationCursor, QueryManyIter, QueryCombinationIter, QueryState::for_each_unchecked_manual, QueryState::par_for_each_unchecked_manual + // QueryIter, QueryIterationCursor, QueryManyIter, QueryCombinationIter,QueryState::par_fold_init_unchecked_manual use arrayvec::ArrayVec; bevy_tasks::ComputeTaskPool::get().scope(|scope| { @@ -1423,19 +1424,27 @@ impl QueryState { } let queue = std::mem::take(queue); let mut func = func.clone(); + let init_accum = init_accum.clone(); scope.spawn(async move { #[cfg(feature = "trace")] let _span = self.par_iter_span.enter(); let mut iter = self.iter_unchecked_manual(world, last_run, this_run); + let mut accum = init_accum(); for storage_id in queue { if D::IS_DENSE && F::IS_DENSE { let id = storage_id.table_id; let table = &world.storages().tables.get(id).debug_checked_unwrap(); - iter.for_each_in_table_range(&mut func, table, 0..table.entity_count()); + accum = iter.fold_over_table_range( + accum, + &mut func, + table, + 0..table.entity_count(), + ); } else { let id = storage_id.archetype_id; let archetype = world.archetypes().get(id).debug_checked_unwrap(); - iter.for_each_in_archetype_range( + accum = iter.fold_over_archetype_range( + accum, &mut func, archetype, 0..archetype.len(), @@ -1449,21 +1458,23 @@ impl QueryState { let submit_single = |count, storage_id: StorageId| { for offset in (0..count).step_by(batch_size) { let mut func = func.clone(); + let init_accum = init_accum.clone(); let len = batch_size.min(count - offset); let batch = offset..offset + len; scope.spawn(async move { #[cfg(feature = "trace")] let _span = self.par_iter_span.enter(); + let accum = init_accum(); if D::IS_DENSE && F::IS_DENSE { let id = storage_id.table_id; let table = world.storages().tables.get(id).debug_checked_unwrap(); self.iter_unchecked_manual(world, last_run, this_run) - .for_each_in_table_range(&mut func, table, batch); + .fold_over_table_range(accum, &mut func, table, batch); } else { let id = storage_id.archetype_id; let archetype = world.archetypes().get(id).debug_checked_unwrap(); self.iter_unchecked_manual(world, last_run, this_run) - .for_each_in_archetype_range(&mut func, archetype, batch); + .fold_over_archetype_range(accum, &mut func, archetype, batch); } }); } diff --git a/crates/bevy_pbr/src/render/mesh.rs b/crates/bevy_pbr/src/render/mesh.rs index ed5a4d7549434..959c77ed6881a 100644 --- a/crates/bevy_pbr/src/render/mesh.rs +++ b/crates/bevy_pbr/src/render/mesh.rs @@ -595,8 +595,10 @@ pub fn extract_meshes_for_cpu_building( )>, >, ) { - meshes_query.par_iter().for_each( - |( + meshes_query.par_iter().for_each_init( + || render_mesh_instance_queues.borrow_local_mut(), + |queue, + ( entity, view_visibility, transform, @@ -621,23 +623,19 @@ pub fn extract_meshes_for_cpu_building( no_automatic_batching, ); - render_mesh_instance_queues.scope(|queue| { - let transform = transform.affine(); - queue.push(( - entity, - RenderMeshInstanceCpu { - transforms: MeshTransforms { - transform: (&transform).into(), - previous_transform: (&previous_transform - .map(|t| t.0) - .unwrap_or(transform)) - .into(), - flags: mesh_flags.bits(), - }, - shared, + let transform = transform.affine(); + queue.push(( + entity, + RenderMeshInstanceCpu { + transforms: MeshTransforms { + transform: (&transform).into(), + previous_transform: (&previous_transform.map(|t| t.0).unwrap_or(transform)) + .into(), + flags: mesh_flags.bits(), }, - )); - }); + shared, + }, + )); }, ); @@ -683,8 +681,10 @@ pub fn extract_meshes_for_gpu_building( )>, >, ) { - meshes_query.par_iter().for_each( - |( + meshes_query.par_iter().for_each_init( + || render_mesh_instance_queues.borrow_local_mut(), + |queue, + ( entity, view_visibility, transform, @@ -713,17 +713,15 @@ pub fn extract_meshes_for_gpu_building( let lightmap_uv_rect = lightmap::pack_lightmap_uv_rect(lightmap.map(|lightmap| lightmap.uv_rect)); - render_mesh_instance_queues.scope(|queue| { - queue.push(( - entity, - RenderMeshInstanceGpuBuilder { - shared, - transform: (&transform.affine()).into(), - lightmap_uv_rect, - mesh_flags, - }, - )); - }); + queue.push(( + entity, + RenderMeshInstanceGpuBuilder { + shared, + transform: (&transform.affine()).into(), + lightmap_uv_rect, + mesh_flags, + }, + )); }, ); diff --git a/crates/bevy_render/src/view/visibility/mod.rs b/crates/bevy_render/src/view/visibility/mod.rs index a57a5f8414b0e..aa2982ca423a1 100644 --- a/crates/bevy_render/src/view/visibility/mod.rs +++ b/crates/bevy_render/src/view/visibility/mod.rs @@ -453,52 +453,53 @@ pub fn check_visibility( let view_mask = maybe_view_mask.copied().unwrap_or_default(); - visible_aabb_query.par_iter_mut().for_each(|query_item| { - let ( - entity, - inherited_visibility, - mut view_visibility, - maybe_entity_mask, - maybe_model_aabb, - transform, - no_frustum_culling, - ) = query_item; - - // Skip computing visibility for entities that are configured to be hidden. - // ViewVisibility has already been reset in `reset_view_visibility`. - if !inherited_visibility.get() { - return; - } + visible_aabb_query.par_iter_mut().for_each_init( + || thread_queues.borrow_local_mut(), + |queue, query_item| { + let ( + entity, + inherited_visibility, + mut view_visibility, + maybe_entity_mask, + maybe_model_aabb, + transform, + no_frustum_culling, + ) = query_item; + + // Skip computing visibility for entities that are configured to be hidden. + // ViewVisibility has already been reset in `reset_view_visibility`. + if !inherited_visibility.get() { + return; + } - let entity_mask = maybe_entity_mask.copied().unwrap_or_default(); - if !view_mask.intersects(&entity_mask) { - return; - } + let entity_mask = maybe_entity_mask.copied().unwrap_or_default(); + if !view_mask.intersects(&entity_mask) { + return; + } - // If we have an aabb, do frustum culling - if !no_frustum_culling { - if let Some(model_aabb) = maybe_model_aabb { - let model = transform.affine(); - let model_sphere = Sphere { - center: model.transform_point3a(model_aabb.center), - radius: transform.radius_vec3a(model_aabb.half_extents), - }; - // Do quick sphere-based frustum culling - if !frustum.intersects_sphere(&model_sphere, false) { - return; - } - // Do aabb-based frustum culling - if !frustum.intersects_obb(model_aabb, &model, true, false) { - return; + // If we have an aabb, do frustum culling + if !no_frustum_culling { + if let Some(model_aabb) = maybe_model_aabb { + let model = transform.affine(); + let model_sphere = Sphere { + center: model.transform_point3a(model_aabb.center), + radius: transform.radius_vec3a(model_aabb.half_extents), + }; + // Do quick sphere-based frustum culling + if !frustum.intersects_sphere(&model_sphere, false) { + return; + } + // Do aabb-based frustum culling + if !frustum.intersects_obb(model_aabb, &model, true, false) { + return; + } } } - } - view_visibility.set(); - thread_queues.scope(|queue| { + view_visibility.set(); queue.push(entity); - }); - }); + }, + ); visible_entities.clear::(); thread_queues.drain_into(visible_entities.get_mut::()); diff --git a/crates/bevy_utils/src/parallel_queue.rs b/crates/bevy_utils/src/parallel_queue.rs index a143c3cf528d3..af1b038d794ab 100644 --- a/crates/bevy_utils/src/parallel_queue.rs +++ b/crates/bevy_utils/src/parallel_queue.rs @@ -1,4 +1,4 @@ -use core::cell::Cell; +use std::{cell::RefCell, ops::DerefMut}; use thread_local::ThreadLocal; /// A cohesive set of thread-local values of a given type. @@ -6,9 +6,10 @@ use thread_local::ThreadLocal; /// Mutable references can be fetched if `T: Default` via [`Parallel::scope`]. #[derive(Default)] pub struct Parallel { - locals: ThreadLocal>, + locals: ThreadLocal>, } +/// A scope guard of a `Parallel`, when this struct is dropped ,the value will writeback to its `Parallel` impl Parallel { /// Gets a mutable iterator over all of the per-thread queues. pub fn iter_mut(&mut self) -> impl Iterator { @@ -26,12 +27,17 @@ impl Parallel { /// /// If there is no thread-local value, it will be initialized to its default. pub fn scope(&self, f: impl FnOnce(&mut T) -> R) -> R { - let cell = self.locals.get_or_default(); - let mut value = cell.take(); - let ret = f(&mut value); - cell.set(value); + let mut cell = self.locals.get_or_default().borrow_mut(); + let ret = f(cell.deref_mut()); ret } + + /// Mutably borrows the thread-local value. + /// + /// If there is no thread-local value, it will be initialized to it's default. + pub fn borrow_local_mut(&self) -> impl DerefMut + '_ { + self.locals.get_or_default().borrow_mut() + } } impl Parallel