Skip to content

Commit 18a622d

Browse files
committed
feat: add smape metric
1 parent 12c6c73 commit 18a622d

1 file changed

Lines changed: 56 additions & 0 deletions

File tree

src/metrics_regression.rs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,19 @@ pub trait SingleTargetRegression<F: Float, T: AsSingleTargets<Elem = F>>:
8080
.ok_or(Error::NotEnoughSamples)
8181
}
8282

83+
/// Symmetric mean absolute percentage error between two continuous variables
84+
/// sMAPE = 1/N * SUM(abs(y_hat - y) / ((abs(y) + abs(y_hat)) / 2))
85+
fn symmetric_mean_absolute_percentage_error(&self, compare_to: &T) -> Result<F> {
86+
let y = self.as_single_targets();
87+
let y_hat = compare_to.as_single_targets();
88+
let abs_diff = (&y_hat - &y).mapv(|x| x.abs());
89+
let abs_sum = (y.mapv(|x| x.abs()) + y_hat.mapv(|x| x.abs())) + F::cast(1e-10);
90+
(abs_diff / abs_sum)
91+
.mapv(|x| x * F::cast(2.0))
92+
.mean()
93+
.ok_or(Error::NotEnoughSamples)
94+
}
95+
8396
/// R squared coefficient, is the proportion of the variance in the dependent variable that is
8497
/// predictable from the independent variable
8598
// r2 = 1 - sum((pred_i - y_i)^2)/sum((mean_y - y_i)^2)
@@ -193,6 +206,16 @@ pub trait MultiTargetRegression<F: Float, T: AsMultiTargets<Elem = F>>:
193206
.collect()
194207
}
195208

209+
/// Symmetric mean absolute percentage error between two continuous variables
210+
/// sMAPE = 1/N * SUM(abs(y_hat - y) / ((abs(y) + abs(y_hat)) / 2))
211+
fn symmetric_mean_absolute_percentage_error(&self, other: &T) -> Result<Array1<F>> {
212+
self.as_multi_targets()
213+
.axis_iter(Axis(1))
214+
.zip(other.as_multi_targets().axis_iter(Axis(1)))
215+
.map(|(a, b)| a.symmetric_mean_absolute_percentage_error(&b))
216+
.collect()
217+
}
218+
196219
/// R squared coefficient, is the proportion of the variance in the dependent variable that is
197220
/// predictable from the independent variable
198221
fn r2(&self, other: &T) -> Result<Array1<F>> {
@@ -242,6 +265,10 @@ mod tests {
242265
assert_abs_diff_eq!(a.r2(&a).unwrap(), 1.0f32);
243266
assert_abs_diff_eq!(a.explained_variance(&a).unwrap(), 1.0f32);
244267
assert_abs_diff_eq!(a.mean_absolute_percentage_error(&a).unwrap(), 0.0f32);
268+
assert_abs_diff_eq!(
269+
a.symmetric_mean_absolute_percentage_error(&a).unwrap(),
270+
0.0f32
271+
);
245272
}
246273

247274
#[test]
@@ -281,6 +308,18 @@ mod tests {
281308
);
282309
}
283310

311+
#[test]
312+
fn test_symmetric_mean_absolute_percentage_error() {
313+
let a = array![0.5, 0.1, 0.2, 0.3, 0.4];
314+
let b = array![0.1, 0.2, 0.3, 0.4, 0.5];
315+
316+
assert_abs_diff_eq!(
317+
a.symmetric_mean_absolute_percentage_error(&b).unwrap(),
318+
0.5815873014693111,
319+
epsilon = 1e-5
320+
);
321+
}
322+
284323
#[test]
285324
fn test_max_error_for_single_targets() {
286325
let records = array![[0.0, 0.0], [0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4]];
@@ -339,6 +378,23 @@ mod tests {
339378
assert_abs_diff_eq!(pct_err_from_arr1, pct_err_from_ds);
340379
}
341380

381+
#[test]
382+
fn test_symmetric_mean_absolute_percentage_error_for_single_targets() {
383+
let records = array![[0.0, 0.0], [0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4]];
384+
let targets = array![0.0, 0.1, 0.2, 0.3, 0.4];
385+
let st_dataset: DatasetBase<_, _> = (records.view(), targets).into();
386+
let prediction = array![0.1, 0.3, 0.2, 0.5, 0.7];
387+
let err_from_arr = prediction
388+
.symmetric_mean_absolute_percentage_error(st_dataset.targets())
389+
.unwrap();
390+
let prediction_ds: DatasetBase<_, _> = (records.view(), prediction).into();
391+
let err_from_ds = prediction_ds
392+
.symmetric_mean_absolute_percentage_error(&st_dataset)
393+
.unwrap();
394+
assert_abs_diff_eq!(err_from_arr, 0.8090909086184916, epsilon = 1e-5);
395+
assert_abs_diff_eq!(err_from_arr, err_from_ds);
396+
}
397+
342398
#[test]
343399
fn test_mean_squared_log_error_for_single_targets() {
344400
let records = array![[0.0, 0.0], [0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4]];

0 commit comments

Comments
 (0)