@@ -18,10 +18,6 @@ import (
18
18
"go.uber.org/zap"
19
19
)
20
20
21
- var (
22
- ErrCapabilityNegotiation = errors .New ("capability negotiation failed" )
23
- )
24
-
25
21
const unknownAuthPlugin = "auth_unknown_plugin"
26
22
const requiredFrontendCaps = pnet .ClientProtocol41
27
23
const defRequiredBackendCaps = pnet .ClientDeprecateEOF
@@ -76,10 +72,10 @@ func (auth *Authenticator) verifyBackendCaps(logger *zap.Logger, backendCapabili
76
72
// The error cannot be sent to the client because the client only expects an initial handshake packet.
77
73
// The only way is to log it and disconnect.
78
74
logger .Error ("require backend capabilities" , zap .Stringer ("common" , commonCaps ), zap .Stringer ("required" , requiredBackendCaps ^ commonCaps ))
79
- return errors .Wrapf (ErrCapabilityNegotiation , "require %s from backend" , requiredBackendCaps ^ commonCaps )
75
+ return errors .Wrapf (ErrBackendCap , "require %s from backend" , requiredBackendCaps ^ commonCaps )
80
76
}
81
77
if auth .requireBackendTLS && (backendCapability & pnet .ClientSSL == 0 ) {
82
- return pnet . WrapUserError ( errors . New ( "backend doesn't enable TLS" ), requireTiDBTLSErrMsg )
78
+ return ErrBackendNoTLS
83
79
}
84
80
return nil
85
81
}
@@ -106,7 +102,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
106
102
frontendCapability := pnet .Capability (binary .LittleEndian .Uint32 (pkt ))
107
103
if isSSL {
108
104
if _ , err = clientIO .ServerTLSHandshake (frontendTLSConfig ); err != nil {
109
- return pnet . WrapUserError ( err , err . Error () )
105
+ return errors . Wrap ( ErrClientHandshake , err )
110
106
}
111
107
pkt , _ , err = clientIO .ReadSSLRequestOrHandshakeResp ()
112
108
if err != nil {
@@ -125,7 +121,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
125
121
if writeErr := clientIO .WriteErrPacket (mysql .NewDefaultError (mysql .ER_NOT_SUPPORTED_AUTH_MODE )); writeErr != nil {
126
122
return writeErr
127
123
}
128
- return errors .Wrapf (ErrCapabilityNegotiation , "require %s from frontend" , requiredFrontendCaps &^commonCaps )
124
+ return errors .Wrapf (ErrClientCap , "require %s from frontend" , requiredFrontendCaps &^commonCaps )
129
125
}
130
126
commonCaps := frontendCapability & proxyCapability
131
127
if frontendCapability ^ commonCaps != 0 {
@@ -147,10 +143,10 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
147
143
if errors .As (err , & warning ) {
148
144
logger .Warn ("parse handshake response encounters error" , zap .Error (err ))
149
145
} else if err != nil {
150
- return pnet . WrapUserError ( err , parsePktErrMsg )
146
+ return err
151
147
}
152
148
if err = handshakeHandler .HandleHandshakeResp (cctx , clientResp ); err != nil {
153
- return pnet . WrapUserError ( err , err . Error () )
149
+ return errors . Wrap ( ErrProxyErr , err )
154
150
}
155
151
auth .user = clientResp .User
156
152
auth .dbname = clientResp .DB
@@ -163,29 +159,28 @@ RECONNECT:
163
159
// In case of testing, backendIO is passed manually that we don't want to bother with the routing logic.
164
160
backendIO , err := getBackendIO (cctx , auth , clientResp )
165
161
if err != nil {
166
- return pnet . WrapUserError ( err , connectErrMsg )
162
+ return err
167
163
}
168
164
backendIO .ResetSequence ()
169
165
170
166
// write proxy header
171
167
if err := auth .writeProxyProtocol (clientIO , backendIO ); err != nil {
172
- return pnet . WrapUserError ( err , handshakeErrMsg )
168
+ return err
173
169
}
174
170
175
171
// read backend initial handshake
176
172
serverPkt , backendCapability , err := auth .readInitialHandshake (backendIO )
177
173
if err != nil {
178
- if IsMySQLError (err ) {
174
+ if pnet . IsMySQLError (err ) {
179
175
if writeErr := clientIO .WritePacket (serverPkt , true ); writeErr != nil {
180
- err = writeErr
176
+ return writeErr
181
177
}
182
- return err
183
178
}
184
- return pnet . WrapUserError ( err , handshakeErrMsg )
179
+ return err
185
180
}
186
181
187
182
if err := auth .verifyBackendCaps (logger , backendCapability ); err != nil {
188
- return pnet . WrapUserError ( err , capabilityErrMsg )
183
+ return err
189
184
}
190
185
191
186
if common := proxyCapability & backendCapability ; (proxyCapability ^ common )&^pnet .ClientSSL != 0 {
@@ -207,7 +202,7 @@ RECONNECT:
207
202
// Copy the auth data so that the backend can set correct `using password` in the error message.
208
203
unknownAuthPlugin , clientResp .AuthData , 0 ,
209
204
); err != nil {
210
- return pnet . WrapUserError ( err , handshakeErrMsg )
205
+ return err
211
206
}
212
207
213
208
// forward other packets
@@ -220,16 +215,18 @@ loop:
220
215
// tiproxy pp enabled, tidb pp disabled, tls disabled => invalid sequence
221
216
// tiproxy pp disabled, tidb pp enabled, tls disabled => invalid sequence
222
217
if pktIdx == 0 && errors .Is (err , pnet .ErrInvalidSequence ) {
223
- return pnet . WrapUserError ( err , checkPPV2ErrMsg )
218
+ return errors . Wrap ( ErrBackendPPV2 , err )
224
219
}
225
220
return err
226
221
}
227
- var packetErr error
222
+ var packetErr * mysql. MyError
228
223
if serverPkt [0 ] == pnet .ErrHeader .Byte () {
229
224
packetErr = pnet .ParseErrorPacket (serverPkt )
230
- if handshakeHandler .HandleHandshakeErr (cctx , packetErr .(* mysql.MyError )) {
231
- logger .Warn ("handle handshake error, start reconnect" , zap .Error (err ))
232
- backendIO .Close ()
225
+ if handshakeHandler .HandleHandshakeErr (cctx , packetErr ) {
226
+ logger .Warn ("handle handshake error, start reconnect" , zap .Error (packetErr ))
227
+ if closeErr := backendIO .Close (); closeErr != nil {
228
+ logger .Warn ("close backend error" , zap .Error (closeErr ))
229
+ }
233
230
goto RECONNECT
234
231
}
235
232
}
@@ -238,17 +235,17 @@ loop:
238
235
return err
239
236
}
240
237
if packetErr != nil {
241
- return packetErr
238
+ return errors . Wrap ( ErrClientAuthFail , packetErr )
242
239
}
243
240
244
241
pktIdx ++
245
242
switch serverPkt [0 ] {
246
243
case pnet .OKHeader .Byte ():
247
244
if err := setCompress (clientIO , auth .capability , auth .zstdLevel ); err != nil {
248
- return err
245
+ return errors . Wrap ( ErrClientHandshake , err )
249
246
}
250
247
if err := setCompress (backendIO , auth .capability & backendCapability , auth .zstdLevel ); err != nil {
251
- return err
248
+ return errors . Wrap ( ErrBackendHandshake , err )
252
249
}
253
250
return nil
254
251
default : // mysql.AuthSwitchRequest, ShaCommand
@@ -276,7 +273,7 @@ func forwardMsg(srcIO, destIO *pnet.PacketIO) (data []byte, err error) {
276
273
277
274
func (auth * Authenticator ) handshakeSecondTime (logger * zap.Logger , clientIO , backendIO * pnet.PacketIO , backendTLSConfig * tls.Config , sessionToken string ) error {
278
275
if len (sessionToken ) == 0 {
279
- return errors .New ( "session token is empty" )
276
+ return errors .Wrapf ( ErrBackendHandshake , "session token is empty" )
280
277
}
281
278
282
279
// write proxy header
@@ -301,17 +298,20 @@ func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, bac
301
298
}
302
299
303
300
if err = auth .handleSecondAuthResult (backendIO ); err == nil {
304
- return setCompress (backendIO , auth .capability & backendCapability , auth .zstdLevel )
301
+ if err = setCompress (backendIO , auth .capability & backendCapability , auth .zstdLevel ); err != nil {
302
+ return errors .Wrap (ErrBackendHandshake , err )
303
+ }
305
304
}
306
- return err
305
+ return errors . Wrap ( ErrBackendHandshake , err )
307
306
}
308
307
309
308
func (auth * Authenticator ) readInitialHandshake (backendIO * pnet.PacketIO ) (serverPkt []byte , capability pnet.Capability , err error ) {
310
309
if serverPkt , err = backendIO .ReadPacket (); err != nil {
310
+ err = errors .Wrap (ErrBackendHandshake , err )
311
311
return
312
312
}
313
313
if pnet .IsErrorPacket (serverPkt [0 ]) {
314
- err = pnet .ParseErrorPacket (serverPkt )
314
+ err = errors . Wrap ( ErrBackendHandshake , pnet .ParseErrorPacket (serverPkt ) )
315
315
return
316
316
}
317
317
capability , _ , _ = pnet .ParseInitialHandshake (serverPkt )
@@ -346,7 +346,7 @@ func (auth *Authenticator) writeAuthHandshake(
346
346
var enableTLS bool
347
347
if auth .requireBackendTLS {
348
348
if backendTLSConfig == nil {
349
- return pnet . WrapUserError ( errors . New ( "tiproxy doesn't enable TLS" ), requireProxyTLSErrMsg )
349
+ return ErrProxyNoTLS
350
350
}
351
351
enableTLS = true
352
352
} else {
@@ -358,7 +358,7 @@ func (auth *Authenticator) writeAuthHandshake(
358
358
pkt = pnet .MakeHandshakeResponse (resp )
359
359
// write SSL Packet
360
360
if err := backendIO .WritePacket (pkt [:32 ], true ); err != nil {
361
- return err
361
+ return errors . Wrap ( ErrBackendHandshake , err )
362
362
}
363
363
// Send TLS / SSL request packet. The server must have supported TLS.
364
364
tcfg := backendTLSConfig .Clone ()
@@ -370,15 +370,18 @@ func (auth *Authenticator) writeAuthHandshake(
370
370
if err := backendIO .ClientTLSHandshake (tcfg ); err != nil {
371
371
// tiproxy pp enabled, tidb pp disabled, tls enabled => tls handshake encounters unrecognized packet
372
372
// tiproxy pp disabled, tidb pp enabled, tls enabled => tls handshake encounters unrecognized packet
373
- return pnet . WrapUserError ( err , checkPPV2ErrMsg )
373
+ return errors . Wrap ( ErrBackendPPV2 , err )
374
374
}
375
375
} else {
376
376
resp .Capability &= ^ pnet .ClientSSL
377
377
pkt = pnet .MakeHandshakeResponse (resp )
378
378
}
379
379
380
380
// write handshake resp
381
- return backendIO .WritePacket (pkt , true )
381
+ if err := backendIO .WritePacket (pkt , true ); err != nil {
382
+ return errors .Wrap (ErrBackendHandshake , err )
383
+ }
384
+ return nil
382
385
}
383
386
384
387
func (auth * Authenticator ) handleSecondAuthResult (backendIO * pnet.PacketIO ) error {
@@ -393,7 +396,7 @@ func (auth *Authenticator) handleSecondAuthResult(backendIO *pnet.PacketIO) erro
393
396
case pnet .ErrHeader .Byte ():
394
397
return pnet .ParseErrorPacket (data )
395
398
default : // mysql.AuthSwitchRequest, ShaCommand:
396
- return errors .Errorf ( "read unexpected command: %#x" , data [0 ])
399
+ return errors .Wrapf ( mysql . ErrMalformPacket , "read unexpected command: %#x" , data [0 ])
397
400
}
398
401
}
399
402
0 commit comments