diff --git a/compiler/rustc_const_eval/src/interpret/intrinsics.rs b/compiler/rustc_const_eval/src/interpret/intrinsics.rs index 698742fe98ceb..44da27a43db0a 100644 --- a/compiler/rustc_const_eval/src/interpret/intrinsics.rs +++ b/compiler/rustc_const_eval/src/interpret/intrinsics.rs @@ -413,48 +413,33 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> { sym::simd_insert => { let index = u64::from(self.read_scalar(&args[1])?.to_u32()?); let elem = &args[2]; - let input = &args[0]; - let (len, e_ty) = input.layout.ty.simd_size_and_type(*self.tcx); + let (input, input_len) = self.operand_to_simd(&args[0])?; + let (dest, dest_len) = self.place_to_simd(dest)?; + assert_eq!(input_len, dest_len, "Return vector length must match input length"); assert!( - index < len, - "Index `{}` must be in bounds of vector type `{}`: `[0, {})`", + index < dest_len, + "Index `{}` must be in bounds of vector with length {}`", index, - e_ty, - len - ); - assert_eq!( - input.layout, dest.layout, - "Return type `{}` must match vector type `{}`", - dest.layout.ty, input.layout.ty - ); - assert_eq!( - elem.layout.ty, e_ty, - "Scalar element type `{}` must match vector element type `{}`", - elem.layout.ty, e_ty + dest_len ); - for i in 0..len { - let place = self.place_index(dest, i)?; - let value = if i == index { *elem } else { self.operand_index(input, i)? }; - self.copy_op(&value, &place)?; + for i in 0..dest_len { + let place = self.mplace_index(&dest, i)?; + let value = + if i == index { *elem } else { self.mplace_index(&input, i)?.into() }; + self.copy_op(&value, &place.into())?; } } sym::simd_extract => { let index = u64::from(self.read_scalar(&args[1])?.to_u32()?); - let (len, e_ty) = args[0].layout.ty.simd_size_and_type(*self.tcx); + let (input, input_len) = self.operand_to_simd(&args[0])?; assert!( - index < len, - "index `{}` is out-of-bounds of vector type `{}` with length `{}`", + index < input_len, + "index `{}` must be in bounds of vector with length `{}`", index, - e_ty, - len - ); - assert_eq!( - e_ty, dest.layout.ty, - "Return type `{}` must match vector element type `{}`", - dest.layout.ty, e_ty + input_len ); - self.copy_op(&self.operand_index(&args[0], index)?, dest)?; + self.copy_op(&self.mplace_index(&input, index)?.into(), dest)?; } sym::likely | sym::unlikely | sym::black_box => { // These just return their argument diff --git a/compiler/rustc_const_eval/src/interpret/operand.rs b/compiler/rustc_const_eval/src/interpret/operand.rs index b6682b13ed216..de9e94ce2ac0c 100644 --- a/compiler/rustc_const_eval/src/interpret/operand.rs +++ b/compiler/rustc_const_eval/src/interpret/operand.rs @@ -437,6 +437,18 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> { }) } + /// Converts a repr(simd) operand into an operand where `place_index` accesses the SIMD elements. + /// Also returns the number of elements. + pub fn operand_to_simd( + &self, + base: &OpTy<'tcx, M::PointerTag>, + ) -> InterpResult<'tcx, (MPlaceTy<'tcx, M::PointerTag>, u64)> { + // Basically we just transmute this place into an array following simd_size_and_type. + // This only works in memory, but repr(simd) types should never be immediates anyway. + assert!(base.layout.ty.is_simd()); + self.mplace_to_simd(&base.assert_mem_place()) + } + /// Read from a local. Will not actually access the local if reading from a ZST. /// Will not access memory, instead an indirect `Operand` is returned. /// diff --git a/compiler/rustc_const_eval/src/interpret/place.rs b/compiler/rustc_const_eval/src/interpret/place.rs index d425b84bdaf26..d7f2853fc86f5 100644 --- a/compiler/rustc_const_eval/src/interpret/place.rs +++ b/compiler/rustc_const_eval/src/interpret/place.rs @@ -200,7 +200,7 @@ impl<'tcx, Tag: Provenance> MPlaceTy<'tcx, Tag> { } } else { // Go through the layout. There are lots of types that support a length, - // e.g., SIMD types. + // e.g., SIMD types. (But not all repr(simd) types even have FieldsShape::Array!) match self.layout.fields { FieldsShape::Array { count, .. } => Ok(count), _ => bug!("len not supported on sized type {:?}", self.layout.ty), @@ -533,6 +533,22 @@ where }) } + /// Converts a repr(simd) place into a place where `place_index` accesses the SIMD elements. + /// Also returns the number of elements. + pub fn mplace_to_simd( + &self, + base: &MPlaceTy<'tcx, M::PointerTag>, + ) -> InterpResult<'tcx, (MPlaceTy<'tcx, M::PointerTag>, u64)> { + // Basically we just transmute this place into an array following simd_size_and_type. + // (Transmuting is okay since this is an in-memory place. We also double-check the size + // stays the same.) + let (len, e_ty) = base.layout.ty.simd_size_and_type(*self.tcx); + let array = self.tcx.mk_array(e_ty, len); + let layout = self.layout_of(array)?; + assert_eq!(layout.size, base.layout.size); + Ok((MPlaceTy { layout, ..*base }, len)) + } + /// Gets the place of a field inside the place, and also the field's type. /// Just a convenience function, but used quite a bit. /// This is the only projection that might have a side-effect: We cannot project @@ -594,6 +610,16 @@ where }) } + /// Converts a repr(simd) place into a place where `place_index` accesses the SIMD elements. + /// Also returns the number of elements. + pub fn place_to_simd( + &mut self, + base: &PlaceTy<'tcx, M::PointerTag>, + ) -> InterpResult<'tcx, (MPlaceTy<'tcx, M::PointerTag>, u64)> { + let mplace = self.force_allocation(base)?; + self.mplace_to_simd(&mplace) + } + /// Computes a place. You should only use this if you intend to write into this /// place; for reading, a more efficient alternative is `eval_place_for_read`. pub fn eval_place( diff --git a/compiler/rustc_middle/src/ty/sty.rs b/compiler/rustc_middle/src/ty/sty.rs index 610f9bd8f82d7..c79e25f4781c8 100644 --- a/compiler/rustc_middle/src/ty/sty.rs +++ b/compiler/rustc_middle/src/ty/sty.rs @@ -1805,10 +1805,13 @@ impl<'tcx> TyS<'tcx> { pub fn simd_size_and_type(&self, tcx: TyCtxt<'tcx>) -> (u64, Ty<'tcx>) { match self.kind() { Adt(def, substs) => { + assert!(def.repr.simd(), "`simd_size_and_type` called on non-SIMD type"); let variant = def.non_enum_variant(); let f0_ty = variant.fields[0].ty(tcx, substs); match f0_ty.kind() { + // If the first field is an array, we assume it is the only field and its + // elements are the SIMD components. Array(f0_elem_ty, f0_len) => { // FIXME(repr_simd): https://github.com/rust-lang/rust/pull/78863#discussion_r522784112 // The way we evaluate the `N` in `[T; N]` here only works since we use @@ -1816,6 +1819,8 @@ impl<'tcx> TyS<'tcx> { // if we use it in generic code. See the `simd-array-trait` ui test. (f0_len.eval_usize(tcx, ParamEnv::empty()) as u64, f0_elem_ty) } + // Otherwise, the fields of this Adt are the SIMD components (and we assume they + // all have the same type). _ => (variant.fields.len() as u64, f0_ty), } } diff --git a/src/test/ui/consts/const-eval/simd/insert_extract.rs b/src/test/ui/consts/const-eval/simd/insert_extract.rs index cae8fcf1068ad..a1d6c5e51b498 100644 --- a/src/test/ui/consts/const-eval/simd/insert_extract.rs +++ b/src/test/ui/consts/const-eval/simd/insert_extract.rs @@ -7,7 +7,9 @@ #[repr(simd)] struct i8x1(i8); #[repr(simd)] struct u16x2(u16, u16); -#[repr(simd)] struct f32x4(f32, f32, f32, f32); +// Make some of them array types to ensure those also work. +#[repr(simd)] struct i8x1_arr([i8; 1]); +#[repr(simd)] struct f32x4([f32; 4]); extern "platform-intrinsic" { #[rustc_const_stable(feature = "foo", since = "1.3.37")] @@ -25,6 +27,14 @@ fn main() { assert_eq!(X0, 42); assert_eq!(Y0, 42); } + { + const U: i8x1_arr = i8x1_arr([13]); + const V: i8x1_arr = unsafe { simd_insert(U, 0_u32, 42_i8) }; + const X0: i8 = V.0[0]; + const Y0: i8 = unsafe { simd_extract(V, 0) }; + assert_eq!(X0, 42); + assert_eq!(Y0, 42); + } { const U: u16x2 = u16x2(13, 14); const V: u16x2 = unsafe { simd_insert(U, 1_u32, 42_u16) }; @@ -38,12 +48,12 @@ fn main() { assert_eq!(Y1, 42); } { - const U: f32x4 = f32x4(13., 14., 15., 16.); + const U: f32x4 = f32x4([13., 14., 15., 16.]); const V: f32x4 = unsafe { simd_insert(U, 1_u32, 42_f32) }; - const X0: f32 = V.0; - const X1: f32 = V.1; - const X2: f32 = V.2; - const X3: f32 = V.3; + const X0: f32 = V.0[0]; + const X1: f32 = V.0[1]; + const X2: f32 = V.0[2]; + const X3: f32 = V.0[3]; const Y0: f32 = unsafe { simd_extract(V, 0) }; const Y1: f32 = unsafe { simd_extract(V, 1) }; const Y2: f32 = unsafe { simd_extract(V, 2) };