Skip to content

Commit

Permalink
hello-workgroups It doesn't work.
Browse files Browse the repository at this point in the history
  • Loading branch information
JustAnotherCodemonkey committed Jun 29, 2023
1 parent dca6a19 commit a15fc4a
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 0 deletions.
14 changes: 14 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions examples/hello-workgroups/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
[package]
name = "wgpu-hello-workgroups-example"
version.workspace = true
license.workspace = true
edition.workspace = true
description = "wgpu hello workgroups example"
publish = false

[[bin]]
name = "hello-workgroups"
path = "src/main.rs"

[dependencies]
bytemuck.workspace = true
env_logger.workspace = true
futures-intrusive.workspace = true
log.workspace = true
pollster.workspace = true
wgpu.workspace = true

[target.'cfg(target_arch = "wasm32")'.dependencies]
console_error_panic_hook.workspace = true
console_log.workspace = true
176 changes: 176 additions & 0 deletions examples/hello-workgroups/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
use wgpu::util::DeviceExt;

async fn run() {
let mut local_a = [0i32; 100];
for (i, e) in local_a.iter_mut().enumerate() {
*e = i as i32;
}
log::info!("Input a: {local_a:?}");
let mut local_b = [0i32; 100];
for (i, e) in local_b.iter_mut().enumerate() {
*e = i as i32 * 2;
}
log::info!("Input b: {local_b:?}");

let instance = wgpu::Instance::default();
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions::default())
.await
.unwrap();
let (device, queue) = adapter
.request_device(&wgpu::DeviceDescriptor::default(), None)
.await
.unwrap();

let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Borrowed(include_str!("shader.wgsl"))),
});

let storage_buffer_a = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: None,
contents: bytemuck::cast_slice(&local_a[..]),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
});
let storage_buffer_b = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: None,
contents: bytemuck::cast_slice(&local_b[..]),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
});
let output_staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: std::mem::size_of_val(&local_a) as u64,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});

let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: None,
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: storage_buffer_a.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: storage_buffer_b.as_entire_binding(),
},
],
});

let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: Some(&pipeline_layout),
module: &shader,
entry_point: "main",
});

//----------------------------------------------------------

let mut command_encoder =
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
{
let mut compute_pass =
command_encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None });
compute_pass.set_pipeline(&pipeline);
compute_pass.set_bind_group(0, &bind_group, &[]);
compute_pass.dispatch_workgroups(local_a.len() as u32, 1, 1);
}
queue.submit(Some(command_encoder.finish()));

//----------------------------------------------------------

get_data(
&mut local_a[..],
&storage_buffer_a,
&output_staging_buffer,
&device,
&queue,
)
.await;
get_data(
&mut local_b[..],
&storage_buffer_b,
&output_staging_buffer,
&device,
&queue,
)
.await;

log::info!("Output in A: {local_a:?}");
log::info!("Output in B: {local_b:?}");
}

async fn get_data<T: bytemuck::Pod>(
output: &mut [T],
storage_buffer: &wgpu::Buffer,
staging_buffer: &wgpu::Buffer,
device: &wgpu::Device,
queue: &wgpu::Queue,
) {
let mut command_encoder =
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
command_encoder.copy_buffer_to_buffer(
&storage_buffer,
0,
&staging_buffer,
0,
std::mem::size_of_val(&output.len()) as u64,
);
queue.submit(Some(command_encoder.finish()));
let buffer_slice = staging_buffer.slice(..);
let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |r| sender.send(r).unwrap());
device.poll(wgpu::Maintain::Wait);
receiver.receive().await.unwrap().unwrap();
output.copy_from_slice(bytemuck::cast_slice(&buffer_slice.get_mapped_range()[..]));
staging_buffer.unmap();
}

fn main() {
#[cfg(not(target_arch = "wasm32"))]
{
env_logger::builder()
.filter_level(log::LevelFilter::Info)
.format_timestamp_nanos()
.init();
pollster::block_on(run());
}
#[cfg(target_arch = "wasm32")]
{
std::panic::set_hook(Box::new(console_error_panic_hook::hook));
console_log::init().expect("could not initialize logger");
wasm_bindgen_futures::spawn_local(run());
}
}
18 changes: 18 additions & 0 deletions examples/hello-workgroups/src/shader.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
@group(0)
@binding(0)
var<storage, read_write> a: array<i32>;

@group(0)
@binding(1)
var<storage, read_write> b: array<i32>;

@compute
@workgroup_size(2, 1, 1)
fn main(@builtin(local_invocation_id) lid: vec3<u32>, @builtin(workgroup_id) wid: vec3<u32>) {
a[0] += 1;
if lid.x == 0u {
a[wid.x] += 1;
} else if lid.x == 1u {
b[wid.x] += 1;
}
}

0 comments on commit a15fc4a

Please sign in to comment.