Skip to content

Commit 29ecf44

Browse files
committed
error metrics fix: printed force mae/rmse was 3x higher than it is last few weeks
1 parent a4fdfe9 commit 29ecf44

1 file changed

Lines changed: 4 additions & 3 deletions

File tree

so3lr/cli/so3lr_eval.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,10 @@ def calculate_metrics(
9696

9797
diff = jnp.asarray(y_pred) - jnp.asarray(y_true)
9898

99-
abs_sum = jnp.abs(diff[mask]).sum()
100-
sq_sum = (diff[mask] ** 2).sum()
101-
count = jnp.asarray(mask).sum()
99+
diff_masked = diff[mask]
100+
abs_sum = jnp.abs(diff_masked).sum()
101+
sq_sum = (diff_masked ** 2).sum()
102+
count = diff_masked.size
102103

103104
# Store as Python scalars to avoid tracer issues later
104105
metrics[f"{target}_abs_sum"] = float(abs_sum)

0 commit comments

Comments
 (0)