|
1 | 1 | use linfa::dataset::{AsSingleTargets, DatasetBase, Labels}; |
2 | 2 | use linfa::traits::{Fit, FitWith, PredictInplace}; |
3 | 3 | use linfa::{Float, Label}; |
4 | | -use ndarray::{Array1, ArrayBase, ArrayView2, Axis, Data, Ix2}; |
| 4 | +use ndarray::{Array1, Array2, s, ArrayBase, ArrayView2, Axis, Data, Ix2}; |
5 | 5 | use std::collections::HashMap; |
6 | 6 | use std::hash::Hash; |
7 | 7 |
|
@@ -248,6 +248,66 @@ where |
248 | 248 | } |
249 | 249 | } |
250 | 250 |
|
| 251 | +impl<F, L> MultinomialNb<F, L> |
| 252 | +where |
| 253 | + F: Float, |
| 254 | + L: Label + Ord + Clone + Hash, |
| 255 | +{ |
| 256 | + /// Compute class probabilities for each row |
| 257 | + pub fn predict_proba(&self, x: ArrayView2<F>) -> (Array2<F>, Vec<L>) { |
| 258 | + let log_likelihood = self.joint_log_likelihood(x); |
| 259 | + let n_samples = x.nrows(); |
| 260 | + let n_classes = log_likelihood.len(); |
| 261 | + |
| 262 | + // Preserve deterministic class order |
| 263 | + let mut classes: Vec<L> = self.class_info.keys().cloned().collect(); |
| 264 | + classes.sort(); |
| 265 | + |
| 266 | + let mut log_prob_mat = Array2::<F>::zeros((n_samples, n_classes)); |
| 267 | + |
| 268 | + for (j, class) in classes.iter().enumerate() { |
| 269 | + let class_log = log_likelihood.get(class).unwrap(); |
| 270 | + log_prob_mat.slice_mut(s![.., j]).assign(class_log); |
| 271 | + } |
| 272 | + |
| 273 | + // Apply softmax to each row |
| 274 | + for mut row in log_prob_mat.axis_iter_mut(Axis(0)) { |
| 275 | + let max = row.fold(F::neg_infinity(), |a, &b| a.max(b)); |
| 276 | + let exp_sum = row.mapv(|v| (v - max).exp()).sum(); |
| 277 | + row.mapv_inplace(|v| (v - max).exp() / exp_sum); |
| 278 | + } |
| 279 | + |
| 280 | + (log_prob_mat, classes) |
| 281 | + } |
| 282 | + |
| 283 | + /// Compute unnormalized log-probabilities for each sample and class |
| 284 | + pub fn predict_log_proba(&self, x: ArrayView2<F>) -> (Array2<F>, Vec<L>) { |
| 285 | + let log_likelihood = self.joint_log_likelihood(x); |
| 286 | + let n_samples = x.nrows(); |
| 287 | + let n_classes = log_likelihood.len(); |
| 288 | + |
| 289 | + let mut classes: Vec<L> = self.class_info.keys().cloned().collect(); |
| 290 | + classes.sort(); |
| 291 | + |
| 292 | + let mut log_prob_mat = Array2::<F>::zeros((n_samples, n_classes)); |
| 293 | + |
| 294 | + for (j, class) in classes.iter().enumerate() { |
| 295 | + let class_log = log_likelihood.get(class).unwrap(); |
| 296 | + log_prob_mat.column_mut(j).assign(class_log); |
| 297 | + } |
| 298 | + |
| 299 | + // log-sum-exp for normalization |
| 300 | + for mut row in log_prob_mat.axis_iter_mut(Axis(0)) { |
| 301 | + let max = row.fold(F::neg_infinity(), |a, &b| a.max(b)); |
| 302 | + let logsumexp = row.mapv(|v| (v - max).exp()).sum().ln() + max; |
| 303 | + row.mapv_inplace(|v| v - logsumexp); |
| 304 | + } |
| 305 | + |
| 306 | + (log_prob_mat, classes) |
| 307 | + } |
| 308 | + |
| 309 | +} |
| 310 | + |
251 | 311 | #[cfg(test)] |
252 | 312 | mod tests { |
253 | 313 | use super::{MultinomialNb, NaiveBayes, Result}; |
@@ -368,4 +428,37 @@ mod tests { |
368 | 428 |
|
369 | 429 | Ok(()) |
370 | 430 | } |
| 431 | + |
| 432 | + #[test] |
| 433 | + fn test_predict_proba_and_log_proba() -> Result<()> { |
| 434 | + use ndarray::{array, Array2}; |
| 435 | + |
| 436 | + let x = array![[2., 1.], [1., 3.], [0., 5.]]; |
| 437 | + let y = array![0, 0, 1]; |
| 438 | + |
| 439 | + let dataset = DatasetView::new(x.view(), y.view()); |
| 440 | + |
| 441 | + let model = MultinomialNb::params().fit(&dataset)?; |
| 442 | + |
| 443 | + let (proba, classes) = model.predict_proba(x.view()); |
| 444 | + let (log_proba, log_classes) = model.predict_log_proba(x.view()); |
| 445 | + |
| 446 | + assert_eq!(classes, log_classes); |
| 447 | + |
| 448 | + for i in 0..x.nrows() { |
| 449 | + let mut sum = 0.0_f64; |
| 450 | + for j in 0..classes.len() { |
| 451 | + let p: f64 = proba[[i, j]]; |
| 452 | + let lp: f64 = log_proba[[i, j]].exp(); |
| 453 | + assert!(p >= 0.0 && p <= 1.0); |
| 454 | + assert!((p - lp).abs() < 1e-6); |
| 455 | + sum += p; |
| 456 | + } |
| 457 | + assert!((sum - 1.0).abs() < 1e-6); |
| 458 | + } |
| 459 | + |
| 460 | + Ok(()) |
| 461 | + } |
371 | 462 | } |
| 463 | + |
| 464 | + |
0 commit comments