Skip to content

Commit fe2707c

Browse files
authored
Merge pull request #411 from neal/bugfix/unmarshal-unexpected-null
Fix unmarshal null values for non-pointer fields
2 parents 8580601 + baefa5c commit fe2707c

File tree

2 files changed

+41
-10
lines changed

2 files changed

+41
-10
lines changed

gen/decoder.go

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,22 +63,34 @@ func (g *Generator) genTypeDecoder(t reflect.Type, out string, tags fieldTags, i
6363

6464
unmarshalerIface := reflect.TypeOf((*easyjson.Unmarshaler)(nil)).Elem()
6565
if reflect.PtrTo(t).Implements(unmarshalerIface) {
66-
fmt.Fprintln(g.out, ws+"("+out+").UnmarshalEasyJSON(in)")
66+
fmt.Fprintln(g.out, ws+"if in.IsNull() {")
67+
fmt.Fprintln(g.out, ws+" in.Skip()")
68+
fmt.Fprintln(g.out, ws+"} else {")
69+
fmt.Fprintln(g.out, ws+" ("+out+").UnmarshalEasyJSON(in)")
70+
fmt.Fprintln(g.out, ws+"}")
6771
return nil
6872
}
6973

7074
unmarshalerIface = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()
7175
if reflect.PtrTo(t).Implements(unmarshalerIface) {
72-
fmt.Fprintln(g.out, ws+"if data := in.Raw(); in.Ok() {")
73-
fmt.Fprintln(g.out, ws+" in.AddError( ("+out+").UnmarshalJSON(data) )")
76+
fmt.Fprintln(g.out, ws+"if in.IsNull() {")
77+
fmt.Fprintln(g.out, ws+" in.Skip()")
78+
fmt.Fprintln(g.out, ws+"} else {")
79+
fmt.Fprintln(g.out, ws+" if data := in.Raw(); in.Ok() {")
80+
fmt.Fprintln(g.out, ws+" in.AddError( ("+out+").UnmarshalJSON(data) )")
81+
fmt.Fprintln(g.out, ws+" }")
7482
fmt.Fprintln(g.out, ws+"}")
7583
return nil
7684
}
7785

7886
unmarshalerIface = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
7987
if reflect.PtrTo(t).Implements(unmarshalerIface) {
80-
fmt.Fprintln(g.out, ws+"if data := in.UnsafeBytes(); in.Ok() {")
81-
fmt.Fprintln(g.out, ws+" in.AddError( ("+out+").UnmarshalText(data) )")
88+
fmt.Fprintln(g.out, ws+"if in.IsNull() {")
89+
fmt.Fprintln(g.out, ws+" in.Skip()")
90+
fmt.Fprintln(g.out, ws+"} else {")
91+
fmt.Fprintln(g.out, ws+" if data := in.UnsafeBytes(); in.Ok() {")
92+
fmt.Fprintln(g.out, ws+" in.AddError( ("+out+").UnmarshalText(data) )")
93+
fmt.Fprintln(g.out, ws+" }")
8294
fmt.Fprintln(g.out, ws+"}")
8395
return nil
8496
}
@@ -110,13 +122,21 @@ func (g *Generator) genTypeDecoderNoCheck(t reflect.Type, out string, tags field
110122
ws := strings.Repeat(" ", indent)
111123
// Check whether type is primitive, needs to be done after interface check.
112124
if dec := customDecoders[t.String()]; dec != "" {
113-
fmt.Fprintln(g.out, ws+out+" = "+dec)
125+
fmt.Fprintln(g.out, ws+"if in.IsNull() {")
126+
fmt.Fprintln(g.out, ws+" in.Skip()")
127+
fmt.Fprintln(g.out, ws+"} else {")
128+
fmt.Fprintln(g.out, ws+" "+out+" = "+dec)
129+
fmt.Fprintln(g.out, ws+"}")
114130
return nil
115131
} else if dec := primitiveStringDecoders[t.Kind()]; dec != "" && tags.asString {
116132
if tags.intern && t.Kind() == reflect.String {
117133
dec = "in.StringIntern()"
118134
}
119-
fmt.Fprintln(g.out, ws+out+" = "+g.getType(t)+"("+dec+")")
135+
fmt.Fprintln(g.out, ws+"if in.IsNull() {")
136+
fmt.Fprintln(g.out, ws+" in.Skip()")
137+
fmt.Fprintln(g.out, ws+"} else {")
138+
fmt.Fprintln(g.out, ws+" "+out+" = "+g.getType(t)+"("+dec+")")
139+
fmt.Fprintln(g.out, ws+"}")
120140
return nil
121141
} else if dec := primitiveDecoders[t.Kind()]; dec != "" {
122142
if tags.intern && t.Kind() == reflect.String {
@@ -125,7 +145,11 @@ func (g *Generator) genTypeDecoderNoCheck(t reflect.Type, out string, tags field
125145
if tags.noCopy && t.Kind() == reflect.String {
126146
dec = "in.UnsafeString()"
127147
}
128-
fmt.Fprintln(g.out, ws+out+" = "+g.getType(t)+"("+dec+")")
148+
fmt.Fprintln(g.out, ws+"if in.IsNull() {")
149+
fmt.Fprintln(g.out, ws+" in.Skip()")
150+
fmt.Fprintln(g.out, ws+"} else {")
151+
fmt.Fprintln(g.out, ws+" "+out+" = "+g.getType(t)+"("+dec+")")
152+
fmt.Fprintln(g.out, ws+"}")
129153
return nil
130154
}
131155

tests/basic_test.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,14 +330,21 @@ func TestNil(t *testing.T) {
330330
}
331331

332332
func TestUnmarshalNull(t *testing.T) {
333-
p := primitiveTypesValue
333+
p := PrimitiveTypes{
334+
String: str,
335+
Ptr: &str,
336+
}
334337

335-
data := `{"Ptr":null}`
338+
data := `{"String":null,"Ptr":null}`
336339

337340
if err := easyjson.Unmarshal([]byte(data), &p); err != nil {
338341
t.Errorf("easyjson.Unmarshal() error: %v", err)
339342
}
340343

344+
if p.String != str {
345+
t.Errorf("Wanted %q, got %q", str, p.String)
346+
}
347+
341348
if p.Ptr != nil {
342349
t.Errorf("Wanted nil, got %q", *p.Ptr)
343350
}

0 commit comments

Comments
 (0)