diff --git a/jxl/src/frame/modular/decode/specialized_trees.rs b/jxl/src/frame/modular/decode/specialized_trees.rs index 5eac54624..d126ae003 100644 --- a/jxl/src/frame/modular/decode/specialized_trees.rs +++ b/jxl/src/frame/modular/decode/specialized_trees.rs @@ -15,10 +15,9 @@ use crate::{ channel::ModularChannelDecoder, common::{make_pixel, precompute_references}, }, + flat_tree::{FlatTreeNode, predict_flat}, predict::{PredictionData, WeightedPredictorState, clamped_gradient}, - tree::{ - FlatTreeNode, NUM_NONREF_PROPERTIES, PROPERTIES_PER_PREVCHAN, TreeNode, predict_flat, - }, + tree::{NUM_NONREF_PROPERTIES, PROPERTIES_PER_PREVCHAN, TreeNode}, }, headers::modular::GroupHeader, image::Image, @@ -27,7 +26,7 @@ use crate::{ pub struct NoWpTree { flat_nodes: Vec, references: Image, - property_buffer: Vec, + property_buffer: Box<[i32; 256]>, single_value: Option, } @@ -44,8 +43,7 @@ impl NoWpTree { .saturating_sub(NUM_NONREF_PROPERTIES) .next_multiple_of(PROPERTIES_PER_PREVCHAN); let references = Image::::new((num_ref_props, xsize))?; - let num_properties = NUM_NONREF_PROPERTIES + num_ref_props; - let mut property_buffer: Vec = vec![0; num_properties]; + let mut property_buffer = Box::new([0; 256]); property_buffer[0] = channel as i32; property_buffer[1] = stream as i32; diff --git a/jxl/src/frame/modular/flat_tree.rs b/jxl/src/frame/modular/flat_tree.rs new file mode 100644 index 000000000..5ec36c7d8 --- /dev/null +++ b/jxl/src/frame/modular/flat_tree.rs @@ -0,0 +1,168 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +use std::fmt::Debug; + +use super::{Predictor, predict::WeightedPredictorState}; +use crate::{ + error::Result, + frame::modular::{ + Tree, + predict::PredictionData, + tree::{PredictionResult, TreeNode, compute_properties}, + }, + image::Image, + util::NewWithCapacity, +}; + +/// Flattened tree node for optimized traversal. +/// Stores parent + info about both children to evaluate 3 nodes per iteration. +#[derive(Debug, Clone, Copy)] +pub(super) enum FlatTreeNode { + Split { + properties: [u8; 3], + splitvals: [i32; 3], + child_id: u32, + }, + Leaf { + predictor: Predictor, + multiplier: u32, + context: u32, + offset: i32, + }, +} + +#[inline] +#[allow(clippy::too_many_arguments, unsafe_code)] +pub(super) fn predict_flat( + flat_tree: &[FlatTreeNode], + prediction_data: PredictionData, + xsize: usize, + wp_state: Option<&mut WeightedPredictorState>, + x: usize, + y: usize, + references: &Image, + property_buffer: &mut [i32; 256], +) -> PredictionResult { + let wp_pred = compute_properties( + prediction_data, + xsize, + wp_state, + x, + y, + references, + property_buffer, + ); + + let mut pos = 0; + loop { + // Removing this bound check doesn't seem to have a significant effect. + let node = flat_tree[pos]; + match node { + FlatTreeNode::Split { + properties, + splitvals, + child_id, + } => { + // This bound check is elided by virtue of `property_buffer` having 256 elements. + let props = properties.map(|x| property_buffer[x as usize]); + let p0 = props[0] <= splitvals[0]; + let p1 = props[1] <= splitvals[1]; + let p2 = props[2] <= splitvals[2]; + pos = child_id as usize + if p0 { 2 | p2 as usize } else { p1 as usize }; + } + FlatTreeNode::Leaf { + predictor, + multiplier, + context, + offset, + } => { + let pred = predictor.predict_one(prediction_data, wp_pred); + return PredictionResult { + guess: pred + offset as i64, + multiplier, + context, + }; + } + }; + } +} + +impl Tree { + /// Build flat tree using BFS traversal. + /// Each flat node stores parent + both children info to reduce branches. + pub(super) fn build_flat_tree(nodes: &[TreeNode]) -> Result> { + use std::collections::VecDeque; + + if nodes.is_empty() { + return Ok(vec![]); + } + + let mut flat_nodes = Vec::new_with_capacity(nodes.len())?; + let mut queue: VecDeque = VecDeque::new(); + queue.push_back(0); // Start with root + + while let Some(cur_idx) = queue.pop_front() { + match nodes[cur_idx] { + TreeNode::Leaf { + predictor, + offset, + multiplier, + id, + } => { + flat_nodes.push(FlatTreeNode::Leaf { + predictor, + offset, + multiplier, + context: id, + }); + } + TreeNode::Split { + property, + val, + left, + right, + } => { + // childID points to first of 4 grandchildren in output + let child_id = (flat_nodes.len() + queue.len() + 1) as u32; + + let mut splitvals = [val, 0, 0]; + let mut properties = [property, 0, 0]; + + // Process left (i=0) and right (i=1) children + for (i, &child_idx) in [left as usize, right as usize].iter().enumerate() { + match &nodes[child_idx] { + TreeNode::Leaf { .. } => { + // Child is leaf: enqueue leaf twice + queue.push_back(child_idx); + queue.push_back(child_idx); + } + TreeNode::Split { + property: cp, + val: cv, + left: cl, + right: cr, + } => { + // Child is split: store property/splitval and enqueue grandchildren + properties[i + 1] = *cp; + splitvals[i + 1] = *cv; + queue.push_back(*cl as usize); + queue.push_back(*cr as usize); + } + } + } + + flat_nodes.push(FlatTreeNode::Split { + properties, + splitvals, + child_id, + }); + } + } + } + + Ok(flat_nodes) + } +} diff --git a/jxl/src/frame/modular/mod.rs b/jxl/src/frame/modular/mod.rs index ffd158872..843c1d933 100644 --- a/jxl/src/frame/modular/mod.rs +++ b/jxl/src/frame/modular/mod.rs @@ -32,6 +32,7 @@ use jxl_transforms::transform_map::*; mod borrowed_buffers; pub(crate) mod decode; +mod flat_tree; mod predict; mod transforms; mod tree; diff --git a/jxl/src/frame/modular/tree.rs b/jxl/src/frame/modular/tree.rs index fb0349afe..cad5e5deb 100644 --- a/jxl/src/frame/modular/tree.rs +++ b/jxl/src/frame/modular/tree.rs @@ -8,12 +8,11 @@ use std::fmt::Debug; use super::{Predictor, predict::WeightedPredictorState}; use crate::{ bit_reader::BitReader, - entropy_coding::decode::Histograms, - entropy_coding::decode::SymbolReader, + entropy_coding::decode::{Histograms, SymbolReader}, error::{Error, Result}, frame::modular::predict::PredictionData, image::Image, - util::{NewWithCapacity, tracing_wrappers::*}, + util::tracing_wrappers::*, }; #[derive(Debug, Clone, Copy)] @@ -32,32 +31,6 @@ pub enum TreeNode { }, } -/// Flattened tree node for optimized traversal (matches C++ FlatDecisionNode). -/// Stores parent + info about both children to evaluate 3 nodes per iteration. -// TODO(hjanuschka): investigate performance of using a Rust enum here, and whether -// separating internal nodes and leaves into two arrays could save a branch. -#[derive(Debug, Clone, Copy)] -pub(super) struct FlatTreeNode { - property0: i32, // Property to test, -1 if leaf - splitval0_or_predictor: i32, // Split value, or predictor if leaf - splitvals_or_multiplier: [i32; 2], // Child splitvals, or multiplier if leaf - child_id: u32, // Index to first grandchild, or context if leaf - properties_or_offset: [i16; 2], // Child properties, or offset if leaf -} - -impl FlatTreeNode { - #[inline] - fn leaf(predictor: Predictor, offset: i32, multiplier: u32, context: u32) -> Self { - Self { - property0: -1, - splitval0_or_predictor: predictor as i32, - splitvals_or_multiplier: [multiplier as i32, 0], - child_id: context, - properties_or_offset: [offset as i16, 0], - } - } -} - pub struct Tree { pub nodes: Vec, pub histograms: Histograms, @@ -213,7 +186,7 @@ const NUM_TREE_CONTEXTS: usize = 6; /// Computes properties for tree traversal. Shared between flat and non-flat prediction. /// Returns the weighted predictor prediction value. #[inline] -fn compute_properties( +pub(super) fn compute_properties( prediction_data: PredictionData, xsize: usize, wp_state: Option<&mut WeightedPredictorState>, @@ -342,71 +315,6 @@ pub(super) fn predict( } } -/// Optimized prediction using flat tree (matches C++ context_predict.h:351-371). -#[inline] -#[allow(clippy::too_many_arguments)] -pub(super) fn predict_flat( - flat_tree: &[FlatTreeNode], - prediction_data: PredictionData, - xsize: usize, - wp_state: Option<&mut WeightedPredictorState>, - x: usize, - y: usize, - references: &Image, - property_buffer: &mut [i32], -) -> PredictionResult { - let wp_pred = compute_properties( - prediction_data, - xsize, - wp_state, - x, - y, - references, - property_buffer, - ); - - // Flat tree traversal - let mut pos = 0; - loop { - let node = &flat_tree[pos]; - - if node.property0 < 0 { - // Leaf node - let predictor = Predictor::try_from(node.splitval0_or_predictor as u32).unwrap(); - let offset = node.properties_or_offset[0] as i32; - let multiplier = node.splitvals_or_multiplier[0] as u32; - let context = node.child_id; - - let pred = predictor.predict_one(prediction_data, wp_pred); - - return PredictionResult { - guess: pred + offset as i64, - multiplier, - context, - }; - } - - // Split node: C++ logic from context_predict.h:361-365 - let p0 = property_buffer[node.property0 as usize] <= node.splitval0_or_predictor; - let off0 = if property_buffer[node.properties_or_offset[0] as usize] - <= node.splitvals_or_multiplier[0] - { - 1 - } else { - 0 - }; - let off1 = if property_buffer[node.properties_or_offset[1] as usize] - <= node.splitvals_or_multiplier[1] - { - 3 - } else { - 2 - }; - - pos = (node.child_id + if p0 { off1 } else { off0 }) as usize; - } -} - impl Tree { #[instrument(level = "debug", skip(br), err)] pub fn read(br: &mut BitReader, size_limit: usize) -> Result { @@ -488,79 +396,6 @@ impl Tree { }) } - /// Build flat tree using BFS traversal (matches C++ encoding.cc:81-144). - /// Each flat node stores parent + both children info to reduce branches. - pub(super) fn build_flat_tree(nodes: &[TreeNode]) -> Result> { - use std::collections::VecDeque; - - if nodes.is_empty() { - return Ok(vec![]); - } - - let mut flat_nodes = Vec::new_with_capacity(nodes.len())?; - let mut queue: VecDeque = VecDeque::new(); - queue.push_back(0); // Start with root - - while let Some(cur_idx) = queue.pop_front() { - match &nodes[cur_idx] { - TreeNode::Leaf { - predictor, - offset, - multiplier, - id, - } => { - flat_nodes.push(FlatTreeNode::leaf(*predictor, *offset, *multiplier, *id)); - } - TreeNode::Split { - property, - val, - left, - right, - } => { - // childID points to first of 4 grandchildren in output - let child_id = (flat_nodes.len() + queue.len() + 1) as u32; - - let mut flat = FlatTreeNode { - property0: *property as i32, - splitval0_or_predictor: *val, - splitvals_or_multiplier: [0, 0], - child_id, - properties_or_offset: [0, 0], - }; - - // Process left (i=0) and right (i=1) children - for (i, &child_idx) in [*left as usize, *right as usize].iter().enumerate() { - match &nodes[child_idx] { - TreeNode::Leaf { .. } => { - // Child is leaf: set property=0 and enqueue leaf twice - flat.properties_or_offset[i] = 0; - flat.splitvals_or_multiplier[i] = 0; - queue.push_back(child_idx); - queue.push_back(child_idx); - } - TreeNode::Split { - property: cp, - val: cv, - left: cl, - right: cr, - } => { - // Child is split: store property/splitval and enqueue grandchildren - flat.properties_or_offset[i] = *cp as i16; - flat.splitvals_or_multiplier[i] = *cv; - queue.push_back(*cl as usize); - queue.push_back(*cr as usize); - } - } - } - - flat_nodes.push(flat); - } - } - } - - Ok(flat_nodes) - } - pub fn max_property_count(&self) -> usize { self.nodes .iter()