Skip to content

Commit 2d549e7

Browse files
committed
Add autocast for i1 vectors
1 parent fd8264e commit 2d549e7

File tree

3 files changed

+151
-17
lines changed

3 files changed

+151
-17
lines changed

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -369,26 +369,31 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
369369
return true;
370370
}
371371

372-
// Some LLVM intrinsics return **non-packed** structs, but they can't be mimicked from Rust
373-
// due to auto field-alignment in non-packed structs (packed structs are represented in LLVM
374-
// as, well, packed structs, so they won't match with those either)
375-
if self.type_kind(llvm_ty) == TypeKind::Struct
376-
&& self.type_kind(rust_ty) == TypeKind::Struct
377-
{
378-
let rust_element_tys = self.struct_element_types(rust_ty);
379-
let llvm_element_tys = self.struct_element_types(llvm_ty);
372+
match self.type_kind(llvm_ty) {
373+
// Some LLVM intrinsics return **non-packed** structs, but they can't be mimicked from Rust
374+
// due to auto field-alignment in non-packed structs (packed structs are represented in LLVM
375+
// as, well, packed structs, so they won't match with those either)
376+
TypeKind::Struct if self.type_kind(rust_ty) == TypeKind::Struct => {
377+
let rust_element_tys = self.struct_element_types(rust_ty);
378+
let llvm_element_tys = self.struct_element_types(llvm_ty);
379+
380+
if rust_element_tys.len() != llvm_element_tys.len() {
381+
return false;
382+
}
380383

381-
if rust_element_tys.len() != llvm_element_tys.len() {
382-
return false;
384+
iter::zip(rust_element_tys, llvm_element_tys).all(
385+
|(rust_element_ty, llvm_element_ty)| {
386+
self.equate_ty(rust_element_ty, llvm_element_ty)
387+
},
388+
)
383389
}
390+
TypeKind::Vector if self.element_type(llvm_ty) == self.type_i1() => {
391+
let element_count = self.vector_length(llvm_ty) as u64;
392+
let int_width = element_count.next_power_of_two().max(8);
384393

385-
iter::zip(rust_element_tys, llvm_element_tys).all(
386-
|(rust_element_ty, llvm_element_ty)| {
387-
self.equate_ty(rust_element_ty, llvm_element_ty)
388-
},
389-
)
390-
} else {
391-
false
394+
rust_ty == self.type_ix(int_width)
395+
}
396+
_ => false,
392397
}
393398
}
394399
}

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1686,6 +1686,46 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
16861686
}
16871687
}
16881688
impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
1689+
fn trunc_int_to_i1_vector(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {
1690+
let vector_length = self.vector_length(dest_ty) as u64;
1691+
let int_width = vector_length.next_power_of_two().max(8);
1692+
1693+
let bitcasted = self.bitcast(val, self.type_vector(self.type_i1(), int_width));
1694+
if vector_length == int_width {
1695+
bitcasted
1696+
} else {
1697+
let shuffle_mask =
1698+
(0..vector_length).map(|i| self.const_i32(i as i32)).collect::<Vec<_>>();
1699+
self.shuffle_vector(bitcasted, bitcasted, self.const_vector(&shuffle_mask))
1700+
}
1701+
}
1702+
1703+
fn zext_i1_vector_to_int(
1704+
&mut self,
1705+
mut val: &'ll Value,
1706+
src_ty: &'ll Type,
1707+
dest_ty: &'ll Type,
1708+
) -> &'ll Value {
1709+
let vector_length = self.vector_length(src_ty) as u64;
1710+
let int_width = vector_length.next_power_of_two().max(8);
1711+
1712+
if vector_length != int_width {
1713+
let shuffle_indices = match vector_length {
1714+
0 => unreachable!("zero length vectors are not allowed"),
1715+
1 => vec![0, 1, 1, 1, 1, 1, 1, 1],
1716+
2 => vec![0, 1, 2, 3, 2, 3, 2, 3],
1717+
3 => vec![0, 1, 2, 3, 4, 5, 3, 4],
1718+
4.. => (0..int_width as i32).collect(),
1719+
};
1720+
let shuffle_mask =
1721+
shuffle_indices.into_iter().map(|i| self.const_i32(i)).collect::<Vec<_>>();
1722+
val =
1723+
self.shuffle_vector(val, self.const_null(src_ty), self.const_vector(&shuffle_mask));
1724+
}
1725+
1726+
self.bitcast(val, dest_ty)
1727+
}
1728+
16891729
fn autocast(
16901730
&mut self,
16911731
llfn: &'ll Value,
@@ -1714,6 +1754,13 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
17141754
}
17151755
ret
17161756
}
1757+
TypeKind::Vector if self.element_type(llvm_ty) == self.type_i1() => {
1758+
if is_argument {
1759+
self.trunc_int_to_i1_vector(val, dest_ty)
1760+
} else {
1761+
self.zext_i1_vector_to_int(val, src_ty, dest_ty)
1762+
}
1763+
}
17171764
_ => unreachable!(),
17181765
}
17191766
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
//@ compile-flags: -C opt-level=0
2+
//@ only-x86_64
3+
4+
#![feature(link_llvm_intrinsics, abi_unadjusted, repr_simd, simd_ffi, portable_simd, f16)]
5+
#![crate_type = "lib"]
6+
7+
use std::simd::i64x2;
8+
9+
#[repr(simd)]
10+
pub struct Tile([i8; 1024]);
11+
12+
#[repr(C, packed)]
13+
pub struct Bar(u32, i64x2, i64x2, i64x2, i64x2, i64x2, i64x2);
14+
// CHECK: %Bar = type <{ i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> }>
15+
16+
#[repr(simd)]
17+
pub struct f16x8([f16; 8]);
18+
19+
// CHECK-LABEL: @struct_with_i1_vector_autocast
20+
#[no_mangle]
21+
pub unsafe fn struct_with_i1_vector_autocast(a: i64x2, b: i64x2) -> (u8, u8) {
22+
extern "unadjusted" {
23+
#[link_name = "llvm.x86.avx512.vp2intersect.q.128"]
24+
fn foo(a: i64x2, b: i64x2) -> (u8, u8);
25+
}
26+
27+
// CHECK: [[A:%[0-9]+]] = call { <2 x i1>, <2 x i1> } @llvm.x86.avx512.vp2intersect.q.128(<2 x i64> {{.*}}, <2 x i64> {{.*}})
28+
// CHECK: [[B:%[0-9]+]] = extractvalue { <2 x i1>, <2 x i1> } [[A]], 0
29+
// CHECK: [[C:%[0-9]+]] = shufflevector <2 x i1> [[B]], <2 x i1> zeroinitializer, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 2, i32 3, i32 2, i32 3>
30+
// CHECK: [[D:%[0-9]+]] = bitcast <8 x i1> [[C]] to i8
31+
// CHECK: [[E:%[0-9]+]] = insertvalue { i8, i8 } poison, i8 [[D]], 0
32+
// CHECK: [[F:%[0-9]+]] = extractvalue { <2 x i1>, <2 x i1> } [[A]], 1
33+
// CHECK: [[G:%[0-9]+]] = shufflevector <2 x i1> [[F]], <2 x i1> zeroinitializer, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 2, i32 3, i32 2, i32 3>
34+
// CHECK: [[H:%[0-9]+]] = bitcast <8 x i1> [[G]] to i8
35+
// CHECK: insertvalue { i8, i8 } [[E]], i8 [[H]], 1
36+
foo(a, b)
37+
}
38+
39+
// CHECK-LABEL: @struct_autocast
40+
#[no_mangle]
41+
pub unsafe fn struct_autocast(key_metadata: u32, key: i64x2) -> Bar {
42+
extern "unadjusted" {
43+
#[link_name = "llvm.x86.encodekey128"]
44+
fn foo(key_metadata: u32, key: i64x2) -> Bar;
45+
}
46+
47+
// CHECK: [[A:%[0-9]+]] = call { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } @llvm.x86.encodekey128(i32 {{.*}}, <2 x i64> {{.*}})
48+
// CHECK: [[B:%[0-9]+]] = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } [[A]], 0
49+
// CHECK: [[C:%[0-9]+]] = insertvalue %Bar poison, i32 [[B]], 0
50+
// CHECK: [[D:%[0-9]+]] = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } [[A]], 1
51+
// CHECK: [[E:%[0-9]+]] = insertvalue %Bar [[C]], <2 x i64> [[D]], 1
52+
// CHECK: [[F:%[0-9]+]] = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } [[A]], 2
53+
// CHECK: [[G:%[0-9]+]] = insertvalue %Bar [[E]], <2 x i64> [[F]], 2
54+
// CHECK: [[H:%[0-9]+]] = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } [[A]], 3
55+
// CHECK: [[I:%[0-9]+]] = insertvalue %Bar [[G]], <2 x i64> [[H]], 3
56+
// CHECK: [[J:%[0-9]+]] = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } [[A]], 4
57+
// CHECK: [[K:%[0-9]+]] = insertvalue %Bar [[I]], <2 x i64> [[J]], 4
58+
// CHECK: [[L:%[0-9]+]] = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } [[A]], 5
59+
// CHECK: [[M:%[0-9]+]] = insertvalue %Bar [[K]], <2 x i64> [[L]], 5
60+
// CHECK: [[N:%[0-9]+]] = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } [[A]], 6
61+
// CHECK: insertvalue %Bar [[M]], <2 x i64> [[N]], 6
62+
foo(key_metadata, key)
63+
}
64+
65+
// CHECK-LABEL: @i1_vector_autocast
66+
#[no_mangle]
67+
pub unsafe fn i1_vector_autocast(a: f16x8) -> u8 {
68+
extern "unadjusted" {
69+
#[link_name = "llvm.x86.avx512fp16.fpclass.ph.128"]
70+
fn foo(a: f16x8, b: i32) -> u8;
71+
}
72+
73+
// CHECK: [[A:%[0-9]+]] = call <8 x i1> @llvm.x86.avx512fp16.fpclass.ph.128(<8 x half> {{.*}}, i32 1)
74+
// CHECK: bitcast <8 x i1> [[A]] to i8
75+
foo(a, 1)
76+
}
77+
78+
// CHECK: declare { <2 x i1>, <2 x i1> } @llvm.x86.avx512.vp2intersect.q.128(<2 x i64>, <2 x i64>)
79+
80+
// CHECK: declare { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } @llvm.x86.encodekey128(i32, <2 x i64>)
81+
82+
// CHECK: declare <8 x i1> @llvm.x86.avx512fp16.fpclass.ph.128(<8 x half>, i32 immarg)

0 commit comments

Comments
 (0)