-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathparser.go
89 lines (74 loc) · 2 KB
/
parser.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
package parplan
import (
"sort"
pc_parser "github.com/pingcap/parser"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/charset"
"github.com/squareup/pranadb/errors"
driver "github.com/squareup/pranadb/tidb/types/parser_driver"
)
func NewParser() *Parser {
p := pc_parser.New()
p.SetStrictDoubleTypeCheck(false)
p.EnableWindowFunc(false)
return &Parser{p}
}
type Parser struct {
parser *pc_parser.Parser
}
func (p *Parser) Parse(sql string) (stmt AstHandle, paramCount int, err error) {
stmtNodes, warns, err := p.parser.Parse(sql, charset.CharsetUTF8, "")
if err != nil {
return AstHandle{}, 0, errors.WithStack(err)
}
if warns != nil {
for _, warn := range warns {
println(warn)
}
}
if len(stmtNodes) != 1 {
return AstHandle{}, 0, errors.Errorf("expected 1 statement got %d", len(stmtNodes))
}
// We gather the param marker expressions then sort them in order of where they appear in the original sql
// as they may be visited in a different order.
// We then set the order property on them
stmtNode := stmtNodes[0]
vis := &pmVisitor{}
stmtNode.Accept(vis)
pms := vis.pms
sorter := &pmSorter{pms: pms}
sort.Sort(sorter)
for i, pme := range pms {
pme.SetOrder(i)
}
return AstHandle{stmt: stmtNode}, len(pms), nil
}
// AstHandle wraps the underlying TiDB ast, to avoid leaking the TiDB too much into the rest of the code
type AstHandle struct {
stmt ast.StmtNode
}
type pmVisitor struct {
pms []ast.ParamMarkerExpr
}
func (p *pmVisitor) Enter(in ast.Node) (ast.Node, bool) {
return in, false
}
func (p *pmVisitor) Leave(in ast.Node) (ast.Node, bool) {
pm, ok := in.(*driver.ParamMarkerExpr)
if ok {
p.pms = append(p.pms, pm)
}
return in, true
}
type pmSorter struct {
pms []ast.ParamMarkerExpr
}
func (ps *pmSorter) Len() int {
return len(ps.pms)
}
func (ps *pmSorter) Less(i, j int) bool {
return ps.pms[i].(*driver.ParamMarkerExpr).Offset < ps.pms[j].(*driver.ParamMarkerExpr).Offset
}
func (ps *pmSorter) Swap(i, j int) {
ps.pms[i], ps.pms[j] = ps.pms[j], ps.pms[i]
}