Conversation
There was a problem hiding this comment.
Code Review
This pull request injects the ${isNaNFunc} helper into several WGSL shader templates to facilitate NaN handling. While this allows the shaders to compile, the feedback highlights significant logic errors in the reduction, covariance, and correlation shaders: although NaNs are skipped during summation, the final calculations still divide by the total dimension length rather than a count of valid samples, leading to incorrect statistical results. Additionally, the helper is injected into the CUMSUM3D shader without being used, and the isNaN implementation's behavior of treating Infinity as NaN should be verified against the intended requirements.
| @group(0) @binding(2) var<storage, read_write> outputData: array<${precision}>; | ||
| @group(0) @binding(3) var<uniform> params: Params; | ||
|
|
||
| ${isNaNFunc} |
There was a problem hiding this comment.
The addition of ${isNaNFunc} allows this shader to compile, but the logic inside the main function (lines 555-556) is incorrect when NaNs are present. The shader currently skips NaNs in the summation loop but still divides by the total dimLength to calculate the mean. This will result in an incorrect average. You should implement a counter to track the number of valid (non-NaN) samples and use that for the division.
Also, note that the isNaN implementation (line 15) returns true for both NaN and Infinity. While often desirable in data processing, ensure this matches your requirements as the function name suggests it only checks for NaN.
| @group(0) @binding(2) var<storage, read_write> outputData: array<${precision}>; | ||
| @group(0) @binding(3) var<uniform> params: Params; | ||
|
|
||
| ${isNaNFunc} |
There was a problem hiding this comment.
Similar to the other reduction shaders, the logic in main (lines 667-668 and 682) is incorrect when NaNs are skipped. It uses dimLength for averaging and covariance calculation, which includes the skipped indices. A dynamic counter for valid samples should be used instead to ensure statistical correctness.
| @group(0) @binding(2) var<storage, read_write> outputData: array<${precision}>; | ||
| @group(0) @binding(3) var<uniform> params: Params; | ||
|
|
||
| ${isNaNFunc} |
| @group(0) @binding(1) var<storage, read_write> outputData: array<f32>; | ||
| @group(0) @binding(2) var<uniform> params: Params; | ||
|
|
||
| ${isNaNFunc} |
|
Will address the concerns later |
The 2D shaders dont use a boilerplate which is where I was storing the isNan bitcheck function. So those shaders weren't working