Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/components/computation/WGSLShaders.ts
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ export const createShaders = (precision: Precision) => {
@group(0) @binding(2) var<storage, read_write> outputData: array<${precision}>;
@group(0) @binding(3) var<uniform> params: Params;

${isNaNFunc}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The addition of ${isNaNFunc} allows this shader to compile, but the logic inside the main function (lines 555-556) is incorrect when NaNs are present. The shader currently skips NaNs in the summation loop but still divides by the total dimLength to calculate the mean. This will result in an incorrect average. You should implement a counter to track the number of valid (non-NaN) samples and use that for the division.

Also, note that the isNaN implementation (line 15) returns true for both NaN and Infinity. While often desirable in data processing, ensure this matches your requirements as the function name suggests it only checks for NaN.

@compute @workgroup_size(16, 16, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let zStride = params.zStride;
Expand Down Expand Up @@ -615,6 +616,7 @@ export const createShaders = (precision: Precision) => {
@group(0) @binding(2) var<storage, read_write> outputData: array<${precision}>;
@group(0) @binding(3) var<uniform> params: Params;

${isNaNFunc}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the other reduction shaders, the logic in main (lines 667-668 and 682) is incorrect when NaNs are skipped. It uses dimLength for averaging and covariance calculation, which includes the skipped indices. A dynamic counter for valid samples should be used instead to ensure statistical correctness.

@compute @workgroup_size(16, 16, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let zStride = params.zStride;
Expand Down Expand Up @@ -697,6 +699,8 @@ export const createShaders = (precision: Precision) => {
@group(0) @binding(2) var<storage, read_write> outputData: array<${precision}>;
@group(0) @binding(3) var<uniform> params: Params;

${isNaNFunc}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic in this shader (line 774) uses dimLength as the sample count N. Since NaNs are skipped during the summation loops, N should be a count of the actual valid pairs processed to avoid incorrect correlation results.


@compute @workgroup_size(16, 16, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let zStride = params.zStride;
Expand Down Expand Up @@ -950,6 +954,7 @@ export const createShaders = (precision: Precision) => {
@group(0) @binding(2) var<storage, read_write> outputData: array<${precision}>;
@group(0) @binding(3) var<uniform> params: Params;

${isNaNFunc}
@compute @workgroup_size(4, 4, 4)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let zStride = params.zStride;
Expand Down Expand Up @@ -1055,7 +1060,7 @@ export const createShaders = (precision: Precision) => {
@group(0) @binding(1) var<storage, read> secondData: array<${precision}>;
@group(0) @binding(2) var<storage, read_write> outputData: array<${precision}>;
@group(0) @binding(3) var<uniform> params: Params;

${isNaNFunc}
@compute @workgroup_size(4, 4, 4)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let zStride = params.zStride;
Expand Down Expand Up @@ -1169,7 +1174,7 @@ export const createShaders = (precision: Precision) => {
@group(0) @binding(1) var<storage, read> secondData: array<${precision}>;
@group(0) @binding(2) var<storage, read_write> outputData: array<${precision}>;
@group(0) @binding(3) var<uniform> params: Params;

${isNaNFunc}
@compute @workgroup_size(4, 4, 4)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let zStride = params.zStride;
Expand Down Expand Up @@ -1411,7 +1416,7 @@ export const createShaders = (precision: Precision) => {
@group(0) @binding(0) var<storage, read> inputData: array<${precision}>;
@group(0) @binding(1) var<storage, read_write> outputData: array<f32>;
@group(0) @binding(2) var<uniform> params: Params;

${isNaNFunc}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The isNaN function is defined here but is not used anywhere in the CUMSUM3D shader body. If you don't intend to filter NaNs during the accumulation, this line should be removed to keep the shader code clean.

@compute @workgroup_size(4, 4, 4)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let zStride = params.zStride;
Expand Down
Loading