From 182ca58280428f96ec5ab67a7cbc71171630e293 Mon Sep 17 00:00:00 2001 From: xhe Date: Fri, 1 Dec 2023 16:41:40 +0800 Subject: [PATCH 1/9] proxy: remove sequence check Signed-off-by: xhe --- pkg/proxy/backend/authenticator.go | 2 +- pkg/proxy/backend/authenticator_test.go | 4 ++ pkg/proxy/backend/error.go | 4 +- pkg/proxy/net/compress.go | 20 +--------- pkg/proxy/net/compress_test.go | 53 ------------------------- pkg/proxy/net/packetio.go | 27 +------------ 6 files changed, 10 insertions(+), 100 deletions(-) diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index 98996160..77ed078c 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -214,7 +214,7 @@ loop: if err != nil { // tiproxy pp enabled, tidb pp disabled, tls disabled => invalid sequence // tiproxy pp disabled, tidb pp enabled, tls disabled => invalid sequence - if pktIdx == 0 && errors.Is(err, pnet.ErrInvalidSequence) { + if pktIdx == 0 && auth.proxyProtocol { return errors.Wrap(ErrBackendPPV2, err) } return err diff --git a/pkg/proxy/backend/authenticator_test.go b/pkg/proxy/backend/authenticator_test.go index 221068d0..c720b6ad 100644 --- a/pkg/proxy/backend/authenticator_test.go +++ b/pkg/proxy/backend/authenticator_test.go @@ -403,6 +403,10 @@ func TestProxyProtocol(t *testing.T) { cfgOverriders := getCfgCombinations(cfgs) for _, cfgs := range cfgOverriders { ts, clean := newTestSuite(t, tc, cfgs...) + // invalid sequence detection removed, backend will stuck if clients insists to send proxy header. + if !ts.mb.proxyProtocol && ts.mp.bcConfig.ProxyProtocol { + continue + } ts.authenticateFirstTime(t, func(t *testing.T, ts *testSuite) { // TiDB proxy-protocol can be set unfallbackable, but TiProxy proxy-protocol is always fallbackable. // So when backend enables proxy-protocol and proxy disables it, it still works well. diff --git a/pkg/proxy/backend/error.go b/pkg/proxy/backend/error.go index 666d1199..529837b3 100644 --- a/pkg/proxy/backend/error.go +++ b/pkg/proxy/backend/error.go @@ -106,8 +106,8 @@ func Error2Source(err error) ErrorSource { } } switch { - // ErrInvalidSequence and ErrMalformPacket may be wrapped with other errors such as ErrBackendHandshake. - case errors.Is(err, pnet.ErrInvalidSequence), errors.Is(err, gomysql.ErrMalformPacket): + // ErrMalformPacket may be wrapped with other errors such as ErrBackendHandshake. + case errors.Is(err, gomysql.ErrMalformPacket): // We assume the clients and TiDB are right and treat it as TiProxy bugs. return SrcProxyMalformed case errors.Is(err, ErrClientHandshake), errors.Is(err, ErrClientCap): diff --git a/pkg/proxy/net/compress.go b/pkg/proxy/net/compress.go index fb6a9be5..a8d11235 100644 --- a/pkg/proxy/net/compress.go +++ b/pkg/proxy/net/compress.go @@ -59,7 +59,6 @@ type compressedReadWriter struct { writeBuffer bytes.Buffer algorithm CompressAlgorithm logger *zap.Logger - rwStatus rwStatus zstdLevel zstd.EncoderLevel header []byte sequence uint8 @@ -71,7 +70,6 @@ func newCompressedReadWriter(rw packetReadWriter, algorithm CompressAlgorithm, z algorithm: algorithm, zstdLevel: zstd.EncoderLevelFromZstd(zstdLevel), logger: logger, - rwStatus: rwNone, header: make([]byte, 7), } } @@ -81,17 +79,6 @@ func (crw *compressedReadWriter) ResetSequence() { // Reset the compressed sequence before the next command. // Sequence wraps around once it hits 0xFF, so we need ResetSequence() to know that it's reset instead of overflow. crw.sequence = 0 - crw.rwStatus = rwNone -} - -// BeginRW implements packetReadWriter.BeginRW. -// Uncompressed sequence of MySQL doesn't follow the spec: it's set to the compressed sequence when -// the client/server begins reading or writing. -func (crw *compressedReadWriter) BeginRW(status rwStatus) { - if crw.rwStatus != status { - crw.packetReadWriter.SetSequence(crw.sequence) - crw.rwStatus = status - } } func (crw *compressedReadWriter) Read(p []byte) (n int, err error) { @@ -134,12 +121,7 @@ func (crw *compressedReadWriter) readFromConn() error { if err = ReadFull(crw.packetReadWriter, crw.header); err != nil { return err } - compressedSequence := crw.header[3] - if compressedSequence != crw.sequence { - return ErrInvalidSequence.GenWithStack( - "invalid compressed sequence, expected %d, actual %d", crw.sequence, compressedSequence) - } - crw.sequence++ + crw.sequence = crw.header[3] + 1 compressedLength := int(uint32(crw.header[0]) | uint32(crw.header[1])<<8 | uint32(crw.header[2])<<16) uncompressedLength := int(uint32(crw.header[4]) | uint32(crw.header[5])<<8 | uint32(crw.header[6])<<16) diff --git a/pkg/proxy/net/compress_test.go b/pkg/proxy/net/compress_test.go index 23c46840..98e27dec 100644 --- a/pkg/proxy/net/compress_test.go +++ b/pkg/proxy/net/compress_test.go @@ -125,57 +125,6 @@ func TestCompressPeekDiscard(t *testing.T) { }, 1) } -// Test that the uncompressed sequence is correct. -func TestCompressSequence(t *testing.T) { - lg, _ := logger.CreateLoggerForTest(t) - testkit.TestTCPConn(t, - func(t *testing.T, c net.Conn) { - crw := newCompressedReadWriter(newBasicReadWriter(c, DefaultConnBufferSize), CompressionZlib, 0, lg) - fillAndWrite(t, crw, 'a', 100) - fillAndWrite(t, crw, 'a', 100) - require.NoError(t, crw.Flush()) - require.Equal(t, uint8(2), crw.Sequence()) - // uncompressed sequence = compressed sequence - readAndCheck(t, crw, 'a', 100) - require.Equal(t, uint8(2), crw.Sequence()) - readAndCheck(t, crw, 'a', 100) - require.Equal(t, uint8(3), crw.Sequence()) - // uncompressed sequence = compressed sequence - fillAndWrite(t, crw, 'a', maxCompressedSize+1) - require.NoError(t, crw.Flush()) - require.Equal(t, uint8(3), crw.Sequence()) - // uncompressed sequence = compressed sequence - readAndCheck(t, crw, 'a', maxCompressedSize+1) - require.Equal(t, uint8(5), crw.Sequence()) - // flush empty buffer won't increase sequence - require.NoError(t, crw.Flush()) - require.NoError(t, crw.Flush()) - fillAndWrite(t, crw, 'a', 100) - require.NoError(t, crw.Flush()) - }, - func(t *testing.T, c net.Conn) { - crw := newCompressedReadWriter(newBasicReadWriter(c, DefaultConnBufferSize), CompressionZlib, 0, lg) - readAndCheck(t, crw, 'a', 100) - readAndCheck(t, crw, 'a', 100) - require.Equal(t, uint8(2), crw.Sequence()) - // uncompressed sequence = compressed sequence - fillAndWrite(t, crw, 'a', 100) - require.Equal(t, uint8(2), crw.Sequence()) - fillAndWrite(t, crw, 'a', 100) - require.Equal(t, uint8(3), crw.Sequence()) - require.NoError(t, crw.Flush()) - // uncompressed sequence = compressed sequence - readAndCheck(t, crw, 'a', maxCompressedSize+1) - require.Equal(t, uint8(3), crw.Sequence()) - // uncompressed sequence = compressed sequence - fillAndWrite(t, crw, 'a', maxCompressedSize+1) - require.NoError(t, crw.Flush()) - require.Equal(t, uint8(5), crw.Sequence()) - // flush empty buffer won't increase sequence - readAndCheck(t, crw, 'a', 100) - }, 1) -} - // Test that the compressed header is correctly filled. func TestCompressHeader(t *testing.T) { lg, _ := logger.CreateLoggerForTest(t) @@ -237,7 +186,6 @@ func TestReadWriteError(t *testing.T) { } func fillAndWrite(t *testing.T, crw *compressedReadWriter, b byte, length int) { - crw.BeginRW(rwWrite) data := fillData(b, length) _, err := crw.Write(data) require.NoError(t, err) @@ -253,7 +201,6 @@ func fillData(b byte, length int) []byte { } func readAndCheck(t *testing.T, crw *compressedReadWriter, b byte, length int) { - crw.BeginRW(rwRead) data := make([]byte, length) _, err := io.ReadFull(crw, data) require.NoError(t, err) diff --git a/pkg/proxy/net/packetio.go b/pkg/proxy/net/packetio.go index 201a6b33..f61db2ba 100644 --- a/pkg/proxy/net/packetio.go +++ b/pkg/proxy/net/packetio.go @@ -30,8 +30,6 @@ import ( "net" "time" - "github.com/pingcap/tidb/errno" - "github.com/pingcap/tidb/util/dbterror" "github.com/pingcap/tiproxy/lib/config" "github.com/pingcap/tiproxy/lib/util/errors" "github.com/pingcap/tiproxy/pkg/proxy/keepalive" @@ -40,10 +38,6 @@ import ( "go.uber.org/zap" ) -var ( - ErrInvalidSequence = dbterror.ClassServer.NewStd(errno.ErrInvalidSequence) -) - const ( DefaultConnBufferSize = 32 * 1024 ) @@ -74,8 +68,6 @@ type packetReadWriter interface { Sequence() uint8 // ResetSequence is called before executing a command. ResetSequence() - // BeginRW is called before reading or writing packets. - BeginRW(status rwStatus) } var _ packetReadWriter = (*basicReadWriter)(nil) @@ -143,9 +135,6 @@ func (brw *basicReadWriter) OutBytes() uint64 { return brw.outBytes } -func (brw *basicReadWriter) BeginRW(rwStatus) { -} - func (brw *basicReadWriter) ResetSequence() { brw.sequence = 0 } @@ -246,11 +235,7 @@ func (p *PacketIO) readOnePacket() ([]byte, bool, error) { if err := ReadFull(p.readWriter, p.header); err != nil { return nil, false, errors.Wrap(ErrReadConn, err) } - sequence, pktSequence := p.header[3], p.readWriter.Sequence() - if sequence != pktSequence { - return nil, false, ErrInvalidSequence.GenWithStack("invalid sequence, expected %d, actual %d", pktSequence, sequence) - } - p.readWriter.SetSequence(sequence + 1) + p.readWriter.SetSequence(p.header[3] + 1) length := int(p.header[0]) | int(p.header[1])<<8 | int(p.header[2])<<16 data := make([]byte, length) @@ -262,7 +247,6 @@ func (p *PacketIO) readOnePacket() ([]byte, bool, error) { // ReadPacket reads data and removes the header func (p *PacketIO) ReadPacket() (data []byte, err error) { - p.readWriter.BeginRW(rwRead) for more := true; more; { var buf []byte buf, more, err = p.readOnePacket() @@ -309,7 +293,6 @@ func (p *PacketIO) writeOnePacket(data []byte) (int, bool, error) { // WritePacket writes data without a header func (p *PacketIO) WritePacket(data []byte, flush bool) (err error) { - p.readWriter.BeginRW(rwWrite) for more := true; more; { var n int n, more, err = p.writeOnePacket(data) @@ -327,8 +310,6 @@ func (p *PacketIO) WritePacket(data []byte, flush bool) (err error) { func (p *PacketIO) ForwardUntil(dest *PacketIO, isEnd func(firstByte byte, firstPktLen int) (end, needData bool), process func(response []byte) error) error { - p.readWriter.BeginRW(rwRead) - dest.readWriter.BeginRW(rwWrite) p.limitReader.R = p.readWriter for { header, err := p.readWriter.Peek(5) @@ -350,11 +331,7 @@ func (p *PacketIO) ForwardUntil(dest *PacketIO, isEnd func(firstByte byte, first } } else { for { - sequence, pktSequence := header[3], p.readWriter.Sequence() - if sequence != pktSequence { - return p.wrapErr(ErrInvalidSequence.GenWithStack("invalid sequence, expected %d, actual %d", pktSequence, sequence)) - } - p.readWriter.SetSequence(sequence + 1) + p.readWriter.SetSequence(header[3] + 1) // Sequence may be different (e.g. with compression) so we can't just copy the data to the destination. dest.readWriter.SetSequence(dest.readWriter.Sequence() + 1) p.limitReader.N = int64(length + 4) From 3fe459ff2f0c48a6bbf337b2751ab58ab5296c91 Mon Sep 17 00:00:00 2001 From: xhe Date: Fri, 1 Dec 2023 16:47:21 +0800 Subject: [PATCH 2/9] relax check Signed-off-by: xhe --- pkg/proxy/backend/authenticator.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index 77ed078c..4f1bafc8 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -214,7 +214,7 @@ loop: if err != nil { // tiproxy pp enabled, tidb pp disabled, tls disabled => invalid sequence // tiproxy pp disabled, tidb pp enabled, tls disabled => invalid sequence - if pktIdx == 0 && auth.proxyProtocol { + if pktIdx == 0 { return errors.Wrap(ErrBackendPPV2, err) } return err From 429cb2937c071d7ed8c9733b6d8bbacaf0cb4264 Mon Sep 17 00:00:00 2001 From: xhe Date: Fri, 1 Dec 2023 17:00:36 +0800 Subject: [PATCH 3/9] remove unused Signed-off-by: xhe --- pkg/proxy/net/packetio.go | 8 -------- 1 file changed, 8 deletions(-) diff --git a/pkg/proxy/net/packetio.go b/pkg/proxy/net/packetio.go index f61db2ba..14b2bbab 100644 --- a/pkg/proxy/net/packetio.go +++ b/pkg/proxy/net/packetio.go @@ -42,14 +42,6 @@ const ( DefaultConnBufferSize = 32 * 1024 ) -type rwStatus int - -const ( - rwNone rwStatus = iota - rwRead - rwWrite -) - // packetReadWriter acts like a net.Conn with read and write buffer. type packetReadWriter interface { net.Conn From 148f285bef01a0606e2b24bcf3e908186665f6ae Mon Sep 17 00:00:00 2001 From: xhe Date: Fri, 1 Dec 2023 17:09:06 +0800 Subject: [PATCH 4/9] revert compress change Signed-off-by: xhe --- pkg/proxy/backend/error.go | 4 +-- pkg/proxy/net/compress.go | 20 ++++++++++++- pkg/proxy/net/compress_test.go | 53 ++++++++++++++++++++++++++++++++++ pkg/proxy/net/packetio.go | 23 +++++++++++++++ 4 files changed, 97 insertions(+), 3 deletions(-) diff --git a/pkg/proxy/backend/error.go b/pkg/proxy/backend/error.go index 529837b3..666d1199 100644 --- a/pkg/proxy/backend/error.go +++ b/pkg/proxy/backend/error.go @@ -106,8 +106,8 @@ func Error2Source(err error) ErrorSource { } } switch { - // ErrMalformPacket may be wrapped with other errors such as ErrBackendHandshake. - case errors.Is(err, gomysql.ErrMalformPacket): + // ErrInvalidSequence and ErrMalformPacket may be wrapped with other errors such as ErrBackendHandshake. + case errors.Is(err, pnet.ErrInvalidSequence), errors.Is(err, gomysql.ErrMalformPacket): // We assume the clients and TiDB are right and treat it as TiProxy bugs. return SrcProxyMalformed case errors.Is(err, ErrClientHandshake), errors.Is(err, ErrClientCap): diff --git a/pkg/proxy/net/compress.go b/pkg/proxy/net/compress.go index a8d11235..fb6a9be5 100644 --- a/pkg/proxy/net/compress.go +++ b/pkg/proxy/net/compress.go @@ -59,6 +59,7 @@ type compressedReadWriter struct { writeBuffer bytes.Buffer algorithm CompressAlgorithm logger *zap.Logger + rwStatus rwStatus zstdLevel zstd.EncoderLevel header []byte sequence uint8 @@ -70,6 +71,7 @@ func newCompressedReadWriter(rw packetReadWriter, algorithm CompressAlgorithm, z algorithm: algorithm, zstdLevel: zstd.EncoderLevelFromZstd(zstdLevel), logger: logger, + rwStatus: rwNone, header: make([]byte, 7), } } @@ -79,6 +81,17 @@ func (crw *compressedReadWriter) ResetSequence() { // Reset the compressed sequence before the next command. // Sequence wraps around once it hits 0xFF, so we need ResetSequence() to know that it's reset instead of overflow. crw.sequence = 0 + crw.rwStatus = rwNone +} + +// BeginRW implements packetReadWriter.BeginRW. +// Uncompressed sequence of MySQL doesn't follow the spec: it's set to the compressed sequence when +// the client/server begins reading or writing. +func (crw *compressedReadWriter) BeginRW(status rwStatus) { + if crw.rwStatus != status { + crw.packetReadWriter.SetSequence(crw.sequence) + crw.rwStatus = status + } } func (crw *compressedReadWriter) Read(p []byte) (n int, err error) { @@ -121,7 +134,12 @@ func (crw *compressedReadWriter) readFromConn() error { if err = ReadFull(crw.packetReadWriter, crw.header); err != nil { return err } - crw.sequence = crw.header[3] + 1 + compressedSequence := crw.header[3] + if compressedSequence != crw.sequence { + return ErrInvalidSequence.GenWithStack( + "invalid compressed sequence, expected %d, actual %d", crw.sequence, compressedSequence) + } + crw.sequence++ compressedLength := int(uint32(crw.header[0]) | uint32(crw.header[1])<<8 | uint32(crw.header[2])<<16) uncompressedLength := int(uint32(crw.header[4]) | uint32(crw.header[5])<<8 | uint32(crw.header[6])<<16) diff --git a/pkg/proxy/net/compress_test.go b/pkg/proxy/net/compress_test.go index 98e27dec..23c46840 100644 --- a/pkg/proxy/net/compress_test.go +++ b/pkg/proxy/net/compress_test.go @@ -125,6 +125,57 @@ func TestCompressPeekDiscard(t *testing.T) { }, 1) } +// Test that the uncompressed sequence is correct. +func TestCompressSequence(t *testing.T) { + lg, _ := logger.CreateLoggerForTest(t) + testkit.TestTCPConn(t, + func(t *testing.T, c net.Conn) { + crw := newCompressedReadWriter(newBasicReadWriter(c, DefaultConnBufferSize), CompressionZlib, 0, lg) + fillAndWrite(t, crw, 'a', 100) + fillAndWrite(t, crw, 'a', 100) + require.NoError(t, crw.Flush()) + require.Equal(t, uint8(2), crw.Sequence()) + // uncompressed sequence = compressed sequence + readAndCheck(t, crw, 'a', 100) + require.Equal(t, uint8(2), crw.Sequence()) + readAndCheck(t, crw, 'a', 100) + require.Equal(t, uint8(3), crw.Sequence()) + // uncompressed sequence = compressed sequence + fillAndWrite(t, crw, 'a', maxCompressedSize+1) + require.NoError(t, crw.Flush()) + require.Equal(t, uint8(3), crw.Sequence()) + // uncompressed sequence = compressed sequence + readAndCheck(t, crw, 'a', maxCompressedSize+1) + require.Equal(t, uint8(5), crw.Sequence()) + // flush empty buffer won't increase sequence + require.NoError(t, crw.Flush()) + require.NoError(t, crw.Flush()) + fillAndWrite(t, crw, 'a', 100) + require.NoError(t, crw.Flush()) + }, + func(t *testing.T, c net.Conn) { + crw := newCompressedReadWriter(newBasicReadWriter(c, DefaultConnBufferSize), CompressionZlib, 0, lg) + readAndCheck(t, crw, 'a', 100) + readAndCheck(t, crw, 'a', 100) + require.Equal(t, uint8(2), crw.Sequence()) + // uncompressed sequence = compressed sequence + fillAndWrite(t, crw, 'a', 100) + require.Equal(t, uint8(2), crw.Sequence()) + fillAndWrite(t, crw, 'a', 100) + require.Equal(t, uint8(3), crw.Sequence()) + require.NoError(t, crw.Flush()) + // uncompressed sequence = compressed sequence + readAndCheck(t, crw, 'a', maxCompressedSize+1) + require.Equal(t, uint8(3), crw.Sequence()) + // uncompressed sequence = compressed sequence + fillAndWrite(t, crw, 'a', maxCompressedSize+1) + require.NoError(t, crw.Flush()) + require.Equal(t, uint8(5), crw.Sequence()) + // flush empty buffer won't increase sequence + readAndCheck(t, crw, 'a', 100) + }, 1) +} + // Test that the compressed header is correctly filled. func TestCompressHeader(t *testing.T) { lg, _ := logger.CreateLoggerForTest(t) @@ -186,6 +237,7 @@ func TestReadWriteError(t *testing.T) { } func fillAndWrite(t *testing.T, crw *compressedReadWriter, b byte, length int) { + crw.BeginRW(rwWrite) data := fillData(b, length) _, err := crw.Write(data) require.NoError(t, err) @@ -201,6 +253,7 @@ func fillData(b byte, length int) []byte { } func readAndCheck(t *testing.T, crw *compressedReadWriter, b byte, length int) { + crw.BeginRW(rwRead) data := make([]byte, length) _, err := io.ReadFull(crw, data) require.NoError(t, err) diff --git a/pkg/proxy/net/packetio.go b/pkg/proxy/net/packetio.go index 14b2bbab..1e5fe865 100644 --- a/pkg/proxy/net/packetio.go +++ b/pkg/proxy/net/packetio.go @@ -30,6 +30,8 @@ import ( "net" "time" + "github.com/pingcap/tidb/errno" + "github.com/pingcap/tidb/util/dbterror" "github.com/pingcap/tiproxy/lib/config" "github.com/pingcap/tiproxy/lib/util/errors" "github.com/pingcap/tiproxy/pkg/proxy/keepalive" @@ -38,10 +40,22 @@ import ( "go.uber.org/zap" ) +var ( + ErrInvalidSequence = dbterror.ClassServer.NewStd(errno.ErrInvalidSequence) +) + const ( DefaultConnBufferSize = 32 * 1024 ) +type rwStatus int + +const ( + rwNone rwStatus = iota + rwRead + rwWrite +) + // packetReadWriter acts like a net.Conn with read and write buffer. type packetReadWriter interface { net.Conn @@ -60,6 +74,8 @@ type packetReadWriter interface { Sequence() uint8 // ResetSequence is called before executing a command. ResetSequence() + // BeginRW is called before reading or writing packets. + BeginRW(status rwStatus) } var _ packetReadWriter = (*basicReadWriter)(nil) @@ -127,6 +143,9 @@ func (brw *basicReadWriter) OutBytes() uint64 { return brw.outBytes } +func (brw *basicReadWriter) BeginRW(rwStatus) { +} + func (brw *basicReadWriter) ResetSequence() { brw.sequence = 0 } @@ -239,6 +258,7 @@ func (p *PacketIO) readOnePacket() ([]byte, bool, error) { // ReadPacket reads data and removes the header func (p *PacketIO) ReadPacket() (data []byte, err error) { + p.readWriter.BeginRW(rwRead) for more := true; more; { var buf []byte buf, more, err = p.readOnePacket() @@ -285,6 +305,7 @@ func (p *PacketIO) writeOnePacket(data []byte) (int, bool, error) { // WritePacket writes data without a header func (p *PacketIO) WritePacket(data []byte, flush bool) (err error) { + p.readWriter.BeginRW(rwWrite) for more := true; more; { var n int n, more, err = p.writeOnePacket(data) @@ -302,6 +323,8 @@ func (p *PacketIO) WritePacket(data []byte, flush bool) (err error) { func (p *PacketIO) ForwardUntil(dest *PacketIO, isEnd func(firstByte byte, firstPktLen int) (end, needData bool), process func(response []byte) error) error { + p.readWriter.BeginRW(rwRead) + dest.readWriter.BeginRW(rwWrite) p.limitReader.R = p.readWriter for { header, err := p.readWriter.Peek(5) From 3299ca34e4458b1305f0036b3d6ca01315f67baa Mon Sep 17 00:00:00 2001 From: xhe Date: Fri, 1 Dec 2023 17:25:45 +0800 Subject: [PATCH 5/9] address comments Signed-off-by: xhe --- conf/proxy.toml | 6 +++--- pkg/proxy/backend/error.go | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/conf/proxy.toml b/conf/proxy.toml index b688e429..0acd5c8e 100644 --- a/conf/proxy.toml +++ b/conf/proxy.toml @@ -3,12 +3,12 @@ [proxy] # addr = "0.0.0.0:6000" # tcp-keep-alive = true -# require-backend-tls = true +require-backend-tls = false # possible values: # "" => disable proxy protocol. # "v2" => accept proxy protocol if any, require backends to support proxy protocol. -# proxy-protocol = "" +proxy-protocol = "" # graceful-wait-before-shutdown is recommanded to be set to 0 when there's no other proxy(e.g. NLB) between the client and TiProxy. # possible values: @@ -26,7 +26,7 @@ graceful-close-conn-timeout = 15 # possible values: # "" => enable static routing. # "pd-addr:pd-port" => automatically tidb discovery. -# pd-addrs = "127.0.0.1:2379" +pd-addrs = "" # possible values: # 0 => no limitation. diff --git a/pkg/proxy/backend/error.go b/pkg/proxy/backend/error.go index 666d1199..a8ffd2ca 100644 --- a/pkg/proxy/backend/error.go +++ b/pkg/proxy/backend/error.go @@ -79,7 +79,7 @@ const ( SrcClientSQLErr // SrcProxyQuit includes: proxy graceful shutdown SrcProxyQuit - // SrcProxyMalformed includes: malformed packet; invalid sequence + // SrcProxyMalformed includes: malformed packet SrcProxyMalformed // SrcProxyNoBackend includes: no backends SrcProxyNoBackend From 903f69887e0a3b328a962690ae3930d5d5649cc4 Mon Sep 17 00:00:00 2001 From: xhe Date: Fri, 1 Dec 2023 17:32:10 +0800 Subject: [PATCH 6/9] remove config change Signed-off-by: xhe --- conf/proxy.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/conf/proxy.toml b/conf/proxy.toml index 0acd5c8e..b688e429 100644 --- a/conf/proxy.toml +++ b/conf/proxy.toml @@ -3,12 +3,12 @@ [proxy] # addr = "0.0.0.0:6000" # tcp-keep-alive = true -require-backend-tls = false +# require-backend-tls = true # possible values: # "" => disable proxy protocol. # "v2" => accept proxy protocol if any, require backends to support proxy protocol. -proxy-protocol = "" +# proxy-protocol = "" # graceful-wait-before-shutdown is recommanded to be set to 0 when there's no other proxy(e.g. NLB) between the client and TiProxy. # possible values: @@ -26,7 +26,7 @@ graceful-close-conn-timeout = 15 # possible values: # "" => enable static routing. # "pd-addr:pd-port" => automatically tidb discovery. -pd-addrs = "" +# pd-addrs = "127.0.0.1:2379" # possible values: # 0 => no limitation. From cdf33ef31ef25d12fd0019cd79a254f6ef5e2cc5 Mon Sep 17 00:00:00 2001 From: xhe Date: Fri, 1 Dec 2023 18:12:11 +0800 Subject: [PATCH 7/9] address comments Signed-off-by: xhe --- pkg/proxy/backend/authenticator.go | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index 4f1bafc8..deeaa5bb 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -9,6 +9,7 @@ import ( "encoding/binary" "fmt" "net" + "strings" "github.com/go-mysql-org/go-mysql/mysql" "github.com/pingcap/tidb/util/hack" @@ -212,11 +213,6 @@ loop: for { serverPkt, err := backendIO.ReadPacket() if err != nil { - // tiproxy pp enabled, tidb pp disabled, tls disabled => invalid sequence - // tiproxy pp disabled, tidb pp enabled, tls disabled => invalid sequence - if pktIdx == 0 { - return errors.Wrap(ErrBackendPPV2, err) - } return err } var packetErr *mysql.MyError @@ -235,6 +231,18 @@ loop: return err } if packetErr != nil { + // tiproxy pp enabled, tidb pp disabled, tls disabled => invalid sequence + // tiproxy pp disabled, tidb pp enabled, tls disabled => invalid sequence + for _, p := range []string{ + "packets out of order", + "(nvali): d sequence", + "invalid sequence", + "PROXY Protocol", + } { + if strings.Contains(packetErr.Message, p) { + return errors.Wrap(ErrBackendPPV2, packetErr) + } + } return errors.Wrap(ErrClientAuthFail, packetErr) } From 59d7ab3879d418b8845175e1cba43456faa12f8b Mon Sep 17 00:00:00 2001 From: xhe Date: Sat, 2 Dec 2023 10:09:15 +0800 Subject: [PATCH 8/9] address comments Signed-off-by: xhe --- pkg/proxy/backend/authenticator.go | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index deeaa5bb..7151d4b5 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -233,15 +233,13 @@ loop: if packetErr != nil { // tiproxy pp enabled, tidb pp disabled, tls disabled => invalid sequence // tiproxy pp disabled, tidb pp enabled, tls disabled => invalid sequence - for _, p := range []string{ - "packets out of order", - "(nvali): d sequence", - "invalid sequence", - "PROXY Protocol", - } { - if strings.Contains(packetErr.Message, p) { + if pktIdx == 0 { + if packetErr.Code == 1156 || packetErr.Code == 8052 { return errors.Wrap(ErrBackendPPV2, packetErr) } + if strings.Contains(packetErr.Message, "PROXY Protocol") { + return ErrBackendPPV2 + } } return errors.Wrap(ErrClientAuthFail, packetErr) } From 92eb72a033dcf191f63799058e040e65b83d6acf Mon Sep 17 00:00:00 2001 From: xhe Date: Fri, 22 Dec 2023 13:21:22 +0800 Subject: [PATCH 9/9] address comments Signed-off-by: xhe --- pkg/proxy/backend/authenticator.go | 33 ++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/pkg/proxy/backend/authenticator.go b/pkg/proxy/backend/authenticator.go index 7151d4b5..c753cd87 100644 --- a/pkg/proxy/backend/authenticator.go +++ b/pkg/proxy/backend/authenticator.go @@ -22,6 +22,7 @@ import ( const unknownAuthPlugin = "auth_unknown_plugin" const requiredFrontendCaps = pnet.ClientProtocol41 const defRequiredBackendCaps = pnet.ClientDeprecateEOF +const ER_INVALID_SEQUENCE = 8052 // SupportedServerCapabilities is the default supported capabilities. Other server capabilities are not supported. // TiDB supports ClientDeprecateEOF since v6.3.0. @@ -231,17 +232,7 @@ loop: return err } if packetErr != nil { - // tiproxy pp enabled, tidb pp disabled, tls disabled => invalid sequence - // tiproxy pp disabled, tidb pp enabled, tls disabled => invalid sequence - if pktIdx == 0 { - if packetErr.Code == 1156 || packetErr.Code == 8052 { - return errors.Wrap(ErrBackendPPV2, packetErr) - } - if strings.Contains(packetErr.Message, "PROXY Protocol") { - return ErrBackendPPV2 - } - } - return errors.Wrap(ErrClientAuthFail, packetErr) + return handleHandshakeError(pktIdx, packetErr) } pktIdx++ @@ -434,3 +425,23 @@ func setCompress(packetIO *pnet.PacketIO, capability pnet.Capability, zstdLevel } return packetIO.SetCompressionAlgorithm(algorithm, zstdLevel) } + +// handleHandshakeError tries to recognize the error and report more friendly messages. +func handleHandshakeError(pktIdx int, packetErr *mysql.MyError) error { + if pktIdx == 0 { + // PPV2 errors only appear in the first packet + + // mysql ERROR 1156: Got packets out of order (proxy ppv2 = true) + if packetErr.Code == mysql.ER_NET_PACKETS_OUT_OF_ORDER || + // tidb ERROR 8052: invalid sequence, received 10 while expecting 1 (proxy ppv2 = true, db ppv2 = false) + packetErr.Code == ER_INVALID_SEQUENCE { + return errors.Wrap(ErrBackendPPV2, packetErr) + } + // tidb ERROR 1105: invalid PROXY Protocol Header (proxy ppv2 = false, db ppv2 = true, db ppv2 fallback = false) + // 1105 is UNKNOWN_ERR, so we judge the error by messages + if strings.Contains(packetErr.Message, "PROXY Protocol") { + return ErrBackendPPV2 + } + } + return errors.Wrap(ErrClientAuthFail, packetErr) +}