Skip to content

Commit 90b40ec

Browse files
committed
SpecConstant: add arrayed spec constants
1 parent e767f24 commit 90b40ec

File tree

4 files changed

+121
-23
lines changed

4 files changed

+121
-23
lines changed

crates/rustc_codegen_spirv/src/attr.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ pub enum IntrinsicType {
7575
pub struct SpecConstant {
7676
pub id: u32,
7777
pub default: Option<u32>,
78+
pub array_count: Option<u32>,
7879
}
7980

8081
// NOTE(eddyb) when adding new `#[spirv(...)]` attributes, the tests found inside
@@ -661,6 +662,8 @@ fn parse_spec_constant_attr(
661662
Ok(SpecConstant {
662663
id: id.ok_or_else(|| (arg.span(), "expected `spec_constant(id = ...)`".into()))?,
663664
default,
665+
// to be set later
666+
array_count: None,
664667
})
665668
}
666669

crates/rustc_codegen_spirv/src/codegen_cx/entry.rs

Lines changed: 59 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@ use rspirv::dr::Operand;
1111
use rspirv::spirv::{
1212
Capability, Decoration, Dim, ExecutionModel, FunctionControl, StorageClass, Word,
1313
};
14+
use rustc_abi::FieldsShape;
1415
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods, MiscCodegenMethods as _};
1516
use rustc_data_structures::fx::FxHashMap;
1617
use rustc_errors::MultiSpan;
1718
use rustc_hir as hir;
1819
use rustc_middle::span_bug;
1920
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
2021
use rustc_middle::ty::{self, Instance, Ty};
21-
use rustc_span::Span;
22+
use rustc_span::{DUMMY_SP, Span};
2223
use rustc_target::callconv::{ArgAbi, FnAbi, PassMode};
2324
use std::assert_matches::assert_matches;
2425

@@ -395,23 +396,38 @@ impl<'tcx> CodegenCx<'tcx> {
395396
// would've assumed it was actually an implicitly-`Input`.
396397
let mut storage_class = Ok(storage_class);
397398
if let Some(spec_constant) = attrs.spec_constant {
398-
if ref_or_value_layout.ty != self.tcx.types.u32 {
399+
let ty = ref_or_value_layout;
400+
let valid_array_count = match ty.fields {
401+
FieldsShape::Array { count, .. } => {
402+
let element = ty.field(self, 0);
403+
(element.ty == self.tcx.types.u32).then_some(u32::try_from(count).ok())
404+
}
405+
FieldsShape::Primitive => (ty.ty == self.tcx.types.u32).then_some(None),
406+
_ => None,
407+
};
408+
409+
if let Some(array_count) = valid_array_count {
410+
if let Some(storage_class) = attrs.storage_class {
411+
self.tcx.dcx().span_err(
412+
storage_class.span,
413+
"`#[spirv(spec_constant)]` cannot have a storage class",
414+
);
415+
} else {
416+
assert_eq!(storage_class, Ok(StorageClass::Input));
417+
assert!(!is_ref);
418+
storage_class = Err(SpecConstant {
419+
array_count,
420+
..spec_constant.value
421+
});
422+
}
423+
} else {
399424
self.tcx.dcx().span_err(
400425
hir_param.ty_span,
401426
format!(
402-
"unsupported `#[spirv(spec_constant)]` type `{}` (expected `{}`)",
403-
ref_or_value_layout.ty, self.tcx.types.u32
427+
"unsupported `#[spirv(spec_constant)]` type `{}` (expected `u32` pr `[u32; N]`)",
428+
ref_or_value_layout.ty
404429
),
405430
);
406-
} else if let Some(storage_class) = attrs.storage_class {
407-
self.tcx.dcx().span_err(
408-
storage_class.span,
409-
"`#[spirv(spec_constant)]` cannot have a storage class",
410-
);
411-
} else {
412-
assert_eq!(storage_class, Ok(StorageClass::Input));
413-
assert!(!is_ref);
414-
storage_class = Err(spec_constant.value);
415431
}
416432
}
417433

@@ -448,18 +464,38 @@ impl<'tcx> CodegenCx<'tcx> {
448464
Ok(self.emit_global().id()),
449465
Err("entry-point interface variable is not a `#[spirv(spec_constant)]`"),
450466
),
451-
Err(SpecConstant { id, default }) => {
452-
let mut emit = self.emit_global();
453-
let spec_const_id =
454-
emit.spec_constant_bit32(value_spirv_type, default.unwrap_or(0));
455-
emit.decorate(
456-
spec_const_id,
457-
Decoration::SpecId,
458-
[Operand::LiteralBit32(id)],
459-
);
467+
Err(SpecConstant {
468+
id,
469+
default,
470+
array_count,
471+
}) => {
472+
let u32_ty = SpirvType::Integer(32, false).def(DUMMY_SP, self);
473+
let single = |id: u32| {
474+
let mut emit = self.emit_global();
475+
let spec_const_id = emit.spec_constant_bit32(u32_ty, default.unwrap_or(0));
476+
emit.decorate(
477+
spec_const_id,
478+
Decoration::SpecId,
479+
[Operand::LiteralBit32(id)],
480+
);
481+
spec_const_id
482+
};
483+
let param_word = if let Some(array_count) = array_count {
484+
let array = (0..array_count).map(|i| single(id + i)).collect::<Vec<_>>();
485+
let array_ty = SpirvType::Array {
486+
element: u32_ty,
487+
count: self.constant_u32(DUMMY_SP, array_count),
488+
}
489+
.def(DUMMY_SP, self);
490+
bx.emit()
491+
.composite_construct(array_ty, None, array)
492+
.unwrap()
493+
} else {
494+
single(id)
495+
};
460496
(
461497
Err("`#[spirv(spec_constant)]` is not an entry-point interface variable"),
462-
Ok(spec_const_id),
498+
Ok(param_word),
463499
)
464500
}
465501
};
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// Tests the various forms of `#[spirv(spec_constant)]`.
2+
3+
// build-pass
4+
// compile-flags: -C llvm-args=--disassemble
5+
// normalize-stderr-test "; .*\n" -> ""
6+
// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> ""
7+
// normalize-stderr-test "OpSource .*\n" -> ""
8+
// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple"
9+
10+
// HACK(eddyb) `compiletest` handles `ui\dis\`, but not `ui\\dis\\`, on Windows.
11+
// normalize-stderr-test "ui/dis/" -> "$$DIR/"
12+
13+
use spirv_std::spirv;
14+
15+
#[spirv(compute(threads(1)))]
16+
pub fn main(
17+
#[spirv(spec_constant(id = 42, default = 69))] array: [u32; 4],
18+
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] out: &mut u32,
19+
) {
20+
*out = array[0] + array[1] + array[2] + array[3];
21+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
OpCapability Shader
2+
OpMemoryModel Logical Simple
3+
OpEntryPoint GLCompute %1 "main" %2
4+
OpExecutionMode %1 LocalSize 1 1 1
5+
%3 = OpString "$DIR/spec_constant_array.rs"
6+
OpDecorate %4 Block
7+
OpMemberDecorate %4 0 Offset 0
8+
OpDecorate %2 Binding 0
9+
OpDecorate %2 DescriptorSet 0
10+
OpDecorate %5 SpecId 42
11+
OpDecorate %6 SpecId 43
12+
OpDecorate %7 SpecId 44
13+
OpDecorate %8 SpecId 45
14+
%9 = OpTypeInt 32 0
15+
%4 = OpTypeStruct %9
16+
%10 = OpTypePointer StorageBuffer %4
17+
%11 = OpTypeVoid
18+
%12 = OpTypeFunction %11
19+
%13 = OpTypePointer StorageBuffer %9
20+
%2 = OpVariable %10 StorageBuffer
21+
%14 = OpConstant %9 0
22+
%5 = OpSpecConstant %9 69
23+
%6 = OpSpecConstant %9 69
24+
%7 = OpSpecConstant %9 69
25+
%8 = OpSpecConstant %9 69
26+
%1 = OpFunction %11 None %12
27+
%15 = OpLabel
28+
OpLine %3 18 4
29+
%16 = OpInBoundsAccessChain %13 %2 %14
30+
OpLine %3 20 11
31+
%17 = OpIAdd %9 %5 %6
32+
%18 = OpIAdd %9 %17 %7
33+
OpLine %3 20 4
34+
%19 = OpIAdd %9 %18 %8
35+
OpStore %16 %19
36+
OpNoLine
37+
OpReturn
38+
OpFunctionEnd

0 commit comments

Comments
 (0)