Skip to content

Commit

Permalink
Add hybrid post-quantum key agreement.
Browse files Browse the repository at this point in the history
Adds X25519Kyber512Draft00 and X25519Kyber768Draft00 hybrid post-quantum key
agreements with temporary group identifiers.

The hybrid post-quantum key exchanges uses plain X{25519,448} instead
of HPKE, which we assume will be more likely to be adopted. The order
is chosen to match CECPQ2.

Not enabled by default.

Adds CFEvents to detect `HelloRetryRequest`s and to signal which
key agreement was used.

Cf #121 #122 #123 #132

Co-authored-by: Christopher Wood <[email protected]>

[ bas, 1.20.1: also adds P256Kyber768Draft00 ]
  • Loading branch information
bwesterb committed Mar 1, 2023
1 parent 4b880d3 commit 837598b
Show file tree
Hide file tree
Showing 59 changed files with 11,098 additions and 65 deletions.
113 changes: 113 additions & 0 deletions src/crypto/tls/cfkem.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright 2022 Cloudflare, Inc. All rights reserved. Use of this source code
// is governed by a BSD-style license that can be found in the LICENSE file.
//
// Glue to add Circl's (post-quantum) hybrid KEMs.
//
// To enable set CurvePreferences with the desired scheme as the first element:
//
// import (
// "github.com/cloudflare/circl/kem/tls"
// "github.com/cloudflare/circl/kem/hybrid"
//
// [...]
//
// config.CurvePreferences = []tls.CurveID{
// hybrid.X25519Kyber512Draft00().(tls.TLSScheme).TLSCurveID(),
// tls.X25519,
// tls.P256,
// }

package tls

import (
"fmt"
"io"

"crypto/ecdh"

"github.com/cloudflare/circl/kem"
"github.com/cloudflare/circl/kem/hybrid"
)

// Either ecdheParameters or kem.PrivateKey
type clientKeySharePrivate interface{}

var (
X25519Kyber512Draft00 = CurveID(0xfe30)
X25519Kyber768Draft00 = CurveID(0xfe31)
P256Kyber768Draft00 = CurveID(0xfe32)
invalidCurveID = CurveID(0)
)

func kemSchemeKeyToCurveID(s kem.Scheme) CurveID {
switch s.Name() {
case "Kyber512-X25519":
return X25519Kyber512Draft00
case "Kyber768-X25519":
return X25519Kyber768Draft00
case "P256Kyber768Draft00":
return P256Kyber768Draft00
default:
return invalidCurveID
}
}

// Extract CurveID from clientKeySharePrivate
func clientKeySharePrivateCurveID(ks clientKeySharePrivate) CurveID {
switch v := ks.(type) {
case kem.PrivateKey:
ret := kemSchemeKeyToCurveID(v.Scheme())
if ret == invalidCurveID {
panic("cfkem: internal error: don't know CurveID for this KEM")
}
return ret
case *ecdh.PrivateKey:
ret, ok := curveIDForCurve(v.Curve())
if !ok {
panic("cfkem: internal error: unknown curve")
}
return ret
default:
panic("cfkem: internal error: unknown clientKeySharePrivate")
}
}

// Returns scheme by CurveID if supported by Circl
func curveIdToCirclScheme(id CurveID) kem.Scheme {
switch id {
case X25519Kyber512Draft00:
return hybrid.Kyber512X25519()
case X25519Kyber768Draft00:
return hybrid.Kyber768X25519()
case P256Kyber768Draft00:
return hybrid.P256Kyber768Draft00()
}
return nil
}

// Generate a new shared secret and encapsulates it for the packed
// public key in ppk using randomness from rnd.
func encapsulateForKem(scheme kem.Scheme, rnd io.Reader, ppk []byte) (
ct, ss []byte, alert alert, err error) {
pk, err := scheme.UnmarshalBinaryPublicKey(ppk)
if err != nil {
return nil, nil, alertIllegalParameter, fmt.Errorf("unpack pk: %w", err)
}
seed := make([]byte, scheme.EncapsulationSeedSize())
if _, err := io.ReadFull(rnd, seed); err != nil {
return nil, nil, alertInternalError, fmt.Errorf("random: %w", err)
}
ct, ss, err = scheme.EncapsulateDeterministically(pk, seed)
return ct, ss, alertIllegalParameter, err
}

// Generate a new keypair using randomness from rnd.
func generateKemKeyPair(scheme kem.Scheme, rnd io.Reader) (
kem.PublicKey, kem.PrivateKey, error) {
seed := make([]byte, scheme.SeedSize())
if _, err := io.ReadFull(rnd, seed); err != nil {
return nil, nil, err
}
pk, sk := scheme.DeriveKeyPair(seed)
return pk, sk, nil
}
119 changes: 119 additions & 0 deletions src/crypto/tls/cfkem_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Copyright 2022 Cloudflare, Inc. All rights reserved. Use of this source code
// is governed by a BSD-style license that can be found in the LICENSE file.

package tls

import (
"fmt"
"testing"

"github.com/cloudflare/circl/kem"
"github.com/cloudflare/circl/kem/hybrid"
)

func testHybridKEX(t *testing.T, scheme kem.Scheme, clientPQ, serverPQ,
clientTLS12, serverTLS12 bool) {
var clientSelectedKEX *CurveID
var retry bool

rsaCert := Certificate{
Certificate: [][]byte{testRSACertificate},
PrivateKey: testRSAPrivateKey,
}
serverCerts := []Certificate{rsaCert}

clientConfig := testConfig.Clone()
if clientPQ {
clientConfig.CurvePreferences = []CurveID{
kemSchemeKeyToCurveID(scheme),
X25519,
}
}
clientConfig.CFEventHandler = func(ev CFEvent) {
switch e := ev.(type) {
case CFEventTLSNegotiatedNamedKEX:
clientSelectedKEX = &e.KEX
case CFEventTLS13HRR:
retry = true
}
}
if clientTLS12 {
clientConfig.MaxVersion = VersionTLS12
}

serverConfig := testConfig.Clone()
if serverPQ {
serverConfig.CurvePreferences = []CurveID{
kemSchemeKeyToCurveID(scheme),
X25519,
}
}
if serverTLS12 {
serverConfig.MaxVersion = VersionTLS12
}
serverConfig.Certificates = serverCerts

c, s := localPipe(t)
done := make(chan error)
defer c.Close()

go func() {
defer s.Close()
done <- Server(s, serverConfig).Handshake()
}()

cli := Client(c, clientConfig)
clientErr := cli.Handshake()
serverErr := <-done
if clientErr != nil {
t.Errorf("client error: %s", clientErr)
}
if serverErr != nil {
t.Errorf("server error: %s", serverErr)
}

var expectedKEX CurveID
var expectedRetry bool

if clientPQ && serverPQ && !clientTLS12 && !serverTLS12 {
expectedKEX = kemSchemeKeyToCurveID(scheme)
} else {
expectedKEX = X25519
}
if !clientTLS12 && clientPQ && !serverPQ {
expectedRetry = true
}

if clientSelectedKEX == nil {
t.Error("No KEX happened?")
}

if *clientSelectedKEX != expectedKEX {
t.Errorf("failed to negotiate: expected %d, got %d",
expectedKEX, *clientSelectedKEX)
}
if expectedRetry != retry {
t.Errorf("Expected retry=%v, got retry=%v", expectedRetry, retry)
}
}

func TestHybridKEX(t *testing.T) {
run := func(scheme kem.Scheme, clientPQ, serverPQ, clientTLS12, serverTLS12 bool) {
t.Run(fmt.Sprintf("%s serverPQ:%v clientPQ:%v serverTLS12:%v clientTLS12:%v", scheme.Name(),
serverPQ, clientPQ, serverTLS12, clientTLS12), func(t *testing.T) {
testHybridKEX(t, scheme, clientPQ, serverPQ, clientTLS12, serverTLS12)
})
}
for _, scheme := range []kem.Scheme{
hybrid.Kyber512X25519(),
hybrid.Kyber768X25519(),
hybrid.P256Kyber768Draft00(),
} {
run(scheme, true, true, false, false)
run(scheme, true, false, false, false)
run(scheme, false, true, false, false)
run(scheme, true, true, true, false)
run(scheme, true, true, false, true)
run(scheme, true, true, true, true)
}
}
68 changes: 47 additions & 21 deletions src/crypto/tls/handshake_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"bytes"
"context"
"crypto"
"crypto/ecdh"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
Expand Down Expand Up @@ -38,7 +37,7 @@ type clientHandshakeState struct {

var testingOnlyForceClientHelloSignatureAlgorithms []SignatureScheme

func (c *Conn) makeClientHello(minVersion uint16) (*clientHelloMsg, *ecdh.PrivateKey, error) {
func (c *Conn) makeClientHello(minVersion uint16) (*clientHelloMsg, clientKeySharePrivate, error) {
config := c.config
if len(config.ServerName) == 0 && !config.InsecureSkipVerify {
return nil, nil, errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config")
Expand Down Expand Up @@ -127,7 +126,7 @@ func (c *Conn) makeClientHello(minVersion uint16) (*clientHelloMsg, *ecdh.Privat
hello.supportedSignatureAlgorithms = testingOnlyForceClientHelloSignatureAlgorithms
}

var key *ecdh.PrivateKey
var secret clientKeySharePrivate
if hello.supportedVersions[0] == VersionTLS13 {
if hasAESGCMHardwareSupport {
hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13...)
Expand All @@ -136,19 +135,36 @@ func (c *Conn) makeClientHello(minVersion uint16) (*clientHelloMsg, *ecdh.Privat
}

curveID := config.curvePreferences()[0]
if _, ok := curveForCurveID(curveID); !ok {
return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve")
}
key, err = generateECDHEKey(config.rand(), curveID)
if err != nil {
return nil, nil, err
if scheme := curveIdToCirclScheme(curveID); scheme != nil {
pk, sk, err := generateKemKeyPair(scheme, config.rand())
if err != nil {
return nil, nil, fmt.Errorf("generateKemKeyPair %s: %w",
scheme.Name(), err)
}
packedPk, err := pk.MarshalBinary()
if err != nil {
return nil, nil, fmt.Errorf("pack circl public key %s: %w",
scheme.Name(), err)
}
hello.keyShares = []keyShare{{group: curveID, data: packedPk}}
secret = sk
} else {
if _, ok := curveForCurveID(curveID); !ok {
return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve")
}
key, err := generateECDHEKey(config.rand(), curveID)
if err != nil {
return nil, nil, err
}
hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
secret = key
}
hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}

hello.delegatedCredentialSupported = config.SupportDelegatedCredential
hello.supportedSignatureAlgorithmsDC = supportedSignatureAlgorithmsDC
}

return hello, key, nil
return hello, secret, nil
}

func (c *Conn) clientHandshake(ctx context.Context) (err error) {
Expand Down Expand Up @@ -239,16 +255,16 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {

if c.vers == VersionTLS13 {
hs := &clientHandshakeStateTLS13{
c: c,
ctx: ctx,
serverHello: serverHello,
hello: hello,
ecdheKey: ecdheKey,
helloInner: helloInner,
session: session,
earlySecret: earlySecret,
binderKey: binderKey,
hsTimings: hsTimings,
c: c,
ctx: ctx,
serverHello: serverHello,
hello: hello,
helloInner: helloInner,
keySharePrivate: ecdheKey,
session: session,
earlySecret: earlySecret,
binderKey: binderKey,
hsTimings: hsTimings,
}

// In TLS 1.3, session tickets are delivered after the handshake.
Expand Down Expand Up @@ -581,6 +597,16 @@ func (hs *clientHandshakeState) doFullHandshake() error {
return err
}

if eccKex, ok := keyAgreement.(*ecdheKeyAgreement); ok {
curveId, ok := curveIDForCurve(eccKex.key.Curve())
if !ok {
panic("internal error: unknown curve")
}
c.handleCFEvent(CFEventTLSNegotiatedNamedKEX{
KEX: curveId,
})
}

msg, err = c.readHandshake(&hs.finishedHash)
if err != nil {
return err
Expand Down
Loading

0 comments on commit 837598b

Please sign in to comment.