Skip to content

Commit

Permalink
Add tls.Config.ClientCurveGuess to allow specifying which keyshares t…
Browse files Browse the repository at this point in the history
…o send

RTG-2919
  • Loading branch information
bwesterb authored Oct 2, 2023
1 parent 6ba8a88 commit ef1765f
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 45 deletions.
23 changes: 6 additions & 17 deletions src/crypto/tls/cfkem.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ import (
"fmt"
"io"

"crypto/ecdh"

"github.com/cloudflare/circl/kem"
"github.com/cloudflare/circl/kem/hybrid"
)

// Either *ecdh.PrivateKey or *kemPrivateKey
type clientKeySharePrivate interface{}
type singleClientKeySharePrivate interface{}

type clientKeySharePrivate map[CurveID]singleClientKeySharePrivate

type kemPrivateKey struct {
secretKey kem.PrivateKey
Expand All @@ -44,20 +44,9 @@ var (
invalidCurveID = CurveID(0)
)

// Extract CurveID from clientKeySharePrivate
func clientKeySharePrivateCurveID(ks clientKeySharePrivate) CurveID {
switch v := ks.(type) {
case *kemPrivateKey:
return v.curveID
case *ecdh.PrivateKey:
ret, ok := curveIDForCurve(v.Curve())
if !ok {
panic("cfkem: internal error: unknown curve")
}
return ret
default:
panic("cfkem: internal error: unknown clientKeySharePrivate")
}
func singleClientKeySharePrivateFor(ks clientKeySharePrivate, group CurveID) singleClientKeySharePrivate {
ret, _ := ks[group]
return ret
}

// Returns scheme by CurveID if supported by Circl
Expand Down
45 changes: 45 additions & 0 deletions src/crypto/tls/cfkem_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,48 @@ func TestHybridKEX(t *testing.T) {
run(curveID, true, true, true, true)
}
}

func TestClientCurveGuess(t *testing.T) {
run := func(guess, clientPrefs, serverPrefs []CurveID) {
t.Run(
fmt.Sprintf("guess=%v clientPrefs=%v serverPrefs=%v",
guess, clientPrefs, serverPrefs),
func(t *testing.T) {
testClientCurveGuess(t, guess, clientPrefs, serverPrefs)
})
}
both := []CurveID{X25519Kyber768Draft00, X25519}
run([]CurveID{}, []CurveID{X25519}, both)
run([]CurveID{X25519}, []CurveID{X25519}, both)
run([]CurveID{X25519Kyber768Draft00}, both, []CurveID{X25519})
run(both, both, both)
run(both, both, []CurveID{X25519})
run(both, both, []CurveID{X25519Kyber768Draft00})
}

func testClientCurveGuess(t *testing.T, guess, clientPrefs, serverPrefs []CurveID) {
clientConfig := testConfig.Clone()
serverConfig := testConfig.Clone()
serverConfig.CurvePreferences = serverPrefs
clientConfig.CurvePreferences = clientPrefs
clientConfig.ClientCurveGuess = guess

c, s := localPipe(t)
done := make(chan error)
defer c.Close()

go func() {
defer s.Close()
done <- Server(s, serverConfig).Handshake()
}()

cli := Client(c, clientConfig)
clientErr := cli.HandshakeContext(context.Background())
serverErr := <-done
if clientErr != nil {
t.Errorf("client error: %v", clientErr)
}
if serverErr != nil {
t.Errorf("server error: %v", serverErr)
}
}
13 changes: 13 additions & 0 deletions src/crypto/tls/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,18 @@ type Config struct {
// which is currently TLS 1.3.
MaxVersion uint16

// ClientCurveGuess contains the "curves" for which the client will create
// a keyshare in the initial ClientHello for TLS 1.3. If the client
// guesses incorrectly, and the server does not support or does not
// prefer those keyshares, then the server will return a HelloRetryRequest
// incurring an extra roundtrip.
//
// If empty, no keyshares will be included in the ClientHello.
//
// If nil (default), will send the single most preferred keyshare
// as configurable via CurvePreferences.
ClientCurveGuess []CurveID

// CurvePreferences contains the elliptic curves that will be used in
// an ECDHE handshake, in preference order. If empty, the default will
// be used. The client will use the first preference as the type for
Expand Down Expand Up @@ -974,6 +986,7 @@ func (c *Config) Clone() *Config {
MinVersion: c.MinVersion,
MaxVersion: c.MaxVersion,
CurvePreferences: c.CurvePreferences,
ClientCurveGuess: c.ClientCurveGuess,
PQSignatureSchemesEnabled: c.PQSignatureSchemesEnabled,
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
Renegotiation: c.Renegotiation,
Expand Down
73 changes: 52 additions & 21 deletions src/crypto/tls/handshake_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func (c *Conn) makeClientHello(minVersion uint16) (*clientHelloMsg, clientKeySha
hello.supportedSignatureAlgorithms = testingOnlyForceClientHelloSignatureAlgorithms
}

var secret clientKeySharePrivate
secret := make(clientKeySharePrivate)
if hello.supportedVersions[0] == VersionTLS13 {
// Reset the list of ciphers when the client only supports TLS 1.3.
if len(hello.supportedVersions) == 1 {
Expand All @@ -146,30 +146,61 @@ func (c *Conn) makeClientHello(minVersion uint16) (*clientHelloMsg, clientKeySha
hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13NoAES...)
}

curveID := config.curvePreferences()[0]
if scheme := curveIdToCirclScheme(curveID); scheme != nil {
pk, sk, err := generateKemKeyPair(scheme, curveID, config.rand())
if err != nil {
return nil, nil, fmt.Errorf("generateKemKeyPair %s: %w",
scheme.Name(), err)
curveIDs := []CurveID{config.curvePreferences()[0]}

if config.ClientCurveGuess != nil {
curveIDs = config.ClientCurveGuess
}

hello.keyShares = make([]keyShare, 0, len(curveIDs))

for _, curveID := range curveIDs {
var (
singleSecret interface{}
singleShare []byte
)

if _, ok := secret[curveID]; ok {
return nil, nil, errors.New("tls: ClientCurveGuess contains duplicate")
}
packedPk, err := pk.MarshalBinary()
if err != nil {
return nil, nil, fmt.Errorf("pack circl public key %s: %w",
scheme.Name(), err)

ok := false
for _, curveID2 := range config.curvePreferences() {
if curveID2 == curveID {
ok = true
break
}
}
hello.keyShares = []keyShare{{group: curveID, data: packedPk}}
secret = sk
} else {
if _, ok := curveForCurveID(curveID); !ok {
return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve")
if !ok {
return nil, nil, errors.New("tls: ClientCurveGuess contains curve not in CurvePreferences")
}
key, err := generateECDHEKey(config.rand(), curveID)
if err != nil {
return nil, nil, err

if scheme := curveIdToCirclScheme(curveID); scheme != nil {
pk, sk, err := generateKemKeyPair(scheme, curveID, config.rand())
if err != nil {
return nil, nil, fmt.Errorf("generateKemKeyPair %s: %w",
scheme.Name(), err)
}
packedPk, err := pk.MarshalBinary()
if err != nil {
return nil, nil, fmt.Errorf("pack circl public key %s: %w",
scheme.Name(), err)
}
singleShare = packedPk
singleSecret = sk
} else {
if _, ok := curveForCurveID(curveID); !ok {
return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve")
}
key, err := generateECDHEKey(config.rand(), curveID)
if err != nil {
return nil, nil, err
}
singleShare = key.PublicKey().Bytes()
singleSecret = key
}
hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
secret = key
hello.keyShares = append(hello.keyShares, keyShare{group: curveID, data: singleShare})
secret[curveID] = singleSecret
}

hello.delegatedCredentialSupported = config.SupportDelegatedCredential
Expand Down
18 changes: 11 additions & 7 deletions src/crypto/tls/handshake_client_tls13.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
}

// Consistency check on the presence of a keyShare and its parameters.
if hs.keySharePrivate == nil || len(hs.hello.keyShares) != 1 {
if hs.keySharePrivate == nil {
return c.sendAlert(alertInternalError)
}

Expand Down Expand Up @@ -379,7 +379,7 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected unsupported group")
}
if clientKeySharePrivateCurveID(hs.keySharePrivate) == curveID {
if singleClientKeySharePrivateFor(hs.keySharePrivate, curveID) != nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share")
}
Expand All @@ -396,7 +396,7 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
return fmt.Errorf("HRR pack circl public key %s: %w",
scheme.Name(), err)
}
hs.keySharePrivate = sk
hs.keySharePrivate = clientKeySharePrivate{curveID: sk}
hello.keyShares = []keyShare{{group: curveID, data: packedPk}}
} else {
if _, ok := curveForCurveID(curveID); !ok {
Expand All @@ -408,7 +408,7 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
c.sendAlert(alertInternalError)
return err
}
hs.keySharePrivate = key
hs.keySharePrivate = clientKeySharePrivate{curveID: key}
hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
}
}
Expand Down Expand Up @@ -558,7 +558,7 @@ func (hs *clientHandshakeStateTLS13) processServerHello() error {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server did not send a key share")
}
if hs.serverHello.serverShare.group != clientKeySharePrivateCurveID(hs.keySharePrivate) {
if singleClientKeySharePrivateFor(hs.keySharePrivate, hs.serverHello.serverShare.group) == nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: server selected unsupported group")
}
Expand Down Expand Up @@ -613,12 +613,16 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {

var sharedKey []byte
var err error
if key, ok := hs.keySharePrivate.(*ecdh.PrivateKey); ok {

// We already checked that ks isn't nil in processServerHello()
ks := singleClientKeySharePrivateFor(hs.keySharePrivate, hs.serverHello.serverShare.group)

if key, ok := ks.(*ecdh.PrivateKey); ok {
peerKey, err := key.Curve().NewPublicKey(hs.serverHello.serverShare.data)
if err == nil {
sharedKey, _ = key.ECDH(peerKey)
}
} else if key, ok := hs.keySharePrivate.(*kemPrivateKey); ok {
} else if key, ok := ks.(*kemPrivateKey); ok {
sk := key.secretKey
sharedKey, err = sk.Scheme().Decapsulate(sk, hs.serverHello.serverShare.data)
if err != nil {
Expand Down
2 changes: 2 additions & 0 deletions src/crypto/tls/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,8 @@ func TestCloneNonFuncFields(t *testing.T) {
f.Set(reflect.ValueOf([]uint16{1, 2}))
case "CurvePreferences":
f.Set(reflect.ValueOf([]CurveID{CurveP256}))
case "ClientCurveGuess":
f.Set(reflect.ValueOf([]CurveID{CurveP256}))
case "PQSignatureSchemesEnabled":
f.Set(reflect.ValueOf(true))
case "Renegotiation":
Expand Down

0 comments on commit ef1765f

Please sign in to comment.