Skip to content

Commit d40bfdb

Browse files
authored
security: add unit test for cert pool (#496)
1 parent 1a2c0c8 commit d40bfdb

File tree

1 file changed

+101
-0
lines changed

1 file changed

+101
-0
lines changed

lib/util/security/cert_test.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@ package security
66
import (
77
"crypto/tls"
88
"crypto/x509"
9+
"encoding/pem"
10+
"net"
11+
"os"
912
"path/filepath"
1013
"testing"
1114
"time"
1215

1316
"github.com/pingcap/tiproxy/lib/config"
1417
"github.com/pingcap/tiproxy/lib/util/logger"
18+
"github.com/pingcap/tiproxy/lib/util/waitgroup"
1519
"github.com/stretchr/testify/require"
1620
)
1721

@@ -357,3 +361,100 @@ func TestSetConfig(t *testing.T) {
357361
require.NoError(t, err)
358362
require.Nil(t, tcfg)
359363
}
364+
365+
// Test that a cert pool can store multiple CAs and every CA works.
366+
func TestCertPool(t *testing.T) {
367+
tmpdir := t.TempDir()
368+
lg, _ := logger.CreateLoggerForTest(t)
369+
caPath1 := filepath.Join(tmpdir, "c1", "ca")
370+
keyPath1 := filepath.Join(tmpdir, "c1", "key")
371+
certPath1 := filepath.Join(tmpdir, "c1", "cert")
372+
caPath2 := filepath.Join(tmpdir, "c2", "ca")
373+
keyPath2 := filepath.Join(tmpdir, "c2", "key")
374+
certPath2 := filepath.Join(tmpdir, "c2", "cert")
375+
376+
require.NoError(t, CreateTLSCertificates(lg, certPath1, keyPath1, caPath1, 0, DefaultCertExpiration))
377+
require.NoError(t, CreateTLSCertificates(lg, certPath2, keyPath2, caPath2, 0, DefaultCertExpiration))
378+
379+
serverCfg := config.TLSConfig{
380+
Cert: certPath1,
381+
Key: keyPath1,
382+
}
383+
serverCert := NewCert(true)
384+
serverCert.cfg.Store(&serverCfg)
385+
serverTLS, err := serverCert.Reload(lg)
386+
require.NoError(t, err)
387+
388+
clientCfg := config.TLSConfig{
389+
CA: caPath2,
390+
}
391+
clientCert := NewCert(false)
392+
clientCert.cfg.Store(&clientCfg)
393+
clientTLS, err := clientCert.Reload(lg)
394+
require.NoError(t, err)
395+
// caPath2 fails to verify certPath1.
396+
clientErr, serverErr := connectWithTLS(clientTLS, serverTLS)
397+
require.Error(t, clientErr)
398+
require.Error(t, serverErr)
399+
400+
// Add both caPath1 and caPath2 to the cert pool and it succeeds to verify certPath1.
401+
err = loadCA(caPath1, clientCert.ca.Load())
402+
require.NoError(t, err)
403+
clientErr, serverErr = connectWithTLS(clientTLS, serverTLS)
404+
require.NoError(t, clientErr)
405+
require.NoError(t, serverErr)
406+
407+
// The cert pool can also verify certPath2.
408+
serverCfg = config.TLSConfig{
409+
Cert: certPath2,
410+
Key: keyPath2,
411+
}
412+
serverCert.cfg.Store(&serverCfg)
413+
serverTLS, err = serverCert.Reload(lg)
414+
require.NoError(t, err)
415+
clientErr, serverErr = connectWithTLS(clientTLS, serverTLS)
416+
require.NoError(t, clientErr)
417+
require.NoError(t, serverErr)
418+
}
419+
420+
func loadCA(caPath string, pool *x509.CertPool) error {
421+
pemCerts, err := os.ReadFile(caPath)
422+
if err != nil {
423+
return err
424+
}
425+
for len(pemCerts) > 0 {
426+
var block *pem.Block
427+
block, pemCerts = pem.Decode(pemCerts)
428+
if block == nil {
429+
break
430+
}
431+
if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
432+
continue
433+
}
434+
435+
certBytes := block.Bytes
436+
cert, err := x509.ParseCertificate(certBytes)
437+
if err != nil {
438+
continue
439+
}
440+
pool.AddCert(cert)
441+
}
442+
return nil
443+
}
444+
445+
func connectWithTLS(ctls, stls *tls.Config) (clientErr, serverErr error) {
446+
client, server := net.Pipe()
447+
var wg waitgroup.WaitGroup
448+
wg.Run(func() {
449+
tlsConn := tls.Client(client, ctls)
450+
clientErr = tlsConn.Handshake()
451+
_ = client.Close()
452+
})
453+
wg.Run(func() {
454+
tlsConn := tls.Server(server, stls)
455+
serverErr = tlsConn.Handshake()
456+
_ = server.Close()
457+
})
458+
wg.Wait()
459+
return
460+
}

0 commit comments

Comments
 (0)