Skip to content

Commit 86403cb

Browse files
committed
chore: refactor for multiple type support
1 parent 3acc4f1 commit 86403cb

File tree

4 files changed

+309
-281
lines changed

4 files changed

+309
-281
lines changed

geometry.go

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
package pgxgeos
2+
3+
import (
4+
"context"
5+
"database/sql/driver"
6+
"encoding/hex"
7+
"errors"
8+
9+
"github.com/jackc/pgx/v5"
10+
"github.com/jackc/pgx/v5/pgtype"
11+
"github.com/twpayne/go-geos"
12+
)
13+
14+
// A geometryCodec implements [github.com/jackc/pgx/v5/pgtype.Codec] for
15+
// [*github.com/twpayne/go-geos.Geom] types.
16+
type geometryCodec struct {
17+
geosContext *geos.Context
18+
}
19+
20+
// A geometryBinaryEncodePlan implements
21+
// [github.com/jackc/pgx/v5/pgtype.EncodePlan] for
22+
// [*github.com/twpayne/go-geos.Geom] types in binary format.
23+
type geometryBinaryEncodePlan struct{}
24+
25+
// A geometryTextEncodePlan implements
26+
// [github.com/jackc/pgx/v5/pgtype.EncodePlan] for
27+
// [*github.com/twpayne/go-geos.Geom] types in text format.
28+
type geometryTextEncodePlan struct{}
29+
30+
// A geometryBinaryScanPlan implements [github.com/jackc/pgx/v5/pgtype.ScanPlan]
31+
// for [*github.com/twpayne/go-geos.Geom] types in binary format.
32+
type geometryBinaryScanPlan struct {
33+
geosContext *geos.Context
34+
}
35+
36+
// A geometryTextScanPlan implements [github.com/jackc/pgx/v5/pgtype.ScanPlan]
37+
// for [*github.com/twpayne/go-geos.Geom] types in text format.
38+
type geometryTextScanPlan struct {
39+
geosContext *geos.Context
40+
}
41+
42+
// FormatSupported implements
43+
// [github.com/jackc/pgx/v5/pgtype.Codec.FormatSupported].
44+
func (c *geometryCodec) FormatSupported(format int16) bool {
45+
switch format {
46+
case pgtype.BinaryFormatCode:
47+
return true
48+
case pgtype.TextFormatCode:
49+
return true
50+
default:
51+
return false
52+
}
53+
}
54+
55+
// PreferredFormat implements
56+
// [github.com/jackc/pgx/v5/pgtype.Codec.PreferredFormat].
57+
func (c *geometryCodec) PreferredFormat() int16 {
58+
return pgtype.BinaryFormatCode
59+
}
60+
61+
// PlanEncode implements [github.com/jackc/pgx/v5/pgtype.Codec.PlanEncode].
62+
func (c *geometryCodec) PlanEncode(m *pgtype.Map, old uint32, format int16, value any) pgtype.EncodePlan {
63+
if _, ok := value.(*geos.Geom); !ok {
64+
return nil
65+
}
66+
switch format {
67+
case pgtype.BinaryFormatCode:
68+
return geometryBinaryEncodePlan{}
69+
case pgtype.TextFormatCode:
70+
return geometryTextEncodePlan{}
71+
default:
72+
return nil
73+
}
74+
}
75+
76+
// PlanScan implements [github.com/jackc/pgx/v5/pgtype.Codec.PlanScan].
77+
func (c *geometryCodec) PlanScan(m *pgtype.Map, old uint32, format int16, target any) pgtype.ScanPlan {
78+
if _, ok := target.(**geos.Geom); !ok {
79+
return nil
80+
}
81+
switch format {
82+
case pgx.BinaryFormatCode:
83+
return &geometryBinaryScanPlan{
84+
geosContext: c.geosContext,
85+
}
86+
case pgx.TextFormatCode:
87+
return &geometryTextScanPlan{
88+
geosContext: c.geosContext,
89+
}
90+
default:
91+
return nil
92+
}
93+
}
94+
95+
// DecodeDatabaseSQLValue implements
96+
// [github.com/jackc/pgx/v5/pgtype.Codec.DecodeDatabaseSQLValue].
97+
func (c *geometryCodec) DecodeDatabaseSQLValue(m *pgtype.Map, oid uint32, format int16, src []byte) (driver.Value, error) {
98+
return nil, errors.ErrUnsupported
99+
}
100+
101+
// DecodeValue implements [github.com/jackc/pgx/v5/pgtype.Codec.DecodeValue].
102+
func (c *geometryCodec) DecodeValue(m *pgtype.Map, oid uint32, format int16, src []byte) (any, error) {
103+
switch format {
104+
case pgtype.TextFormatCode:
105+
var err error
106+
src, err = hex.DecodeString(string(src))
107+
if err != nil {
108+
return nil, err
109+
}
110+
fallthrough
111+
case pgtype.BinaryFormatCode:
112+
geom, err := c.geosContext.NewGeomFromWKB(src)
113+
return geom, err
114+
default:
115+
return nil, errors.ErrUnsupported
116+
}
117+
}
118+
119+
// Encode implements [github.com/jackc/pgx/v5/pgtype.EncodePlan.Encode].
120+
func (p geometryBinaryEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) {
121+
geom, ok := value.(*geos.Geom)
122+
if !ok {
123+
return buf, errors.ErrUnsupported
124+
}
125+
return append(buf, geom.ToEWKBWithSRID()...), nil
126+
}
127+
128+
// Encode implements [github.com/jackc/pgx/v5/pgtype.EncodePlan.Encode].
129+
func (p geometryTextEncodePlan) Encode(value any, buf []byte) (newBuf []byte, err error) {
130+
geom, ok := value.(*geos.Geom)
131+
if !ok {
132+
return buf, errors.ErrUnsupported
133+
}
134+
wkb := geom.ToEWKBWithSRID()
135+
return append(buf, []byte(hex.EncodeToString(wkb))...), nil
136+
}
137+
138+
// Scan implements [github.com/jackc/pgx/v5/pgtype.ScanPlan.Scan].
139+
func (p *geometryBinaryScanPlan) Scan(src []byte, target any) error {
140+
pgeom, ok := target.(**geos.Geom)
141+
if !ok {
142+
return errors.ErrUnsupported
143+
}
144+
if len(src) == 0 {
145+
*pgeom = nil
146+
return nil
147+
}
148+
geom, err := p.geosContext.NewGeomFromWKB(src)
149+
if err != nil {
150+
return err
151+
}
152+
(*pgeom).Destroy()
153+
*pgeom = geom
154+
return nil
155+
}
156+
157+
// Scan implements [github.com/jackc/pgx/v5/pgtype.ScanPlan.Scan].
158+
func (p *geometryTextScanPlan) Scan(src []byte, target any) error {
159+
pgeom, ok := target.(**geos.Geom)
160+
if !ok {
161+
return errors.ErrUnsupported
162+
}
163+
if len(src) == 0 {
164+
*pgeom = nil
165+
return nil
166+
}
167+
var err error
168+
src, err = hex.DecodeString(string(src))
169+
if err != nil {
170+
return err
171+
}
172+
geom, err := p.geosContext.NewGeomFromWKB(src)
173+
if err != nil {
174+
return err
175+
}
176+
(*pgeom).Destroy()
177+
*pgeom = geom
178+
return nil
179+
}
180+
181+
// registerGeom registers codecs for [*github.com/twpayne/go-geos.Geom] types on conn.
182+
func registerGeom(ctx context.Context, conn *pgx.Conn, geosContext *geos.Context) error {
183+
var geometryOID uint32
184+
err := conn.QueryRow(ctx, "select 'geometry'::text::regtype::oid").Scan(&geometryOID)
185+
if err != nil {
186+
return err
187+
}
188+
189+
if geosContext == nil {
190+
geosContext = geos.DefaultContext
191+
}
192+
193+
conn.TypeMap().RegisterType(&pgtype.Type{
194+
Codec: &geometryCodec{
195+
geosContext: geosContext,
196+
},
197+
Name: "geometry",
198+
OID: geometryOID,
199+
})
200+
201+
return nil
202+
}

geometry_test.go

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
package pgxgeos_test
2+
3+
import (
4+
"context"
5+
"strconv"
6+
"testing"
7+
8+
"github.com/alecthomas/assert/v2"
9+
"github.com/jackc/pgx/v5"
10+
"github.com/twpayne/go-geos"
11+
)
12+
13+
func TestCodecDecodeGeometryValue(t *testing.T) {
14+
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, tb testing.TB, conn *pgx.Conn) {
15+
tb.Helper()
16+
for _, format := range []int16{
17+
pgx.BinaryFormatCode,
18+
pgx.TextFormatCode,
19+
} {
20+
tb.(*testing.T).Run(strconv.Itoa(int(format)), func(t *testing.T) {
21+
original := mustNewGeomFromWKT(t, "POINT(1 2)").SetSRID(4326)
22+
rows, err := conn.Query(ctx, "select $1::geometry", pgx.QueryResultFormats{format}, original)
23+
assert.NoError(t, err)
24+
25+
for rows.Next() {
26+
values, err := rows.Values()
27+
assert.NoError(t, err)
28+
29+
assert.Equal(t, 1, len(values))
30+
v0, ok := values[0].(*geos.Geom)
31+
assert.True(t, ok)
32+
assert.True(t, original.Equals(v0))
33+
}
34+
35+
assert.NoError(t, rows.Err())
36+
})
37+
}
38+
})
39+
}
40+
41+
func TestCodecDecodeGeometryNullValue(t *testing.T) {
42+
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, tb testing.TB, conn *pgx.Conn) {
43+
tb.Helper()
44+
45+
type s struct {
46+
Geom *geos.Geom `db:"geom"`
47+
}
48+
49+
for _, format := range []int16{
50+
pgx.BinaryFormatCode,
51+
pgx.TextFormatCode,
52+
} {
53+
tb.(*testing.T).Run(strconv.Itoa(int(format)), func(t *testing.T) {
54+
tb.Helper()
55+
56+
rows, err := conn.Query(ctx, "select NULL::geometry AS geom", pgx.QueryResultFormats{format})
57+
assert.NoError(tb, err)
58+
59+
value, err := pgx.CollectExactlyOneRow(rows, pgx.RowToStructByName[s])
60+
assert.NoError(t, err)
61+
assert.Zero(t, value)
62+
})
63+
}
64+
})
65+
}
66+
67+
func TestCodecDecodeGeometryNull(t *testing.T) {
68+
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, tb testing.TB, conn *pgx.Conn) {
69+
tb.Helper()
70+
rows, err := conn.Query(ctx, "select $1::geometry", nil)
71+
assert.NoError(tb, err)
72+
73+
for rows.Next() {
74+
values, err := rows.Values()
75+
assert.NoError(tb, err)
76+
assert.Equal(tb, []any{nil}, values)
77+
}
78+
79+
assert.NoError(tb, rows.Err())
80+
})
81+
}
82+
83+
func TestCodecGeometryScanValue(t *testing.T) {
84+
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, tb testing.TB, conn *pgx.Conn) {
85+
tb.Helper()
86+
for _, format := range []int16{
87+
pgx.BinaryFormatCode,
88+
pgx.TextFormatCode,
89+
} {
90+
tb.(*testing.T).Run(strconv.Itoa(int(format)), func(t *testing.T) {
91+
var geom *geos.Geom
92+
err := conn.QueryRow(ctx, "select ST_SetSRID('POINT(1 2)'::geometry, 4326)", pgx.QueryResultFormats{format}).Scan(&geom)
93+
assert.NoError(t, err)
94+
assert.Equal(t, mustNewGeomFromWKT(t, "POINT(1 2)").SetSRID(4326).ToEWKBWithSRID(), geom.ToEWKBWithSRID())
95+
})
96+
}
97+
})
98+
}
99+
100+
func mustNewGeomFromWKT(tb testing.TB, wkt string) *geos.Geom {
101+
tb.Helper()
102+
geom, err := geos.NewGeomFromWKT(wkt)
103+
assert.NoError(tb, err)
104+
return geom
105+
}

0 commit comments

Comments
 (0)