diff --git a/lib/util/security/cert_test.go b/lib/util/security/cert_test.go index 7f032ded..8e983c86 100644 --- a/lib/util/security/cert_test.go +++ b/lib/util/security/cert_test.go @@ -6,12 +6,16 @@ package security import ( "crypto/tls" "crypto/x509" + "encoding/pem" + "net" + "os" "path/filepath" "testing" "time" "github.com/pingcap/tiproxy/lib/config" "github.com/pingcap/tiproxy/lib/util/logger" + "github.com/pingcap/tiproxy/lib/util/waitgroup" "github.com/stretchr/testify/require" ) @@ -357,3 +361,100 @@ func TestSetConfig(t *testing.T) { require.NoError(t, err) require.Nil(t, tcfg) } + +// Test that a cert pool can store multiple CAs and every CA works. +func TestCertPool(t *testing.T) { + tmpdir := t.TempDir() + lg, _ := logger.CreateLoggerForTest(t) + caPath1 := filepath.Join(tmpdir, "c1", "ca") + keyPath1 := filepath.Join(tmpdir, "c1", "key") + certPath1 := filepath.Join(tmpdir, "c1", "cert") + caPath2 := filepath.Join(tmpdir, "c2", "ca") + keyPath2 := filepath.Join(tmpdir, "c2", "key") + certPath2 := filepath.Join(tmpdir, "c2", "cert") + + require.NoError(t, CreateTLSCertificates(lg, certPath1, keyPath1, caPath1, 0, DefaultCertExpiration)) + require.NoError(t, CreateTLSCertificates(lg, certPath2, keyPath2, caPath2, 0, DefaultCertExpiration)) + + serverCfg := config.TLSConfig{ + Cert: certPath1, + Key: keyPath1, + } + serverCert := NewCert(true) + serverCert.cfg.Store(&serverCfg) + serverTLS, err := serverCert.Reload(lg) + require.NoError(t, err) + + clientCfg := config.TLSConfig{ + CA: caPath2, + } + clientCert := NewCert(false) + clientCert.cfg.Store(&clientCfg) + clientTLS, err := clientCert.Reload(lg) + require.NoError(t, err) + // caPath2 fails to verify certPath1. + clientErr, serverErr := connectWithTLS(clientTLS, serverTLS) + require.Error(t, clientErr) + require.Error(t, serverErr) + + // Add both caPath1 and caPath2 to the cert pool and it succeeds to verify certPath1. + err = loadCA(caPath1, clientCert.ca.Load()) + require.NoError(t, err) + clientErr, serverErr = connectWithTLS(clientTLS, serverTLS) + require.NoError(t, clientErr) + require.NoError(t, serverErr) + + // The cert pool can also verify certPath2. + serverCfg = config.TLSConfig{ + Cert: certPath2, + Key: keyPath2, + } + serverCert.cfg.Store(&serverCfg) + serverTLS, err = serverCert.Reload(lg) + require.NoError(t, err) + clientErr, serverErr = connectWithTLS(clientTLS, serverTLS) + require.NoError(t, clientErr) + require.NoError(t, serverErr) +} + +func loadCA(caPath string, pool *x509.CertPool) error { + pemCerts, err := os.ReadFile(caPath) + if err != nil { + return err + } + for len(pemCerts) > 0 { + var block *pem.Block + block, pemCerts = pem.Decode(pemCerts) + if block == nil { + break + } + if block.Type != "CERTIFICATE" || len(block.Headers) != 0 { + continue + } + + certBytes := block.Bytes + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + continue + } + pool.AddCert(cert) + } + return nil +} + +func connectWithTLS(ctls, stls *tls.Config) (clientErr, serverErr error) { + client, server := net.Pipe() + var wg waitgroup.WaitGroup + wg.Run(func() { + tlsConn := tls.Client(client, ctls) + clientErr = tlsConn.Handshake() + _ = client.Close() + }) + wg.Run(func() { + tlsConn := tls.Server(server, stls) + serverErr = tlsConn.Handshake() + _ = server.Close() + }) + wg.Wait() + return +}