Skip to content

Commit c8adb7e

Browse files
committed
work with predict inplace only
1 parent fd0aea0 commit c8adb7e

1 file changed

Lines changed: 21 additions & 12 deletions

File tree

src/composing/residual_sequence.rs

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -273,11 +273,11 @@ where
273273
impl<R1, R2, F: Float, D: Data<Elem = F>> PredictInplace<Arr2<D>, Array1<F>>
274274
for ResidualChain<R1, R2, F>
275275
where
276-
for<'a> R1: Predict<&'a Arr2<D>, Array1<F>>,
277-
for<'a> R2: Predict<&'a Arr2<D>, Array1<F>>,
276+
R1: PredictInplace<Arr2<D>, Array1<F>>,
277+
R2: PredictInplace<Arr2<D>, Array1<F>>,
278278
{
279279
fn predict_inplace<'a>(&'a self, x: &'a Arr2<D>, y: &mut Array1<F>) {
280-
y.assign(&self.base.predict(x));
280+
self.base.predict_inplace(x, y);
281281
y.add_assign(
282282
&self
283283
.corrector
@@ -411,9 +411,12 @@ mod tests {
411411
}
412412
}
413413

414-
impl<'a> Predict<&'a Array2<f64>, Array1<f64>> for MeanModel {
415-
fn predict(&self, x: &'a Array2<f64>) -> Array1<f64> {
416-
Array1::from_elem(x.nrows(), self.0)
414+
impl PredictInplace<Array2<f64>, Array1<f64>> for MeanModel {
415+
fn predict_inplace(&self, x: &Array2<f64>, y: &mut Array1<f64>) {
416+
y.assign(&Array1::from_elem(x.nrows(), self.0));
417+
}
418+
fn default_target(&self, x: &Array2<f64>) -> Array1<f64> {
419+
Array1::zeros(x.nrows())
417420
}
418421
}
419422

@@ -457,9 +460,12 @@ mod tests {
457460
}
458461
}
459462

460-
impl<'a> Predict<&'a Array2<f64>, Array1<f64>> for FixedModel {
461-
fn predict(&self, x: &'a Array2<f64>) -> Array1<f64> {
462-
Array1::from_elem(x.nrows(), self.0)
463+
impl PredictInplace<Array2<f64>, Array1<f64>> for FixedModel {
464+
fn predict_inplace(&self, x: &Array2<f64>, y: &mut Array1<f64>) {
465+
y.assign(&Array1::from_elem(x.nrows(), self.0));
466+
}
467+
fn default_target(&self, x: &Array2<f64>) -> Array1<f64> {
468+
Array1::zeros(x.nrows())
463469
}
464470
}
465471

@@ -517,9 +523,12 @@ mod tests {
517523
}
518524
}
519525

520-
impl<'a> Predict<&'a Array2<f64>, Array1<f64>> for FixedModel {
521-
fn predict(&self, x: &'a Array2<f64>) -> Array1<f64> {
522-
Array1::from_elem(x.nrows(), self.0)
526+
impl PredictInplace<Array2<f64>, Array1<f64>> for FixedModel {
527+
fn predict_inplace(&self, x: &Array2<f64>, y: &mut Array1<f64>) {
528+
y.assign(&Array1::from_elem(x.nrows(), self.0));
529+
}
530+
fn default_target(&self, x: &Array2<f64>) -> Array1<f64> {
531+
Array1::zeros(x.nrows())
523532
}
524533
}
525534

0 commit comments

Comments
 (0)