Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve par_iter and Parallel #12904

Merged
merged 14 commits into from
Apr 23, 2024
50 changes: 1 addition & 49 deletions crates/bevy_ecs/src/query/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,54 +41,6 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIter<'w, 's, D, F> {
}
}

/// Executes the equivalent of [`Iterator::for_each`] over a contiguous segment
/// from a table.
///
/// # Safety
/// - all `rows` must be in `[0, table.entity_count)`.
/// - `table` must match D and F
/// - Both `D::IS_DENSE` and `F::IS_DENSE` must be true.
#[inline]
#[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))]
pub(super) unsafe fn for_each_in_table_range<Func>(
&mut self,
func: &mut Func,
table: &'w Table,
rows: Range<usize>,
) where
Func: FnMut(D::Item<'w>),
{
// SAFETY: Caller assures that D::IS_DENSE and F::IS_DENSE are true, that table matches D and F
// and all indices in rows are in range.
unsafe {
self.fold_over_table_range((), &mut |_, item| func(item), table, rows);
}
}

/// Executes the equivalent of [`Iterator::for_each`] over a contiguous segment
/// from an archetype.
///
/// # Safety
/// - all `indices` must be in `[0, archetype.len())`.
/// - `archetype` must match D and F
/// - Either `D::IS_DENSE` or `F::IS_DENSE` must be false.
#[inline]
#[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))]
pub(super) unsafe fn for_each_in_archetype_range<Func>(
&mut self,
func: &mut Func,
archetype: &'w Archetype,
rows: Range<usize>,
) where
Func: FnMut(D::Item<'w>),
{
// SAFETY: Caller assures that either D::IS_DENSE or F::IS_DENSE are false, that archetype matches D and F
// and all indices in rows are in range.
unsafe {
self.fold_over_archetype_range((), &mut |_, item| func(item), archetype, rows);
}
}

/// Executes the equivalent of [`Iterator::fold`] over a contiguous segment
/// from an table.
///
Expand Down Expand Up @@ -752,7 +704,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIterationCursor<'w, 's, D, F> {
}

// NOTE: If you are changing query iteration code, remember to update the following places, where relevant:
// QueryIter, QueryIterationCursor, QueryManyIter, QueryCombinationIter, QueryState::par_for_each_unchecked_manual
// QueryIter, QueryIterationCursor, QueryManyIter, QueryCombinationIter, QueryState::par_fold_init_unchecked_manual
/// # Safety
/// `tables` and `archetypes` must belong to the same world that the [`QueryIterationCursor`]
/// was initialized for.
Expand Down
52 changes: 49 additions & 3 deletions crates/bevy_ecs/src/query/par_iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,52 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryParIter<'w, 's, D, F> {
/// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool
#[inline]
pub fn for_each<FN: Fn(QueryItem<'w, D>) + Send + Sync + Clone>(self, func: FN) {
self.for_each_init(|| {}, |_, item| func(item));
}

/// Runs `func` on each query result in parallel on a value returned by `init`.
///
/// `init` function will be called only when necessary for a value to
/// be paired with the group of items in each bevy's task.
/// its useful to init a thread-local value for each task.
re0312 marked this conversation as resolved.
Show resolved Hide resolved
///
/// # Example
///
/// ```
/// use bevy_utils::Parallel;
/// use crate::{bevy_ecs::prelude::Component, bevy_ecs::system::Query};
/// #[derive(Component)]
/// struct T;
/// fn system(query: Query<&T>){
/// let mut queue: Parallel<usize> = Parallel::default();
/// // queue.borrow_local_mut() will get or create a thread_local queue for each task/thread;
/// query.par_iter().for_each_init(|| queue.borrow_local_mut(),|local_queue,item| {
/// **local_queue += 1;
/// });
///
/// // collect value from every thread
/// let entity_count: usize = queue.iter_mut().map(|v| *v).sum();
/// }
/// ```
///
/// # Panics
/// If the [`ComputeTaskPool`] is not initialized. If using this from a query that is being
/// initialized and run from the ECS scheduler, this should never panic.
///
/// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool
#[inline]
pub fn for_each_init<FN, INIT, T>(self, init: INIT, func: FN)
where
FN: Fn(&mut T, QueryItem<'w, D>) + Send + Sync + Clone,
INIT: Fn() -> T + Sync + Send + Clone,
{
let func = |mut init, item| {
func(&mut init, item);
init
};
#[cfg(any(target_arch = "wasm32", not(feature = "multi-threaded")))]
{
let init = init();
// SAFETY:
// This method can only be called once per instance of QueryParIter,
// which ensures that mutable queries cannot be executed multiple times at once.
Expand All @@ -120,25 +164,27 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryParIter<'w, 's, D, F> {
unsafe {
self.state
.iter_unchecked_manual(self.world, self.last_run, self.this_run)
.for_each(func);
.fold(init, func);
}
}
#[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))]
{
let thread_count = bevy_tasks::ComputeTaskPool::get().thread_num();
if thread_count <= 1 {
let init = init();
// SAFETY: See the safety comment above.
unsafe {
self.state
.iter_unchecked_manual(self.world, self.last_run, self.this_run)
.for_each(func);
.fold(init, func);
}
} else {
// Need a batch size of at least 1.
let batch_size = self.get_batch_size(thread_count).max(1);
// SAFETY: See the safety comment above.
unsafe {
self.state.par_for_each_unchecked_manual(
self.state.par_fold_init_unchecked_manual(
init,
self.world,
batch_size,
func,
Expand Down
31 changes: 21 additions & 10 deletions crates/bevy_ecs/src/query/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1374,19 +1374,20 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
///
/// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool
#[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))]
pub(crate) unsafe fn par_for_each_unchecked_manual<
'w,
FN: Fn(D::Item<'w>) + Send + Sync + Clone,
>(
pub(crate) unsafe fn par_fold_init_unchecked_manual<'w, T, FN, INIT>(
&self,
init_accum: INIT,
world: UnsafeWorldCell<'w>,
batch_size: usize,
func: FN,
last_run: Tick,
this_run: Tick,
) {
) where
FN: Fn(T, D::Item<'w>) -> T + Send + Sync + Clone,
INIT: Fn() -> T + Sync + Send + Clone,
{
// NOTE: If you are changing query iteration code, remember to update the following places, where relevant:
// QueryIter, QueryIterationCursor, QueryManyIter, QueryCombinationIter, QueryState::for_each_unchecked_manual, QueryState::par_for_each_unchecked_manual
// QueryIter, QueryIterationCursor, QueryManyIter, QueryCombinationIter,QueryState::par_fold_init_unchecked_manual
use arrayvec::ArrayVec;

bevy_tasks::ComputeTaskPool::get().scope(|scope| {
Expand All @@ -1403,19 +1404,27 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
}
let queue = std::mem::take(queue);
let mut func = func.clone();
let init_accum = init_accum.clone();
scope.spawn(async move {
#[cfg(feature = "trace")]
let _span = self.par_iter_span.enter();
let mut iter = self.iter_unchecked_manual(world, last_run, this_run);
let mut accum = init_accum();
for storage_id in queue {
if D::IS_DENSE && F::IS_DENSE {
let id = storage_id.table_id;
let table = &world.storages().tables.get(id).debug_checked_unwrap();
iter.for_each_in_table_range(&mut func, table, 0..table.entity_count());
accum = iter.fold_over_table_range(
accum,
&mut func,
table,
0..table.entity_count(),
);
} else {
let id = storage_id.archetype_id;
let archetype = world.archetypes().get(id).debug_checked_unwrap();
iter.for_each_in_archetype_range(
accum = iter.fold_over_archetype_range(
accum,
&mut func,
archetype,
0..archetype.len(),
Expand All @@ -1429,21 +1438,23 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
let submit_single = |count, storage_id: StorageId| {
for offset in (0..count).step_by(batch_size) {
let mut func = func.clone();
let init_accum = init_accum.clone();
let len = batch_size.min(count - offset);
let batch = offset..offset + len;
scope.spawn(async move {
#[cfg(feature = "trace")]
let _span = self.par_iter_span.enter();
let accum = init_accum();
if D::IS_DENSE && F::IS_DENSE {
let id = storage_id.table_id;
let table = world.storages().tables.get(id).debug_checked_unwrap();
self.iter_unchecked_manual(world, last_run, this_run)
.for_each_in_table_range(&mut func, table, batch);
.fold_over_table_range(accum, &mut func, table, batch);
} else {
let id = storage_id.archetype_id;
let archetype = world.archetypes().get(id).debug_checked_unwrap();
self.iter_unchecked_manual(world, last_run, this_run)
.for_each_in_archetype_range(&mut func, archetype, batch);
.fold_over_archetype_range(accum, &mut func, archetype, batch);
}
});
}
Expand Down
28 changes: 14 additions & 14 deletions crates/bevy_pbr/src/render/mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,10 @@ pub fn extract_meshes(
)>,
>,
) {
meshes_query.par_iter().for_each(
|(
meshes_query.par_iter().for_each_init(
|| thread_local_queues.borrow_local_mut(),
|queue,
(
entity,
view_visibility,
transform,
Expand Down Expand Up @@ -317,18 +319,16 @@ pub fn extract_meshes(
previous_transform: (&previous_transform).into(),
flags: flags.bits(),
};
thread_local_queues.scope(|queue| {
queue.push((
entity,
RenderMeshInstance {
mesh_asset_id: handle.id(),
transforms,
shadow_caster: !not_shadow_caster,
material_bind_group_id: AtomicMaterialBindGroupId::default(),
automatic_batching: !no_automatic_batching,
},
));
});
queue.push((
entity,
RenderMeshInstance {
mesh_asset_id: handle.id(),
transforms,
shadow_caster: !not_shadow_caster,
material_bind_group_id: AtomicMaterialBindGroupId::default(),
automatic_batching: !no_automatic_batching,
},
));
},
);

Expand Down
81 changes: 41 additions & 40 deletions crates/bevy_render/src/view/visibility/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,52 +395,53 @@ pub fn check_visibility(
let view_mask = maybe_view_mask.copied().unwrap_or_default();

visible_entities.entities.clear();
visible_aabb_query.par_iter_mut().for_each(|query_item| {
let (
entity,
inherited_visibility,
mut view_visibility,
maybe_entity_mask,
maybe_model_aabb,
transform,
no_frustum_culling,
) = query_item;

// Skip computing visibility for entities that are configured to be hidden.
// ViewVisibility has already been reset in `reset_view_visibility`.
if !inherited_visibility.get() {
return;
}
visible_aabb_query.par_iter_mut().for_each_init(
|| thread_queues.borrow_local_mut(),
|queue, query_item| {
let (
entity,
inherited_visibility,
mut view_visibility,
maybe_entity_mask,
maybe_model_aabb,
transform,
no_frustum_culling,
) = query_item;

// Skip computing visibility for entities that are configured to be hidden.
// ViewVisibility has already been reset in `reset_view_visibility`.
if !inherited_visibility.get() {
return;
}

let entity_mask = maybe_entity_mask.copied().unwrap_or_default();
if !view_mask.intersects(&entity_mask) {
return;
}
let entity_mask = maybe_entity_mask.copied().unwrap_or_default();
if !view_mask.intersects(&entity_mask) {
return;
}

// If we have an aabb, do frustum culling
if !no_frustum_culling {
if let Some(model_aabb) = maybe_model_aabb {
let model = transform.affine();
let model_sphere = Sphere {
center: model.transform_point3a(model_aabb.center),
radius: transform.radius_vec3a(model_aabb.half_extents),
};
// Do quick sphere-based frustum culling
if !frustum.intersects_sphere(&model_sphere, false) {
return;
}
// Do aabb-based frustum culling
if !frustum.intersects_obb(model_aabb, &model, true, false) {
return;
// If we have an aabb, do frustum culling
if !no_frustum_culling {
if let Some(model_aabb) = maybe_model_aabb {
let model = transform.affine();
let model_sphere = Sphere {
center: model.transform_point3a(model_aabb.center),
radius: transform.radius_vec3a(model_aabb.half_extents),
};
// Do quick sphere-based frustum culling
if !frustum.intersects_sphere(&model_sphere, false) {
return;
}
// Do aabb-based frustum culling
if !frustum.intersects_obb(model_aabb, &model, true, false) {
return;
}
}
}
}

view_visibility.set();
thread_queues.scope(|queue| {
view_visibility.set();
queue.push(entity);
});
});
},
);

visible_entities.entities.clear();
thread_queues.drain_into(&mut visible_entities.entities);
Expand Down
Loading