Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Decompress request body when multi Content-Encoding sent on request headers #2555

Merged
merged 11 commits into from
Aug 6, 2023
Merged
91 changes: 72 additions & 19 deletions ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,31 +260,84 @@ func (c *Ctx) BaseURL() string {
return c.baseURI
}

// Body contains the raw body submitted in a POST request.
// BodyRaw contains the raw body submitted in a POST request.
// Returned value is only valid within the handler. Do not store any references.
// Make copies or use the Immutable setting instead.
func (c *Ctx) Body() []byte {
var err error
var encoding string
var body []byte
// faster than peek
c.Request().Header.VisitAll(func(key, value []byte) {
if c.app.getString(key) == HeaderContentEncoding {
encoding = c.app.getString(value)
func (c *Ctx) BodyRaw() []byte {
return c.fasthttp.Request.Body()
}

func (c *Ctx) tryDecodeBodyInOrder(
originalBody *[]byte,
encodings []string,
) ([]byte, uint8, error) {
var (
err error
body []byte
decodesRealized uint8
)

for index, encoding := range encodings {
decodesRealized++
switch encoding {
case StrGzip:
body, err = c.fasthttp.Request.BodyGunzip()
case StrBr, StrBrotli:
body, err = c.fasthttp.Request.BodyUnbrotli()
case StrDeflate:
body, err = c.fasthttp.Request.BodyInflate()
default:
decodesRealized--
if len(encodings) == 1 {
body = c.fasthttp.Request.Body()
}
return body, decodesRealized, nil
}
})

switch encoding {
case StrGzip:
body, err = c.fasthttp.Request.BodyGunzip()
case StrBr, StrBrotli:
body, err = c.fasthttp.Request.BodyUnbrotli()
case StrDeflate:
body, err = c.fasthttp.Request.BodyInflate()
default:
body = c.fasthttp.Request.Body()
if err != nil {
return nil, decodesRealized, err
}

// Only execute body raw update if it has a next iteration to try to decode
if index < len(encodings)-1 && decodesRealized > 0 {
if index == 0 {
tempBody := c.fasthttp.Request.Body()
*originalBody = make([]byte, len(tempBody))
copy(*originalBody, tempBody)
}
c.fasthttp.Request.SetBodyRaw(body)
}
}

return body, decodesRealized, nil
}

// Body contains the raw body submitted in a POST request.
// This method will decompress the body if the 'Content-Encoding' header is provided.
// It returns the original (or decompressed) body data which is valid only within the handler.
// Don't store direct references to the returned data.
// If you need to keep the body's data later, make a copy or use the Immutable option.
func (c *Ctx) Body() []byte {
var (
err error
body, originalBody []byte
encodingOrder = []string{"", "", ""}
)

// Split and get the encodings list, in order to attend the
// rule defined at: https://www.rfc-editor.org/rfc/rfc9110#section-8.4-5
encodingOrder = getSplicedStrList(c.Get(HeaderContentEncoding), encodingOrder)
if len(encodingOrder) == 0 {
return c.fasthttp.Request.Body()
}

var decodesRealized uint8
body, decodesRealized, err = c.tryDecodeBodyInOrder(&originalBody, encodingOrder)

// Ensure that the body will be the original
if originalBody != nil && decodesRealized > 0 {
c.fasthttp.Request.SetBodyRaw(originalBody)
}
if err != nil {
return []byte(err.Error())
}
Expand Down
225 changes: 195 additions & 30 deletions ctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"bufio"
"bytes"
"compress/gzip"
"compress/zlib"
"context"
"crypto/tls"
"encoding/xml"
Expand Down Expand Up @@ -323,47 +324,211 @@ func Test_Ctx_Body(t *testing.T) {
utils.AssertEqual(t, []byte("john=doe"), c.Body())
}

// go test -run Test_Ctx_Body_With_Compression
func Test_Ctx_Body_With_Compression(t *testing.T) {
t.Parallel()
func Benchmark_Ctx_Body(b *testing.B) {
const input = "john=doe"

app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.Set("Content-Encoding", "gzip")
var b bytes.Buffer
gz := gzip.NewWriter(&b)
_, err := gz.Write([]byte("john=doe"))
utils.AssertEqual(t, nil, err)
err = gz.Flush()
utils.AssertEqual(t, nil, err)
err = gz.Close()
utils.AssertEqual(t, nil, err)
c.Request().SetBody(b.Bytes())
utils.AssertEqual(t, []byte("john=doe"), c.Body())

c.Request().SetBody([]byte(input))
for i := 0; i < b.N; i++ {
_ = c.Body()
}

utils.AssertEqual(b, []byte(input), c.Body())
}

// go test -run Test_Ctx_Body_With_Compression
func Test_Ctx_Body_With_Compression(t *testing.T) {
t.Parallel()
tests := []struct {
name string
contentEncoding string
body []byte
expectedBody []byte
}{
{
name: "gzip",
contentEncoding: "gzip",
body: []byte("john=doe"),
expectedBody: []byte("john=doe"),
},
{
name: "unsupported_encoding",
contentEncoding: "undefined",
body: []byte("keeps_ORIGINAL"),
expectedBody: []byte("keeps_ORIGINAL"),
},
{
name: "gzip then unsupported",
contentEncoding: "gzip, undefined",
body: []byte("Go, be gzipped"),
expectedBody: []byte("Go, be gzipped"),
},
{
name: "invalid_deflate",
contentEncoding: "gzip,deflate",
body: []byte("I'm not correctly compressed"),
expectedBody: []byte(zlib.ErrHeader.Error()),
},
}

for _, testObject := range tests {
tCase := testObject // Duplicate object to ensure it will be unique across all runs
t.Run(tCase.name, func(t *testing.T) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.Set("Content-Encoding", tCase.contentEncoding)

if strings.Contains(tCase.contentEncoding, "gzip") {
var b bytes.Buffer
gz := gzip.NewWriter(&b)
_, err := gz.Write(tCase.body)
if err != nil {
t.Fatal(err)
}
if err = gz.Flush(); err != nil {
t.Fatal(err)
}
if err = gz.Close(); err != nil {
t.Fatal(err)
}
tCase.body = b.Bytes()
}

c.Request().SetBody(tCase.body)
body := c.Body()
utils.AssertEqual(t, tCase.expectedBody, body)

// Check if body raw is the same as previous before decompression
utils.AssertEqual(
t, tCase.body, c.Request().Body(),
"Body raw must be the same as set before",
)
})
}
}

// go test -v -run=^$ -bench=Benchmark_Ctx_Body_With_Compression -benchmem -count=4
func Benchmark_Ctx_Body_With_Compression(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.Set("Content-Encoding", "gzip")
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
_, err := gz.Write([]byte("john=doe"))
utils.AssertEqual(b, nil, err)
err = gz.Flush()
utils.AssertEqual(b, nil, err)
err = gz.Close()
utils.AssertEqual(b, nil, err)
encodingErr := errors.New("failed to encoding data")

var (
compressGzip = func(data []byte) ([]byte, error) {
var buf bytes.Buffer
writer := gzip.NewWriter(&buf)
if _, err := writer.Write(data); err != nil {
return nil, encodingErr
}
if err := writer.Flush(); err != nil {
return nil, encodingErr
}
if err := writer.Close(); err != nil {
return nil, encodingErr
}
return buf.Bytes(), nil
}
compressDeflate = func(data []byte) ([]byte, error) {
var buf bytes.Buffer
writer := zlib.NewWriter(&buf)
if _, err := writer.Write(data); err != nil {
return nil, encodingErr
}
if err := writer.Flush(); err != nil {
return nil, encodingErr
}
if err := writer.Close(); err != nil {
return nil, encodingErr
}
return buf.Bytes(), nil
}
)
compressionTests := []struct {
contentEncoding string
compressWriter func([]byte) ([]byte, error)
}{
{
contentEncoding: "gzip",
compressWriter: compressGzip,
},
{
contentEncoding: "gzip,invalid",
compressWriter: compressGzip,
},
{
contentEncoding: "deflate",
compressWriter: compressDeflate,
},
{
contentEncoding: "gzip,deflate",
compressWriter: func(data []byte) ([]byte, error) {
var (
buf bytes.Buffer
writer interface {
io.WriteCloser
Flush() error
}
err error
)

// deflate
{
writer = zlib.NewWriter(&buf)
if _, err = writer.Write(data); err != nil {
return nil, encodingErr
}
if err = writer.Flush(); err != nil {
return nil, encodingErr
}
if err = writer.Close(); err != nil {
return nil, encodingErr
}
}

c.Request().SetBody(buf.Bytes())
data = make([]byte, buf.Len())
copy(data, buf.Bytes())
buf.Reset()

// gzip
{
writer = gzip.NewWriter(&buf)
if _, err = writer.Write(data); err != nil {
return nil, encodingErr
}
if err = writer.Flush(); err != nil {
return nil, encodingErr
}
if err = writer.Close(); err != nil {
return nil, encodingErr
}
}

for i := 0; i < b.N; i++ {
_ = c.Body()
return buf.Bytes(), nil
},
},
}

utils.AssertEqual(b, []byte("john=doe"), c.Body())
for _, ct := range compressionTests {
b.Run(ct.contentEncoding, func(b *testing.B) {
app := New()
const input = "john=doe"
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)

c.Request().Header.Set("Content-Encoding", ct.contentEncoding)
compressedBody, err := ct.compressWriter([]byte(input))
utils.AssertEqual(b, nil, err)

c.Request().SetBody(compressedBody)
for i := 0; i < b.N; i++ {
_ = c.Body()
}

utils.AssertEqual(b, []byte(input), c.Body())
})
}
}

// go test -run Test_Ctx_BodyParser
Expand Down
35 changes: 35 additions & 0 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,41 @@ func acceptsOfferType(spec, offerType string) bool {
return false
}

// getSplicedStrList function takes a string and a string slice as an argument, divides the string into different
// elements divided by ',' and stores these elements in the string slice.
// It returns the populated string slice as an output.
//
// If the given slice hasn't enough space, it will allocate more and return.
func getSplicedStrList(headerValue string, dst []string) []string {
if headerValue == "" {
return nil
}

var (
index int
character rune
lastElementEndsAt uint8
insertIndex int
)
for index, character = range headerValue + "$" {
if character == ',' || index == len(headerValue) {
if insertIndex >= len(dst) {
oldSlice := dst
dst = make([]string, len(dst)+(len(dst)>>1)+2)
copy(dst, oldSlice)
}
dst[insertIndex] = utils.TrimLeft(headerValue[lastElementEndsAt:index], ' ')
lastElementEndsAt = uint8(index + 1)
insertIndex++
}
}

if len(dst) > insertIndex {
dst = dst[:insertIndex]
}
return dst
}

// getOffer return valid offer for header negotiation
func getOffer(header string, isAccepted func(spec, offer string) bool, offers ...string) string {
if len(offers) == 0 {
Expand Down
Loading