diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index 0a63939e2..990c02884 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -254,14 +254,35 @@ jobs: - name: Install jxl-perfhistory run: cargo install --git https://github.com/zond/jxl-perfhistory - - name: Select a few test images + - name: Select benchmark corpus (autoresearch 8 images) id: select-images run: | - mkdir testimages - cp jxl/resources/test/conformance_test_images/sunset_logo.jxl testimages - cp jxl/resources/test/conformance_test_images/bike.jxl testimages - cp jxl/resources/test/green_queen_modular_e3.jxl testimages - cp jxl/resources/test/green_queen_vardct_e3.jxl testimages + mkdir -p testimages + + # Keep this aligned with the autoresearch image set so CI numbers stay + # comparable while avoiding benchmark timeouts. + image_map=( + "jxl/resources/test/conformance_test_images/sunset_logo.jxl:sunset_logo.jxl" + "jxl/resources/test/conformance_test_images/bike.jxl:bike.jxl" + "jxl/resources/test/green_queen_modular_e3.jxl:green_queen_modular_e3.jxl" + "jxl/resources/test/green_queen_vardct_e3.jxl:green_queen_vardct_e3.jxl" + "jxl/resources/test/conformance_test_images/bicycles.jxl:bicycles.jxl" + "jxl/resources/test/conformance_test_images/delta_palette.jxl:delta_palette.jxl" + "jxl/resources/test/conformance_test_images/lz77_flower.jxl:lz77_flower.jxl" + "jxl/resources/test/conformance_test_images/patches.jxl:patches_lossless.jxl" + ) + + for entry in "${image_map[@]}"; do + src="${entry%%:*}" + dest="${entry##*:}" + if [ ! -f "$src" ]; then + echo "Missing benchmark image: $src" + exit 1 + fi + ln -s "$GITHUB_WORKSPACE/$src" "testimages/$dest" + done + + echo "Selected $(find testimages -maxdepth 1 -name '*.jxl' | wc -l) images for benchmark" - name: Cache benchmark binaries uses: actions/cache@v4 diff --git a/Cargo.toml b/Cargo.toml index 4409589fa..51a26fb52 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,8 @@ [profile.release] debug = true +lto = "thin" +panic = "abort" +overflow-checks = false [profile.bench] debug = true diff --git a/jxl/src/api/color.rs b/jxl/src/api/color.rs index ba5558155..549e011c4 100644 --- a/jxl/src/api/color.rs +++ b/jxl/src/api/color.rs @@ -1259,7 +1259,9 @@ impl JxlColorProfile { _ => false, } } - // ICC profiles require CMS + // Same ICC profile bytes means same color encoding -- skip CMS + (Self::Icc(a), Self::Icc(b)) => a == b, + // Mixed Simple/ICC always requires CMS _ => false, } } @@ -2736,7 +2738,7 @@ mod test { #[test] fn test_same_color_encoding_icc_profile() { - // ICC profiles are never considered same (even with themselves) + // ICC vs Simple are never the same let srgb = JxlColorProfile::Simple(JxlColorEncoding::RgbColorSpace { white_point: JxlWhitePoint::D65, primaries: JxlPrimaries::SRGB, @@ -2746,7 +2748,11 @@ mod test { let icc = JxlColorProfile::Icc(vec![0u8; 100]); // Dummy ICC profile assert!(!srgb.same_color_encoding(&icc)); assert!(!icc.same_color_encoding(&srgb)); - assert!(!icc.same_color_encoding(&icc)); + // Same ICC bytes ARE the same encoding (skip CMS identity transform) + assert!(icc.same_color_encoding(&icc)); + // Different ICC bytes are NOT the same + let icc2 = JxlColorProfile::Icc(vec![1u8; 100]); + assert!(!icc.same_color_encoding(&icc2)); } #[test] diff --git a/jxl/src/api/inner/codestream_parser/sections.rs b/jxl/src/api/inner/codestream_parser/sections.rs index 4fc5278c5..094174a8d 100644 --- a/jxl/src/api/inner/codestream_parser/sections.rs +++ b/jxl/src/api/inner/codestream_parser/sections.rs @@ -62,8 +62,11 @@ impl CodestreamParser { .front() .is_some_and(|s| s.len <= self.ready_section_data) { - let s = self.sections.pop_front().unwrap(); + let mut s = self.sections.pop_front().unwrap(); self.ready_section_data -= s.len; + // Add 8 zero-padding bytes so BitReader::refill() always takes the + // fast path (avoids refill_slow for small/tail sections). + s.data.extend_from_slice(&[0u8; 8]); match s.section { Section::LfGlobal => { @@ -87,32 +90,31 @@ impl CodestreamParser { let pixel_format = self.pixel_format.as_ref().unwrap(); let complete_lf_global; - let (lf_global, lf_global_is_complete) = if let Some(d) = self.lf_global_section.take() { - complete_lf_global = d; - ( - Some(&complete_lf_global.data[..complete_lf_global.len]), - true, - ) - } else if do_flush - && self - .sections - .front() - .is_some_and(|s| s.section == Section::LfGlobal) - && 2 * self.ready_section_data > 3 * self.section_state.lf_global_flush_len - && frame_header.encoding == Encoding::Modular - && matches!( - frame_header.frame_type, - FrameType::RegularFrame | FrameType::LFFrame - ) - { - self.section_state.lf_global_flush_len = self.ready_section_data; - ( - Some(&self.sections[0].data[..self.ready_section_data]), - false, - ) - } else { - (None, false) - }; + // lf_global_real_len: the actual data length (excluding padding bytes) + let (lf_global, lf_global_real_len, lf_global_is_complete) = + if let Some(d) = self.lf_global_section.take() { + complete_lf_global = d; + // Use full data slice (includes 8-byte padding from dequeue) + let real_len = complete_lf_global.len; + (Some(&complete_lf_global.data[..]), real_len, true) + } else if do_flush + && self + .sections + .front() + .is_some_and(|s| s.section == Section::LfGlobal) + && 2 * self.ready_section_data > 3 * self.section_state.lf_global_flush_len + && frame_header.encoding == Encoding::Modular + && matches!( + frame_header.frame_type, + FrameType::RegularFrame | FrameType::LFFrame + ) + { + self.section_state.lf_global_flush_len = self.ready_section_data; + let rsd = self.ready_section_data; + (Some(&self.sections[0].data[..rsd]), rsd, false) + } else { + (None, 0, false) + }; 'process: { if frame_header.num_groups() == 1 && frame_header.passes.num_passes == 1 { @@ -121,7 +123,11 @@ impl CodestreamParser { break 'process; }; assert!(self.sections.is_empty() || !lf_global_is_complete); - let mut br = BitReader::new(buf); + let mut br = if lf_global_is_complete { + BitReader::new_padded(buf, lf_global_real_len) + } else { + BitReader::new(buf) + }; let res = (|| -> Result<()> { frame.decode_lf_global(&mut br, !lf_global_is_complete)?; frame.decode_lf_group(0, &mut br)?; @@ -148,7 +154,12 @@ impl CodestreamParser { } } else { if let Some(buf) = lf_global { - match frame.decode_lf_global(&mut BitReader::new(buf), !lf_global_is_complete) { + let mut br = if lf_global_is_complete { + BitReader::new_padded(buf, lf_global_real_len) + } else { + BitReader::new(buf) + }; + match frame.decode_lf_global(&mut br, !lf_global_is_complete) { Ok(_) => { self.section_state.lf_global_done = true; processed_section = true; @@ -168,7 +179,10 @@ impl CodestreamParser { let Section::Lf { group } = lf_section.section else { unreachable!() }; - frame.decode_lf_group(group, &mut BitReader::new(&lf_section.data))?; + frame.decode_lf_group( + group, + &mut BitReader::new_padded(&lf_section.data, lf_section.len), + )?; processed_section = true; self.section_state.remaining_lf -= 1; } @@ -178,7 +192,10 @@ impl CodestreamParser { } if let Some(hf_global) = self.hf_global_section.take() { - frame.decode_hf_global(&mut BitReader::new(&hf_global.data))?; + frame.decode_hf_global(&mut BitReader::new_padded( + &hf_global.data, + hf_global.len, + ))?; frame.finalize_lf()?; self.section_state.hf_global_done = true; processed_section = true; @@ -202,7 +219,7 @@ impl CodestreamParser { break; }; self.section_state.completed_passes[g] += 1; - sections.push((pass, BitReader::new(&s.data))); + sections.push((pass, BitReader::new_padded(&s.data, s.len))); } if !sections.is_empty() { group_readers.push((g, sections)); diff --git a/jxl/src/bit_reader.rs b/jxl/src/bit_reader.rs index 31a1772de..1d5c03ac7 100644 --- a/jxl/src/bit_reader.rs +++ b/jxl/src/bit_reader.rs @@ -45,8 +45,23 @@ impl<'a> BitReader<'a> { } } + /// Constructs a BitReader for data with zero-padding appended. + /// `data` must contain at least `real_len + 8` bytes, with the last 8 bytes + /// being zero padding. `initial_bits` is set to `real_len * 8` for error checking. + /// This ensures `refill()` always takes the fast path (no refill_slow calls). + pub fn new_padded(data: &[u8], real_len: usize) -> BitReader<'_> { + debug_assert!(data.len() >= real_len + 8); + BitReader { + data, + bit_buf: 0, + bits_in_buf: 0, + total_bits_read: 0, + initial_bits: real_len * 8, + } + } + /// Reads `num` bits from the buffer without consuming them. - #[inline] + #[inline(always)] pub fn peek(&mut self, num: usize) -> u64 { debug_assert!(num <= MAX_BITS_PER_CALL); if self.bits_in_buf < num { @@ -66,7 +81,7 @@ impl<'a> BitReader<'a> { Ok(()) } - #[inline] + #[inline(always)] pub fn consume_optimistic(&mut self, num: usize) { self.bit_buf >>= num; self.bits_in_buf = self.bits_in_buf.saturating_sub(num); @@ -84,7 +99,7 @@ impl<'a> BitReader<'a> { /// assert!(br.read(1).is_err()); /// # Ok::<(), jxl::error::Error>(()) /// ``` - #[inline] + #[inline(always)] pub fn read(&mut self, num: usize) -> Result { let ret = self.peek(num); self.consume(num)?; @@ -97,7 +112,7 @@ impl<'a> BitReader<'a> { self.read(num) } - #[inline] + #[inline(always)] pub fn read_optimistic(&mut self, num: usize) -> u64 { let ret = self.peek(num); self.consume_optimistic(num); @@ -201,7 +216,7 @@ impl<'a> BitReader<'a> { Ok(()) } - #[inline] + #[inline(always)] fn refill(&mut self) { // See Refill() in C++ code. if self.data.len() >= 8 { diff --git a/jxl/src/entropy_coding/ans.rs b/jxl/src/entropy_coding/ans.rs index 9da5cbee0..ebeca2703 100644 --- a/jxl/src/entropy_coding/ans.rs +++ b/jxl/src/entropy_coding/ans.rs @@ -352,7 +352,7 @@ impl AnsHistogram { } impl AnsHistogram { - #[inline] + #[inline(always)] pub fn read(&self, br: &mut BitReader, state: &mut u32) -> u32 { let idx = *state & 0xfff; let i = (idx >> self.log_bucket_size) as usize; @@ -433,9 +433,13 @@ impl AnsReader { Ok(Self(initial_state)) } - #[inline] + #[inline(always)] + #[allow(unsafe_code)] pub fn read(&mut self, codes: &AnsCodes, br: &mut BitReader, ctx: usize) -> u32 { - codes.histograms[ctx].read(br, &mut self.0) + debug_assert!(ctx < codes.histograms.len()); + // SAFETY: ctx is a validated cluster ID from the context map, + // checked during Histograms::decode() to be < histograms.len(). + unsafe { codes.histograms.get_unchecked(ctx) }.read(br, &mut self.0) } pub fn check_final_state(self) -> Result<()> { diff --git a/jxl/src/entropy_coding/decode.rs b/jxl/src/entropy_coding/decode.rs index e57a3211a..69602f46e 100644 --- a/jxl/src/entropy_coding/decode.rs +++ b/jxl/src/entropy_coding/decode.rs @@ -28,6 +28,7 @@ pub fn decode_varint16(br: &mut BitReader) -> Result { } } +#[inline(always)] pub fn unpack_signed(unsigned: u32) -> i32 { ((unsigned >> 1) ^ ((!unsigned) & 1).wrapping_sub(1)) as i32 } @@ -101,7 +102,7 @@ impl Lz77State { ( 8, 4), ( 6, 7), (-6, 7), ( 7, 6), (-7, 6), ( 8, 5), ( 7, 7), (-7, 7), ( 8, 6), ( 8, 7), ]; - #[inline] + #[inline(always)] fn apply_copy(&mut self, distance_sym: u32, num_to_copy: u32) { let distance_sub_1 = if self.dist_multiplier == 0 { distance_sym @@ -118,11 +119,13 @@ impl Lz77State { self.num_to_copy = num_to_copy; } - #[inline] + #[inline(always)] + #[allow(unsafe_code)] fn push_decoded_symbol(&mut self, token: u32) { let offset = (self.num_decoded & Self::WINDOW_MASK) as usize; - if let Some(slot) = self.window.get_mut(offset) { - *slot = token; + if offset < self.window.len() { + // SAFETY: offset < self.window.len() checked above + unsafe { *self.window.get_unchecked_mut(offset) = token }; } else { debug_assert_eq!(self.window.len(), offset); self.window.push(token); @@ -130,10 +133,15 @@ impl Lz77State { self.num_decoded += 1; } - #[inline] + #[inline(always)] + #[allow(unsafe_code)] fn pull_symbol(&mut self) -> Option { if let Some(next_num_to_copy) = self.num_to_copy.checked_sub(1) { - let sym = self.window[(self.copy_pos & Self::WINDOW_MASK) as usize]; + let idx = (self.copy_pos & Self::WINDOW_MASK) as usize; + // SAFETY: copy_pos & WINDOW_MASK is always < 1 << LOG_WINDOW_SIZE, + // and the window has capacity 1 << LOG_WINDOW_SIZE. As long as num_decoded > 0 + // (checked by apply_copy), the window has enough entries. + let sym = unsafe { *self.window.get_unchecked(idx) }; self.copy_pos += 1; self.num_to_copy = next_num_to_copy; Some(sym) @@ -152,7 +160,7 @@ struct RleState { } impl RleState { - #[inline] + #[inline(always)] fn push_token( &mut self, token: u32, @@ -173,7 +181,7 @@ impl RleState { } } - #[inline] + #[inline(always)] fn pull_symbol(&mut self) -> Option { if self.repeat_count > 0 { self.repeat_count -= 1; @@ -322,6 +330,7 @@ impl SymbolReader { } #[inline(always)] + #[allow(unsafe_code)] pub fn read_unsigned_clustered_inline( &mut self, histograms: &Histograms, @@ -334,7 +343,10 @@ impl SymbolReader { Codes::Huffman(hc) => hc.read(br, cluster), Codes::Ans(ans) => self.ans_reader.read(ans, br, cluster), }; - histograms.uint_configs[cluster].read(token, br) + debug_assert!(cluster < histograms.uint_configs.len()); + // SAFETY: cluster is a validated cluster ID from the context map, + // which is checked during Histograms::decode() to be < uint_configs.len(). + unsafe { histograms.uint_configs.get_unchecked(cluster) }.read(token, br) } SymbolReaderState::Lz77(lz77_state) => { @@ -393,8 +405,14 @@ impl SymbolReader { Codes::Ans(ans) => self.ans_reader.read(ans, br, cluster), }; rle_state.push_token(token, histograms, br, cluster); - if let Some(sym) = rle_state.pull_symbol() { - sym + if rle_state.repeat_count > 0 { + rle_state.repeat_count -= 1; + if let Some(sym) = rle_state.last_sym { + sym + } else { + self.errors.lz77_repeat = true; + 0 + } } else { self.errors.lz77_repeat = true; 0 @@ -469,6 +487,30 @@ impl SymbolReader { unpack_signed(unsigned) } + /// Fast path for reads without LZ77. Skips the LZ77 state match. + /// Matching libjxl's template approach. + /// + /// # Preconditions + /// - `self.state` must be `SymbolReaderState::None` (no LZ77) + #[inline(always)] + #[allow(unsafe_code)] + pub fn read_signed_clustered_no_lz77( + &mut self, + histograms: &Histograms, + br: &mut BitReader, + cluster: usize, + ) -> i32 { + debug_assert!(matches!(self.state, SymbolReaderState::None)); + let token = match &histograms.codes { + Codes::Huffman(hc) => hc.read(br, cluster), + Codes::Ans(ans) => self.ans_reader.read(ans, br, cluster), + }; + debug_assert!(cluster < histograms.uint_configs.len()); + // SAFETY: cluster is a validated cluster ID. + let unsigned = unsafe { histograms.uint_configs.get_unchecked(cluster) }.read(token, br); + unpack_signed(unsigned) + } + pub fn check_final_state(self, histograms: &Histograms, br: &mut BitReader) -> Result<()> { self.errors.check_for_error()?; br.check_for_error()?; @@ -622,8 +664,14 @@ impl Histograms { }) } + #[inline(always)] + #[allow(unsafe_code)] pub fn map_context_to_cluster(&self, context: usize) -> usize { - self.context_map[context] as usize + debug_assert!(context < self.context_map.len()); + // SAFETY: context < context_map.len() is guaranteed by the caller - + // contexts are bounded by the number of leaf nodes in the tree + // or num_ac_contexts, both validated during decode. + unsafe { *self.context_map.get_unchecked(context) as usize } } pub fn num_histograms(&self) -> usize { @@ -640,6 +688,11 @@ impl Histograms { pub fn can_use_config_420_fast_path(&self) -> bool { !self.lz77_params.enabled && self.uint_configs.iter().all(|cfg| cfg.is_config_420()) } + + /// Returns true if LZ77 is disabled, enabling the no-LZ77 fast path. + pub fn has_no_lz77(&self) -> bool { + !self.lz77_params.enabled + } } #[cfg(test)] diff --git a/jxl/src/entropy_coding/huffman.rs b/jxl/src/entropy_coding/huffman.rs index 0790b3360..e69fed3cc 100644 --- a/jxl/src/entropy_coding/huffman.rs +++ b/jxl/src/entropy_coding/huffman.rs @@ -441,18 +441,26 @@ impl Table { Ok(Table { entries }) } - #[inline] + #[inline(always)] + #[allow(unsafe_code)] pub fn read(&self, br: &mut BitReader) -> u32 { let mut pos = br.peek(TABLE_BITS) as usize; - let mut n_bits = self.entries[pos].bits as usize; + // SAFETY: pos = peek(TABLE_BITS) which returns at most (1< TABLE_BITS { br.consume_optimistic(TABLE_BITS); n_bits -= TABLE_BITS; - pos += self.entries[pos].value as usize; - pos += br.peek(n_bits) as usize; + pos = pos + entry.value as usize + br.peek(n_bits) as usize; + // SAFETY: For 2nd-level tables, build() ensures the table is large enough + // to hold all entries pointed to by 1st-level value + peek(n_bits). + let entry = unsafe { *self.entries.get_unchecked(pos) }; + br.consume_optimistic(entry.bits as usize); + return entry.value as u32; } - br.consume_optimistic(self.entries[pos].bits as usize); - self.entries[pos].value as u32 + br.consume_optimistic(n_bits); + entry.value as u32 } } @@ -477,9 +485,13 @@ impl HuffmanCodes { Ok(HuffmanCodes { tables }) } - #[inline] + #[inline(always)] + #[allow(unsafe_code)] pub fn read(&self, br: &mut BitReader, ctx: usize) -> u32 { - self.tables[ctx].read(br) + debug_assert!(ctx < self.tables.len()); + // SAFETY: ctx is always < self.tables.len() because it comes from a validated + // context map (cluster ID), which was checked during Histograms::decode(). + unsafe { self.tables.get_unchecked(ctx) }.read(br) } pub fn single_symbol(&self, ctx: usize) -> Option { diff --git a/jxl/src/entropy_coding/hybrid_uint.rs b/jxl/src/entropy_coding/hybrid_uint.rs index 447bca94a..dc0ff953e 100644 --- a/jxl/src/entropy_coding/hybrid_uint.rs +++ b/jxl/src/entropy_coding/hybrid_uint.rs @@ -80,11 +80,24 @@ impl HybridUint { (hi << nbits) | bits } - #[inline] + #[inline(always)] pub fn read(&self, symbol: u32, br: &mut BitReader) -> u32 { if symbol < self.split_token { return symbol; } + + // Common path in some LZ77/RLE configs: no MSB bits embedded in token. + if self.msb_in_token == 0 { + let nbits = self.split_exponent - self.lsb_in_token + + ((symbol - self.split_token) >> self.lsb_in_token); + // The bitstream is invalid if nbits >= 32. We do not report errors, and just pretend we + // decoded a number <32. + let nbits = nbits & 31; + let low = symbol & ((1 << self.lsb_in_token) - 1); + let bits = br.read_optimistic(nbits as usize) as u32; + return (((1u32 << nbits) | bits) << self.lsb_in_token) | low; + } + let bits_in_token = self.lsb_in_token + self.msb_in_token; let nbits = self.split_exponent - bits_in_token + ((symbol - self.split_token) >> bits_in_token); diff --git a/jxl/src/features/blending.rs b/jxl/src/features/blending.rs index 62667162a..09b0f3d2a 100644 --- a/jxl/src/features/blending.rs +++ b/jxl/src/features/blending.rs @@ -9,7 +9,7 @@ use crate::headers::extra_channels::{ExtraChannel, ExtraChannelInfo}; use super::patches::{PatchBlendMode, PatchBlending}; -#[inline] +#[inline(always)] fn maybe_clamp(v: f32, clamp: bool) -> f32 { if clamp { v.clamp(0.0, 1.0) } else { v } } @@ -21,13 +21,78 @@ pub fn perform_blending, V: AsMut<[f32]>>( ec_blending: &[PatchBlending], extra_channel_info: &[ExtraChannelInfo], ) { + perform_blending_with_tmp( + bg, + fg, + color_blending, + ec_blending, + extra_channel_info, + None, + ); +} + +/// Like `perform_blending` but accepts a pre-allocated tmp buffer to avoid per-call heap +/// allocation. Each inner Vec is reused across calls (only resized/zeroed as needed). +pub fn perform_blending_with_tmp, V: AsMut<[f32]>>( + bg: &mut [V], + fg: &[T], + color_blending: &PatchBlending, + ec_blending: &[PatchBlending], + extra_channel_info: &[ExtraChannelInfo], + reusable_tmp: Option<&mut Vec>>, +) { + let num_ec = extra_channel_info.len(); + let xsize = bg[0].as_mut().len(); + + // Fast path: if color is None (keep bg) and all ec are None, nothing to do. + if color_blending.mode == PatchBlendMode::None + && ec_blending.iter().all(|b| b.mode == PatchBlendMode::None) + { + return; + } + + // Fast path: Replace color + Replace/None ec -> copy fg directly to bg, no tmp needed. + if color_blending.mode == PatchBlendMode::Replace { + let all_simple = ec_blending[..num_ec] + .iter() + .all(|b| b.mode == PatchBlendMode::Replace || b.mode == PatchBlendMode::None); + if all_simple { + for c in 0..3 { + bg[c].as_mut().copy_from_slice(fg[c].as_ref()); + } + for i in 0..num_ec { + match ec_blending[i].mode { + PatchBlendMode::Replace => { + bg[3 + i].as_mut().copy_from_slice(fg[3 + i].as_ref()); + } + PatchBlendMode::None => {} // keep bg + _ => unreachable!(), + } + } + return; + } + } + let has_alpha = extra_channel_info .iter() .any(|info| info.ec_type == ExtraChannel::Alpha); - let num_ec = extra_channel_info.len(); - let xsize = bg[0].as_mut().len(); - let mut tmp = vec![vec![0.0f32; xsize]; 3 + num_ec]; + let num_channels = 3 + num_ec; + + // Reuse pre-allocated buffer or allocate fresh. + let mut owned_tmp; + let tmp: &mut Vec> = if let Some(buf) = reusable_tmp { + // Ensure we have enough channels and each is the right size. + // No need to zero -- every code path below fully overwrites the used elements. + buf.resize_with(num_channels, || Vec::with_capacity(xsize)); + for ch in buf.iter_mut().take(num_channels) { + ch.resize(xsize, 0.0); + } + buf + } else { + owned_tmp = vec![vec![0.0f32; xsize]; num_channels]; + &mut owned_tmp + }; for i in 0..num_ec { let alpha = ec_blending[i].alpha_channel; diff --git a/jxl/src/features/noise.rs b/jxl/src/features/noise.rs index 5770a1934..760ff3ccb 100644 --- a/jxl/src/features/noise.rs +++ b/jxl/src/features/noise.rs @@ -17,6 +17,7 @@ impl Noise { } Ok(noise) } + #[inline(always)] pub fn strength(&self, vx: f32) -> f32 { let k_scale = (self.lut.len() - 2) as f32; let scaled_vx = f32::max(0.0, vx * k_scale); diff --git a/jxl/src/features/patches.rs b/jxl/src/features/patches.rs index d981ef3d9..61946d9ac 100644 --- a/jxl/src/features/patches.rs +++ b/jxl/src/features/patches.rs @@ -631,6 +631,7 @@ impl PatchesDictionary { Ok(patches_dict) } + #[inline(always)] pub fn set_patches_for_row(&self, y: usize, patches_for_row_result: &mut Vec) { patches_for_row_result.clear(); if self.num_patches.len() <= y || self.num_patches[y] == 0 { @@ -678,6 +679,7 @@ impl PatchesDictionary { patches_for_row_result.sort(); } + #[inline(always)] pub fn add_one_row( &self, row: &mut [&mut [f32]], diff --git a/jxl/src/frame/block_context_map.rs b/jxl/src/frame/block_context_map.rs index 9051f5965..2655dbb9a 100644 --- a/jxl/src/frame/block_context_map.rs +++ b/jxl/src/frame/block_context_map.rs @@ -124,6 +124,7 @@ impl BlockContextMap { } } } + #[inline(always)] pub fn block_context(&self, lf_idx: usize, qf: u32, shape_id: usize, c: usize) -> usize { let mut qf_idx: usize = 0; for t in &self.qf_thresholds { @@ -137,6 +138,7 @@ impl BlockContextMap { idx = idx * self.num_lf_contexts + lf_idx; self.context_map[idx] as usize } + #[inline(always)] pub fn nonzero_context(&self, nonzeros: usize, block_context: usize) -> usize { let context: usize = if nonzeros < 8 { nonzeros @@ -147,6 +149,7 @@ impl BlockContextMap { }; context * self.num_contexts + block_context } + #[inline(always)] pub fn zero_density_context_offset(&self, block_context: usize) -> usize { self.num_contexts * NON_ZERO_BUCKETS + ZERO_DENSITY_CONTEXT_COUNT * block_context } diff --git a/jxl/src/frame/color_correlation_map.rs b/jxl/src/frame/color_correlation_map.rs index 0c513b794..0fe86e939 100644 --- a/jxl/src/frame/color_correlation_map.rs +++ b/jxl/src/frame/color_correlation_map.rs @@ -73,18 +73,22 @@ impl ColorCorrelationParams { } } + #[inline(always)] pub fn y_to_x(&self, factor: i32) -> f32 { self.base_correlation_x + (factor as f32) / (self.color_factor as f32) } + #[inline(always)] pub fn y_to_x_lf(&self) -> f32 { self.y_to_x(self.ytox_lf) } + #[inline(always)] pub fn y_to_b(&self, factor: i32) -> f32 { self.base_correlation_b + (factor as f32) / (self.color_factor as f32) } + #[inline(always)] pub fn y_to_b_lf(&self) -> f32 { self.y_to_b(self.ytob_lf) } diff --git a/jxl/src/frame/decode.rs b/jxl/src/frame/decode.rs index f58b1044d..c7dde66bf 100644 --- a/jxl/src/frame/decode.rs +++ b/jxl/src/frame/decode.rs @@ -547,6 +547,8 @@ impl Frame { Ok(()) } + #[inline(always)] + #[allow(unsafe_code)] pub fn render_noise_for_group( &mut self, group: usize, @@ -614,23 +616,20 @@ impl Frame { continue; } - // Fill all 3 channels with this subregion's noise, sharing the RNG + // Fill all 3 channels with this subregion's noise, sharing the RNG. + // Reinterpret the u64 batch as u32 pairs to avoid per-element branching. for buf in &mut bufs { for y in 0..sub_ysize { let row = buf.row_mut(sub_y0 + y); for batch_index in 0..sub_xsize.div_ceil(FLOATS_PER_BATCH) { rng.fill(&mut batch); + // SAFETY: [u64; N] and [u32; 2*N] have the same layout + let batch_u32: &[u32; FLOATS_PER_BATCH] = + unsafe { &*batch.as_ptr().cast() }; let batch_size = (sub_xsize - batch_index * FLOATS_PER_BATCH).min(FLOATS_PER_BATCH); - for i in 0..batch_size { + for (i, &bits) in batch_u32.iter().take(batch_size).enumerate() { let x = sub_x0 + FLOATS_PER_BATCH * batch_index + i; - let k = i / 2; - let high_bytes = i % 2 != 0; - let bits = if high_bytes { - ((batch[k] & 0xFFFFFFFF00000000) >> 32) as u32 - } else { - (batch[k] & 0xFFFFFFFF) as u32 - }; row[x] = bits_to_float(bits); } } diff --git a/jxl/src/frame/group.rs b/jxl/src/frame/group.rs index b7d8021b3..b0707cc76 100644 --- a/jxl/src/frame/group.rs +++ b/jxl/src/frame/group.rs @@ -47,10 +47,10 @@ impl VarDctBuffers { /// Reset buffers to zero for reuse. pub fn reset(&mut self) { - self.scratch.fill(0.0); - for buf in &mut self.transform_buffer { - buf.fill(0.0); - } + // scratch does NOT need zeroing: each block's LF coefficients are fully written + // by copy_from_slice before transform_to_pixels reads them. + // transform_buffer does NOT need zeroing: dequant_block fully overwrites + // all num_coeffs entries before transform_to_pixels reads them. self.coeffs_storage.fill(0); } } @@ -61,7 +61,7 @@ impl Default for VarDctBuffers { } } -#[inline] +#[inline(always)] fn predict_num_nonzeros(nzeros_map: &Image, bx: usize, by: usize) -> usize { if bx == 0 { if by == 0 { @@ -365,6 +365,7 @@ impl<'a, 'b> PassInfo<'a, 'b> { #[allow(clippy::too_many_arguments)] #[allow(clippy::type_complexity)] +#[allow(unsafe_code)] pub fn decode_vardct_group( group: usize, passes: &mut [(usize, BitReader)], @@ -572,8 +573,13 @@ pub fn decode_vardct_group( reader.read_signed_inline(&pass_info.histograms, br, ctx) << *shift; prev = if coeff != 0 { 1 } else { 0 }; nonzeros -= prev; - let coeff_index = permutation[k] as usize; - current_coeffs[coeff_index] += coeff; + // SAFETY: permutation[k] is validated to be < num_coeffs during + // coeff_order decoding, and current_coeffs.len() == num_coeffs. + // k < num_coeffs by loop bounds. + let coeff_index = unsafe { *permutation.get_unchecked(k) } as usize; + // SAFETY: coeff_index comes from permutation[k], which is validated + // to be < num_coeffs, and current_coeffs.len() == num_coeffs. + unsafe { *current_coeffs.get_unchecked_mut(coeff_index) += coeff }; } if nonzeros != 0 { return Err(Error::EndOfBlockResidualNonZeros(nonzeros)); diff --git a/jxl/src/frame/modular/borrowed_buffers.rs b/jxl/src/frame/modular/borrowed_buffers.rs index 29c93efa4..d55f1bed0 100644 --- a/jxl/src/frame/modular/borrowed_buffers.rs +++ b/jxl/src/frame/modular/borrowed_buffers.rs @@ -27,8 +27,14 @@ pub fn with_buffers( let b = &buf.buffer_grid[grid]; let mut data = b.data.borrow_mut(); if data.is_none() { + // SAFETY: The modular decode loop fully writes every pixel in the data region. + // The padding region is zeroed for correct boundary behavior in prediction. + #[allow(unsafe_code)] + let img = unsafe { + Image::new_uninit_with_zeroed_padding(b.size, IMAGE_OFFSET, IMAGE_PADDING)? + }; *data = Some(ModularChannel { - data: Image::new_with_padding(b.size, IMAGE_OFFSET, IMAGE_PADDING)?, + data: img, auxiliary_data: None, shift: buf.info.shift, bit_depth: buf.info.bit_depth, diff --git a/jxl/src/frame/modular/decode/channel.rs b/jxl/src/frame/modular/decode/channel.rs index 398eb204c..5ccbef4ba 100644 --- a/jxl/src/frame/modular/decode/channel.rs +++ b/jxl/src/frame/modular/decode/channel.rs @@ -47,11 +47,15 @@ fn decode_modular_channel_small( const { assert!(IMAGE_OFFSET.1 == 2) }; + let mut property_buffer: Vec = vec![0; num_properties]; + property_buffer[0] = chan as i32; + property_buffer[1] = stream_id as i32; + for y in 0..size.1 { precompute_references(buffers, chan, y, &mut references); - let mut property_buffer: Vec = vec![0; num_properties]; - property_buffer[0] = chan as i32; - property_buffer[1] = stream_id as i32; + // Reset property 9 (local gradient depends on previous row's value) + property_buffer[9] = 0; + wp_state.set_row(y, size.0); let [row, row_top, row_toptop] = buffers[chan].data.distinct_full_rows_mut([y + 2, y + 1, y]); let row = &mut row[IMAGE_OFFSET.0..IMAGE_OFFSET.0 + size.0]; @@ -80,8 +84,6 @@ fn decode_modular_channel_small( } pub(super) trait ModularChannelDecoder { - const NEEDS_TOP: bool; - const NEEDS_TOPTOP: bool; fn init_row(&mut self, buffers: &mut [&mut ModularChannel], chan: usize, y: usize); fn decode_one( &mut self, @@ -92,6 +94,20 @@ pub(super) trait ModularChannelDecoder { br: &mut BitReader, histograms: &Histograms, ) -> i32; + /// Interior variant: x > 0 and x < xsize-1 guaranteed, y >= 2. + /// Default: delegates to decode_one. Override for WP trees to skip edge checks. + #[inline(always)] + fn decode_one_interior( + &mut self, + prediction_data: PredictionData, + pos: (usize, usize), + xsize: usize, + reader: &mut SymbolReader, + br: &mut BitReader, + histograms: &Histograms, + ) -> i32 { + self.decode_one(prediction_data, pos, xsize, reader, br, histograms) + } } #[inline(never)] @@ -137,18 +153,19 @@ fn decode_modular_channel_impl( row[x] = val; } } else { - for (x, r) in row.iter_mut().enumerate().skip(2).take(size.0 - 4) { - prediction_data = prediction_data.update_for_interior_row( - row_top, - row_toptop, - x, - last, - D::NEEDS_TOP, - D::NEEDS_TOPTOP, + #[allow(unsafe_code)] + for x in 2..size.0 - 2 { + prediction_data.update_for_interior_row(row_top, row_toptop, x, last); + let val = decoder.decode_one_interior( + prediction_data, + (x, y), + size.0, + reader, + br, + histograms, ); - let val = - decoder.decode_one(prediction_data, (x, y), size.0, reader, br, histograms); - *r = val; + // SAFETY: x is in [2, size.0 - 2), and row.len() == size.0. + unsafe { *row.get_unchecked_mut(x) = val }; last = val; } } @@ -195,6 +212,12 @@ pub(super) fn decode_modular_channel( TreeSpecialCase::NoWp(t) => { decode_modular_channel_impl(buffers, chan, t, reader, br, &tree.histograms) } + TreeSpecialCase::NoWpNoLz77(t) => { + decode_modular_channel_impl(buffers, chan, t, reader, br, &tree.histograms) + } + TreeSpecialCase::NoWpConfig420(t) => { + decode_modular_channel_impl(buffers, chan, t, reader, br, &tree.histograms) + } TreeSpecialCase::WpOnlyConfig420(t) => { decode_modular_channel_impl(buffers, chan, t, reader, br, &tree.histograms) } @@ -207,6 +230,12 @@ pub(super) fn decode_modular_channel( TreeSpecialCase::General(t) => { decode_modular_channel_impl(buffers, chan, t, reader, br, &tree.histograms) } + TreeSpecialCase::GeneralNoLz77(t) => { + decode_modular_channel_impl(buffers, chan, t, reader, br, &tree.histograms) + } + TreeSpecialCase::GeneralConfig420(t) => { + decode_modular_channel_impl(buffers, chan, t, reader, br, &tree.histograms) + } }?; br.check_for_error() } diff --git a/jxl/src/frame/modular/decode/common.rs b/jxl/src/frame/modular/decode/common.rs index e3107afea..80813ab31 100644 --- a/jxl/src/frame/modular/decode/common.rs +++ b/jxl/src/frame/modular/decode/common.rs @@ -47,6 +47,9 @@ pub(super) fn precompute_references( y: usize, references: &mut Image, ) { + if references.size().0 == 0 { + return; + } references.fill(0); let mut offset = 0; let num_extra_props = references.size().0; @@ -82,6 +85,7 @@ pub(super) fn precompute_references( } } +#[inline(always)] pub(super) fn make_pixel(dec: i32, mul: u32, guess: i64) -> i32 { (guess + (mul as i64) * (dec as i64)) as i32 } diff --git a/jxl/src/frame/modular/decode/specialized_trees.rs b/jxl/src/frame/modular/decode/specialized_trees.rs index 8ffebfe8a..1fb50c9e1 100644 --- a/jxl/src/frame/modular/decode/specialized_trees.rs +++ b/jxl/src/frame/modular/decode/specialized_trees.rs @@ -17,7 +17,8 @@ use crate::{ }, predict::{PredictionData, WeightedPredictorState, clamped_gradient}, tree::{ - FlatTreeNode, NUM_NONREF_PROPERTIES, PROPERTIES_PER_PREVCHAN, TreeNode, predict_flat, + FlatTreeNode, NUM_NONREF_PROPERTIES, PROPERTIES_PER_PREVCHAN, TreeNode, + predict_flat_no_wp, predict_flat_with_wp, predict_flat_with_wp_interior, }, }, headers::modular::GroupHeader, @@ -28,6 +29,8 @@ pub struct NoWpTree { flat_nodes: Vec, references: Image, property_buffer: Vec, + /// Bitmask of properties used by the tree (bit i = property i is used). + used_properties: u32, } impl NoWpTree { @@ -48,45 +51,123 @@ impl NoWpTree { property_buffer[0] = channel as i32; property_buffer[1] = stream as i32; - let flat_nodes = Tree::build_flat_tree(&nodes)?; + let (flat_nodes, used_properties) = Tree::build_flat_tree(&nodes)?; Ok(Self { flat_nodes, references, property_buffer, + used_properties, }) } } impl ModularChannelDecoder for NoWpTree { - const NEEDS_TOP: bool = true; - const NEEDS_TOPTOP: bool = true; - fn init_row(&mut self, buffers: &mut [&mut ModularChannel], chan: usize, y: usize) { precompute_references(buffers, chan, y, &mut self.references); - self.property_buffer[2..].fill(0); + // Only need to zero property 9 (local gradient) since property 8 + // depends on the previous value of property 9. Skip if neither is used. + if self.used_properties & 0x0300 != 0 { + self.property_buffer[9] = 0; + } } + #[inline(always)] fn decode_one( &mut self, prediction_data: PredictionData, pos: (usize, usize), - xsize: usize, + _xsize: usize, reader: &mut SymbolReader, br: &mut BitReader, histograms: &Histograms, ) -> i32 { - let prediction_result = predict_flat( + let prediction_result = predict_flat_no_wp( &self.flat_nodes, prediction_data, - xsize, - None, pos.0, pos.1, &self.references, &mut self.property_buffer, + self.used_properties, + ); + // Use inlined variant for hot interior loop (matches C++ ReadHybridUintClusteredInlined) + let dec = + reader.read_signed_clustered_inline(histograms, br, prediction_result.context as usize); + make_pixel(dec, prediction_result.multiplier, prediction_result.guess) + } +} + +/// NoWp tree variant without LZ77. Skips SymbolReaderState match per pixel. +pub struct NoWpTreeNoLz77(NoWpTree); + +impl ModularChannelDecoder for NoWpTreeNoLz77 { + fn init_row(&mut self, buffers: &mut [&mut ModularChannel], chan: usize, y: usize) { + self.0.init_row(buffers, chan, y); + } + + #[inline(always)] + fn decode_one( + &mut self, + prediction_data: PredictionData, + pos: (usize, usize), + _xsize: usize, + reader: &mut SymbolReader, + br: &mut BitReader, + histograms: &Histograms, + ) -> i32 { + let prediction_result = predict_flat_no_wp( + &self.0.flat_nodes, + prediction_data, + pos.0, + pos.1, + &self.0.references, + &mut self.0.property_buffer, + self.0.used_properties, + ); + let dec = reader.read_signed_clustered_no_lz77( + histograms, + br, + prediction_result.context as usize, + ); + make_pixel(dec, prediction_result.multiplier, prediction_result.guess) + } +} + +/// NoWp tree variant using the faster config_420 entropy decoder. +/// Used when all HybridUint configs are 420 and there's no LZ77. +pub struct NoWpTreeConfig420(NoWpTree); + +impl ModularChannelDecoder for NoWpTreeConfig420 { + fn init_row(&mut self, buffers: &mut [&mut ModularChannel], chan: usize, y: usize) { + self.0.init_row(buffers, chan, y); + } + + #[inline(always)] + fn decode_one( + &mut self, + prediction_data: PredictionData, + pos: (usize, usize), + _xsize: usize, + reader: &mut SymbolReader, + br: &mut BitReader, + histograms: &Histograms, + ) -> i32 { + let prediction_result = predict_flat_no_wp( + &self.0.flat_nodes, + prediction_data, + pos.0, + pos.1, + &self.0.references, + &mut self.0.property_buffer, + self.0.used_properties, + ); + // Use the specialized config_420 fast path for entropy decoding + let dec = reader.read_signed_clustered_config_420( + histograms, + br, + prediction_result.context as usize, ); - let dec = reader.read_signed_clustered(histograms, br, prediction_result.context as usize); make_pixel(dec, prediction_result.multiplier, prediction_result.guess) } } @@ -113,14 +194,102 @@ impl GeneralTree { } } +/// GeneralTree variant without LZ77. Skips the LZ77 state check per pixel. +/// Used when LZ77 is disabled but uint configs aren't all 420. +pub struct GeneralTreeNoLz77 { + no_wp_tree: NoWpTree, + wp_state: WeightedPredictorState, +} + +impl GeneralTreeNoLz77 { + fn new( + nodes: Vec, + max_property_count: usize, + header: &GroupHeader, + channel: usize, + stream: usize, + xsize: usize, + ) -> Result { + let wp_state = WeightedPredictorState::new(&header.wp_header, xsize); + Ok(Self { + no_wp_tree: NoWpTree::new(nodes, max_property_count, channel, stream, xsize)?, + wp_state, + }) + } +} + impl ModularChannelDecoder for GeneralTree { - const NEEDS_TOP: bool = true; - const NEEDS_TOPTOP: bool = true; + fn init_row(&mut self, buffers: &mut [&mut ModularChannel], chan: usize, y: usize) { + self.no_wp_tree.init_row(buffers, chan, y); + let xsize = buffers[chan].data.size().0; + self.wp_state.set_row(y, xsize); + } + + #[inline(always)] + fn decode_one( + &mut self, + prediction_data: PredictionData, + pos: (usize, usize), + xsize: usize, + reader: &mut SymbolReader, + br: &mut BitReader, + histograms: &Histograms, + ) -> i32 { + let prediction_result = predict_flat_with_wp( + &self.no_wp_tree.flat_nodes, + prediction_data, + xsize, + &mut self.wp_state, + pos.0, + pos.1, + &self.no_wp_tree.references, + &mut self.no_wp_tree.property_buffer, + self.no_wp_tree.used_properties, + ); + let dec = + reader.read_signed_clustered_inline(histograms, br, prediction_result.context as usize); + let val = make_pixel(dec, prediction_result.multiplier, prediction_result.guess); + self.wp_state.update_errors(val, pos, xsize); + val + } + #[inline(always)] + fn decode_one_interior( + &mut self, + prediction_data: PredictionData, + pos: (usize, usize), + xsize: usize, + reader: &mut SymbolReader, + br: &mut BitReader, + histograms: &Histograms, + ) -> i32 { + let prediction_result = predict_flat_with_wp_interior( + &self.no_wp_tree.flat_nodes, + prediction_data, + xsize, + &mut self.wp_state, + pos.0, + pos.1, + &self.no_wp_tree.references, + &mut self.no_wp_tree.property_buffer, + self.no_wp_tree.used_properties, + ); + let dec = + reader.read_signed_clustered_inline(histograms, br, prediction_result.context as usize); + let val = make_pixel(dec, prediction_result.multiplier, prediction_result.guess); + self.wp_state.update_errors(val, pos, xsize); + val + } +} + +impl ModularChannelDecoder for GeneralTreeNoLz77 { fn init_row(&mut self, buffers: &mut [&mut ModularChannel], chan: usize, y: usize) { self.no_wp_tree.init_row(buffers, chan, y); + let xsize = buffers[chan].data.size().0; + self.wp_state.set_row(y, xsize); } + #[inline(always)] fn decode_one( &mut self, prediction_data: PredictionData, @@ -130,17 +299,147 @@ impl ModularChannelDecoder for GeneralTree { br: &mut BitReader, histograms: &Histograms, ) -> i32 { - let prediction_result = predict_flat( + let prediction_result = predict_flat_with_wp( &self.no_wp_tree.flat_nodes, prediction_data, xsize, - Some(&mut self.wp_state), + &mut self.wp_state, pos.0, pos.1, &self.no_wp_tree.references, &mut self.no_wp_tree.property_buffer, + self.no_wp_tree.used_properties, + ); + // No-LZ77 fast path: skip SymbolReaderState match (compile-time elimination) + let dec = reader.read_signed_clustered_no_lz77( + histograms, + br, + prediction_result.context as usize, + ); + let val = make_pixel(dec, prediction_result.multiplier, prediction_result.guess); + self.wp_state.update_errors(val, pos, xsize); + val + } + + #[inline(always)] + fn decode_one_interior( + &mut self, + prediction_data: PredictionData, + pos: (usize, usize), + xsize: usize, + reader: &mut SymbolReader, + br: &mut BitReader, + histograms: &Histograms, + ) -> i32 { + let prediction_result = predict_flat_with_wp_interior( + &self.no_wp_tree.flat_nodes, + prediction_data, + xsize, + &mut self.wp_state, + pos.0, + pos.1, + &self.no_wp_tree.references, + &mut self.no_wp_tree.property_buffer, + self.no_wp_tree.used_properties, + ); + let dec = reader.read_signed_clustered_no_lz77( + histograms, + br, + prediction_result.context as usize, + ); + let val = make_pixel(dec, prediction_result.multiplier, prediction_result.guess); + self.wp_state.update_errors(val, pos, xsize); + val + } +} + +/// GeneralTree variant using the faster config_420 entropy decoder. +pub struct GeneralTreeConfig420 { + no_wp_tree: NoWpTree, + wp_state: WeightedPredictorState, +} + +impl GeneralTreeConfig420 { + fn new( + nodes: Vec, + max_property_count: usize, + header: &GroupHeader, + channel: usize, + stream: usize, + xsize: usize, + ) -> Result { + let wp_state = WeightedPredictorState::new(&header.wp_header, xsize); + Ok(Self { + no_wp_tree: NoWpTree::new(nodes, max_property_count, channel, stream, xsize)?, + wp_state, + }) + } +} + +impl ModularChannelDecoder for GeneralTreeConfig420 { + fn init_row(&mut self, buffers: &mut [&mut ModularChannel], chan: usize, y: usize) { + self.no_wp_tree.init_row(buffers, chan, y); + let xsize = buffers[chan].data.size().0; + self.wp_state.set_row(y, xsize); + } + + #[inline(always)] + fn decode_one( + &mut self, + prediction_data: PredictionData, + pos: (usize, usize), + xsize: usize, + reader: &mut SymbolReader, + br: &mut BitReader, + histograms: &Histograms, + ) -> i32 { + let prediction_result = predict_flat_with_wp( + &self.no_wp_tree.flat_nodes, + prediction_data, + xsize, + &mut self.wp_state, + pos.0, + pos.1, + &self.no_wp_tree.references, + &mut self.no_wp_tree.property_buffer, + self.no_wp_tree.used_properties, + ); + let dec = reader.read_signed_clustered_config_420( + histograms, + br, + prediction_result.context as usize, + ); + let val = make_pixel(dec, prediction_result.multiplier, prediction_result.guess); + self.wp_state.update_errors(val, pos, xsize); + val + } + + #[inline(always)] + fn decode_one_interior( + &mut self, + prediction_data: PredictionData, + pos: (usize, usize), + xsize: usize, + reader: &mut SymbolReader, + br: &mut BitReader, + histograms: &Histograms, + ) -> i32 { + let prediction_result = predict_flat_with_wp_interior( + &self.no_wp_tree.flat_nodes, + prediction_data, + xsize, + &mut self.wp_state, + pos.0, + pos.1, + &self.no_wp_tree.references, + &mut self.no_wp_tree.property_buffer, + self.no_wp_tree.used_properties, + ); + let dec = reader.read_signed_clustered_config_420( + histograms, + br, + prediction_result.context as usize, ); - let dec = reader.read_signed_clustered(histograms, br, prediction_result.context as usize); let val = make_pixel(dec, prediction_result.multiplier, prediction_result.guess); self.wp_state.update_errors(val, pos, xsize); val @@ -225,14 +524,13 @@ impl WpOnlyLookupConfig420 { } impl ModularChannelDecoder for WpOnlyLookupConfig420 { - const NEEDS_TOP: bool = true; - const NEEDS_TOPTOP: bool = true; - - fn init_row(&mut self, _buffers: &mut [&mut ModularChannel], _chan: usize, _y: usize) { - // nothing to do + fn init_row(&mut self, buffers: &mut [&mut ModularChannel], chan: usize, y: usize) { + let xsize = buffers[chan].data.size().0; + self.wp_state.set_row(y, xsize); } #[inline(always)] + #[allow(unsafe_code)] fn decode_one( &mut self, prediction_data: PredictionData, @@ -245,9 +543,34 @@ impl ModularChannelDecoder for WpOnlyLookupConfig420 { let (wp_pred, property) = self .wp_state .predict_and_property(pos, xsize, &prediction_data); - let ctx = self.lut[(property as i64 - LUT_MIN_SPLITVAL as i64) - .clamp(0, LUT_TABLE_SIZE as i64 - 1) as usize]; - // Use the specialized 420 fast path + let clamped = property.clamp(LUT_MIN_SPLITVAL, LUT_MAX_SPLITVAL); + let idx = (clamped - LUT_MIN_SPLITVAL) as usize; + // SAFETY: clamped is constrained to the LUT range, so idx is in-bounds. + let ctx = unsafe { *self.lut.get_unchecked(idx) }; + let dec = reader.read_signed_clustered_config_420(histograms, br, ctx as usize); + let val = dec.wrapping_add(wp_pred as i32); + self.wp_state.update_errors(val, pos, xsize); + val + } + + #[inline(always)] + #[allow(unsafe_code)] + fn decode_one_interior( + &mut self, + prediction_data: PredictionData, + pos: (usize, usize), + xsize: usize, + reader: &mut SymbolReader, + br: &mut BitReader, + histograms: &Histograms, + ) -> i32 { + let (wp_pred, property) = + self.wp_state + .predict_and_property_interior(pos.0, xsize, &prediction_data); + let clamped = property.clamp(LUT_MIN_SPLITVAL, LUT_MAX_SPLITVAL); + let idx = (clamped - LUT_MIN_SPLITVAL) as usize; + // SAFETY: clamped is constrained to the LUT range, so idx is in-bounds. + let ctx = unsafe { *self.lut.get_unchecked(idx) }; let dec = reader.read_signed_clustered_config_420(histograms, br, ctx as usize); let val = dec.wrapping_add(wp_pred as i32); self.wp_state.update_errors(val, pos, xsize); @@ -292,9 +615,6 @@ fn make_gradient_lut_config_420( } impl ModularChannelDecoder for GradientLookupConfig420 { - const NEEDS_TOP: bool = true; - const NEEDS_TOPTOP: bool = false; - fn init_row(&mut self, _: &mut [&mut ModularChannel], _: usize, _: usize) {} #[inline(always)] @@ -333,9 +653,6 @@ pub struct SingleGradientOnly { } impl ModularChannelDecoder for SingleGradientOnly { - const NEEDS_TOP: bool = true; - const NEEDS_TOPTOP: bool = false; - fn init_row(&mut self, _: &mut [&mut ModularChannel], _: usize, _: usize) {} #[inline(always)] @@ -348,9 +665,13 @@ impl ModularChannelDecoder for SingleGradientOnly { br: &mut BitReader, histograms: &Histograms, ) -> i32 { - let pred = Predictor::Gradient.predict_one(prediction_data, 0); + let pred = clamped_gradient( + prediction_data.left as i64, + prediction_data.top as i64, + prediction_data.topleft as i64, + ); let dec = reader.read_signed_clustered_inline(histograms, br, self.clustered_ctx); - make_pixel(dec, 1, pred) + dec.wrapping_add(pred as i32) } } @@ -359,9 +680,6 @@ pub struct NoTree { } impl ModularChannelDecoder for NoTree { - const NEEDS_TOP: bool = false; - const NEEDS_TOPTOP: bool = false; - fn init_row(&mut self, _: &mut [&mut ModularChannel], _: usize, _: usize) {} #[inline(always)] @@ -383,10 +701,14 @@ impl ModularChannelDecoder for NoTree { pub enum TreeSpecialCase { NoTree(NoTree), NoWp(NoWpTree), + NoWpNoLz77(NoWpTreeNoLz77), + NoWpConfig420(NoWpTreeConfig420), WpOnlyConfig420(WpOnlyLookupConfig420), GradientLookupConfig420(GradientLookupConfig420), SingleGradientOnly(SingleGradientOnly), General(GeneralTree), + GeneralNoLz77(GeneralTreeNoLz77), + GeneralConfig420(GeneralTreeConfig420), } pub fn specialize_tree( @@ -496,6 +818,28 @@ pub fn specialize_tree( if let Some(gl) = make_gradient_lut_config_420(&pruned_tree, &tree.histograms) { return Ok(TreeSpecialCase::GradientLookupConfig420(gl)); } + // Use config_420 fast path when all entropy configs are 420 (no LZ77) + if tree.histograms.can_use_config_420_fast_path() { + return Ok(TreeSpecialCase::NoWpConfig420(NoWpTreeConfig420( + NoWpTree::new( + pruned_tree, + tree.max_property_count(), + channel, + stream, + xsize, + )?, + ))); + } + // Use no-LZ77 fast path when LZ77 is disabled + if tree.histograms.has_no_lz77() { + return Ok(TreeSpecialCase::NoWpNoLz77(NoWpTreeNoLz77(NoWpTree::new( + pruned_tree, + tree.max_property_count(), + channel, + stream, + xsize, + )?))); + } return Ok(TreeSpecialCase::NoWp(NoWpTree::new( pruned_tree, tree.max_property_count(), @@ -505,6 +849,32 @@ pub fn specialize_tree( )?)); } + // Use config_420 fast path for general (WP) trees when all entropy configs are 420 + if tree.histograms.can_use_config_420_fast_path() { + return Ok(TreeSpecialCase::GeneralConfig420( + GeneralTreeConfig420::new( + pruned_tree, + tree.max_property_count(), + header, + channel, + stream, + xsize, + )?, + )); + } + + // Use no-LZ77 fast path when LZ77 is disabled but configs aren't all 420 + if tree.histograms.has_no_lz77() { + return Ok(TreeSpecialCase::GeneralNoLz77(GeneralTreeNoLz77::new( + pruned_tree, + tree.max_property_count(), + header, + channel, + stream, + xsize, + )?)); + } + Ok(TreeSpecialCase::General(GeneralTree::new( pruned_tree, tree.max_property_count(), diff --git a/jxl/src/frame/modular/mod.rs b/jxl/src/frame/modular/mod.rs index c5d665f32..860cf2c12 100644 --- a/jxl/src/frame/modular/mod.rs +++ b/jxl/src/frame/modular/mod.rs @@ -142,13 +142,19 @@ impl ModularChannel { Self::new_with_shift(size, Some((0, 0)), bit_depth) } + #[allow(unsafe_code)] fn new_with_shift( size: (usize, usize), shift: Option<(usize, usize)>, bit_depth: BitDepth, ) -> Result { + // SAFETY: The modular decode loop (decode_modular_channel_impl) fully writes + // every pixel in the data region. The padding region is zeroed for correct + // boundary behavior in prediction. + let data = + unsafe { Image::new_uninit_with_zeroed_padding(size, IMAGE_OFFSET, IMAGE_PADDING)? }; Ok(ModularChannel { - data: Image::new_with_padding(size, IMAGE_OFFSET, IMAGE_PADDING)?, + data, auxiliary_data: None, shift, bit_depth, diff --git a/jxl/src/frame/modular/predict.rs b/jxl/src/frame/modular/predict.rs index d5907c373..56ef16583 100644 --- a/jxl/src/frame/modular/predict.rs +++ b/jxl/src/frame/modular/predict.rs @@ -9,11 +9,9 @@ use crate::{ image::Image, util::floor_log2_nonzero, }; -use num_derive::FromPrimitive; -use num_traits::FromPrimitive; #[repr(u8)] -#[derive(Debug, FromPrimitive, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Predictor { Zero = 0, West = 1, @@ -43,11 +41,29 @@ impl Predictor { } } +impl Predictor { + /// Fast conversion from u32. Since Predictor is #[repr(u8)] with + /// contiguous values 0..=13, a bounds check + transmute replaces + /// the num_derive FromPrimitive chain (from_u32 -> from_u64 -> from_i64). + #[inline(always)] + #[allow(unsafe_code)] + pub fn from_u32_fast(value: u32) -> Option { + if value <= Predictor::AverageAll as u32 { + // SAFETY: Predictor is #[repr(u8)] with contiguous discriminants 0..=13, + // and we verified value is in range. + Some(unsafe { std::mem::transmute::(value as u8) }) + } else { + None + } + } +} + impl TryFrom for Predictor { type Error = Error; + #[inline(always)] fn try_from(value: u32) -> Result { - Self::from_u32(value).ok_or(Error::InvalidPredictor(value)) + Self::from_u32_fast(value).ok_or(Error::InvalidPredictor(value)) } } @@ -63,36 +79,29 @@ pub struct PredictionData { } impl PredictionData { - #[inline] + #[inline(always)] + #[allow(unsafe_code)] pub fn update_for_interior_row( - self, + &mut self, row_top: &[i32], row_toptop: &[i32], x: usize, cur: i32, - needs_top: bool, - needs_toptop: bool, - ) -> PredictionData { + ) { debug_assert!(x > 1); debug_assert!(x + 2 < row_top.len()); - let left = cur; - let top = self.topright; - let topleft = self.top; - let topright = self.toprightright; - let leftleft = self.left; - let toptop = if needs_toptop { row_toptop[x] } else { 0 }; - let toprightright = if needs_top { row_top[x + 2] } else { 0 }; - Self { - left, - top, - toptop, - topleft, - topright, - leftleft, - toprightright, - } + self.leftleft = self.left; + self.topleft = self.top; + self.top = self.topright; + self.topright = self.toprightright; + self.left = cur; + // SAFETY: x < row_toptop.len() because x < width and row_toptop has width elements. + self.toptop = unsafe { *row_toptop.get_unchecked(x) }; + // SAFETY: x + 2 < row_top.len() asserted above. + self.toprightright = unsafe { *row_top.get_unchecked(x + 2) }; } + #[inline(always)] pub fn get_rows(row: &[i32], row_top: &[i32], row_toptop: &[i32], x: usize, y: usize) -> Self { let left = if x > 0 { row[x - 1] @@ -126,6 +135,7 @@ impl PredictionData { } } + #[inline(always)] pub fn get(rect: &Image, x: usize, y: usize) -> Self { Self::get_rows( rect.row(y), @@ -136,6 +146,7 @@ impl PredictionData { ) } + #[inline(always)] #[allow(clippy::too_many_arguments)] pub fn get_with_neighbors( rect: &Image, @@ -250,6 +261,7 @@ impl PredictionData { } } +#[inline(always)] pub fn clamped_gradient(left: i64, top: i64, topleft: i64) -> i64 { // Same code/logic as libjxl. let min = left.min(top); @@ -262,7 +274,7 @@ pub fn clamped_gradient(left: i64, top: i64, topleft: i64) -> i64 { impl Predictor { pub const NUM_PREDICTORS: u32 = Predictor::AverageAll as u32 + 1; - #[inline] + #[inline(always)] pub fn predict_one( &self, PredictionData { @@ -302,6 +314,7 @@ impl Predictor { } } + #[inline(always)] fn select(left: i64, top: i64, topleft: i64) -> i64 { let p = left + top - topleft; if (p - left).abs() < (p - top).abs() { @@ -332,29 +345,46 @@ fn add_bits(x: i32) -> i64 { } #[inline(always)] +#[allow(unsafe_code)] fn error_weight(x: u32, maxweight: u32) -> u32 { - let shift = floor_log2_nonzero(x as u64 + 1) as i32 - 5; - if shift < 0 { - 4u32 + maxweight * DIVLOOKUP[x as usize & 63] - } else { - 4u32 + ((maxweight * DIVLOOKUP[(x as usize >> shift) & 63]) >> shift) - } + // Branchless version matching libjxl: clamp shift to 0 instead of branching + // Use u32 lzcnt directly (avoids zero-extend to u64) + let log2 = 31u32 ^ (x + 1).leading_zeros(); + let shift = (log2 as i32 - 5).max(0) as u32; + // SAFETY: x >> shift is < 64 because: + // - if shift == 0: x < 32 (since log2(x+1) < 5), so x < 64 + // - if shift > 0: x >> shift < 2^5 = 32 < 64 + 4u32 + ((maxweight * unsafe { *DIVLOOKUP.get_unchecked((x >> shift) as usize) }) >> shift) } #[inline(always)] +#[allow(unsafe_code)] fn weighted_average(pixels: &[i64; NUM_PREDICTORS], weights: &mut [u32; NUM_PREDICTORS]) -> i64 { - let log_weight = floor_log2_nonzero(weights.iter().fold(0u64, |sum, el| sum + *el as u64)); - let weight_sum = weights.iter_mut().fold(0, |sum, el| { - *el >>= log_weight - 4; - sum + *el - }); - let sum = weights - .iter() - .enumerate() - .fold(((weight_sum >> 1) - 1) as i64, |sum, (i, weight)| { - sum + pixels[i] * *weight as i64 - }); - (sum * DIVLOOKUP[(weight_sum - 1) as usize] as i64) >> 24 + // Sum weights as u32 (always fits in practice). + let sum32 = weights[0] + .wrapping_add(weights[1]) + .wrapping_add(weights[2]) + .wrapping_add(weights[3]); + let log_weight = if sum32 > 0 { + 31u32 ^ sum32.leading_zeros() + } else { + floor_log2_nonzero(weights.iter().fold(0u64, |sum, el| sum + *el as u64)) + }; + let shift = log_weight - 4; + weights[0] >>= shift; + weights[1] >>= shift; + weights[2] >>= shift; + weights[3] >>= shift; + let weight_sum = weights[0] + weights[1] + weights[2] + weights[3]; + let sum = ((weight_sum >> 1) - 1) as i64 + + pixels[0] * weights[0] as i64 + + pixels[1] * weights[1] as i64 + + pixels[2] * weights[2] as i64 + + pixels[3] * weights[3] as i64; + debug_assert!((weight_sum - 1) < 64, "weight_sum={}", weight_sum); + // SAFETY: weight_sum <= 64 after the shift-right by (log_weight - 4). + let div = unsafe { *DIVLOOKUP.get_unchecked((weight_sum - 1) as usize) }; + (sum * div as i64) >> 24 } #[derive(Debug)] @@ -366,11 +396,22 @@ pub struct WeightedPredictorState { pred_errors_buffer: Vec, error: Vec, wp_header: WeightedHeader, + /// Pre-computed max weights from wp_header, avoids match+unwrap per iteration. + maxweights: [u32; NUM_PREDICTORS], + /// Precomputed row offsets (set once per row via set_row) + cur_row: usize, + prev_row: usize, } impl WeightedPredictorState { pub fn new(wp_header: &WeightedHeader, xsize: usize) -> WeightedPredictorState { let num_errors = (xsize + 2) * 2; + let maxweights = [ + wp_header.w(0).unwrap(), + wp_header.w(1).unwrap(), + wp_header.w(2).unwrap(), + wp_header.w(3).unwrap(), + ]; WeightedPredictorState { prediction: [0; NUM_PREDICTORS], pred: 0, @@ -380,25 +421,33 @@ impl WeightedPredictorState { pred_errors_buffer: vec![0; num_errors * NUM_PREDICTORS], error: vec![0; num_errors], wp_header: wp_header.clone(), + maxweights, + cur_row: 0, + prev_row: 0, } } /// Get all predictor errors for a given position (contiguous in memory) #[inline(always)] + #[allow(unsafe_code)] fn get_errors_at_pos(&self, pos: usize) -> &[u32; NUM_PREDICTORS] { let start = pos * NUM_PREDICTORS; - self.pred_errors_buffer[start..start + NUM_PREDICTORS] - .try_into() - .unwrap() + debug_assert!(start + NUM_PREDICTORS <= self.pred_errors_buffer.len()); + // SAFETY: start + NUM_PREDICTORS <= buffer.len() because pos < num_errors + // (which is (xsize+2)*2), and buffer.len() = num_errors * NUM_PREDICTORS. + unsafe { &*(self.pred_errors_buffer.as_ptr().add(start) as *const [u32; NUM_PREDICTORS]) } } /// Get mutable reference to all predictor errors for a given position #[inline(always)] + #[allow(unsafe_code, dead_code)] fn get_errors_at_pos_mut(&mut self, pos: usize) -> &mut [u32; NUM_PREDICTORS] { let start = pos * NUM_PREDICTORS; - (&mut self.pred_errors_buffer[start..start + NUM_PREDICTORS]) - .try_into() - .unwrap() + debug_assert!(start + NUM_PREDICTORS <= self.pred_errors_buffer.len()); + // SAFETY: same invariant as get_errors_at_pos. + unsafe { + &mut *(self.pred_errors_buffer.as_mut_ptr().add(start) as *mut [u32; NUM_PREDICTORS]) + } } pub fn save_state(&self, wp_image: &mut Image, xsize: usize) { @@ -411,44 +460,223 @@ impl WeightedPredictorState { self.error[xsize + 2..].copy_from_slice(wp_image.row(0)); } + /// Precompute row offsets for the current y. Call once per row before the x loop. #[inline(always)] - pub fn update_errors(&mut self, correct_val: i32, pos: (usize, usize), xsize: usize) { - let (cur_row, prev_row) = if pos.1 & 1 != 0 { - (0, xsize + 2) + pub fn set_row(&mut self, y: usize, xsize: usize) { + if y & 1 != 0 { + self.cur_row = 0; + self.prev_row = xsize + 2; } else { - (xsize + 2, 0) - }; - let val = add_bits(correct_val); - self.error[cur_row + pos.0] = (self.pred - val) as i32; - - // Compute errors for all predictors - let mut errs = [0u32; NUM_PREDICTORS]; - for (err, &pred) in errs.iter_mut().zip(self.prediction.iter()) { - *err = (((pred - val).abs() + PREDICTION_ROUND) >> PRED_EXTRA_BITS) as u32; + self.cur_row = xsize + 2; + self.prev_row = 0; } + } - // Write to current position (contiguous access) - *self.get_errors_at_pos_mut(cur_row + pos.0) = errs; - - // Update previous row position (contiguous access) - let prev_errors = self.get_errors_at_pos_mut(prev_row + pos.0 + 1); - for i in 0..NUM_PREDICTORS { - prev_errors[i] = prev_errors[i].wrapping_add(errs[i]); + #[allow(unsafe_code)] + pub fn update_errors(&mut self, correct_val: i32, pos: (usize, usize), xsize: usize) { + let _ = xsize; // xsize now precomputed in set_row + let cur_row = self.cur_row; + let prev_row = self.prev_row; + let val = add_bits(correct_val); + // SAFETY: cur_row + pos.0 < (xsize+2)*2 = error.len() since pos.0 < xsize + // and cur_row is either 0 or xsize+2. + unsafe { *self.error.get_unchecked_mut(cur_row + pos.0) = (self.pred - val) as i32 }; + + // Compute errors for all predictors and write to cur + accumulate to prev in one pass. + // Unrolled to avoid loop overhead and help register allocation. + let cur_start = (cur_row + pos.0) * NUM_PREDICTORS; + let prev_start = (prev_row + pos.0 + 1) * NUM_PREDICTORS; + debug_assert!(cur_start + NUM_PREDICTORS <= self.pred_errors_buffer.len()); + debug_assert!(prev_start + NUM_PREDICTORS <= self.pred_errors_buffer.len()); + let buf = self.pred_errors_buffer.as_mut_ptr(); + // SAFETY: cur_start/prev_start ranges are bounds-checked above, and each access + // stays within [0, pred_errors_buffer.len()). + unsafe { + let e0 = + (((self.prediction[0] - val).abs() + PREDICTION_ROUND) >> PRED_EXTRA_BITS) as u32; + let e1 = + (((self.prediction[1] - val).abs() + PREDICTION_ROUND) >> PRED_EXTRA_BITS) as u32; + let e2 = + (((self.prediction[2] - val).abs() + PREDICTION_ROUND) >> PRED_EXTRA_BITS) as u32; + let e3 = + (((self.prediction[3] - val).abs() + PREDICTION_ROUND) >> PRED_EXTRA_BITS) as u32; + // Write current position + *buf.add(cur_start) = e0; + *buf.add(cur_start + 1) = e1; + *buf.add(cur_start + 2) = e2; + *buf.add(cur_start + 3) = e3; + // Accumulate to previous row + *buf.add(prev_start) = (*buf.add(prev_start)).wrapping_add(e0); + *buf.add(prev_start + 1) = (*buf.add(prev_start + 1)).wrapping_add(e1); + *buf.add(prev_start + 2) = (*buf.add(prev_start + 2)).wrapping_add(e2); + *buf.add(prev_start + 3) = (*buf.add(prev_start + 3)).wrapping_add(e3); } } #[inline(always)] + #[allow(unsafe_code)] pub fn predict_and_property( &mut self, pos: (usize, usize), xsize: usize, data: &PredictionData, ) -> (i64, i32) { - let (cur_row, prev_row) = if pos.1 & 1 != 0 { - (0, xsize + 2) + self.predict_impl(pos, xsize, data, true) + } + + /// Predict without computing the WP property (property 15). + /// Use when the tree doesn't split on property 15. + #[inline(always)] + #[allow(unsafe_code)] + pub fn predict_no_property( + &mut self, + pos: (usize, usize), + xsize: usize, + data: &PredictionData, + ) -> i64 { + self.predict_impl(pos, xsize, data, false).0 + } + + /// Interior predict with WP property. No edge checks. + #[inline(always)] + #[allow(unsafe_code)] + pub fn predict_and_property_interior( + &mut self, + x: usize, + xsize: usize, + data: &PredictionData, + ) -> (i64, i32) { + self.predict_interior::(x, xsize, data) + } + + /// Interior predict without WP property. No edge checks. + #[inline(always)] + #[allow(unsafe_code)] + pub fn predict_no_property_interior( + &mut self, + x: usize, + xsize: usize, + data: &PredictionData, + ) -> i64 { + self.predict_interior::(x, xsize, data).0 + } + + /// Interior version: no edge checks. x > 0 and x < xsize-1 guaranteed. + #[inline(always)] + #[allow(unsafe_code)] + pub fn predict_interior( + &mut self, + x: usize, + _xsize: usize, + data: &PredictionData, + ) -> (i64, i32) { + let cur_row = self.cur_row; + let prev_row = self.prev_row; + let pos_n = prev_row + x; + // No edge checks: x > 0 and x < xsize-1 guaranteed in interior + let pos_ne = pos_n + 1; + let pos_nw = pos_n - 1; + // Direct pointer access to error buffers -- avoids get_errors_at_pos overhead + let base = self.pred_errors_buffer.as_ptr(); + let off_n = pos_n * NUM_PREDICTORS; + let off_ne = pos_ne * NUM_PREDICTORS; + let off_nw = pos_nw * NUM_PREDICTORS; + let mut weights = [0u32; NUM_PREDICTORS]; + // SAFETY: pos_n/pos_ne/pos_nw are valid row positions, so each computed offset + // points to 4 contiguous predictor-error entries within pred_errors_buffer. + unsafe { + weights[0] = error_weight( + (*base.add(off_n)) + .wrapping_add(*base.add(off_ne)) + .wrapping_add(*base.add(off_nw)), + self.maxweights[0], + ); + weights[1] = error_weight( + (*base.add(off_n + 1)) + .wrapping_add(*base.add(off_ne + 1)) + .wrapping_add(*base.add(off_nw + 1)), + self.maxweights[1], + ); + weights[2] = error_weight( + (*base.add(off_n + 2)) + .wrapping_add(*base.add(off_ne + 2)) + .wrapping_add(*base.add(off_nw + 2)), + self.maxweights[2], + ); + weights[3] = error_weight( + (*base.add(off_n + 3)) + .wrapping_add(*base.add(off_ne + 3)) + .wrapping_add(*base.add(off_nw + 3)), + self.maxweights[3], + ); + } + let n = add_bits(data.top); + let w = add_bits(data.left); + let ne = add_bits(data.topright); + let nw = add_bits(data.topleft); + let nn = add_bits(data.toptop); + + let err_base = self.error.as_ptr(); + // SAFETY: x > 0 in the interior path, and pos_n/pos_nw/pos_ne are all valid + // indexes within error.len() = (xsize + 2) * 2. + let (te_w, te_n, te_nw, te_ne) = unsafe { + ( + *err_base.add(cur_row + x - 1) as i64, + *err_base.add(pos_n) as i64, + *err_base.add(pos_nw) as i64, + *err_base.add(pos_ne) as i64, + ) + }; + let sum_wn = te_n + te_w; + + let p = if COMPUTE_PROPERTY { + let mut p = te_w; + if te_n.abs() > p.abs() { + p = te_n; + } + if te_nw.abs() > p.abs() { + p = te_nw; + } + if te_ne.abs() > p.abs() { + p = te_ne; + } + p } else { - (xsize + 2, 0) + 0 }; + + self.prediction[0] = w + ne - n; + self.prediction[1] = n - (((sum_wn + te_ne) * self.wp_header.p1c as i64) >> 5); + self.prediction[2] = w - (((sum_wn + te_nw) * self.wp_header.p2c as i64) >> 5); + self.prediction[3] = n + - ((te_nw * (self.wp_header.p3ca as i64) + + (te_n * (self.wp_header.p3cb as i64)) + + (te_ne * (self.wp_header.p3cc as i64)) + + ((nn - n) * (self.wp_header.p3cd as i64)) + + ((nw - w) * (self.wp_header.p3ce as i64))) + >> 5); + + self.pred = weighted_average(&self.prediction, &mut weights); + + if ((te_n ^ te_w) | (te_n ^ te_nw)) <= 0 { + let mx = w.max(ne.max(n)); + let mn = w.min(ne.min(n)); + self.pred = mn.max(mx.min(self.pred)); + } + ((self.pred + PREDICTION_ROUND) >> PRED_EXTRA_BITS, p as i32) + } + + #[inline(always)] + #[allow(unsafe_code)] + fn predict_impl( + &mut self, + pos: (usize, usize), + xsize: usize, + data: &PredictionData, + compute_property: bool, + ) -> (i64, i32) { + let cur_row = self.cur_row; + let prev_row = self.prev_row; let pos_n = prev_row + pos.0; let pos_ne = if pos.0 < xsize - 1 { pos_n + 1 } else { pos_n }; let pos_nw = if pos.0 > 0 { pos_n - 1 } else { pos_n }; @@ -463,7 +691,7 @@ impl WeightedPredictorState { errors_n[i] .wrapping_add(errors_ne[i]) .wrapping_add(errors_nw[i]), - self.wp_header.w(i).unwrap(), + self.maxweights[i], ); } let n = add_bits(data.top); @@ -475,23 +703,35 @@ impl WeightedPredictorState { let te_w = if pos.0 == 0 { 0 } else { - self.error[cur_row + pos.0 - 1] as i64 + // SAFETY: when pos.0 > 0, cur_row + pos.0 - 1 is a valid error index. + unsafe { *self.error.get_unchecked(cur_row + pos.0 - 1) as i64 } + }; + // SAFETY: pos_n/pos_nw/pos_ne are clamped to valid neighborhood positions + // in error.len() = (xsize + 2) * 2. + let (te_n, te_nw, te_ne) = unsafe { + ( + *self.error.get_unchecked(pos_n) as i64, + *self.error.get_unchecked(pos_nw) as i64, + *self.error.get_unchecked(pos_ne) as i64, + ) }; - let te_n = self.error[pos_n] as i64; - let te_nw = self.error[pos_nw] as i64; let sum_wn = te_n + te_w; - let te_ne = self.error[pos_ne] as i64; - let mut p = te_w; - if te_n.abs() > p.abs() { - p = te_n; - } - if te_nw.abs() > p.abs() { - p = te_nw; - } - if te_ne.abs() > p.abs() { - p = te_ne; - } + let p = if compute_property { + let mut p = te_w; + if te_n.abs() > p.abs() { + p = te_n; + } + if te_nw.abs() > p.abs() { + p = te_nw; + } + if te_ne.abs() > p.abs() { + p = te_ne; + } + p + } else { + 0 + }; self.prediction[0] = w + ne - n; self.prediction[1] = n - (((sum_wn + te_ne) * self.wp_header.p1c as i64) >> 5); @@ -542,6 +782,7 @@ mod tests { ysize: usize, ) -> (i64, i32) { let pos = (rng.next() as usize % xsize, rng.next() as usize % ysize); + state.set_row(pos.1, xsize); let res = state.predict_and_property( pos, xsize, diff --git a/jxl/src/frame/modular/transforms/apply.rs b/jxl/src/frame/modular/transforms/apply.rs index 2c74441aa..214b5984e 100644 --- a/jxl/src/frame/modular/transforms/apply.rs +++ b/jxl/src/frame/modular/transforms/apply.rs @@ -562,7 +562,7 @@ fn meta_apply_single_transform( let num_channels = transform.num_channels as usize; let num_colors = transform.num_colors as usize; let num_deltas = transform.num_deltas as usize; - let pred = Predictor::from_u32(transform.predictor_id) + let pred = Predictor::from_u32_fast(transform.predictor_id) .expect("header decoding should ensure a valid predictor"); check_equal_channels(channels, begin_channel, num_channels)?; // We already checked the bit_depth for all channels from `begin_channel` is diff --git a/jxl/src/frame/modular/transforms/palette.rs b/jxl/src/frame/modular/transforms/palette.rs index ae7c71103..88df75238 100644 --- a/jxl/src/frame/modular/transforms/palette.rs +++ b/jxl/src/frame/modular/transforms/palette.rs @@ -24,6 +24,7 @@ const SMALL_CUBE_BITS: usize = 2; // SMALL_CUBE ** 3 const LARGE_CUBE_OFFSET: usize = SMALL_CUBE * SMALL_CUBE * SMALL_CUBE; +#[inline(always)] fn scale(value: usize, bit_depth: usize) -> i32 { // return (value * ((1 << bit_depth) - 1)) / DENOM; // We only call this function with SMALL_CUBE or LARGE_CUBE - 1 as DENOM, @@ -37,9 +38,10 @@ fn scale(value: usize, bit_depth: usize) -> i32 { // The purpose of this function is solely to extend the interpretation of // palette indices to implicit values. If index < nb_deltas, indicating that the // result is a delta palette entry, it is the responsibility of the caller to -// treat it as such. -fn get_palette_value( - palette: &Image, +/// Look up palette value. `pal_row` is the pre-fetched palette row for channel `c`. +#[inline(always)] +fn get_palette_value_with_row( + pal_row: &[i32], index: isize, c: usize, palette_size: usize, @@ -161,11 +163,12 @@ fn get_palette_value( } scale::<{ LARGE_CUBE - 1 }>(index % LARGE_CUBE, bit_depth) } else { - palette.row(c)[index] + pal_row[index] } } } +#[inline(always)] pub fn do_palette_step_general( buf_in: &ModularChannel, buf_pal: &ModularChannel, @@ -184,31 +187,39 @@ pub fn do_palette_step_general( // Avoid touching "empty" channels with non-zero height. } else if num_deltas == 0 && predictor == Predictor::Zero { for (chan_index, out) in buf_out.iter_mut().enumerate() { + let pal_row = palette.row(chan_index); for y in 0..h { let row_index = buf_in.data.row(y); let row_out = out.data.row_mut(y); + #[allow(unsafe_code)] for x in 0..w { let index = row_index[x]; - let palette_value = get_palette_value( - palette, - index as isize, - /*c=*/ chan_index, - /*palette_size=*/ num_colors, - /*bit_depth=*/ bit_depth, - ); - row_out[x] = palette_value; + let idx = index as usize; + if idx < num_colors { + // SAFETY: idx < num_colors <= pal_row.len() + row_out[x] = unsafe { *pal_row.get_unchecked(idx) }; + } else { + row_out[x] = get_palette_value_with_row( + pal_row, + index as isize, + chan_index, + num_colors, + bit_depth, + ); + } } } } } else if predictor == Predictor::Weighted { let w = buf_in.data.size().0; for (chan_index, out) in buf_out.iter_mut().enumerate() { + let pal_row = palette.row(chan_index); let mut wp_state = WeightedPredictorState::new(wp_header, w); for y in 0..h { let idx = buf_in.data.row(y); for (x, &index) in idx.iter().enumerate() { - let palette_entry = get_palette_value( - palette, + let palette_entry = get_palette_value_with_row( + pal_row, index as isize, /*c=*/ chan_index, /*palette_size=*/ num_colors + num_deltas, @@ -230,11 +241,12 @@ pub fn do_palette_step_general( } } else { for (chan_index, out) in buf_out.iter_mut().enumerate() { + let pal_row = palette.row(chan_index); for y in 0..h { let idx = buf_in.data.row(y); for (x, &index) in idx.iter().enumerate() { - let palette_entry = get_palette_value( - palette, + let palette_entry = get_palette_value_with_row( + pal_row, index as isize, /*c=*/ chan_index, /*palette_size=*/ num_colors + num_deltas, @@ -254,6 +266,7 @@ pub fn do_palette_step_general( } } +#[inline(always)] #[allow(clippy::too_many_arguments)] fn get_prediction_data( buf: &mut [&mut ModularChannel], @@ -300,6 +313,7 @@ fn get_prediction_data( ) } +#[inline(always)] #[allow(clippy::too_many_arguments)] pub fn do_palette_step_one_group( buf_in: &ModularChannel, @@ -319,35 +333,74 @@ pub fn do_palette_step_one_group( let num_c = buf_out.len() / (grid_xsize * grid_ysize); let (xsize, ysize) = buf_out[0].data.size(); - for c in 0..num_c { - for y in 0..h { - let index_img = buf_in.data.row(y); + let palette_size = num_colors + num_deltas; + + if num_deltas == 0 { + // Fast path: no delta palette entries, just direct lookups. + // Avoids prediction data computation entirely. + for c in 0..num_c { + let pal_row = palette.row(c); let out_idx = c * grid_ysize * grid_xsize + grid_y * grid_xsize + grid_x; - for (x, &index) in index_img.iter().enumerate() { - let palette_entry = get_palette_value( - palette, - index as isize, - c, - /*palette_size=*/ num_colors + num_deltas, - /*bit_depth=*/ bit_depth, - ); - let val = if index < num_deltas as i32 { - let pred = predictor.predict_one( - get_prediction_data( - buf_out, out_idx, grid_x, grid_y, grid_xsize, x, y, xsize, ysize, - ), - /*wp_pred=*/ 0, + for y in 0..h { + let index_img = buf_in.data.row(y); + let out_row = buf_out[out_idx].data.row_mut(y); + #[allow(unsafe_code)] + for (x, &index) in index_img.iter().enumerate() { + // Fast path: direct palette lookup for valid indices (common case). + // Skip the multi-branch get_palette_value_with_row for the hot path. + let idx = index as usize; + if idx < palette_size { + // SAFETY: idx < palette_size <= pal_row.len() (palette is at least + // palette_size wide, validated during palette transform setup). + out_row[x] = unsafe { *pal_row.get_unchecked(idx) }; + } else { + // Rare case: implicit color cube or negative index + out_row[x] = get_palette_value_with_row( + pal_row, + index as isize, + c, + palette_size, + bit_depth, + ); + } + } + } + } + } else { + for c in 0..num_c { + let pal_row = palette.row(c); + let out_idx = c * grid_ysize * grid_xsize + grid_y * grid_xsize + grid_x; + for y in 0..h { + let index_img = buf_in.data.row(y); + for (x, &index) in index_img.iter().enumerate() { + let palette_entry = get_palette_value_with_row( + pal_row, + index as isize, + c, + palette_size, + bit_depth, ); - (pred + palette_entry as i64) as i32 - } else { - palette_entry - }; - buf_out[out_idx].data.row_mut(y)[x] = val; + let val = if index < num_deltas as i32 { + // Delta palette prediction may need cross-grid neighbors. + // Always use get_prediction_data to preserve exact behavior. + let pred = predictor.predict_one( + get_prediction_data( + buf_out, out_idx, grid_x, grid_y, grid_xsize, x, y, xsize, ysize, + ), + /*wp_pred=*/ 0, + ); + (pred + palette_entry as i64) as i32 + } else { + palette_entry + }; + buf_out[out_idx].data.row_mut(y)[x] = val; + } } } } } +#[inline(always)] #[allow(clippy::too_many_arguments)] pub fn do_palette_step_group_row( buf_in: &[&ModularChannel], @@ -371,8 +424,10 @@ pub fn do_palette_step_group_row( .sum(); let (xsize, ysize) = buf_out[0].data.size(); + let palette_size = num_colors + num_deltas; if predictor == Predictor::Weighted { for c in 0..num_c { + let pal_row = palette.row(c); let mut wp_state = WeightedPredictorState::new(wp_header, total_w); let out_row_idx = c * grid_ysize * grid_xsize + grid_y * grid_xsize; if grid_y > 0 { @@ -387,14 +442,15 @@ pub fn do_palette_step_group_row( let index_img = index_buf.data.row(y); let out_idx = out_row_idx + grid_x; for (x, &index) in index_img.iter().enumerate() { - let palette_entry = get_palette_value( - palette, + let palette_entry = get_palette_value_with_row( + pal_row, index as isize, c, - /*palette_size=*/ num_colors + num_deltas, - /*bit_depth=*/ bit_depth, + palette_size, + bit_depth, ); let val = if index < num_deltas as i32 { + // Delta palette prediction may need cross-grid neighbors. let prediction_data = get_prediction_data( buf_out, out_idx, grid_x, grid_y, grid_xsize, x, y, xsize, ysize, ); @@ -418,19 +474,21 @@ pub fn do_palette_step_group_row( } } else { for c in 0..num_c { + let pal_row = palette.row(c); for y in 0..h { for (grid_x, index_buf) in buf_in.iter().enumerate().take(grid_xsize) { let index_img = index_buf.data.row(y); let out_idx = c * grid_ysize * grid_xsize + grid_y * grid_xsize + grid_x; for (x, &index) in index_img.iter().enumerate() { - let palette_entry = get_palette_value( - palette, + let palette_entry = get_palette_value_with_row( + pal_row, index as isize, c, - /*palette_size=*/ num_colors + num_deltas, - /*bit_depth=*/ bit_depth, + palette_size, + bit_depth, ); let val = if index < num_deltas as i32 { + // Delta palette prediction may need cross-grid neighbors. let pred = predictor.predict_one( get_prediction_data( buf_out, out_idx, grid_x, grid_y, grid_xsize, x, y, xsize, diff --git a/jxl/src/frame/modular/transforms/squeeze.rs b/jxl/src/frame/modular/transforms/squeeze.rs index 5ec330b39..ab1532f02 100644 --- a/jxl/src/frame/modular/transforms/squeeze.rs +++ b/jxl/src/frame/modular/transforms/squeeze.rs @@ -453,6 +453,7 @@ simd_function!( } ); +#[inline(always)] pub fn do_hsqueeze_step( in_avg: &ImageRect<'_, i32>, in_res: &ImageRect<'_, i32>, @@ -648,6 +649,7 @@ simd_function!( } ); +#[inline(always)] pub fn do_vsqueeze_step( in_avg: &ImageRect<'_, i32>, in_res: &ImageRect<'_, i32>, diff --git a/jxl/src/frame/modular/tree.rs b/jxl/src/frame/modular/tree.rs index b5f0022a9..50f3e7036 100644 --- a/jxl/src/frame/modular/tree.rs +++ b/jxl/src/frame/modular/tree.rs @@ -34,12 +34,12 @@ pub enum TreeNode { /// Flattened tree node for optimized traversal (matches C++ FlatDecisionNode). /// Stores parent + info about both children to evaluate 3 nodes per iteration. -// TODO(hjanuschka): investigate performance of using a Rust enum here, and whether -// separating internal nodes and leaves into two arrays could save a branch. +/// Leaf nodes store `Predictor` directly to avoid runtime integer-to-enum conversion. #[derive(Debug, Clone, Copy)] pub(super) struct FlatTreeNode { property0: i32, // Property to test, -1 if leaf - splitval0_or_predictor: i32, // Split value, or predictor if leaf + splitval0: i32, // Split value (only used for split nodes) + predictor: Predictor, // Stored for leaf nodes (avoids runtime conversion) splitvals_or_multiplier: [i32; 2], // Child splitvals, or multiplier if leaf child_id: u32, // Index to first grandchild, or context if leaf properties_or_offset: [i16; 2], // Child properties, or offset if leaf @@ -50,12 +50,31 @@ impl FlatTreeNode { fn leaf(predictor: Predictor, offset: i32, multiplier: u32, context: u32) -> Self { Self { property0: -1, - splitval0_or_predictor: predictor as i32, + splitval0: 0, + predictor, splitvals_or_multiplier: [multiplier as i32, 0], child_id: context, properties_or_offset: [offset as i16, 0], } } + + #[inline] + fn split( + property0: i32, + splitval0: i32, + child_id: u32, + properties: [i16; 2], + splitvals: [i32; 2], + ) -> Self { + Self { + property0, + splitval0, + predictor: Predictor::Zero, // unused for split nodes + splitvals_or_multiplier: splitvals, + child_id, + properties_or_offset: properties, + } + } } pub struct Tree { @@ -210,18 +229,20 @@ const NUM_TREE_CONTEXTS: usize = 6; // Also, the first two properties (the static properties) should be already set by the caller. // All other properties should be 0 on the first call in a row. -/// Computes properties for tree traversal. Shared between flat and non-flat prediction. -/// Returns the weighted predictor prediction value. -#[inline] -fn compute_properties( +/// Computes all properties needed by the tree, guided by `used_mask`. +/// Bit `i` in `used_mask` means property `i` is referenced by some split node. +/// Properties not in the mask are skipped. +/// Note: property 9 (local gradient) must always be computed if property 8 is used, +/// since property 8 depends on the *previous* value of property 9. +#[inline(always)] +fn compute_properties_common( prediction_data: PredictionData, - xsize: usize, - wp_state: Option<&mut WeightedPredictorState>, - x: usize, - y: usize, references: &Image, property_buffer: &mut [i32], -) -> i64 { + x: usize, + y: usize, + used_mask: u32, +) { let PredictionData { left, top, @@ -232,26 +253,169 @@ fn compute_properties( toprightright: _, } = prediction_data; - // Position - property_buffer[2] = y as i32; - property_buffer[3] = x as i32; + // Properties 0,1 (channel, stream) are set once in init, never change. + + // Only compute properties that the tree actually splits on. + // used_mask is constant per channel, so branches predict perfectly. + // TESTED: unconditional (all 13 stores) was -7.3% (cache pressure). + // TESTED: hybrid (7 unconditional + 6 gated) was also worse. + // Full mask-based approach is best: branch cost < store savings. + + if used_mask & 0x000C != 0 { + if used_mask & (1 << 2) != 0 { + property_buffer[2] = y as i32; + } + if used_mask & (1 << 3) != 0 { + property_buffer[3] = x as i32; + } + } + if used_mask & 0x0030 != 0 { + if used_mask & (1 << 4) != 0 { + property_buffer[4] = top.wrapping_abs(); + } + if used_mask & (1 << 5) != 0 { + property_buffer[5] = left.wrapping_abs(); + } + } + if used_mask & 0x00C0 != 0 { + if used_mask & (1 << 6) != 0 { + property_buffer[6] = top; + } + if used_mask & (1 << 7) != 0 { + property_buffer[7] = left; + } + } + if used_mask & 0x0300 != 0 { + property_buffer[8] = left.wrapping_sub(property_buffer[9]); + property_buffer[9] = left.wrapping_add(top).wrapping_sub(topleft); + } + if used_mask & 0x7C00 != 0 { + if used_mask & (1 << 10) != 0 { + property_buffer[10] = left.wrapping_sub(topleft); + } + if used_mask & (1 << 11) != 0 { + property_buffer[11] = topleft.wrapping_sub(top); + } + if used_mask & (1 << 12) != 0 { + property_buffer[12] = top.wrapping_sub(topright); + } + if used_mask & (1 << 13) != 0 { + property_buffer[13] = top.wrapping_sub(toptop); + } + if used_mask & (1 << 14) != 0 { + property_buffer[14] = left.wrapping_sub(leftleft); + } + } + + // Reference properties - only copy if used (this involves a memcpy). + if used_mask >> NUM_NONREF_PROPERTIES as u32 != 0 { + let num_refs = references.size().0; + if num_refs != 0 { + let ref_properties = &mut property_buffer[NUM_NONREF_PROPERTIES..]; + ref_properties[..num_refs].copy_from_slice(&references.row(x)[..num_refs]); + } + } +} + +/// Computes properties without weighted predictor. Returns 0 for wp_pred. +#[inline(always)] +fn compute_properties_no_wp( + prediction_data: PredictionData, + references: &Image, + property_buffer: &mut [i32], + x: usize, + y: usize, + used_mask: u32, +) { + compute_properties_common( + prediction_data, + references, + property_buffer, + x, + y, + used_mask, + ); + if used_mask & (1 << 15) != 0 { + property_buffer[15] = 0; + } +} - // Neighbours - property_buffer[4] = top.wrapping_abs(); - property_buffer[5] = left.wrapping_abs(); - property_buffer[6] = top; - property_buffer[7] = left; +/// Computes properties with weighted predictor. Returns the WP prediction value. +#[inline(always)] +#[allow(clippy::too_many_arguments)] +fn compute_properties_with_wp( + prediction_data: PredictionData, + xsize: usize, + wp_state: &mut WeightedPredictorState, + references: &Image, + property_buffer: &mut [i32], + x: usize, + y: usize, + used_mask: u32, +) -> i64 { + compute_properties_common( + prediction_data, + references, + property_buffer, + x, + y, + used_mask, + ); + // If property 15 (WP max error) is used by the tree, compute it. + // Otherwise skip the expensive abs() comparisons. + if used_mask & (1 << 15) != 0 { + let (wp_pred, wp_prop) = wp_state.predict_and_property((x, y), xsize, &prediction_data); + property_buffer[15] = wp_prop; + wp_pred + } else { + wp_state.predict_no_property((x, y), xsize, &prediction_data) + } +} - // Local gradient - property_buffer[8] = left.wrapping_sub(property_buffer[9]); - property_buffer[9] = left.wrapping_add(top).wrapping_sub(topleft); +/// Interior version: no edge checks for WP predictor. x > 0 and x < xsize-1 guaranteed. +#[inline(always)] +#[allow(clippy::too_many_arguments, unsafe_code)] +fn compute_properties_with_wp_interior( + prediction_data: PredictionData, + xsize: usize, + wp_state: &mut WeightedPredictorState, + references: &Image, + property_buffer: &mut [i32], + x: usize, + y: usize, + used_mask: u32, +) -> i64 { + compute_properties_common( + prediction_data, + references, + property_buffer, + x, + y, + used_mask, + ); + if used_mask & (1 << 15) != 0 { + let (wp_pred, wp_prop) = wp_state.predict_and_property_interior(x, xsize, &prediction_data); + property_buffer[15] = wp_prop; + wp_pred + } else { + wp_state.predict_no_property_interior(x, xsize, &prediction_data) + } +} - // FFV1 context properties - property_buffer[10] = left.wrapping_sub(topleft); - property_buffer[11] = topleft.wrapping_sub(top); - property_buffer[12] = top.wrapping_sub(topright); - property_buffer[13] = top.wrapping_sub(toptop); - property_buffer[14] = left.wrapping_sub(leftleft); +/// Computes all properties for tree traversal (non-flat path). +/// Returns the weighted predictor prediction value. +#[inline] +fn compute_properties( + prediction_data: PredictionData, + xsize: usize, + wp_state: Option<&mut WeightedPredictorState>, + x: usize, + y: usize, + references: &Image, + property_buffer: &mut [i32], +) -> i64 { + // Non-flat path: compute all properties (cold path, no mask optimization) + compute_properties_common(prediction_data, references, property_buffer, x, y, u32::MAX); // Weighted predictor property. let (wp_pred, wp_prop) = wp_state @@ -259,13 +423,6 @@ fn compute_properties( .unwrap_or((0, 0)); property_buffer[15] = wp_prop; - // Reference properties. - let num_refs = references.size().0; - if num_refs != 0 { - let ref_properties = &mut property_buffer[NUM_NONREF_PROPERTIES..]; - ref_properties[..num_refs].copy_from_slice(&references.row(x)[..num_refs]); - } - wp_pred } @@ -341,8 +498,9 @@ pub(super) fn predict( } /// Optimized prediction using flat tree (matches C++ context_predict.h:351-371). +/// Note: prefer predict_flat_no_wp / predict_flat_with_wp for better codegen. #[inline] -#[allow(clippy::too_many_arguments)] +#[allow(clippy::too_many_arguments, unsafe_code, dead_code)] pub(super) fn predict_flat( flat_tree: &[FlatTreeNode], prediction_data: PredictionData, @@ -363,45 +521,229 @@ pub(super) fn predict_flat( property_buffer, ); - // Flat tree traversal + // Flat tree traversal -- this is the hottest loop in modular decoding. + // Uses targeted unsafe for bounds-check elimination matching C++ performance. + // Macro matches C++ TRAVERSE_THE_TREE pattern, unrolled 2x per loop iteration. let mut pos = 0; + + macro_rules! traverse { + ($pos:expr) => {{ + // SAFETY: pos is always in bounds -- the flat tree is built via BFS from a + // validated tree, and each split node computes pos = child_id + offset (0..3) + // which always lands within the flat_nodes array. + let node = unsafe { flat_tree.get_unchecked($pos) }; + if node.property0 < 0 { + // Leaf node -- predictor is stored directly, no conversion needed + let pred = node.predictor.predict_one(prediction_data, wp_pred); + return PredictionResult { + guess: pred + node.properties_or_offset[0] as i32 as i64, + multiplier: node.splitvals_or_multiplier[0] as u32, + context: node.child_id, + }; + } + // Split node: C++ logic from context_predict.h:361-365 + // SAFETY: property indices are bounded by max_property_count which is validated + // against property_buffer.len() when the tree is built and the buffer is allocated. + let p0 = unsafe { + *property_buffer.get_unchecked(node.property0 as usize) <= node.splitval0 + }; + // SAFETY: properties_or_offset[0] is validated to be a valid property index. + let off0 = unsafe { + (*property_buffer.get_unchecked(node.properties_or_offset[0] as usize) + <= node.splitvals_or_multiplier[0]) as u32 + }; + // SAFETY: properties_or_offset[1] is validated to be a valid property index. + let off1 = unsafe { + 2 | (*property_buffer.get_unchecked(node.properties_or_offset[1] as usize) + <= node.splitvals_or_multiplier[1]) as u32 + }; + (node.child_id + if p0 { off1 } else { off0 }) as usize + }}; + } + loop { - let node = &flat_tree[pos]; + pos = traverse!(pos); + pos = traverse!(pos); + } +} + +/// Specialized predict_flat without weighted predictor (avoids Option overhead). +#[inline(always)] +#[allow(unsafe_code)] +pub(super) fn predict_flat_no_wp( + flat_tree: &[FlatTreeNode], + prediction_data: PredictionData, + x: usize, + y: usize, + references: &Image, + property_buffer: &mut [i32], + used_mask: u32, +) -> PredictionResult { + compute_properties_no_wp( + prediction_data, + references, + property_buffer, + x, + y, + used_mask, + ); - if node.property0 < 0 { - // Leaf node - let predictor = Predictor::try_from(node.splitval0_or_predictor as u32).unwrap(); - let offset = node.properties_or_offset[0] as i32; - let multiplier = node.splitvals_or_multiplier[0] as u32; - let context = node.child_id; + let mut pos = 0; + macro_rules! traverse { + ($pos:expr) => {{ + // SAFETY: pos always points to a valid node in flat_tree by construction. + let node = unsafe { flat_tree.get_unchecked($pos) }; + if node.property0 < 0 { + let pred = node.predictor.predict_one(prediction_data, 0); + return PredictionResult { + guess: pred + node.properties_or_offset[0] as i32 as i64, + multiplier: node.splitvals_or_multiplier[0] as u32, + context: node.child_id, + }; + } + // SAFETY: node.property0 is a validated property index for property_buffer. + let p0 = unsafe { + *property_buffer.get_unchecked(node.property0 as usize) <= node.splitval0 + }; + // SAFETY: properties_or_offset[0] is a validated property index. + let off0 = unsafe { + (*property_buffer.get_unchecked(node.properties_or_offset[0] as usize) + <= node.splitvals_or_multiplier[0]) as u32 + }; + // SAFETY: properties_or_offset[1] is a validated property index. + let off1 = unsafe { + 2 | (*property_buffer.get_unchecked(node.properties_or_offset[1] as usize) + <= node.splitvals_or_multiplier[1]) as u32 + }; + (node.child_id + if p0 { off1 } else { off0 }) as usize + }}; + } + loop { + pos = traverse!(pos); + pos = traverse!(pos); + } +} - let pred = predictor.predict_one(prediction_data, wp_pred); +/// Specialized predict_flat with weighted predictor (avoids Option overhead). +#[inline(always)] +#[allow(clippy::too_many_arguments, unsafe_code)] +pub(super) fn predict_flat_with_wp( + flat_tree: &[FlatTreeNode], + prediction_data: PredictionData, + xsize: usize, + wp_state: &mut WeightedPredictorState, + x: usize, + y: usize, + references: &Image, + property_buffer: &mut [i32], + used_mask: u32, +) -> PredictionResult { + let wp_pred = compute_properties_with_wp( + prediction_data, + xsize, + wp_state, + references, + property_buffer, + x, + y, + used_mask, + ); - return PredictionResult { - guess: pred + offset as i64, - multiplier, - context, + let mut pos = 0; + macro_rules! traverse { + ($pos:expr) => {{ + // SAFETY: pos always points to a valid node in flat_tree by construction. + let node = unsafe { flat_tree.get_unchecked($pos) }; + if node.property0 < 0 { + let pred = node.predictor.predict_one(prediction_data, wp_pred); + return PredictionResult { + guess: pred + node.properties_or_offset[0] as i32 as i64, + multiplier: node.splitvals_or_multiplier[0] as u32, + context: node.child_id, + }; + } + // SAFETY: node.property0 is a validated property index for property_buffer. + let p0 = unsafe { + *property_buffer.get_unchecked(node.property0 as usize) <= node.splitval0 }; - } + // SAFETY: properties_or_offset[0] is a validated property index. + let off0 = unsafe { + (*property_buffer.get_unchecked(node.properties_or_offset[0] as usize) + <= node.splitvals_or_multiplier[0]) as u32 + }; + // SAFETY: properties_or_offset[1] is a validated property index. + let off1 = unsafe { + 2 | (*property_buffer.get_unchecked(node.properties_or_offset[1] as usize) + <= node.splitvals_or_multiplier[1]) as u32 + }; + (node.child_id + if p0 { off1 } else { off0 }) as usize + }}; + } + loop { + pos = traverse!(pos); + pos = traverse!(pos); + } +} - // Split node: C++ logic from context_predict.h:361-365 - let p0 = property_buffer[node.property0 as usize] <= node.splitval0_or_predictor; - let off0 = if property_buffer[node.properties_or_offset[0] as usize] - <= node.splitvals_or_multiplier[0] - { - 1 - } else { - 0 - }; - let off1 = if property_buffer[node.properties_or_offset[1] as usize] - <= node.splitvals_or_multiplier[1] - { - 3 - } else { - 2 - }; +/// Interior version of predict_flat_with_wp. No edge checks for WP predictor. +/// x > 0 and x < xsize-1 guaranteed. +#[inline(always)] +#[allow(clippy::too_many_arguments, unsafe_code)] +pub(super) fn predict_flat_with_wp_interior( + flat_tree: &[FlatTreeNode], + prediction_data: PredictionData, + xsize: usize, + wp_state: &mut WeightedPredictorState, + x: usize, + y: usize, + references: &Image, + property_buffer: &mut [i32], + used_mask: u32, +) -> PredictionResult { + let wp_pred = compute_properties_with_wp_interior( + prediction_data, + xsize, + wp_state, + references, + property_buffer, + x, + y, + used_mask, + ); - pos = (node.child_id + if p0 { off1 } else { off0 }) as usize; + let mut pos = 0; + macro_rules! traverse { + ($pos:expr) => {{ + // SAFETY: pos always points to a valid node in flat_tree by construction. + let node = unsafe { flat_tree.get_unchecked($pos) }; + if node.property0 < 0 { + let pred = node.predictor.predict_one(prediction_data, wp_pred); + return PredictionResult { + guess: pred + node.properties_or_offset[0] as i32 as i64, + multiplier: node.splitvals_or_multiplier[0] as u32, + context: node.child_id, + }; + } + // SAFETY: node.property0 is a validated property index for property_buffer. + let p0 = unsafe { + *property_buffer.get_unchecked(node.property0 as usize) <= node.splitval0 + }; + // SAFETY: properties_or_offset[0] is a validated property index. + let off0 = unsafe { + (*property_buffer.get_unchecked(node.properties_or_offset[0] as usize) + <= node.splitvals_or_multiplier[0]) as u32 + }; + // SAFETY: properties_or_offset[1] is a validated property index. + let off1 = unsafe { + 2 | (*property_buffer.get_unchecked(node.properties_or_offset[1] as usize) + <= node.splitvals_or_multiplier[1]) as u32 + }; + (node.child_id + if p0 { off1 } else { off0 }) as usize + }}; + } + loop { + pos = traverse!(pos); + pos = traverse!(pos); } } @@ -488,14 +830,17 @@ impl Tree { /// Build flat tree using BFS traversal (matches C++ encoding.cc:81-144). /// Each flat node stores parent + both children info to reduce branches. - pub(super) fn build_flat_tree(nodes: &[TreeNode]) -> Result> { + /// Returns (flat_nodes, used_properties_mask). + /// The mask has bit `i` set if property `i` is used in any split node. + pub(super) fn build_flat_tree(nodes: &[TreeNode]) -> Result<(Vec, u32)> { use std::collections::VecDeque; if nodes.is_empty() { - return Ok(vec![]); + return Ok((vec![], 0)); } let mut flat_nodes = Vec::new_with_capacity(nodes.len())?; + let mut used_mask: u32 = 0; let mut queue: VecDeque = VecDeque::new(); queue.push_back(0); // Start with root @@ -515,24 +860,20 @@ impl Tree { left, right, } => { + used_mask |= 1u32 << (*property as u32); // childID points to first of 4 grandchildren in output let child_id = (flat_nodes.len() + queue.len() + 1) as u32; - let mut flat = FlatTreeNode { - property0: *property as i32, - splitval0_or_predictor: *val, - splitvals_or_multiplier: [0, 0], - child_id, - properties_or_offset: [0, 0], - }; + let mut properties = [0i16; 2]; + let mut splitvals = [0i32; 2]; // Process left (i=0) and right (i=1) children for (i, &child_idx) in [*left as usize, *right as usize].iter().enumerate() { match &nodes[child_idx] { TreeNode::Leaf { .. } => { // Child is leaf: set property=0 and enqueue leaf twice - flat.properties_or_offset[i] = 0; - flat.splitvals_or_multiplier[i] = 0; + properties[i] = 0; + splitvals[i] = 0; queue.push_back(child_idx); queue.push_back(child_idx); } @@ -543,20 +884,27 @@ impl Tree { right: cr, } => { // Child is split: store property/splitval and enqueue grandchildren - flat.properties_or_offset[i] = *cp as i16; - flat.splitvals_or_multiplier[i] = *cv; + used_mask |= 1u32 << (*cp as u32); + properties[i] = *cp as i16; + splitvals[i] = *cv; queue.push_back(*cl as usize); queue.push_back(*cr as usize); } } } - flat_nodes.push(flat); + flat_nodes.push(FlatTreeNode::split( + *property as i32, + *val, + child_id, + properties, + splitvals, + )); } } } - Ok(flat_nodes) + Ok((flat_nodes, used_mask)) } pub fn max_property_count(&self) -> usize { diff --git a/jxl/src/image/internal.rs b/jxl/src/image/internal.rs index 0c77b4b30..3c9c03854 100644 --- a/jxl/src/image/internal.rs +++ b/jxl/src/image/internal.rs @@ -122,7 +122,7 @@ impl RawImageBuffer { } } - #[inline] + #[inline(always)] pub(super) fn byte_size(&self) -> (usize, usize) { (self.bytes_per_row, self.num_rows) } diff --git a/jxl/src/image/output_buffer.rs b/jxl/src/image/output_buffer.rs index 57c9f1280..9b417b172 100644 --- a/jxl/src/image/output_buffer.rs +++ b/jxl/src/image/output_buffer.rs @@ -58,48 +58,19 @@ impl<'a> JxlOutputBuffer<'a> { num_rows: usize, bytes_per_row: usize, ) -> Self { - Self::new_uninit_with_stride(buf, num_rows, bytes_per_row, bytes_per_row) - } - - pub fn new(buf: &'a mut [u8], num_rows: usize, bytes_per_row: usize) -> Self { - Self::new_with_stride(buf, num_rows, bytes_per_row, bytes_per_row) - } - - /// Creates a new JxlOutputBuffer from a slice of uninit data. - /// It is guaranteed that `buf` will never be used to write uninitalized data. - pub fn new_uninit_with_stride( - buf: &'a mut [MaybeUninit], - num_rows: usize, - bytes_per_row: usize, - byte_stride: usize, - ) -> Self { - assert_ne!(num_rows, 0); - assert!( - buf.len() - >= byte_stride - .checked_mul(num_rows - 1) - .unwrap() - .checked_add(bytes_per_row) - .unwrap() - ); + assert!(buf.len() >= bytes_per_row * num_rows); // SAFETY: The assert above guarantees that `buf` has enough space to satisfy the first // safety requirement, and the rest follow from borrowing from a &mut []. - unsafe { Self::new_from_ptr(buf.as_mut_ptr(), num_rows, bytes_per_row, byte_stride) } + unsafe { Self::new_from_ptr(buf.as_mut_ptr(), num_rows, bytes_per_row, bytes_per_row) } } - pub fn new_with_stride( - buf: &'a mut [u8], - num_rows: usize, - bytes_per_row: usize, - byte_stride: usize, - ) -> Self { - Self::new_uninit_with_stride( + pub fn new(buf: &'a mut [u8], num_rows: usize, bytes_per_row: usize) -> Self { + Self::new_uninit( // SAFETY: `new_uninit` guarantees that no uninit data is ever written to the passed-in // slice. Moreover, `T` and `MaybeUninit` have the same memory layout. unsafe { std::slice::from_raw_parts_mut(buf.as_mut_ptr().cast(), buf.len()) }, num_rows, bytes_per_row, - byte_stride, ) } @@ -119,7 +90,7 @@ impl<'a> JxlOutputBuffer<'a> { unsafe { self.inner.row_mut(row) } } - #[inline] + #[inline(always)] pub fn write_bytes(&mut self, row: usize, col: usize, bytes: &[u8]) { // SAFETY: We never use the returned slice to write uninit data, and we have write access // to the data. diff --git a/jxl/src/image/raw.rs b/jxl/src/image/raw.rs index 9c689e788..8cf76f962 100644 --- a/jxl/src/image/raw.rs +++ b/jxl/src/image/raw.rs @@ -46,6 +46,49 @@ impl OwnedRawImage { }) } + /// Like `new_zeroed_with_padding`, but only zeroes the padding region. + /// The main data area (byte_size) is left uninitialized. + /// + /// # Safety + /// The caller must ensure that the main data area is fully written before reading. + /// The offset/padding regions are zeroed for correct boundary behavior. + #[allow(unsafe_code)] + pub unsafe fn new_uninit_with_zeroed_padding( + byte_size: (usize, usize), + offset: (usize, usize), + mut padding: (usize, usize), + ) -> Result { + if !(padding.0 + byte_size.0).is_multiple_of(CACHE_LINE_BYTE_SIZE) { + padding.0 += CACHE_LINE_BYTE_SIZE - (padding.0 + byte_size.0) % CACHE_LINE_BYTE_SIZE; + } + let total_cols = byte_size.0 + padding.0; + let total_rows = byte_size.1 + padding.1; + let mut data = RawImageBuffer::try_allocate((total_cols, total_rows), true)?; + + // Zero only the padding regions: + // 1. Top offset rows (full width each) + for y in 0..offset.1 { + // SAFETY: y is in 0..total_rows, and data has total_rows rows. + let row = unsafe { data.row_mut(y) }; + row.fill(std::mem::MaybeUninit::new(0)); + } + // 2. Left offset columns + right padding columns for data rows + for y in offset.1..total_rows { + // SAFETY: y is in 0..total_rows, and data has total_rows rows. + let row = unsafe { data.row_mut(y) }; + // Left offset region + row[..offset.0].fill(std::mem::MaybeUninit::new(0)); + // Right padding region: [byte_size.0..total_cols] + row[byte_size.0..total_cols].fill(std::mem::MaybeUninit::new(0)); + } + + Ok(Self { + data, + offset, + padding, + }) + } + pub fn get_rect_including_padding_mut(&mut self, rect: Rect) -> RawImageRectMut<'_> { RawImageRectMut { // Safety note: we are lending exclusive ownership to RawImageRectMut. diff --git a/jxl/src/image/typed.rs b/jxl/src/image/typed.rs index bdd07839d..e75d27462 100644 --- a/jxl/src/image/typed.rs +++ b/jxl/src/image/typed.rs @@ -36,6 +36,30 @@ impl Image { Ok(Self::from_raw(img)) } + /// Like `new_with_padding`, but only zeroes the padding region. + /// The main data area is left uninitialized -- caller must write all pixels before reading. + /// + /// # Safety + /// The caller must fully write the main data area before reading from it. + #[allow(unsafe_code)] + pub unsafe fn new_uninit_with_zeroed_padding( + size: (usize, usize), + offset: (usize, usize), + padding: (usize, usize), + ) -> Result> { + let s = T::DATA_TYPE_ID.size(); + // SAFETY: this function has the same safety contract as + // OwnedRawImage::new_uninit_with_zeroed_padding, and we only scale dimensions. + let img = unsafe { + OwnedRawImage::new_uninit_with_zeroed_padding( + (size.0 * s, size.1), + (offset.0 * s, offset.1), + (padding.0 * s, padding.1), + )? + }; + Ok(Self::from_raw(img)) + } + #[instrument(ret, err)] pub fn new(size: (usize, usize)) -> Result> { Self::new_with_padding(size, (0, 0), (0, 0)) diff --git a/jxl/src/render/channels.rs b/jxl/src/render/channels.rs index 1dbd1afbf..0cd61f4d6 100644 --- a/jxl/src/render/channels.rs +++ b/jxl/src/render/channels.rs @@ -3,7 +3,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -use crate::util::SmallVec; +use crate::util::StackVec; /// Multi-row channel accessor for immutable access. /// @@ -11,9 +11,10 @@ use crate::util::SmallVec; /// and `channels[ch][row]` returns `&[T]` (pixels for a specific row). /// /// This eliminates nested Vec collections while maintaining the same indexing syntax. +/// Uses StackVec (no enum discriminant) for zero-overhead access in the hot render loop. pub struct Channels<'a, T> { // The number of input rows should be maximized by the EPF0 stage, which has 21. - pub(crate) row_data: SmallVec<&'a [T], 32>, + pub(crate) row_data: StackVec<&'a [T], 32>, num_channels: usize, pub(crate) rows_per_channel: usize, } @@ -25,8 +26,9 @@ impl<'a, T> Channels<'a, T> { /// * `row_data` - Flat vector of all rows for all channels (length = num_channels * rows_per_channel) /// * `num_channels` - Number of channels /// * `rows_per_channel` - Number of rows per channel (typically 2*BORDER+1) + #[inline(always)] pub fn new( - row_data: SmallVec<&'a [T], 32>, + row_data: StackVec<&'a [T], 32>, num_channels: usize, rows_per_channel: usize, ) -> Self { @@ -43,16 +45,19 @@ impl<'a, T> Channels<'a, T> { } /// Returns the number of channels. + #[inline(always)] pub fn len(&self) -> usize { self.num_channels } /// Returns true if there are no channels. + #[inline(always)] pub fn is_empty(&self) -> bool { self.num_channels == 0 } /// Returns an iterator over channel slices. + #[inline(always)] pub fn iter(&self) -> impl Iterator { (0..self.num_channels).map(move |ch| &self[ch]) } @@ -62,6 +67,7 @@ impl<'a, T> Channels<'a, T> { impl<'a, T> std::ops::Index for Channels<'a, T> { type Output = [&'a [T]]; + #[inline(always)] fn index(&self, ch: usize) -> &[&'a [T]] { let start = ch * self.rows_per_channel; &self.row_data[start..start + self.rows_per_channel] @@ -74,7 +80,7 @@ impl<'a, T> std::ops::Index for Channels<'a, T> { /// and `channels[ch][row]` returns `&mut [T]` (pixels for a specific row). pub struct ChannelsMut<'a, T> { // The number of output rows should be maximized by the Upsample8 stage, which has 8. - pub(crate) row_data: SmallVec<&'a mut [T], 8>, + pub(crate) row_data: StackVec<&'a mut [T], 8>, num_channels: usize, pub(crate) rows_per_channel: usize, } @@ -86,8 +92,9 @@ impl<'a, T> ChannelsMut<'a, T> { /// * `row_data` - Flat vector of all mutable rows for all channels /// * `num_channels` - Number of channels /// * `rows_per_channel` - Number of rows per channel (typically 1 << SHIFT) + #[inline(always)] pub fn new( - row_data: SmallVec<&'a mut [T], 8>, + row_data: StackVec<&'a mut [T], 8>, num_channels: usize, rows_per_channel: usize, ) -> Self { @@ -104,11 +111,13 @@ impl<'a, T> ChannelsMut<'a, T> { } /// Returns the number of channels. + #[inline(always)] pub fn len(&self) -> usize { self.num_channels } /// Returns true if there are no channels. + #[inline(always)] pub fn is_empty(&self) -> bool { self.num_channels == 0 } @@ -116,6 +125,7 @@ impl<'a, T> ChannelsMut<'a, T> { /// Splits the first 3 channels into separate mutable slices. /// Returns a tuple containing mutable references to each channel's rows. #[allow(clippy::type_complexity)] + #[inline(always)] pub fn split_first_3_mut( &mut self, ) -> (&mut [&'a mut [T]], &mut [&'a mut [T]], &mut [&'a mut [T]]) { @@ -133,6 +143,7 @@ impl<'a, T> ChannelsMut<'a, T> { /// Returns a mutable iterator over all channels. /// Each item is a mutable slice of rows for that channel. + #[inline(always)] pub fn iter_mut(&mut self) -> impl Iterator { let rpc = self.rows_per_channel; self.row_data.chunks_mut(rpc) @@ -143,6 +154,7 @@ impl<'a, T> ChannelsMut<'a, T> { impl<'a, T> std::ops::Index for ChannelsMut<'a, T> { type Output = [&'a mut [T]]; + #[inline(always)] fn index(&self, ch: usize) -> &[&'a mut [T]] { let start = ch * self.rows_per_channel; &self.row_data[start..start + self.rows_per_channel] @@ -151,6 +163,7 @@ impl<'a, T> std::ops::Index for ChannelsMut<'a, T> { /// Implement mutable indexing: &mut channels[ch] returns &mut [&mut [T]] impl<'a, T> std::ops::IndexMut for ChannelsMut<'a, T> { + #[inline(always)] fn index_mut(&mut self, ch: usize) -> &mut [&'a mut [T]] { let start = ch * self.rows_per_channel; &mut self.row_data[start..start + self.rows_per_channel] diff --git a/jxl/src/render/low_memory_pipeline/helpers.rs b/jxl/src/render/low_memory_pipeline/helpers.rs index 7f8214ff3..0f3843844 100644 --- a/jxl/src/render/low_memory_pipeline/helpers.rs +++ b/jxl/src/render/low_memory_pipeline/helpers.rs @@ -9,26 +9,29 @@ use crate::render::low_memory_pipeline::render_group::ChannelVec; /// Panics if any of the indices are out of bounds or /// (idx[i].0, idx[i].1) == (idx[j].0, idx[j].1) for i != j or indices are not /// sorted lexicographically. +#[allow(unsafe_code)] pub(super) fn get_distinct_indices<'a, T>( vals: &'a mut [impl AsMut<[T]>], idx: &[(usize, usize, usize)], ) -> ChannelVec<&'a mut T> { - let mut answer_buffer = ChannelVec::new(); - for _ in 0..idx.len() { - answer_buffer.push(None); + // Build output directly using pointer math to avoid Option overhead. + // idx is sorted lexicographically by (a, b), so we iterate vals in order. + let mut result: ChannelVec<&'a mut T> = ChannelVec::new(); + // Pre-fill with dummy values that will be overwritten. + // We need the result in idx[i].2 order, so we use a position-indexed buffer. + let n = idx.len(); + let mut ptrs: ChannelVec<*mut T> = ChannelVec::new(); + for _ in 0..n { + ptrs.push(std::ptr::null_mut()); } - // TODO(veluca): in theory, we don't really need to first create a vector of - // `Option`s that then get `unwrap`-ed separately. Currently, this function - // uses somewhere between 0.5 and 1.5% of the total runtime; if that number - // increases, it might be worth investigating how to speed this up. let mut targets = idx.iter(); let mut target = targets.next().unwrap(); 'outer: for (aa, bufs) in vals.iter_mut().enumerate() { for (bb, buf) in bufs.as_mut().iter_mut().enumerate() { let (a, b, pos) = target; if aa == *a && bb == *b { - answer_buffer[*pos] = Some(buf); + ptrs[*pos] = buf as *mut T; if let Some(t) = targets.next() { target = t; } else { @@ -38,8 +41,12 @@ pub(super) fn get_distinct_indices<'a, T>( } } - answer_buffer - .iter_mut() - .map(|x| std::mem::take(x).expect("Not all elements were found")) - .collect() + for p in ptrs.iter() { + debug_assert!(!p.is_null(), "Not all elements were found"); + // SAFETY: Each pointer was obtained from a distinct &mut T in vals, + // and the lexicographic sort + distinctness guarantees no aliasing. + result.push(unsafe { &mut **p }); + } + + result } diff --git a/jxl/src/render/low_memory_pipeline/render_group.rs b/jxl/src/render/low_memory_pipeline/render_group.rs index 6f9b65b67..6760d0e8f 100644 --- a/jxl/src/render/low_memory_pipeline/render_group.rs +++ b/jxl/src/render/low_memory_pipeline/render_group.rs @@ -13,14 +13,14 @@ use crate::{ internal::{ChannelInfo, Stage}, low_memory_pipeline::{helpers::get_distinct_indices, run_stage::ExtraInfo}, }, - util::{ShiftRightCeil, SmallVec, mirror, tracing_wrappers::*}, + util::{ShiftRightCeil, StackVec, mirror, tracing_wrappers::*}, }; use super::{LowMemoryRenderPipeline, row_buffers::RowBuffer}; // Most images have at most 7 channels (RGBA + noise extra channels). // 8 gives a bit extra leeway and makes the size a power of two. -pub(super) type ChannelVec = SmallVec; +pub(super) type ChannelVec = StackVec; fn apply_x_padding( input_type: DataTypeTag, diff --git a/jxl/src/render/low_memory_pipeline/row_buffers.rs b/jxl/src/render/low_memory_pipeline/row_buffers.rs index 4cf01155d..969bcb476 100644 --- a/jxl/src/render/low_memory_pipeline/row_buffers.rs +++ b/jxl/src/render/low_memory_pipeline/row_buffers.rs @@ -10,7 +10,7 @@ use crate::{ image::{DataTypeTag, ImageDataType}, render::MAX_BORDER, util::{ - CACHE_LINE_BYTE_SIZE, CacheLine, SmallVec, num_per_cache_line, slice_from_cachelines, + CACHE_LINE_BYTE_SIZE, CacheLine, StackVec, num_per_cache_line, slice_from_cachelines, slice_from_cachelines_mut, }, }; @@ -68,14 +68,14 @@ impl RowBuffer { Ok(result) } - #[inline] + #[inline(always)] pub fn get_row(&self, row: usize) -> &[T] { let row_idx = row & (self.num_rows - 1); let start = row_idx * self.row_stride; slice_from_cachelines(&self.buffer[start..start + self.row_stride]) } - #[inline] + #[inline(always)] pub fn get_row_mut(&mut self, row: usize) -> &mut [T] { let row_idx = row & (self.num_rows - 1); let stride = self.row_stride; @@ -83,11 +83,12 @@ impl RowBuffer { slice_from_cachelines_mut(&mut self.buffer[start..start + stride]) } + #[inline(always)] pub fn get_rows_mut( &mut self, y: Range, xoffset: usize, - ) -> SmallVec<&mut [T], 8> { + ) -> StackVec<&mut [T], 8> { assert!(y.clone().count() <= self.num_rows); let first_row_idx = y.start & (self.num_rows - 1); let stride = self.row_stride; @@ -104,6 +105,29 @@ impl RowBuffer { .collect() } + /// Push rows directly into an existing StackVec, avoiding temporary allocation. + pub fn push_rows_mut<'a, T: ImageDataType>( + &'a mut self, + y: Range, + xoffset: usize, + out: &mut StackVec<&'a mut [T], 8>, + ) { + assert!(y.clone().count() <= self.num_rows); + let first_row_idx = y.start & (self.num_rows - 1); + let stride = self.row_stride; + let start = first_row_idx * stride; + let num_pre = (y.clone().count() + first_row_idx).saturating_sub(self.num_rows); + let num_post = y.clone().count() - num_pre; + let buf = &mut self.buffer[..]; + let (pre, post) = buf.split_at_mut(start); + for chunk in post.chunks_exact_mut(stride).take(num_post) { + out.push(&mut slice_from_cachelines_mut(chunk)[xoffset..]); + } + for chunk in pre.chunks_exact_mut(stride).take(num_pre) { + out.push(&mut slice_from_cachelines_mut(chunk)[xoffset..]); + } + } + pub const fn x0_offset() -> usize { assert!(num_per_cache_line::() >= MAX_BORDER); num_per_cache_line::() diff --git a/jxl/src/render/low_memory_pipeline/run_stage.rs b/jxl/src/render/low_memory_pipeline/run_stage.rs index 5acced8b3..ba9686f24 100644 --- a/jxl/src/render/low_memory_pipeline/run_stage.rs +++ b/jxl/src/render/low_memory_pipeline/run_stage.rs @@ -9,9 +9,8 @@ use crate::{ render::{ Channels, ChannelsMut, RunInPlaceStage, internal::{PipelineBuffer, RunInOutStage}, - low_memory_pipeline::render_group::ChannelVec, }, - util::{ShiftRightCeil, SmallVec, mirror, tracing_wrappers::*}, + util::{ShiftRightCeil, StackVec, mirror, tracing_wrappers::*}, }; use super::{ @@ -56,10 +55,10 @@ impl RunInPlaceStage for T { let xpre = if start_of_row { 0 } else { out_extra_x }; let xstart = x0 - xpre; let xend = x0 + xsize + if end_of_row { 0 } else { out_extra_x }; - let mut rows: ChannelVec<_> = buffers - .iter_mut() - .map(|x| &mut x.get_row_mut::(current_row)[xstart..]) - .collect(); + let mut rows: StackVec<&mut [T::Type], 8> = StackVec::new(); + for x in buffers.iter_mut() { + rows.push(&mut x.get_row_mut::(current_row)[xstart..]); + } self.process_row_chunk( (group_x0 - xpre, current_row), @@ -103,10 +102,10 @@ impl RunInOutStage for T { out_extra_x.shrc(T::SHIFT.0) }; - // Build flat input rows: all rows for all channels in one Vec + // Build flat input rows: all rows for all channels in one StackVec let input_rows_per_channel = (2 * Self::BORDER.1 + 1) as usize; let num_channels = input_buffers.len(); - let mut input_row_data = SmallVec::new(); + let mut input_row_data: StackVec<&[T::InputT], 32> = StackVec::new(); for x in input_buffers.iter() { for iy in -ibordery..=ibordery { input_row_data.push( @@ -117,10 +116,10 @@ impl RunInOutStage for T { } let input_rows = Channels::new(input_row_data, num_channels, input_rows_per_channel); - // Build flat output rows: all rows for all channels in one Vec + // Build flat output rows: all rows for all channels in one StackVec let output_rows_per_channel = 1 << T::SHIFT.1; let num_output_channels = output_buffers.len(); - let mut output_row_data = SmallVec::new(); + let mut output_row_data: StackVec<&mut [T::OutputT], 8> = StackVec::new(); // optimize for the common case of a single output row per channel. if output_rows_per_channel == 1 { // Use OutputT's x0_offset, not InputT's - they differ for type conversions (e.g., f32→u8). @@ -132,11 +131,11 @@ impl RunInOutStage for T { } } else { for x in output_buffers.iter_mut() { - let rows = x.get_rows_mut::( + x.push_rows_mut::( (current_row << T::SHIFT.1)..((current_row + 1) << T::SHIFT.1), RowBuffer::x0_offset::() - (xpre << T::SHIFT.0), + &mut output_row_data, ); - output_row_data.extend_sv(rows); } } let mut output_rows = ChannelsMut::new( diff --git a/jxl/src/render/simple_pipeline/run_stage.rs b/jxl/src/render/simple_pipeline/run_stage.rs index bfaea9943..5b73ddc4b 100644 --- a/jxl/src/render/simple_pipeline/run_stage.rs +++ b/jxl/src/render/simple_pipeline/run_stage.rs @@ -13,7 +13,7 @@ use crate::{ RenderPipelineInOutStage, RenderPipelineInPlaceStage, RunInOutStage, RunInPlaceStage, internal::PipelineBuffer, }, - util::{SmallVec, mirror, round_up_size_to_cache_line, tracing_wrappers::*}, + util::{StackVec, mirror, round_up_size_to_cache_line, tracing_wrappers::*}, }; impl PipelineBuffer for Image { @@ -150,7 +150,7 @@ impl RunInOutStage> for T { // Build flat input rows: all rows for all channels in one Vec let num_input_channels = buffer_in.len(); let input_rows_per_channel = buffer_in[0].len(); - let mut input_row_data = SmallVec::new(); + let mut input_row_data: StackVec<&[_], 32> = StackVec::new(); for ch_buf in buffer_in.iter() { for row in ch_buf.iter() { input_row_data.push(row as &[_]); @@ -165,7 +165,7 @@ impl RunInOutStage> for T { // Build flat output rows: all rows for all channels in one Vec let num_output_channels = buffer_out.len(); let output_rows_per_channel = buffer_out[0].len(); - let mut output_row_data = SmallVec::new(); + let mut output_row_data: StackVec<&mut [_], 8> = StackVec::new(); for ch_buf in buffer_out.iter_mut() { for row in ch_buf.iter_mut() { output_row_data.push(row as &mut [_]); diff --git a/jxl/src/render/stages/blending.rs b/jxl/src/render/stages/blending.rs index d53eda748..33ee2fae5 100644 --- a/jxl/src/render/stages/blending.rs +++ b/jxl/src/render/stages/blending.rs @@ -8,13 +8,12 @@ use std::sync::Arc; use crate::{ error::Result, features::{ - blending::perform_blending, + blending::perform_blending_with_tmp, patches::{PatchBlendMode, PatchBlending}, }, frame::ReferenceFrame, headers::{FileHeader, extra_channels::ExtraChannelInfo, frame_header::*}, render::RenderPipelineInPlaceStage, - util::slice, }; pub struct BlendingStage { @@ -27,6 +26,11 @@ pub struct BlendingStage { pub zeros: Vec, } +/// Per-thread state for BlendingStage: pre-allocated tmp buffer reused across calls. +struct BlendingState { + tmp: Vec>, +} + impl From<&BlendingInfo> for PatchBlending { fn from(info: &BlendingInfo) -> Self { let mode = match info.mode { @@ -76,12 +80,16 @@ impl RenderPipelineInPlaceStage for BlendingStage { c < 3 + self.extra_channels.len() } + fn init_local_state(&self, _thread_index: usize) -> Result>> { + Ok(Some(Box::new(BlendingState { tmp: Vec::new() }))) + } + fn process_row_chunk( &self, position: (usize, usize), xsize: usize, row: &mut [&mut [f32]], - _state: Option<&mut dyn std::any::Any>, + state: Option<&mut dyn std::any::Any>, ) { let num_ec = self.extra_channels.len(); let fg_y0 = self.frame_origin.1 + position.1 as isize; @@ -109,47 +117,55 @@ impl RenderPipelineInPlaceStage for BlendingStage { let bg_x1: usize = bg_x1 as usize; let fg_y0: usize = fg_y0 as usize; - // TODO(szabadka): Allocate a buffer for this when building the stage instead of when - // executing it. - let mut out = row - .iter_mut() - .map(|s| &mut s[..xsize]) - .collect::>(); - - let mut fg = vec![self.zeros.as_slice(); 3 + num_ec]; - - for (c, fg_ptr) in fg.iter_mut().enumerate().take(3) { - if self.reference_frames[self.blending_info.source as usize].is_some() { - *fg_ptr = &(self.reference_frames[self.blending_info.source as usize] - .as_ref() - .unwrap() - .frame[c] - .row(fg_y0)[fg_x0..fg_x1]); + // Use stack-allocated arrays to avoid per-row heap allocation. + // Max 16 channels (3 color + 13 extra) is well above JPEG XL's practical limit. + const MAX_CHANNELS: usize = 16; + let total_channels = 3 + num_ec; + debug_assert!(total_channels <= MAX_CHANNELS); + + // Build fg references on stack. + let zeros_slice: &[f32] = self.zeros.as_slice(); + let mut fg_buf: [&[f32]; MAX_CHANNELS] = [zeros_slice; MAX_CHANNELS]; + + if let Some(ref rf) = self.reference_frames[self.blending_info.source as usize] { + for (c, fg) in fg_buf.iter_mut().enumerate().take(3) { + *fg = &rf.frame[c].row(fg_y0)[fg_x0..fg_x1]; } } for i in 0..num_ec { - if self.reference_frames[self.ec_blending_info[i].source as usize].is_some() { - fg[3 + i] = &(self.reference_frames[self.ec_blending_info[i].source as usize] - .as_ref() - .unwrap() - .frame[3 + i] - .row(fg_y0)[fg_x0..fg_x1]); + if let Some(ref rf) = self.reference_frames[self.ec_blending_info[i].source as usize] { + fg_buf[3 + i] = &rf.frame[3 + i].row(fg_y0)[fg_x0..fg_x1]; } } let blending_info = PatchBlending::from(&self.blending_info); - let ec_blending_info: Vec = self - .ec_blending_info - .iter() - .map(PatchBlending::from) - .collect(); - - perform_blending( - &mut slice!(&mut out, .., bg_x0..bg_x1), - &fg, + // ec_blending_info on stack (max 13 extra channels). + let mut ec_blend_buf: [PatchBlending; MAX_CHANNELS] = [PatchBlending { + mode: PatchBlendMode::None, + alpha_channel: 0, + clamp: false, + }; MAX_CHANNELS]; + for (i, bi) in self.ec_blending_info.iter().enumerate() { + ec_blend_buf[i] = PatchBlending::from(bi); + } + + // Use pre-allocated tmp buffer from state to avoid per-call heap allocation. + let reusable_tmp = + state.and_then(|s| s.downcast_mut::().map(|bs| &mut bs.tmp)); + + // Build bg slice references on stack instead of using slice! macro (which allocates a Vec). + let mut bg_slices: [&mut [f32]; MAX_CHANNELS] = Default::default(); + for (i, r) in row[..total_channels].iter_mut().enumerate() { + bg_slices[i] = &mut r[bg_x0..bg_x1]; + } + + perform_blending_with_tmp( + &mut bg_slices[..total_channels], + &fg_buf[..total_channels], &blending_info, - &ec_blending_info, + &ec_blend_buf[..num_ec], &self.extra_channels, + reusable_tmp, ); } } diff --git a/jxl/src/util/fast_math.rs b/jxl/src/util/fast_math.rs index 6fa51988b..bb8922e62 100644 --- a/jxl/src/util/fast_math.rs +++ b/jxl/src/util/fast_math.rs @@ -12,7 +12,7 @@ use std::f32::consts::{PI, SQRT_2}; const POW2F_NUMER_COEFFS: [f32; 3] = [1.01749063e1, 4.88687798e1, 9.85506591e1]; const POW2F_DENOM_COEFFS: [f32; 4] = [2.10242958e-1, -2.22328856e-2, -1.94414990e1, 9.85506633e1]; -#[inline] +#[inline(always)] pub fn fast_cos(x: f32) -> f32 { // Step 1: range reduction to [0, 2pi) let pi2 = PI * 2.0; @@ -41,7 +41,7 @@ pub fn fast_cos(x: f32) -> f32 { } } -#[inline] +#[inline(always)] pub fn fast_erff(x: f32) -> f32 { // Formula from // https://en.wikipedia.org/wiki/Error_function#Numerical_approximations @@ -74,7 +74,7 @@ pub fn fast_erff_simd(d: D, x: D::F32Vec) -> D::F32Vec { result.copysign(x) } -#[inline] +#[inline(always)] pub fn fast_pow2f(x: f32) -> f32 { let x_floor = x.floor(); let exp = f32::from_bits(((x_floor as i32 + 127) as u32) << 23); @@ -122,7 +122,7 @@ const LOG2F_Q: [f32; 3] = [ 1.7409343003366853e-1, ]; -#[inline] +#[inline(always)] pub fn fast_log2f(x: f32) -> f32 { let x_bits = x.to_bits() as i32; let exp_bits = x_bits.wrapping_sub(0x3f2aaaab); @@ -147,16 +147,17 @@ pub fn fast_log2f_simd(d: D, x: D::F32Vec) -> D::F32Vec { } // Max relative error: ~3e-5 -#[inline] +#[inline(always)] pub fn fast_powf(base: f32, exp: f32) -> f32 { fast_pow2f(fast_log2f(base) * exp) } -#[inline] +#[inline(always)] pub fn fast_powf_simd(d: D, base: D::F32Vec, exp: D::F32Vec) -> D::F32Vec { fast_pow2f_simd(d, fast_log2f_simd(d, base) * exp) } +#[inline(always)] pub fn floor_log2_nonzero(x: u64) -> u32 { (u64::BITS as usize - 1) as u32 ^ x.leading_zeros() } diff --git a/jxl/src/util/mirror.rs b/jxl/src/util/mirror.rs index 4d7e6c2fc..656d7e561 100644 --- a/jxl/src/util/mirror.rs +++ b/jxl/src/util/mirror.rs @@ -4,6 +4,7 @@ // license that can be found in the LICENSE file. /// Mirror-reflects a value v to fit in a [0; s) range. +#[inline(always)] pub fn mirror(mut v: isize, s: usize) -> usize { // TODO(veluca): consider speeding this up if needed. loop { diff --git a/jxl/src/util/mod.rs b/jxl/src/util/mod.rs index 514820bcd..9b2994b68 100644 --- a/jxl/src/util/mod.rs +++ b/jxl/src/util/mod.rs @@ -19,6 +19,7 @@ pub mod ndarray; mod rational_poly; mod shift_right_ceil; mod smallvec; +mod stack_vec; pub mod tracing_wrappers; mod vec_helpers; mod xorshift128plus; @@ -36,5 +37,6 @@ pub(crate) use ndarray::*; pub use rational_poly::*; pub use shift_right_ceil::*; pub use smallvec::*; +pub use stack_vec::*; pub use vec_helpers::*; pub use xorshift128plus::*; diff --git a/jxl/src/util/rational_poly.rs b/jxl/src/util/rational_poly.rs index 12f97fb5e..9d7c5762e 100644 --- a/jxl/src/util/rational_poly.rs +++ b/jxl/src/util/rational_poly.rs @@ -9,7 +9,7 @@ use jxl_simd::{F32SimdVec, SimdDescriptor}; /// /// # Panics /// Panics if either `P` or `Q` is zero. -#[inline] +#[inline(always)] pub fn eval_rational_poly(x: f32, p: [f32; P], q: [f32; Q]) -> f32 { let yp = p.into_iter().rev().reduce(|yp, p| yp * x + p).unwrap(); let yq = q.into_iter().rev().reduce(|yq, q| yq * x + q).unwrap(); diff --git a/jxl/src/util/shift_right_ceil.rs b/jxl/src/util/shift_right_ceil.rs index 3d756b2b1..06be84637 100644 --- a/jxl/src/util/shift_right_ceil.rs +++ b/jxl/src/util/shift_right_ceil.rs @@ -14,6 +14,7 @@ pub trait ShiftRightCeil: Copy { impl + Sub + From> ShiftRightCeil for S { + #[inline(always)] fn shrc(self, rhs: T) -> Self where Self: Shr + Shl, diff --git a/jxl/src/util/smallvec.rs b/jxl/src/util/smallvec.rs index 579f75cf5..2bb2c838e 100644 --- a/jxl/src/util/smallvec.rs +++ b/jxl/src/util/smallvec.rs @@ -27,6 +27,7 @@ pub enum SmallVec { impl Deref for SmallVec { type Target = [T]; + #[inline(always)] fn deref(&self) -> &[T] { match self { SmallVec::Stack { len, data } => { @@ -41,6 +42,7 @@ impl Deref for SmallVec { } impl DerefMut for SmallVec { + #[inline(always)] fn deref_mut(&mut self) -> &mut [T] { match self { SmallVec::Stack { len, data } => { @@ -67,7 +69,7 @@ impl Default for SmallVec { } impl SmallVec { - #[inline] + #[inline(always)] pub fn new() -> Self { Self::Stack { // Safety note: len == 0 makes the safety invariant trivially true. @@ -76,7 +78,7 @@ impl SmallVec { } } - #[inline] + #[inline(always)] pub fn is_empty(&self) -> bool { match self { Self::Stack { len, .. } => *len == 0, @@ -84,7 +86,7 @@ impl SmallVec { } } - #[inline] + #[inline(always)] pub fn len(&self) -> usize { match self { Self::Stack { len, .. } => *len, @@ -140,7 +142,7 @@ impl SmallVec { } } - #[inline] + #[inline(always)] pub fn push(&mut self, val: T) { if self.len() + 1 > N { self.move_to_heap(); @@ -199,7 +201,7 @@ impl SmallVec { } impl FromIterator for SmallVec { - #[inline] + #[inline(always)] fn from_iter>(iter: I) -> Self { let mut ret = Self::new(); ret.extend(iter); diff --git a/jxl/src/util/stack_vec.rs b/jxl/src/util/stack_vec.rs new file mode 100644 index 000000000..d01b8ffe8 --- /dev/null +++ b/jxl/src/util/stack_vec.rs @@ -0,0 +1,113 @@ +// Copyright (c) the JPEG XL Project Authors. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +#![allow(unsafe_code)] + +use core::slice; +use std::mem::MaybeUninit; +use std::ops::{Deref, DerefMut}; + +/// A fixed-capacity, stack-only vector. No heap fallback, no enum discriminant. +/// Unlike SmallVec, every operation avoids a Stack/Heap match branch. +/// Use when the maximum size is known at compile time and guaranteed to fit. +pub struct StackVec { + len: usize, + data: [MaybeUninit; N], +} + +impl StackVec { + #[inline(always)] + pub fn new() -> Self { + Self { + len: 0, + data: [const { MaybeUninit::uninit() }; N], + } + } + + #[inline(always)] + pub fn len(&self) -> usize { + self.len + } + + #[inline(always)] + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + #[inline(always)] + pub fn push(&mut self, val: T) { + debug_assert!( + self.len < N, + "StackVec overflow: len={}, cap={}", + self.len, + N + ); + // SAFETY: we just checked len < N (in debug), and the caller must ensure capacity. + unsafe { + self.data.get_unchecked_mut(self.len).write(val); + } + self.len += 1; + } + + #[inline(always)] + pub fn clear(&mut self) { + // Drop existing elements + for i in 0..self.len { + // SAFETY: elements 0..len are initialized + unsafe { self.data[i].assume_init_drop() }; + } + self.len = 0; + } + + #[inline(always)] + pub fn extend>(&mut self, iter: I) { + for val in iter { + self.push(val); + } + } +} + +impl Default for StackVec { + #[inline(always)] + fn default() -> Self { + Self::new() + } +} + +impl Deref for StackVec { + type Target = [T]; + + #[inline(always)] + fn deref(&self) -> &[T] { + // SAFETY: the first `len` elements are initialized. + unsafe { slice::from_raw_parts(self.data.as_ptr().cast::(), self.len) } + } +} + +impl DerefMut for StackVec { + #[inline(always)] + fn deref_mut(&mut self) -> &mut [T] { + // SAFETY: the first `len` elements are initialized. + unsafe { slice::from_raw_parts_mut(self.data.as_mut_ptr().cast::(), self.len) } + } +} + +impl Drop for StackVec { + fn drop(&mut self) { + for i in 0..self.len { + // SAFETY: by invariant, elements 0..len are initialized. + unsafe { self.data[i].assume_init_drop() }; + } + } +} + +impl FromIterator for StackVec { + #[inline(always)] + fn from_iter>(iter: I) -> Self { + let mut ret = Self::new(); + ret.extend(iter); + ret + } +} diff --git a/jxl/src/util/xorshift128plus.rs b/jxl/src/util/xorshift128plus.rs index bf55805a1..5b323721e 100644 --- a/jxl/src/util/xorshift128plus.rs +++ b/jxl/src/util/xorshift128plus.rs @@ -47,6 +47,7 @@ impl Xorshift128Plus { Self { s0, s1 } } + #[inline(always)] pub fn fill(&mut self, random_bits: &mut [u64; Self::N]) { for ((s0, s1), random_bits) in self .s0