Skip to content

Commit

Permalink
Merge pull request #362 from armansito/pr-issue-33
Browse files Browse the repository at this point in the history
[shaders] Expose a binding index set per target language
  • Loading branch information
armansito authored Sep 15, 2023
2 parents 34d313d + 5ef483e commit d4192ec
Show file tree
Hide file tree
Showing 10 changed files with 198 additions and 64 deletions.
50 changes: 45 additions & 5 deletions crates/shaders/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,58 @@ fn write_shaders(
wg_bufs
)?;
if cfg!(feature = "wgsl") {
writeln!(buf, " wgsl: Cow::Borrowed({:?}),", info.source)?;
}
if cfg!(feature = "msl") {
let indices = info
.bindings
.iter()
.map(|binding| binding.location.1)
.collect::<Vec<_>>();
writeln!(buf, " wgsl: WgslSource {{")?;
writeln!(
buf,
" msl: Cow::Borrowed({:?}),",
compile::msl::translate(info).unwrap()
" code: Cow::Borrowed({:?}),",
info.source
)?;
writeln!(
buf,
" binding_indices : Cow::Borrowed(&{:?}),",
indices
)?;
writeln!(buf, " }},")?;
}
if cfg!(feature = "msl") {
write_msl(buf, info)?;
}
writeln!(buf, " }},")?;
}
writeln!(buf, " }};")?;
writeln!(buf, "}}")?;
Ok(())
}

#[cfg(not(feature = "msl"))]
fn write_msl(_: &mut String, _: &ShaderInfo) -> Result<(), std::fmt::Error> {
Ok(())
}

#[cfg(feature = "msl")]
fn write_msl(buf: &mut String, info: &ShaderInfo) -> Result<(), std::fmt::Error> {
let mut index_iter = compile::msl::BindingIndexIterator::default();
let indices = info
.bindings
.iter()
.map(|binding| index_iter.next(binding.ty))
.collect::<Vec<_>>();
writeln!(buf, " msl: MslSource {{")?;
writeln!(
buf,
" code: Cow::Borrowed({:?}),",
compile::msl::translate(info).unwrap()
)?;
writeln!(
buf,
" binding_indices : Cow::Borrowed(&{:?}),",
indices
)?;
writeln!(buf, " }},")?;
Ok(())
}
1 change: 1 addition & 0 deletions crates/shaders/src/compile/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use {
pub mod permutations;
pub mod preprocess;

#[cfg(feature = "msl")]
pub mod msl;

use crate::types::{BindType, BindingInfo, WorkgroupBufferInfo};
Expand Down
64 changes: 44 additions & 20 deletions crates/shaders/src/compile/msl.rs
Original file line number Diff line number Diff line change
@@ -1,43 +1,42 @@
// Copyright 2023 The Vello authors
// SPDX-License-Identifier: Apache-2.0 OR MIT

use naga::back::msl;
use naga::back::msl as naga_msl;
use {
super::{BindType, ShaderInfo},
crate::types::msl::BindingIndex,
};

use super::{BindType, ShaderInfo};

pub fn translate(shader: &ShaderInfo) -> Result<String, msl::Error> {
let mut map = msl::EntryPointResourceMap::default();
let mut buffer_index = 0u8;
let mut image_index = 0u8;
let mut binding_map = msl::BindingMap::default();
pub fn translate(shader: &ShaderInfo) -> Result<String, naga_msl::Error> {
let mut map = naga_msl::EntryPointResourceMap::default();
let mut idx_iter = BindingIndexIterator::default();
let mut binding_map = naga_msl::BindingMap::default();
for resource in &shader.bindings {
let binding = naga::ResourceBinding {
group: resource.location.0,
binding: resource.location.1,
};
let mut target = msl::BindTarget::default();
match resource.ty {
BindType::Buffer | BindType::BufReadOnly | BindType::Uniform => {
target.buffer = Some(buffer_index);
buffer_index += 1;
let mut target = naga_msl::BindTarget::default();
match idx_iter.next(resource.ty) {
BindingIndex::Buffer(idx) => {
target.buffer = Some(idx);
}
BindType::Image | BindType::ImageRead => {
target.texture = Some(image_index);
image_index += 1;
BindingIndex::Texture(idx) => {
target.texture = Some(idx);
}
}
target.mutable = resource.ty.is_mutable();
binding_map.insert(binding, target);
}
map.insert(
"main".to_string(),
msl::EntryPointResources {
naga_msl::EntryPointResources {
resources: binding_map,
push_constant_buffer: None,
sizes_buffer: Some(30),
},
);
let options = msl::Options {
let options = naga_msl::Options {
lang_version: (2, 0),
per_entry_point_map: map,
inline_samplers: vec![],
Expand All @@ -46,11 +45,36 @@ pub fn translate(shader: &ShaderInfo) -> Result<String, msl::Error> {
bounds_check_policies: naga::proc::BoundsCheckPolicies::default(),
zero_initialize_workgroup_memory: false,
};
let (source, _) = msl::write_string(
let (source, _) = naga_msl::write_string(
&shader.module,
&shader.module_info,
&options,
&msl::PipelineOptions::default(),
&naga_msl::PipelineOptions::default(),
)?;
Ok(source)
}

#[derive(Default)]
pub struct BindingIndexIterator {
buffer_idx: u8,
tex_idx: u8,
}

impl BindingIndexIterator {
pub fn next(&mut self, ty: BindType) -> BindingIndex {
match ty {
BindType::Buffer | BindType::BufReadOnly | BindType::Uniform => {
let idx = self.buffer_idx;
self.buffer_idx += 1;
assert!(self.buffer_idx > 0);
BindingIndex::Buffer(idx)
}
BindType::Image | BindType::ImageRead => {
let idx = self.tex_idx;
self.tex_idx += 1;
assert!(self.tex_idx > 0);
BindingIndex::Texture(idx)
}
}
}
}
74 changes: 72 additions & 2 deletions crates/shaders/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ pub mod compile;

pub use types::{BindType, BindingInfo, WorkgroupBufferInfo};

#[cfg(feature = "msl")]
pub use types::msl;

use std::borrow::Cow;

#[derive(Clone, Debug)]
Expand All @@ -18,10 +21,77 @@ pub struct ComputeShader<'a> {
pub workgroup_buffers: Cow<'a, [WorkgroupBufferInfo]>,

#[cfg(feature = "wgsl")]
pub wgsl: Cow<'a, str>,
pub wgsl: WgslSource<'a>,

#[cfg(feature = "msl")]
pub msl: Cow<'a, str>,
pub msl: MslSource<'a>,
}

#[cfg(feature = "wgsl")]
#[derive(Clone, Debug)]
pub struct WgslSource<'a> {
pub code: Cow<'a, str>,

/// Contains the binding index of each resource listed in `ComputeShader::bindings`.
/// This is guaranteed to have the same element count as `ComputeShader::bindings`.
///
/// In WGSL, each index directly corresponds to the value of the corresponding
/// `@binding(..)` declaration in the shader source. The bind group index (i.e.
/// value of `@group(..)`) is always 0.
///
/// Example:
/// --------
///
/// // An unused binding (i.e. declaration is not reachable from the entry-point)
/// @group(0) @binding(0) var<uniform> foo: Foo;
///
/// // Used bindings:
/// @group(0) @binding(1) var<storage> buffer: Buffer;
/// @group(0) @binding(2) var tex: texture_2d<f32>;
/// ...
///
/// This results in the following bindings:
///
/// bindings: [BindType::Buffer, BindType::ImageRead],
/// ...
/// wgsl: WgslSource {
/// code: ...,
/// binding_indices: [1, 2],
/// },
pub binding_indices: Cow<'a, [u8]>,
}

#[cfg(feature = "msl")]
#[derive(Clone, Debug)]
pub struct MslSource<'a> {
pub code: Cow<'a, str>,

/// Contains the binding index of each resource listed in `ComputeShader::bindings`.
/// This is guaranteed to have the same element count as `ComputeShader::bindings`.
///
/// In MSL, each index is scoped to the index range of the corresponding resource type.
///
/// Example:
/// --------
///
/// // An unused binding (i.e. declaration is not reachable from the entry-point)
/// @group(0) @binding(0) var<uniform> foo: Foo;
///
/// // Used bindings:
/// @group(0) @binding(1) var<storage> buffer: Buffer;
/// @group(0) @binding(2) var tex: texture_2d<f32>;
/// ...
///
/// This results in the following bindings:
///
/// bindings: [BindType::Buffer, BindType::ImageRead],
/// ...
/// msl: MslSource {
/// code: ...,
/// // In MSL these would be declared as `[[buffer(0)]]` and `[[texture(0)]]`.
/// binding_indices: [msl::BindingIndex::Buffer(0), msl::BindingIndex::Texture(0)],
/// },
pub binding_indices: Cow<'a, [msl::BindingIndex]>,
}

pub trait PipelineHost {
Expand Down
20 changes: 20 additions & 0 deletions crates/shaders/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,23 @@ pub struct WorkgroupBufferInfo {
/// The order in which the workgroup variable is declared in the shader module.
pub index: u32,
}

#[cfg(feature = "msl")]
pub mod msl {
use std::fmt;

#[derive(Clone)]
pub enum BindingIndex {
Buffer(u8),
Texture(u8),
}

impl fmt::Debug for BindingIndex {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Self::Buffer(i) => write!(f, "msl::BindingIndex::Buffer({})", i),
Self::Texture(i) => write!(f, "msl::BindingIndex::Texture({})", i),
}
}
}
}
10 changes: 3 additions & 7 deletions shader/clip_reduce.wgsl
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
// SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense

#import config
#import bbox
#import clip

@group(0) @binding(0)
var<uniform> config: Config;

@group(0) @binding(1)
var<storage> clip_inp: array<ClipInp>;

@group(0) @binding(2)
@group(0) @binding(1)
var<storage> path_bboxes: array<PathBbox>;

@group(0) @binding(3)
@group(0) @binding(2)
var<storage, read_write> reduced: array<Bic>;

@group(0) @binding(4)
@group(0) @binding(3)
var<storage, read_write> clip_out: array<ClipEl>;

let WG_SIZE = 256u;
Expand Down
13 changes: 5 additions & 8 deletions shader/fine.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ struct Tile {
var<uniform> config: Config;

@group(0) @binding(1)
var<storage> tiles: array<Tile>;

@group(0) @binding(2)
var<storage> segments: array<Segment>;

#ifdef full
Expand All @@ -28,19 +25,19 @@ var<storage> segments: array<Segment>;

let GRADIENT_WIDTH = 512;

@group(0) @binding(2)
var<storage> ptcl: array<u32>;

@group(0) @binding(3)
var output: texture_storage_2d<rgba8unorm, write>;
var<storage> info: array<u32>;

@group(0) @binding(4)
var<storage> ptcl: array<u32>;
var output: texture_storage_2d<rgba8unorm, write>;

@group(0) @binding(5)
var gradients: texture_2d<f32>;

@group(0) @binding(6)
var<storage> info: array<u32>;

@group(0) @binding(7)
var image_atlas: texture_2d<f32>;

fn read_fill(cmd_ix: u32) -> CmdFill {
Expand Down
11 changes: 4 additions & 7 deletions shader/path_coarse_full.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,9 @@ var<uniform> config: Config;
var<storage> scene: array<u32>;

@group(0) @binding(2)
var<storage> tag_monoids: array<TagMonoid>;

@group(0) @binding(3)
var<storage> cubics: array<Cubic>;

@group(0) @binding(4)
@group(0) @binding(3)
var<storage> paths: array<Path>;

// We don't get this from import as it's the atomic version
Expand All @@ -30,13 +27,13 @@ struct AtomicTile {
segments: atomic<u32>,
}

@group(0) @binding(5)
@group(0) @binding(4)
var<storage, read_write> bump: BumpAllocators;

@group(0) @binding(6)
@group(0) @binding(5)
var<storage, read_write> tiles: array<AtomicTile>;

@group(0) @binding(7)
@group(0) @binding(6)
var<storage, read_write> segments: array<Segment>;

struct SubdivResult {
Expand Down
Loading

0 comments on commit d4192ec

Please sign in to comment.