Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions jxl/src/frame/modular/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,12 @@ impl FullModularImage {
Ok(())
}

pub fn outputs_are_gridded(&self) -> bool {
self.buffer_info
.iter()
.all(|b| b.info.output_channel_idx.is_none() || b.grid_kind != ModularGridKind::None)
}

pub fn channel_range(&self) -> Range<usize> {
if self.modular_color_channels != 0 {
0..self.buffers_for_channels.len()
Expand Down
28 changes: 28 additions & 0 deletions jxl/src/frame/modular/transforms/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,20 @@ impl TransformStepChunk {
}
buffers[buf_in[0]].buffer_grid[in_grid].mark_used(is_final);
buffers[buf_in[1]].buffer_grid[res_grid].mark_used(is_final);
// Release the weak neighbor grids read above (next average and previous
// decoded), which are counted as uses in the transform graph.
let (gx, gy) = self.grid_pos;
if gx + 1 < buffers[*buf_out].grid_shape.0 {
let next_avg_grid =
buffers[buf_in[0]].get_grid_idx(out_grid_kind, (gx + 1, gy));
if next_avg_grid != in_grid {
buffers[buf_in[0]].buffer_grid[next_avg_grid].mark_used(is_final);
}
}
if gx > 0 {
let prev_out_grid = buffers[*buf_out].get_grid_idx(out_grid_kind, (gx - 1, gy));
buffers[*buf_out].buffer_grid[prev_out_grid].mark_used(is_final);
}
}
TransformStep::VSqueeze {
buf_in,
Expand Down Expand Up @@ -491,6 +505,20 @@ impl TransformStepChunk {
}
buffers[buf_in[0]].buffer_grid[in_grid].mark_used(is_final);
buffers[buf_in[1]].buffer_grid[res_grid].mark_used(is_final);
// Release the weak neighbor grids read above (next average and previous
// decoded), which are counted as uses in the transform graph.
let (gx, gy) = self.grid_pos;
if gy + 1 < buffers[*buf_out].grid_shape.1 {
let next_avg_grid =
buffers[buf_in[0]].get_grid_idx(out_grid_kind, (gx, gy + 1));
if next_avg_grid != in_grid {
buffers[buf_in[0]].buffer_grid[next_avg_grid].mark_used(is_final);
}
}
if gy > 0 {
let prev_out_grid = buffers[*buf_out].get_grid_idx(out_grid_kind, (gx, gy - 1));
buffers[*buf_out].buffer_grid[prev_out_grid].mark_used(is_final);
}
}
};

Expand Down
66 changes: 57 additions & 9 deletions jxl/src/frame/render.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,14 +204,26 @@ impl Frame {
)?;
}

// STEP 2: ensure that groups that will be re-rendered are marked as such.
// VarDCT data to be rendered.
// STEP 2: ensure that VarDCT groups that will be re-rendered are marked as such.
for (g, _) in groups.iter() {
self.groups_to_flush.insert(*g);
pipeline!(self, p, p.mark_group_to_rerender(*g));
}
// Modular data to be re-rendered.
{

let group_count = self.header.size_groups();
let interleave_modular = !do_flush
&& !self.was_flushed_once
&& self.header.encoding == Encoding::Modular
&& group_count.0 >= 8
&& group_count.1 >= 8
&& self
.lf_global
.as_ref()
.unwrap()
.modular_global
.outputs_are_gridded();

if !interleave_modular {
let modular_global = &mut self.lf_global.as_mut().unwrap().modular_global;
for (group, passes) in groups.iter() {
for (pass, _) in passes.iter() {
Expand All @@ -226,18 +238,52 @@ impl Frame {
modular_global.process_output(&self.header, true, &mut pass_to_pipeline)?;
}

// STEP 3: decode the groups, eagerly rendering VarDCT channels and noise.
for (group, mut passes) in groups {
if interleave_modular {
let modular_global = &mut self.lf_global.as_mut().unwrap().modular_global;
for (pass, _) in passes.iter() {
modular_global.mark_group_to_be_read(2 + *pass, group);
}
let mut pass_to_pipeline = |_, group, _, _| {
self.groups_to_flush.insert(group);
pipeline!(self, p, p.mark_group_to_rerender(group));
Ok(())
};
modular_global.process_output(&self.header, true, &mut pass_to_pipeline)?;
}

if self.decode_hf_group(group, &mut passes, &mut buffer_splitter, do_flush)? {
self.changed_since_last_flush
.insert((group, RenderUnit::VarDCT));
}

if interleave_modular {
pipeline!(self, p, p.set_allow_pending_buffer_replacement(true));
let modular_global = &mut self.lf_global.as_mut().unwrap().modular_global;
let mut pass_to_pipeline = |chan, group, complete, image: Option<Image<i32>>| {
self.changed_since_last_flush
.insert((group, RenderUnit::Modular(chan)));
pipeline!(
self,
p,
p.set_buffer_for_group(
chan,
group,
complete,
image.unwrap(),
&mut buffer_splitter
)?
);
Ok(())
};
let result =
modular_global.process_output(&self.header, false, &mut pass_to_pipeline);
pipeline!(self, p, p.set_allow_pending_buffer_replacement(false));
result?;
}
}

// STEP 4: process all modular transforms that can now be processed,
// flushing buffers that will not be used again, if either we are forcing a render now
// or we are done with the file.
if self.incomplete_groups == 0 || do_flush {
if !interleave_modular && (self.incomplete_groups == 0 || do_flush) {
let modular_global = &mut self.lf_global.as_mut().unwrap().modular_global;
let mut pass_to_pipeline = |chan, group, complete, image: Option<Image<i32>>| {
self.changed_since_last_flush
Expand All @@ -256,7 +302,9 @@ impl Frame {
Ok(())
};
modular_global.process_output(&self.header, false, &mut pass_to_pipeline)?;
}

if self.incomplete_groups == 0 || do_flush {
// STEP 5: re-render VarDCT/noise data in rendered groups for which it was
// not rendered, or re-send to pipeline modular channels that were not
// updated in those groups.
Expand Down
12 changes: 9 additions & 3 deletions jxl/src/render/low_memory_pipeline/group_scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,13 @@ pub(super) struct InputBuffer {
}

impl InputBuffer {
pub(super) fn set_buffer(&mut self, chan: usize, buf: OwnedRawImage) {
assert!(self.data[chan].is_none());
pub(super) fn set_buffer(&mut self, chan: usize, buf: OwnedRawImage, replace: bool) {
if self.data[chan].is_none() {
self.ready_channels += 1;
} else {
assert!(replace);
}
self.data[chan] = Some(buf);
self.ready_channels += 1;
}

pub(super) fn new(num_channels: usize) -> Self {
Expand Down Expand Up @@ -113,6 +116,9 @@ impl LowMemoryRenderPipeline {
}

fn store_scratch_buffer(&mut self, channel: usize, kind: usize, image: OwnedRawImage) {
if kind == 0 && self.allow_pending_buffer_replacement {
return;
}
self.scratch_channel_buffers[channel * 3 + kind].push(image)
}

Expand Down
10 changes: 9 additions & 1 deletion jxl/src/render/low_memory_pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ pub struct LowMemoryRenderPipeline {
// could be reused to store group data for that channel.
// Indexed by [3*channel] = center, [3*channel+1] = topbottom, [3*channel+2] = leftright.
scratch_channel_buffers: Vec<Vec<OwnedRawImage>>,
allow_pending_buffer_replacement: bool,
}

impl RenderPipeline for LowMemoryRenderPipeline {
Expand Down Expand Up @@ -282,9 +283,14 @@ impl RenderPipeline for LowMemoryRenderPipeline {
opaque_alpha_buffers,
sorted_buffer_indices,
scratch_channel_buffers: (0..nc * 3).map(|_| vec![]).collect(),
allow_pending_buffer_replacement: false,
})
}

fn set_allow_pending_buffer_replacement(&mut self, allow: bool) {
self.allow_pending_buffer_replacement = allow;
}

#[instrument(skip_all, err)]
fn get_buffer<T: ImageDataType>(&mut self, channel: usize) -> Result<Image<T>> {
if let Some(b) = self.maybe_get_scratch_buffer(channel, 0) {
Expand All @@ -309,7 +315,9 @@ impl RenderPipeline for LowMemoryRenderPipeline {
channel,
T::DATA_TYPE_ID,
);
self.input_buffers[group_id].set_buffer(channel, buf.into_raw());
let replace = self.allow_pending_buffer_replacement
&& !self.shared.group_chan_complete[group_id][channel];
self.input_buffers[group_id].set_buffer(channel, buf.into_raw(), replace);
self.shared.group_chan_complete[group_id][channel] = complete;

self.render_with_new_group(group_id, buffer_splitter)?;
Expand Down
2 changes: 2 additions & 0 deletions jxl/src/render/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ pub(crate) trait RenderPipeline: Sized {
/// pass, a new buffer, or a re-used buffer from i.e. previously decoded frames.
fn get_buffer<T: ImageDataType>(&mut self, channel: usize) -> Result<Image<T>>;

fn set_allow_pending_buffer_replacement(&mut self, _allow: bool) {}

/// Gives back the buffer for a channel and group to the render pipeline, marking whether
/// this will be the last time that this function is called for this group.
fn set_buffer_for_group<T: ImageDataType>(
Expand Down
1 change: 1 addition & 0 deletions jxl_cli/benches/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ fn decode_benches(c: &mut Criterion) {
false,
None,
false,
false,
)
.unwrap();
})
Expand Down
5 changes: 3 additions & 2 deletions jxl_cli/src/dec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ pub fn decode_frames<In: JxlBitstreamInputExt>(
linear_output: bool,
render_interval: Option<usize>,
allow_partial_files: bool,
store_partial_renders: bool,
) -> Result<(DecodeOutput, Duration)> {
let start = Instant::now();

Expand Down Expand Up @@ -282,7 +283,7 @@ pub fn decode_frames<In: JxlBitstreamInputExt>(
// render and retry.
if render_interval.is_some() && input.available_bytes()? > 0 {
has_rendered_data |= fallback.flush_pixels(&mut output_bufs)?;
if has_rendered_data {
if has_rendered_data && store_partial_renders {
partial_renders.push(
outputs
.iter()
Expand Down Expand Up @@ -332,7 +333,7 @@ pub fn decode_frames<In: JxlBitstreamInputExt>(
// render and retry.
if render_interval.is_some() && input.available_bytes()? > 0 {
has_rendered_data |= fallback.flush_pixels(&mut output_bufs)?;
if has_rendered_data {
if has_rendered_data && store_partial_renders {
partial_renders.push(
outputs
.iter()
Expand Down
2 changes: 2 additions & 0 deletions jxl_cli/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ mod tests {
false,
None,
false,
false,
)
.unwrap()
.0
Expand Down Expand Up @@ -188,6 +189,7 @@ mod tests {
false,
None,
false,
false,
)
.unwrap();
}
Expand Down
2 changes: 2 additions & 0 deletions jxl_cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ fn main() -> Result<()> {
let linear_output = matches!(output_format, Some(OutputFormat::Exr));
#[cfg(not(feature = "exr"))]
let linear_output = false;
let store_partial_renders = output_format.is_some() && opt.render_interval.is_some();
let (mut output, duration) = dec::decode_frames(
$input,
options(skip_preview),
Expand All @@ -176,6 +177,7 @@ fn main() -> Result<()> {
linear_output,
opt.render_interval,
opt.allow_partial_files,
store_partial_renders,
)?;
if opt.preview {
output.frames.truncate(1);
Expand Down
Loading