@@ -80,6 +80,19 @@ pub trait SingleTargetRegression<F: Float, T: AsSingleTargets<Elem = F>>:
8080 . ok_or ( Error :: NotEnoughSamples )
8181 }
8282
83+ /// Symmetric mean absolute percentage error between two continuous variables
84+ /// sMAPE = 1/N * SUM(abs(y_hat - y) / ((abs(y) + abs(y_hat)) / 2))
85+ fn symmetric_mean_absolute_percentage_error ( & self , compare_to : & T ) -> Result < F > {
86+ let y = self . as_single_targets ( ) ;
87+ let y_hat = compare_to. as_single_targets ( ) ;
88+ let abs_diff = ( & y_hat - & y) . mapv ( |x| x. abs ( ) ) ;
89+ let abs_sum = ( y. mapv ( |x| x. abs ( ) ) + y_hat. mapv ( |x| x. abs ( ) ) ) + F :: cast ( 1e-10 ) ;
90+ ( abs_diff / abs_sum)
91+ . mapv ( |x| x * F :: cast ( 2.0 ) )
92+ . mean ( )
93+ . ok_or ( Error :: NotEnoughSamples )
94+ }
95+
8396 /// R squared coefficient, is the proportion of the variance in the dependent variable that is
8497 /// predictable from the independent variable
8598 // r2 = 1 - sum((pred_i - y_i)^2)/sum((mean_y - y_i)^2)
@@ -193,6 +206,16 @@ pub trait MultiTargetRegression<F: Float, T: AsMultiTargets<Elem = F>>:
193206 . collect ( )
194207 }
195208
209+ /// Symmetric mean absolute percentage error between two continuous variables
210+ /// sMAPE = 1/N * SUM(abs(y_hat - y) / ((abs(y) + abs(y_hat)) / 2))
211+ fn symmetric_mean_absolute_percentage_error ( & self , other : & T ) -> Result < Array1 < F > > {
212+ self . as_multi_targets ( )
213+ . axis_iter ( Axis ( 1 ) )
214+ . zip ( other. as_multi_targets ( ) . axis_iter ( Axis ( 1 ) ) )
215+ . map ( |( a, b) | a. symmetric_mean_absolute_percentage_error ( & b) )
216+ . collect ( )
217+ }
218+
196219 /// R squared coefficient, is the proportion of the variance in the dependent variable that is
197220 /// predictable from the independent variable
198221 fn r2 ( & self , other : & T ) -> Result < Array1 < F > > {
@@ -242,6 +265,10 @@ mod tests {
242265 assert_abs_diff_eq ! ( a. r2( & a) . unwrap( ) , 1.0f32 ) ;
243266 assert_abs_diff_eq ! ( a. explained_variance( & a) . unwrap( ) , 1.0f32 ) ;
244267 assert_abs_diff_eq ! ( a. mean_absolute_percentage_error( & a) . unwrap( ) , 0.0f32 ) ;
268+ assert_abs_diff_eq ! (
269+ a. symmetric_mean_absolute_percentage_error( & a) . unwrap( ) ,
270+ 0.0f32
271+ ) ;
245272 }
246273
247274 #[ test]
@@ -281,6 +308,18 @@ mod tests {
281308 ) ;
282309 }
283310
311+ #[ test]
312+ fn test_symmetric_mean_absolute_percentage_error ( ) {
313+ let a = array ! [ 0.5 , 0.1 , 0.2 , 0.3 , 0.4 ] ;
314+ let b = array ! [ 0.1 , 0.2 , 0.3 , 0.4 , 0.5 ] ;
315+
316+ assert_abs_diff_eq ! (
317+ a. symmetric_mean_absolute_percentage_error( & b) . unwrap( ) ,
318+ 0.5815873014693111 ,
319+ epsilon = 1e-5
320+ ) ;
321+ }
322+
284323 #[ test]
285324 fn test_max_error_for_single_targets ( ) {
286325 let records = array ! [ [ 0.0 , 0.0 ] , [ 0.1 , 0.1 ] , [ 0.2 , 0.2 ] , [ 0.3 , 0.3 ] , [ 0.4 , 0.4 ] ] ;
@@ -339,6 +378,23 @@ mod tests {
339378 assert_abs_diff_eq ! ( pct_err_from_arr1, pct_err_from_ds) ;
340379 }
341380
381+ #[ test]
382+ fn test_symmetric_mean_absolute_percentage_error_for_single_targets ( ) {
383+ let records = array ! [ [ 0.0 , 0.0 ] , [ 0.1 , 0.1 ] , [ 0.2 , 0.2 ] , [ 0.3 , 0.3 ] , [ 0.4 , 0.4 ] ] ;
384+ let targets = array ! [ 0.0 , 0.1 , 0.2 , 0.3 , 0.4 ] ;
385+ let st_dataset: DatasetBase < _ , _ > = ( records. view ( ) , targets) . into ( ) ;
386+ let prediction = array ! [ 0.1 , 0.3 , 0.2 , 0.5 , 0.7 ] ;
387+ let err_from_arr = prediction
388+ . symmetric_mean_absolute_percentage_error ( st_dataset. targets ( ) )
389+ . unwrap ( ) ;
390+ let prediction_ds: DatasetBase < _ , _ > = ( records. view ( ) , prediction) . into ( ) ;
391+ let err_from_ds = prediction_ds
392+ . symmetric_mean_absolute_percentage_error ( & st_dataset)
393+ . unwrap ( ) ;
394+ assert_abs_diff_eq ! ( err_from_arr, 0.8090909086184916 , epsilon = 1e-5 ) ;
395+ assert_abs_diff_eq ! ( err_from_arr, err_from_ds) ;
396+ }
397+
342398 #[ test]
343399 fn test_mean_squared_log_error_for_single_targets ( ) {
344400 let records = array ! [ [ 0.0 , 0.0 ] , [ 0.1 , 0.1 ] , [ 0.2 , 0.2 ] , [ 0.3 , 0.3 ] , [ 0.4 , 0.4 ] ] ;
0 commit comments