Skip to content

Commit 9ada31e

Browse files
committed
Added missing predict_proba() and predict_log_proba() to MultinomialNb
1 parent 5272ad1 commit 9ada31e

1 file changed

Lines changed: 94 additions & 1 deletion

File tree

algorithms/linfa-bayes/src/multinomial_nb.rs

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use linfa::dataset::{AsSingleTargets, DatasetBase, Labels};
22
use linfa::traits::{Fit, FitWith, PredictInplace};
33
use linfa::{Float, Label};
4-
use ndarray::{Array1, ArrayBase, ArrayView2, Axis, Data, Ix2};
4+
use ndarray::{Array1, Array2, s, ArrayBase, ArrayView2, Axis, Data, Ix2};
55
use std::collections::HashMap;
66
use std::hash::Hash;
77

@@ -248,6 +248,66 @@ where
248248
}
249249
}
250250

251+
impl<F, L> MultinomialNb<F, L>
252+
where
253+
F: Float,
254+
L: Label + Ord + Clone + Hash,
255+
{
256+
/// Compute class probabilities for each row
257+
pub fn predict_proba(&self, x: ArrayView2<F>) -> (Array2<F>, Vec<L>) {
258+
let log_likelihood = self.joint_log_likelihood(x);
259+
let n_samples = x.nrows();
260+
let n_classes = log_likelihood.len();
261+
262+
// Preserve deterministic class order
263+
let mut classes: Vec<L> = self.class_info.keys().cloned().collect();
264+
classes.sort();
265+
266+
let mut log_prob_mat = Array2::<F>::zeros((n_samples, n_classes));
267+
268+
for (j, class) in classes.iter().enumerate() {
269+
let class_log = log_likelihood.get(class).unwrap();
270+
log_prob_mat.slice_mut(s![.., j]).assign(class_log);
271+
}
272+
273+
// Apply softmax to each row
274+
for mut row in log_prob_mat.axis_iter_mut(Axis(0)) {
275+
let max = row.fold(F::neg_infinity(), |a, &b| a.max(b));
276+
let exp_sum = row.mapv(|v| (v - max).exp()).sum();
277+
row.mapv_inplace(|v| (v - max).exp() / exp_sum);
278+
}
279+
280+
(log_prob_mat, classes)
281+
}
282+
283+
/// Compute unnormalized log-probabilities for each sample and class
284+
pub fn predict_log_proba(&self, x: ArrayView2<F>) -> (Array2<F>, Vec<L>) {
285+
let log_likelihood = self.joint_log_likelihood(x);
286+
let n_samples = x.nrows();
287+
let n_classes = log_likelihood.len();
288+
289+
let mut classes: Vec<L> = self.class_info.keys().cloned().collect();
290+
classes.sort();
291+
292+
let mut log_prob_mat = Array2::<F>::zeros((n_samples, n_classes));
293+
294+
for (j, class) in classes.iter().enumerate() {
295+
let class_log = log_likelihood.get(class).unwrap();
296+
log_prob_mat.column_mut(j).assign(class_log);
297+
}
298+
299+
// log-sum-exp for normalization
300+
for mut row in log_prob_mat.axis_iter_mut(Axis(0)) {
301+
let max = row.fold(F::neg_infinity(), |a, &b| a.max(b));
302+
let logsumexp = row.mapv(|v| (v - max).exp()).sum().ln() + max;
303+
row.mapv_inplace(|v| v - logsumexp);
304+
}
305+
306+
(log_prob_mat, classes)
307+
}
308+
309+
}
310+
251311
#[cfg(test)]
252312
mod tests {
253313
use super::{MultinomialNb, NaiveBayes, Result};
@@ -368,4 +428,37 @@ mod tests {
368428

369429
Ok(())
370430
}
431+
432+
#[test]
433+
fn test_predict_proba_and_log_proba() -> Result<()> {
434+
use ndarray::{array, Array2};
435+
436+
let x = array![[2., 1.], [1., 3.], [0., 5.]];
437+
let y = array![0, 0, 1];
438+
439+
let dataset = DatasetView::new(x.view(), y.view());
440+
441+
let model = MultinomialNb::params().fit(&dataset)?;
442+
443+
let (proba, classes) = model.predict_proba(x.view());
444+
let (log_proba, log_classes) = model.predict_log_proba(x.view());
445+
446+
assert_eq!(classes, log_classes);
447+
448+
for i in 0..x.nrows() {
449+
let mut sum = 0.0_f64;
450+
for j in 0..classes.len() {
451+
let p: f64 = proba[[i, j]];
452+
let lp: f64 = log_proba[[i, j]].exp();
453+
assert!(p >= 0.0 && p <= 1.0);
454+
assert!((p - lp).abs() < 1e-6);
455+
sum += p;
456+
}
457+
assert!((sum - 1.0).abs() < 1e-6);
458+
}
459+
460+
Ok(())
461+
}
371462
}
463+
464+

0 commit comments

Comments
 (0)