Skip to content

Commit ee21162

Browse files
jrpriceDawn LUCI CQ
authored andcommitted
[ir] Validate intermediate workgroup size products
When checking that the total workgroup size is less that UINT32_MAX, we were only checking the final x*y*z product. This may overflow a uint64_t and wrap around to be a valid uint32_t value, so we need to check the intermediate products instead. Fixed: 463283605 Change-Id: Ie4cb2354bc6693230b5b591d152ec3d95b0469c4 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/277935 Reviewed-by: Peter McNeeley <petermcneeley@google.com> Commit-Queue: James Price <jrprice@google.com> Commit-Queue: Peter McNeeley <petermcneeley@google.com> Auto-Submit: James Price <jrprice@google.com>
1 parent b06bb48 commit ee21162

2 files changed

Lines changed: 24 additions & 5 deletions

File tree

src/tint/lang/core/ir/validator.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2970,6 +2970,12 @@ void Validator::CheckWorkgroupSize(const Function* func) {
29702970
return;
29712971
}
29722972
total_size *= c->Value()->ValueAs<uint64_t>();
2973+
2974+
constexpr uint64_t kMaxGridSize = 0xffffffff;
2975+
if (total_size > kMaxGridSize) {
2976+
AddError(func) << "workgroup grid size cannot exceed 0x" << std::hex
2977+
<< kMaxGridSize;
2978+
}
29732979
continue;
29742980
}
29752981

@@ -3007,11 +3013,6 @@ void Validator::CheckWorkgroupSize(const Function* func) {
30073013

30083014
AddError(func) << "@workgroup_size must be an InstructionResult or a Constant";
30093015
}
3010-
3011-
constexpr uint64_t kMaxGridSize = 0xffffffff;
3012-
if (total_size > kMaxGridSize) {
3013-
AddError(func) << "workgroup grid size cannot exceed 0x" << std::hex << kMaxGridSize;
3014-
}
30153016
}
30163017

30173018
void Validator::CheckPositionPresentForVertexOutput(const Function* ep) {

src/tint/lang/core/ir/validator_function_test.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2498,6 +2498,24 @@ TEST_F(IR_ValidatorTest, Function_WorkgroupSize_ParamsTooLarge) {
24982498
)")) << res.Failure();
24992499
}
25002500

2501+
// Test the case where the intermediate workgroup product overflows a uint64_t and wraps back around
2502+
// to be a valid uint32_t value.
2503+
TEST_F(IR_ValidatorTest, Function_WorkgroupSize_ParamsTooLarge_U64Overflow) {
2504+
auto* f = ComputeEntryPoint();
2505+
f->SetWorkgroupSize(
2506+
{b.Constant(1526726656_i), b.Constant(1526726656_i), b.Constant(1526726656_i)});
2507+
2508+
b.Append(f->Block(), [&] { b.Unreachable(); });
2509+
2510+
auto res = ir::Validate(mod);
2511+
ASSERT_NE(res, Success);
2512+
EXPECT_THAT(res.Failure().reason,
2513+
testing::HasSubstr(R"(:1:1 error: workgroup grid size cannot exceed 0xffffffff
2514+
%f = @compute @workgroup_size(1526726656i, 1526726656i, 1526726656i) func():void {
2515+
^^
2516+
)")) << res.Failure();
2517+
}
2518+
25012519
TEST_F(IR_ValidatorTest, Function_WorkgroupSize_OverrideWithoutAllowOverrides) {
25022520
auto* o = b.Override(ty.u32());
25032521
auto* f = ComputeEntryPoint();

0 commit comments

Comments
 (0)