Skip to content

Commit

Permalink
Ensure safety of indirect dispatch (#5714)
Browse files Browse the repository at this point in the history
by injecting a compute shader that validates the content of the indirect buffer
  • Loading branch information
teoxoy authored Oct 14, 2024
1 parent c0e3972 commit 7f708ed
Show file tree
Hide file tree
Showing 15 changed files with 913 additions and 22 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ By @bradwerth [#6216](https://github.com/gfx-rs/wgpu/pull/6216).
- Call `flush_mapped_ranges` when unmapping write-mapped buffers. By @teoxoy in [#6089](https://github.com/gfx-rs/wgpu/pull/6089).
- When mapping buffers for reading, mark buffers as initialized only when they have `MAP_WRITE` usage. By @teoxoy in [#6178](https://github.com/gfx-rs/wgpu/pull/6178).
- Add a separate pipeline constants error. By @teoxoy in [#6094](https://github.com/gfx-rs/wgpu/pull/6094).
- Ensure safety of indirect dispatch by injecting a compute shader that validates the content of the indirect buffer. By @teoxoy in [#5714](https://github.com/gfx-rs/wgpu/pull/5714)

#### GLES / OpenGL

Expand Down
241 changes: 241 additions & 0 deletions tests/tests/dispatch_workgroups_indirect.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
use wgpu_test::{gpu_test, FailureCase, GpuTestConfiguration, TestParameters, TestingContext};

/// Make sure that the num_workgroups builtin works properly (it requires a workaround on D3D12).
#[gpu_test]
static NUM_WORKGROUPS_BUILTIN: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.features(wgpu::Features::PUSH_CONSTANTS)
.downlevel_flags(
wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION,
)
.limits(wgpu::Limits {
max_push_constant_size: 4,
..wgpu::Limits::downlevel_defaults()
})
.expect_fail(FailureCase::backend(wgt::Backends::DX12)),
)
.run_async(|ctx| async move {
let num_workgroups = [1, 2, 3];
let res = run_test(&ctx, &num_workgroups, false).await;
assert_eq!(res, num_workgroups);
});

/// Make sure that we discard (don't run) the dispatch if its size exceeds the device limit.
#[gpu_test]
static DISCARD_DISPATCH: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.features(wgpu::Features::PUSH_CONSTANTS)
.downlevel_flags(
wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION,
)
.limits(wgpu::Limits {
max_compute_workgroups_per_dimension: 10,
max_push_constant_size: 4,
..wgpu::Limits::downlevel_defaults()
})
.expect_fail(FailureCase::backend(wgt::Backends::DX12)),
)
.run_async(|ctx| async move {
let max = ctx.device.limits().max_compute_workgroups_per_dimension;

let res = run_test(&ctx, &[max, max, max], false).await;
assert_eq!(res, [max; 3]);

let res = run_test(&ctx, &[max + 1, 1, 1], false).await;
assert_eq!(res, [0; 3]);

let res = run_test(&ctx, &[1, max + 1, 1], false).await;
assert_eq!(res, [0; 3]);

let res = run_test(&ctx, &[1, 1, max + 1], false).await;
assert_eq!(res, [0; 3]);
});

/// Make sure that resetting the bind groups set by the validation code works properly.
#[gpu_test]
static RESET_BIND_GROUPS: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.features(wgpu::Features::PUSH_CONSTANTS)
.downlevel_flags(
wgpu::DownlevelFlags::COMPUTE_SHADERS | wgpu::DownlevelFlags::INDIRECT_EXECUTION,
)
.limits(wgpu::Limits {
max_push_constant_size: 4,
..wgpu::Limits::downlevel_defaults()
}),
)
.run_async(|ctx| async move {
ctx.device.push_error_scope(wgpu::ErrorFilter::Validation);

let _ = run_test(&ctx, &[0, 0, 0], true).await;

let error = pollster::block_on(ctx.device.pop_error_scope());
assert!(error.map_or(false, |error| {
format!("{error}").contains("The current set ComputePipeline with '' label expects a BindGroup to be set at index 0")
}));
});

async fn run_test(
ctx: &TestingContext,
num_workgroups: &[u32; 3],
forget_to_set_bind_group: bool,
) -> [u32; 3] {
const SHADER_SRC: &str = "
struct TestOffsetPc {
inner: u32,
}
// `test_offset.inner` should always be 0; we test that resetting the push constant set by the validation code works properly.
var<push_constant> test_offset: TestOffsetPc;
@group(0) @binding(0)
var<storage, read_write> out: array<u32, 3>;
@compute @workgroup_size(1)
fn main(@builtin(num_workgroups) num_workgroups: vec3u, @builtin(workgroup_id) workgroup_id: vec3u) {
if (all(workgroup_id == vec3u())) {
out[0] = num_workgroups.x + test_offset.inner;
out[1] = num_workgroups.y + test_offset.inner;
out[2] = num_workgroups.z + test_offset.inner;
}
}
";

let module = ctx
.device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(SHADER_SRC.into()),
});

let bgl = ctx
.device
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: None,
entries: &[wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgt::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}],
});

let layout = ctx
.device
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[&bgl],
push_constant_ranges: &[wgt::PushConstantRange {
stages: wgt::ShaderStages::COMPUTE,
range: 0..4,
}],
});

let pipeline = ctx
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: Some(&layout),
module: &module,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});

let out_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 12,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});

let readback_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 12,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});

let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &pipeline.get_bind_group_layout(0),
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: out_buffer.as_entire_binding(),
}],
});

let mut res = None;

for (indirect_offset, indirect_buffer_size) in [
// internal src buffer binding size will be buffer.size
(0, 12),
(4, 4 + 12),
(4, 8 + 12),
(256 * 2 - 4 - 12, 256 * 2 - 4),
// internal src buffer binding size will be 256 * 2 + x
(0, 256 * 2 * 2 + 4),
(256, 256 * 2 * 2 + 8),
(256 + 4, 256 * 2 * 2 + 12),
(256 * 2 + 16, 256 * 2 * 2 + 16),
(256 * 2 * 2, 256 * 2 * 2 + 32),
(256 + 12, 256 * 2 * 2 + 64),
] {
let indirect_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: indirect_buffer_size,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::INDIRECT,
mapped_at_creation: false,
});

ctx.queue.write_buffer(
&indirect_buffer,
indirect_offset,
bytemuck::bytes_of(num_workgroups),
);

let mut encoder = ctx.device.create_command_encoder(&Default::default());
{
let mut compute_pass = encoder.begin_compute_pass(&Default::default());
compute_pass.set_pipeline(&pipeline);
compute_pass.set_push_constants(0, &[0, 0, 0, 0]);
if !forget_to_set_bind_group {
compute_pass.set_bind_group(0, Some(&bind_group), &[]);
}
compute_pass.dispatch_workgroups_indirect(&indirect_buffer, indirect_offset);
}

encoder.copy_buffer_to_buffer(&out_buffer, 0, &readback_buffer, 0, 12);

ctx.queue.submit(Some(encoder.finish()));

readback_buffer
.slice(..)
.map_async(wgpu::MapMode::Read, |_| {});

ctx.async_poll(wgpu::Maintain::wait())
.await
.panic_on_timeout();

let view = readback_buffer.slice(..).get_mapped_range();

let current_res = *bytemuck::from_bytes(&view);
drop(view);
readback_buffer.unmap();

if let Some(past_res) = res {
assert_eq!(past_res, current_res);
} else {
res = Some(current_res);
}
}

res.unwrap()
}
1 change: 1 addition & 0 deletions tests/tests/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ mod clear_texture;
mod compute_pass_ownership;
mod create_surface_error;
mod device;
mod dispatch_workgroups_indirect;
mod encoder;
mod external_texture;
mod float32_filterable;
Expand Down
4 changes: 4 additions & 0 deletions wgpu-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ renderdoc = ["hal/renderdoc"]
## to the validation carried out at public APIs in all builds.
strict_asserts = ["wgt/strict_asserts"]

## Validates indirect draw/dispatch calls. This will also enable naga's
## WGSL frontend since we use a WGSL compute shader to do the validation.
indirect-validation = ["naga/wgsl-in"]

## Enables serialization via `serde` on common wgpu types.
serde = ["dep:serde", "wgt/serde", "arrayvec/serde"]

Expand Down
20 changes: 16 additions & 4 deletions wgpu-core/src/command/bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,17 @@ mod compat {
entries: (0..hal::MAX_BIND_GROUPS).map(|_| Entry::empty()).collect(),
}
}
fn make_range(&self, start_index: usize) -> Range<usize> {

pub fn num_valid_entries(&self) -> usize {
// find first incompatible entry
let end = self
.entries
self.entries
.iter()
.position(|e| e.is_incompatible())
.unwrap_or(self.entries.len());
.unwrap_or(self.entries.len())
}

fn make_range(&self, start_index: usize) -> Range<usize> {
let end = self.num_valid_entries();
start_index..end.max(start_index)
}

Expand Down Expand Up @@ -406,6 +410,14 @@ impl Binder {
.map(move |index| payloads[index].group.as_ref().unwrap())
}

#[cfg(feature = "indirect-validation")]
pub(super) fn list_valid<'a>(&'a self) -> impl Iterator<Item = (usize, &'a EntryPayload)> + '_ {
self.payloads
.iter()
.take(self.manager.num_valid_entries())
.enumerate()
}

pub(super) fn check_compatibility<T: Labeled>(
&self,
pipeline: &T,
Expand Down
Loading

0 comments on commit 7f708ed

Please sign in to comment.