cg_llvm: simplify llvm.masked.gather/scatter naming with opaque pointers

With opaque pointers, there's no longer a need to generate a chain
of pointer types in the intrinsic name when arguments are pointers to
pointers.
This commit is contained in:
Erik Desjardins 2023-07-29 16:56:27 -04:00
parent cf7788d54b
commit 55800123b7
3 changed files with 51 additions and 109 deletions

View File

@ -1307,49 +1307,34 @@ macro_rules! return_error {
// FIXME: use: // FIXME: use:
// https://github.com/llvm-mirror/llvm/blob/master/include/llvm/IR/Function.h#L182 // https://github.com/llvm-mirror/llvm/blob/master/include/llvm/IR/Function.h#L182
// https://github.com/llvm-mirror/llvm/blob/master/include/llvm/IR/Intrinsics.h#L81 // https://github.com/llvm-mirror/llvm/blob/master/include/llvm/IR/Intrinsics.h#L81
fn llvm_vector_str( fn llvm_vector_str(bx: &Builder<'_, '_, '_>, elem_ty: Ty<'_>, vec_len: u64) -> String {
elem_ty: Ty<'_>,
vec_len: u64,
no_pointers: usize,
bx: &Builder<'_, '_, '_>,
) -> String {
let p0s: String = "p0".repeat(no_pointers);
match *elem_ty.kind() { match *elem_ty.kind() {
ty::Int(v) => format!( ty::Int(v) => format!(
"v{}{}i{}", "v{}i{}",
vec_len, vec_len,
p0s,
// Normalize to prevent crash if v: IntTy::Isize // Normalize to prevent crash if v: IntTy::Isize
v.normalize(bx.target_spec().pointer_width).bit_width().unwrap() v.normalize(bx.target_spec().pointer_width).bit_width().unwrap()
), ),
ty::Uint(v) => format!( ty::Uint(v) => format!(
"v{}{}i{}", "v{}i{}",
vec_len, vec_len,
p0s,
// Normalize to prevent crash if v: UIntTy::Usize // Normalize to prevent crash if v: UIntTy::Usize
v.normalize(bx.target_spec().pointer_width).bit_width().unwrap() v.normalize(bx.target_spec().pointer_width).bit_width().unwrap()
), ),
ty::Float(v) => format!("v{}{}f{}", vec_len, p0s, v.bit_width()), ty::Float(v) => format!("v{}f{}", vec_len, v.bit_width()),
ty::RawPtr(_) => format!("v{}p0", vec_len),
_ => unreachable!(), _ => unreachable!(),
} }
} }
fn llvm_vector_ty<'ll>( fn llvm_vector_ty<'ll>(cx: &CodegenCx<'ll, '_>, elem_ty: Ty<'_>, vec_len: u64) -> &'ll Type {
cx: &CodegenCx<'ll, '_>, let elem_ty = match *elem_ty.kind() {
elem_ty: Ty<'_>,
vec_len: u64,
no_pointers: usize,
) -> &'ll Type {
// FIXME: use cx.layout_of(ty).llvm_type() ?
let mut elem_ty = match *elem_ty.kind() {
ty::Int(v) => cx.type_int_from_ty(v), ty::Int(v) => cx.type_int_from_ty(v),
ty::Uint(v) => cx.type_uint_from_ty(v), ty::Uint(v) => cx.type_uint_from_ty(v),
ty::Float(v) => cx.type_float_from_ty(v), ty::Float(v) => cx.type_float_from_ty(v),
ty::RawPtr(_) => cx.type_ptr(),
_ => unreachable!(), _ => unreachable!(),
}; };
if no_pointers > 0 {
elem_ty = cx.type_ptr();
}
cx.type_vector(elem_ty, vec_len) cx.type_vector(elem_ty, vec_len)
} }
@ -1404,47 +1389,26 @@ fn llvm_vector_ty<'ll>(
InvalidMonomorphization::ExpectedReturnType { span, name, in_ty, ret_ty } InvalidMonomorphization::ExpectedReturnType { span, name, in_ty, ret_ty }
); );
// This counts how many pointers
fn ptr_count(t: Ty<'_>) -> usize {
match t.kind() {
ty::RawPtr(p) => 1 + ptr_count(p.ty),
_ => 0,
}
}
// Non-ptr type
fn non_ptr(t: Ty<'_>) -> Ty<'_> {
match t.kind() {
ty::RawPtr(p) => non_ptr(p.ty),
_ => t,
}
}
// The second argument must be a simd vector with an element type that's a pointer // The second argument must be a simd vector with an element type that's a pointer
// to the element type of the first argument // to the element type of the first argument
let (_, element_ty0) = arg_tys[0].simd_size_and_type(bx.tcx()); let (_, element_ty0) = arg_tys[0].simd_size_and_type(bx.tcx());
let (_, element_ty1) = arg_tys[1].simd_size_and_type(bx.tcx()); let (_, element_ty1) = arg_tys[1].simd_size_and_type(bx.tcx());
let (pointer_count, underlying_ty) = match element_ty1.kind() {
ty::RawPtr(p) if p.ty == in_elem => (ptr_count(element_ty1), non_ptr(element_ty1)), require!(
_ => { matches!(
require!( element_ty1.kind(),
false, ty::RawPtr(p) if p.ty == in_elem && p.ty.kind() == element_ty0.kind()
InvalidMonomorphization::ExpectedElementType { ),
span, InvalidMonomorphization::ExpectedElementType {
name, span,
expected_element: element_ty1, name,
second_arg: arg_tys[1], expected_element: element_ty1,
in_elem, second_arg: arg_tys[1],
in_ty, in_elem,
mutability: ExpectedPointerMutability::Not, in_ty,
} mutability: ExpectedPointerMutability::Not,
);
unreachable!();
} }
}; );
assert!(pointer_count > 0);
assert_eq!(pointer_count - 1, ptr_count(element_ty0));
assert_eq!(underlying_ty, non_ptr(element_ty0));
// The element type of the third argument must be a signed integer type of any width: // The element type of the third argument must be a signed integer type of any width:
let (_, element_ty2) = arg_tys[2].simd_size_and_type(bx.tcx()); let (_, element_ty2) = arg_tys[2].simd_size_and_type(bx.tcx());
@ -1475,12 +1439,12 @@ fn non_ptr(t: Ty<'_>) -> Ty<'_> {
}; };
// Type of the vector of pointers: // Type of the vector of pointers:
let llvm_pointer_vec_ty = llvm_vector_ty(bx, underlying_ty, in_len, pointer_count); let llvm_pointer_vec_ty = llvm_vector_ty(bx, element_ty1, in_len);
let llvm_pointer_vec_str = llvm_vector_str(underlying_ty, in_len, pointer_count, bx); let llvm_pointer_vec_str = llvm_vector_str(bx, element_ty1, in_len);
// Type of the vector of elements: // Type of the vector of elements:
let llvm_elem_vec_ty = llvm_vector_ty(bx, underlying_ty, in_len, pointer_count - 1); let llvm_elem_vec_ty = llvm_vector_ty(bx, element_ty0, in_len);
let llvm_elem_vec_str = llvm_vector_str(underlying_ty, in_len, pointer_count - 1, bx); let llvm_elem_vec_str = llvm_vector_str(bx, element_ty0, in_len);
let llvm_intrinsic = let llvm_intrinsic =
format!("llvm.masked.gather.{}.{}", llvm_elem_vec_str, llvm_pointer_vec_str); format!("llvm.masked.gather.{}.{}", llvm_elem_vec_str, llvm_pointer_vec_str);
@ -1544,50 +1508,28 @@ fn non_ptr(t: Ty<'_>) -> Ty<'_> {
} }
); );
// This counts how many pointers
fn ptr_count(t: Ty<'_>) -> usize {
match t.kind() {
ty::RawPtr(p) => 1 + ptr_count(p.ty),
_ => 0,
}
}
// Non-ptr type
fn non_ptr(t: Ty<'_>) -> Ty<'_> {
match t.kind() {
ty::RawPtr(p) => non_ptr(p.ty),
_ => t,
}
}
// The second argument must be a simd vector with an element type that's a pointer // The second argument must be a simd vector with an element type that's a pointer
// to the element type of the first argument // to the element type of the first argument
let (_, element_ty0) = arg_tys[0].simd_size_and_type(bx.tcx()); let (_, element_ty0) = arg_tys[0].simd_size_and_type(bx.tcx());
let (_, element_ty1) = arg_tys[1].simd_size_and_type(bx.tcx()); let (_, element_ty1) = arg_tys[1].simd_size_and_type(bx.tcx());
let (_, element_ty2) = arg_tys[2].simd_size_and_type(bx.tcx()); let (_, element_ty2) = arg_tys[2].simd_size_and_type(bx.tcx());
let (pointer_count, underlying_ty) = match element_ty1.kind() {
ty::RawPtr(p) if p.ty == in_elem && p.mutbl.is_mut() => { require!(
(ptr_count(element_ty1), non_ptr(element_ty1)) matches!(
element_ty1.kind(),
ty::RawPtr(p)
if p.ty == in_elem && p.mutbl.is_mut() && p.ty.kind() == element_ty0.kind()
),
InvalidMonomorphization::ExpectedElementType {
span,
name,
expected_element: element_ty1,
second_arg: arg_tys[1],
in_elem,
in_ty,
mutability: ExpectedPointerMutability::Mut,
} }
_ => { );
require!(
false,
InvalidMonomorphization::ExpectedElementType {
span,
name,
expected_element: element_ty1,
second_arg: arg_tys[1],
in_elem,
in_ty,
mutability: ExpectedPointerMutability::Mut,
}
);
unreachable!();
}
};
assert!(pointer_count > 0);
assert_eq!(pointer_count - 1, ptr_count(element_ty0));
assert_eq!(underlying_ty, non_ptr(element_ty0));
// The element type of the third argument must be a signed integer type of any width: // The element type of the third argument must be a signed integer type of any width:
match element_ty2.kind() { match element_ty2.kind() {
@ -1619,12 +1561,12 @@ fn non_ptr(t: Ty<'_>) -> Ty<'_> {
let ret_t = bx.type_void(); let ret_t = bx.type_void();
// Type of the vector of pointers: // Type of the vector of pointers:
let llvm_pointer_vec_ty = llvm_vector_ty(bx, underlying_ty, in_len, pointer_count); let llvm_pointer_vec_ty = llvm_vector_ty(bx, element_ty1, in_len);
let llvm_pointer_vec_str = llvm_vector_str(underlying_ty, in_len, pointer_count, bx); let llvm_pointer_vec_str = llvm_vector_str(bx, element_ty1, in_len);
// Type of the vector of elements: // Type of the vector of elements:
let llvm_elem_vec_ty = llvm_vector_ty(bx, underlying_ty, in_len, pointer_count - 1); let llvm_elem_vec_ty = llvm_vector_ty(bx, element_ty0, in_len);
let llvm_elem_vec_str = llvm_vector_str(underlying_ty, in_len, pointer_count - 1, bx); let llvm_elem_vec_str = llvm_vector_str(bx, element_ty0, in_len);
let llvm_intrinsic = let llvm_intrinsic =
format!("llvm.masked.scatter.{}.{}", llvm_elem_vec_str, llvm_pointer_vec_str); format!("llvm.masked.scatter.{}.{}", llvm_elem_vec_str, llvm_pointer_vec_str);

View File

@ -23,7 +23,7 @@
#[no_mangle] #[no_mangle]
pub unsafe fn gather_f32x2(pointers: Vec2<*const f32>, mask: Vec2<i32>, pub unsafe fn gather_f32x2(pointers: Vec2<*const f32>, mask: Vec2<i32>,
values: Vec2<f32>) -> Vec2<f32> { values: Vec2<f32>) -> Vec2<f32> {
// CHECK: call <2 x float> @llvm.masked.gather.v2f32.{{.+}}(<2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> {{.*}}, <2 x float> {{.*}}) // CHECK: call <2 x float> @llvm.masked.gather.v2f32.v2p0(<2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> {{.*}}, <2 x float> {{.*}})
simd_gather(values, pointers, mask) simd_gather(values, pointers, mask)
} }
@ -31,6 +31,6 @@ pub unsafe fn gather_f32x2(pointers: Vec2<*const f32>, mask: Vec2<i32>,
#[no_mangle] #[no_mangle]
pub unsafe fn gather_pf32x2(pointers: Vec2<*const *const f32>, mask: Vec2<i32>, pub unsafe fn gather_pf32x2(pointers: Vec2<*const *const f32>, mask: Vec2<i32>,
values: Vec2<*const f32>) -> Vec2<*const f32> { values: Vec2<*const f32>) -> Vec2<*const f32> {
// CHECK: call <2 x ptr> @llvm.masked.gather.{{.+}}(<2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> {{.*}}, <2 x ptr> {{.*}}) // CHECK: call <2 x ptr> @llvm.masked.gather.v2p0.v2p0(<2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> {{.*}}, <2 x ptr> {{.*}})
simd_gather(values, pointers, mask) simd_gather(values, pointers, mask)
} }

View File

@ -23,7 +23,7 @@
#[no_mangle] #[no_mangle]
pub unsafe fn scatter_f32x2(pointers: Vec2<*mut f32>, mask: Vec2<i32>, pub unsafe fn scatter_f32x2(pointers: Vec2<*mut f32>, mask: Vec2<i32>,
values: Vec2<f32>) { values: Vec2<f32>) {
// CHECK: call void @llvm.masked.scatter.v2f32.v2p0{{.*}}(<2 x float> {{.*}}, <2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> {{.*}}) // CHECK: call void @llvm.masked.scatter.v2f32.v2p0(<2 x float> {{.*}}, <2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> {{.*}})
simd_scatter(values, pointers, mask) simd_scatter(values, pointers, mask)
} }
@ -32,6 +32,6 @@ pub unsafe fn scatter_f32x2(pointers: Vec2<*mut f32>, mask: Vec2<i32>,
#[no_mangle] #[no_mangle]
pub unsafe fn scatter_pf32x2(pointers: Vec2<*mut *const f32>, mask: Vec2<i32>, pub unsafe fn scatter_pf32x2(pointers: Vec2<*mut *const f32>, mask: Vec2<i32>,
values: Vec2<*const f32>) { values: Vec2<*const f32>) {
// CHECK: call void @llvm.masked.scatter.v2p0{{.*}}.v2p0{{.*}}(<2 x ptr> {{.*}}, <2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> {{.*}}) // CHECK: call void @llvm.masked.scatter.v2p0.v2p0(<2 x ptr> {{.*}}, <2 x ptr> {{.*}}, i32 {{.*}}, <2 x i1> {{.*}})
simd_scatter(values, pointers, mask) simd_scatter(values, pointers, mask)
} }