From 0da14edee166ef9ef97cab058053ce2a21b5c51d Mon Sep 17 00:00:00 2001 From: Andy Balholm Date: Wed, 6 Mar 2019 17:08:24 -0800 Subject: [PATCH] Make BrotliDecoderState into a Reader. --- brotli_test.go | 396 +++++++++++++++++++++++++++++++++++++++++++++++++ decode.go | 108 +++++++------- reader.go | 94 ++++++++++++ state.go | 18 ++- 4 files changed, 556 insertions(+), 60 deletions(-) create mode 100644 brotli_test.go create mode 100644 reader.go diff --git a/brotli_test.go b/brotli_test.go new file mode 100644 index 0000000..006fe6e --- /dev/null +++ b/brotli_test.go @@ -0,0 +1,396 @@ +// Copyright 2016 Google Inc. All Rights Reserved. +// +// Distributed under MIT license. +// See file LICENSE for detail or copy at https://opensource.org/licenses/MIT + +package brotli + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "math" + "math/rand" + "testing" + "time" +) + +func checkCompressedData(compressedData, wantOriginalData []byte) error { + uncompressed, err := Decode(compressedData) + if err != nil { + return fmt.Errorf("brotli decompress failed: %v", err) + } + if !bytes.Equal(uncompressed, wantOriginalData) { + if len(wantOriginalData) != len(uncompressed) { + return fmt.Errorf(""+ + "Data doesn't uncompress to the original value.\n"+ + "Length of original: %v\n"+ + "Length of uncompressed: %v", + len(wantOriginalData), len(uncompressed)) + } + for i := range wantOriginalData { + if wantOriginalData[i] != uncompressed[i] { + return fmt.Errorf(""+ + "Data doesn't uncompress to the original value.\n"+ + "Original at %v is %v\n"+ + "Uncompressed at %v is %v", + i, wantOriginalData[i], i, uncompressed[i]) + } + } + } + return nil +} + +func TestEncoderNoWrite(t *testing.T) { + out := bytes.Buffer{} + e := NewWriter(&out, WriterOptions{Quality: 5}) + if err := e.Close(); err != nil { + t.Errorf("Close()=%v, want nil", err) + } + // Check Write after close. + if _, err := e.Write([]byte("hi")); err == nil { + t.Errorf("No error after Close() + Write()") + } +} + +func TestEncoderEmptyWrite(t *testing.T) { + out := bytes.Buffer{} + e := NewWriter(&out, WriterOptions{Quality: 5}) + n, err := e.Write([]byte("")) + if n != 0 || err != nil { + t.Errorf("Write()=%v,%v, want 0, nil", n, err) + } + if err := e.Close(); err != nil { + t.Errorf("Close()=%v, want nil", err) + } +} + +func TestWriter(t *testing.T) { + // Test basic encoder usage. + input := []byte("

Hello world

") + out := bytes.Buffer{} + e := NewWriter(&out, WriterOptions{Quality: 1}) + in := bytes.NewReader([]byte(input)) + n, err := io.Copy(e, in) + if err != nil { + t.Errorf("Copy Error: %v", err) + } + if int(n) != len(input) { + t.Errorf("Copy() n=%v, want %v", n, len(input)) + } + if err := e.Close(); err != nil { + t.Errorf("Close Error after copied %d bytes: %v", n, err) + } + if err := checkCompressedData(out.Bytes(), input); err != nil { + t.Error(err) + } +} + +func TestEncoderStreams(t *testing.T) { + // Test that output is streamed. + // Adjust window size to ensure the encoder outputs at least enough bytes + // to fill the window. + const lgWin = 16 + windowSize := int(math.Pow(2, lgWin)) + input := make([]byte, 8*windowSize) + rand.Read(input) + out := bytes.Buffer{} + e := NewWriter(&out, WriterOptions{Quality: 11, LGWin: lgWin}) + halfInput := input[:len(input)/2] + in := bytes.NewReader(halfInput) + + n, err := io.Copy(e, in) + if err != nil { + t.Errorf("Copy Error: %v", err) + } + + // We've fed more data than the sliding window size. Check that some + // compressed data has been output. + if out.Len() == 0 { + t.Errorf("Output length is 0 after %d bytes written", n) + } + if err := e.Close(); err != nil { + t.Errorf("Close Error after copied %d bytes: %v", n, err) + } + if err := checkCompressedData(out.Bytes(), halfInput); err != nil { + t.Error(err) + } +} + +func TestEncoderLargeInput(t *testing.T) { + input := make([]byte, 1000000) + rand.Read(input) + out := bytes.Buffer{} + e := NewWriter(&out, WriterOptions{Quality: 5}) + in := bytes.NewReader(input) + + n, err := io.Copy(e, in) + if err != nil { + t.Errorf("Copy Error: %v", err) + } + if int(n) != len(input) { + t.Errorf("Copy() n=%v, want %v", n, len(input)) + } + if err := e.Close(); err != nil { + t.Errorf("Close Error after copied %d bytes: %v", n, err) + } + if err := checkCompressedData(out.Bytes(), input); err != nil { + t.Error(err) + } +} + +func TestEncoderFlush(t *testing.T) { + input := make([]byte, 1000) + rand.Read(input) + out := bytes.Buffer{} + e := NewWriter(&out, WriterOptions{Quality: 5}) + in := bytes.NewReader(input) + _, err := io.Copy(e, in) + if err != nil { + t.Fatalf("Copy Error: %v", err) + } + if err := e.Flush(); err != nil { + t.Fatalf("Flush(): %v", err) + } + if out.Len() == 0 { + t.Fatalf("0 bytes written after Flush()") + } + decompressed := make([]byte, 1000) + reader := NewReader(bytes.NewReader(out.Bytes())) + n, err := reader.Read(decompressed) + if n != len(decompressed) || err != nil { + t.Errorf("Expected <%v, nil>, but <%v, %v>", len(decompressed), n, err) + } + if !bytes.Equal(decompressed, input) { + t.Errorf(""+ + "Decompress after flush: %v\n"+ + "%q\n"+ + "want:\n%q", + err, decompressed, input) + } + if err := e.Close(); err != nil { + t.Errorf("Close(): %v", err) + } +} + +type readerWithTimeout struct { + io.Reader +} + +func (r readerWithTimeout) Read(p []byte) (int, error) { + type result struct { + n int + err error + } + ch := make(chan result) + go func() { + n, err := r.Reader.Read(p) + ch <- result{n, err} + }() + select { + case result := <-ch: + return result.n, result.err + case <-time.After(5 * time.Second): + return 0, fmt.Errorf("read timed out") + } +} + +func TestDecoderStreaming(t *testing.T) { + pr, pw := io.Pipe() + writer := NewWriter(pw, WriterOptions{Quality: 5, LGWin: 20}) + reader := readerWithTimeout{NewReader(pr)} + defer func() { + go ioutil.ReadAll(pr) // swallow the "EOF" token from writer.Close + if err := writer.Close(); err != nil { + t.Errorf("writer.Close: %v", err) + } + }() + + ch := make(chan []byte) + errch := make(chan error) + go func() { + for { + segment, ok := <-ch + if !ok { + return + } + if n, err := writer.Write(segment); err != nil || n != len(segment) { + errch <- fmt.Errorf("write=%v,%v, want %v,%v", n, err, len(segment), nil) + return + } + if err := writer.Flush(); err != nil { + errch <- fmt.Errorf("flush: %v", err) + return + } + } + }() + defer close(ch) + + segments := [...][]byte{ + []byte("first"), + []byte("second"), + []byte("third"), + } + for k, segment := range segments { + t.Run(fmt.Sprintf("Segment%d", k), func(t *testing.T) { + select { + case ch <- segment: + case err := <-errch: + t.Fatalf("write: %v", err) + case <-time.After(5 * time.Second): + t.Fatalf("timed out") + } + wantLen := len(segment) + got := make([]byte, wantLen) + if n, err := reader.Read(got); err != nil || n != wantLen || !bytes.Equal(got, segment) { + t.Fatalf("read[%d]=%q,%v,%v, want %q,%v,%v", k, got, n, err, segment, wantLen, nil) + } + }) + } +} + +func TestReader(t *testing.T) { + content := bytes.Repeat([]byte("hello world!"), 10000) + encoded, _ := Encode(content, WriterOptions{Quality: 5}) + r := NewReader(bytes.NewReader(encoded)) + var decodedOutput bytes.Buffer + n, err := io.Copy(&decodedOutput, r) + if err != nil { + t.Fatalf("Copy(): n=%v, err=%v", n, err) + } + if got := decodedOutput.Bytes(); !bytes.Equal(got, content) { + t.Errorf(""+ + "Reader output:\n"+ + "%q\n"+ + "want:\n"+ + "<%d bytes>", + got, len(content)) + } +} + +func TestDecode(t *testing.T) { + content := bytes.Repeat([]byte("hello world!"), 10000) + encoded, _ := Encode(content, WriterOptions{Quality: 5}) + decoded, err := Decode(encoded) + if err != nil { + t.Errorf("Decode: %v", err) + } + if !bytes.Equal(decoded, content) { + t.Errorf(""+ + "Decode content:\n"+ + "%q\n"+ + "want:\n"+ + "<%d bytes>", + decoded, len(content)) + } +} + +func TestQuality(t *testing.T) { + content := bytes.Repeat([]byte("hello world!"), 10000) + for q := 0; q < 12; q++ { + encoded, _ := Encode(content, WriterOptions{Quality: q}) + decoded, err := Decode(encoded) + if err != nil { + t.Errorf("Decode: %v", err) + } + if !bytes.Equal(decoded, content) { + t.Errorf(""+ + "Decode content:\n"+ + "%q\n"+ + "want:\n"+ + "<%d bytes>", + decoded, len(content)) + } + } +} + +func TestDecodeFuzz(t *testing.T) { + // Test that the decoder terminates with corrupted input. + content := bytes.Repeat([]byte("hello world!"), 100) + src := rand.NewSource(0) + encoded, err := Encode(content, WriterOptions{Quality: 5}) + if err != nil { + t.Fatalf("Encode(<%d bytes>, _) = _, %s", len(content), err) + } + if len(encoded) == 0 { + t.Fatalf("Encode(<%d bytes>, _) produced empty output", len(content)) + } + for i := 0; i < 100; i++ { + enc := append([]byte{}, encoded...) + for j := 0; j < 5; j++ { + enc[int(src.Int63())%len(enc)] = byte(src.Int63() % 256) + } + Decode(enc) + } +} + +func TestDecodeTrailingData(t *testing.T) { + content := bytes.Repeat([]byte("hello world!"), 100) + encoded, _ := Encode(content, WriterOptions{Quality: 5}) + _, err := Decode(append(encoded, 0)) + if err == nil { + t.Errorf("Expected 'excessive input' error") + } +} + +func TestEncodeDecode(t *testing.T) { + for _, test := range []struct { + data []byte + repeats int + }{ + {nil, 0}, + {[]byte("A"), 1}, + {[]byte("

Hello world

"), 10}, + {[]byte("

Hello world

"), 1000}, + } { + t.Logf("case %q x %d", test.data, test.repeats) + input := bytes.Repeat(test.data, test.repeats) + encoded, err := Encode(input, WriterOptions{Quality: 5}) + if err != nil { + t.Errorf("Encode: %v", err) + } + // Inputs are compressible, but may be too small to compress. + if maxSize := len(input)/2 + 20; len(encoded) >= maxSize { + t.Errorf(""+ + "Encode returned %d bytes, want <%d\n"+ + "Encoded=%q", + len(encoded), maxSize, encoded) + } + decoded, err := Decode(encoded) + if err != nil { + t.Errorf("Decode: %v", err) + } + if !bytes.Equal(decoded, input) { + var want string + if len(input) > 320 { + want = fmt.Sprintf("<%d bytes>", len(input)) + } else { + want = fmt.Sprintf("%q", input) + } + t.Errorf(""+ + "Decode content:\n"+ + "%q\n"+ + "want:\n"+ + "%s", + decoded, want) + } + } +} + +// Encode returns content encoded with Brotli. +func Encode(content []byte, options WriterOptions) ([]byte, error) { + var buf bytes.Buffer + writer := NewWriter(&buf, options) + _, err := writer.Write(content) + if closeErr := writer.Close(); err == nil { + err = closeErr + } + return buf.Bytes(), err +} + +// Decode decodes Brotli encoded data. +func Decode(encodedData []byte) ([]byte, error) { + r := NewReader(bytes.NewReader(encodedData)) + return ioutil.ReadAll(r) +} diff --git a/decode.go b/decode.go index 38b7153..3ef969d 100644 --- a/decode.go +++ b/decode.go @@ -114,7 +114,7 @@ var kCodeLengthPrefixLength = [16]byte{2, 2, 2, 3, 2, 2, 2, 4, 2, 2, 2, 3, 2, 2, var kCodeLengthPrefixValue = [16]byte{0, 4, 3, 2, 0, 4, 3, 1, 0, 4, 3, 2, 0, 4, 3, 5} -func BrotliDecoderSetParameter(state *BrotliDecoderState, p int, value uint32) bool { +func BrotliDecoderSetParameter(state *Reader, p int, value uint32) bool { if state.state != BROTLI_STATE_UNINITED { return false } @@ -136,9 +136,9 @@ func BrotliDecoderSetParameter(state *BrotliDecoderState, p int, value uint32) b } } -func BrotliDecoderCreateInstance() *BrotliDecoderState { - var state *BrotliDecoderState - state = new(BrotliDecoderState) +func BrotliDecoderCreateInstance() *Reader { + var state *Reader + state = new(Reader) if state == nil { return nil } @@ -151,7 +151,7 @@ func BrotliDecoderCreateInstance() *BrotliDecoderState { } /* Deinitializes and frees BrotliDecoderState instance. */ -func BrotliDecoderDestroyInstance(state *BrotliDecoderState) { +func BrotliDecoderDestroyInstance(state *Reader) { if state == nil { return } else { @@ -160,7 +160,7 @@ func BrotliDecoderDestroyInstance(state *BrotliDecoderState) { } /* Saves error code and converts it to BrotliDecoderResult. */ -func SaveErrorCode(s *BrotliDecoderState, e int) int { +func SaveErrorCode(s *Reader, e int) int { s.error_code = int(e) switch e { case BROTLI_DECODER_SUCCESS: @@ -179,7 +179,7 @@ func SaveErrorCode(s *BrotliDecoderState, e int) int { /* Decodes WBITS by reading 1 - 7 bits, or 0x11 for "Large Window Brotli". Precondition: bit-reader accumulator has at least 8 bits. */ -func DecodeWindowBits(s *BrotliDecoderState, br *BrotliBitReader) int { +func DecodeWindowBits(s *Reader, br *BrotliBitReader) int { var n uint32 var large_window bool = s.large_window s.large_window = false @@ -220,7 +220,7 @@ func DecodeWindowBits(s *BrotliDecoderState, br *BrotliBitReader) int { } /* Decodes a number in the range [0..255], by reading 1 - 11 bits. */ -func DecodeVarLenUint8(s *BrotliDecoderState, br *BrotliBitReader, value *uint32) int { +func DecodeVarLenUint8(s *Reader, br *BrotliBitReader, value *uint32) int { var bits uint32 switch s.substate_decode_uint8 { case BROTLI_STATE_DECODE_UINT8_NONE: @@ -268,7 +268,7 @@ func DecodeVarLenUint8(s *BrotliDecoderState, br *BrotliBitReader, value *uint32 } /* Decodes a metablock length and flags by reading 2 - 31 bits. */ -func DecodeMetaBlockLength(s *BrotliDecoderState, br *BrotliBitReader) int { +func DecodeMetaBlockLength(s *Reader, br *BrotliBitReader) int { var bits uint32 var i int for { @@ -538,7 +538,7 @@ func Log2Floor(x uint32) uint32 { /* Reads (s->symbol + 1) symbols. Totally 1..4 symbols are read, 1..11 bits each. The list of symbols MUST NOT contain duplicates. */ -func ReadSimpleHuffmanSymbols(alphabet_size uint32, max_symbol uint32, s *BrotliDecoderState) int { +func ReadSimpleHuffmanSymbols(alphabet_size uint32, max_symbol uint32, s *Reader) int { var br *BrotliBitReader = &s.br var max_bits uint32 = Log2Floor(alphabet_size - 1) var i uint32 = s.sub_loop_counter @@ -651,7 +651,7 @@ func ProcessRepeatedCodeLength(code_len uint32, repeat_delta uint32, alphabet_si } /* Reads and decodes symbol codelengths. */ -func ReadSymbolCodeLengths(alphabet_size uint32, s *BrotliDecoderState) int { +func ReadSymbolCodeLengths(alphabet_size uint32, s *Reader) int { var br *BrotliBitReader = &s.br var symbol uint32 = s.symbol var repeat uint32 = s.repeat @@ -700,7 +700,7 @@ func ReadSymbolCodeLengths(alphabet_size uint32, s *BrotliDecoderState) int { return BROTLI_DECODER_SUCCESS } -func SafeReadSymbolCodeLengths(alphabet_size uint32, s *BrotliDecoderState) int { +func SafeReadSymbolCodeLengths(alphabet_size uint32, s *Reader) int { var br *BrotliBitReader = &s.br var get_byte bool = false var p []HuffmanCode @@ -746,7 +746,7 @@ func SafeReadSymbolCodeLengths(alphabet_size uint32, s *BrotliDecoderState) int /* Reads and decodes 15..18 codes using static prefix code. Each code is 2..4 bits long. In total 30..72 bits are used. */ -func ReadCodeLengthCodeLengths(s *BrotliDecoderState) int { +func ReadCodeLengthCodeLengths(s *Reader) int { var br *BrotliBitReader = &s.br var num_codes uint32 = s.repeat var space uint32 = s.space @@ -804,7 +804,7 @@ func ReadCodeLengthCodeLengths(s *BrotliDecoderState) int { encoded with predefined entropy code. 32 - 74 bits are used. B.2) Decoded table is used to decode code lengths of symbols in resulting Huffman table. In worst case 3520 bits are read. */ -func ReadHuffmanCode(alphabet_size uint32, max_symbol uint32, table []HuffmanCode, opt_table_size *uint32, s *BrotliDecoderState) int { +func ReadHuffmanCode(alphabet_size uint32, max_symbol uint32, table []HuffmanCode, opt_table_size *uint32, s *Reader) int { var br *BrotliBitReader = &s.br /* Unnecessary masking, but might be good for safety. */ @@ -954,7 +954,7 @@ func ReadBlockLength(table []HuffmanCode, br *BrotliBitReader) uint32 { /* WARNING: if state is not BROTLI_STATE_READ_BLOCK_LENGTH_NONE, then reading can't be continued with ReadBlockLength. */ -func SafeReadBlockLength(s *BrotliDecoderState, result *uint32, table []HuffmanCode, br *BrotliBitReader) bool { +func SafeReadBlockLength(s *Reader, result *uint32, table []HuffmanCode, br *BrotliBitReader) bool { var index uint32 if s.substate_read_block_length == BROTLI_STATE_READ_BLOCK_LENGTH_NONE { if !SafeReadSymbol(table, br, &index) { @@ -992,7 +992,7 @@ func SafeReadBlockLength(s *BrotliDecoderState, result *uint32, table []HuffmanC Most of input values are 0 and 1. To reduce number of branches, we replace inner for loop with do-while. */ -func InverseMoveToFrontTransform(v []byte, v_len uint32, state *BrotliDecoderState) { +func InverseMoveToFrontTransform(v []byte, v_len uint32, state *Reader) { var mtf [256]byte var i int for i = 1; i < 256; i++ { @@ -1016,7 +1016,7 @@ func InverseMoveToFrontTransform(v []byte, v_len uint32, state *BrotliDecoderSta } /* Decodes a series of Huffman table using ReadHuffmanCode function. */ -func HuffmanTreeGroupDecode(group *HuffmanTreeGroup, s *BrotliDecoderState) int { +func HuffmanTreeGroupDecode(group *HuffmanTreeGroup, s *Reader) int { if s.substate_tree_group != BROTLI_STATE_TREE_GROUP_LOOP { s.next = group.codes s.htree_index = 0 @@ -1046,7 +1046,7 @@ func HuffmanTreeGroupDecode(group *HuffmanTreeGroup, s *BrotliDecoderState) int This table will be used for reading context map items. 3) Read context map items; "0" values could be run-length encoded. 4) Optionally, apply InverseMoveToFront transform to the resulting map. */ -func DecodeContextMap(context_map_size uint32, num_htrees *uint32, context_map_arg *[]byte, s *BrotliDecoderState) int { +func DecodeContextMap(context_map_size uint32, num_htrees *uint32, context_map_arg *[]byte, s *Reader) int { var br *BrotliBitReader = &s.br var result int = BROTLI_DECODER_SUCCESS @@ -1192,7 +1192,7 @@ func DecodeContextMap(context_map_size uint32, num_htrees *uint32, context_map_a /* Decodes a command or literal and updates block type ring-buffer. Reads 3..54 bits. */ -func DecodeBlockTypeAndLength(safe int, s *BrotliDecoderState, tree_type int) bool { +func DecodeBlockTypeAndLength(safe int, s *Reader, tree_type int) bool { var max_block_type uint32 = s.num_block_types[tree_type] var type_tree []HuffmanCode type_tree = s.block_type_trees[tree_type*BROTLI_HUFFMAN_MAX_SIZE_258:] @@ -1239,7 +1239,7 @@ func DecodeBlockTypeAndLength(safe int, s *BrotliDecoderState, tree_type int) bo return true } -func DetectTrivialLiteralBlockTypes(s *BrotliDecoderState) { +func DetectTrivialLiteralBlockTypes(s *Reader) { var i uint for i = 0; i < 8; i++ { s.trivial_literal_contexts[i] = 0 @@ -1263,7 +1263,7 @@ func DetectTrivialLiteralBlockTypes(s *BrotliDecoderState) { } } -func PrepareLiteralDecoding(s *BrotliDecoderState) { +func PrepareLiteralDecoding(s *Reader) { var context_mode byte var trivial uint var block_type uint32 = s.block_type_rb[1] @@ -1278,7 +1278,7 @@ func PrepareLiteralDecoding(s *BrotliDecoderState) { /* Decodes the block type and updates the state for literal context. Reads 3..54 bits. */ -func DecodeLiteralBlockSwitchInternal(safe int, s *BrotliDecoderState) bool { +func DecodeLiteralBlockSwitchInternal(safe int, s *Reader) bool { if !DecodeBlockTypeAndLength(safe, s, 0) { return false } @@ -1287,17 +1287,17 @@ func DecodeLiteralBlockSwitchInternal(safe int, s *BrotliDecoderState) bool { return true } -func DecodeLiteralBlockSwitch(s *BrotliDecoderState) { +func DecodeLiteralBlockSwitch(s *Reader) { DecodeLiteralBlockSwitchInternal(0, s) } -func SafeDecodeLiteralBlockSwitch(s *BrotliDecoderState) bool { +func SafeDecodeLiteralBlockSwitch(s *Reader) bool { return DecodeLiteralBlockSwitchInternal(1, s) } /* Block switch for insert/copy length. Reads 3..54 bits. */ -func DecodeCommandBlockSwitchInternal(safe int, s *BrotliDecoderState) bool { +func DecodeCommandBlockSwitchInternal(safe int, s *Reader) bool { if !DecodeBlockTypeAndLength(safe, s, 1) { return false } @@ -1306,17 +1306,17 @@ func DecodeCommandBlockSwitchInternal(safe int, s *BrotliDecoderState) bool { return true } -func DecodeCommandBlockSwitch(s *BrotliDecoderState) { +func DecodeCommandBlockSwitch(s *Reader) { DecodeCommandBlockSwitchInternal(0, s) } -func SafeDecodeCommandBlockSwitch(s *BrotliDecoderState) bool { +func SafeDecodeCommandBlockSwitch(s *Reader) bool { return DecodeCommandBlockSwitchInternal(1, s) } /* Block switch for distance codes. Reads 3..54 bits. */ -func DecodeDistanceBlockSwitchInternal(safe int, s *BrotliDecoderState) bool { +func DecodeDistanceBlockSwitchInternal(safe int, s *Reader) bool { if !DecodeBlockTypeAndLength(safe, s, 2) { return false } @@ -1326,15 +1326,15 @@ func DecodeDistanceBlockSwitchInternal(safe int, s *BrotliDecoderState) bool { return true } -func DecodeDistanceBlockSwitch(s *BrotliDecoderState) { +func DecodeDistanceBlockSwitch(s *Reader) { DecodeDistanceBlockSwitchInternal(0, s) } -func SafeDecodeDistanceBlockSwitch(s *BrotliDecoderState) bool { +func SafeDecodeDistanceBlockSwitch(s *Reader) bool { return DecodeDistanceBlockSwitchInternal(1, s) } -func UnwrittenBytes(s *BrotliDecoderState, wrap bool) uint { +func UnwrittenBytes(s *Reader, wrap bool) uint { var pos uint if wrap && s.pos > s.ringbuffer_size { pos = uint(s.ringbuffer_size) @@ -1348,7 +1348,7 @@ func UnwrittenBytes(s *BrotliDecoderState, wrap bool) uint { /* Dumps output. Returns BROTLI_DECODER_NEEDS_MORE_OUTPUT only if there is more output to push and either ring-buffer is as big as window size, or |force| is true. */ -func WriteRingBuffer(s *BrotliDecoderState, available_out *uint, next_out *[]byte, total_out *uint, force bool) int { +func WriteRingBuffer(s *Reader, available_out *uint, next_out *[]byte, total_out *uint, force bool) int { var start []byte start = s.ringbuffer[s.partial_pos_out&uint(s.ringbuffer_mask):] var to_write uint = UnwrittenBytes(s, true) @@ -1398,7 +1398,7 @@ func WriteRingBuffer(s *BrotliDecoderState, available_out *uint, next_out *[]byt return BROTLI_DECODER_SUCCESS } -func WrapRingBuffer(s *BrotliDecoderState) { +func WrapRingBuffer(s *Reader) { if s.should_wrap_ringbuffer != 0 { copy(s.ringbuffer, s.ringbuffer_end[:uint(s.pos)]) s.should_wrap_ringbuffer = 0 @@ -1412,7 +1412,7 @@ func WrapRingBuffer(s *BrotliDecoderState) { Last two bytes of ring-buffer are initialized to 0, so context calculation could be done uniformly for the first two and all other positions. */ -func BrotliEnsureRingBuffer(s *BrotliDecoderState) bool { +func BrotliEnsureRingBuffer(s *Reader) bool { var old_ringbuffer []byte = s.ringbuffer if s.ringbuffer_size == s.new_ringbuffer_size { return true @@ -1442,7 +1442,7 @@ func BrotliEnsureRingBuffer(s *BrotliDecoderState) bool { return true } -func CopyUncompressedBlockToOutput(available_out *uint, next_out *[]byte, total_out *uint, s *BrotliDecoderState) int { +func CopyUncompressedBlockToOutput(available_out *uint, next_out *[]byte, total_out *uint, s *Reader) int { /* TODO: avoid allocation for single uncompressed block. */ if !BrotliEnsureRingBuffer(s) { return BROTLI_DECODER_ERROR_ALLOC_RING_BUFFER_1 @@ -1508,7 +1508,7 @@ func CopyUncompressedBlockToOutput(available_out *uint, next_out *[]byte, total_ size than needed to reduce memory usage. When this method is called, metablock size and flags MUST be decoded. */ -func BrotliCalculateRingBufferSize(s *BrotliDecoderState) { +func BrotliCalculateRingBufferSize(s *Reader) { var window_size int = 1 << s.window_bits var new_ringbuffer_size int = window_size var min_size int @@ -1557,7 +1557,7 @@ func BrotliCalculateRingBufferSize(s *BrotliDecoderState) { } /* Reads 1..256 2-bit context modes. */ -func ReadContextModes(s *BrotliDecoderState) int { +func ReadContextModes(s *Reader) int { var br *BrotliBitReader = &s.br var i int = s.loop_counter @@ -1575,7 +1575,7 @@ func ReadContextModes(s *BrotliDecoderState) int { return BROTLI_DECODER_SUCCESS } -func TakeDistanceFromRingBuffer(s *BrotliDecoderState) { +func TakeDistanceFromRingBuffer(s *Reader) { if s.distance_code == 0 { s.dist_rb_idx-- s.distance_code = s.dist_rb[s.dist_rb_idx&3] @@ -1618,7 +1618,7 @@ func SafeReadBits(br *BrotliBitReader, n_bits uint32, val *uint32) bool { } /* Precondition: s->distance_code < 0. */ -func ReadDistanceInternal(safe int, s *BrotliDecoderState, br *BrotliBitReader) bool { +func ReadDistanceInternal(safe int, s *Reader, br *BrotliBitReader) bool { var distval int var memento BrotliBitReaderState var distance_tree []HuffmanCode = []HuffmanCode(s.distance_hgroup.htrees[s.dist_htree_index]) @@ -1679,15 +1679,15 @@ func ReadDistanceInternal(safe int, s *BrotliDecoderState, br *BrotliBitReader) return true } -func ReadDistance(s *BrotliDecoderState, br *BrotliBitReader) { +func ReadDistance(s *Reader, br *BrotliBitReader) { ReadDistanceInternal(0, s, br) } -func SafeReadDistance(s *BrotliDecoderState, br *BrotliBitReader) bool { +func SafeReadDistance(s *Reader, br *BrotliBitReader) bool { return ReadDistanceInternal(1, s, br) } -func ReadCommandInternal(safe int, s *BrotliDecoderState, br *BrotliBitReader, insert_length *int) bool { +func ReadCommandInternal(safe int, s *Reader, br *BrotliBitReader, insert_length *int) bool { var cmd_code uint32 var insert_len_extra uint32 = 0 var copy_length uint32 @@ -1726,11 +1726,11 @@ func ReadCommandInternal(safe int, s *BrotliDecoderState, br *BrotliBitReader, i return true } -func ReadCommand(s *BrotliDecoderState, br *BrotliBitReader, insert_length *int) { +func ReadCommand(s *Reader, br *BrotliBitReader, insert_length *int) { ReadCommandInternal(0, s, br, insert_length) } -func SafeReadCommand(s *BrotliDecoderState, br *BrotliBitReader, insert_length *int) bool { +func SafeReadCommand(s *Reader, br *BrotliBitReader, insert_length *int) bool { return ReadCommandInternal(1, s, br, insert_length) } @@ -1742,7 +1742,7 @@ func CheckInputAmount(safe int, br *BrotliBitReader, num uint) bool { return BrotliCheckInputAmount(br, num) } -func ProcessCommandsInternal(safe int, s *BrotliDecoderState) int { +func ProcessCommandsInternal(safe int, s *Reader) int { var pos int = s.pos var i int = s.loop_counter var result int = BROTLI_DECODER_SUCCESS @@ -2110,11 +2110,11 @@ saveStateAndReturn: return result } -func ProcessCommands(s *BrotliDecoderState) int { +func ProcessCommands(s *Reader) int { return ProcessCommandsInternal(0, s) } -func SafeProcessCommands(s *BrotliDecoderState) int { +func SafeProcessCommands(s *Reader) int { return ProcessCommandsInternal(1, s) } @@ -2136,7 +2136,7 @@ func BrotliMaxDistanceSymbol(ndirect uint32, npostfix uint32) uint32 { } func BrotliDecoderDecompress(encoded_size uint, encoded_buffer []byte, decoded_size *uint, decoded_buffer []byte) int { - var s BrotliDecoderState + var s Reader var result int var total_out uint = 0 var available_in uint = encoded_size @@ -2168,7 +2168,7 @@ func BrotliDecoderDecompress(encoded_size uint, encoded_buffer []byte, decoded_s buffer ahead of time - when result is "success" decoder MUST return all unused data back to input buffer; this is possible because the invariant is held on enter */ -func BrotliDecoderDecompressStream(s *BrotliDecoderState, available_in *uint, next_in *[]byte, available_out *uint, next_out *[]byte, total_out *uint) int { +func BrotliDecoderDecompressStream(s *Reader, available_in *uint, next_in *[]byte, available_out *uint, next_out *[]byte, total_out *uint) int { var result int = BROTLI_DECODER_SUCCESS var br *BrotliBitReader = &s.br @@ -2687,7 +2687,7 @@ func BrotliDecoderDecompressStream(s *BrotliDecoderState, available_in *uint, ne return SaveErrorCode(s, result) } -func BrotliDecoderHasMoreOutput(s *BrotliDecoderState) bool { +func BrotliDecoderHasMoreOutput(s *Reader) bool { /* After unrecoverable error remaining output is considered nonsensical. */ if int(s.error_code) < 0 { return false @@ -2696,7 +2696,7 @@ func BrotliDecoderHasMoreOutput(s *BrotliDecoderState) bool { return s.ringbuffer != nil && UnwrittenBytes(s, false) != 0 } -func BrotliDecoderTakeOutput(s *BrotliDecoderState, size *uint) []byte { +func BrotliDecoderTakeOutput(s *Reader, size *uint) []byte { var result []byte = nil var available_out uint if *size != 0 { @@ -2730,15 +2730,15 @@ func BrotliDecoderTakeOutput(s *BrotliDecoderState, size *uint) []byte { return result } -func BrotliDecoderIsUsed(s *BrotliDecoderState) bool { +func BrotliDecoderIsUsed(s *Reader) bool { return s.state != BROTLI_STATE_UNINITED || BrotliGetAvailableBits(&s.br) != 0 } -func BrotliDecoderIsFinished(s *BrotliDecoderState) bool { +func BrotliDecoderIsFinished(s *Reader) bool { return (s.state == BROTLI_STATE_DONE) && !BrotliDecoderHasMoreOutput(s) } -func BrotliDecoderGetErrorCode(s *BrotliDecoderState) int { +func BrotliDecoderGetErrorCode(s *Reader) int { return int(s.error_code) } diff --git a/reader.go b/reader.go new file mode 100644 index 0000000..d97cde6 --- /dev/null +++ b/reader.go @@ -0,0 +1,94 @@ +package brotli + +import ( + "errors" + "io" +) + +type decodeError int + +func (err decodeError) Error() string { + return "brotli: " + string(BrotliDecoderErrorString(int(err))) +} + +var errExcessiveInput = errors.New("brotli: excessive input") +var errInvalidState = errors.New("brotli: invalid state") +var errReaderClosed = errors.New("brotli: Reader is closed") + +// readBufSize is a "good" buffer size that avoids excessive round-trips +// between C and Go but doesn't waste too much memory on buffering. +// It is arbitrarily chosen to be equal to the constant used in io.Copy. +const readBufSize = 32 * 1024 + +// NewReader initializes new Reader instance. +func NewReader(src io.Reader) *Reader { + r := new(Reader) + BrotliDecoderStateInit(r) + r.src = src + r.buf = make([]byte, readBufSize) + return r +} + +func (r *Reader) Read(p []byte) (n int, err error) { + if !BrotliDecoderHasMoreOutput(r) && len(r.in) == 0 { + m, readErr := r.src.Read(r.buf) + if m == 0 { + // If readErr is `nil`, we just proxy underlying stream behavior. + return 0, readErr + } + r.in = r.buf[:m] + } + + if len(p) == 0 { + return 0, nil + } + + for { + var written uint + in_len := uint(len(r.in)) + out_len := uint(len(p)) + in_remaining := in_len + out_remaining := out_len + result := BrotliDecoderDecompressStream(r, &in_remaining, &r.in, &out_remaining, &p, nil) + written = out_len - out_remaining + n = int(written) + + switch result { + case BROTLI_DECODER_RESULT_SUCCESS: + if len(r.in) > 0 { + return n, errExcessiveInput + } + return n, nil + case BROTLI_DECODER_RESULT_ERROR: + return n, decodeError(BrotliDecoderGetErrorCode(r)) + case BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT: + if n == 0 { + return 0, io.ErrShortBuffer + } + return n, nil + case BROTLI_DECODER_NEEDS_MORE_INPUT: + } + + if len(r.in) != 0 { + return 0, errInvalidState + } + + // Calling r.src.Read may block. Don't block if we have data to return. + if n > 0 { + return n, nil + } + + // Top off the buffer. + encN, err := r.src.Read(r.buf) + if encN == 0 { + // Not enough data to complete decoding. + if err == io.EOF { + return 0, io.ErrUnexpectedEOF + } + return 0, err + } + r.in = r.buf[:encN] + } + + return n, nil +} diff --git a/state.go b/state.go index 9fbe66e..4b06601 100644 --- a/state.go +++ b/state.go @@ -1,5 +1,7 @@ package brotli +import "io" + /* Copyright 2015 Google Inc. All Rights Reserved. Distributed under MIT license. @@ -89,7 +91,11 @@ const ( BROTLI_STATE_READ_BLOCK_LENGTH_SUFFIX ) -type BrotliDecoderState struct { +type Reader struct { + src io.Reader + buf []byte // scratch space for reading from src + in []byte // current chunk to decode; usually aliases buf + state int loop_counter int br BrotliBitReader @@ -177,7 +183,7 @@ type BrotliDecoderState struct { trivial_literal_contexts [8]uint32 } -func BrotliDecoderStateInit(s *BrotliDecoderState) bool { +func BrotliDecoderStateInit(s *Reader) bool { s.error_code = 0 /* BROTLI_DECODER_NO_ERROR */ BrotliInitBitReader(&s.br) @@ -244,7 +250,7 @@ func BrotliDecoderStateInit(s *BrotliDecoderState) bool { return true } -func BrotliDecoderStateMetablockBegin(s *BrotliDecoderState) { +func BrotliDecoderStateMetablockBegin(s *Reader) { s.meta_block_remaining_len = 0 s.block_length[0] = 1 << 24 s.block_length[1] = 1 << 24 @@ -274,7 +280,7 @@ func BrotliDecoderStateMetablockBegin(s *BrotliDecoderState) { s.distance_hgroup.htrees = nil } -func BrotliDecoderStateCleanupAfterMetablock(s *BrotliDecoderState) { +func BrotliDecoderStateCleanupAfterMetablock(s *Reader) { s.context_modes = nil s.context_map = nil s.dist_context_map = nil @@ -283,14 +289,14 @@ func BrotliDecoderStateCleanupAfterMetablock(s *BrotliDecoderState) { s.distance_hgroup.htrees = nil } -func BrotliDecoderStateCleanup(s *BrotliDecoderState) { +func BrotliDecoderStateCleanup(s *Reader) { BrotliDecoderStateCleanupAfterMetablock(s) s.ringbuffer = nil s.block_type_trees = nil } -func BrotliDecoderHuffmanTreeGroupInit(s *BrotliDecoderState, group *HuffmanTreeGroup, alphabet_size uint32, max_symbol uint32, ntrees uint32) bool { +func BrotliDecoderHuffmanTreeGroupInit(s *Reader, group *HuffmanTreeGroup, alphabet_size uint32, max_symbol uint32, ntrees uint32) bool { var max_table_size uint = uint(kMaxHuffmanTableSize[(alphabet_size+31)>>5]) group.alphabet_size = uint16(alphabet_size) group.max_symbol = uint16(max_symbol)