Skip to content

Commit

Permalink
add: Support for other bindings fo embedded structs
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-laskowski committed Nov 13, 2024
1 parent ee247ed commit cc0971c
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 105 deletions.
19 changes: 11 additions & 8 deletions bind_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1098,17 +1098,18 @@ func Test_Bind_Body_Form_Embedded(b *testing.T) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
type EmbeddedDemo struct {
EmbeddedStrings []string `form:"embedded_strings"`
EmbeddedString string `form:"embedded_string"`
EmbeddedStrings []string `form:"embedded_strings"`
}

type Demo struct {
SomeString string `form:"some_string"`
SomeOtherString string `form:"some_other_string"`
Strings []string `form:"strings"`
String string `form:"some_string"`
OtherString string `form:"some_other_string"`
Strings []string `form:"strings"`
OtherStrings []string `form:"other_strings"`
EmbeddedDemo
}
body := []byte("SomeString=john%2Clong&SomeOtherString=long%2Cjohn&Strings=long%2Cjohn&EmbededStrings=john%2Clong&EmbededString=johny%2Cwalker")
body := []byte("some_string=john%2Clong&some_other_string=long&some_other_string=long&strings=long%2Cjohn&embedded_strings=john%2Clongest&embedded_string=johny%2Cwalker&other_strings=long&other_strings=johny")
c.Request().SetBody(body)
c.Request().Header.SetContentType(MIMEApplicationForm)
c.Request().Header.SetContentLength(len(body))
Expand All @@ -1117,11 +1118,13 @@ func Test_Bind_Body_Form_Embedded(b *testing.T) {
err = c.Bind().Body(d)

require.NoError(b, err)
require.Equal(b, "john,long", d.String)
require.Equal(b, []string{"long", "john"}, d.Strings)
require.Equal(b, []string{"john", "long"}, d.EmbeddedStrings)
//! only one value is taken
require.Equal(b, "long", d.OtherString)
require.Equal(b, []string{"long", "johny"}, d.OtherStrings)
require.Equal(b, "johny,walker", d.EmbeddedString)
require.Equal(b, "john,long", d.SomeString)
require.Equal(b, "long,john", d.SomeOtherString)
require.Equal(b, []string{"john", "longest"}, d.EmbeddedStrings)
}

// go test -v -run=^$ -bench=Benchmark_Bind_Body_Form -benchmem -count=4
Expand Down
12 changes: 1 addition & 11 deletions binder/cookie.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package binder

import (
"reflect"
"strings"

"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
)
Expand All @@ -26,14 +23,7 @@ func (b *cookieBinding) Bind(reqCtx *fasthttp.RequestCtx, out any) error {
k := utils.UnsafeString(key)
v := utils.UnsafeString(val)

if strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
values := strings.Split(v, ",")
for i := 0; i < len(values); i++ {
data[k] = append(data[k], values[i])
}
} else {
data[k] = append(data[k], v)
}
appendValue(data, v, out, k, b.Name())
})

if err != nil {
Expand Down
10 changes: 1 addition & 9 deletions binder/form.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package binder

import (
"reflect"
"strings"

"github.com/gofiber/utils/v2"
Expand Down Expand Up @@ -30,14 +29,7 @@ func (b *formBinding) Bind(reqCtx *fasthttp.RequestCtx, out any) error {
k, err = parseParamSquareBrackets(k)
}

if strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
values := strings.Split(v, ",")
for i := 0; i < len(values); i++ {
data[k] = append(data[k], values[i])
}
} else {
data[k] = append(data[k], v)
}
appendValue(data, v, out, k, b.Name())
})

if err != nil {
Expand Down
12 changes: 1 addition & 11 deletions binder/header.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package binder

import (
"reflect"
"strings"

"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
)
Expand All @@ -20,14 +17,7 @@ func (b *headerBinding) Bind(req *fasthttp.Request, out any) error {
k := utils.UnsafeString(key)
v := utils.UnsafeString(val)

if strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
values := strings.Split(v, ",")
for i := 0; i < len(values); i++ {
data[k] = append(data[k], values[i])
}
} else {
data[k] = append(data[k], v)
}
appendValue(data, v, out, k, b.Name())
})

return parse(b.Name(), out, data)
Expand Down
86 changes: 49 additions & 37 deletions binder/mapping.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,18 @@ func parseParamSquareBrackets(k string) (string, error) {
return bb.String(), nil
}

func equalFieldType(out any, kind reflect.Kind, key string) bool {
func appendValue(to map[string][]string, rawValue string, out any, k string, bindingName string) {
if strings.Contains(rawValue, ",") && equalFieldType(out, reflect.Slice, k, bindingName) {
values := strings.Split(rawValue, ",")
for i := 0; i < len(values); i++ {
to[k] = append(to[k], values[i])
}
} else {
to[k] = append(to[k], rawValue)
}
}

func equalFieldType(out any, kind reflect.Kind, key string, bindingName string) bool {
// Get type of interface
outTyp := reflect.TypeOf(out).Elem()
key = utils.ToLower(key)
Expand All @@ -196,53 +207,54 @@ func equalFieldType(out any, kind reflect.Kind, key string) bool {
if !structField.CanSet() {
continue
}

// Get field key data
typeField := outTyp.Field(i)
// Get type of field key
structFieldKind := structField.Kind()
// Does the field type equals input?
if structFieldKind != kind {
// Is the field an embedded struct?
if structFieldKind == reflect.Struct && typeField.Anonymous {
// Loop over embedded struct fields
for j := 0; j < structField.NumField(); j++ {
fNm := utils.ToLower(structField.Type().Field(j).Name)
if fNm != key {
//this is not the field that we are looking for
continue
}

structFieldField := structField.Field(j)

// Can this embedded field be changed?
if !structFieldField.CanSet() {
continue
}

// Is the embedded struct field type equal to the input?
if structFieldField.Kind() == kind {
return true
}
}
}

continue
}
// Get tag from field if exist
inputFieldName := typeField.Tag.Get(QueryBinder.Name())
if inputFieldName == "" {
inputFieldName = typeField.Name
} else {
inputFieldName = strings.Split(inputFieldName, ",")[0]
}
// Compare field/tag with provided key
if utils.ToLower(inputFieldName) == key {
return true
if getFieldKey(typeField, bindingName) == key {
return structFieldKind == kind
}

// Is the field an embedded struct?
if typeField.Anonymous {
// Loop over embedded struct fields
for j := 0; j < structField.NumField(); j++ {
if getFieldKey(structField.Type().Field(j), bindingName) != key {
// this is not the field that we are looking for
continue
}

structFieldField := structField.Field(j)

// Can this embedded field be changed?
if !structFieldField.CanSet() {
continue
}

// Is the embedded struct field type equal to the input?
return structFieldField.Kind() == kind
}
}
}
return false
}

// Get binding key for a field
func getFieldKey(typeField reflect.StructField, bindingName string) string {
// Get tag from field if exist
inputFieldName := typeField.Tag.Get(bindingName)
if inputFieldName == "" {
inputFieldName = typeField.Name
} else {
inputFieldName = strings.Split(inputFieldName, ",")[0]
}
// Compare field key
return utils.ToLower(inputFieldName)
}

// Get content type from content type header
func FilterFlags(content string) string {
for i, char := range content {
Expand Down
18 changes: 9 additions & 9 deletions binder/mapping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,25 @@ import (

func Test_EqualFieldType(t *testing.T) {
var out int
require.False(t, equalFieldType(&out, reflect.Int, "key"))
require.False(t, equalFieldType(&out, reflect.Int, "key", "query"))

var dummy struct{ f string }
require.False(t, equalFieldType(&dummy, reflect.String, "key"))
require.False(t, equalFieldType(&dummy, reflect.String, "key", "query"))

var dummy2 struct{ f string }
require.False(t, equalFieldType(&dummy2, reflect.String, "f"))
require.False(t, equalFieldType(&dummy2, reflect.String, "f", "query"))

var user struct {
Name string
Address string `query:"address"`
Age int `query:"AGE"`
}
require.True(t, equalFieldType(&user, reflect.String, "name"))
require.True(t, equalFieldType(&user, reflect.String, "Name"))
require.True(t, equalFieldType(&user, reflect.String, "address"))
require.True(t, equalFieldType(&user, reflect.String, "Address"))
require.True(t, equalFieldType(&user, reflect.Int, "AGE"))
require.True(t, equalFieldType(&user, reflect.Int, "age"))
require.True(t, equalFieldType(&user, reflect.String, "name", "query"))
require.True(t, equalFieldType(&user, reflect.String, "Name", "query"))
require.True(t, equalFieldType(&user, reflect.String, "address", "query"))
require.True(t, equalFieldType(&user, reflect.String, "Address", "query"))
require.True(t, equalFieldType(&user, reflect.Int, "AGE", "query"))
require.True(t, equalFieldType(&user, reflect.Int, "age", "query"))
}

func Test_ParseParamSquareBrackets(t *testing.T) {
Expand Down
10 changes: 1 addition & 9 deletions binder/query.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package binder

import (
"reflect"
"strings"

"github.com/gofiber/utils/v2"
Expand Down Expand Up @@ -30,14 +29,7 @@ func (b *queryBinding) Bind(reqCtx *fasthttp.RequestCtx, out any) error {
k, err = parseParamSquareBrackets(k)
}

if strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
values := strings.Split(v, ",")
for i := 0; i < len(values); i++ {
data[k] = append(data[k], values[i])
}
} else {
data[k] = append(data[k], v)
}
appendValue(data, v, out, k, b.Name())
})

if err != nil {
Expand Down
12 changes: 1 addition & 11 deletions binder/resp_header.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package binder

import (
"reflect"
"strings"

"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
)
Expand All @@ -20,14 +17,7 @@ func (b *respHeaderBinding) Bind(resp *fasthttp.Response, out any) error {
k := utils.UnsafeString(key)
v := utils.UnsafeString(val)

if strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
values := strings.Split(v, ",")
for i := 0; i < len(values); i++ {
data[k] = append(data[k], values[i])
}
} else {
data[k] = append(data[k], v)
}
appendValue(data, v, out, k, b.Name())
})

return parse(b.Name(), out, data)
Expand Down

0 comments on commit cc0971c

Please sign in to comment.