Skip to content

Commit a9d86a3

Browse files
committed
Add autocasts for bf16 and bf16xN
1 parent ff16bf1 commit a9d86a3

File tree

5 files changed

+37
-6
lines changed

5 files changed

+37
-6
lines changed

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,8 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
376376
}
377377

378378
match self.type_kind(llvm_ty) {
379+
TypeKind::BFloat => rust_ty == self.type_i16(),
380+
379381
// Some LLVM intrinsics return **non-packed** structs, but they can't be mimicked from Rust
380382
// due to auto field-alignment in non-packed structs (packed structs are represented in LLVM
381383
// as, well, packed structs, so they won't match with those either)
@@ -393,11 +395,18 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
393395
},
394396
)
395397
}
396-
TypeKind::Vector if self.element_type(llvm_ty) == self.type_i1() => {
398+
TypeKind::Vector => {
397399
let element_count = self.vector_length(llvm_ty) as u64;
398-
let int_width = element_count.next_power_of_two().max(8);
400+
let llvm_element_ty = self.element_type(llvm_ty);
399401

400-
rust_ty == self.type_ix(int_width)
402+
if llvm_element_ty == self.type_bf16() {
403+
rust_ty == self.type_vector(self.type_i16(), element_count)
404+
} else if llvm_element_ty == self.type_i1() {
405+
let int_width = element_count.next_power_of_two().max(8);
406+
rust_ty == self.type_ix(int_width)
407+
} else {
408+
false
409+
}
401410
}
402411
_ => false,
403412
}

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1755,7 +1755,7 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
17551755
self.zext_i1_vector_to_int(val, src_ty, dest_ty)
17561756
}
17571757
}
1758-
_ => unreachable!(),
1758+
_ => self.bitcast(val, dest_ty), // for `bf16(xN)` <-> `u16(xN)`
17591759
}
17601760
}
17611761

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,9 @@ unsafe extern "C" {
974974
pub(crate) fn LLVMDoubleTypeInContext(C: &Context) -> &Type;
975975
pub(crate) fn LLVMFP128TypeInContext(C: &Context) -> &Type;
976976

977+
// Operations on non-IEEE real types
978+
pub(crate) fn LLVMBFloatTypeInContext(C: &Context) -> &Type;
979+
977980
// Operations on function types
978981
pub(crate) fn LLVMFunctionType<'a>(
979982
ReturnType: &'a Type,

compiler/rustc_codegen_llvm/src/type_.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,10 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
174174
)
175175
}
176176
}
177+
178+
pub(crate) fn type_bf16(&self) -> &'ll Type {
179+
unsafe { llvm::LLVMBFloatTypeInContext(self.llcx()) }
180+
}
177181
}
178182

179183
impl<'ll, CX: Borrow<SCx<'ll>>> BaseTypeCodegenMethods for GenericCx<'ll, CX> {
@@ -247,7 +251,7 @@ impl<'ll, CX: Borrow<SCx<'ll>>> BaseTypeCodegenMethods for GenericCx<'ll, CX> {
247251

248252
fn float_width(&self, ty: &'ll Type) -> usize {
249253
match self.type_kind(ty) {
250-
TypeKind::Half => 16,
254+
TypeKind::Half | TypeKind::BFloat => 16,
251255
TypeKind::Float => 32,
252256
TypeKind::Double => 64,
253257
TypeKind::X86_FP80 => 80,

tests/codegen-llvm/inject-autocast.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#![feature(link_llvm_intrinsics, abi_unadjusted, repr_simd, simd_ffi, portable_simd, f16)]
55
#![crate_type = "lib"]
66

7-
use std::simd::i64x2;
7+
use std::simd::{f32x4, i16x8, i64x2};
88

99
#[repr(C, packed)]
1010
pub struct Bar(u32, i64x2, i64x2, i64x2, i64x2, i64x2, i64x2);
@@ -72,8 +72,23 @@ pub unsafe fn i1_vector_autocast(a: f16x8) -> u8 {
7272
foo(a, 1)
7373
}
7474

75+
// CHECK-LABEL: @bf16_vector_autocast
76+
#[no_mangle]
77+
pub unsafe fn bf16_vector_autocast(a: f32x4) -> i16x8 {
78+
extern "unadjusted" {
79+
#[link_name = "llvm.x86.vcvtneps2bf16128"]
80+
fn foo(a: f32x4) -> i16x8;
81+
}
82+
83+
// CHECK: [[A:%[0-9]+]] = call <8 x bfloat> @llvm.x86.vcvtneps2bf16128(<4 x float> {{.*}})
84+
// CHECK: bitcast <8 x bfloat> [[A]] to <8 x i16>
85+
foo(a)
86+
}
87+
7588
// CHECK: declare { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } @llvm.x86.encodekey128(i32, <2 x i64>)
7689

7790
// CHECK: declare { <2 x i1>, <2 x i1> } @llvm.x86.avx512.vp2intersect.q.128(<2 x i64>, <2 x i64>)
7891

7992
// CHECK: declare <8 x i1> @llvm.x86.avx512fp16.fpclass.ph.128(<8 x half>, i32 immarg)
93+
94+
// CHECK: declare <8 x bfloat> @llvm.x86.vcvtneps2bf16128(<4 x float>)

0 commit comments

Comments
 (0)