Skip to content

Commit 0780345

Browse files
committed
promote u8 to u32 if needed
1 parent 3f7cb64 commit 0780345

3 files changed

Lines changed: 91 additions & 1 deletion

File tree

crates/rustc_codegen_spirv/src/linker/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,11 @@ pub fn link(
492492
simple_passes::remove_non_uniform_decorations(sess, &mut output)?;
493493
}
494494

495+
{
496+
let _timer = sess.timer("link_promote_int8_to_int32");
497+
simple_passes::promote_int8_to_int32(&mut output);
498+
}
499+
495500
// NOTE(eddyb) SPIR-T pipeline is entirely limited to this block.
496501
{
497502
let (spv_words, module_or_err, lower_from_spv_timer) =

crates/rustc_codegen_spirv/src/linker/simple_passes.rs

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use super::{get_name, get_names};
2-
use rspirv::dr::{Block, Function, Module};
2+
use rspirv::dr::{Block, Function, Module, Operand};
33
use rspirv::spirv::{Decoration, ExecutionModel, Op, Word};
44
use rustc_codegen_spirv_types::Capability;
55
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
@@ -365,3 +365,60 @@ pub fn remove_non_uniform_decorations(_sess: &Session, module: &mut Module) -> s
365365
}
366366
Ok(())
367367
}
368+
369+
/// When `OpCapability Int8` is not declared, promote all `i8`/`u8` types to `i32`/`u32`.
370+
pub fn promote_int8_to_int32(module: &mut Module) {
371+
let has_int8 = module.capabilities.iter().any(|inst| {
372+
inst.class.opcode == Op::Capability
373+
&& inst.operands[0].unwrap_capability() == Capability::Int8
374+
});
375+
if has_int8 {
376+
return;
377+
}
378+
379+
let narrow_types: FxHashMap<Word, u32> = module
380+
.types_global_values
381+
.iter()
382+
.filter_map(|inst| {
383+
if inst.class.opcode == Op::TypeInt
384+
&& inst.operands[0].unwrap_literal_bit32() == 8
385+
{
386+
let signedness = inst.operands[1].unwrap_literal_bit32();
387+
Some((inst.result_id?, signedness))
388+
} else {
389+
None
390+
}
391+
})
392+
.collect();
393+
394+
if narrow_types.is_empty() {
395+
return;
396+
}
397+
398+
for inst in &mut module.types_global_values {
399+
// widen each 8-bit OpTypeInt to 32 bits
400+
if inst.class.opcode == Op::TypeInt {
401+
if let Some(id) = inst.result_id {
402+
if narrow_types.contains_key(&id) {
403+
inst.operands[0] = Operand::LiteralBit32(32);
404+
}
405+
}
406+
}
407+
408+
// fix OpConstant values: sign-extend signed 8-bit constants to 32 bits.
409+
if inst.class.opcode == Op::Constant {
410+
if let Some(ty) = inst.result_type {
411+
if let Some(&signedness) = narrow_types.get(&ty) {
412+
if let Operand::LiteralBit32(ref mut val) = inst.operands[0] {
413+
let narrow = *val as u8;
414+
*val = if signedness != 0 {
415+
(narrow as i8 as i32) as u32
416+
} else {
417+
narrow as u32
418+
};
419+
}
420+
}
421+
}
422+
}
423+
}
424+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// build-pass
2+
//PartialOrd on CustomPosition(u32) internally returns Option<Ordering>,
3+
//where Ordering is represented as i8 in Rust's layout.
4+
//This caused rust-gpu to emit OpTypeInt 8 declarations requiring OpCapability Int8
5+
#![no_std]
6+
7+
use spirv_std::{glam::Vec4, spirv};
8+
9+
pub struct ShaderInputs {
10+
pub x: CustomPosition,
11+
pub y: CustomPosition,
12+
}
13+
14+
#[derive(Debug, Clone, Copy, Ord, PartialOrd, Eq, PartialEq, )]
15+
pub struct CustomPosition(u32);
16+
17+
18+
#[spirv(vertex)]
19+
pub fn test_vs(
20+
#[spirv(push_constant)] inputs: &ShaderInputs,
21+
#[spirv(position)] out_pos: &mut Vec4
22+
) {
23+
let mut result:f32 = 0.;
24+
if inputs.x < inputs.y{
25+
result =1.0;
26+
}
27+
*out_pos = Vec4::new(inputs.x.0 as f32, inputs.y.0 as f32, result as f32, 1.0);
28+
}

0 commit comments

Comments
 (0)