diff --git a/Changelog.md b/Changelog.md index ff2908487e3..b06c9746b82 100644 --- a/Changelog.md +++ b/Changelog.md @@ -7,6 +7,7 @@ - Add a `quic.Config` option to request truncation of the connection ID from a server - Add a `quic.Config` option to configure the source address validation - Add a `quic.Config` option to configure the handshake timeout +- Remove the `tls.Config` from the `quic.Config`. The `tls.Config` must now be passed to the `Dial` and `Listen` functions as a separate parameter. See the [Godoc](https://godoc.org/github.com/lucas-clemente/quic-go) for details. - Changed the log level environment variable to only accept strings ("DEBUG", "INFO", "ERROR"), see [the wiki](https://github.com/lucas-clemente/quic-go/wiki/Logging) for more details. - Rename the `h2quic.QuicRoundTripper` to `h2quic.RoundTripper` - Various bugfixes diff --git a/benchmark_test.go b/benchmark_test.go index 8c3962a8681..697ec8a76eb 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -35,7 +35,7 @@ var _ = Describe("Benchmarks", func() { go func() { defer GinkgoRecover() var err error - ln, err = ListenAddr("localhost:0", &Config{TLSConfig: testdata.GetTLSConfig()}) + ln, err = ListenAddr("localhost:0", testdata.GetTLSConfig(), nil) Expect(err).ToNot(HaveOccurred()) serverAddr <- ln.Addr() sess, err := ln.Accept() @@ -49,11 +49,8 @@ var _ = Describe("Benchmarks", func() { }() // start the client - conf := &Config{ - TLSConfig: &tls.Config{InsecureSkipVerify: true}, - } addr := <-serverAddr - sess, err := DialAddr(addr.String(), conf) + sess, err := DialAddr(addr.String(), &tls.Config{InsecureSkipVerify: true}, nil) Expect(err).ToNot(HaveOccurred()) str, err := sess.AcceptStream() Expect(err).ToNot(HaveOccurred()) diff --git a/client.go b/client.go index 9e6ae0ad2cc..a4f6141244b 100644 --- a/client.go +++ b/client.go @@ -2,6 +2,7 @@ package quic import ( "bytes" + "crypto/tls" "errors" "fmt" "net" @@ -24,6 +25,7 @@ type client struct { errorChan chan struct{} handshakeChan <-chan handshakeEvent + tlsConf *tls.Config config *Config versionNegotiated bool // has version negotiation completed yet @@ -39,7 +41,7 @@ var ( // DialAddr establishes a new QUIC connection to a server. // The hostname for SNI is taken from the given address. -func DialAddr(addr string, config *Config) (Session, error) { +func DialAddr(addr string, tlsConf *tls.Config, config *Config) (Session, error) { udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err @@ -48,12 +50,16 @@ func DialAddr(addr string, config *Config) (Session, error) { if err != nil { return nil, err } - return Dial(udpConn, udpAddr, addr, config) + return Dial(udpConn, udpAddr, addr, tlsConf, config) } // DialAddrNonFWSecure establishes a new QUIC connection to a server. // The hostname for SNI is taken from the given address. -func DialAddrNonFWSecure(addr string, config *Config) (NonFWSession, error) { +func DialAddrNonFWSecure( + addr string, + tlsConf *tls.Config, + config *Config, +) (NonFWSession, error) { udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err @@ -62,20 +68,26 @@ func DialAddrNonFWSecure(addr string, config *Config) (NonFWSession, error) { if err != nil { return nil, err } - return DialNonFWSecure(udpConn, udpAddr, addr, config) + return DialNonFWSecure(udpConn, udpAddr, addr, tlsConf, config) } // DialNonFWSecure establishes a new non-forward-secure QUIC connection to a server using a net.PacketConn. // The host parameter is used for SNI. -func DialNonFWSecure(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config) (NonFWSession, error) { +func DialNonFWSecure( + pconn net.PacketConn, + remoteAddr net.Addr, + host string, + tlsConf *tls.Config, + config *Config, +) (NonFWSession, error) { connID, err := utils.GenerateConnectionID() if err != nil { return nil, err } var hostname string - if config.TLSConfig != nil { - hostname = config.TLSConfig.ServerName + if tlsConf != nil { + hostname = tlsConf.ServerName } if hostname == "" { @@ -90,6 +102,7 @@ func DialNonFWSecure(pconn net.PacketConn, remoteAddr net.Addr, host string, con conn: &conn{pconn: pconn, currentAddr: remoteAddr}, connectionID: connID, hostname: hostname, + tlsConf: tlsConf, config: clientConfig, version: clientConfig.Versions[0], errorChan: make(chan struct{}), @@ -107,8 +120,14 @@ func DialNonFWSecure(pconn net.PacketConn, remoteAddr net.Addr, host string, con // Dial establishes a new QUIC connection to a server using a net.PacketConn. // The host parameter is used for SNI. -func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config) (Session, error) { - sess, err := DialNonFWSecure(pconn, remoteAddr, host, config) +func Dial( + pconn net.PacketConn, + remoteAddr net.Addr, + host string, + tlsConf *tls.Config, + config *Config, +) (Session, error) { + sess, err := DialNonFWSecure(pconn, remoteAddr, host, tlsConf, config) if err != nil { return nil, err } @@ -119,7 +138,12 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config return sess, nil } +// populateClientConfig populates fields in the quic.Config with their default values, if none are set +// it may be called with nil func populateClientConfig(config *Config) *Config { + if config == nil { + config = &Config{} + } versions := config.Versions if len(versions) == 0 { versions = protocol.SupportedVersions @@ -140,7 +164,6 @@ func populateClientConfig(config *Config) *Config { } return &Config{ - TLSConfig: config.TLSConfig, Versions: versions, HandshakeTimeout: handshakeTimeout, RequestConnectionIDTruncation: config.RequestConnectionIDTruncation, @@ -270,6 +293,7 @@ func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) e c.hostname, c.version, c.connectionID, + c.tlsConf, c.config, negotiatedVersions, ) diff --git a/client_test.go b/client_test.go index 81d88ead0a0..ece3a012414 100644 --- a/client_test.go +++ b/client_test.go @@ -22,13 +22,13 @@ var _ = Describe("Client", func() { packetConn *mockPacketConn addr net.Addr - originalClientSessConstructor func(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, config *Config, negotiatedVersions []protocol.VersionNumber) (packetHandler, <-chan handshakeEvent, error) + originalClientSessConstructor func(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConf *tls.Config, config *Config, negotiatedVersions []protocol.VersionNumber) (packetHandler, <-chan handshakeEvent, error) ) BeforeEach(func() { originalClientSessConstructor = newClientSession Eventually(areSessionsRunning).Should(BeFalse()) - msess, _, _ := newMockSession(nil, 0, 0, nil, nil) + msess, _, _ := newMockSession(nil, 0, 0, nil, nil, nil) sess = msess.(*mockSession) packetConn = &mockPacketConn{addr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}} config = &Config{ @@ -63,6 +63,7 @@ var _ = Describe("Client", func() { _ string, _ protocol.VersionNumber, _ protocol.ConnectionID, + _ *tls.Config, _ *Config, _ []protocol.VersionNumber, ) (packetHandler, <-chan handshakeEvent, error) { @@ -75,7 +76,7 @@ var _ = Describe("Client", func() { go func() { defer GinkgoRecover() var err error - dialedSess, err = DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", config) + dialedSess, err = DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", nil, config) Expect(err).ToNot(HaveOccurred()) }() Consistently(func() Session { return dialedSess }).Should(BeNil()) @@ -89,7 +90,7 @@ var _ = Describe("Client", func() { go func() { defer GinkgoRecover() var err error - dialedSess, err = DialAddrNonFWSecure("localhost:18901", config) + dialedSess, err = DialAddrNonFWSecure("localhost:18901", nil, config) Expect(err).ToNot(HaveOccurred()) }() Consistently(func() Session { return dialedSess }).Should(BeNil()) @@ -103,7 +104,7 @@ var _ = Describe("Client", func() { go func() { defer GinkgoRecover() var err error - dialedSess, err = Dial(packetConn, addr, "quic.clemente.io:1337", config) + dialedSess, err = Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) Expect(err).ToNot(HaveOccurred()) }() sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure} @@ -120,13 +121,14 @@ var _ = Describe("Client", func() { _ string, _ protocol.VersionNumber, _ protocol.ConnectionID, + _ *tls.Config, _ *Config, _ []protocol.VersionNumber, ) (packetHandler, <-chan handshakeEvent, error) { cconn = conn return sess, nil, nil } - go DialAddr("localhost:17890", &Config{}) + go DialAddr("localhost:17890", nil, &Config{}) Eventually(func() connection { return cconn }).ShouldNot(BeNil()) Expect(cconn.RemoteAddr().String()).To(Equal("127.0.0.1:17890")) close(done) @@ -139,13 +141,14 @@ var _ = Describe("Client", func() { h string, _ protocol.VersionNumber, _ protocol.ConnectionID, + _ *tls.Config, _ *Config, _ []protocol.VersionNumber, ) (packetHandler, <-chan handshakeEvent, error) { hostname = h return sess, nil, nil } - go DialAddr("localhost:17890", &Config{TLSConfig: &tls.Config{ServerName: "foobar"}}) + go DialAddr("localhost:17890", &tls.Config{ServerName: "foobar"}, nil) Eventually(func() string { return hostname }).Should(Equal("foobar")) close(done) }) @@ -154,7 +157,7 @@ var _ = Describe("Client", func() { testErr := errors.New("early handshake error") var dialErr error go func() { - _, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", config) + _, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) }() sess.handshakeChan <- handshakeEvent{err: testErr} Eventually(func() error { return dialErr }).Should(MatchError(testErr)) @@ -165,7 +168,7 @@ var _ = Describe("Client", func() { testErr := errors.New("late handshake error") var dialErr error go func() { - _, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", config) + _, dialErr = Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) }() sess.handshakeChan <- handshakeEvent{encLevel: protocol.EncryptionSecure} sess.handshakeComplete <- testErr @@ -192,7 +195,7 @@ var _ = Describe("Client", func() { It("errors when receiving an invalid first packet from the server", func(done Done) { packetConn.dataToRead = []byte{0xff} - _, err := Dial(packetConn, addr, "quic.clemente.io:1337", config) + _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) Expect(err).To(HaveOccurred()) close(done) }) @@ -200,7 +203,7 @@ var _ = Describe("Client", func() { It("errors when receiving an error from the connection", func(done Done) { testErr := errors.New("connection error") packetConn.readErr = testErr - _, err := Dial(packetConn, addr, "quic.clemente.io:1337", config) + _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) Expect(err).To(MatchError(testErr)) close(done) }) @@ -212,12 +215,13 @@ var _ = Describe("Client", func() { _ string, _ protocol.VersionNumber, _ protocol.ConnectionID, + _ *tls.Config, _ *Config, _ []protocol.VersionNumber, ) (packetHandler, <-chan handshakeEvent, error) { return nil, nil, testErr } - _, err := DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", config) + _, err := DialNonFWSecure(packetConn, addr, "quic.clemente.io:1337", nil, config) Expect(err).To(MatchError(testErr)) }) @@ -243,6 +247,7 @@ var _ = Describe("Client", func() { _ string, _ protocol.VersionNumber, connectionID protocol.ConnectionID, + _ *tls.Config, _ *Config, negotiatedVersionsP []protocol.VersionNumber, ) (packetHandler, <-chan handshakeEvent, error) { @@ -324,6 +329,7 @@ var _ = Describe("Client", func() { hostnameP string, versionP protocol.VersionNumber, _ protocol.ConnectionID, + _ *tls.Config, configP *Config, _ []protocol.VersionNumber, ) (packetHandler, <-chan handshakeEvent, error) { @@ -335,7 +341,7 @@ var _ = Describe("Client", func() { return sess, nil, nil } go func() { - _, err := Dial(packetConn, addr, "quic.clemente.io:1337", config) + _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) Expect(err).ToNot(HaveOccurred()) }() <-c diff --git a/example/echo/echo.go b/example/echo/echo.go index 520130dd985..0f39c12701a 100644 --- a/example/echo/echo.go +++ b/example/echo/echo.go @@ -31,10 +31,7 @@ func main() { // Start a server that echos all data on the first stream opened by the client func echoServer() error { - cfgServer := &quic.Config{ - TLSConfig: generateTLSConfig(), - } - listener, err := quic.ListenAddr(addr, cfgServer) + listener, err := quic.ListenAddr(addr, generateTLSConfig(), nil) if err != nil { return err } @@ -52,10 +49,7 @@ func echoServer() error { } func clientMain() error { - cfgClient := &quic.Config{ - TLSConfig: &tls.Config{InsecureSkipVerify: true}, - } - session, err := quic.DialAddr(addr, cfgClient) + session, err := quic.DialAddr(addr, &tls.Config{InsecureSkipVerify: true}, nil) if err != nil { return err } diff --git a/h2quic/client.go b/h2quic/client.go index 3ae14c88fab..906253e08c5 100644 --- a/h2quic/client.go +++ b/h2quic/client.go @@ -28,7 +28,8 @@ type roundTripperOpts struct { type client struct { mutex sync.RWMutex - dialAddr func(hostname string, config *quic.Config) (quic.Session, error) + dialAddr func(hostname string, tlsConf *tls.Config, config *quic.Config) (quic.Session, error) + tlsConf *tls.Config config *quic.Config opts *roundTripperOpts @@ -55,8 +56,8 @@ func newClient(tlsConfig *tls.Config, hostname string, opts *roundTripperOpts) * hostname: authorityAddr("https", hostname), responses: make(map[protocol.StreamID]chan *http.Response), encryptionLevel: protocol.EncryptionUnencrypted, + tlsConf: tlsConfig, config: &quic.Config{ - TLSConfig: tlsConfig, RequestConnectionIDTruncation: true, }, opts: opts, @@ -67,7 +68,7 @@ func newClient(tlsConfig *tls.Config, hostname string, opts *roundTripperOpts) * // dial dials the connection func (c *client) dial() error { var err error - c.session, err = c.dialAddr(c.hostname, c.config) + c.session, err = c.dialAddr(c.hostname, c.tlsConf, c.config) if err != nil { return err } diff --git a/h2quic/client_test.go b/h2quic/client_test.go index 6f66425d2a6..fa61d142bad 100644 --- a/h2quic/client_test.go +++ b/h2quic/client_test.go @@ -45,7 +45,7 @@ var _ = Describe("Client", func() { It("saves the TLS config", func() { tlsConf := &tls.Config{InsecureSkipVerify: true} client = newClient(tlsConf, "", &roundTripperOpts{}) - Expect(client.config.TLSConfig).To(Equal(tlsConf)) + Expect(client.tlsConf).To(Equal(tlsConf)) }) It("adds the port to the hostname, if none is given", func() { @@ -56,7 +56,7 @@ var _ = Describe("Client", func() { It("dials", func(done Done) { client = newClient(nil, "localhost:1337", &roundTripperOpts{}) session.streamsToOpen = []quic.Stream{newMockStream(3), newMockStream(5)} - client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) { + client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return session, nil } close(headerStream.unblockRead) @@ -68,7 +68,7 @@ var _ = Describe("Client", func() { It("errors when dialing fails", func() { testErr := errors.New("handshake error") client = newClient(nil, "localhost:1337", &roundTripperOpts{}) - client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) { + client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return nil, testErr } _, err := client.RoundTrip(req) @@ -78,7 +78,7 @@ var _ = Describe("Client", func() { It("errors if the header stream has the wrong stream ID", func() { client = newClient(nil, "localhost:1337", &roundTripperOpts{}) session.streamsToOpen = []quic.Stream{&mockStream{id: 2}} - client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) { + client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return session, nil } _, err := client.RoundTrip(req) @@ -89,7 +89,7 @@ var _ = Describe("Client", func() { testErr := errors.New("you shall not pass") client = newClient(nil, "localhost:1337", &roundTripperOpts{}) session.streamOpenErr = testErr - client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) { + client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return session, nil } _, err := client.RoundTrip(req) @@ -98,7 +98,7 @@ var _ = Describe("Client", func() { It("returns a request when dial fails", func() { testErr := errors.New("dial error") - client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) { + client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return nil, testErr } request, err := http.NewRequest("https", "https://quic.clemente.io:1337/file1.dat", nil) @@ -140,7 +140,7 @@ var _ = Describe("Client", func() { BeforeEach(func() { var err error client.encryptionLevel = protocol.EncryptionForwardSecure - client.dialAddr = func(hostname string, conf *quic.Config) (quic.Session, error) { + client.dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { return session, nil } dataStream = newMockStream(5) diff --git a/h2quic/server.go b/h2quic/server.go index ae23762daba..e391ceb11d6 100644 --- a/h2quic/server.go +++ b/h2quic/server.go @@ -84,16 +84,15 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn *net.UDPConn) error { } config := quic.Config{ - TLSConfig: tlsConfig, - Versions: protocol.SupportedVersions, + Versions: protocol.SupportedVersions, } var ln quic.Listener var err error if conn == nil { - ln, err = quic.ListenAddr(s.Addr, &config) + ln, err = quic.ListenAddr(s.Addr, tlsConfig, &config) } else { - ln, err = quic.Listen(conn, &config) + ln, err = quic.Listen(conn, tlsConfig, &config) } if err != nil { s.listenerMutex.Unlock() diff --git a/integrationtests/handshake/rtt.go b/integrationtests/handshake/rtt.go index c1f4533fe00..29473e8c741 100644 --- a/integrationtests/handshake/rtt.go +++ b/integrationtests/handshake/rtt.go @@ -11,8 +11,8 @@ import ( "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" - "github.com/lucas-clemente/quic-go/testdata" + "github.com/lucas-clemente/quic-go/testdata" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -24,11 +24,10 @@ var _ = Describe("Handshake integration tets", func() { serverConfig *quic.Config testStartedAt time.Time ) - rtt := 350 * time.Millisecond BeforeEach(func() { - serverConfig = &quic.Config{TLSConfig: testdata.GetTLSConfig()} + serverConfig = &quic.Config{} }) AfterEach(func() { @@ -39,7 +38,7 @@ var _ = Describe("Handshake integration tets", func() { runServerAndProxy := func() { var err error // start the server - server, err = quic.ListenAddr("localhost:0", serverConfig) + server, err = quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) // start the proxy proxy, err = quicproxy.NewQuicProxy("localhost:0", quicproxy.Opts{ @@ -73,7 +72,7 @@ var _ = Describe("Handshake integration tets", func() { clientConfig := &quic.Config{ Versions: protocol.SupportedVersions[1:2], } - _, err := quic.DialAddr(proxy.LocalAddr().String(), clientConfig) + _, err := quic.DialAddr(proxy.LocalAddr().String(), nil, clientConfig) Expect(err).To(HaveOccurred()) Expect(err.(qerr.ErrorCode)).To(Equal(qerr.InvalidVersion)) expectDurationInRTTs(1) @@ -84,7 +83,7 @@ var _ = Describe("Handshake integration tets", func() { // 1 RTT to become forward-secure It("is forward-secure after 3 RTTs", func() { runServerAndProxy() - _, err := quic.DialAddr(proxy.LocalAddr().String(), &quic.Config{TLSConfig: &tls.Config{InsecureSkipVerify: true}}) + _, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil) Expect(err).ToNot(HaveOccurred()) expectDurationInRTTs(3) }) @@ -95,7 +94,7 @@ var _ = Describe("Handshake integration tets", func() { PIt("is secure after 2 RTTs", func() { utils.SetLogLevel(utils.LogLevelDebug) runServerAndProxy() - _, err := quic.DialAddrNonFWSecure(proxy.LocalAddr().String(), &quic.Config{TLSConfig: &tls.Config{InsecureSkipVerify: true}}) + _, err := quic.DialAddrNonFWSecure(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil) fmt.Println("#### is non fw secure ###") Expect(err).ToNot(HaveOccurred()) expectDurationInRTTs(2) @@ -106,7 +105,7 @@ var _ = Describe("Handshake integration tets", func() { return true } runServerAndProxy() - _, err := quic.DialAddr(proxy.LocalAddr().String(), &quic.Config{TLSConfig: &tls.Config{InsecureSkipVerify: true}}) + _, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil) Expect(err).ToNot(HaveOccurred()) expectDurationInRTTs(2) }) @@ -116,7 +115,7 @@ var _ = Describe("Handshake integration tets", func() { return false } runServerAndProxy() - _, err := quic.DialAddr(proxy.LocalAddr().String(), &quic.Config{TLSConfig: &tls.Config{InsecureSkipVerify: true}}) + _, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil) Expect(err).To(HaveOccurred()) Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.CryptoTooManyRejects)) }) @@ -124,7 +123,7 @@ var _ = Describe("Handshake integration tets", func() { It("doesn't complete the handshake when the handshake timeout is too short", func() { serverConfig.HandshakeTimeout = 2 * rtt runServerAndProxy() - _, err := quic.DialAddr(proxy.LocalAddr().String(), &quic.Config{TLSConfig: &tls.Config{InsecureSkipVerify: true}}) + _, err := quic.DialAddr(proxy.LocalAddr().String(), &tls.Config{InsecureSkipVerify: true}, nil) Expect(err).To(HaveOccurred()) Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout)) // 2 RTTs during the timeout diff --git a/interface.go b/interface.go index 9cf52d613f5..b812d37b4bc 100644 --- a/interface.go +++ b/interface.go @@ -1,7 +1,6 @@ package quic import ( - "crypto/tls" "io" "net" "time" @@ -64,7 +63,6 @@ type STK struct { // Config contains all configuration data needed for a QUIC server or client. // More config parameters (such as timeouts) will be added soon, see e.g. https://github.com/lucas-clemente/quic-go/issues/441. type Config struct { - TLSConfig *tls.Config // The QUIC versions that can be negotiated. // If not set, it uses all versions available. // Warning: This API should not be considered stable and will change soon. diff --git a/server.go b/server.go index 5048df5f392..d45168b3933 100644 --- a/server.go +++ b/server.go @@ -2,6 +2,7 @@ package quic import ( "bytes" + "crypto/tls" "errors" "net" "sync" @@ -23,7 +24,8 @@ type packetHandler interface { // A Listener of QUIC type server struct { - config *Config + tlsConf *tls.Config + config *Config conn net.PacketConn @@ -38,14 +40,15 @@ type server struct { sessionQueue chan Session errorChan chan struct{} - newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, config *Config) (packetHandler, <-chan handshakeEvent, error) + newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, tlsConf *tls.Config, config *Config) (packetHandler, <-chan handshakeEvent, error) } var _ Listener = &server{} // ListenAddr creates a QUIC server listening on a given address. // The listener is not active until Serve() is called. -func ListenAddr(addr string, config *Config) (Listener, error) { +// The tls.Config must not be nil, the quic.Config may be nil. +func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) { udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err @@ -54,13 +57,14 @@ func ListenAddr(addr string, config *Config) (Listener, error) { if err != nil { return nil, err } - return Listen(conn, config) + return Listen(conn, tlsConf, config) } // Listen listens for QUIC connections on a given net.PacketConn. // The listener is not active until Serve() is called. -func Listen(conn net.PacketConn, config *Config) (Listener, error) { - certChain := crypto.NewCertChain(config.TLSConfig) +// The tls.Config must not be nil, the quic.Config may be nil. +func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) { + certChain := crypto.NewCertChain(tlsConf) kex, err := crypto.NewCurve25519KEX() if err != nil { return nil, err @@ -72,6 +76,7 @@ func Listen(conn net.PacketConn, config *Config) (Listener, error) { s := &server{ conn: conn, + tlsConf: tlsConf, config: populateServerConfig(config), certChain: certChain, scfg: scfg, @@ -101,7 +106,12 @@ var defaultAcceptSTK = func(clientAddr net.Addr, stk *STK) bool { return sourceAddr == stk.remoteAddr } +// populateServerConfig populates fields in the quic.Config with their default values, if none are set +// it may be called with nil func populateServerConfig(config *Config) *Config { + if config == nil { + config = &Config{} + } versions := config.Versions if len(versions) == 0 { versions = protocol.SupportedVersions @@ -127,7 +137,6 @@ func populateServerConfig(config *Config) *Config { } return &Config{ - TLSConfig: config.TLSConfig, Versions: versions, HandshakeTimeout: handshakeTimeout, AcceptSTK: vsa, @@ -256,6 +265,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet version, hdr.ConnectionID, s.scfg, + s.tlsConf, s.config, ) if err != nil { diff --git a/server_test.go b/server_test.go index 6810b7eba56..893754fed56 100644 --- a/server_test.go +++ b/server_test.go @@ -75,6 +75,7 @@ func newMockSession( _ protocol.VersionNumber, connectionID protocol.ConnectionID, _ *handshake.ServerConfig, + _ *tls.Config, _ *Config, ) (packetHandler, <-chan handshakeEvent, error) { s := mockSession{ @@ -95,10 +96,7 @@ var _ = Describe("Server", func() { BeforeEach(func() { conn = &mockPacketConn{} - config = &Config{ - TLSConfig: &tls.Config{}, - Versions: protocol.SupportedVersions, - } + config = &Config{Versions: protocol.SupportedVersions} }) Context("with mock session", func() { @@ -225,7 +223,7 @@ var _ = Describe("Server", func() { }) It("closes sessions and the connection when Close is called", func() { - session, _, _ := newMockSession(nil, 0, 0, nil, nil) + session, _, _ := newMockSession(nil, 0, 0, nil, nil, nil) serv.sessions[1] = session err := serv.Close() Expect(err).NotTo(HaveOccurred()) @@ -241,8 +239,15 @@ var _ = Describe("Server", func() { Expect(serv.sessions[connID]).To(BeNil()) }) + It("works if no quic.Config is given", func(done Done) { + ln, err := ListenAddr("127.0.0.1:0", nil, config) + Expect(err).ToNot(HaveOccurred()) + Expect(ln.Close()).To(Succeed()) + close(done) + }, 1) + It("closes properly", func() { - ln, err := ListenAddr("127.0.0.1:0", config) + ln, err := ListenAddr("127.0.0.1:0", nil, config) Expect(err).ToNot(HaveOccurred()) var returned bool @@ -268,7 +273,7 @@ var _ = Describe("Server", func() { }, 0.5) It("closes all sessions when encountering a connection error", func() { - session, _, _ := newMockSession(nil, 0, 0, nil, nil) + session, _, _ := newMockSession(nil, 0, 0, nil, nil, nil) serv.sessions[0x12345] = session Expect(serv.sessions[0x12345].(*mockSession).closed).To(BeFalse()) testErr := errors.New("connection error") @@ -348,12 +353,11 @@ var _ = Describe("Server", func() { supportedVersions := []protocol.VersionNumber{1, 3, 5} acceptSTK := func(_ net.Addr, _ *STK) bool { return true } config := Config{ - TLSConfig: &tls.Config{}, Versions: supportedVersions, AcceptSTK: acceptSTK, HandshakeTimeout: 1337 * time.Hour, } - ln, err := Listen(conn, &config) + ln, err := Listen(conn, &tls.Config{}, &config) Expect(err).ToNot(HaveOccurred()) server := ln.(*server) Expect(server.deleteClosedSessionsAfter).To(Equal(protocol.ClosedSessionDeleteTimeout)) @@ -365,8 +369,7 @@ var _ = Describe("Server", func() { }) It("fills in default values if options are not set in the Config", func() { - config := Config{TLSConfig: &tls.Config{}} - ln, err := Listen(conn, &config) + ln, err := Listen(conn, &tls.Config{}, &Config{}) Expect(err).ToNot(HaveOccurred()) server := ln.(*server) Expect(server.config.Versions).To(Equal(protocol.SupportedVersions)) @@ -376,7 +379,7 @@ var _ = Describe("Server", func() { It("listens on a given address", func() { addr := "127.0.0.1:13579" - ln, err := ListenAddr(addr, config) + ln, err := ListenAddr(addr, nil, config) Expect(err).ToNot(HaveOccurred()) serv := ln.(*server) Expect(serv.Addr().String()).To(Equal(addr)) @@ -384,13 +387,13 @@ var _ = Describe("Server", func() { It("errors if given an invalid address", func() { addr := "127.0.0.1" - _, err := ListenAddr(addr, config) + _, err := ListenAddr(addr, nil, config) Expect(err).To(BeAssignableToTypeOf(&net.AddrError{})) }) It("errors if given an invalid address", func() { addr := "1.1.1.1:1111" - _, err := ListenAddr(addr, config) + _, err := ListenAddr(addr, nil, config) Expect(err).To(BeAssignableToTypeOf(&net.OpError{})) }) @@ -407,7 +410,7 @@ var _ = Describe("Server", func() { b.Write(bytes.Repeat([]byte{0}, protocol.ClientHelloMinimumSize)) // add a fake CHLO conn.dataToRead = b.Bytes() conn.dataReadFrom = udpAddr - ln, err := Listen(conn, config) + ln, err := Listen(conn, nil, config) Expect(err).ToNot(HaveOccurred()) var returned bool @@ -431,7 +434,7 @@ var _ = Describe("Server", func() { It("sends a PublicReset for new connections that don't have the VersionFlag set", func() { conn.dataReadFrom = udpAddr conn.dataToRead = []byte{0x08, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0x01} - ln, err := Listen(conn, config) + ln, err := Listen(conn, nil, config) Expect(err).ToNot(HaveOccurred()) go func() { defer GinkgoRecover() diff --git a/session.go b/session.go index c241601747c..ae9cf2155c0 100644 --- a/session.go +++ b/session.go @@ -1,6 +1,7 @@ package quic import ( + "crypto/tls" "errors" "fmt" "net" @@ -53,6 +54,7 @@ type session struct { connectionID protocol.ConnectionID perspective protocol.Perspective version protocol.VersionNumber + tlsConf *tls.Config config *Config conn connection @@ -119,6 +121,7 @@ func newSession( v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, + tlsConf *tls.Config, config *Config, ) (packetHandler, <-chan handshakeEvent, error) { s := &session{ @@ -137,6 +140,7 @@ var newClientSession = func( hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, + tlsConf *tls.Config, config *Config, negotiatedVersions []protocol.VersionNumber, ) (packetHandler, <-chan handshakeEvent, error) { @@ -145,6 +149,7 @@ var newClientSession = func( connectionID: connectionID, perspective: protocol.PerspectiveClient, version: v, + tlsConf: tlsConf, config: config, } return s.setup(nil, hostname, negotiatedVersions) @@ -209,7 +214,7 @@ func (s *session) setup( s.connectionID, s.version, cryptoStream, - s.config.TLSConfig, + s.tlsConf, s.connectionParameters, aeadChanged, &handshake.TransportParameters{RequestConnectionIDTruncation: s.config.RequestConnectionIDTruncation}, diff --git a/session_test.go b/session_test.go index ba0c69d5ee7..fd4952eae7d 100644 --- a/session_test.go +++ b/session_test.go @@ -169,6 +169,7 @@ var _ = Describe("Session", func() { protocol.Version35, 0, scfg, + nil, populateServerConfig(&Config{}), ) Expect(err).NotTo(HaveOccurred()) @@ -220,6 +221,7 @@ var _ = Describe("Session", func() { protocol.Version35, 0, scfg, + nil, conf, ) Expect(err).NotTo(HaveOccurred()) @@ -1635,6 +1637,7 @@ var _ = Describe("Client Session", func() { "hostname", protocol.Version35, 0, + nil, populateClientConfig(&Config{}), nil, )