From 3ca76252d68933c5d8d53b8e528043c169159073 Mon Sep 17 00:00:00 2001 From: Helmut Januschka Date: Mon, 16 Mar 2026 08:41:29 +0100 Subject: [PATCH 1/3] perf: add palette transform fast paths --- jxl/src/frame/modular/transforms/palette.rs | 150 ++++++++++++++------ 1 file changed, 104 insertions(+), 46 deletions(-) diff --git a/jxl/src/frame/modular/transforms/palette.rs b/jxl/src/frame/modular/transforms/palette.rs index ae7c71103..88df75238 100644 --- a/jxl/src/frame/modular/transforms/palette.rs +++ b/jxl/src/frame/modular/transforms/palette.rs @@ -24,6 +24,7 @@ const SMALL_CUBE_BITS: usize = 2; // SMALL_CUBE ** 3 const LARGE_CUBE_OFFSET: usize = SMALL_CUBE * SMALL_CUBE * SMALL_CUBE; +#[inline(always)] fn scale(value: usize, bit_depth: usize) -> i32 { // return (value * ((1 << bit_depth) - 1)) / DENOM; // We only call this function with SMALL_CUBE or LARGE_CUBE - 1 as DENOM, @@ -37,9 +38,10 @@ fn scale(value: usize, bit_depth: usize) -> i32 { // The purpose of this function is solely to extend the interpretation of // palette indices to implicit values. If index < nb_deltas, indicating that the // result is a delta palette entry, it is the responsibility of the caller to -// treat it as such. -fn get_palette_value( - palette: &Image, +/// Look up palette value. `pal_row` is the pre-fetched palette row for channel `c`. +#[inline(always)] +fn get_palette_value_with_row( + pal_row: &[i32], index: isize, c: usize, palette_size: usize, @@ -161,11 +163,12 @@ fn get_palette_value( } scale::<{ LARGE_CUBE - 1 }>(index % LARGE_CUBE, bit_depth) } else { - palette.row(c)[index] + pal_row[index] } } } +#[inline(always)] pub fn do_palette_step_general( buf_in: &ModularChannel, buf_pal: &ModularChannel, @@ -184,31 +187,39 @@ pub fn do_palette_step_general( // Avoid touching "empty" channels with non-zero height. } else if num_deltas == 0 && predictor == Predictor::Zero { for (chan_index, out) in buf_out.iter_mut().enumerate() { + let pal_row = palette.row(chan_index); for y in 0..h { let row_index = buf_in.data.row(y); let row_out = out.data.row_mut(y); + #[allow(unsafe_code)] for x in 0..w { let index = row_index[x]; - let palette_value = get_palette_value( - palette, - index as isize, - /*c=*/ chan_index, - /*palette_size=*/ num_colors, - /*bit_depth=*/ bit_depth, - ); - row_out[x] = palette_value; + let idx = index as usize; + if idx < num_colors { + // SAFETY: idx < num_colors <= pal_row.len() + row_out[x] = unsafe { *pal_row.get_unchecked(idx) }; + } else { + row_out[x] = get_palette_value_with_row( + pal_row, + index as isize, + chan_index, + num_colors, + bit_depth, + ); + } } } } } else if predictor == Predictor::Weighted { let w = buf_in.data.size().0; for (chan_index, out) in buf_out.iter_mut().enumerate() { + let pal_row = palette.row(chan_index); let mut wp_state = WeightedPredictorState::new(wp_header, w); for y in 0..h { let idx = buf_in.data.row(y); for (x, &index) in idx.iter().enumerate() { - let palette_entry = get_palette_value( - palette, + let palette_entry = get_palette_value_with_row( + pal_row, index as isize, /*c=*/ chan_index, /*palette_size=*/ num_colors + num_deltas, @@ -230,11 +241,12 @@ pub fn do_palette_step_general( } } else { for (chan_index, out) in buf_out.iter_mut().enumerate() { + let pal_row = palette.row(chan_index); for y in 0..h { let idx = buf_in.data.row(y); for (x, &index) in idx.iter().enumerate() { - let palette_entry = get_palette_value( - palette, + let palette_entry = get_palette_value_with_row( + pal_row, index as isize, /*c=*/ chan_index, /*palette_size=*/ num_colors + num_deltas, @@ -254,6 +266,7 @@ pub fn do_palette_step_general( } } +#[inline(always)] #[allow(clippy::too_many_arguments)] fn get_prediction_data( buf: &mut [&mut ModularChannel], @@ -300,6 +313,7 @@ fn get_prediction_data( ) } +#[inline(always)] #[allow(clippy::too_many_arguments)] pub fn do_palette_step_one_group( buf_in: &ModularChannel, @@ -319,35 +333,74 @@ pub fn do_palette_step_one_group( let num_c = buf_out.len() / (grid_xsize * grid_ysize); let (xsize, ysize) = buf_out[0].data.size(); - for c in 0..num_c { - for y in 0..h { - let index_img = buf_in.data.row(y); + let palette_size = num_colors + num_deltas; + + if num_deltas == 0 { + // Fast path: no delta palette entries, just direct lookups. + // Avoids prediction data computation entirely. + for c in 0..num_c { + let pal_row = palette.row(c); let out_idx = c * grid_ysize * grid_xsize + grid_y * grid_xsize + grid_x; - for (x, &index) in index_img.iter().enumerate() { - let palette_entry = get_palette_value( - palette, - index as isize, - c, - /*palette_size=*/ num_colors + num_deltas, - /*bit_depth=*/ bit_depth, - ); - let val = if index < num_deltas as i32 { - let pred = predictor.predict_one( - get_prediction_data( - buf_out, out_idx, grid_x, grid_y, grid_xsize, x, y, xsize, ysize, - ), - /*wp_pred=*/ 0, + for y in 0..h { + let index_img = buf_in.data.row(y); + let out_row = buf_out[out_idx].data.row_mut(y); + #[allow(unsafe_code)] + for (x, &index) in index_img.iter().enumerate() { + // Fast path: direct palette lookup for valid indices (common case). + // Skip the multi-branch get_palette_value_with_row for the hot path. + let idx = index as usize; + if idx < palette_size { + // SAFETY: idx < palette_size <= pal_row.len() (palette is at least + // palette_size wide, validated during palette transform setup). + out_row[x] = unsafe { *pal_row.get_unchecked(idx) }; + } else { + // Rare case: implicit color cube or negative index + out_row[x] = get_palette_value_with_row( + pal_row, + index as isize, + c, + palette_size, + bit_depth, + ); + } + } + } + } + } else { + for c in 0..num_c { + let pal_row = palette.row(c); + let out_idx = c * grid_ysize * grid_xsize + grid_y * grid_xsize + grid_x; + for y in 0..h { + let index_img = buf_in.data.row(y); + for (x, &index) in index_img.iter().enumerate() { + let palette_entry = get_palette_value_with_row( + pal_row, + index as isize, + c, + palette_size, + bit_depth, ); - (pred + palette_entry as i64) as i32 - } else { - palette_entry - }; - buf_out[out_idx].data.row_mut(y)[x] = val; + let val = if index < num_deltas as i32 { + // Delta palette prediction may need cross-grid neighbors. + // Always use get_prediction_data to preserve exact behavior. + let pred = predictor.predict_one( + get_prediction_data( + buf_out, out_idx, grid_x, grid_y, grid_xsize, x, y, xsize, ysize, + ), + /*wp_pred=*/ 0, + ); + (pred + palette_entry as i64) as i32 + } else { + palette_entry + }; + buf_out[out_idx].data.row_mut(y)[x] = val; + } } } } } +#[inline(always)] #[allow(clippy::too_many_arguments)] pub fn do_palette_step_group_row( buf_in: &[&ModularChannel], @@ -371,8 +424,10 @@ pub fn do_palette_step_group_row( .sum(); let (xsize, ysize) = buf_out[0].data.size(); + let palette_size = num_colors + num_deltas; if predictor == Predictor::Weighted { for c in 0..num_c { + let pal_row = palette.row(c); let mut wp_state = WeightedPredictorState::new(wp_header, total_w); let out_row_idx = c * grid_ysize * grid_xsize + grid_y * grid_xsize; if grid_y > 0 { @@ -387,14 +442,15 @@ pub fn do_palette_step_group_row( let index_img = index_buf.data.row(y); let out_idx = out_row_idx + grid_x; for (x, &index) in index_img.iter().enumerate() { - let palette_entry = get_palette_value( - palette, + let palette_entry = get_palette_value_with_row( + pal_row, index as isize, c, - /*palette_size=*/ num_colors + num_deltas, - /*bit_depth=*/ bit_depth, + palette_size, + bit_depth, ); let val = if index < num_deltas as i32 { + // Delta palette prediction may need cross-grid neighbors. let prediction_data = get_prediction_data( buf_out, out_idx, grid_x, grid_y, grid_xsize, x, y, xsize, ysize, ); @@ -418,19 +474,21 @@ pub fn do_palette_step_group_row( } } else { for c in 0..num_c { + let pal_row = palette.row(c); for y in 0..h { for (grid_x, index_buf) in buf_in.iter().enumerate().take(grid_xsize) { let index_img = index_buf.data.row(y); let out_idx = c * grid_ysize * grid_xsize + grid_y * grid_xsize + grid_x; for (x, &index) in index_img.iter().enumerate() { - let palette_entry = get_palette_value( - palette, + let palette_entry = get_palette_value_with_row( + pal_row, index as isize, c, - /*palette_size=*/ num_colors + num_deltas, - /*bit_depth=*/ bit_depth, + palette_size, + bit_depth, ); let val = if index < num_deltas as i32 { + // Delta palette prediction may need cross-grid neighbors. let pred = predictor.predict_one( get_prediction_data( buf_out, out_idx, grid_x, grid_y, grid_xsize, x, y, xsize, From 8d5cc774fec62a84c5dbd68f25be92d3ea3c8585 Mon Sep 17 00:00:00 2001 From: Helmut Januschka Date: Sun, 31 May 2026 21:17:10 +0200 Subject: [PATCH 2/3] palette: drop unsafe in fast paths, use assert-elided bounds checks Address review feedback from veluca93: replace the get_unchecked calls in the no-delta palette fast paths with safe indexing, gated by an assert!(palette_size <= pal_row.len()) outside the inner loop. Combined with the existing 'if idx < palette_size' guard, the compiler can prove in-bounds access and elide the check while still keeping the code unsafe-free. Also drop #[inline(always)] from the large body functions (do_palette_step_general, do_palette_step_one_group, do_palette_step_group_row, get_prediction_data); keep it only on the small helpers (scale, get_palette_value_with_row). The original commit's benchmark showed a regression on all four files (-0.74% .. -3.07%), and aggressive inlining of these large bodies is a plausible contributor via icache pressure. The structural wins (hoisting palette.row(c) out of the inner loops, splitting the no-delta case) are retained. --- jxl/src/frame/modular/transforms/palette.rs | 24 +++++++++------------ 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/jxl/src/frame/modular/transforms/palette.rs b/jxl/src/frame/modular/transforms/palette.rs index 88df75238..b268d24c1 100644 --- a/jxl/src/frame/modular/transforms/palette.rs +++ b/jxl/src/frame/modular/transforms/palette.rs @@ -168,7 +168,6 @@ fn get_palette_value_with_row( } } -#[inline(always)] pub fn do_palette_step_general( buf_in: &ModularChannel, buf_pal: &ModularChannel, @@ -188,16 +187,18 @@ pub fn do_palette_step_general( } else if num_deltas == 0 && predictor == Predictor::Zero { for (chan_index, out) in buf_out.iter_mut().enumerate() { let pal_row = palette.row(chan_index); + // Asserting palette_size <= pal_row.len() once lets the compiler elide the + // bounds check on `pal_row[idx]` inside the loop given the `idx < num_colors` + // guard. + assert!(num_colors <= pal_row.len()); for y in 0..h { let row_index = buf_in.data.row(y); let row_out = out.data.row_mut(y); - #[allow(unsafe_code)] for x in 0..w { let index = row_index[x]; let idx = index as usize; if idx < num_colors { - // SAFETY: idx < num_colors <= pal_row.len() - row_out[x] = unsafe { *pal_row.get_unchecked(idx) }; + row_out[x] = pal_row[idx]; } else { row_out[x] = get_palette_value_with_row( pal_row, @@ -266,7 +267,6 @@ pub fn do_palette_step_general( } } -#[inline(always)] #[allow(clippy::too_many_arguments)] fn get_prediction_data( buf: &mut [&mut ModularChannel], @@ -313,7 +313,6 @@ fn get_prediction_data( ) } -#[inline(always)] #[allow(clippy::too_many_arguments)] pub fn do_palette_step_one_group( buf_in: &ModularChannel, @@ -340,21 +339,19 @@ pub fn do_palette_step_one_group( // Avoids prediction data computation entirely. for c in 0..num_c { let pal_row = palette.row(c); + // Asserting once lets the compiler elide the bounds check inside the loop + // given the `idx < palette_size` guard. + assert!(palette_size <= pal_row.len()); let out_idx = c * grid_ysize * grid_xsize + grid_y * grid_xsize + grid_x; for y in 0..h { let index_img = buf_in.data.row(y); let out_row = buf_out[out_idx].data.row_mut(y); - #[allow(unsafe_code)] for (x, &index) in index_img.iter().enumerate() { - // Fast path: direct palette lookup for valid indices (common case). - // Skip the multi-branch get_palette_value_with_row for the hot path. let idx = index as usize; if idx < palette_size { - // SAFETY: idx < palette_size <= pal_row.len() (palette is at least - // palette_size wide, validated during palette transform setup). - out_row[x] = unsafe { *pal_row.get_unchecked(idx) }; + out_row[x] = pal_row[idx]; } else { - // Rare case: implicit color cube or negative index + // Rare case: implicit color cube or negative index. out_row[x] = get_palette_value_with_row( pal_row, index as isize, @@ -400,7 +397,6 @@ pub fn do_palette_step_one_group( } } -#[inline(always)] #[allow(clippy::too_many_arguments)] pub fn do_palette_step_group_row( buf_in: &[&ModularChannel], From 7bf76ae75928dc91f2064ca2fed0735cf87dfd66 Mon Sep 17 00:00:00 2001 From: Helmut Januschka Date: Thu, 4 Jun 2026 12:01:19 +0200 Subject: [PATCH 3/3] palette: rename helper, generalize fast path, factor common loop Address review feedback from veluca93: * Rename get_palette_value_with_row back to get_palette_value. The function's role hasn't changed; it just now takes a pre-fetched palette row (`&[i32]`) instead of looking it up via `palette.row(c)` on every call. Updated the doc comment to reflect that. * Drop the predictor == Predictor::Zero half of the fast-path gate in do_palette_step_general. When num_deltas == 0, the `if index < num_deltas as i32` branch in the weighted/general arms is never taken, so predictor.predict_one is never called and the WP state's update_errors writes are dead. The fast path is therefore correct for any predictor when num_deltas == 0, matching what do_palette_step_one_group was already doing. * Factor the common inner loop into apply_palette_lookup_row. Both fast paths now call it, removing the duplicated direct-lookup / out-of-range-fallback / assert pattern. Behavior unchanged. 639 unit tests pass. --- jxl/src/frame/modular/transforms/palette.rs | 113 +++++++++----------- 1 file changed, 50 insertions(+), 63 deletions(-) diff --git a/jxl/src/frame/modular/transforms/palette.rs b/jxl/src/frame/modular/transforms/palette.rs index b268d24c1..332f0fda0 100644 --- a/jxl/src/frame/modular/transforms/palette.rs +++ b/jxl/src/frame/modular/transforms/palette.rs @@ -38,9 +38,12 @@ fn scale(value: usize, bit_depth: usize) -> i32 { // The purpose of this function is solely to extend the interpretation of // palette indices to implicit values. If index < nb_deltas, indicating that the // result is a delta palette entry, it is the responsibility of the caller to -/// Look up palette value. `pal_row` is the pre-fetched palette row for channel `c`. +// treat it as such. +// +// `pal_row` is the pre-fetched palette row for channel `c` (i.e. equivalent to +// `palette.row(c)` for some `palette: &Image`). #[inline(always)] -fn get_palette_value_with_row( +fn get_palette_value( pal_row: &[i32], index: isize, c: usize, @@ -168,6 +171,35 @@ fn get_palette_value_with_row( } } +/// Apply a pure palette lookup (no delta prediction) to one row. +/// +/// The hot path is the `idx < palette_size` branch (direct lookup); out-of-range +/// indices fall back to `get_palette_value` for the implicit color-cube cases. +/// Used by both `do_palette_step_general` and `do_palette_step_one_group` when +/// `num_deltas == 0`, since in that case the delta-prediction branch is never +/// taken and any predictor state is dead. +#[inline(always)] +fn apply_palette_lookup_row( + pal_row: &[i32], + row_in: &[i32], + row_out: &mut [i32], + c: usize, + palette_size: usize, + bit_depth: usize, +) { + // Asserting once lets the compiler elide the bounds check on `pal_row[idx]` + // inside the loop given the `idx < palette_size` guard. + assert!(palette_size <= pal_row.len()); + for (x, &index) in row_in.iter().enumerate() { + let idx = index as usize; + if idx < palette_size { + row_out[x] = pal_row[idx]; + } else { + row_out[x] = get_palette_value(pal_row, index as isize, c, palette_size, bit_depth); + } + } +} + pub fn do_palette_step_general( buf_in: &ModularChannel, buf_pal: &ModularChannel, @@ -184,31 +216,18 @@ pub fn do_palette_step_general( if w == 0 { // Nothing to do. // Avoid touching "empty" channels with non-zero height. - } else if num_deltas == 0 && predictor == Predictor::Zero { + } else if num_deltas == 0 { + // Fast path: no delta entries means the delta-prediction branch in the + // weighted/general arms below is never taken, so predictor invocations + // and WP state updates are dead work. This is independent of `predictor`. for (chan_index, out) in buf_out.iter_mut().enumerate() { let pal_row = palette.row(chan_index); - // Asserting palette_size <= pal_row.len() once lets the compiler elide the - // bounds check on `pal_row[idx]` inside the loop given the `idx < num_colors` - // guard. - assert!(num_colors <= pal_row.len()); for y in 0..h { let row_index = buf_in.data.row(y); let row_out = out.data.row_mut(y); - for x in 0..w { - let index = row_index[x]; - let idx = index as usize; - if idx < num_colors { - row_out[x] = pal_row[idx]; - } else { - row_out[x] = get_palette_value_with_row( - pal_row, - index as isize, - chan_index, - num_colors, - bit_depth, - ); - } - } + apply_palette_lookup_row( + pal_row, row_index, row_out, chan_index, num_colors, bit_depth, + ); } } } else if predictor == Predictor::Weighted { @@ -219,7 +238,7 @@ pub fn do_palette_step_general( for y in 0..h { let idx = buf_in.data.row(y); for (x, &index) in idx.iter().enumerate() { - let palette_entry = get_palette_value_with_row( + let palette_entry = get_palette_value( pal_row, index as isize, /*c=*/ chan_index, @@ -246,7 +265,7 @@ pub fn do_palette_step_general( for y in 0..h { let idx = buf_in.data.row(y); for (x, &index) in idx.iter().enumerate() { - let palette_entry = get_palette_value_with_row( + let palette_entry = get_palette_value( pal_row, index as isize, /*c=*/ chan_index, @@ -339,28 +358,11 @@ pub fn do_palette_step_one_group( // Avoids prediction data computation entirely. for c in 0..num_c { let pal_row = palette.row(c); - // Asserting once lets the compiler elide the bounds check inside the loop - // given the `idx < palette_size` guard. - assert!(palette_size <= pal_row.len()); let out_idx = c * grid_ysize * grid_xsize + grid_y * grid_xsize + grid_x; for y in 0..h { let index_img = buf_in.data.row(y); let out_row = buf_out[out_idx].data.row_mut(y); - for (x, &index) in index_img.iter().enumerate() { - let idx = index as usize; - if idx < palette_size { - out_row[x] = pal_row[idx]; - } else { - // Rare case: implicit color cube or negative index. - out_row[x] = get_palette_value_with_row( - pal_row, - index as isize, - c, - palette_size, - bit_depth, - ); - } - } + apply_palette_lookup_row(pal_row, index_img, out_row, c, palette_size, bit_depth); } } } else { @@ -370,13 +372,8 @@ pub fn do_palette_step_one_group( for y in 0..h { let index_img = buf_in.data.row(y); for (x, &index) in index_img.iter().enumerate() { - let palette_entry = get_palette_value_with_row( - pal_row, - index as isize, - c, - palette_size, - bit_depth, - ); + let palette_entry = + get_palette_value(pal_row, index as isize, c, palette_size, bit_depth); let val = if index < num_deltas as i32 { // Delta palette prediction may need cross-grid neighbors. // Always use get_prediction_data to preserve exact behavior. @@ -438,13 +435,8 @@ pub fn do_palette_step_group_row( let index_img = index_buf.data.row(y); let out_idx = out_row_idx + grid_x; for (x, &index) in index_img.iter().enumerate() { - let palette_entry = get_palette_value_with_row( - pal_row, - index as isize, - c, - palette_size, - bit_depth, - ); + let palette_entry = + get_palette_value(pal_row, index as isize, c, palette_size, bit_depth); let val = if index < num_deltas as i32 { // Delta palette prediction may need cross-grid neighbors. let prediction_data = get_prediction_data( @@ -476,13 +468,8 @@ pub fn do_palette_step_group_row( let index_img = index_buf.data.row(y); let out_idx = c * grid_ysize * grid_xsize + grid_y * grid_xsize + grid_x; for (x, &index) in index_img.iter().enumerate() { - let palette_entry = get_palette_value_with_row( - pal_row, - index as isize, - c, - palette_size, - bit_depth, - ); + let palette_entry = + get_palette_value(pal_row, index as isize, c, palette_size, bit_depth); let val = if index < num_deltas as i32 { // Delta palette prediction may need cross-grid neighbors. let pred = predictor.predict_one(