We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent a4fdfe9 commit 29ecf44Copy full SHA for 29ecf44
1 file changed
so3lr/cli/so3lr_eval.py
@@ -96,9 +96,10 @@ def calculate_metrics(
96
97
diff = jnp.asarray(y_pred) - jnp.asarray(y_true)
98
99
- abs_sum = jnp.abs(diff[mask]).sum()
100
- sq_sum = (diff[mask] ** 2).sum()
101
- count = jnp.asarray(mask).sum()
+ diff_masked = diff[mask]
+ abs_sum = jnp.abs(diff_masked).sum()
+ sq_sum = (diff_masked ** 2).sum()
102
+ count = diff_masked.size
103
104
# Store as Python scalars to avoid tracer issues later
105
metrics[f"{target}_abs_sum"] = float(abs_sum)
0 commit comments