Skip to content

Commit

Permalink
remove the tls.Config from the quic.Config
Browse files Browse the repository at this point in the history
The tls.Config now is a separate parameter to all Listen and Dial
functions in the quic package.
  • Loading branch information
marten-seemann committed Jul 3, 2017
1 parent 890b801 commit a851aaa
Show file tree
Hide file tree
Showing 14 changed files with 126 additions and 86 deletions.
1 change: 1 addition & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Add a `quic.Config` option to request truncation of the connection ID from a server
- Add a `quic.Config` option to configure the source address validation
- Add a `quic.Config` option to configure the handshake timeout
- Remove the `tls.Config` from the `quic.Config`. The `tls.Config` must now be passed to the `Dial` and `Listen` functions as a separate parameter. See the [Godoc](https://godoc.org/github.com/lucas-clemente/quic-go) for details.
- Changed the log level environment variable to only accept strings ("DEBUG", "INFO", "ERROR"), see [the wiki](https://github.com/lucas-clemente/quic-go/wiki/Logging) for more details.
- Rename the `h2quic.QuicRoundTripper` to `h2quic.RoundTripper`
- Various bugfixes
7 changes: 2 additions & 5 deletions benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ var _ = Describe("Benchmarks", func() {
go func() {
defer GinkgoRecover()
var err error
ln, err = ListenAddr("localhost:0", &Config{TLSConfig: testdata.GetTLSConfig()})
ln, err = ListenAddr("localhost:0", testdata.GetTLSConfig(), nil)
Expect(err).ToNot(HaveOccurred())
serverAddr <- ln.Addr()
sess, err := ln.Accept()
Expand All @@ -49,11 +49,8 @@ var _ = Describe("Benchmarks", func() {
}()

// start the client
conf := &Config{
TLSConfig: &tls.Config{InsecureSkipVerify: true},
}
addr := <-serverAddr
sess, err := DialAddr(addr.String(), conf)
sess, err := DialAddr(addr.String(), &tls.Config{InsecureSkipVerify: true}, nil)
Expect(err).ToNot(HaveOccurred())
str, err := sess.AcceptStream()
Expect(err).ToNot(HaveOccurred())
Expand Down
44 changes: 34 additions & 10 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package quic

import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"net"
Expand All @@ -24,6 +25,7 @@ type client struct {
errorChan chan struct{}
handshakeChan <-chan handshakeEvent

tlsConf *tls.Config
config *Config
versionNegotiated bool // has version negotiation completed yet

Expand All @@ -39,7 +41,7 @@ var (

// DialAddr establishes a new QUIC connection to a server.
// The hostname for SNI is taken from the given address.
func DialAddr(addr string, config *Config) (Session, error) {
func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
Expand All @@ -48,12 +50,16 @@ func DialAddr(addr string, config *Config) (Session, error) {
if err != nil {
return nil, err
}
return Dial(udpConn, udpAddr, addr, config)
return Dial(udpConn, udpAddr, addr, tlsConf, config)
}

// DialAddrNonFWSecure establishes a new QUIC connection to a server.
// The hostname for SNI is taken from the given address.
func DialAddrNonFWSecure(addr string, config *Config) (NonFWSession, error) {
func DialAddrNonFWSecure(
addr string,
tlsConf *tls.Config,
config *Config,
) (NonFWSession, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
Expand All @@ -62,20 +68,26 @@ func DialAddrNonFWSecure(addr string, config *Config) (NonFWSession, error) {
if err != nil {
return nil, err
}
return DialNonFWSecure(udpConn, udpAddr, addr, config)
return DialNonFWSecure(udpConn, udpAddr, addr, tlsConf, config)
}

// DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI.
func DialNonFWSecure(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config) (NonFWSession, error) {
func DialNonFWSecure(
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (NonFWSession, error) {
connID, err := utils.GenerateConnectionID()
if err != nil {
return nil, err
}

var hostname string
if config.TLSConfig != nil {
hostname = config.TLSConfig.ServerName
if tlsConf != nil {
hostname = tlsConf.ServerName
}

if hostname == "" {
Expand All @@ -90,6 +102,7 @@ func DialNonFWSecure(pconn net.PacketConn, remoteAddr net.Addr, host string, con
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
connectionID: connID,
hostname: hostname,
tlsConf: tlsConf,
config: clientConfig,
version: clientConfig.Versions[0],
errorChan: make(chan struct{}),
Expand All @@ -107,8 +120,14 @@ func DialNonFWSecure(pconn net.PacketConn, remoteAddr net.Addr, host string, con

// Dial establishes a new QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI.
func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config) (Session, error) {
sess, err := DialNonFWSecure(pconn, remoteAddr, host, config)
func Dial(
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (Session, error) {
sess, err := DialNonFWSecure(pconn, remoteAddr, host, tlsConf, config)
if err != nil {
return nil, err
}
Expand All @@ -119,7 +138,12 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config
return sess, nil
}

// populateClientConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateClientConfig(config *Config) *Config {
if config == nil {
config = &Config{}
}
versions := config.Versions
if len(versions) == 0 {
versions = protocol.SupportedVersions
Expand All @@ -140,7 +164,6 @@ func populateClientConfig(config *Config) *Config {
}

return &Config{
TLSConfig: config.TLSConfig,
Versions: versions,
HandshakeTimeout: handshakeTimeout,
RequestConnectionIDTruncation: config.RequestConnectionIDTruncation,
Expand Down Expand Up @@ -270,6 +293,7 @@ func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) e
c.hostname,
c.version,
c.connectionID,
c.tlsConf,
c.config,
negotiatedVersions,
)
Expand Down
32 changes: 19 additions & 13 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ var _ = Describe("Client", func() {
packetConn *mockPacketConn
addr net.Addr

originalClientSessConstructor func(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, config *Config, negotiatedVersions []protocol.VersionNumber) (packetHandler, <-chan handshakeEvent, error)
originalClientSessConstructor func(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConf *tls.Config, config *Config, negotiatedVersions []protocol.VersionNumber) (packetHandler, <-chan handshakeEvent, error)
)

BeforeEach(func() {
originalClientSessConstructor = newClientSession
Eventually(areSessionsRunning).Should(BeFalse())
msess, _, _ := newMockSession(nil, 0, 0, nil, nil)
msess, _, _ := newMockSession(nil, 0, 0, nil, nil, nil)
sess = msess.(*mockSession)
packetConn = &mockPacketConn{addr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}}
config = &Config{
Expand Down Expand Up @@ -63,6 +63,7 @@ var _ = Describe("Client", func() {
_ string,
_ protocol.VersionNumber,
_ protocol.ConnectionID,
_ *tls.Config,
_ *Config,
_ []protocol.VersionNumber,
) (packetHandler, <-chan handshakeEvent, error) {
Expand All @@ -75,7 +76,7 @@ var _ = Describe("Client", func() {
go func() {
defer GinkgoRecover()
var err error
dialedSess, err = DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", config)
dialedSess, err = DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).ToNot(HaveOccurred())
}()
Consistently(func() Session { return dialedSess }).Should(BeNil())
Expand All @@ -89,7 +90,7 @@ var _ = Describe("Client", func() {
go func() {
defer GinkgoRecover()
var err error
dialedSess, err = DialAddrNonFWSecure("localhost:18901", config)
dialedSess, err = DialAddrNonFWSecure("localhost:18901", nil, config)
Expect(err).ToNot(HaveOccurred())
}()
Consistently(func() Session { return dialedSess }).Should(BeNil())
Expand All @@ -103,7 +104,7 @@ var _ = Describe("Client", func() {
go func() {
defer GinkgoRecover()
var err error
dialedSess, err = Dial(packetConn, addr, "quic.clemente.io:1337", config)
dialedSess, err = Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).ToNot(HaveOccurred())
}()
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
Expand All @@ -120,13 +121,14 @@ var _ = Describe("Client", func() {
_ string,
_ protocol.VersionNumber,
_ protocol.ConnectionID,
_ *tls.Config,
_ *Config,
_ []protocol.VersionNumber,
) (packetHandler, <-chan handshakeEvent, error) {
cconn = conn
return sess, nil, nil
}
go DialAddr("localhost:17890", &Config{})
go DialAddr("localhost:17890", nil, &Config{})
Eventually(func() connection { return cconn }).ShouldNot(BeNil())
Expect(cconn.RemoteAddr().String()).To(Equal("127.0.0.1:17890"))
close(done)
Expand All @@ -139,13 +141,14 @@ var _ = Describe("Client", func() {
h string,
_ protocol.VersionNumber,
_ protocol.ConnectionID,
_ *tls.Config,
_ *Config,
_ []protocol.VersionNumber,
) (packetHandler, <-chan handshakeEvent, error) {
hostname = h
return sess, nil, nil
}
go DialAddr("localhost:17890", &Config{TLSConfig: &tls.Config{ServerName: "foobar"}})
go DialAddr("localhost:17890", &tls.Config{ServerName: "foobar"}, nil)
Eventually(func() string { return hostname }).Should(Equal("foobar"))
close(done)
})
Expand All @@ -154,7 +157,7 @@ var _ = Describe("Client", func() {
testErr := errors.New("early handshake error")
var dialErr error
go func() {
_, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", config)
_, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
}()
sess.handshakeChan <- handshakeEvent{err: testErr}
Eventually(func() error { return dialErr }).Should(MatchError(testErr))
Expand All @@ -165,7 +168,7 @@ var _ = Describe("Client", func() {
testErr := errors.New("late handshake error")
var dialErr error
go func() {
_, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", config)
_, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
}()
sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure}
sess.handshakeComplete <- testErr
Expand All @@ -192,15 +195,15 @@ var _ = Describe("Client", func() {

It("errors when receiving an invalid first packet from the server", func(done Done) {
packetConn.dataToRead = []byte{0xff}
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).To(HaveOccurred())
close(done)
})

It("errors when receiving an error from the connection", func(done Done) {
testErr := errors.New("connection error")
packetConn.readErr = testErr
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).To(MatchError(testErr))
close(done)
})
Expand All @@ -212,12 +215,13 @@ var _ = Describe("Client", func() {
_ string,
_ protocol.VersionNumber,
_ protocol.ConnectionID,
_ *tls.Config,
_ *Config,
_ []protocol.VersionNumber,
) (packetHandler, <-chan handshakeEvent, error) {
return nil, nil, testErr
}
_, err := DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", config)
_, err := DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).To(MatchError(testErr))
})

Expand All @@ -243,6 +247,7 @@ var _ = Describe("Client", func() {
_ string,
_ protocol.VersionNumber,
connectionID protocol.ConnectionID,
_ *tls.Config,
_ *Config,
negotiatedVersionsP []protocol.VersionNumber,
) (packetHandler, <-chan handshakeEvent, error) {
Expand Down Expand Up @@ -324,6 +329,7 @@ var _ = Describe("Client", func() {
hostnameP string,
versionP protocol.VersionNumber,
_ protocol.ConnectionID,
_ *tls.Config,
configP *Config,
_ []protocol.VersionNumber,
) (packetHandler, <-chan handshakeEvent, error) {
Expand All @@ -335,7 +341,7 @@ var _ = Describe("Client", func() {
return sess, nil, nil
}
go func() {
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
_, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config)
Expect(err).ToNot(HaveOccurred())
}()
<-c
Expand Down
10 changes: 2 additions & 8 deletions example/echo/echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@ func main() {

// Start a server that echos all data on the first stream opened by the client
func echoServer() error {
cfgServer := &quic.Config{
TLSConfig: generateTLSConfig(),
}
listener, err := quic.ListenAddr(addr, cfgServer)
listener, err := quic.ListenAddr(addr, generateTLSConfig(), nil)
if err != nil {
return err
}
Expand All @@ -52,10 +49,7 @@ func echoServer() error {
}

func clientMain() error {
cfgClient := &quic.Config{
TLSConfig: &tls.Config{InsecureSkipVerify: true},
}
session, err := quic.DialAddr(addr, cfgClient)
session, err := quic.DialAddr(addr, &tls.Config{InsecureSkipVerify: true}, nil)
if err != nil {
return err
}
Expand Down
7 changes: 4 additions & 3 deletions h2quic/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ type roundTripperOpts struct {
type client struct {
mutex sync.RWMutex

dialAddr func(hostname string, config *quic.Config) (quic.Session, error)
dialAddr func(hostname string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error)
tlsConf *tls.Config
config *quic.Config
opts *roundTripperOpts

Expand All @@ -55,8 +56,8 @@ func newClient(tlsConfig *tls.Config, hostname string, opts *roundTripperOpts) *
hostname: authorityAddr("https", hostname),
responses: make(map[protocol.StreamID]chan *http.Response),
encryptionLevel: protocol.EncryptionUnencrypted,
tlsConf: tlsConfig,
config: &quic.Config{
TLSConfig: tlsConfig,
RequestConnectionIDTruncation: true,
},
opts: opts,
Expand All @@ -67,7 +68,7 @@ func newClient(tlsConfig *tls.Config, hostname string, opts *roundTripperOpts) *
// dial dials the connection
func (c *client) dial() error {
var err error
c.session, err = c.dialAddr(c.hostname, c.config)
c.session, err = c.dialAddr(c.hostname, c.tlsConf, c.config)
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit a851aaa

Please sign in to comment.