From ba815facdd2d4278c158b24d060bd9c4123eda81 Mon Sep 17 00:00:00 2001 From: kimikage Date: Wed, 29 May 2024 02:18:16 +0900 Subject: [PATCH] Improve type handling to work on 32-bit systems --- src/Inflate.jl | 346 +++++++++++++++++++++++++++++-------------------- 1 file changed, 203 insertions(+), 143 deletions(-) diff --git a/src/Inflate.jl b/src/Inflate.jl index c6e0008..aa40be1 100644 --- a/src/Inflate.jl +++ b/src/Inflate.jl @@ -33,11 +33,11 @@ module Inflate export inflate, inflate_zlib, inflate_gzip, InflateStream, InflateZlibStream, InflateGzipStream -# Huffman codes are internally represented by Vector{Vector{Int}}, +# Huffman codes are internally represented by Vector{Vector{UInt16}}, # where code[k] are a vector of the values with code words of length # k. Codes are assigned in order from shorter to longer codes and in # the order listed. E.g. -# [[], [2, 7], [1, 3, 5], [6, 4]] +# Vector{UInt16}[[], [2, 7], [1, 3, 5], [6, 4]] # would be the code # 00 - 2 # 01 - 7 @@ -47,31 +47,23 @@ export inflate, inflate_zlib, inflate_gzip, # 1110 - 6 # 1111 - 4 -const fixed_literal_or_length_table = (Vector{Int})[Int[], - Int[], - Int[], - Int[], - Int[], - Int[], - Int[256:279;], - Int[0:143; 280:287], - Int[144:255;]] - -const fixed_distance_table = (Vector{Int})[Int[], - Int[], - Int[], - Int[], - Int[0:31;]] +const fixed_literal_or_length_table = Vector{UInt16}[ + [], [], [], [], [], [], + [0x100:0x117;], # 7-bit + [0x000:0x08f; 0x118:0x11f], # 8-bit + [0x090:0x0ff;]] # 9-bit + +const fixed_distance_table = Vector{UInt16}[[], [], [], [], [0x000:0x01f;]] abstract type AbstractInflateData end mutable struct InflateData <: AbstractInflateData bytes::Vector{UInt8} - current_byte::Int + current_byte::UInt8 bytepos::Int bitpos::Int - literal_or_length_code::Vector{Vector{Int}} - distance_code::Vector{Vector{Int}} + literal_or_length_code::Vector{Vector{UInt16}} + distance_code::Vector{Vector{UInt16}} update_input_crc::Bool crc::UInt32 end @@ -81,6 +73,9 @@ function InflateData(source::Vector{UInt8}) fixed_distance_table, false, init_crc()) end +""" + get_input_byte(data::InflateData) -> UInt8 +""" function get_input_byte(data::InflateData) byte = data.bytes[data.bytepos] data.bytepos += 1 @@ -92,34 +87,47 @@ end # This isn't called when reading Gzip header, so no need to # consider updating crc. +""" + get_input_bytes(data::InflateData, n::Int) -> SubArray{UInt8, 1} +""" function get_input_bytes(data::InflateData, n::Int) bytes = @view data.bytes[data.bytepos:(data.bytepos + n - 1)] data.bytepos += n return bytes end +""" + getbit(data::AbstractInflateData) -> Bool +""" function getbit(data::AbstractInflateData) if data.bitpos == 0 - data.current_byte = Int(get_input_byte(data)) - end - b = data.current_byte & 1 - data.bitpos += 1 - if data.bitpos == 8 - data.bitpos = 0 - else - data.current_byte >>= 1 + data.current_byte = get_input_byte(data) end + b = data.current_byte % Bool + data.current_byte >>= 0x1 + data.bitpos = (data.bitpos + 1) & 0x7 return b end +""" + getbits(data::AbstractInflateData, n::Int) -> UInt32 +""" function getbits(data::AbstractInflateData, n::Int) - b = 0 - for i = 0:(n-1) - b |= getbit(data) << i + b = UInt32(0) + for i = 0x00:UInt8(n-1) + b |= UInt32(getbit(data)) << i end return b end +""" + getbitsint(data::AbstractInflateData, n::Int) -> Int +""" +function getbitsint(data::AbstractInflateData, n::Int) + n < 32 || error("too long bit length `n`") + return getbits(data, n) % Int +end + function skip_bits_to_byte_boundary(data::AbstractInflateData) data.bitpos = 0 return @@ -127,57 +135,75 @@ end # It is the responsibility of the caller to make sure that bitpos is # at zero, e.g. by calling skip_bits_to_byte_boundary. +""" + get_aligned_byte(data::AbstractInflateData) -> UInt8 +""" function get_aligned_byte(data::AbstractInflateData) return get_input_byte(data) end +""" + get_value_from_code(data::AbstractInflateData, code) -> UInt16 +""" function get_value_from_code(data::AbstractInflateData, - code::Vector{Vector{Int}}) + code::Vector{Vector{UInt16}}) v = 0 - for i = 1:length(code) - v = (v << 1) | getbit(data) - if v < length(code[i]) - return code[i][1 + v] + for c in code + v = (v << 0x1) | getbit(data) + if v < length(c) + return c[1 + v] end - v -= length(code[i]) + v -= length(c) end error("incomplete code table") end +""" + get_literal_or_length(data::AbstractInflateData) -> UInt16 +""" function get_literal_or_length(data::AbstractInflateData) return get_value_from_code(data, data.literal_or_length_code) end -const base_length = [11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227] -const extra_length_bits = [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5] +const base_length = Tuple(UInt8[11:2:17; 19:4:31; 35:8:59; 67:16:115; 131:32:227]) +""" + getlength(data::AbstractInflateData, v::Int) -> Int +""" function getlength(data::AbstractInflateData, v::Int) if v <= 264 return v - 254 elseif v <= 284 - return base_length[v - 264] + getbits(data, extra_length_bits[v - 264]) + extra_length_bits = (v - 264 + 3) >> 0x2 + return base_length[v - 264] + getbitsint(data, extra_length_bits) else return 258 end end +""" + getdist(data::AbstractInflateData) -> Int +""" function getdist(data::AbstractInflateData) - b = get_value_from_code(data, data.distance_code) + b = Int(get_value_from_code(data, data.distance_code)) if b <= 3 return b + 1 else - extra_bits = fld(b - 2, 2) - return 1 + ((2 + b % 2) << extra_bits) + getbits(data, extra_bits) + extra_bits = (b - 2) >> 0x1 + return 1 + ((2 + b & 0x1) << extra_bits) + getbitsint(data, extra_bits) end end +""" + transform_code_lengths_to_code(code_lengths::Vector{Int}) -> Vector{Vector{UInt16}} +""" function transform_code_lengths_to_code(code_lengths::Vector{Int}) - code = Vector{Int}[] + code = Vector{UInt16}[] for i = 1:length(code_lengths) n = code_lengths[i] if n > 0 while n > length(code) - push!(code, Int[]) + push!(code, UInt16[]) end push!(code[n], i - 1) end @@ -185,34 +211,34 @@ function transform_code_lengths_to_code(code_lengths::Vector{Int}) return code end -const order = [16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15] +const order = (16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15) function read_code_tables(data::AbstractInflateData) - hlit = getbits(data, 5) + 257 - hdist = getbits(data, 5) + 1 - hclen = getbits(data, 4) + 4 + hlit = getbitsint(data, 5) + 257 + hdist = getbitsint(data, 5) + 1 + hclen = getbitsint(data, 4) + 4 code_length_code_lengths = zeros(Int, 19) for i = 1:hclen - code_length_code_lengths[1 + order[i]] = getbits(data, 3) + code_length_code_lengths[1 + order[i]] = getbits(data, 3) % Int end code_length_code = transform_code_lengths_to_code(code_length_code_lengths) code_lengths = zeros(Int, hlit + hdist) i = 1 while i <= hlit + hdist - c = get_value_from_code(data, code_length_code) + c = get_value_from_code(data, code_length_code) % Int n = 1 l = 0 if c < 16 l = c elseif c == 16 - n = 3 + getbits(data, 2) + n = 3 + getbitsint(data, 2) l = code_lengths[i-1] elseif c == 17 - n = 3 + getbits(data, 3) + n = 3 + getbitsint(data, 3) else # A code of length 19 can only yield values between 0 and # 18 so we can only get here if c == 18. - n = 11 + getbits(data, 7) + n = 11 + getbitsint(data, 7) end code_lengths[i:(i+n-1)] .= l i += n @@ -225,21 +251,21 @@ function _inflate(data::InflateData) out = UInt8[] final_block = false while !final_block - final_block = getbits(data, 1) == 1 - compression_mode = getbits(data, 2) - if compression_mode == 0 + final_block = getbit(data) + compression_mode = getbits(data, 2) % UInt8 + if compression_mode === 0x0 skip_bits_to_byte_boundary(data) len = getbits(data, 16) nlen = getbits(data, 16) if len ⊻ nlen != 0xffff error("corrupted data") end - append!(out, get_input_bytes(data, len)) + append!(out, get_input_bytes(data, len % Int)) continue - elseif compression_mode == 1 + elseif compression_mode === 0x1 data.literal_or_length_code = fixed_literal_or_length_table data.distance_code = fixed_distance_table - elseif compression_mode == 2 + elseif compression_mode === 0x2 read_code_tables(data) else error("invalid block compression mode 3") @@ -247,12 +273,12 @@ function _inflate(data::InflateData) while true v = get_literal_or_length(data) - if v < 256 - push!(out, UInt8(v)) - elseif v == 256 + if v < 0x100 + push!(out, v % UInt8) + elseif v == 0x100 break else - length = getlength(data, v) + length = getlength(data, v % Int) distance = getdist(data) if length <= distance append!(out, @view out[(end - distance + 1):(end - distance + length)]) @@ -268,67 +294,90 @@ function _inflate(data::InflateData) return out end +""" + init_adler() -> UInt32 +""" function init_adler() - return (0, 1) + return 0x0000_0001 end -function update_adler(adler::Tuple{Int, Int}, x::UInt8) - s2, s1 = adler +""" + update_adler(adler::UInt32, x::UInt8) -> UInt32 +""" +function update_adler(adler::UInt32, x::UInt8) + s2, s1 = adler >> 0x10, adler & 0xffff s1 += x - if s1 >= 65521 - s1 -= 65521 + if s1 >= UInt32(65521) + s1 -= UInt32(65521) end s2 += s1 - if s2 >= 65521 - s2 -= 65521 + if s2 >= UInt32(65521) + s2 -= UInt32(65521) end - return (s2, s1) -end - -function finish_adler(adler) - s2, s1 = adler - return (UInt32(s2) << 16) | UInt32(s1) + return (s2 << 0x10) | s1 end +""" + compute_adler_checksum(x::Vector{UInt8}) -> UInt32 +""" @inline function compute_adler_checksum(x::Vector{UInt8}) adler = init_adler() for b in x adler = update_adler(adler, b) end - return finish_adler(adler) + return adler end -const crc_table = zeros(UInt32, 256) +""" + read_stored_adler(data::AbstractInflateData) -> UInt32 +""" +function read_stored_adler(data::AbstractInflateData) + skip_bits_to_byte_boundary(data) + stored_adler = UInt32(0) + for i = 1:4 + stored_adler = stored_adler << 0x8 | get_aligned_byte(data) + end + return stored_adler +end -function make_crc_table() - for n = 1:256 - c = UInt32(n - 1) - for k = 1:8 - if (c & 0x00000001) != 0 - c = 0xedb88320 ⊻ (c >> 1) - else - c >>= 1 - end +function make_crc_table(n) + c = UInt32(n) + for k = 1:8 + if (c & 0x00000001) != 0 + c = 0xedb88320 ⊻ (c >> 0x1) + else + c >>= 0x1 end - crc_table[n] = c end + return c end +const crc_table = [make_crc_table(n) for n = 0x00:0xff] + +""" + init_crc() -> UInt32 +""" function init_crc() - if crc_table[1] == 0 - make_crc_table() - end return 0xffffffff end +""" + update_crc(c::UInt32, x::UInt8) -> UInt32 +""" @inline function update_crc(c::UInt32, x::UInt8) - @inbounds return crc_table[1 + ((c ⊻ x) & 0xff)] ⊻ (c >> 8) + @inbounds return crc_table[1 + ((c % UInt8) ⊻ x)] ⊻ (c >> 0x8) end -function finish_crc(c) +""" + finish_crc(c::UInt32) -> UInt32 +""" +function finish_crc(c::UInt32) return c ⊻ 0xffffffff end +""" + crc(x::Vector{UInt8}) -> UInt32 +""" function crc(x::Vector{UInt8}) c = init_crc() for b in x @@ -337,12 +386,15 @@ function crc(x::Vector{UInt8}) return finish_crc(c) end +""" + read_zero_terminated_data(data::AbstractInflateData) -> Vector{UInt8} +""" function read_zero_terminated_data(data::AbstractInflateData) s = UInt8[] while true c = get_aligned_byte(data) push!(s, c) - if c == 0 + if c === 0x00 break end end @@ -353,19 +405,19 @@ function read_zlib_header(data::AbstractInflateData) CMF = get_aligned_byte(data) FLG = get_aligned_byte(data) CM = CMF & 0x0f - CINFO = CMF >> 4 - FLEVEL = FLG >> 6 - FDICT = (FLG >> 5) & 0x01 - if CM != 8 + CINFO = CMF >> 0x4 + FLEVEL = FLG >> 0x6 + FDICT = (FLG >> 0x5) & 0x01 + if CM !== 0x8 error("unsupported compression method") end - if CINFO > 7 + if CINFO > 0x7 error("invalid LZ77 window size") end - if FDICT != 0 + if FDICT !== 0x00 error("preset dictionary not supported") end - if mod((UInt(CMF) << 8) | FLG, 31) != 0 + if mod((UInt16(CMF) << 0x8) | FLG, 31) != 0 error("header checksum error") end end @@ -374,11 +426,11 @@ function read_gzip_header(data::AbstractInflateData, headers, compute_crc) data.update_input_crc = compute_crc ID1 = get_aligned_byte(data) ID2 = get_aligned_byte(data) - if ID1 != 0x1f || ID2 != 0x8b + if ID1 !== 0x1f || ID2 !== 0x8b error("not gzipped data") end CM = get_aligned_byte(data) - if CM != 8 + if CM !== 0x8 error("unsupported compression method") end FLG = get_aligned_byte(data) @@ -391,7 +443,7 @@ function read_gzip_header(data::AbstractInflateData, headers, compute_crc) headers["os"] = OS end - if (FLG & 0x04) != 0 # FLG.FEXTRA + if (FLG & 0x04) !== 0x0 # FLG.FEXTRA xlen = getbits(data, 16) if headers != nothing headers["fextra"] = zeros(UInt8, xlen) @@ -405,30 +457,30 @@ function read_gzip_header(data::AbstractInflateData, headers, compute_crc) end end - if (FLG & 0x08) != 0 # FLG.FNAME + if (FLG & 0x08) !== 0x0 # FLG.FNAME name = read_zero_terminated_data(data) if headers != nothing headers["fname"] = String(name[1:end-1]) end end - if (FLG & 0x10) != 0 # FLG.FCOMMENT + if (FLG & 0x10) !== 0x0 # FLG.FCOMMENT comment = read_zero_terminated_data(data) if headers != nothing headers["fcomment"] = String(comment[1:end-1]) end end - if (FLG & 0xe0) != 0 + if (FLG & 0xe0) !== 0x0 error("reserved FLG bit set") end data.update_input_crc = false - if (FLG & 0x02) != 0 # FLG.FHCRC - crc16 = getbits(data, 16) + if (FLG & 0x02) !== 0x0 # FLG.FHCRC + crc16 = getbits(data, 16) % UInt16 if compute_crc header_crc = finish_crc(data.crc) - if crc16 != (header_crc & 0xffff) + if crc16 !== (header_crc % UInt16) error("corrupted data, header crc check failed") end end @@ -465,11 +517,7 @@ function inflate_zlib(source::Vector{UInt8}; ignore_checksum = false) out = _inflate(data) - skip_bits_to_byte_boundary(data) - stored_adler = 0 - for i = [24, 16, 8, 0] - stored_adler |= Int(get_aligned_byte(data)) << i - end + stored_adler = read_stored_adler(data) if !ignore_checksum && compute_adler_checksum(out) != stored_adler error("corrupted data, adler checksum error") end @@ -516,7 +564,7 @@ function inflate_gzip(source::Vector{UInt8}; headers = nothing, end """ - inflate_gzip(filename::AbstractString) + inflate_gzip(filename::AbstractString) -> String Convenience wrapper for reading a gzip compressed text file. The result is returned as a string. @@ -536,10 +584,10 @@ mutable struct StreamingInflateData <: AbstractInflateData stream::IO input_buffer::Vector{UInt8} input_buffer_pos::Int - current_byte::Int + current_byte::UInt8 bitpos::Int - literal_or_length_code::Vector{Vector{Int}} - distance_code::Vector{Vector{Int}} + literal_or_length_code::Vector{Vector{UInt16}} + distance_code::Vector{Vector{UInt16}} output_buffer::Vector{UInt8} write_pos::Int read_pos::Int @@ -559,9 +607,12 @@ function StreamingInflateData(stream::IO) true, 0, -2, false, false, init_crc()) end +""" + get_input_byte(data::StreamingInflateData) -> UInt8 +""" function get_input_byte(data::StreamingInflateData) if data.input_buffer_pos > length(data.input_buffer) - data.input_buffer = read(data.stream, 65536) + data.input_buffer = read(data.stream, buffer_size) data.input_buffer_pos = 1 end byte = data.input_buffer[data.input_buffer_pos] @@ -574,9 +625,12 @@ end # This isn't called when reading Gzip header, so no need to # consider updating crc. +""" + get_input_bytes(data::StreamingInflateData, n) -> SubArray{UInt8, 1} +""" function get_input_bytes(data::StreamingInflateData, n) if data.input_buffer_pos > length(data.input_buffer) - data.input_buffer = read(data.stream, 65536) + data.input_buffer = read(data.stream, buffer_size) data.input_buffer_pos = 1 end n = min(n, length(data.input_buffer) - data.input_buffer_pos + 1) @@ -619,7 +673,7 @@ Reference: [RFC 1950](https://www.ietf.org/rfc/rfc1950.txt) """ mutable struct InflateZlibStream <: AbstractInflateStream data::StreamingInflateData - adler::Tuple{Int, Int} + adler::UInt32 compute_adler::Bool end @@ -675,12 +729,8 @@ end read_trailer(stream::InflateStream) = nothing function read_trailer(stream::InflateZlibStream) - computed_adler = finish_adler(stream.adler) - skip_bits_to_byte_boundary(stream.data) - stored_adler = 0 - for i = [24, 16, 8, 0] - stored_adler |= Int(get_aligned_byte(stream.data)) << i - end + computed_adler = stream.adler + stored_adler = read_stored_adler(stream.data) if stream.compute_adler && computed_adler != stored_adler error("corrupted data, adler checksum error") end @@ -690,15 +740,18 @@ function read_trailer(stream::InflateGzipStream) crc = finish_crc(stream.crc) skip_bits_to_byte_boundary(stream.data) crc32 = getbits(stream.data, 32) - if stream.compute_crc && crc32 != crc + if stream.compute_crc && crc32 !== crc error("corrupted data, crc check failed") end isize = getbits(stream.data, 32) - if isize != stream.num_bytes % UInt32 + if isize !== stream.num_bytes % UInt32 error("corrupted data, length check failed") end end +""" + read_output_byte(data::StreamingInflateData) -> UInt8 +""" @inline function read_output_byte(data::StreamingInflateData) @inbounds byte = data.output_buffer[data.read_pos] data.read_pos += 1 @@ -728,6 +781,9 @@ function read_output_bytes!(data::StreamingInflateData, out, i) return n end +""" + read_output_byte(stream::AbstractInflateStream) -> UInt8 +""" function read_output_byte(stream::AbstractInflateStream) byte = read_output_byte(stream.data) getbyte(stream) @@ -808,9 +864,12 @@ function write_to_buffer(stream::InflateGzipStream, x::AbstractVector{UInt8}) stream.num_bytes += length(x) end +""" + getbyte(stream::AbstractInflateStream) +""" function getbyte(stream::AbstractInflateStream) if stream.data.write_pos != stream.data.read_pos - return + return nothing end if stream.data.pending_bytes > 0 @@ -834,16 +893,16 @@ function getbyte(stream::AbstractInflateStream) stream.data.pending_bytes -= n write_to_buffer(stream, @view stream.data.output_buffer[pos:(pos + n - 1)]) end - return + return nothing end if stream.data.waiting_for_new_block if stream.data.reading_final_block - return + return nothing end - stream.data.reading_final_block = getbits(stream.data, 1) == 1 - compression_mode = getbits(stream.data, 2) - if compression_mode == 0 + stream.data.reading_final_block = getbit(stream.data) + compression_mode = getbits(stream.data, 2) % UInt8 + if compression_mode === 0x0 skip_bits_to_byte_boundary(stream.data) len = getbits(stream.data, 16) nlen = getbits(stream.data, 16) @@ -853,11 +912,11 @@ function getbyte(stream::AbstractInflateStream) stream.data.distance = -1 stream.data.pending_bytes = len getbyte(stream) - return - elseif compression_mode == 1 + return nothing + elseif compression_mode === 0x1 stream.data.literal_or_length_code = fixed_literal_or_length_table stream.data.distance_code = fixed_distance_table - elseif compression_mode == 2 + elseif compression_mode === 0x2 read_code_tables(stream.data) else error("invalid block compression mode 3") @@ -866,16 +925,17 @@ function getbyte(stream::AbstractInflateStream) end v = get_literal_or_length(stream.data) - if v < 256 + if v < 0x100 write_to_buffer(stream, UInt8(v)) - elseif v == 256 + elseif v == 0x100 stream.data.waiting_for_new_block = true getbyte(stream) else - stream.data.pending_bytes = getlength(stream.data, v) + stream.data.pending_bytes = getlength(stream.data, v % Int) stream.data.distance = getdist(stream.data) getbyte(stream) end + return nothing end function Base.eof(stream::AbstractInflateStream)