Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions jxl/src/frame/modular/decode/specialized_trees.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@ use crate::{
channel::ModularChannelDecoder,
common::{make_pixel, precompute_references},
},
flat_tree::{FlatTreeNode, predict_flat},
predict::{PredictionData, WeightedPredictorState, clamped_gradient},
tree::{
FlatTreeNode, NUM_NONREF_PROPERTIES, PROPERTIES_PER_PREVCHAN, TreeNode, predict_flat,
},
tree::{NUM_NONREF_PROPERTIES, PROPERTIES_PER_PREVCHAN, TreeNode},
},
headers::modular::GroupHeader,
image::Image,
Expand All @@ -27,7 +26,7 @@ use crate::{
pub struct NoWpTree {
flat_nodes: Vec<FlatTreeNode>,
references: Image<i32>,
property_buffer: Vec<i32>,
property_buffer: Box<[i32; 256]>,
single_value: Option<i32>,
}

Expand All @@ -44,8 +43,7 @@ impl NoWpTree {
.saturating_sub(NUM_NONREF_PROPERTIES)
.next_multiple_of(PROPERTIES_PER_PREVCHAN);
let references = Image::<i32>::new((num_ref_props, xsize))?;
let num_properties = NUM_NONREF_PROPERTIES + num_ref_props;
let mut property_buffer: Vec<i32> = vec![0; num_properties];
let mut property_buffer = Box::new([0; 256]);

property_buffer[0] = channel as i32;
property_buffer[1] = stream as i32;
Expand Down
168 changes: 168 additions & 0 deletions jxl/src/frame/modular/flat_tree.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

use std::fmt::Debug;

use super::{Predictor, predict::WeightedPredictorState};
use crate::{
error::Result,
frame::modular::{
Tree,
predict::PredictionData,
tree::{PredictionResult, TreeNode, compute_properties},
},
image::Image,
util::NewWithCapacity,
};

/// Flattened tree node for optimized traversal.
/// Stores parent + info about both children to evaluate 3 nodes per iteration.
#[derive(Debug, Clone, Copy)]
pub(super) enum FlatTreeNode {
Split {
properties: [u8; 3],
splitvals: [i32; 3],
child_id: u32,
},
Leaf {
predictor: Predictor,
multiplier: u32,
context: u32,
offset: i32,
},
}

#[inline]
#[allow(clippy::too_many_arguments, unsafe_code)]
pub(super) fn predict_flat(
flat_tree: &[FlatTreeNode],
prediction_data: PredictionData,
xsize: usize,
wp_state: Option<&mut WeightedPredictorState>,
x: usize,
y: usize,
references: &Image<i32>,
property_buffer: &mut [i32; 256],
) -> PredictionResult {
let wp_pred = compute_properties(
prediction_data,
xsize,
wp_state,
x,
y,
references,
property_buffer,
);

let mut pos = 0;
loop {
// Removing this bound check doesn't seem to have a significant effect.
let node = flat_tree[pos];
match node {
FlatTreeNode::Split {
properties,
splitvals,
child_id,
} => {
// This bound check is elided by virtue of `property_buffer` having 256 elements.
let props = properties.map(|x| property_buffer[x as usize]);
let p0 = props[0] <= splitvals[0];
let p1 = props[1] <= splitvals[1];
let p2 = props[2] <= splitvals[2];
pos = child_id as usize + if p0 { 2 | p2 as usize } else { p1 as usize };
}
FlatTreeNode::Leaf {
predictor,
multiplier,
context,
offset,
} => {
let pred = predictor.predict_one(prediction_data, wp_pred);
return PredictionResult {
guess: pred + offset as i64,
multiplier,
context,
};
}
};
}
}

impl Tree {
/// Build flat tree using BFS traversal.
/// Each flat node stores parent + both children info to reduce branches.
pub(super) fn build_flat_tree(nodes: &[TreeNode]) -> Result<Vec<FlatTreeNode>> {
use std::collections::VecDeque;

if nodes.is_empty() {
return Ok(vec![]);
}

let mut flat_nodes = Vec::new_with_capacity(nodes.len())?;
let mut queue: VecDeque<usize> = VecDeque::new();
queue.push_back(0); // Start with root

while let Some(cur_idx) = queue.pop_front() {
match nodes[cur_idx] {
TreeNode::Leaf {
predictor,
offset,
multiplier,
id,
} => {
flat_nodes.push(FlatTreeNode::Leaf {
predictor,
offset,
multiplier,
context: id,
});
}
TreeNode::Split {
property,
val,
left,
right,
} => {
// childID points to first of 4 grandchildren in output
let child_id = (flat_nodes.len() + queue.len() + 1) as u32;

let mut splitvals = [val, 0, 0];
let mut properties = [property, 0, 0];

// Process left (i=0) and right (i=1) children
for (i, &child_idx) in [left as usize, right as usize].iter().enumerate() {
match &nodes[child_idx] {
TreeNode::Leaf { .. } => {
// Child is leaf: enqueue leaf twice
queue.push_back(child_idx);
queue.push_back(child_idx);
}
TreeNode::Split {
property: cp,
val: cv,
left: cl,
right: cr,
} => {
// Child is split: store property/splitval and enqueue grandchildren
properties[i + 1] = *cp;
splitvals[i + 1] = *cv;
queue.push_back(*cl as usize);
queue.push_back(*cr as usize);
}
}
}

flat_nodes.push(FlatTreeNode::Split {
properties,
splitvals,
child_id,
});
}
}
}

Ok(flat_nodes)
}
}
1 change: 1 addition & 0 deletions jxl/src/frame/modular/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use jxl_transforms::transform_map::*;

mod borrowed_buffers;
pub(crate) mod decode;
mod flat_tree;
mod predict;
mod transforms;
mod tree;
Expand Down
Loading
Loading