-
Notifications
You must be signed in to change notification settings - Fork 3
Added IsNan to shaders w/out boilerplate #636
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -492,6 +492,7 @@ export const createShaders = (precision: Precision) => { | |
| @group(0) @binding(2) var<storage, read_write> outputData: array<${precision}>; | ||
| @group(0) @binding(3) var<uniform> params: Params; | ||
|
|
||
| ${isNaNFunc} | ||
| @compute @workgroup_size(16, 16, 1) | ||
| fn main(@builtin(global_invocation_id) global_id: vec3<u32>) { | ||
| let zStride = params.zStride; | ||
|
|
@@ -615,6 +616,7 @@ export const createShaders = (precision: Precision) => { | |
| @group(0) @binding(2) var<storage, read_write> outputData: array<${precision}>; | ||
| @group(0) @binding(3) var<uniform> params: Params; | ||
|
|
||
| ${isNaNFunc} | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the other reduction shaders, the logic in |
||
| @compute @workgroup_size(16, 16, 1) | ||
| fn main(@builtin(global_invocation_id) global_id: vec3<u32>) { | ||
| let zStride = params.zStride; | ||
|
|
@@ -697,6 +699,8 @@ export const createShaders = (precision: Precision) => { | |
| @group(0) @binding(2) var<storage, read_write> outputData: array<${precision}>; | ||
| @group(0) @binding(3) var<uniform> params: Params; | ||
|
|
||
| ${isNaNFunc} | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| @compute @workgroup_size(16, 16, 1) | ||
| fn main(@builtin(global_invocation_id) global_id: vec3<u32>) { | ||
| let zStride = params.zStride; | ||
|
|
@@ -950,6 +954,7 @@ export const createShaders = (precision: Precision) => { | |
| @group(0) @binding(2) var<storage, read_write> outputData: array<${precision}>; | ||
| @group(0) @binding(3) var<uniform> params: Params; | ||
|
|
||
| ${isNaNFunc} | ||
| @compute @workgroup_size(4, 4, 4) | ||
| fn main(@builtin(global_invocation_id) global_id: vec3<u32>) { | ||
| let zStride = params.zStride; | ||
|
|
@@ -1055,7 +1060,7 @@ export const createShaders = (precision: Precision) => { | |
| @group(0) @binding(1) var<storage, read> secondData: array<${precision}>; | ||
| @group(0) @binding(2) var<storage, read_write> outputData: array<${precision}>; | ||
| @group(0) @binding(3) var<uniform> params: Params; | ||
|
|
||
| ${isNaNFunc} | ||
| @compute @workgroup_size(4, 4, 4) | ||
| fn main(@builtin(global_invocation_id) global_id: vec3<u32>) { | ||
| let zStride = params.zStride; | ||
|
|
@@ -1169,7 +1174,7 @@ export const createShaders = (precision: Precision) => { | |
| @group(0) @binding(1) var<storage, read> secondData: array<${precision}>; | ||
| @group(0) @binding(2) var<storage, read_write> outputData: array<${precision}>; | ||
| @group(0) @binding(3) var<uniform> params: Params; | ||
|
|
||
| ${isNaNFunc} | ||
| @compute @workgroup_size(4, 4, 4) | ||
| fn main(@builtin(global_invocation_id) global_id: vec3<u32>) { | ||
| let zStride = params.zStride; | ||
|
|
@@ -1411,7 +1416,7 @@ export const createShaders = (precision: Precision) => { | |
| @group(0) @binding(0) var<storage, read> inputData: array<${precision}>; | ||
| @group(0) @binding(1) var<storage, read_write> outputData: array<f32>; | ||
| @group(0) @binding(2) var<uniform> params: Params; | ||
|
|
||
| ${isNaNFunc} | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| @compute @workgroup_size(4, 4, 4) | ||
| fn main(@builtin(global_invocation_id) global_id: vec3<u32>) { | ||
| let zStride = params.zStride; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The addition of
${isNaNFunc}allows this shader to compile, but the logic inside themainfunction (lines 555-556) is incorrect when NaNs are present. The shader currently skips NaNs in the summation loop but still divides by the totaldimLengthto calculate the mean. This will result in an incorrect average. You should implement a counter to track the number of valid (non-NaN) samples and use that for the division.Also, note that the
isNaNimplementation (line 15) returnstruefor both NaN and Infinity. While often desirable in data processing, ensure this matches your requirements as the function name suggests it only checks for NaN.