Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 104 additions & 46 deletions jxl/src/frame/modular/transforms/palette.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const SMALL_CUBE_BITS: usize = 2;
// SMALL_CUBE ** 3
const LARGE_CUBE_OFFSET: usize = SMALL_CUBE * SMALL_CUBE * SMALL_CUBE;

#[inline(always)]
fn scale<const DENOM: usize>(value: usize, bit_depth: usize) -> i32 {
// return (value * ((1 << bit_depth) - 1)) / DENOM;
// We only call this function with SMALL_CUBE or LARGE_CUBE - 1 as DENOM,
Expand All @@ -37,9 +38,10 @@ fn scale<const DENOM: usize>(value: usize, bit_depth: usize) -> i32 {
// The purpose of this function is solely to extend the interpretation of
// palette indices to implicit values. If index < nb_deltas, indicating that the
// result is a delta palette entry, it is the responsibility of the caller to
// treat it as such.
fn get_palette_value(
palette: &Image<i32>,
/// Look up palette value. `pal_row` is the pre-fetched palette row for channel `c`.
#[inline(always)]
fn get_palette_value_with_row(

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rename this back to get_palette_value.

pal_row: &[i32],
index: isize,
c: usize,
palette_size: usize,
Expand Down Expand Up @@ -161,11 +163,12 @@ fn get_palette_value(
}
scale::<{ LARGE_CUBE - 1 }>(index % LARGE_CUBE, bit_depth)
} else {
palette.row(c)[index]
pal_row[index]
}
}
}

#[inline(always)]
pub fn do_palette_step_general(
buf_in: &ModularChannel,
buf_pal: &ModularChannel,
Expand All @@ -184,31 +187,39 @@ pub fn do_palette_step_general(
// Avoid touching "empty" channels with non-zero height.
} else if num_deltas == 0 && predictor == Predictor::Zero {
for (chan_index, out) in buf_out.iter_mut().enumerate() {
let pal_row = palette.row(chan_index);
for y in 0..h {
let row_index = buf_in.data.row(y);
let row_out = out.data.row_mut(y);
#[allow(unsafe_code)]
for x in 0..w {
let index = row_index[x];
let palette_value = get_palette_value(
palette,
index as isize,
/*c=*/ chan_index,
/*palette_size=*/ num_colors,
/*bit_depth=*/ bit_depth,
);
row_out[x] = palette_value;
let idx = index as usize;
if idx < num_colors {
// SAFETY: idx < num_colors <= pal_row.len()
row_out[x] = unsafe { *pal_row.get_unchecked(idx) };
} else {
row_out[x] = get_palette_value_with_row(
pal_row,
index as isize,
chan_index,
num_colors,
bit_depth,
);
}
}
}
}
} else if predictor == Predictor::Weighted {
let w = buf_in.data.size().0;
for (chan_index, out) in buf_out.iter_mut().enumerate() {
let pal_row = palette.row(chan_index);
let mut wp_state = WeightedPredictorState::new(wp_header, w);
for y in 0..h {
let idx = buf_in.data.row(y);
for (x, &index) in idx.iter().enumerate() {
let palette_entry = get_palette_value(
palette,
let palette_entry = get_palette_value_with_row(
pal_row,
index as isize,
/*c=*/ chan_index,
/*palette_size=*/ num_colors + num_deltas,
Expand All @@ -230,11 +241,12 @@ pub fn do_palette_step_general(
}
} else {
for (chan_index, out) in buf_out.iter_mut().enumerate() {
let pal_row = palette.row(chan_index);
for y in 0..h {
let idx = buf_in.data.row(y);
for (x, &index) in idx.iter().enumerate() {
let palette_entry = get_palette_value(
palette,
let palette_entry = get_palette_value_with_row(
pal_row,
index as isize,
/*c=*/ chan_index,
/*palette_size=*/ num_colors + num_deltas,
Expand All @@ -254,6 +266,7 @@ pub fn do_palette_step_general(
}
}

#[inline(always)]
#[allow(clippy::too_many_arguments)]
fn get_prediction_data(
buf: &mut [&mut ModularChannel],
Expand Down Expand Up @@ -300,6 +313,7 @@ fn get_prediction_data(
)
}

#[inline(always)]
#[allow(clippy::too_many_arguments)]
pub fn do_palette_step_one_group(
buf_in: &ModularChannel,
Expand All @@ -319,35 +333,74 @@ pub fn do_palette_step_one_group(
let num_c = buf_out.len() / (grid_xsize * grid_ysize);
let (xsize, ysize) = buf_out[0].data.size();

for c in 0..num_c {
for y in 0..h {
let index_img = buf_in.data.row(y);
let palette_size = num_colors + num_deltas;

if num_deltas == 0 {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In do_palette_step_general, this is gated by the predictor also being zero - why is this not the same here?

If the fast path is basically equivalent, shouldn't we factor it out to a function?

What images trigger this fast path?

// Fast path: no delta palette entries, just direct lookups.
// Avoids prediction data computation entirely.
for c in 0..num_c {
let pal_row = palette.row(c);
let out_idx = c * grid_ysize * grid_xsize + grid_y * grid_xsize + grid_x;
for (x, &index) in index_img.iter().enumerate() {
let palette_entry = get_palette_value(
palette,
index as isize,
c,
/*palette_size=*/ num_colors + num_deltas,
/*bit_depth=*/ bit_depth,
);
let val = if index < num_deltas as i32 {
let pred = predictor.predict_one(
get_prediction_data(
buf_out, out_idx, grid_x, grid_y, grid_xsize, x, y, xsize, ysize,
),
/*wp_pred=*/ 0,
for y in 0..h {
let index_img = buf_in.data.row(y);
let out_row = buf_out[out_idx].data.row_mut(y);
#[allow(unsafe_code)]
for (x, &index) in index_img.iter().enumerate() {
// Fast path: direct palette lookup for valid indices (common case).
// Skip the multi-branch get_palette_value_with_row for the hot path.
let idx = index as usize;
if idx < palette_size {
// SAFETY: idx < palette_size <= pal_row.len() (palette is at least
// palette_size wide, validated during palette transform setup).
out_row[x] = unsafe { *pal_row.get_unchecked(idx) };
Comment thread
hjanuschka marked this conversation as resolved.
Outdated
} else {
// Rare case: implicit color cube or negative index
out_row[x] = get_palette_value_with_row(
pal_row,
index as isize,
c,
palette_size,
bit_depth,
);
}
}
}
}
} else {
for c in 0..num_c {
let pal_row = palette.row(c);
let out_idx = c * grid_ysize * grid_xsize + grid_y * grid_xsize + grid_x;
for y in 0..h {
let index_img = buf_in.data.row(y);
for (x, &index) in index_img.iter().enumerate() {
let palette_entry = get_palette_value_with_row(
pal_row,
index as isize,
c,
palette_size,
bit_depth,
);
(pred + palette_entry as i64) as i32
} else {
palette_entry
};
buf_out[out_idx].data.row_mut(y)[x] = val;
let val = if index < num_deltas as i32 {
// Delta palette prediction may need cross-grid neighbors.
// Always use get_prediction_data to preserve exact behavior.
let pred = predictor.predict_one(
get_prediction_data(
buf_out, out_idx, grid_x, grid_y, grid_xsize, x, y, xsize, ysize,
),
/*wp_pred=*/ 0,
);
(pred + palette_entry as i64) as i32
} else {
palette_entry
};
buf_out[out_idx].data.row_mut(y)[x] = val;
}
}
}
}
}

#[inline(always)]
#[allow(clippy::too_many_arguments)]
pub fn do_palette_step_group_row(
buf_in: &[&ModularChannel],
Expand All @@ -371,8 +424,10 @@ pub fn do_palette_step_group_row(
.sum();
let (xsize, ysize) = buf_out[0].data.size();

let palette_size = num_colors + num_deltas;
if predictor == Predictor::Weighted {
for c in 0..num_c {
let pal_row = palette.row(c);
let mut wp_state = WeightedPredictorState::new(wp_header, total_w);
let out_row_idx = c * grid_ysize * grid_xsize + grid_y * grid_xsize;
if grid_y > 0 {
Expand All @@ -387,14 +442,15 @@ pub fn do_palette_step_group_row(
let index_img = index_buf.data.row(y);
let out_idx = out_row_idx + grid_x;
for (x, &index) in index_img.iter().enumerate() {
let palette_entry = get_palette_value(
palette,
let palette_entry = get_palette_value_with_row(
pal_row,
index as isize,
c,
/*palette_size=*/ num_colors + num_deltas,
/*bit_depth=*/ bit_depth,
palette_size,
bit_depth,
);
let val = if index < num_deltas as i32 {
// Delta palette prediction may need cross-grid neighbors.
let prediction_data = get_prediction_data(
buf_out, out_idx, grid_x, grid_y, grid_xsize, x, y, xsize, ysize,
);
Expand All @@ -418,19 +474,21 @@ pub fn do_palette_step_group_row(
}
} else {
for c in 0..num_c {
let pal_row = palette.row(c);
for y in 0..h {
for (grid_x, index_buf) in buf_in.iter().enumerate().take(grid_xsize) {
let index_img = index_buf.data.row(y);
let out_idx = c * grid_ysize * grid_xsize + grid_y * grid_xsize + grid_x;
for (x, &index) in index_img.iter().enumerate() {
let palette_entry = get_palette_value(
palette,
let palette_entry = get_palette_value_with_row(
pal_row,
index as isize,
c,
/*palette_size=*/ num_colors + num_deltas,
/*bit_depth=*/ bit_depth,
palette_size,
bit_depth,
);
let val = if index < num_deltas as i32 {
// Delta palette prediction may need cross-grid neighbors.
let pred = predictor.predict_one(
get_prediction_data(
buf_out, out_idx, grid_x, grid_y, grid_xsize, x, y, xsize,
Expand Down
Loading