@@ -11,14 +11,15 @@ use rspirv::dr::Operand;
1111use rspirv:: spirv:: {
1212 Capability , Decoration , Dim , ExecutionModel , FunctionControl , StorageClass , Word ,
1313} ;
14+ use rustc_abi:: FieldsShape ;
1415use rustc_codegen_ssa:: traits:: { BaseTypeCodegenMethods , BuilderMethods , MiscCodegenMethods as _} ;
1516use rustc_data_structures:: fx:: FxHashMap ;
1617use rustc_errors:: MultiSpan ;
1718use rustc_hir as hir;
1819use rustc_middle:: span_bug;
1920use rustc_middle:: ty:: layout:: { LayoutOf , TyAndLayout } ;
2021use rustc_middle:: ty:: { self , Instance , Ty } ;
21- use rustc_span:: Span ;
22+ use rustc_span:: { DUMMY_SP , Span } ;
2223use rustc_target:: callconv:: { ArgAbi , FnAbi , PassMode } ;
2324use std:: assert_matches:: assert_matches;
2425
@@ -395,23 +396,38 @@ impl<'tcx> CodegenCx<'tcx> {
395396 // would've assumed it was actually an implicitly-`Input`.
396397 let mut storage_class = Ok ( storage_class) ;
397398 if let Some ( spec_constant) = attrs. spec_constant {
398- if ref_or_value_layout. ty != self . tcx . types . u32 {
399+ let ty = ref_or_value_layout;
400+ let valid_array_count = match ty. fields {
401+ FieldsShape :: Array { count, .. } => {
402+ let element = ty. field ( self , 0 ) ;
403+ ( element. ty == self . tcx . types . u32 ) . then_some ( u32:: try_from ( count) . ok ( ) )
404+ }
405+ FieldsShape :: Primitive => ( ty. ty == self . tcx . types . u32 ) . then_some ( None ) ,
406+ _ => None ,
407+ } ;
408+
409+ if let Some ( array_count) = valid_array_count {
410+ if let Some ( storage_class) = attrs. storage_class {
411+ self . tcx . dcx ( ) . span_err (
412+ storage_class. span ,
413+ "`#[spirv(spec_constant)]` cannot have a storage class" ,
414+ ) ;
415+ } else {
416+ assert_eq ! ( storage_class, Ok ( StorageClass :: Input ) ) ;
417+ assert ! ( !is_ref) ;
418+ storage_class = Err ( SpecConstant {
419+ array_count,
420+ ..spec_constant. value
421+ } ) ;
422+ }
423+ } else {
399424 self . tcx . dcx ( ) . span_err (
400425 hir_param. ty_span ,
401426 format ! (
402- "unsupported `#[spirv(spec_constant)]` type `{}` (expected `{} `)" ,
403- ref_or_value_layout. ty, self . tcx . types . u32
427+ "unsupported `#[spirv(spec_constant)]` type `{}` (expected `u32` or `[u32; N] `)" ,
428+ ref_or_value_layout. ty
404429 ) ,
405430 ) ;
406- } else if let Some ( storage_class) = attrs. storage_class {
407- self . tcx . dcx ( ) . span_err (
408- storage_class. span ,
409- "`#[spirv(spec_constant)]` cannot have a storage class" ,
410- ) ;
411- } else {
412- assert_eq ! ( storage_class, Ok ( StorageClass :: Input ) ) ;
413- assert ! ( !is_ref) ;
414- storage_class = Err ( spec_constant. value ) ;
415431 }
416432 }
417433
@@ -448,18 +464,38 @@ impl<'tcx> CodegenCx<'tcx> {
448464 Ok ( self . emit_global ( ) . id ( ) ) ,
449465 Err ( "entry-point interface variable is not a `#[spirv(spec_constant)]`" ) ,
450466 ) ,
451- Err ( SpecConstant { id, default } ) => {
452- let mut emit = self . emit_global ( ) ;
453- let spec_const_id =
454- emit. spec_constant_bit32 ( value_spirv_type, default. unwrap_or ( 0 ) ) ;
455- emit. decorate (
456- spec_const_id,
457- Decoration :: SpecId ,
458- [ Operand :: LiteralBit32 ( id) ] ,
459- ) ;
467+ Err ( SpecConstant {
468+ id,
469+ default,
470+ array_count,
471+ } ) => {
472+ let u32_ty = SpirvType :: Integer ( 32 , false ) . def ( DUMMY_SP , self ) ;
473+ let single = |id : u32 | {
474+ let mut emit = self . emit_global ( ) ;
475+ let spec_const_id = emit. spec_constant_bit32 ( u32_ty, default. unwrap_or ( 0 ) ) ;
476+ emit. decorate (
477+ spec_const_id,
478+ Decoration :: SpecId ,
479+ [ Operand :: LiteralBit32 ( id) ] ,
480+ ) ;
481+ spec_const_id
482+ } ;
483+ let param_word = if let Some ( array_count) = array_count {
484+ let array = ( 0 ..array_count) . map ( |i| single ( id + i) ) . collect :: < Vec < _ > > ( ) ;
485+ let array_ty = SpirvType :: Array {
486+ element : u32_ty,
487+ count : self . constant_u32 ( DUMMY_SP , array_count) ,
488+ }
489+ . def ( DUMMY_SP , self ) ;
490+ bx. emit ( )
491+ . composite_construct ( array_ty, None , array)
492+ . unwrap ( )
493+ } else {
494+ single ( id)
495+ } ;
460496 (
461497 Err ( "`#[spirv(spec_constant)]` is not an entry-point interface variable" ) ,
462- Ok ( spec_const_id ) ,
498+ Ok ( param_word ) ,
463499 )
464500 }
465501 } ;
0 commit comments