Skip to content

Commit f22f82b

Browse files
djshow832xhebox
andauthored
backend, net: add more error sources (#407)
Signed-off-by: xhe <[email protected]> Co-authored-by: xhe <[email protected]>
1 parent 8712ca0 commit f22f82b

22 files changed

+332
-235
lines changed

pkg/manager/router/router.go

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,6 @@ import (
77
"time"
88

99
glist "github.com/bahlo/generic-list-go"
10-
"github.com/pingcap/tiproxy/lib/util/errors"
11-
)
12-
13-
var (
14-
ErrNoInstanceToSelect = errors.New("no instances to route")
1510
)
1611

1712
// ConnEventReceiver receives connection events.

pkg/proxy/backend/authenticator.go

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@ import (
1818
"go.uber.org/zap"
1919
)
2020

21-
var (
22-
ErrCapabilityNegotiation = errors.New("capability negotiation failed")
23-
)
24-
2521
const unknownAuthPlugin = "auth_unknown_plugin"
2622
const requiredFrontendCaps = pnet.ClientProtocol41
2723
const defRequiredBackendCaps = pnet.ClientDeprecateEOF
@@ -76,10 +72,10 @@ func (auth *Authenticator) verifyBackendCaps(logger *zap.Logger, backendCapabili
7672
// The error cannot be sent to the client because the client only expects an initial handshake packet.
7773
// The only way is to log it and disconnect.
7874
logger.Error("require backend capabilities", zap.Stringer("common", commonCaps), zap.Stringer("required", requiredBackendCaps^commonCaps))
79-
return errors.Wrapf(ErrCapabilityNegotiation, "require %s from backend", requiredBackendCaps^commonCaps)
75+
return errors.Wrapf(ErrBackendCap, "require %s from backend", requiredBackendCaps^commonCaps)
8076
}
8177
if auth.requireBackendTLS && (backendCapability&pnet.ClientSSL == 0) {
82-
return pnet.WrapUserError(errors.New("backend doesn't enable TLS"), requireTiDBTLSErrMsg)
78+
return ErrBackendNoTLS
8379
}
8480
return nil
8581
}
@@ -106,7 +102,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
106102
frontendCapability := pnet.Capability(binary.LittleEndian.Uint32(pkt))
107103
if isSSL {
108104
if _, err = clientIO.ServerTLSHandshake(frontendTLSConfig); err != nil {
109-
return pnet.WrapUserError(err, err.Error())
105+
return errors.Wrap(ErrClientHandshake, err)
110106
}
111107
pkt, _, err = clientIO.ReadSSLRequestOrHandshakeResp()
112108
if err != nil {
@@ -125,7 +121,7 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
125121
if writeErr := clientIO.WriteErrPacket(mysql.NewDefaultError(mysql.ER_NOT_SUPPORTED_AUTH_MODE)); writeErr != nil {
126122
return writeErr
127123
}
128-
return errors.Wrapf(ErrCapabilityNegotiation, "require %s from frontend", requiredFrontendCaps&^commonCaps)
124+
return errors.Wrapf(ErrClientCap, "require %s from frontend", requiredFrontendCaps&^commonCaps)
129125
}
130126
commonCaps := frontendCapability & proxyCapability
131127
if frontendCapability^commonCaps != 0 {
@@ -147,10 +143,10 @@ func (auth *Authenticator) handshakeFirstTime(logger *zap.Logger, cctx ConnConte
147143
if errors.As(err, &warning) {
148144
logger.Warn("parse handshake response encounters error", zap.Error(err))
149145
} else if err != nil {
150-
return pnet.WrapUserError(err, parsePktErrMsg)
146+
return err
151147
}
152148
if err = handshakeHandler.HandleHandshakeResp(cctx, clientResp); err != nil {
153-
return pnet.WrapUserError(err, err.Error())
149+
return errors.Wrap(ErrProxyErr, err)
154150
}
155151
auth.user = clientResp.User
156152
auth.dbname = clientResp.DB
@@ -163,29 +159,28 @@ RECONNECT:
163159
// In case of testing, backendIO is passed manually that we don't want to bother with the routing logic.
164160
backendIO, err := getBackendIO(cctx, auth, clientResp)
165161
if err != nil {
166-
return pnet.WrapUserError(err, connectErrMsg)
162+
return err
167163
}
168164
backendIO.ResetSequence()
169165

170166
// write proxy header
171167
if err := auth.writeProxyProtocol(clientIO, backendIO); err != nil {
172-
return pnet.WrapUserError(err, handshakeErrMsg)
168+
return err
173169
}
174170

175171
// read backend initial handshake
176172
serverPkt, backendCapability, err := auth.readInitialHandshake(backendIO)
177173
if err != nil {
178-
if IsMySQLError(err) {
174+
if pnet.IsMySQLError(err) {
179175
if writeErr := clientIO.WritePacket(serverPkt, true); writeErr != nil {
180-
err = writeErr
176+
return writeErr
181177
}
182-
return err
183178
}
184-
return pnet.WrapUserError(err, handshakeErrMsg)
179+
return err
185180
}
186181

187182
if err := auth.verifyBackendCaps(logger, backendCapability); err != nil {
188-
return pnet.WrapUserError(err, capabilityErrMsg)
183+
return err
189184
}
190185

191186
if common := proxyCapability & backendCapability; (proxyCapability^common)&^pnet.ClientSSL != 0 {
@@ -207,7 +202,7 @@ RECONNECT:
207202
// Copy the auth data so that the backend can set correct `using password` in the error message.
208203
unknownAuthPlugin, clientResp.AuthData, 0,
209204
); err != nil {
210-
return pnet.WrapUserError(err, handshakeErrMsg)
205+
return err
211206
}
212207

213208
// forward other packets
@@ -220,16 +215,18 @@ loop:
220215
// tiproxy pp enabled, tidb pp disabled, tls disabled => invalid sequence
221216
// tiproxy pp disabled, tidb pp enabled, tls disabled => invalid sequence
222217
if pktIdx == 0 && errors.Is(err, pnet.ErrInvalidSequence) {
223-
return pnet.WrapUserError(err, checkPPV2ErrMsg)
218+
return errors.Wrap(ErrBackendPPV2, err)
224219
}
225220
return err
226221
}
227-
var packetErr error
222+
var packetErr *mysql.MyError
228223
if serverPkt[0] == pnet.ErrHeader.Byte() {
229224
packetErr = pnet.ParseErrorPacket(serverPkt)
230-
if handshakeHandler.HandleHandshakeErr(cctx, packetErr.(*mysql.MyError)) {
231-
logger.Warn("handle handshake error, start reconnect", zap.Error(err))
232-
backendIO.Close()
225+
if handshakeHandler.HandleHandshakeErr(cctx, packetErr) {
226+
logger.Warn("handle handshake error, start reconnect", zap.Error(packetErr))
227+
if closeErr := backendIO.Close(); closeErr != nil {
228+
logger.Warn("close backend error", zap.Error(closeErr))
229+
}
233230
goto RECONNECT
234231
}
235232
}
@@ -238,17 +235,17 @@ loop:
238235
return err
239236
}
240237
if packetErr != nil {
241-
return packetErr
238+
return errors.Wrap(ErrClientAuthFail, packetErr)
242239
}
243240

244241
pktIdx++
245242
switch serverPkt[0] {
246243
case pnet.OKHeader.Byte():
247244
if err := setCompress(clientIO, auth.capability, auth.zstdLevel); err != nil {
248-
return err
245+
return errors.Wrap(ErrClientHandshake, err)
249246
}
250247
if err := setCompress(backendIO, auth.capability&backendCapability, auth.zstdLevel); err != nil {
251-
return err
248+
return errors.Wrap(ErrBackendHandshake, err)
252249
}
253250
return nil
254251
default: // mysql.AuthSwitchRequest, ShaCommand
@@ -276,7 +273,7 @@ func forwardMsg(srcIO, destIO *pnet.PacketIO) (data []byte, err error) {
276273

277274
func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, backendIO *pnet.PacketIO, backendTLSConfig *tls.Config, sessionToken string) error {
278275
if len(sessionToken) == 0 {
279-
return errors.New("session token is empty")
276+
return errors.Wrapf(ErrBackendHandshake, "session token is empty")
280277
}
281278

282279
// write proxy header
@@ -301,17 +298,20 @@ func (auth *Authenticator) handshakeSecondTime(logger *zap.Logger, clientIO, bac
301298
}
302299

303300
if err = auth.handleSecondAuthResult(backendIO); err == nil {
304-
return setCompress(backendIO, auth.capability&backendCapability, auth.zstdLevel)
301+
if err = setCompress(backendIO, auth.capability&backendCapability, auth.zstdLevel); err != nil {
302+
return errors.Wrap(ErrBackendHandshake, err)
303+
}
305304
}
306-
return err
305+
return errors.Wrap(ErrBackendHandshake, err)
307306
}
308307

309308
func (auth *Authenticator) readInitialHandshake(backendIO *pnet.PacketIO) (serverPkt []byte, capability pnet.Capability, err error) {
310309
if serverPkt, err = backendIO.ReadPacket(); err != nil {
310+
err = errors.Wrap(ErrBackendHandshake, err)
311311
return
312312
}
313313
if pnet.IsErrorPacket(serverPkt[0]) {
314-
err = pnet.ParseErrorPacket(serverPkt)
314+
err = errors.Wrap(ErrBackendHandshake, pnet.ParseErrorPacket(serverPkt))
315315
return
316316
}
317317
capability, _, _ = pnet.ParseInitialHandshake(serverPkt)
@@ -346,7 +346,7 @@ func (auth *Authenticator) writeAuthHandshake(
346346
var enableTLS bool
347347
if auth.requireBackendTLS {
348348
if backendTLSConfig == nil {
349-
return pnet.WrapUserError(errors.New("tiproxy doesn't enable TLS"), requireProxyTLSErrMsg)
349+
return ErrProxyNoTLS
350350
}
351351
enableTLS = true
352352
} else {
@@ -358,7 +358,7 @@ func (auth *Authenticator) writeAuthHandshake(
358358
pkt = pnet.MakeHandshakeResponse(resp)
359359
// write SSL Packet
360360
if err := backendIO.WritePacket(pkt[:32], true); err != nil {
361-
return err
361+
return errors.Wrap(ErrBackendHandshake, err)
362362
}
363363
// Send TLS / SSL request packet. The server must have supported TLS.
364364
tcfg := backendTLSConfig.Clone()
@@ -370,15 +370,18 @@ func (auth *Authenticator) writeAuthHandshake(
370370
if err := backendIO.ClientTLSHandshake(tcfg); err != nil {
371371
// tiproxy pp enabled, tidb pp disabled, tls enabled => tls handshake encounters unrecognized packet
372372
// tiproxy pp disabled, tidb pp enabled, tls enabled => tls handshake encounters unrecognized packet
373-
return pnet.WrapUserError(err, checkPPV2ErrMsg)
373+
return errors.Wrap(ErrBackendPPV2, err)
374374
}
375375
} else {
376376
resp.Capability &= ^pnet.ClientSSL
377377
pkt = pnet.MakeHandshakeResponse(resp)
378378
}
379379

380380
// write handshake resp
381-
return backendIO.WritePacket(pkt, true)
381+
if err := backendIO.WritePacket(pkt, true); err != nil {
382+
return errors.Wrap(ErrBackendHandshake, err)
383+
}
384+
return nil
382385
}
383386

384387
func (auth *Authenticator) handleSecondAuthResult(backendIO *pnet.PacketIO) error {
@@ -393,7 +396,7 @@ func (auth *Authenticator) handleSecondAuthResult(backendIO *pnet.PacketIO) erro
393396
case pnet.ErrHeader.Byte():
394397
return pnet.ParseErrorPacket(data)
395398
default: // mysql.AuthSwitchRequest, ShaCommand:
396-
return errors.Errorf("read unexpected command: %#x", data[0])
399+
return errors.Wrapf(mysql.ErrMalformPacket, "read unexpected command: %#x", data[0])
397400
}
398401
}
399402

pkg/proxy/backend/authenticator_test.go

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"testing"
99

1010
"github.com/pingcap/tidb/parser/mysql"
11-
"github.com/pingcap/tiproxy/lib/util/errors"
1211
pnet "github.com/pingcap/tiproxy/pkg/proxy/net"
1312
"github.com/stretchr/testify/require"
1413
)
@@ -70,10 +69,14 @@ func TestUnsupportedCapability(t *testing.T) {
7069
for _, cfgs := range cfgOverriders {
7170
ts, clean := newTestSuite(t, tc, cfgs...)
7271
ts.authenticateFirstTime(t, func(t *testing.T, _ *testSuite) {
73-
if ts.mb.backendConfig.capability&defRequiredBackendCaps != defRequiredBackendCaps {
74-
require.ErrorIs(t, ts.mp.err, ErrCapabilityNegotiation)
75-
} else if ts.mc.clientConfig.capability&requiredFrontendCaps != requiredFrontendCaps {
76-
require.ErrorIs(t, ts.mp.err, ErrCapabilityNegotiation)
72+
if ts.mc.clientConfig.capability&requiredFrontendCaps != requiredFrontendCaps {
73+
require.ErrorIs(t, ts.mp.err, ErrClientCap)
74+
require.Nil(t, ErrToClient(ts.mp.err))
75+
require.Equal(t, SrcClientHandshake, Error2Source(ts.mp.err))
76+
} else if ts.mb.backendConfig.capability&defRequiredBackendCaps != defRequiredBackendCaps {
77+
require.ErrorIs(t, ts.mp.err, ErrBackendCap)
78+
require.Equal(t, ErrBackendCap, ErrToClient(ts.mp.err))
79+
require.Equal(t, SrcBackendHandshake, Error2Source(ts.mp.err))
7780
} else {
7881
require.NoError(t, ts.mc.err)
7982
require.NoError(t, ts.mp.err)
@@ -311,31 +314,35 @@ func TestAuthFail(t *testing.T) {
311314
ts, clean := newTestSuite(t, tc, cfg)
312315
ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) {
313316
require.Equal(t, len(ts.mc.authData), len(ts.mb.authData))
317+
require.Equal(t, SrcClientAuthFail, Error2Source(ts.mp.err))
314318
})
315319
clean()
316320
}
317321
}
318322

319323
func TestRequireBackendTLS(t *testing.T) {
320324
tests := []struct {
321-
cfg cfgOverrider
322-
errMsg string
325+
cfg cfgOverrider
326+
err error
327+
src ErrorSource
323328
}{
324329
{
325330
cfg: func(cfg *testConfig) {
326331
cfg.proxyConfig.bcConfig.RequireBackendTLS = true
327332
cfg.proxyConfig.backendTLSConfig = nil
328333
cfg.backendConfig.capability |= pnet.ClientSSL
329334
},
330-
errMsg: requireProxyTLSErrMsg,
335+
err: ErrProxyNoTLS,
336+
src: SrcProxyErr,
331337
},
332338
{
333339
cfg: func(cfg *testConfig) {
334340
cfg.proxyConfig.bcConfig.RequireBackendTLS = true
335341
cfg.backendConfig.tlsConfig = nil
336342
cfg.backendConfig.capability &= ^pnet.ClientSSL
337343
},
338-
errMsg: requireTiDBTLSErrMsg,
344+
err: ErrBackendNoTLS,
345+
src: SrcBackendHandshake,
339346
},
340347
{
341348
cfg: func(cfg *testConfig) {
@@ -351,10 +358,9 @@ func TestRequireBackendTLS(t *testing.T) {
351358
for _, tt := range tests {
352359
ts, clean := newTestSuite(t, tc, tt.cfg)
353360
ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) {
354-
if len(tt.errMsg) > 0 {
355-
var userError *pnet.UserError
356-
require.True(t, errors.As(ts.mp.err, &userError))
357-
require.Equal(t, tt.errMsg, userError.UserMsg())
361+
if tt.err != nil {
362+
require.ErrorIs(t, ts.mp.err, tt.err)
363+
require.Equal(t, tt.src, Error2Source(ts.mp.err))
358364
} else {
359365
require.NoError(t, ts.mp.err)
360366
}
@@ -401,9 +407,9 @@ func TestProxyProtocol(t *testing.T) {
401407
// TiDB proxy-protocol can be set unfallbackable, but TiProxy proxy-protocol is always fallbackable.
402408
// So when backend enables proxy-protocol and proxy disables it, it still works well.
403409
if ts.mp.bcConfig.ProxyProtocol && !ts.mb.proxyProtocol {
404-
var userError *pnet.UserError
405-
require.True(t, errors.As(ts.mp.err, &userError))
406-
require.Equal(t, checkPPV2ErrMsg, userError.UserMsg())
410+
err := ErrToClient(ts.mp.err)
411+
require.Equal(t, ErrBackendPPV2, err)
412+
require.Equal(t, SrcBackendHandshake, Error2Source(err))
407413
} else {
408414
require.NoError(t, ts.mp.err)
409415
}

0 commit comments

Comments
 (0)