|
1 | 1 | use super::{get_name, get_names}; |
2 | | -use rspirv::dr::{Block, Function, Module}; |
| 2 | +use rspirv::dr::{Block, Function, Module, Operand}; |
3 | 3 | use rspirv::spirv::{Decoration, ExecutionModel, Op, Word}; |
4 | 4 | use rustc_codegen_spirv_types::Capability; |
5 | 5 | use rustc_data_structures::fx::{FxHashMap, FxHashSet}; |
@@ -365,3 +365,60 @@ pub fn remove_non_uniform_decorations(_sess: &Session, module: &mut Module) -> s |
365 | 365 | } |
366 | 366 | Ok(()) |
367 | 367 | } |
| 368 | + |
| 369 | +/// When `OpCapability Int8` is not declared, promote all `i8`/`u8` types to `i32`/`u32`. |
| 370 | +pub fn promote_int8_to_int32(module: &mut Module) { |
| 371 | + let has_int8 = module.capabilities.iter().any(|inst| { |
| 372 | + inst.class.opcode == Op::Capability |
| 373 | + && inst.operands[0].unwrap_capability() == Capability::Int8 |
| 374 | + }); |
| 375 | + if has_int8 { |
| 376 | + return; |
| 377 | + } |
| 378 | + |
| 379 | + let narrow_types: FxHashMap<Word, u32> = module |
| 380 | + .types_global_values |
| 381 | + .iter() |
| 382 | + .filter_map(|inst| { |
| 383 | + if inst.class.opcode == Op::TypeInt |
| 384 | + && inst.operands[0].unwrap_literal_bit32() == 8 |
| 385 | + { |
| 386 | + let signedness = inst.operands[1].unwrap_literal_bit32(); |
| 387 | + Some((inst.result_id?, signedness)) |
| 388 | + } else { |
| 389 | + None |
| 390 | + } |
| 391 | + }) |
| 392 | + .collect(); |
| 393 | + |
| 394 | + if narrow_types.is_empty() { |
| 395 | + return; |
| 396 | + } |
| 397 | + |
| 398 | + for inst in &mut module.types_global_values { |
| 399 | + // widen each 8-bit OpTypeInt to 32 bits |
| 400 | + if inst.class.opcode == Op::TypeInt { |
| 401 | + if let Some(id) = inst.result_id { |
| 402 | + if narrow_types.contains_key(&id) { |
| 403 | + inst.operands[0] = Operand::LiteralBit32(32); |
| 404 | + } |
| 405 | + } |
| 406 | + } |
| 407 | + |
| 408 | + // fix OpConstant values: sign-extend signed 8-bit constants to 32 bits. |
| 409 | + if inst.class.opcode == Op::Constant { |
| 410 | + if let Some(ty) = inst.result_type { |
| 411 | + if let Some(&signedness) = narrow_types.get(&ty) { |
| 412 | + if let Operand::LiteralBit32(ref mut val) = inst.operands[0] { |
| 413 | + let narrow = *val as u8; |
| 414 | + *val = if signedness != 0 { |
| 415 | + (narrow as i8 as i32) as u32 |
| 416 | + } else { |
| 417 | + narrow as u32 |
| 418 | + }; |
| 419 | + } |
| 420 | + } |
| 421 | + } |
| 422 | + } |
| 423 | + } |
| 424 | +} |
0 commit comments