Skip to content

Commit e51e1d2

Browse files
thecampagnardsKonstantin Sidorenko
authored andcommitted
fix: pgdialect range
1 parent b25423b commit e51e1d2

File tree

6 files changed

+214
-94
lines changed

6 files changed

+214
-94
lines changed

dialect/pgdialect/array.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -389,10 +389,6 @@ func arrayScanner(typ reflect.Type) schema.ScannerFunc {
389389
}
390390
}
391391

392-
if src == nil {
393-
return nil
394-
}
395-
396392
b, err := toBytes(src)
397393
if err != nil {
398394
return err

dialect/pgdialect/dialect.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,6 @@ func (d *Dialect) onField(field *schema.Field) {
119119
return
120120
}
121121

122-
if field.Tag.HasOption("multirange") {
123-
field.Append = d.arrayAppender(field.StructField.Type)
124-
field.Scan = arrayScanner(field.StructField.Type)
125-
return
126-
}
127-
128122
switch field.DiscoveredSQLType {
129123
case sqltype.HSTORE:
130124
field.Append = d.hstoreAppender(field.StructField.Type)

dialect/pgdialect/range.go

Lines changed: 100 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,138 +1,156 @@
11
package pgdialect
22

33
import (
4+
"bytes"
45
"database/sql"
6+
"database/sql/driver"
57
"fmt"
68
"io"
7-
"time"
9+
"reflect"
810

9-
"github.com/uptrace/bun/internal"
1011
"github.com/uptrace/bun/schema"
1112
)
1213

1314
type MultiRange[T any] []Range[T]
1415

16+
var (
17+
_ sql.Scanner = (*MultiRange[any])(nil)
18+
_ driver.Valuer = (*MultiRange[any])(nil)
19+
)
20+
21+
func (m *MultiRange[T]) Scan(anySrc any) (err error) {
22+
return Array(m).Scan(anySrc)
23+
}
24+
25+
func (m MultiRange[T]) Value() (driver.Value, error) {
26+
return m.String(), nil
27+
}
28+
29+
func (m MultiRange[T]) String() string {
30+
if len(m) == 0 {
31+
return "{}"
32+
}
33+
var b []byte
34+
b = append(b, '{')
35+
for _, r := range m {
36+
b = append(b, unquote(appendElem(nil, r))...)
37+
b = append(b, ',')
38+
}
39+
b = append(b[:len(b)-1], '}')
40+
return string(b)
41+
}
42+
1543
type Range[T any] struct {
16-
Lower, Upper T
17-
LowerBound, UpperBound RangeBound
44+
Lower, Upper T
45+
LowerBound RangeLowerBound
46+
UpperBound RangeUpperBound
1847
}
1948

20-
type RangeBound byte
49+
var (
50+
_ driver.Valuer = (*Range[any])(nil)
51+
_ sql.Scanner = (*Range[any])(nil)
52+
)
53+
54+
type RangeLowerBound byte
55+
type RangeUpperBound byte
2156

2257
const (
23-
RangeBoundInclusiveLeft RangeBound = '['
24-
RangeBoundInclusiveRight RangeBound = ']'
25-
RangeBoundExclusiveLeft RangeBound = '('
26-
RangeBoundExclusiveRight RangeBound = ')'
58+
RangeBoundExclusiveLeft RangeLowerBound = '('
59+
RangeBoundExclusiveRight RangeUpperBound = ')'
60+
RangeBoundInclusiveLeft RangeLowerBound = '['
61+
RangeBoundInclusiveRight RangeUpperBound = ']'
62+
63+
RangeBoundDefaultLeft = RangeBoundInclusiveLeft
64+
RangeBoundDefaultRight = RangeBoundExclusiveRight
2765
)
2866

2967
func NewRange[T any](lower, upper T) Range[T] {
3068
return Range[T]{
3169
Lower: lower,
70+
LowerBound: RangeBoundDefaultLeft,
3271
Upper: upper,
33-
LowerBound: RangeBoundInclusiveLeft,
34-
UpperBound: RangeBoundExclusiveRight,
72+
UpperBound: RangeBoundDefaultRight,
3573
}
3674
}
3775

38-
var _ sql.Scanner = (*Range[any])(nil)
39-
4076
func (r *Range[T]) Scan(anySrc any) (err error) {
41-
src, ok := anySrc.([]byte)
42-
if !ok {
77+
var src []byte
78+
switch s := anySrc.(type) {
79+
case string:
80+
src = []byte(s)
81+
case []byte:
82+
src = s
83+
default:
4384
return fmt.Errorf("pgdialect: Range can't scan %T", anySrc)
4485
}
4586

87+
src = bytes.TrimSpace(src)
4688
if len(src) == 0 {
4789
return io.ErrUnexpectedEOF
4890
}
49-
r.LowerBound = RangeBound(src[0])
50-
src = src[1:]
5191

52-
src, err = scanElem(&r.Lower, src)
53-
if err != nil {
54-
return err
92+
if string(src) == "empty" {
93+
return nil
5594
}
5695

96+
// read bounds
97+
r.LowerBound = RangeLowerBound(src[0])
98+
r.UpperBound = RangeUpperBound(src[len(src)-1])
99+
src = src[1 : len(src)-1]
57100
if len(src) == 0 {
58101
return io.ErrUnexpectedEOF
59102
}
60-
if ch := src[0]; ch != ',' {
61-
return fmt.Errorf("got %q, wanted %q", ch, ',')
62-
}
63-
src = src[1:]
64103

65-
src, err = scanElem(&r.Upper, src)
66-
if err != nil {
67-
return err
68-
}
69-
70-
if len(src) == 0 {
104+
l, u, ok := bytes.Cut(src, []byte(","))
105+
if !ok {
71106
return io.ErrUnexpectedEOF
72107
}
73-
r.UpperBound = RangeBound(src[0])
74-
src = src[1:]
75108

76-
if len(src) > 0 {
77-
return fmt.Errorf("unread data: %q", src)
109+
scanner := schema.Scanner(reflect.TypeOf(r.Lower))
110+
if err := scanner(reflect.ValueOf(&r.Lower).Elem(), unquote(l)); err != nil {
111+
return err
112+
}
113+
if err := scanner(reflect.ValueOf(&r.Upper).Elem(), unquote(u)); err != nil {
114+
return err
78115
}
79116
return nil
80117
}
81118

82-
var _ schema.QueryAppender = (*Range[any])(nil)
83-
84-
func (r *Range[T]) AppendQuery(fmt schema.Formatter, buf []byte) ([]byte, error) {
85-
buf = append(buf, byte(r.LowerBound))
86-
buf = appendElem(buf, r.Lower)
87-
buf = append(buf, ',')
88-
buf = appendElem(buf, r.Upper)
89-
buf = append(buf, byte(r.UpperBound))
90-
return buf, nil
119+
func (r Range[T]) Value() (driver.Value, error) {
120+
return r.String(), nil
91121
}
92122

93-
func scanElem(ptr any, src []byte) ([]byte, error) {
94-
switch ptr := ptr.(type) {
95-
case *time.Time:
96-
src, str, err := readStringLiteral(src)
97-
if err != nil {
98-
return nil, err
99-
}
100-
101-
tm, err := internal.ParseTime(internal.String(str))
102-
if err != nil {
103-
return nil, err
104-
}
105-
*ptr = tm
106-
107-
return src, nil
108-
109-
case sql.Scanner:
110-
src, str, err := readStringLiteral(src)
111-
if err != nil {
112-
return nil, err
113-
}
114-
if err := ptr.Scan(str); err != nil {
115-
return nil, err
116-
}
117-
return src, nil
118-
119-
default:
120-
panic(fmt.Errorf("unsupported range type: %T", ptr))
123+
func (r Range[T]) String() string {
124+
if r.IsZero() {
125+
return "empty"
126+
}
127+
var rs []byte
128+
if r.LowerBound == 0 {
129+
rs = append(rs, byte(RangeBoundDefaultLeft))
130+
} else {
131+
rs = append(rs, byte(r.LowerBound))
121132
}
133+
rs = appendElem(rs, r.Lower)
134+
rs = append(rs, ',')
135+
rs = appendElem(rs, r.Upper)
136+
if r.UpperBound == 0 {
137+
rs = append(rs, byte(RangeBoundDefaultRight))
138+
} else {
139+
rs = append(rs, byte(r.UpperBound))
140+
}
141+
return string(rs)
122142
}
123143

124-
func readStringLiteral(src []byte) ([]byte, []byte, error) {
125-
p := newParser(src)
144+
func (r Range[T]) IsZero() bool {
145+
return r.LowerBound == 0 && r.UpperBound == 0
146+
}
126147

127-
if err := p.Skip('"'); err != nil {
128-
return nil, nil, err
148+
func unquote(s []byte) []byte {
149+
if len(s) == 0 {
150+
return s
129151
}
130-
131-
str, err := p.ReadSubstring('"')
132-
if err != nil {
133-
return nil, nil, err
152+
if s[0] == '"' && s[len(s)-1] == '"' {
153+
return bytes.ReplaceAll(s[1:len(s)-1], []byte("\\\""), []byte("\""))
134154
}
135-
136-
src = p.Remaining()
137-
return src, str, nil
155+
return s
138156
}

dialect/pgdialect/range_test.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package pgdialect_test
2+
3+
import (
4+
"database/sql/driver"
5+
"testing"
6+
"time"
7+
8+
"github.com/stretchr/testify/assert"
9+
"github.com/uptrace/bun/dialect/pgdialect"
10+
)
11+
12+
func TestRange(t *testing.T) {
13+
t.Run("scan", func(t *testing.T) {
14+
for _, tt := range []struct {
15+
Name string
16+
Value string
17+
Expected any
18+
}{
19+
{
20+
Name: "daterange",
21+
Value: " [1995-11-01,1995-12-01) ",
22+
Expected: &pgdialect.Range[time.Time]{
23+
Lower: time.Date(1995, time.November, 1, 0, 0, 0, 0, time.UTC),
24+
LowerBound: pgdialect.RangeBoundInclusiveLeft,
25+
Upper: time.Date(1995, time.December, 1, 0, 0, 0, 0, time.UTC),
26+
UpperBound: pgdialect.RangeBoundExclusiveRight,
27+
},
28+
},
29+
{
30+
Name: "tstzrange",
31+
Value: `["1995-11-01 10:00:00+00","1995-12-01 10:00:00+00")`,
32+
Expected: &pgdialect.Range[time.Time]{
33+
Lower: time.Date(1995, time.November, 1, 10, 0, 0, 0, time.Local),
34+
LowerBound: pgdialect.RangeBoundInclusiveLeft,
35+
Upper: time.Date(1995, time.December, 1, 10, 0, 0, 0, time.Local),
36+
UpperBound: pgdialect.RangeBoundExclusiveRight,
37+
},
38+
},
39+
{
40+
Name: "empty",
41+
Value: "empty",
42+
Expected: &pgdialect.Range[time.Time]{},
43+
},
44+
} {
45+
t.Run(tt.Name, func(t *testing.T) {
46+
r := &pgdialect.Range[time.Time]{}
47+
assert.NoError(t, r.Scan(tt.Value))
48+
assert.Equal(t, tt.Expected, r)
49+
})
50+
}
51+
})
52+
53+
t.Run("append_query", func(t *testing.T) {
54+
for _, tt := range []struct {
55+
Name string
56+
Value driver.Valuer
57+
Expected string
58+
}{
59+
{
60+
Name: "daterange",
61+
Value: &pgdialect.Range[time.Time]{
62+
Lower: time.Date(1995, time.November, 1, 0, 0, 0, 0, time.Local),
63+
LowerBound: pgdialect.RangeBoundInclusiveLeft,
64+
Upper: time.Date(1995, time.December, 1, 0, 0, 0, 0, time.Local),
65+
UpperBound: pgdialect.RangeBoundExclusiveRight,
66+
},
67+
Expected: `["1995-11-01 00:00:00+00:00","1995-12-01 00:00:00+00:00")`,
68+
},
69+
{
70+
Name: "tstzrange",
71+
Value: &pgdialect.Range[time.Time]{
72+
Lower: time.Date(1995, time.November, 1, 10, 0, 0, 0, time.Local),
73+
LowerBound: pgdialect.RangeBoundInclusiveLeft,
74+
Upper: time.Date(1995, time.December, 1, 10, 0, 0, 0, time.Local),
75+
UpperBound: pgdialect.RangeBoundExclusiveRight,
76+
},
77+
Expected: `["1995-11-01 10:00:00+00:00","1995-12-01 10:00:00+00:00")`,
78+
},
79+
} {
80+
t.Run(tt.Name, func(t *testing.T) {
81+
out, err := tt.Value.Value()
82+
assert.NoError(t, err)
83+
assert.Equal(t, tt.Expected, out)
84+
})
85+
}
86+
})
87+
}

internal/dbtest/pg_test.go

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -873,10 +873,34 @@ func TestPostgresCustomTypeBytes(t *testing.T) {
873873
require.NoError(t, err)
874874
}
875875

876+
func TestPostgresRange(t *testing.T) {
877+
type Model struct {
878+
ID int64 `bun:",pk,autoincrement"`
879+
Value pgdialect.Range[time.Time] `bun:",type:tstzrange"`
880+
}
881+
882+
ctx := context.Background()
883+
884+
db := pg(t)
885+
t.Cleanup(func() { db.Close() })
886+
887+
mustResetModel(t, ctx, db, (*Model)(nil))
888+
889+
in := &Model{Value: pgdialect.NewRange(time.Unix(1000, 0), time.Unix(2000, 0))}
890+
_, err := db.NewInsert().Model(in).Exec(ctx)
891+
require.NoError(t, err)
892+
893+
out := new(Model)
894+
err = db.NewSelect().Model(out).Scan(ctx)
895+
require.NoError(t, err)
896+
897+
require.True(t, reflect.DeepEqual(in, out))
898+
}
899+
876900
func TestPostgresMultiRange(t *testing.T) {
877901
type Model struct {
878902
ID int64 `bun:",pk,autoincrement"`
879-
Value pgdialect.MultiRange[time.Time] `bun:",multirange,type:tstzmultirange"`
903+
Value pgdialect.MultiRange[time.Time] `bun:",type:tstzmultirange"`
880904
}
881905

882906
ctx := context.Background()
@@ -895,6 +919,8 @@ func TestPostgresMultiRange(t *testing.T) {
895919
out := new(Model)
896920
err = db.NewSelect().Model(out).Scan(ctx)
897921
require.NoError(t, err)
922+
923+
require.True(t, reflect.DeepEqual(in, out))
898924
}
899925

900926
type UserID struct {

0 commit comments

Comments
 (0)