@@ -6,12 +6,16 @@ package security
6
6
import (
7
7
"crypto/tls"
8
8
"crypto/x509"
9
+ "encoding/pem"
10
+ "net"
11
+ "os"
9
12
"path/filepath"
10
13
"testing"
11
14
"time"
12
15
13
16
"github.com/pingcap/tiproxy/lib/config"
14
17
"github.com/pingcap/tiproxy/lib/util/logger"
18
+ "github.com/pingcap/tiproxy/lib/util/waitgroup"
15
19
"github.com/stretchr/testify/require"
16
20
)
17
21
@@ -357,3 +361,100 @@ func TestSetConfig(t *testing.T) {
357
361
require .NoError (t , err )
358
362
require .Nil (t , tcfg )
359
363
}
364
+
365
+ // Test that a cert pool can store multiple CAs and every CA works.
366
+ func TestCertPool (t * testing.T ) {
367
+ tmpdir := t .TempDir ()
368
+ lg , _ := logger .CreateLoggerForTest (t )
369
+ caPath1 := filepath .Join (tmpdir , "c1" , "ca" )
370
+ keyPath1 := filepath .Join (tmpdir , "c1" , "key" )
371
+ certPath1 := filepath .Join (tmpdir , "c1" , "cert" )
372
+ caPath2 := filepath .Join (tmpdir , "c2" , "ca" )
373
+ keyPath2 := filepath .Join (tmpdir , "c2" , "key" )
374
+ certPath2 := filepath .Join (tmpdir , "c2" , "cert" )
375
+
376
+ require .NoError (t , CreateTLSCertificates (lg , certPath1 , keyPath1 , caPath1 , 0 , DefaultCertExpiration ))
377
+ require .NoError (t , CreateTLSCertificates (lg , certPath2 , keyPath2 , caPath2 , 0 , DefaultCertExpiration ))
378
+
379
+ serverCfg := config.TLSConfig {
380
+ Cert : certPath1 ,
381
+ Key : keyPath1 ,
382
+ }
383
+ serverCert := NewCert (true )
384
+ serverCert .cfg .Store (& serverCfg )
385
+ serverTLS , err := serverCert .Reload (lg )
386
+ require .NoError (t , err )
387
+
388
+ clientCfg := config.TLSConfig {
389
+ CA : caPath2 ,
390
+ }
391
+ clientCert := NewCert (false )
392
+ clientCert .cfg .Store (& clientCfg )
393
+ clientTLS , err := clientCert .Reload (lg )
394
+ require .NoError (t , err )
395
+ // caPath2 fails to verify certPath1.
396
+ clientErr , serverErr := connectWithTLS (clientTLS , serverTLS )
397
+ require .Error (t , clientErr )
398
+ require .Error (t , serverErr )
399
+
400
+ // Add both caPath1 and caPath2 to the cert pool and it succeeds to verify certPath1.
401
+ err = loadCA (caPath1 , clientCert .ca .Load ())
402
+ require .NoError (t , err )
403
+ clientErr , serverErr = connectWithTLS (clientTLS , serverTLS )
404
+ require .NoError (t , clientErr )
405
+ require .NoError (t , serverErr )
406
+
407
+ // The cert pool can also verify certPath2.
408
+ serverCfg = config.TLSConfig {
409
+ Cert : certPath2 ,
410
+ Key : keyPath2 ,
411
+ }
412
+ serverCert .cfg .Store (& serverCfg )
413
+ serverTLS , err = serverCert .Reload (lg )
414
+ require .NoError (t , err )
415
+ clientErr , serverErr = connectWithTLS (clientTLS , serverTLS )
416
+ require .NoError (t , clientErr )
417
+ require .NoError (t , serverErr )
418
+ }
419
+
420
+ func loadCA (caPath string , pool * x509.CertPool ) error {
421
+ pemCerts , err := os .ReadFile (caPath )
422
+ if err != nil {
423
+ return err
424
+ }
425
+ for len (pemCerts ) > 0 {
426
+ var block * pem.Block
427
+ block , pemCerts = pem .Decode (pemCerts )
428
+ if block == nil {
429
+ break
430
+ }
431
+ if block .Type != "CERTIFICATE" || len (block .Headers ) != 0 {
432
+ continue
433
+ }
434
+
435
+ certBytes := block .Bytes
436
+ cert , err := x509 .ParseCertificate (certBytes )
437
+ if err != nil {
438
+ continue
439
+ }
440
+ pool .AddCert (cert )
441
+ }
442
+ return nil
443
+ }
444
+
445
+ func connectWithTLS (ctls , stls * tls.Config ) (clientErr , serverErr error ) {
446
+ client , server := net .Pipe ()
447
+ var wg waitgroup.WaitGroup
448
+ wg .Run (func () {
449
+ tlsConn := tls .Client (client , ctls )
450
+ clientErr = tlsConn .Handshake ()
451
+ _ = client .Close ()
452
+ })
453
+ wg .Run (func () {
454
+ tlsConn := tls .Server (server , stls )
455
+ serverErr = tlsConn .Handshake ()
456
+ _ = server .Close ()
457
+ })
458
+ wg .Wait ()
459
+ return
460
+ }
0 commit comments