diff --git a/jxl/src/frame/modular/mod.rs b/jxl/src/frame/modular/mod.rs index 8167de94c..6058d0829 100644 --- a/jxl/src/frame/modular/mod.rs +++ b/jxl/src/frame/modular/mod.rs @@ -961,6 +961,12 @@ impl FullModularImage { Ok(()) } + pub fn outputs_are_gridded(&self) -> bool { + self.buffer_info + .iter() + .all(|b| b.info.output_channel_idx.is_none() || b.grid_kind != ModularGridKind::None) + } + pub fn channel_range(&self) -> Range { if self.modular_color_channels != 0 { 0..self.buffers_for_channels.len() diff --git a/jxl/src/frame/modular/transforms/apply.rs b/jxl/src/frame/modular/transforms/apply.rs index 33fe5bd2a..101fc92df 100644 --- a/jxl/src/frame/modular/transforms/apply.rs +++ b/jxl/src/frame/modular/transforms/apply.rs @@ -382,6 +382,20 @@ impl TransformStepChunk { } buffers[buf_in[0]].buffer_grid[in_grid].mark_used(is_final); buffers[buf_in[1]].buffer_grid[res_grid].mark_used(is_final); + // Release the weak neighbor grids read above (next average and previous + // decoded), which are counted as uses in the transform graph. + let (gx, gy) = self.grid_pos; + if gx + 1 < buffers[*buf_out].grid_shape.0 { + let next_avg_grid = + buffers[buf_in[0]].get_grid_idx(out_grid_kind, (gx + 1, gy)); + if next_avg_grid != in_grid { + buffers[buf_in[0]].buffer_grid[next_avg_grid].mark_used(is_final); + } + } + if gx > 0 { + let prev_out_grid = buffers[*buf_out].get_grid_idx(out_grid_kind, (gx - 1, gy)); + buffers[*buf_out].buffer_grid[prev_out_grid].mark_used(is_final); + } } TransformStep::VSqueeze { buf_in, @@ -491,6 +505,20 @@ impl TransformStepChunk { } buffers[buf_in[0]].buffer_grid[in_grid].mark_used(is_final); buffers[buf_in[1]].buffer_grid[res_grid].mark_used(is_final); + // Release the weak neighbor grids read above (next average and previous + // decoded), which are counted as uses in the transform graph. + let (gx, gy) = self.grid_pos; + if gy + 1 < buffers[*buf_out].grid_shape.1 { + let next_avg_grid = + buffers[buf_in[0]].get_grid_idx(out_grid_kind, (gx, gy + 1)); + if next_avg_grid != in_grid { + buffers[buf_in[0]].buffer_grid[next_avg_grid].mark_used(is_final); + } + } + if gy > 0 { + let prev_out_grid = buffers[*buf_out].get_grid_idx(out_grid_kind, (gx, gy - 1)); + buffers[*buf_out].buffer_grid[prev_out_grid].mark_used(is_final); + } } }; diff --git a/jxl/src/frame/render.rs b/jxl/src/frame/render.rs index b5782e92b..9b7f1672f 100644 --- a/jxl/src/frame/render.rs +++ b/jxl/src/frame/render.rs @@ -204,14 +204,26 @@ impl Frame { )?; } - // STEP 2: ensure that groups that will be re-rendered are marked as such. - // VarDCT data to be rendered. + // STEP 2: ensure that VarDCT groups that will be re-rendered are marked as such. for (g, _) in groups.iter() { self.groups_to_flush.insert(*g); pipeline!(self, p, p.mark_group_to_rerender(*g)); } - // Modular data to be re-rendered. - { + + let group_count = self.header.size_groups(); + let interleave_modular = !do_flush + && !self.was_flushed_once + && self.header.encoding == Encoding::Modular + && group_count.0 >= 8 + && group_count.1 >= 8 + && self + .lf_global + .as_ref() + .unwrap() + .modular_global + .outputs_are_gridded(); + + if !interleave_modular { let modular_global = &mut self.lf_global.as_mut().unwrap().modular_global; for (group, passes) in groups.iter() { for (pass, _) in passes.iter() { @@ -226,18 +238,52 @@ impl Frame { modular_global.process_output(&self.header, true, &mut pass_to_pipeline)?; } - // STEP 3: decode the groups, eagerly rendering VarDCT channels and noise. for (group, mut passes) in groups { + if interleave_modular { + let modular_global = &mut self.lf_global.as_mut().unwrap().modular_global; + for (pass, _) in passes.iter() { + modular_global.mark_group_to_be_read(2 + *pass, group); + } + let mut pass_to_pipeline = |_, group, _, _| { + self.groups_to_flush.insert(group); + pipeline!(self, p, p.mark_group_to_rerender(group)); + Ok(()) + }; + modular_global.process_output(&self.header, true, &mut pass_to_pipeline)?; + } + if self.decode_hf_group(group, &mut passes, &mut buffer_splitter, do_flush)? { self.changed_since_last_flush .insert((group, RenderUnit::VarDCT)); } + + if interleave_modular { + pipeline!(self, p, p.set_allow_pending_buffer_replacement(true)); + let modular_global = &mut self.lf_global.as_mut().unwrap().modular_global; + let mut pass_to_pipeline = |chan, group, complete, image: Option>| { + self.changed_since_last_flush + .insert((group, RenderUnit::Modular(chan))); + pipeline!( + self, + p, + p.set_buffer_for_group( + chan, + group, + complete, + image.unwrap(), + &mut buffer_splitter + )? + ); + Ok(()) + }; + let result = + modular_global.process_output(&self.header, false, &mut pass_to_pipeline); + pipeline!(self, p, p.set_allow_pending_buffer_replacement(false)); + result?; + } } - // STEP 4: process all modular transforms that can now be processed, - // flushing buffers that will not be used again, if either we are forcing a render now - // or we are done with the file. - if self.incomplete_groups == 0 || do_flush { + if !interleave_modular && (self.incomplete_groups == 0 || do_flush) { let modular_global = &mut self.lf_global.as_mut().unwrap().modular_global; let mut pass_to_pipeline = |chan, group, complete, image: Option>| { self.changed_since_last_flush @@ -256,7 +302,9 @@ impl Frame { Ok(()) }; modular_global.process_output(&self.header, false, &mut pass_to_pipeline)?; + } + if self.incomplete_groups == 0 || do_flush { // STEP 5: re-render VarDCT/noise data in rendered groups for which it was // not rendered, or re-send to pipeline modular channels that were not // updated in those groups. diff --git a/jxl/src/render/low_memory_pipeline/group_scheduler.rs b/jxl/src/render/low_memory_pipeline/group_scheduler.rs index abc810ef8..9575292b2 100644 --- a/jxl/src/render/low_memory_pipeline/group_scheduler.rs +++ b/jxl/src/render/low_memory_pipeline/group_scheduler.rs @@ -26,10 +26,13 @@ pub(super) struct InputBuffer { } impl InputBuffer { - pub(super) fn set_buffer(&mut self, chan: usize, buf: OwnedRawImage) { - assert!(self.data[chan].is_none()); + pub(super) fn set_buffer(&mut self, chan: usize, buf: OwnedRawImage, replace: bool) { + if self.data[chan].is_none() { + self.ready_channels += 1; + } else { + assert!(replace); + } self.data[chan] = Some(buf); - self.ready_channels += 1; } pub(super) fn new(num_channels: usize) -> Self { @@ -113,6 +116,9 @@ impl LowMemoryRenderPipeline { } fn store_scratch_buffer(&mut self, channel: usize, kind: usize, image: OwnedRawImage) { + if kind == 0 && self.allow_pending_buffer_replacement { + return; + } self.scratch_channel_buffers[channel * 3 + kind].push(image) } diff --git a/jxl/src/render/low_memory_pipeline/mod.rs b/jxl/src/render/low_memory_pipeline/mod.rs index 761175a12..808a676a2 100644 --- a/jxl/src/render/low_memory_pipeline/mod.rs +++ b/jxl/src/render/low_memory_pipeline/mod.rs @@ -61,6 +61,7 @@ pub struct LowMemoryRenderPipeline { // could be reused to store group data for that channel. // Indexed by [3*channel] = center, [3*channel+1] = topbottom, [3*channel+2] = leftright. scratch_channel_buffers: Vec>, + allow_pending_buffer_replacement: bool, } impl RenderPipeline for LowMemoryRenderPipeline { @@ -282,9 +283,14 @@ impl RenderPipeline for LowMemoryRenderPipeline { opaque_alpha_buffers, sorted_buffer_indices, scratch_channel_buffers: (0..nc * 3).map(|_| vec![]).collect(), + allow_pending_buffer_replacement: false, }) } + fn set_allow_pending_buffer_replacement(&mut self, allow: bool) { + self.allow_pending_buffer_replacement = allow; + } + #[instrument(skip_all, err)] fn get_buffer(&mut self, channel: usize) -> Result> { if let Some(b) = self.maybe_get_scratch_buffer(channel, 0) { @@ -309,7 +315,9 @@ impl RenderPipeline for LowMemoryRenderPipeline { channel, T::DATA_TYPE_ID, ); - self.input_buffers[group_id].set_buffer(channel, buf.into_raw()); + let replace = self.allow_pending_buffer_replacement + && !self.shared.group_chan_complete[group_id][channel]; + self.input_buffers[group_id].set_buffer(channel, buf.into_raw(), replace); self.shared.group_chan_complete[group_id][channel] = complete; self.render_with_new_group(group_id, buffer_splitter)?; diff --git a/jxl/src/render/mod.rs b/jxl/src/render/mod.rs index 5748513ba..63657285c 100644 --- a/jxl/src/render/mod.rs +++ b/jxl/src/render/mod.rs @@ -124,6 +124,8 @@ pub(crate) trait RenderPipeline: Sized { /// pass, a new buffer, or a re-used buffer from i.e. previously decoded frames. fn get_buffer(&mut self, channel: usize) -> Result>; + fn set_allow_pending_buffer_replacement(&mut self, _allow: bool) {} + /// Gives back the buffer for a channel and group to the render pipeline, marking whether /// this will be the last time that this function is called for this group. fn set_buffer_for_group( diff --git a/jxl_cli/benches/decode.rs b/jxl_cli/benches/decode.rs index a066c146e..f15a4bccd 100644 --- a/jxl_cli/benches/decode.rs +++ b/jxl_cli/benches/decode.rs @@ -68,6 +68,7 @@ fn decode_benches(c: &mut Criterion) { false, None, false, + false, ) .unwrap(); }) diff --git a/jxl_cli/src/dec/mod.rs b/jxl_cli/src/dec/mod.rs index d8e5e737f..b1d6b4f09 100644 --- a/jxl_cli/src/dec/mod.rs +++ b/jxl_cli/src/dec/mod.rs @@ -139,6 +139,7 @@ pub fn decode_frames( linear_output: bool, render_interval: Option, allow_partial_files: bool, + store_partial_renders: bool, ) -> Result<(DecodeOutput, Duration)> { let start = Instant::now(); @@ -282,7 +283,7 @@ pub fn decode_frames( // render and retry. if render_interval.is_some() && input.available_bytes()? > 0 { has_rendered_data |= fallback.flush_pixels(&mut output_bufs)?; - if has_rendered_data { + if has_rendered_data && store_partial_renders { partial_renders.push( outputs .iter() @@ -332,7 +333,7 @@ pub fn decode_frames( // render and retry. if render_interval.is_some() && input.available_bytes()? > 0 { has_rendered_data |= fallback.flush_pixels(&mut output_bufs)?; - if has_rendered_data { + if has_rendered_data && store_partial_renders { partial_renders.push( outputs .iter() diff --git a/jxl_cli/src/lib.rs b/jxl_cli/src/lib.rs index 4f4b627ba..bf71087e0 100644 --- a/jxl_cli/src/lib.rs +++ b/jxl_cli/src/lib.rs @@ -77,6 +77,7 @@ mod tests { false, None, false, + false, ) .unwrap() .0 @@ -188,6 +189,7 @@ mod tests { false, None, false, + false, ) .unwrap(); } diff --git a/jxl_cli/src/main.rs b/jxl_cli/src/main.rs index d8abf491c..64faf4e94 100644 --- a/jxl_cli/src/main.rs +++ b/jxl_cli/src/main.rs @@ -164,6 +164,7 @@ fn main() -> Result<()> { let linear_output = matches!(output_format, Some(OutputFormat::Exr)); #[cfg(not(feature = "exr"))] let linear_output = false; + let store_partial_renders = output_format.is_some() && opt.render_interval.is_some(); let (mut output, duration) = dec::decode_frames( $input, options(skip_preview), @@ -176,6 +177,7 @@ fn main() -> Result<()> { linear_output, opt.render_interval, opt.allow_partial_files, + store_partial_renders, )?; if opt.preview { output.frames.truncate(1);