Skip to content

Commit

Permalink
Add --lz77-mode argument (#272)
Browse files Browse the repository at this point in the history
* Add `--lz77-mode` argument

* jxl-oxide-cli: Reword argument description
  • Loading branch information
tirr-c committed Mar 4, 2024
1 parent 06c7c87 commit 18af5ce
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 33 deletions.
8 changes: 8 additions & 0 deletions crates/jxl-bitstream/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,11 @@ impl std::ops::DerefMut for Name {
&mut self.0
}
}

#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default)]
pub enum Lz77Mode {
#[default]
Auto,
Standard,
Legacy,
}
17 changes: 16 additions & 1 deletion crates/jxl-bitstream/src/memory.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{Bundle, Error, Result};
use crate::{Bundle, Error, Lz77Mode, Result};

/// Bitstream reader with borrowed in-memory buffer.
///
Expand All @@ -11,6 +11,10 @@ pub struct Bitstream<'buf> {
buf: u64,
num_read_bits: usize,
remaining_buf_bits: usize,
/// LZ77 dist_multiplier mode.
///
/// It shouldn't be here, this is a hack to avoid API breakage.
lz77_mode: Lz77Mode,
}

impl std::fmt::Debug for Bitstream<'_> {
Expand Down Expand Up @@ -40,6 +44,7 @@ impl<'buf> Bitstream<'buf> {
buf: 0,
num_read_bits: 0,
remaining_buf_bits: 0,
lz77_mode: Lz77Mode::Auto,
}
}

Expand All @@ -48,6 +53,16 @@ impl<'buf> Bitstream<'buf> {
pub fn num_read_bits(&self) -> usize {
self.num_read_bits
}

#[inline]
pub fn lz77_mode(&self) -> Lz77Mode {
self.lz77_mode
}

#[inline]
pub fn set_lz77_mode(&mut self, lz77_mode: Lz77Mode) {
self.lz77_mode = lz77_mode;
}
}

impl Bitstream<'_> {
Expand Down
18 changes: 13 additions & 5 deletions crates/jxl-frame/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::collections::BTreeMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

use jxl_bitstream::{read_bits, Bitstream, Bundle};
use jxl_bitstream::{read_bits, Bitstream, Bundle, Lz77Mode};
use jxl_grid::AllocTracker;
use jxl_image::ImageHeader;

Expand Down Expand Up @@ -51,6 +51,7 @@ pub struct Frame {
all_group_offsets: AllGroupOffsets,
reading_data_index: usize,
pass_shifts: BTreeMap<u32, (i32, i32)>,
lz77_mode: Lz77Mode,
}

#[derive(Debug, Default)]
Expand Down Expand Up @@ -198,6 +199,7 @@ impl Bundle<FrameContext<'_>> for Frame {
all_group_offsets: AllGroupOffsets::default(),
reading_data_index: 0,
pass_shifts,
lz77_mode: bitstream.lz77_mode(),
})
}
}
Expand Down Expand Up @@ -270,6 +272,7 @@ impl Frame {
let group = self.data.first()?;
let loaded = self.reading_data_index != 0;
let mut bitstream = Bitstream::new(&group.bytes);
bitstream.set_lz77_mode(self.lz77_mode);
let lf_global = LfGlobal::parse(
&mut bitstream,
LfGlobalParams::new(
Expand Down Expand Up @@ -299,6 +302,7 @@ impl Frame {
let allow_partial = group.bytes.len() < group.toc_group.size as usize;

let mut bitstream = Bitstream::new(&group.bytes);
bitstream.set_lz77_mode(self.lz77_mode);
LfGlobal::parse(
&mut bitstream,
LfGlobalParams::new(
Expand Down Expand Up @@ -330,6 +334,7 @@ impl Frame {
let group = self.data.first()?;
let loaded = self.reading_data_index != 0;
let mut bitstream = Bitstream::new(&group.bytes);
bitstream.set_lz77_mode(self.lz77_mode);
let offset = self.all_group_offsets.lf_group.load(Ordering::Relaxed);
if offset == 0 {
let lf_global = self.try_parse_lf_global::<S>().unwrap();
Expand Down Expand Up @@ -376,6 +381,7 @@ impl Frame {
let allow_partial = group.bytes.len() < group.toc_group.size as usize;

let mut bitstream = Bitstream::new(&group.bytes);
bitstream.set_lz77_mode(self.lz77_mode);
let result = LfGroup::parse(
&mut bitstream,
LfGroupParams {
Expand Down Expand Up @@ -410,6 +416,7 @@ impl Frame {
let group = self.data.first()?;
let loaded = self.reading_data_index != 0;
let mut bitstream = Bitstream::new(&group.bytes);
bitstream.set_lz77_mode(self.lz77_mode);
let offset = self.all_group_offsets.hf_global.load(Ordering::Relaxed);
let lf_global = if cached_lf_global.is_none() && (offset == 0 || !is_modular) {
match self.try_parse_lf_global()? {
Expand Down Expand Up @@ -499,6 +506,7 @@ impl Frame {
}

let mut bitstream = Bitstream::new(&group.bytes);
bitstream.set_lz77_mode(self.lz77_mode);
let lf_global = if cached_lf_global.is_none() {
match self.try_parse_lf_global()? {
Ok(lf_global) => Some(lf_global),
Expand Down Expand Up @@ -536,6 +544,7 @@ impl Frame {
let group = self.data.first()?;
let loaded = self.reading_data_index != 0;
let mut bitstream = Bitstream::new(&group.bytes);
bitstream.set_lz77_mode(self.lz77_mode);
let mut offset = self.all_group_offsets.pass_group.load(Ordering::Relaxed);
if offset == 0 {
let hf_global = self.try_parse_hf_global::<i32>(None)?;
Expand All @@ -560,10 +569,9 @@ impl Frame {
let group = self.data.get(idx)?;
let partial = group.bytes.len() < group.toc_group.size as usize;

Ok(PassGroupBitstream {
bitstream: Bitstream::new(&group.bytes),
partial,
})
let mut bitstream = Bitstream::new(&group.bytes);
bitstream.set_lz77_mode(self.lz77_mode);
Ok(PassGroupBitstream { bitstream, partial })
})
}
}
Expand Down
72 changes: 47 additions & 25 deletions crates/jxl-modular/src/image.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use jxl_bitstream::{unpack_signed, Bitstream};
use jxl_bitstream::{unpack_signed, Bitstream, Lz77Mode};
use jxl_coding::{DecoderRleMode, DecoderWithLz77, RleToken};
use jxl_grid::{AllocTracker, CutGrid, SimpleGrid};

Expand Down Expand Up @@ -529,39 +529,61 @@ impl<'dest, S: Sample> TransformedModularSubimage<'dest, S> {
.max()
.unwrap_or(0);

if dist_multiplier == dist_multiplier_legacy {
let lz77_decoder = decoder.as_with_lz77().unwrap();
self.decode_image_lz77(bitstream, stream_index, lz77_decoder, dist_multiplier)?;
decoder.finalize()?;
return Ok(());
}
let lz77_mode = if dist_multiplier == dist_multiplier_legacy {
Lz77Mode::Standard
} else {
bitstream.lz77_mode()
};

let tmp_bitstream = bitstream.clone();
let tmp_decoder = decoder.clone();

// Try standard method first
let lz77_decoder = decoder.as_with_lz77().unwrap();
let result =
self.decode_image_lz77(bitstream, stream_index, lz77_decoder, dist_multiplier);
if result.is_ok() && decoder.finalize().is_ok() {
return Ok(());
}
match lz77_mode {
Lz77Mode::Standard => {
self.decode_image_lz77(bitstream, stream_index, lz77_decoder, dist_multiplier)?;
decoder.finalize()?;
}
Lz77Mode::Legacy => {
self.decode_image_lz77(
bitstream,
stream_index,
lz77_decoder,
dist_multiplier_legacy,
)?;
decoder.finalize()?;
}
Lz77Mode::Auto => {
// Try standard method first
let result = self.decode_image_lz77(
bitstream,
stream_index,
lz77_decoder,
dist_multiplier,
);
if result.is_ok() && decoder.finalize().is_ok() {
return Ok(());
}

// Decode error with standard method
tracing::warn!("Invalid LZ77 stream, trying legacy method");
*bitstream = tmp_bitstream;
let mut decoder = tmp_decoder;
// Decode error with standard method
tracing::warn!("Invalid LZ77 stream, trying legacy method");
*bitstream = tmp_bitstream;
let mut decoder = tmp_decoder;

let lz77_decoder = decoder.as_with_lz77().unwrap();
self.decode_image_lz77(
bitstream,
stream_index,
lz77_decoder,
dist_multiplier_legacy,
)?;
decoder.finalize()?;
}
}

let lz77_decoder = decoder.as_with_lz77().unwrap();
self.decode_image_lz77(
bitstream,
stream_index,
lz77_decoder,
dist_multiplier_legacy,
)?;
decoder.finalize()?;
return Ok(());
}

let mut no_lz77_decoder = decoder.as_no_lz77().unwrap();

let mut predictor = PredictorState::new();
Expand Down
25 changes: 24 additions & 1 deletion crates/jxl-oxide-cli/src/commands/decode.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::path::PathBuf;

use clap::Parser;
use jxl_oxide::{CropInfo, EnumColourEncoding};
use jxl_oxide::{CropInfo, EnumColourEncoding, Lz77Mode};

#[derive(Debug, Parser)]
#[non_exhaustive]
Expand Down Expand Up @@ -70,6 +70,9 @@ pub struct DecodeArgs {
#[cfg(feature = "rayon")]
#[arg(short = 'j', long)]
pub num_threads: Option<usize>,
/// (unstable) LZ77 mode to use.
#[arg(value_enum, long, default_value_t = Lz77ModeArg::Auto)]
pub lz77_mode: Lz77ModeArg,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
Expand Down Expand Up @@ -126,3 +129,23 @@ fn parse_crop_info(s: &str) -> Result<CropInfo, std::num::ParseIntError> {
top: y,
})
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
pub enum Lz77ModeArg {
/// Try both Standard and Legacy method.
Auto,
/// Use spec-conforming Standard method.
Std,
/// Use Legacy method that emulates old libjxl.
Legacy,
}

impl From<Lz77ModeArg> for Lz77Mode {
fn from(value: Lz77ModeArg) -> Self {
match value {
Lz77ModeArg::Auto => Lz77Mode::Auto,
Lz77ModeArg::Std => Lz77Mode::Standard,
Lz77ModeArg::Legacy => Lz77Mode::Legacy,
}
}
}
4 changes: 3 additions & 1 deletion crates/jxl-oxide-cli/src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ pub fn handle_decode(args: DecodeArgs) -> Result<()> {
#[cfg(not(feature = "rayon"))]
let pool = JxlThreadPool::none();

let mut image_builder = JxlImage::builder().pool(pool.clone());
let mut image_builder = JxlImage::builder()
.pool(pool.clone())
.lz77_mode(args.lz77_mode.into());
if args.approx_memory_limit != 0 {
let tracker = AllocTracker::with_limit(args.approx_memory_limit);
image_builder = image_builder.alloc_tracker(tracker);
Expand Down
14 changes: 14 additions & 0 deletions crates/jxl-oxide/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ use jxl_bitstream::{Bitstream, Bundle};
use jxl_frame::FrameContext;
use jxl_render::{IndexedFrame, RenderContext};

pub use jxl_bitstream::Lz77Mode;
pub use jxl_color::header as color;
pub use jxl_color::{
ColorEncodingWithProfile, ColorManagementSystem, EnumColourEncoding, NullCms, RenderingIntent,
Expand Down Expand Up @@ -186,6 +187,7 @@ fn default_pool() -> JxlThreadPool {
pub struct JxlImageBuilder {
pool: Option<JxlThreadPool>,
tracker: Option<AllocTracker>,
lz77_mode: Lz77Mode,
}

impl JxlImageBuilder {
Expand All @@ -201,13 +203,20 @@ impl JxlImageBuilder {
self
}

/// Sets LZ77 `dist_multiplier` mode to use.
pub fn lz77_mode(mut self, lz77_mode: Lz77Mode) -> Self {
self.lz77_mode = lz77_mode;
self
}

/// Consumes the builder, and creates an empty, uninitialized JPEG XL image decoder.
pub fn build_uninit(self) -> UninitializedJxlImage {
UninitializedJxlImage {
pool: self.pool.unwrap_or_else(default_pool),
tracker: self.tracker,
reader: ContainerDetectingReader::new(),
buffer: Vec::new(),
lz77_mode: self.lz77_mode,
}
}

Expand Down Expand Up @@ -285,6 +294,7 @@ pub struct UninitializedJxlImage {
tracker: Option<AllocTracker>,
reader: ContainerDetectingReader,
buffer: Vec<u8>,
lz77_mode: Lz77Mode,
}

impl UninitializedJxlImage {
Expand All @@ -310,6 +320,7 @@ impl UninitializedJxlImage {
/// was given.
pub fn try_init(mut self) -> Result<InitializeResult> {
let mut bitstream = Bitstream::new(&self.buffer);
bitstream.set_lz77_mode(self.lz77_mode);
let image_header = match ImageHeader::parse(&mut bitstream, ()) {
Ok(x) => x,
Err(e) if e.unexpected_eof() => {
Expand Down Expand Up @@ -395,6 +406,7 @@ impl UninitializedJxlImage {
buffer: Vec::new(),
buffer_offset: bytes_read,
frame_offsets: Vec::new(),
lz77_mode: self.lz77_mode,
};
image.feed_bytes_inner(&self.buffer)?;

Expand Down Expand Up @@ -422,6 +434,7 @@ pub struct JxlImage {
buffer: Vec<u8>,
buffer_offset: usize,
frame_offsets: Vec<usize>,
lz77_mode: Lz77Mode,
}

impl JxlImage {
Expand Down Expand Up @@ -473,6 +486,7 @@ impl JxlImage {
let mut buf = &*self.buffer;
while !buf.is_empty() {
let mut bitstream = Bitstream::new(buf);
bitstream.set_lz77_mode(self.lz77_mode);
let frame = match self.ctx.load_frame_header(&mut bitstream) {
Ok(x) => x,
Err(e) if e.unexpected_eof() => {
Expand Down

0 comments on commit 18af5ce

Please sign in to comment.