Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 8 additions & 14 deletions codegen/masm/src/emit/int32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,15 +296,12 @@ impl OpEmitter<'_> {
/// Execution traps if the value cannot fit in the unsigned N-bit range.
pub fn int32_to_uint(&mut self, n: u32, span: SourceSpan) {
assert_valid_integer_size!(n, 1, 32);
// Mask the value and ensure that the unused bits above the N-bit range are 0
let reserved = 32 - n;
let mask = (2u32.pow(reserved) - 1) << n;
// Copy the input
self.emit(masm::Instruction::Dup1, span);
// Apply the mask
// A 32-bit target has no unused high bits, so use a zero mask without shifting by 32.
let mask = if n == 32 { 0 } else { u32::MAX << n };
// The range check must inspect the top value, not another live value below it.
self.emit(masm::Instruction::Dup0, span);
self.emit_push(mask, span);
self.emit(masm::Instruction::U32And, span);
// Assert the masked value is all 0s
self.emit(
Self::assertz_with_message_inst(
format!("value does not fit in unsigned {n}-bit range"),
Expand All @@ -319,15 +316,12 @@ impl OpEmitter<'_> {
/// Places a boolean on top of the stack indicating if the conversion was successful
pub fn try_int32_to_uint(&mut self, n: u32, span: SourceSpan) {
assert_valid_integer_size!(n, 1, 32);
// Mask the value and ensure that the unused bits above the N-bit range are 0
let reserved = 32 - n;
let mask = (2u32.pow(reserved) - 1) << n;
// Copy the input
self.emit(masm::Instruction::Dup1, span);
// Apply the mask
// A 32-bit target has no unused high bits, so use a zero mask without shifting by 32.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// A 32-bit target has no unused high bits, so use a zero mask without shifting by 32.
// If the target bit width is 32, then use an empty mask

let mask = if n == 32 { 0 } else { u32::MAX << n };
// The range check must inspect the top value, not another live value below it.
self.emit(masm::Instruction::Dup0, span);
self.emit_push(mask, span);
self.emit(masm::Instruction::U32And, span);
// Assert the masked value is all 0s
self.emit(masm::Instruction::EqImm(Felt::ZERO.into()), span);
}

Expand Down
47 changes: 47 additions & 0 deletions codegen/masm/src/emit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1906,6 +1906,53 @@ mod tests {
assert_eq!(emitter.stack()[0], Type::I32);
}

#[test]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is essentially tautological - IMO it isn't useful, so I'd remove it. Tests should be behavioral (i.e. test that the output is correct for specific inputs), and when based on specific prior failures, we can specifically exercise those as a form of regression testing.

fn op_emitter_full_width_uint_range_checks_use_zero_mask() {
let span = SourceSpan::UNKNOWN;

let mut block = Vec::default();
let context = Rc::new(Context::default());
let mut stack = OperandStack::new(context.clone());
let mut invoked = BTreeSet::default();
let mut emitter = OpEmitter::new(&mut invoked, &mut block, &mut stack);

emitter.push(Type::U32);
emitter.int32_to_uint(32, span);

{
let ops = emitter.current_block();
assert_eq!(ops.len(), 4);
assert_eq!(&ops[0], &Op::Inst(Span::new(span, masm::Instruction::Dup0)));
assert_eq!(&ops[1], &push!(0u32));
assert_eq!(&ops[2], &Op::Inst(Span::new(span, masm::Instruction::U32And)));
assert!(matches!(
&ops[3],
Op::Inst(inst)
if matches!(inst.inner(), masm::Instruction::AssertzWithError(_))
));
}

let mut block = Vec::default();
let mut stack = OperandStack::new(context);
let mut invoked = BTreeSet::default();
let mut emitter = OpEmitter::new(&mut invoked, &mut block, &mut stack);

emitter.push(Type::U32);
emitter.try_int32_to_uint(32, span);

{
let ops = emitter.current_block();
assert_eq!(ops.len(), 4);
assert_eq!(&ops[0], &Op::Inst(Span::new(span, masm::Instruction::Dup0)));
assert_eq!(&ops[1], &push!(0u32));
assert_eq!(&ops[2], &Op::Inst(Span::new(span, masm::Instruction::U32And)));
assert_eq!(
&ops[3],
&Op::Inst(Span::new(span, masm::Instruction::EqImm(Felt::ZERO.into())))
);
}
}

#[test]
fn op_emitter_u32_inttoptr_test() {
let mut block = Vec::default();
Expand Down
210 changes: 210 additions & 0 deletions tests/integration/src/codegen/int32.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
use std::{
panic::{self, AssertUnwindSafe},
rc::Rc,
sync::Arc,
};

use miden_mast_package::Package;
use midenc_dialect_arith::ArithOpBuilder;
use midenc_dialect_hir::HirOpBuilder;
use midenc_hir::{Context, Felt, SourceSpan, Type, ValueRef, dialects::builtin::BuiltinOpBuilder};

use crate::testing::{compile_test_module, eval_package};

const HIGH_BIT_VALUE: u32 = 1 << 31;

fn compile_guarded_int32_cast(source_ty: Type, target_ty: Type) -> (Arc<Package>, Rc<Context>) {
let span = SourceSpan::default();
let cast_target_ty = target_ty.clone();

compile_test_module(
[source_ty.clone(), source_ty.clone(), source_ty],
[target_ty],
move |builder| {
let block = builder.current_block();
let expected_guard = block.borrow().arguments()[0] as ValueRef;
let live_guard = block.borrow().arguments()[1] as ValueRef;
let value = block.borrow().arguments()[2] as ValueRef;

let narrowed = builder.cast(value, cast_target_ty.clone(), span).unwrap();

// Use both guards after the cast so they stay live below the value while the
// narrowing check is emitted. The check must inspect `value`, not either guard.
builder.assert_eq(live_guard, expected_guard, span).unwrap();
builder.ret(Some(narrowed), span).unwrap();
},
)
}

fn compile_guarded_u8_overflowing_add() -> (Arc<Package>, Rc<Context>) {
let span = SourceSpan::default();

compile_test_module([Type::U32, Type::U32, Type::U8, Type::U8], [Type::I1], |builder| {
let block = builder.current_block();
let expected_guard = block.borrow().arguments()[0] as ValueRef;
let live_guard = block.borrow().arguments()[1] as ValueRef;
let lhs = block.borrow().arguments()[2] as ValueRef;
let rhs = block.borrow().arguments()[3] as ValueRef;

let (overflowed, _sum) = builder.add_overflowing(lhs, rhs, span).unwrap();
// Keep both guards live below the sum so overflowing arithmetic must validate the sum,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The wording of this is pretty unclear, because it is referring to a specific failure mode of the original code that is fixed. I would remove this comment, and instead provide a comment that explains the purpose of the guard values (which is essentially to have stuff on the operand stack that would cause an error if values on the operand stack were consumed incorrectly).

// not a live value below it.
builder.assert_eq(live_guard, expected_guard, span).unwrap();
builder.ret(Some(overflowed), span).unwrap();
})
}

fn try_eval_guarded_cast(
package: &Package,
context: &Context,
args: [u32; 3],
) -> Result<u32, String> {
let args = args.map(|arg| Felt::new_unchecked(u64::from(arg)));
panic::catch_unwind(AssertUnwindSafe(|| {
eval_package::<u32, _, _>(package, None, &args, context.session(), |_| Ok(()))
}))
.map_err(panic_payload_to_string)?
.map_err(|err| format!("{err:?}"))
}

fn panic_payload_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
if let Some(message) = payload.downcast_ref::<String>() {
message.clone()
} else if let Some(message) = payload.downcast_ref::<&str>() {
message.to_string()
} else {
"unknown panic".to_string()
}
}

fn eval_guarded_u8_overflowing_add(package: &Package, context: &Context, args: [u32; 4]) -> u32 {
let args = args.map(|arg| Felt::new_unchecked(u64::from(arg)));
eval_package::<u32, _, _>(package, None, &args, context.session(), |_| Ok(())).unwrap()
}

#[track_caller]
fn assert_cast_succeeds(
package: &Package,
context: &Context,
source_name: &str,
target_name: &str,
args: [u32; 3],
expected: u32,
) {
let actual = try_eval_guarded_cast(package, context, args).unwrap_or_else(|err| {
panic!(
"expected checked {source_name}-to-{target_name} cast of {} to succeed, got: {err}",
args[2],
)
});

assert_eq!(
actual, expected,
"checked {source_name}-to-{target_name} cast returned the wrong value"
);
}

#[track_caller]
fn assert_cast_traps(
package: &Package,
context: &Context,
source_name: &str,
target_name: &str,
args: [u32; 3],
) {
match try_eval_guarded_cast(package, context, args) {
Ok(actual) => panic!(
"expected checked {source_name}-to-{target_name} cast of {} to trap, but returned \
{actual}",
args[2]
),
Err(err) => assert!(
err.contains("does not fit in unsigned"),
"expected checked {source_name}-to-{target_name} cast of {} to fail the unsigned \
range check, got: {err}",
args[2]
),
}
}

#[track_caller]
fn assert_overflowing_add_flag(
package: &Package,
context: &Context,
args: [u32; 4],
expected_overflowed: bool,
) {
let actual = eval_guarded_u8_overflowing_add(package, context, args);

assert_eq!(
actual,
u32::from(expected_overflowed),
"overflow flag for guarded u8 overflowing add was incorrect"
);
}

#[track_caller]
fn assert_guarded_int32_cast(
source_ty: Type,
source_name: &str,
target_ty: Type,
target_name: &str,
max: u32,
first_invalid: u32,
) {
// Keep the high-bit guard representable as an i32 value while still setting a bit
// outside every narrower unsigned target range covered by this test.
let (package, context) = compile_guarded_int32_cast(source_ty, target_ty);

assert_cast_succeeds(
&package,
&context,
source_name,
target_name,
[HIGH_BIT_VALUE, HIGH_BIT_VALUE, 0],
0,
);
assert_cast_succeeds(
&package,
&context,
source_name,
target_name,
[HIGH_BIT_VALUE, HIGH_BIT_VALUE, max],
max,
);
assert_cast_traps(&package, &context, source_name, target_name, [0, 0, first_invalid]);
assert_cast_traps(&package, &context, source_name, target_name, [0, 0, HIGH_BIT_VALUE]);
}

#[test]
fn checked_int32_to_unsigned_narrowing_checks_the_cast_operand() {
for (source_ty, source_name) in [(Type::U32, "u32"), (Type::I32, "i32")] {
for (target_ty, target_name, max, first_invalid) in [
(Type::I1, "i1", 1u32, 2u32),
(Type::U8, "u8", u32::from(u8::MAX), u32::from(u8::MAX) + 1),
(Type::U16, "u16", u32::from(u16::MAX), u32::from(u16::MAX) + 1),
] {
assert_guarded_int32_cast(
source_ty.clone(),
source_name,
target_ty,
target_name,
max,
first_invalid,
);
}
}
}

#[test]
fn overflowing_u8_add_checks_the_sum_being_narrowed() {
let (package, context) = compile_guarded_u8_overflowing_add();

assert_overflowing_add_flag(
&package,
&context,
[HIGH_BIT_VALUE, HIGH_BIT_VALUE, u32::from(u8::MAX) - 1, 1],
false,
);
assert_overflowing_add_flag(&package, &context, [0, 0, u32::from(u8::MAX), 1], true);
}
1 change: 1 addition & 0 deletions tests/integration/src/codegen/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod control_flow;
mod int32;
mod memory;
mod wasm;