Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 94 additions & 53 deletions jxl/src/frame/modular/transforms/palette.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const SMALL_CUBE_BITS: usize = 2;
// SMALL_CUBE ** 3
const LARGE_CUBE_OFFSET: usize = SMALL_CUBE * SMALL_CUBE * SMALL_CUBE;

#[inline(always)]
fn scale<const DENOM: usize>(value: usize, bit_depth: usize) -> i32 {
// return (value * ((1 << bit_depth) - 1)) / DENOM;
// We only call this function with SMALL_CUBE or LARGE_CUBE - 1 as DENOM,
Expand All @@ -38,8 +39,12 @@ fn scale<const DENOM: usize>(value: usize, bit_depth: usize) -> i32 {
// palette indices to implicit values. If index < nb_deltas, indicating that the
// result is a delta palette entry, it is the responsibility of the caller to
// treat it as such.
//
// `pal_row` is the pre-fetched palette row for channel `c` (i.e. equivalent to
// `palette.row(c)` for some `palette: &Image<i32>`).
#[inline(always)]
fn get_palette_value(
palette: &Image<i32>,
pal_row: &[i32],
index: isize,
c: usize,
palette_size: usize,
Expand Down Expand Up @@ -161,7 +166,36 @@ fn get_palette_value(
}
scale::<{ LARGE_CUBE - 1 }>(index % LARGE_CUBE, bit_depth)
} else {
palette.row(c)[index]
pal_row[index]
}
}
}

/// Apply a pure palette lookup (no delta prediction) to one row.
///
/// The hot path is the `idx < palette_size` branch (direct lookup); out-of-range
/// indices fall back to `get_palette_value` for the implicit color-cube cases.
/// Used by both `do_palette_step_general` and `do_palette_step_one_group` when
/// `num_deltas == 0`, since in that case the delta-prediction branch is never
/// taken and any predictor state is dead.
#[inline(always)]
fn apply_palette_lookup_row(
pal_row: &[i32],
row_in: &[i32],
row_out: &mut [i32],
c: usize,
palette_size: usize,
bit_depth: usize,
) {
// Asserting once lets the compiler elide the bounds check on `pal_row[idx]`
// inside the loop given the `idx < palette_size` guard.
assert!(palette_size <= pal_row.len());
for (x, &index) in row_in.iter().enumerate() {
let idx = index as usize;
if idx < palette_size {
row_out[x] = pal_row[idx];
} else {
row_out[x] = get_palette_value(pal_row, index as isize, c, palette_size, bit_depth);
}
}
}
Expand All @@ -182,33 +216,30 @@ pub fn do_palette_step_general(
if w == 0 {
// Nothing to do.
// Avoid touching "empty" channels with non-zero height.
} else if num_deltas == 0 && predictor == Predictor::Zero {
} else if num_deltas == 0 {
// Fast path: no delta entries means the delta-prediction branch in the
// weighted/general arms below is never taken, so predictor invocations
// and WP state updates are dead work. This is independent of `predictor`.
for (chan_index, out) in buf_out.iter_mut().enumerate() {
let pal_row = palette.row(chan_index);
for y in 0..h {
let row_index = buf_in.data.row(y);
let row_out = out.data.row_mut(y);
for x in 0..w {
let index = row_index[x];
let palette_value = get_palette_value(
palette,
index as isize,
/*c=*/ chan_index,
/*palette_size=*/ num_colors,
/*bit_depth=*/ bit_depth,
);
row_out[x] = palette_value;
}
apply_palette_lookup_row(
pal_row, row_index, row_out, chan_index, num_colors, bit_depth,
);
}
}
} else if predictor == Predictor::Weighted {
let w = buf_in.data.size().0;
for (chan_index, out) in buf_out.iter_mut().enumerate() {
let pal_row = palette.row(chan_index);
let mut wp_state = WeightedPredictorState::new(wp_header, w);
for y in 0..h {
let idx = buf_in.data.row(y);
for (x, &index) in idx.iter().enumerate() {
let palette_entry = get_palette_value(
palette,
pal_row,
index as isize,
/*c=*/ chan_index,
/*palette_size=*/ num_colors + num_deltas,
Expand All @@ -230,11 +261,12 @@ pub fn do_palette_step_general(
}
} else {
for (chan_index, out) in buf_out.iter_mut().enumerate() {
let pal_row = palette.row(chan_index);
for y in 0..h {
let idx = buf_in.data.row(y);
for (x, &index) in idx.iter().enumerate() {
let palette_entry = get_palette_value(
palette,
pal_row,
index as isize,
/*c=*/ chan_index,
/*palette_size=*/ num_colors + num_deltas,
Expand Down Expand Up @@ -319,30 +351,44 @@ pub fn do_palette_step_one_group(
let num_c = buf_out.len() / (grid_xsize * grid_ysize);
let (xsize, ysize) = buf_out[0].data.size();

for c in 0..num_c {
for y in 0..h {
let index_img = buf_in.data.row(y);
let palette_size = num_colors + num_deltas;

if num_deltas == 0 {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In do_palette_step_general, this is gated by the predictor also being zero - why is this not the same here?

If the fast path is basically equivalent, shouldn't we factor it out to a function?

What images trigger this fast path?

// Fast path: no delta palette entries, just direct lookups.
// Avoids prediction data computation entirely.
for c in 0..num_c {
let pal_row = palette.row(c);
let out_idx = c * grid_ysize * grid_xsize + grid_y * grid_xsize + grid_x;
for (x, &index) in index_img.iter().enumerate() {
let palette_entry = get_palette_value(
palette,
index as isize,
c,
/*palette_size=*/ num_colors + num_deltas,
/*bit_depth=*/ bit_depth,
);
let val = if index < num_deltas as i32 {
let pred = predictor.predict_one(
get_prediction_data(
buf_out, out_idx, grid_x, grid_y, grid_xsize, x, y, xsize, ysize,
),
/*wp_pred=*/ 0,
);
(pred + palette_entry as i64) as i32
} else {
palette_entry
};
buf_out[out_idx].data.row_mut(y)[x] = val;
for y in 0..h {
let index_img = buf_in.data.row(y);
let out_row = buf_out[out_idx].data.row_mut(y);
apply_palette_lookup_row(pal_row, index_img, out_row, c, palette_size, bit_depth);
}
}
} else {
for c in 0..num_c {
let pal_row = palette.row(c);
let out_idx = c * grid_ysize * grid_xsize + grid_y * grid_xsize + grid_x;
for y in 0..h {
let index_img = buf_in.data.row(y);
for (x, &index) in index_img.iter().enumerate() {
let palette_entry =
get_palette_value(pal_row, index as isize, c, palette_size, bit_depth);
let val = if index < num_deltas as i32 {
// Delta palette prediction may need cross-grid neighbors.
// Always use get_prediction_data to preserve exact behavior.
let pred = predictor.predict_one(
get_prediction_data(
buf_out, out_idx, grid_x, grid_y, grid_xsize, x, y, xsize, ysize,
),
/*wp_pred=*/ 0,
);
(pred + palette_entry as i64) as i32
} else {
palette_entry
};
buf_out[out_idx].data.row_mut(y)[x] = val;
}
}
}
}
Expand Down Expand Up @@ -371,8 +417,10 @@ pub fn do_palette_step_group_row(
.sum();
let (xsize, ysize) = buf_out[0].data.size();

let palette_size = num_colors + num_deltas;
if predictor == Predictor::Weighted {
for c in 0..num_c {
let pal_row = palette.row(c);
let mut wp_state = WeightedPredictorState::new(wp_header, total_w);
let out_row_idx = c * grid_ysize * grid_xsize + grid_y * grid_xsize;
if grid_y > 0 {
Expand All @@ -387,14 +435,10 @@ pub fn do_palette_step_group_row(
let index_img = index_buf.data.row(y);
let out_idx = out_row_idx + grid_x;
for (x, &index) in index_img.iter().enumerate() {
let palette_entry = get_palette_value(
palette,
index as isize,
c,
/*palette_size=*/ num_colors + num_deltas,
/*bit_depth=*/ bit_depth,
);
let palette_entry =
get_palette_value(pal_row, index as isize, c, palette_size, bit_depth);
let val = if index < num_deltas as i32 {
// Delta palette prediction may need cross-grid neighbors.
let prediction_data = get_prediction_data(
buf_out, out_idx, grid_x, grid_y, grid_xsize, x, y, xsize, ysize,
);
Expand All @@ -418,19 +462,16 @@ pub fn do_palette_step_group_row(
}
} else {
for c in 0..num_c {
let pal_row = palette.row(c);
for y in 0..h {
for (grid_x, index_buf) in buf_in.iter().enumerate().take(grid_xsize) {
let index_img = index_buf.data.row(y);
let out_idx = c * grid_ysize * grid_xsize + grid_y * grid_xsize + grid_x;
for (x, &index) in index_img.iter().enumerate() {
let palette_entry = get_palette_value(
palette,
index as isize,
c,
/*palette_size=*/ num_colors + num_deltas,
/*bit_depth=*/ bit_depth,
);
let palette_entry =
get_palette_value(pal_row, index as isize, c, palette_size, bit_depth);
let val = if index < num_deltas as i32 {
// Delta palette prediction may need cross-grid neighbors.
let pred = predictor.predict_one(
get_prediction_data(
buf_out, out_idx, grid_x, grid_y, grid_xsize, x, y, xsize,
Expand Down
Loading