Skip to content

Commit 84c41e7

Browse files
committed
Add autocasts for structs
1 parent d548876 commit 84c41e7

File tree

5 files changed

+185
-50
lines changed

5 files changed

+185
-50
lines changed

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use std::borrow::Borrow;
12
use std::{cmp, iter};
23

34
use libc::c_uint;
@@ -6,6 +7,7 @@ use rustc_abi::{
67
X86Call,
78
};
89
use rustc_codegen_ssa::MemFlags;
10+
use rustc_codegen_ssa::common::TypeKind;
911
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
1012
use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue};
1113
use rustc_codegen_ssa::traits::*;
@@ -21,7 +23,7 @@ use smallvec::SmallVec;
2123

2224
use crate::attributes::{self, llfn_attrs_from_instance};
2325
use crate::builder::Builder;
24-
use crate::context::CodegenCx;
26+
use crate::context::{CodegenCx, GenericCx, SCx};
2527
use crate::llvm::{self, Attribute, AttributePlace, Type, Value};
2628
use crate::llvm_util;
2729
use crate::type_of::LayoutLlvmExt;
@@ -367,6 +369,36 @@ pub(crate) trait FnAbiLlvmExt<'ll, 'tcx> {
367369
);
368370
}
369371

372+
impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
373+
pub(crate) fn equate_ty(&self, rust_ty: &'ll Type, llvm_ty: &'ll Type) -> bool {
374+
if rust_ty == llvm_ty {
375+
return true;
376+
}
377+
378+
// Some LLVM intrinsics return **non-packed** structs, but they can't be mimicked from Rust
379+
// due to auto field-alignment in non-packed structs (packed structs are represented in LLVM
380+
// as, well, packed structs, so they won't match with those either)
381+
if self.type_kind(llvm_ty) == TypeKind::Struct
382+
&& self.type_kind(rust_ty) == TypeKind::Struct
383+
{
384+
let rust_element_tys = self.struct_element_types(rust_ty);
385+
let llvm_element_tys = self.struct_element_types(llvm_ty);
386+
387+
if rust_element_tys.len() != llvm_element_tys.len() {
388+
return false;
389+
}
390+
391+
iter::zip(rust_element_tys, llvm_element_tys).all(
392+
|(rust_element_ty, llvm_element_ty)| {
393+
self.equate_ty(rust_element_ty, llvm_element_ty)
394+
},
395+
)
396+
} else {
397+
false
398+
}
399+
}
400+
}
401+
370402
impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
371403
fn llvm_return_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type {
372404
match &self.ret.mode {
@@ -456,7 +488,7 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
456488

457489
iter::once((rust_return_ty, llvm_return_ty))
458490
.chain(iter::zip(rust_argument_tys, llvm_argument_tys))
459-
.all(|(rust_ty, llvm_ty)| rust_ty == llvm_ty)
491+
.all(|(rust_ty, llvm_ty)| cx.equate_ty(rust_ty, llvm_ty))
460492
}
461493

462494
fn rust_signature(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type {

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 99 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ impl<'a, 'll> SBuilder<'a, 'll> {
6767
) -> &'ll Value {
6868
debug!("call {:?} with args ({:?})", llfn, args);
6969

70-
let args = self.check_call("call", llty, llfn, args);
7170
let funclet_bundle = funclet.map(|funclet| funclet.bundle());
7271
let mut bundles: SmallVec<[_; 2]> = SmallVec::new();
7372
if let Some(funclet_bundle) = funclet_bundle {
@@ -411,7 +410,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
411410
) -> &'ll Value {
412411
debug!("invoke {:?} with args ({:?})", llfn, args);
413412

414-
let args = self.check_call("invoke", llty, llfn, args);
413+
let args = self.cast_arguments("invoke", llty, llfn, args, fn_abi.is_some());
415414
let funclet_bundle = funclet.map(|funclet| funclet.bundle());
416415
let mut bundles: SmallVec<[_; 2]> = SmallVec::new();
417416
if let Some(funclet_bundle) = funclet_bundle {
@@ -443,8 +442,10 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
443442
};
444443
if let Some(fn_abi) = fn_abi {
445444
fn_abi.apply_attrs_callsite(self, invoke, llfn);
445+
self.cast_return(fn_abi, llfn, invoke)
446+
} else {
447+
invoke
446448
}
447-
invoke
448449
}
449450

450451
fn unreachable(&mut self) {
@@ -1384,7 +1385,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
13841385
) -> &'ll Value {
13851386
debug!("call {:?} with args ({:?})", llfn, args);
13861387

1387-
let args = self.check_call("call", llty, llfn, args);
1388+
let args = self.cast_arguments("call", llty, llfn, args, fn_abi.is_some());
13881389
let funclet_bundle = funclet.map(|funclet| funclet.bundle());
13891390
let mut bundles: SmallVec<[_; 2]> = SmallVec::new();
13901391
if let Some(funclet_bundle) = funclet_bundle {
@@ -1437,8 +1438,10 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
14371438

14381439
if let Some(fn_abi) = fn_abi {
14391440
fn_abi.apply_attrs_callsite(self, call, llfn);
1441+
self.cast_return(fn_abi, llfn, call)
1442+
} else {
1443+
call
14401444
}
1441-
call
14421445
}
14431446

14441447
fn tail_call(
@@ -1622,47 +1625,6 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
16221625
ret.expect("LLVM does not have support for catchret")
16231626
}
16241627

1625-
fn check_call<'b>(
1626-
&mut self,
1627-
typ: &str,
1628-
fn_ty: &'ll Type,
1629-
llfn: &'ll Value,
1630-
args: &'b [&'ll Value],
1631-
) -> Cow<'b, [&'ll Value]> {
1632-
assert!(
1633-
self.cx.type_kind(fn_ty) == TypeKind::Function,
1634-
"builder::{typ} not passed a function, but {fn_ty:?}"
1635-
);
1636-
1637-
let param_tys = self.cx.func_params_types(fn_ty);
1638-
1639-
let all_args_match = iter::zip(&param_tys, args.iter().map(|&v| self.cx.val_ty(v)))
1640-
.all(|(expected_ty, actual_ty)| *expected_ty == actual_ty);
1641-
1642-
if all_args_match {
1643-
return Cow::Borrowed(args);
1644-
}
1645-
1646-
let casted_args: Vec<_> = iter::zip(param_tys, args)
1647-
.enumerate()
1648-
.map(|(i, (expected_ty, &actual_val))| {
1649-
let actual_ty = self.cx.val_ty(actual_val);
1650-
if expected_ty != actual_ty {
1651-
debug!(
1652-
"type mismatch in function call of {:?}. \
1653-
Expected {:?} for param {}, got {:?}; injecting bitcast",
1654-
llfn, expected_ty, i, actual_ty
1655-
);
1656-
self.bitcast(actual_val, expected_ty)
1657-
} else {
1658-
actual_val
1659-
}
1660-
})
1661-
.collect();
1662-
1663-
Cow::Owned(casted_args)
1664-
}
1665-
16661628
pub(crate) fn va_arg(&mut self, list: &'ll Value, ty: &'ll Type) -> &'ll Value {
16671629
unsafe { llvm::LLVMBuildVAArg(self.llbuilder, list, ty, UNNAMED) }
16681630
}
@@ -1718,6 +1680,93 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
17181680
}
17191681
}
17201682
impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
1683+
fn autocast(
1684+
&mut self,
1685+
llfn: &'ll Value,
1686+
val: &'ll Value,
1687+
src_ty: &'ll Type,
1688+
dest_ty: &'ll Type,
1689+
is_argument: bool,
1690+
) -> &'ll Value {
1691+
let (rust_ty, llvm_ty) = if is_argument { (src_ty, dest_ty) } else { (dest_ty, src_ty) };
1692+
1693+
if rust_ty == llvm_ty {
1694+
return val;
1695+
}
1696+
1697+
match self.type_kind(llvm_ty) {
1698+
TypeKind::Struct => {
1699+
let mut ret = self.const_poison(dest_ty);
1700+
for (idx, (src_element_ty, dest_element_ty)) in
1701+
iter::zip(self.struct_element_types(src_ty), self.struct_element_types(dest_ty))
1702+
.enumerate()
1703+
{
1704+
let elt = self.extract_value(val, idx as u64);
1705+
let casted_elt =
1706+
self.autocast(llfn, elt, src_element_ty, dest_element_ty, is_argument);
1707+
ret = self.insert_value(ret, casted_elt, idx as u64);
1708+
}
1709+
ret
1710+
}
1711+
_ => unreachable!(),
1712+
}
1713+
}
1714+
1715+
fn cast_arguments<'b>(
1716+
&mut self,
1717+
typ: &str,
1718+
fn_ty: &'ll Type,
1719+
llfn: &'ll Value,
1720+
args: &'b [&'ll Value],
1721+
has_fnabi: bool,
1722+
) -> Cow<'b, [&'ll Value]> {
1723+
assert_eq!(
1724+
self.type_kind(fn_ty),
1725+
TypeKind::Function,
1726+
"{typ} not passed a function, but {fn_ty:?}"
1727+
);
1728+
1729+
let param_tys = self.func_params_types(fn_ty);
1730+
1731+
let mut casted_args = Cow::Borrowed(args);
1732+
1733+
for (idx, (dest_ty, &arg)) in iter::zip(param_tys, args).enumerate() {
1734+
let src_ty = self.val_ty(arg);
1735+
assert!(
1736+
self.equate_ty(src_ty, dest_ty),
1737+
"Cannot match `{dest_ty:?}` (expected) with `{src_ty:?}` (found) in `{llfn:?}`"
1738+
);
1739+
1740+
let casted_arg = self.autocast(llfn, arg, src_ty, dest_ty, true);
1741+
if arg != casted_arg {
1742+
assert!(
1743+
has_fnabi,
1744+
"Should inject autocasts in function call of {llfn:?}, but not able to get Rust signature"
1745+
);
1746+
1747+
casted_args.to_mut()[idx] = casted_arg;
1748+
}
1749+
}
1750+
1751+
casted_args
1752+
}
1753+
1754+
fn cast_return(
1755+
&mut self,
1756+
fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
1757+
llfn: &'ll Value,
1758+
ret: &'ll Value,
1759+
) -> &'ll Value {
1760+
let src_ty = self.val_ty(ret);
1761+
let dest_ty = fn_abi.llvm_return_type(self);
1762+
assert!(
1763+
self.equate_ty(dest_ty, src_ty),
1764+
"Cannot match `{src_ty:?}` (expected) with `{dest_ty:?}` (found) in `{llfn:?}`"
1765+
);
1766+
1767+
self.autocast(llfn, ret, src_ty, dest_ty, false)
1768+
}
1769+
17211770
pub(crate) fn landing_pad(
17221771
&mut self,
17231772
ty: &'ll Type,
@@ -1747,7 +1796,7 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
17471796
) -> &'ll Value {
17481797
debug!("invoke {:?} with args ({:?})", llfn, args);
17491798

1750-
let args = self.check_call("callbr", llty, llfn, args);
1799+
let args = self.cast_arguments("callbr", llty, llfn, args, fn_abi.is_some());
17511800
let funclet_bundle = funclet.map(|funclet| funclet.bundle());
17521801
let mut bundles: SmallVec<[_; 2]> = SmallVec::new();
17531802
if let Some(funclet_bundle) = funclet_bundle {
@@ -1780,8 +1829,10 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
17801829
};
17811830
if let Some(fn_abi) = fn_abi {
17821831
fn_abi.apply_attrs_callsite(self, callbr, llfn);
1832+
self.cast_return(fn_abi, llfn, callbr)
1833+
} else {
1834+
callbr
17831835
}
1784-
callbr
17851836
}
17861837

17871838
// Emits CFI pointer type membership tests.

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1671,6 +1671,9 @@ unsafe extern "C" {
16711671
Packed: Bool,
16721672
);
16731673

1674+
pub(crate) fn LLVMCountStructElementTypes(StructTy: &Type) -> c_uint;
1675+
pub(crate) fn LLVMGetStructElementTypes<'a>(StructTy: &'a Type, Dest: *mut &'a Type);
1676+
16741677
pub(crate) safe fn LLVMMetadataAsValue<'a>(C: &'a Context, MD: &'a Metadata) -> &'a Value;
16751678

16761679
pub(crate) safe fn LLVMSetUnnamedAddress(Global: &Value, UnnamedAddr: UnnamedAddr);

compiler/rustc_codegen_llvm/src/type_.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,16 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
8585
pub(crate) fn func_is_variadic(&self, ty: &'ll Type) -> bool {
8686
unsafe { llvm::LLVMIsFunctionVarArg(ty).is_true() }
8787
}
88+
89+
pub(crate) fn struct_element_types(&self, ty: &'ll Type) -> Vec<&'ll Type> {
90+
unsafe {
91+
let n_args = llvm::LLVMCountStructElementTypes(ty) as usize;
92+
let mut args = Vec::with_capacity(n_args);
93+
llvm::LLVMGetStructElementTypes(ty, args.as_mut_ptr());
94+
args.set_len(n_args);
95+
args
96+
}
97+
}
8898
}
8999
impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
90100
pub(crate) fn type_bool(&self) -> &'ll Type {
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
//@ compile-flags: -C opt-level=0
2+
//@ only-x86_64
3+
4+
#![feature(link_llvm_intrinsics, abi_unadjusted, simd_ffi, portable_simd)]
5+
#![crate_type = "lib"]
6+
7+
use std::simd::i64x2;
8+
9+
#[repr(C, packed)]
10+
pub struct Bar(u32, i64x2, i64x2, i64x2, i64x2, i64x2, i64x2);
11+
// CHECK: %Bar = type <{ i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> }>
12+
13+
// CHECK-LABEL: @struct_autocast
14+
#[no_mangle]
15+
pub unsafe fn struct_autocast(key_metadata: u32, key: i64x2) -> Bar {
16+
extern "unadjusted" {
17+
#[link_name = "llvm.x86.encodekey128"]
18+
fn foo(key_metadata: u32, key: i64x2) -> Bar;
19+
}
20+
21+
// CHECK: [[A:%[0-9]+]] = call { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } @llvm.x86.encodekey128(i32 {{.*}}, <2 x i64> {{.*}})
22+
// CHECK: [[B:%[0-9]+]] = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } [[A]], 0
23+
// CHECK: [[C:%[0-9]+]] = insertvalue %Bar poison, i32 [[B]], 0
24+
// CHECK: [[D:%[0-9]+]] = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } [[A]], 1
25+
// CHECK: [[E:%[0-9]+]] = insertvalue %Bar [[C]], <2 x i64> [[D]], 1
26+
// CHECK: [[F:%[0-9]+]] = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } [[A]], 2
27+
// CHECK: [[G:%[0-9]+]] = insertvalue %Bar [[E]], <2 x i64> [[F]], 2
28+
// CHECK: [[H:%[0-9]+]] = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } [[A]], 3
29+
// CHECK: [[I:%[0-9]+]] = insertvalue %Bar [[G]], <2 x i64> [[H]], 3
30+
// CHECK: [[J:%[0-9]+]] = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } [[A]], 4
31+
// CHECK: [[K:%[0-9]+]] = insertvalue %Bar [[I]], <2 x i64> [[J]], 4
32+
// CHECK: [[L:%[0-9]+]] = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } [[A]], 5
33+
// CHECK: [[M:%[0-9]+]] = insertvalue %Bar [[K]], <2 x i64> [[L]], 5
34+
// CHECK: [[N:%[0-9]+]] = extractvalue { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } [[A]], 6
35+
// CHECK: insertvalue %Bar [[M]], <2 x i64> [[N]], 6
36+
foo(key_metadata, key)
37+
}
38+
39+
// CHECK: declare { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } @llvm.x86.encodekey128(i32, <2 x i64>)

0 commit comments

Comments
 (0)