Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions crates/rustc_codegen_spirv/src/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ pub enum IntrinsicType {
pub struct SpecConstant {
pub id: u32,
pub default: Option<u32>,
pub array_count: Option<u32>,
}

// NOTE(eddyb) when adding new `#[spirv(...)]` attributes, the tests found inside
Expand Down Expand Up @@ -661,6 +662,8 @@ fn parse_spec_constant_attr(
Ok(SpecConstant {
id: id.ok_or_else(|| (arg.span(), "expected `spec_constant(id = ...)`".into()))?,
default,
// to be set later
array_count: None,
})
}

Expand Down
82 changes: 59 additions & 23 deletions crates/rustc_codegen_spirv/src/codegen_cx/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ use rspirv::dr::Operand;
use rspirv::spirv::{
Capability, Decoration, Dim, ExecutionModel, FunctionControl, StorageClass, Word,
};
use rustc_abi::FieldsShape;
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods, MiscCodegenMethods as _};
use rustc_data_structures::fx::FxHashMap;
use rustc_errors::MultiSpan;
use rustc_hir as hir;
use rustc_middle::span_bug;
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
use rustc_middle::ty::{self, Instance, Ty};
use rustc_span::Span;
use rustc_span::{DUMMY_SP, Span};
use rustc_target::callconv::{ArgAbi, FnAbi, PassMode};
use std::assert_matches::assert_matches;

Expand Down Expand Up @@ -395,23 +396,38 @@ impl<'tcx> CodegenCx<'tcx> {
// would've assumed it was actually an implicitly-`Input`.
let mut storage_class = Ok(storage_class);
if let Some(spec_constant) = attrs.spec_constant {
if ref_or_value_layout.ty != self.tcx.types.u32 {
let ty = ref_or_value_layout;
let valid_array_count = match ty.fields {
FieldsShape::Array { count, .. } => {
let element = ty.field(self, 0);
(element.ty == self.tcx.types.u32).then_some(u32::try_from(count).ok())
}
FieldsShape::Primitive => (ty.ty == self.tcx.types.u32).then_some(None),
_ => None,
};

if let Some(array_count) = valid_array_count {
if let Some(storage_class) = attrs.storage_class {
self.tcx.dcx().span_err(
storage_class.span,
"`#[spirv(spec_constant)]` cannot have a storage class",
);
} else {
assert_eq!(storage_class, Ok(StorageClass::Input));
assert!(!is_ref);
storage_class = Err(SpecConstant {
array_count,
..spec_constant.value
});
}
} else {
self.tcx.dcx().span_err(
hir_param.ty_span,
format!(
"unsupported `#[spirv(spec_constant)]` type `{}` (expected `{}`)",
ref_or_value_layout.ty, self.tcx.types.u32
"unsupported `#[spirv(spec_constant)]` type `{}` (expected `u32` or `[u32; N]`)",
ref_or_value_layout.ty
),
);
} else if let Some(storage_class) = attrs.storage_class {
self.tcx.dcx().span_err(
storage_class.span,
"`#[spirv(spec_constant)]` cannot have a storage class",
);
} else {
assert_eq!(storage_class, Ok(StorageClass::Input));
assert!(!is_ref);
storage_class = Err(spec_constant.value);
}
}

Expand Down Expand Up @@ -448,18 +464,38 @@ impl<'tcx> CodegenCx<'tcx> {
Ok(self.emit_global().id()),
Err("entry-point interface variable is not a `#[spirv(spec_constant)]`"),
),
Err(SpecConstant { id, default }) => {
let mut emit = self.emit_global();
let spec_const_id =
emit.spec_constant_bit32(value_spirv_type, default.unwrap_or(0));
emit.decorate(
spec_const_id,
Decoration::SpecId,
[Operand::LiteralBit32(id)],
);
Err(SpecConstant {
id,
default,
array_count,
}) => {
let u32_ty = SpirvType::Integer(32, false).def(DUMMY_SP, self);
let single = |id: u32| {
let mut emit = self.emit_global();
let spec_const_id = emit.spec_constant_bit32(u32_ty, default.unwrap_or(0));
emit.decorate(
spec_const_id,
Decoration::SpecId,
[Operand::LiteralBit32(id)],
);
spec_const_id
};
let param_word = if let Some(array_count) = array_count {
let array = (0..array_count).map(|i| single(id + i)).collect::<Vec<_>>();
let array_ty = SpirvType::Array {
element: u32_ty,
count: self.constant_u32(DUMMY_SP, array_count),
}
.def(DUMMY_SP, self);
bx.emit()
.composite_construct(array_ty, None, array)
.unwrap()
} else {
single(id)
};
(
Err("`#[spirv(spec_constant)]` is not an entry-point interface variable"),
Ok(spec_const_id),
Ok(param_word),
)
}
};
Expand Down
28 changes: 28 additions & 0 deletions tests/compiletests/ui/dis/spec_constant_array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Tests the various forms of `#[spirv(spec_constant)]`.

// build-pass
// ignore-spv1.0
// ignore-spv1.1
// ignore-spv1.2
// ignore-spv1.3
// ignore-vulkan1.0
// ignore-vulkan1.1

// compile-flags: -C llvm-args=--disassemble
// normalize-stderr-test "; .*\n" -> ""
// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> ""
// normalize-stderr-test "OpSource .*\n" -> ""
// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple"

// HACK(eddyb) `compiletest` handles `ui\dis\`, but not `ui\\dis\\`, on Windows.
// normalize-stderr-test "ui/dis/" -> "$$DIR/"

use spirv_std::spirv;

#[spirv(compute(threads(1)))]
pub fn main(
#[spirv(spec_constant(id = 42, default = 69))] array: [u32; 4],
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] out: &mut u32,
) {
*out = array[0] + array[1] + array[2] + array[3];
}
38 changes: 38 additions & 0 deletions tests/compiletests/ui/dis/spec_constant_array.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
OpCapability Shader
OpMemoryModel Logical Simple
OpEntryPoint GLCompute %1 "main" %2
OpExecutionMode %1 LocalSize 1 1 1
%3 = OpString "$DIR/spec_constant_array.rs"
OpDecorate %4 Block
OpMemberDecorate %4 0 Offset 0
OpDecorate %2 Binding 0
OpDecorate %2 DescriptorSet 0
OpDecorate %5 SpecId 42
OpDecorate %6 SpecId 43
OpDecorate %7 SpecId 44
OpDecorate %8 SpecId 45
%9 = OpTypeInt 32 0
%4 = OpTypeStruct %9
%10 = OpTypePointer StorageBuffer %4
%11 = OpTypeVoid
%12 = OpTypeFunction %11
%13 = OpTypePointer StorageBuffer %9
%2 = OpVariable %10 StorageBuffer
%14 = OpConstant %9 0
%5 = OpSpecConstant %9 69
%6 = OpSpecConstant %9 69
%7 = OpSpecConstant %9 69
%8 = OpSpecConstant %9 69
%1 = OpFunction %11 None %12
%15 = OpLabel
OpLine %3 25 4
%16 = OpInBoundsAccessChain %13 %2 %14
OpLine %3 27 11
%17 = OpIAdd %9 %5 %6
%18 = OpIAdd %9 %17 %7
OpLine %3 27 4
%19 = OpIAdd %9 %18 %8
OpStore %16 %19
OpNoLine
OpReturn
OpFunctionEnd