Skip to content

Commit

Permalink
builder: handle ptr_add-like GEPs (introduced by rust-lang/rust#118991
Browse files Browse the repository at this point in the history
).
  • Loading branch information
eddyb committed Aug 8, 2024
1 parent a21330b commit 5c8fbb4
Show file tree
Hide file tree
Showing 2 changed files with 244 additions and 131 deletions.
264 changes: 243 additions & 21 deletions crates/rustc_codegen_spirv/src/builder/builder_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use std::borrow::Cow;
use std::cell::Cell;
use std::convert::TryInto;
use std::iter::{self, empty};
use std::ops::RangeInclusive;

macro_rules! simple_op {
(
Expand Down Expand Up @@ -412,9 +413,12 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
// FIXME(eddyb) this isn't efficient, `recover_access_chain_from_offset`
// could instead be doing all the extra digging itself.
let mut indices = SmallVec::<[_; 8]>::new();
while let Some((inner_indices, inner_ty)) =
self.recover_access_chain_from_offset(leaf_ty, Size::ZERO, Some(size), None)
{
while let Some((inner_indices, inner_ty)) = self.recover_access_chain_from_offset(
leaf_ty,
Size::ZERO,
Some(size)..=Some(size),
None,
) {
indices.extend(inner_indices);
leaf_ty = inner_ty;
}
Expand All @@ -439,8 +443,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
}

/// If possible, return the appropriate `OpAccessChain` indices for going
/// from a pointer to `ty`, to a pointer to some leaf field/element of size
/// `leaf_size` (and optionally type `leaf_ty`), while adding `offset` bytes.
/// from a pointer to `ty`, to a pointer to some leaf field/element having
/// a size that fits `leaf_size_range` (and, optionally, the type `leaf_ty`),
/// while adding `offset` bytes.
///
/// That is, try to turn `((_: *T) as *u8).add(offset) as *Leaf` into a series
/// of struct field and array/vector element accesses.
Expand All @@ -449,7 +454,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
mut ty: <Self as BackendTypes>::Type,
mut offset: Size,
// FIXME(eddyb) using `None` for "unsized" is a pretty bad design.
leaf_size_or_unsized: Option<Size>,
leaf_size_or_unsized_range: RangeInclusive<Option<Size>>,
leaf_ty: Option<<Self as BackendTypes>::Type>,
) -> Option<(SmallVec<[u32; 8]>, <Self as BackendTypes>::Type)> {
assert_ne!(Some(ty), leaf_ty);
Expand All @@ -460,7 +465,12 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
Sized(Size),
Unsized,
}
let leaf_size = leaf_size_or_unsized.map_or(MaybeSized::Unsized, MaybeSized::Sized);
let leaf_size_range = {
let r = leaf_size_or_unsized_range;
let [start, end] =
[r.start(), r.end()].map(|x| x.map_or(MaybeSized::Unsized, MaybeSized::Sized));
start..=end
};

// NOTE(eddyb) `ty` and `ty_kind`/`ty_size` should be kept in sync.
let mut ty_kind = self.lookup_type(ty);
Expand Down Expand Up @@ -493,7 +503,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
if MaybeSized::Sized(offset_in_field) < field_ty_size
// If the field is a zero sized type, check the
// expected size and type to get the correct entry
|| offset_in_field == Size::ZERO && leaf_size == MaybeSized::Sized(Size::ZERO) && leaf_ty == Some(field_ty)
|| offset_in_field == Size::ZERO
&& leaf_size_range.contains(&MaybeSized::Sized(Size::ZERO)) && leaf_ty == Some(field_ty)
{
Some((i, field_ty, field_ty_kind, field_ty_size, offset_in_field))
} else {
Expand Down Expand Up @@ -525,19 +536,211 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
}

// Avoid digging beyond the point the leaf could actually fit.
if ty_size < leaf_size {
if ty_size < *leaf_size_range.start() {
return None;
}

if offset == Size::ZERO
&& ty_size == leaf_size
&& leaf_size_range.contains(&ty_size)
&& leaf_ty.map_or(true, |leaf_ty| leaf_ty == ty)
{
return Some((indices, ty));
}
}
}

fn maybe_inbounds_gep(
&mut self,
ty: Word,
ptr: SpirvValue,
combined_indices: &[SpirvValue],
is_inbounds: bool,
) -> SpirvValue {
let (&ptr_base_index, indices) = combined_indices.split_first().unwrap();

// The first index is an offset to the pointer, the rest are actual members.
// https://llvm.org/docs/GetElementPtr.html
// "An OpAccessChain instruction is the equivalent of an LLVM getelementptr instruction where the first index element is zero."
// https://github.com/gpuweb/gpuweb/issues/33
let mut result_pointee_type = ty;
let indices: Vec<_> = indices
.iter()
.map(|index| {
result_pointee_type = match self.lookup_type(result_pointee_type) {
SpirvType::Array { element, .. } | SpirvType::RuntimeArray { element } => {
element
}
_ => self.fatal(format!(
"GEP not implemented for type {}",
self.debug_type(result_pointee_type)
)),
};
index.def(self)
})
.collect();

// Special-case field accesses through a `pointercast`, to accesss the
// right field in the original type, for the `Logical` addressing model.
let ptr = ptr.strip_ptrcasts();
let ptr_id = ptr.def(self);
let original_pointee_ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee } => pointee,
other => self.fatal(format!("gep called on non-pointer type: {other:?}")),
};

// HACK(eddyb) `struct_gep` itself is falling out of use, as it's being
// replaced upstream by `ptr_add` (aka `inbounds_gep` with byte offsets).
//
// FIXME(eddyb) get rid of everything other than:
// - constant byte offset (`ptr_add`?)
// - dynamic indexing of a single array
let const_ptr_offset = self
.builder
.lookup_const_u64(ptr_base_index)
.and_then(|idx| Some(idx * self.lookup_type(ty).sizeof(self)?));
if let Some(const_ptr_offset) = const_ptr_offset {
if let Some((base_indices, base_pointee_ty)) = self.recover_access_chain_from_offset(
original_pointee_ty,
const_ptr_offset,
Some(Size::ZERO)..=None,
None,
) {
// FIXME(eddyb) this condition is pretty limiting, but
// eventually it shouldn't matter if GEPs are going away.
if ty == base_pointee_ty || indices.is_empty() {
let result_pointee_type = if indices.is_empty() {
base_pointee_ty
} else {
result_pointee_type
};
let indices = base_indices
.into_iter()
.map(|idx| self.constant_u32(self.span(), idx).def(self))
.chain(indices)
.collect();
return self.emit_access_chain(
self.type_ptr_to(result_pointee_type),
ptr_id,
None,
indices,
is_inbounds,
);
}
}
}

let result_type = self.type_ptr_to(result_pointee_type);

// Check if `ptr_id` is defined by an `OpAccessChain`, and if it is,
// grab its base pointer and indices.
//
// FIXME(eddyb) this could get ridiculously expensive, at the very least
// it could use `.rev()`, hoping the base pointer was recently defined?
let maybe_original_access_chain = if ty == original_pointee_ty {
let emit = self.emit();
let module = emit.module_ref();
let func = &module.functions[emit.selected_function().unwrap()];
let base_ptr_and_combined_indices = func
.all_inst_iter()
.find(|inst| inst.result_id == Some(ptr_id))
.and_then(|ptr_def_inst| {
if matches!(
ptr_def_inst.class.opcode,
Op::AccessChain | Op::InBoundsAccessChain
) {
let base_ptr = ptr_def_inst.operands[0].unwrap_id_ref();
let indices = ptr_def_inst.operands[1..]
.iter()
.map(|op| op.unwrap_id_ref())
.collect::<Vec<_>>();
Some((base_ptr, indices))
} else {
None
}
});
base_ptr_and_combined_indices
} else {
None
};
if let Some((original_ptr, mut original_indices)) = maybe_original_access_chain {
// Transform the following:
// OpAccessChain original_ptr [a, b, c]
// OpPtrAccessChain ptr base [d, e, f]
// into
// OpAccessChain original_ptr [a, b, c + base, d, e, f]
// to remove the need for OpPtrAccessChain
let last = original_indices.last_mut().unwrap();
*last = self
.add(last.with_type(ptr_base_index.ty), ptr_base_index)
.def(self);
original_indices.extend(indices);
return self.emit_access_chain(
result_type,
original_ptr,
None,
original_indices,
is_inbounds,
);
}

// HACK(eddyb) temporary workaround for untyped pointers upstream.
// FIXME(eddyb) replace with untyped memory SPIR-V + `qptr` or similar.
let ptr = self.pointercast(ptr, self.type_ptr_to(ty));
let ptr_id = ptr.def(self);

self.emit_access_chain(
result_type,
ptr_id,
Some(ptr_base_index),
indices,
is_inbounds,
)
}

fn emit_access_chain(
&self,
result_type: <Self as BackendTypes>::Type,
pointer: Word,
ptr_base_index: Option<SpirvValue>,
indices: Vec<Word>,
is_inbounds: bool,
) -> SpirvValue {
let mut emit = self.emit();

let non_zero_ptr_base_index =
ptr_base_index.filter(|&idx| self.builder.lookup_const_u64(idx) != Some(0));
if let Some(ptr_base_index) = non_zero_ptr_base_index {
let result = if is_inbounds {
emit.in_bounds_ptr_access_chain(
result_type,
None,
pointer,
ptr_base_index.def(self),
indices,
)
} else {
emit.ptr_access_chain(
result_type,
None,
pointer,
ptr_base_index.def(self),
indices,
)
}
.unwrap();
self.zombie(result, "cannot offset a pointer to an arbitrary element");
result
} else {
if is_inbounds {
emit.in_bounds_access_chain(result_type, None, pointer, indices)
} else {
emit.access_chain(result_type, None, pointer, indices)
}
.unwrap()
}
.with_type(result_type)
}

fn fptoint_sat(
&mut self,
signed: bool,
Expand Down Expand Up @@ -1361,7 +1564,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
}

fn gep(&mut self, ty: Self::Type, ptr: Self::Value, indices: &[Self::Value]) -> Self::Value {
self.gep_help(ty, ptr, indices, false)
self.maybe_inbounds_gep(ty, ptr, indices, false)
}

fn inbounds_gep(
Expand All @@ -1370,7 +1573,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
ptr: Self::Value,
indices: &[Self::Value],
) -> Self::Value {
self.gep_help(ty, ptr, indices, true)
self.maybe_inbounds_gep(ty, ptr, indices, true)
}

fn struct_gep(&mut self, ty: Self::Type, ptr: Self::Value, idx: u64) -> Self::Value {
Expand All @@ -1395,6 +1598,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
"struct_gep not on struct, array, or vector type: {other:?}, index {idx}"
)),
};
let result_pointee_size = self.lookup_type(result_pointee_type).sizeof(self);
let result_type = self.type_ptr_to(result_pointee_type);

// Special-case field accesses through a `pointercast`, to accesss the
Expand All @@ -1407,7 +1611,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
if let Some((indices, _)) = self.recover_access_chain_from_offset(
original_pointee_ty,
offset,
self.lookup_type(result_pointee_type).sizeof(self),
result_pointee_size..=result_pointee_size,
Some(result_pointee_type),
) {
let original_ptr = ptr.def(self);
Expand Down Expand Up @@ -1586,9 +1790,12 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
// FIXME(eddyb) this isn't efficient, `recover_access_chain_from_offset`
// could instead be doing all the extra digging itself.
let mut indices = SmallVec::<[_; 8]>::new();
while let Some((inner_indices, inner_ty)) =
self.recover_access_chain_from_offset(leaf_ty, Size::ZERO, Some(size), None)
{
while let Some((inner_indices, inner_ty)) = self.recover_access_chain_from_offset(
leaf_ty,
Size::ZERO,
Some(size)..=Some(size),
None,
) {
indices.extend(inner_indices);
leaf_ty = inner_ty;
}
Expand Down Expand Up @@ -1716,9 +1923,17 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
return self.const_bitcast(ptr, dest_ty);
}

if ptr.ty == dest_ty {
return ptr;
}

// Strip a previous `pointercast`, to reveal the original pointer type.
let ptr = ptr.strip_ptrcasts();

if ptr.ty == dest_ty {
return ptr;
}

let ptr_pointee = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee } => pointee,
other => self.fatal(format!(
Expand All @@ -1731,12 +1946,12 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
"pointercast called on non-pointer dest type: {other:?}"
)),
};
if ptr.ty == dest_ty {
ptr
} else if let Some((indices, _)) = self.recover_access_chain_from_offset(
let dest_pointee_size = self.lookup_type(dest_pointee).sizeof(self);

if let Some((indices, _)) = self.recover_access_chain_from_offset(
ptr_pointee,
Size::ZERO,
self.lookup_type(dest_pointee).sizeof(self),
dest_pointee_size..=dest_pointee_size,
Some(dest_pointee),
) {
let indices = indices
Expand Down Expand Up @@ -2687,6 +2902,13 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
Store(ID, ID),
Load(ID, ID),
Call(ID, ID, SmallVec<[ID; 4]>),

// HACK(eddyb) this only exists for better error reporting,
// as `Result<Inst<...>, Op>` would only report one `Op`.
Unsupported(
// HACK(eddyb) only exists for `fmt::Debug` in case of error.
#[allow(dead_code)] Op,
),
}

let taken_inst_idx_range = Cell::new(func.blocks[block_idx].instructions.len())..;
Expand Down Expand Up @@ -2732,7 +2954,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
(Op::FunctionCall, Some(r), [f, args @ ..]) => {
Inst::Call(r, *f, args.iter().copied().collect())
}
_ => return None,
_ => Inst::Unsupported(inst.class.opcode),
},
)
});
Expand Down
Loading

0 comments on commit 5c8fbb4

Please sign in to comment.