diff --git a/ext/zstdruby/zstdruby.c b/ext/zstdruby/zstdruby.c index 6f6e881..512aa7b 100644 --- a/ext/zstdruby/zstdruby.c +++ b/ext/zstdruby/zstdruby.c @@ -87,13 +87,13 @@ static VALUE rb_compress_using_dict(int argc, VALUE *argv, VALUE self) static VALUE decompress_buffered(ZSTD_DCtx* dctx, const char* input_data, size_t input_size) { - VALUE output_string = rb_str_new(NULL, 0); - ZSTD_outBuffer output = { NULL, 0, 0 }; - ZSTD_inBuffer input = { input_data, input_size, 0 }; + VALUE result = rb_str_new(0, 0); + while (input.pos < input.size) { + ZSTD_outBuffer output = { NULL, 0, 0 }; output.size += ZSTD_DStreamOutSize(); - rb_str_resize(output_string, output.size); + VALUE output_string = rb_str_new(NULL, output.size); output.dst = RSTRING_PTR(output_string); size_t ret = zstd_stream_decompress(dctx, &output, &input, false); @@ -101,10 +101,10 @@ static VALUE decompress_buffered(ZSTD_DCtx* dctx, const char* input_data, size_t ZSTD_freeDCtx(dctx); rb_raise(rb_eRuntimeError, "%s: %s", "ZSTD_decompressStream failed", ZSTD_getErrorName(ret)); } + rb_str_cat(result, output.dst, output.pos); } - rb_str_resize(output_string, output.pos); ZSTD_freeDCtx(dctx); - return output_string; + return result; } static VALUE rb_decompress(int argc, VALUE *argv, VALUE self) @@ -134,7 +134,7 @@ static VALUE rb_decompress(int argc, VALUE *argv, VALUE self) VALUE output = rb_str_new(NULL, uncompressed_size); char* output_data = RSTRING_PTR(output); - size_t const decompress_size = zstd_decompress(dctx, output_data, uncompressed_size, input_data, input_size, false); + size_t const decompress_size = zstd_decompress(dctx, output_data, uncompressed_size, input_data, input_size, false); if (ZSTD_isError(decompress_size)) { rb_raise(rb_eRuntimeError, "%s: %s", "decompress error", ZSTD_getErrorName(decompress_size)); } diff --git a/spec/zstd-ruby_spec.rb b/spec/zstd-ruby_spec.rb index 8f80f65..ba39ea8 100644 --- a/spec/zstd-ruby_spec.rb +++ b/spec/zstd-ruby_spec.rb @@ -87,6 +87,14 @@ def to_str expect(decompressed.force_encoding('UTF-8')).to eq('あああ') end + it 'should work hash equal streaming compress' do + simple_compressed = Zstd.compress('あ') + stream = Zstd::StreamingCompress.new + stream << "あ" + streaming_compressed = stream.finish + expect(Zstd.decompress(simple_compressed).force_encoding('UTF-8').hash).to eq(Zstd.decompress(streaming_compressed).force_encoding('UTF-8').hash) + end + it 'should raise exception with unsupported object' do expect { Zstd.decompress(Object.new) }.to raise_error(TypeError) end @@ -111,4 +119,3 @@ def to_str end end end -