Skip to content

Commit ff16bf1

Browse files
committed
Add autocasts for i1 vectors
1 parent 84c41e7 commit ff16bf1

File tree

3 files changed

+110
-18
lines changed

3 files changed

+110
-18
lines changed

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -375,26 +375,31 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
375375
return true;
376376
}
377377

378-
// Some LLVM intrinsics return **non-packed** structs, but they can't be mimicked from Rust
379-
// due to auto field-alignment in non-packed structs (packed structs are represented in LLVM
380-
// as, well, packed structs, so they won't match with those either)
381-
if self.type_kind(llvm_ty) == TypeKind::Struct
382-
&& self.type_kind(rust_ty) == TypeKind::Struct
383-
{
384-
let rust_element_tys = self.struct_element_types(rust_ty);
385-
let llvm_element_tys = self.struct_element_types(llvm_ty);
378+
match self.type_kind(llvm_ty) {
379+
// Some LLVM intrinsics return **non-packed** structs, but they can't be mimicked from Rust
380+
// due to auto field-alignment in non-packed structs (packed structs are represented in LLVM
381+
// as, well, packed structs, so they won't match with those either)
382+
TypeKind::Struct if self.type_kind(rust_ty) == TypeKind::Struct => {
383+
let rust_element_tys = self.struct_element_types(rust_ty);
384+
let llvm_element_tys = self.struct_element_types(llvm_ty);
385+
386+
if rust_element_tys.len() != llvm_element_tys.len() {
387+
return false;
388+
}
386389

387-
if rust_element_tys.len() != llvm_element_tys.len() {
388-
return false;
390+
iter::zip(rust_element_tys, llvm_element_tys).all(
391+
|(rust_element_ty, llvm_element_ty)| {
392+
self.equate_ty(rust_element_ty, llvm_element_ty)
393+
},
394+
)
389395
}
396+
TypeKind::Vector if self.element_type(llvm_ty) == self.type_i1() => {
397+
let element_count = self.vector_length(llvm_ty) as u64;
398+
let int_width = element_count.next_power_of_two().max(8);
390399

391-
iter::zip(rust_element_tys, llvm_element_tys).all(
392-
|(rust_element_ty, llvm_element_ty)| {
393-
self.equate_ty(rust_element_ty, llvm_element_ty)
394-
},
395-
)
396-
} else {
397-
false
400+
rust_ty == self.type_ix(int_width)
401+
}
402+
_ => false,
398403
}
399404
}
400405
}

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1680,6 +1680,46 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
16801680
}
16811681
}
16821682
impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
1683+
fn trunc_int_to_i1_vector(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {
1684+
let vector_length = self.vector_length(dest_ty) as u64;
1685+
let int_width = vector_length.next_power_of_two().max(8);
1686+
1687+
let bitcasted = self.bitcast(val, self.type_vector(self.type_i1(), int_width));
1688+
if vector_length == int_width {
1689+
bitcasted
1690+
} else {
1691+
let shuffle_mask =
1692+
(0..vector_length).map(|i| self.const_i32(i as i32)).collect::<Vec<_>>();
1693+
self.shuffle_vector(bitcasted, bitcasted, self.const_vector(&shuffle_mask))
1694+
}
1695+
}
1696+
1697+
fn zext_i1_vector_to_int(
1698+
&mut self,
1699+
mut val: &'ll Value,
1700+
src_ty: &'ll Type,
1701+
dest_ty: &'ll Type,
1702+
) -> &'ll Value {
1703+
let vector_length = self.vector_length(src_ty) as u64;
1704+
let int_width = vector_length.next_power_of_two().max(8);
1705+
1706+
if vector_length != int_width {
1707+
let shuffle_indices = match vector_length {
1708+
0 => unreachable!("zero length vectors are not allowed"),
1709+
1 => vec![0, 1, 1, 1, 1, 1, 1, 1],
1710+
2 => vec![0, 1, 2, 3, 2, 3, 2, 3],
1711+
3 => vec![0, 1, 2, 3, 4, 5, 3, 4],
1712+
4.. => (0..int_width as i32).collect(),
1713+
};
1714+
let shuffle_mask =
1715+
shuffle_indices.into_iter().map(|i| self.const_i32(i)).collect::<Vec<_>>();
1716+
val =
1717+
self.shuffle_vector(val, self.const_null(src_ty), self.const_vector(&shuffle_mask));
1718+
}
1719+
1720+
self.bitcast(val, dest_ty)
1721+
}
1722+
16831723
fn autocast(
16841724
&mut self,
16851725
llfn: &'ll Value,
@@ -1708,6 +1748,13 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
17081748
}
17091749
ret
17101750
}
1751+
TypeKind::Vector if self.element_type(llvm_ty) == self.type_i1() => {
1752+
if is_argument {
1753+
self.trunc_int_to_i1_vector(val, dest_ty)
1754+
} else {
1755+
self.zext_i1_vector_to_int(val, src_ty, dest_ty)
1756+
}
1757+
}
17111758
_ => unreachable!(),
17121759
}
17131760
}

tests/codegen-llvm/inject-autocast.rs

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//@ compile-flags: -C opt-level=0
22
//@ only-x86_64
33

4-
#![feature(link_llvm_intrinsics, abi_unadjusted, simd_ffi, portable_simd)]
4+
#![feature(link_llvm_intrinsics, abi_unadjusted, repr_simd, simd_ffi, portable_simd, f16)]
55
#![crate_type = "lib"]
66

77
use std::simd::i64x2;
@@ -10,6 +10,9 @@ use std::simd::i64x2;
1010
pub struct Bar(u32, i64x2, i64x2, i64x2, i64x2, i64x2, i64x2);
1111
// CHECK: %Bar = type <{ i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> }>
1212

13+
#[repr(simd)]
14+
pub struct f16x8([f16; 8]);
15+
1316
// CHECK-LABEL: @struct_autocast
1417
#[no_mangle]
1518
pub unsafe fn struct_autocast(key_metadata: u32, key: i64x2) -> Bar {
@@ -36,4 +39,41 @@ pub unsafe fn struct_autocast(key_metadata: u32, key: i64x2) -> Bar {
3639
foo(key_metadata, key)
3740
}
3841

42+
// CHECK-LABEL: @struct_with_i1_vector_autocast
43+
#[no_mangle]
44+
pub unsafe fn struct_with_i1_vector_autocast(a: i64x2, b: i64x2) -> (u8, u8) {
45+
extern "unadjusted" {
46+
#[link_name = "llvm.x86.avx512.vp2intersect.q.128"]
47+
fn foo(a: i64x2, b: i64x2) -> (u8, u8);
48+
}
49+
50+
// CHECK: [[A:%[0-9]+]] = call { <2 x i1>, <2 x i1> } @llvm.x86.avx512.vp2intersect.q.128(<2 x i64> {{.*}}, <2 x i64> {{.*}})
51+
// CHECK: [[B:%[0-9]+]] = extractvalue { <2 x i1>, <2 x i1> } [[A]], 0
52+
// CHECK: [[C:%[0-9]+]] = shufflevector <2 x i1> [[B]], <2 x i1> zeroinitializer, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 2, i32 3, i32 2, i32 3>
53+
// CHECK: [[D:%[0-9]+]] = bitcast <8 x i1> [[C]] to i8
54+
// CHECK: [[E:%[0-9]+]] = insertvalue { i8, i8 } poison, i8 [[D]], 0
55+
// CHECK: [[F:%[0-9]+]] = extractvalue { <2 x i1>, <2 x i1> } [[A]], 1
56+
// CHECK: [[G:%[0-9]+]] = shufflevector <2 x i1> [[F]], <2 x i1> zeroinitializer, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 2, i32 3, i32 2, i32 3>
57+
// CHECK: [[H:%[0-9]+]] = bitcast <8 x i1> [[G]] to i8
58+
// CHECK: insertvalue { i8, i8 } [[E]], i8 [[H]], 1
59+
foo(a, b)
60+
}
61+
62+
// CHECK-LABEL: @i1_vector_autocast
63+
#[no_mangle]
64+
pub unsafe fn i1_vector_autocast(a: f16x8) -> u8 {
65+
extern "unadjusted" {
66+
#[link_name = "llvm.x86.avx512fp16.fpclass.ph.128"]
67+
fn foo(a: f16x8, b: i32) -> u8;
68+
}
69+
70+
// CHECK: [[A:%[0-9]+]] = call <8 x i1> @llvm.x86.avx512fp16.fpclass.ph.128(<8 x half> {{.*}}, i32 1)
71+
// CHECK: bitcast <8 x i1> [[A]] to i8
72+
foo(a, 1)
73+
}
74+
3975
// CHECK: declare { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } @llvm.x86.encodekey128(i32, <2 x i64>)
76+
77+
// CHECK: declare { <2 x i1>, <2 x i1> } @llvm.x86.avx512.vp2intersect.q.128(<2 x i64>, <2 x i64>)
78+
79+
// CHECK: declare <8 x i1> @llvm.x86.avx512fp16.fpclass.ph.128(<8 x half>, i32 immarg)

0 commit comments

Comments
 (0)