Skip to content

Commit ef1765f

Browse files
authored
Add tls.Config.ClientCurveGuess to allow specifying which keyshares to send
RTG-2919
1 parent 6ba8a88 commit ef1765f

File tree

6 files changed

+129
-45
lines changed

6 files changed

+129
-45
lines changed

src/crypto/tls/cfkem.go

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ import (
2222
"fmt"
2323
"io"
2424

25-
"crypto/ecdh"
26-
2725
"github.com/cloudflare/circl/kem"
2826
"github.com/cloudflare/circl/kem/hybrid"
2927
)
3028

3129
// Either *ecdh.PrivateKey or *kemPrivateKey
32-
type clientKeySharePrivate interface{}
30+
type singleClientKeySharePrivate interface{}
31+
32+
type clientKeySharePrivate map[CurveID]singleClientKeySharePrivate
3333

3434
type kemPrivateKey struct {
3535
secretKey kem.PrivateKey
@@ -44,20 +44,9 @@ var (
4444
invalidCurveID = CurveID(0)
4545
)
4646

47-
// Extract CurveID from clientKeySharePrivate
48-
func clientKeySharePrivateCurveID(ks clientKeySharePrivate) CurveID {
49-
switch v := ks.(type) {
50-
case *kemPrivateKey:
51-
return v.curveID
52-
case *ecdh.PrivateKey:
53-
ret, ok := curveIDForCurve(v.Curve())
54-
if !ok {
55-
panic("cfkem: internal error: unknown curve")
56-
}
57-
return ret
58-
default:
59-
panic("cfkem: internal error: unknown clientKeySharePrivate")
60-
}
47+
func singleClientKeySharePrivateFor(ks clientKeySharePrivate, group CurveID) singleClientKeySharePrivate {
48+
ret, _ := ks[group]
49+
return ret
6150
}
6251

6352
// Returns scheme by CurveID if supported by Circl

src/crypto/tls/cfkem_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,48 @@ func TestHybridKEX(t *testing.T) {
104104
run(curveID, true, true, true, true)
105105
}
106106
}
107+
108+
func TestClientCurveGuess(t *testing.T) {
109+
run := func(guess, clientPrefs, serverPrefs []CurveID) {
110+
t.Run(
111+
fmt.Sprintf("guess=%v clientPrefs=%v serverPrefs=%v",
112+
guess, clientPrefs, serverPrefs),
113+
func(t *testing.T) {
114+
testClientCurveGuess(t, guess, clientPrefs, serverPrefs)
115+
})
116+
}
117+
both := []CurveID{X25519Kyber768Draft00, X25519}
118+
run([]CurveID{}, []CurveID{X25519}, both)
119+
run([]CurveID{X25519}, []CurveID{X25519}, both)
120+
run([]CurveID{X25519Kyber768Draft00}, both, []CurveID{X25519})
121+
run(both, both, both)
122+
run(both, both, []CurveID{X25519})
123+
run(both, both, []CurveID{X25519Kyber768Draft00})
124+
}
125+
126+
func testClientCurveGuess(t *testing.T, guess, clientPrefs, serverPrefs []CurveID) {
127+
clientConfig := testConfig.Clone()
128+
serverConfig := testConfig.Clone()
129+
serverConfig.CurvePreferences = serverPrefs
130+
clientConfig.CurvePreferences = clientPrefs
131+
clientConfig.ClientCurveGuess = guess
132+
133+
c, s := localPipe(t)
134+
done := make(chan error)
135+
defer c.Close()
136+
137+
go func() {
138+
defer s.Close()
139+
done <- Server(s, serverConfig).Handshake()
140+
}()
141+
142+
cli := Client(c, clientConfig)
143+
clientErr := cli.HandshakeContext(context.Background())
144+
serverErr := <-done
145+
if clientErr != nil {
146+
t.Errorf("client error: %v", clientErr)
147+
}
148+
if serverErr != nil {
149+
t.Errorf("server error: %v", serverErr)
150+
}
151+
}

src/crypto/tls/common.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,18 @@ type Config struct {
837837
// which is currently TLS 1.3.
838838
MaxVersion uint16
839839

840+
// ClientCurveGuess contains the "curves" for which the client will create
841+
// a keyshare in the initial ClientHello for TLS 1.3. If the client
842+
// guesses incorrectly, and the server does not support or does not
843+
// prefer those keyshares, then the server will return a HelloRetryRequest
844+
// incurring an extra roundtrip.
845+
//
846+
// If empty, no keyshares will be included in the ClientHello.
847+
//
848+
// If nil (default), will send the single most preferred keyshare
849+
// as configurable via CurvePreferences.
850+
ClientCurveGuess []CurveID
851+
840852
// CurvePreferences contains the elliptic curves that will be used in
841853
// an ECDHE handshake, in preference order. If empty, the default will
842854
// be used. The client will use the first preference as the type for
@@ -974,6 +986,7 @@ func (c *Config) Clone() *Config {
974986
MinVersion: c.MinVersion,
975987
MaxVersion: c.MaxVersion,
976988
CurvePreferences: c.CurvePreferences,
989+
ClientCurveGuess: c.ClientCurveGuess,
977990
PQSignatureSchemesEnabled: c.PQSignatureSchemesEnabled,
978991
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
979992
Renegotiation: c.Renegotiation,

src/crypto/tls/handshake_client.go

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ func (c *Conn) makeClientHello(minVersion uint16) (*clientHelloMsg, clientKeySha
134134
hello.supportedSignatureAlgorithms = testingOnlyForceClientHelloSignatureAlgorithms
135135
}
136136

137-
var secret clientKeySharePrivate
137+
secret := make(clientKeySharePrivate)
138138
if hello.supportedVersions[0] == VersionTLS13 {
139139
// Reset the list of ciphers when the client only supports TLS 1.3.
140140
if len(hello.supportedVersions) == 1 {
@@ -146,30 +146,61 @@ func (c *Conn) makeClientHello(minVersion uint16) (*clientHelloMsg, clientKeySha
146146
hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13NoAES...)
147147
}
148148

149-
curveID := config.curvePreferences()[0]
150-
if scheme := curveIdToCirclScheme(curveID); scheme != nil {
151-
pk, sk, err := generateKemKeyPair(scheme, curveID, config.rand())
152-
if err != nil {
153-
return nil, nil, fmt.Errorf("generateKemKeyPair %s: %w",
154-
scheme.Name(), err)
149+
curveIDs := []CurveID{config.curvePreferences()[0]}
150+
151+
if config.ClientCurveGuess != nil {
152+
curveIDs = config.ClientCurveGuess
153+
}
154+
155+
hello.keyShares = make([]keyShare, 0, len(curveIDs))
156+
157+
for _, curveID := range curveIDs {
158+
var (
159+
singleSecret interface{}
160+
singleShare []byte
161+
)
162+
163+
if _, ok := secret[curveID]; ok {
164+
return nil, nil, errors.New("tls: ClientCurveGuess contains duplicate")
155165
}
156-
packedPk, err := pk.MarshalBinary()
157-
if err != nil {
158-
return nil, nil, fmt.Errorf("pack circl public key %s: %w",
159-
scheme.Name(), err)
166+
167+
ok := false
168+
for _, curveID2 := range config.curvePreferences() {
169+
if curveID2 == curveID {
170+
ok = true
171+
break
172+
}
160173
}
161-
hello.keyShares = []keyShare{{group: curveID, data: packedPk}}
162-
secret = sk
163-
} else {
164-
if _, ok := curveForCurveID(curveID); !ok {
165-
return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve")
174+
if !ok {
175+
return nil, nil, errors.New("tls: ClientCurveGuess contains curve not in CurvePreferences")
166176
}
167-
key, err := generateECDHEKey(config.rand(), curveID)
168-
if err != nil {
169-
return nil, nil, err
177+
178+
if scheme := curveIdToCirclScheme(curveID); scheme != nil {
179+
pk, sk, err := generateKemKeyPair(scheme, curveID, config.rand())
180+
if err != nil {
181+
return nil, nil, fmt.Errorf("generateKemKeyPair %s: %w",
182+
scheme.Name(), err)
183+
}
184+
packedPk, err := pk.MarshalBinary()
185+
if err != nil {
186+
return nil, nil, fmt.Errorf("pack circl public key %s: %w",
187+
scheme.Name(), err)
188+
}
189+
singleShare = packedPk
190+
singleSecret = sk
191+
} else {
192+
if _, ok := curveForCurveID(curveID); !ok {
193+
return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve")
194+
}
195+
key, err := generateECDHEKey(config.rand(), curveID)
196+
if err != nil {
197+
return nil, nil, err
198+
}
199+
singleShare = key.PublicKey().Bytes()
200+
singleSecret = key
170201
}
171-
hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
172-
secret = key
202+
hello.keyShares = append(hello.keyShares, keyShare{group: curveID, data: singleShare})
203+
secret[curveID] = singleSecret
173204
}
174205

175206
hello.delegatedCredentialSupported = config.SupportDelegatedCredential

src/crypto/tls/handshake_client_tls13.go

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
103103
}
104104

105105
// Consistency check on the presence of a keyShare and its parameters.
106-
if hs.keySharePrivate == nil || len(hs.hello.keyShares) != 1 {
106+
if hs.keySharePrivate == nil {
107107
return c.sendAlert(alertInternalError)
108108
}
109109

@@ -379,7 +379,7 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
379379
c.sendAlert(alertIllegalParameter)
380380
return errors.New("tls: server selected unsupported group")
381381
}
382-
if clientKeySharePrivateCurveID(hs.keySharePrivate) == curveID {
382+
if singleClientKeySharePrivateFor(hs.keySharePrivate, curveID) != nil {
383383
c.sendAlert(alertIllegalParameter)
384384
return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share")
385385
}
@@ -396,7 +396,7 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
396396
return fmt.Errorf("HRR pack circl public key %s: %w",
397397
scheme.Name(), err)
398398
}
399-
hs.keySharePrivate = sk
399+
hs.keySharePrivate = clientKeySharePrivate{curveID: sk}
400400
hello.keyShares = []keyShare{{group: curveID, data: packedPk}}
401401
} else {
402402
if _, ok := curveForCurveID(curveID); !ok {
@@ -408,7 +408,7 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
408408
c.sendAlert(alertInternalError)
409409
return err
410410
}
411-
hs.keySharePrivate = key
411+
hs.keySharePrivate = clientKeySharePrivate{curveID: key}
412412
hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
413413
}
414414
}
@@ -558,7 +558,7 @@ func (hs *clientHandshakeStateTLS13) processServerHello() error {
558558
c.sendAlert(alertIllegalParameter)
559559
return errors.New("tls: server did not send a key share")
560560
}
561-
if hs.serverHello.serverShare.group != clientKeySharePrivateCurveID(hs.keySharePrivate) {
561+
if singleClientKeySharePrivateFor(hs.keySharePrivate, hs.serverHello.serverShare.group) == nil {
562562
c.sendAlert(alertIllegalParameter)
563563
return errors.New("tls: server selected unsupported group")
564564
}
@@ -613,12 +613,16 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {
613613

614614
var sharedKey []byte
615615
var err error
616-
if key, ok := hs.keySharePrivate.(*ecdh.PrivateKey); ok {
616+
617+
// We already checked that ks isn't nil in processServerHello()
618+
ks := singleClientKeySharePrivateFor(hs.keySharePrivate, hs.serverHello.serverShare.group)
619+
620+
if key, ok := ks.(*ecdh.PrivateKey); ok {
617621
peerKey, err := key.Curve().NewPublicKey(hs.serverHello.serverShare.data)
618622
if err == nil {
619623
sharedKey, _ = key.ECDH(peerKey)
620624
}
621-
} else if key, ok := hs.keySharePrivate.(*kemPrivateKey); ok {
625+
} else if key, ok := ks.(*kemPrivateKey); ok {
622626
sk := key.secretKey
623627
sharedKey, err = sk.Scheme().Decapsulate(sk, hs.serverHello.serverShare.data)
624628
if err != nil {

src/crypto/tls/tls_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,8 @@ func TestCloneNonFuncFields(t *testing.T) {
865865
f.Set(reflect.ValueOf([]uint16{1, 2}))
866866
case "CurvePreferences":
867867
f.Set(reflect.ValueOf([]CurveID{CurveP256}))
868+
case "ClientCurveGuess":
869+
f.Set(reflect.ValueOf([]CurveID{CurveP256}))
868870
case "PQSignatureSchemesEnabled":
869871
f.Set(reflect.ValueOf(true))
870872
case "Renegotiation":

0 commit comments

Comments
 (0)