diff --git a/conn/bind_std_tcp.go b/conn/bind_std_tcp.go new file mode 100644 index 000000000..299c242c8 --- /dev/null +++ b/conn/bind_std_tcp.go @@ -0,0 +1,331 @@ +/* + * Copyright (c) 2022. Proton AG + * + * This file is part of ProtonVPN. + * + * ProtonVPN is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * ProtonVPN is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with ProtonVPN. If not, see . + */ + +package conn + +import ( + "bytes" + "errors" + "io" + "net" + "net/netip" + "sync" + "time" + + tls "github.com/refraction-networking/utls" +) + +var lastErrorTimestamp time.Time + +type StdNetBindTcp struct { + mu sync.Mutex + + useTls bool + tcp *net.TCPConn + tls *tls.UConn + endpoint *StdNetEndpoint + currentPacket *bytes.Reader + closed bool + log *Logger + errorChan chan<- error + protectSocket func(fd int) int + + tunsafe *TunSafeData +} + +//goland:noinspection GoUnusedExportedFunction +func CreateStdNetBind(socketType string, log *Logger, errorChan chan<- error) Bind { + if socketType == "udp" { + return NewStdNetBind() + } else { + return &StdNetBindTcp{tunsafe: NewTunSafeData(), useTls: socketType == "tls", log: log, errorChan: errorChan} + } +} + +func (s *StdNetBindTcp) BatchSize() int { + return 1 +} + +func (s *StdNetBindTcp) GetOffloadInfo() string { + return "" +} + +func (bind *StdNetBindTcp) ParseEndpoint(s string) (Endpoint, error) { + e, err := netip.ParseAddrPort(s) + if err == nil { + bind.endpoint = &StdNetEndpoint{AddrPort: e} + } + if err != nil { + return nil, err + } + return asEndpoint(e), err +} + +func dialTcp(addr string, protectSocket func(fd int) int) (*net.TCPConn, int, error) { + dialer := net.Dialer{Timeout: 5 * time.Second} + netConn, err := dialer.Dial("tcp", addr) + if err != nil { + return nil, 0, err + } + + conn := netConn.(*net.TCPConn) + conn.SetLinger(0) + + // Retrieve port. + laddr := conn.LocalAddr() + taddr, err := net.ResolveTCPAddr( + laddr.Network(), + laddr.String(), + ) + if err != nil { + _ = conn.Close() + return nil, 0, err + } + return conn, taddr.Port, nil +} + +func (bind *StdNetBindTcp) upgradeToTls() error { + tlsConf := &tls.Config{ + InsecureSkipVerify: true, + ServerName: randomServerName(), + } + + conn := tls.UClient(bind.tcp, tlsConf, tls.HelloChrome_Auto) + conn.SetDeadline(time.Now().Add(5 * time.Second)) + bind.log.Verbosef("TLS: Starting handshake") + err := conn.Handshake() + bind.log.Verbosef("TLS: Handshake result: %v", err) + conn.SetDeadline(time.Time{}) + + // On some devices (e.g. Samsung S21 FE) we see first WireGuard handshake failing on TLS socket and adding small + // delay seems to fix that - issue is likely with timing on the server side, but couldn't find server-side fix. + time.Sleep(100 * time.Millisecond) + + if err == nil { + bind.tls = conn + } else { + bind.onSocketError(err) + conn.Close() + } + return err +} + +func (bind *StdNetBindTcp) Open(uport uint16) ([]ReceiveFunc, uint16, error) { + bind.mu.Lock() + defer bind.mu.Unlock() + + bind.log.Verbosef("TCP/TLS: Open %d", uport) + bind.closed = false + return []ReceiveFunc{bind.makeReceiveFunc()}, uport, nil +} + +func (bind *StdNetBindTcp) initTcp() error { + var err error + + if bind.tcp != nil { + return ErrBindAlreadyOpen + } + + var tcp *net.TCPConn + + tcp, _, err = dialTcp(bind.endpoint.DstToString(), bind.protectSocket) + bind.log.Verbosef("TCP dial result: %v", err) + if err != nil { + bind.onSocketError(err) + return err + } + bind.tcp = tcp + return nil +} + +func (bind *StdNetBindTcp) Close() error { + bind.mu.Lock() + defer bind.mu.Unlock() + + bind.log.Verbosef("TCP/TLS: Close") + bind.closed = true + err := bind.closeInternal() + return err +} + +func (bind *StdNetBindTcp) closeInternal() error { + var err error + if bind.tls != nil { + err = bind.tls.Close() + bind.tls = nil + } + if bind.tcp != nil { + err = bind.tcp.Close() + bind.tcp = nil + } + bind.tunsafe.clear() + return err +} + +func (bind *StdNetBindTcp) getConn() (net.Conn, error) { + bind.mu.Lock() + defer bind.mu.Unlock() + + if bind.closed { + return nil, net.ErrClosed + } + + conn, err := bind.getConnInternal() + if err != nil { + bind.closed = true + } + return conn, err +} + +func (bind *StdNetBindTcp) getConnInternal() (net.Conn, error) { + if bind.tcp == nil { + err := bind.initTcp() + if err != nil { + return nil, err + } + } + if !bind.useTls { + return bind.tcp, nil + } + if bind.tls == nil { + err := bind.upgradeToTls() + if err != nil { + bind.closeInternal() + return nil, err + } + } + return bind.tls, nil +} + +func (bind *StdNetBindTcp) makeReceiveFunc() ReceiveFunc { + return func(packets [][]byte, sizes []int, eps []Endpoint) (int, error) { + var err error + eps[0] = bind.endpoint + if bind.currentPacket == nil || bind.currentPacket.Len() == 0 { + var conn net.Conn + conn, err = bind.getConn() + if err != nil { + bind.logError("recv getConn", err) + return 0, err + } + err = bind.readNextPacket(conn) + if err != nil { + if !errors.Is(err, net.ErrClosed) { + bind.onSocketError(err) + bind.logError("recv", err) + } + return 0, err + } + } + n, err := bind.currentPacket.Read(packets[0]) + if err != nil { + bind.logError("read packet", err) + return 0, err + } + sizes[0] = n + return 1, err + } +} + +func (bind *StdNetBindTcp) readNextPacket(conn net.Conn) error { + tunSafeHeader := make([]byte, tunSafeHeaderSize) + _, err := io.ReadFull(conn, tunSafeHeader) + if err != nil { + return err + } + + tunSafeType, payloadSize := parseTunSafeHeader(tunSafeHeader) + wgPacket, offset, err := bind.tunsafe.prepareWgPacket(tunSafeType, payloadSize) + if err != nil { + return err + } + + _, err = io.ReadFull(conn, wgPacket[offset:]) + if err != nil { + return err + } + + bind.tunsafe.onRecvPacket(tunSafeType, wgPacket) + bind.currentPacket = bytes.NewReader(wgPacket) + return nil +} + +func (bind *StdNetBindTcp) Send(buff [][]byte, endpoint Endpoint) error { + conn, err := bind.getConn() + if err != nil { + bind.logError("send conn", err) + return err + } + + // As single tcp socket can send only to single destination. We assume endpoint passed to ParseEndpoint will be + // the same. + boundEndpoint := asEndpoint(bind.endpoint.AddrPort) + if endpoint != boundEndpoint { + return errors.New("StdNetBindTcp.Send endpoints mismatch") + } + for i := range buff { + tunSafePacket := bind.tunsafe.wgToTunSafe(buff[i]) + _, err = conn.Write(tunSafePacket) + if err != nil { + bind.onSocketError(err) + bind.logError("send", err) + break + } + + } + return err +} + +func (bind *StdNetBindTcp) SetMark(_ uint32) error { + return nil +} + +func (bind *StdNetBindTcp) onSocketError(err error) { + if err != nil && !bind.closed { + bind.errorChan <- err + } +} + +func (bind *StdNetBindTcp) logError(t string, err error) { + if time.Now().After(lastErrorTimestamp.Add(5 * time.Second)) { + lastErrorTimestamp = time.Now() + bind.log.Errorf("TCP/TLS error %s: %v", t, err) + } +} + +// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint. +// This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates, +// but Endpoints are immutable, so we can re-use them. +var endpointPool = sync.Pool{ + New: func() any { + return make(map[netip.AddrPort]Endpoint) + }, +} + +// asEndpoint returns an Endpoint containing ap. +func asEndpoint(ap netip.AddrPort) Endpoint { + m := endpointPool.Get().(map[netip.AddrPort]Endpoint) + defer endpointPool.Put(m) + e, ok := m[ap] + if !ok { + e = Endpoint(&StdNetEndpoint{AddrPort: ap}) + m[ap] = e + } + return e +} diff --git a/conn/boundif_android.go b/conn/boundif_android.go index dd3ca5b07..a9f8e126e 100644 --- a/conn/boundif_android.go +++ b/conn/boundif_android.go @@ -32,3 +32,11 @@ func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) { } return } + +func (bind *StdNetBindTcp) PeekLookAtSocketFd4() (fd int, err error) { + return -1, err +} + +func (bind *StdNetBindTcp) PeekLookAtSocketFd6() (fd int, err error) { + return -1, err +} diff --git a/conn/conn.go b/conn/conn.go index 489cb3520..f87b24521 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -59,6 +59,11 @@ type Bind interface { GetOffloadInfo() string } +type Logger struct { + Verbosef func(format string, args ...any) + Errorf func(format string, args ...any) +} + // BindSocketToInterface is implemented by Bind objects that support being // tied to a single network interface. Used by wireguard-windows. type BindSocketToInterface interface { diff --git a/conn/default.go b/conn/default.go index b6f761b9e..8f93c172b 100644 --- a/conn/default.go +++ b/conn/default.go @@ -7,4 +7,6 @@ package conn -func NewDefaultBind() Bind { return NewStdNetBind() } +func NewDefaultBind(logger *Logger, errorChan chan<- error) Bind { + return CreateStdNetBind("tls", logger, errorChan) +} diff --git a/conn/tcp_tls_utils.go b/conn/tcp_tls_utils.go new file mode 100644 index 000000000..f1c454255 --- /dev/null +++ b/conn/tcp_tls_utils.go @@ -0,0 +1,175 @@ +/* + * Copyright (c) 2022. Proton AG + * + * This file is part of ProtonVPN. + * + * ProtonVPN is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * ProtonVPN is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with ProtonVPN. If not, see . + */ + +package conn + +import ( + "bytes" + cryptoRand "crypto/rand" + "encoding/binary" + "errors" + "math/big" + "math/rand" + "time" +) + +var wgDataPrefix = []byte{4, 0, 0, 0} +var wgDataHeaderSize = 16 +var wgDataPrefixSize = 8 // Wireguard data header without counter + +var tunSafeHeaderSize = 2 +var tunSafeNormalType = uint8(0b00) +var tunSafeDataType = uint8(0b10) + +type TunSafeData struct { + wgSendPrefix []byte + wgSendCount uint64 + wgRecvPrefix []byte + wgRecvCount uint64 +} + +var topLevelDomains = []string{"com", "net", "org", "it", "fr", "me", "ru", "cn", "es", "tr", "top", "xyz", "info"} + +func NewTunSafeData() *TunSafeData { + return &TunSafeData{ + wgRecvPrefix: make([]byte, 8), + wgSendPrefix: make([]byte, 8), + } +} + +// Returns (type, size) +func parseTunSafeHeader(header []byte) (byte, int) { + tunSafeType := header[0] >> 6 + size := (int(header[0])&0b00111111)<<8 | int(header[1]) + return tunSafeType, size +} + +func (tunSafe *TunSafeData) clear() { + tunSafe.wgSendCount = 0 + tunSafe.wgRecvCount = 0 +} + +func (tunSafe *TunSafeData) writeWgHeader(wgPacket []byte) { + buffer := new(bytes.Buffer) + buffer.Grow(len(tunSafe.wgRecvPrefix) + binary.Size(tunSafe.wgRecvCount)) + buffer.Write(tunSafe.wgRecvPrefix) + _ = binary.Write(buffer, binary.LittleEndian, tunSafe.wgRecvCount) + copy(wgPacket, buffer.Bytes()) +} + +func (tunSafe *TunSafeData) prepareWgPacket(tunSafeType byte, payloadSize int) ([]byte, int, error) { + var wgPacket []byte + offset := 0 + switch tunSafeType { + case tunSafeNormalType: + wgPacket = make([]byte, payloadSize) + case tunSafeDataType: + offset = wgDataHeaderSize + wgPacket = make([]byte, payloadSize+offset) + tunSafe.writeWgHeader(wgPacket) + default: + return nil, 0, errors.New("StdNetBindTcp: unknown TunSafe type") + } + return wgPacket, offset, nil +} + +func (tunSafe *TunSafeData) onRecvPacket(tunSafeType byte, wgPacket []byte) { + if tunSafeType == tunSafeNormalType { + isWgDataPacket := bytes.HasPrefix(wgPacket, wgDataPrefix) + if isWgDataPacket { + copy(tunSafe.wgRecvPrefix, wgPacket[:wgDataPrefixSize]) + countBuffer := bytes.NewBuffer(wgPacket[wgDataPrefixSize:wgDataHeaderSize]) + _ = binary.Read(countBuffer, binary.LittleEndian, &tunSafe.wgRecvCount) + } + } + tunSafe.wgRecvCount++ +} + +func (tunSafe *TunSafeData) wgToTunSafe(wgPacket []byte) []byte { + wgLen := len(wgPacket) + if wgLen < wgDataHeaderSize { + return wgToTunSafeNormal(wgPacket) + } + wgPrefix := wgPacket[:wgDataPrefixSize] + var wgCount uint64 + _ = binary.Read(bytes.NewReader(wgPacket[wgDataPrefixSize:wgDataHeaderSize]), binary.LittleEndian, &wgCount) + prefixMatch := bytes.Equal(wgPrefix, tunSafe.wgSendPrefix) + if prefixMatch && wgCount == tunSafe.wgSendCount+1 { + tunSafe.wgSendCount += 1 + return wgToTunSafeData(wgPacket) + } else { + isWgDataPacket := bytes.HasPrefix(wgPacket, wgDataPrefix) + if isWgDataPacket { + tunSafe.wgSendPrefix = wgPrefix + tunSafe.wgSendCount = wgCount + } + return wgToTunSafeNormal(wgPacket) + } +} + +func wgToTunSafeNormal(wgPacket []byte) []byte { + payloadSize := len(wgPacket) + result := make([]byte, payloadSize+tunSafeHeaderSize) + + // Tunsafe normal header + result[0] = uint8(payloadSize >> 8) + result[1] = uint8(payloadSize & 0xff) + + // Full packet + copy(result[tunSafeHeaderSize:], wgPacket) + + return result +} + +func wgToTunSafeData(wgPacket []byte) []byte { + payloadSize := len(wgPacket) - wgDataHeaderSize + result := make([]byte, payloadSize+tunSafeHeaderSize) + + // TunSafe data header + result[0] = uint8(0b10<<6 | payloadSize>>8) + result[1] = uint8(payloadSize & 0xff) + + // Packet without header + copy(result[tunSafeHeaderSize:], wgPacket[wgDataHeaderSize:]) + + return result +} + +func randomServerName() string { + charNum := int('z') - int('a') + 1 + size := 3 + randInt(10) + name := make([]byte, size) + for i := range name { + name[i] = byte(int('a') + randInt(charNum)) + } + return string(name) + "." + randItem(topLevelDomains) +} + +func randItem(list []string) string { + return list[randInt(len(list))] +} + +func randInt(n int) int { + size, err := cryptoRand.Int(cryptoRand.Reader, big.NewInt(int64(n))) + if err == nil { + return int(size.Int64()) + } + rand.Seed(time.Now().UnixNano()) + return rand.Intn(n) +} diff --git a/device/device.go b/device/device.go index 24ae1eab6..d32803620 100644 --- a/device/device.go +++ b/device/device.go @@ -6,6 +6,7 @@ package device import ( + "net" "runtime" "sync" "sync/atomic" @@ -95,6 +96,9 @@ type Device struct { isASecOn abool.AtomicBool aSecMux sync.RWMutex aSecCfg aSecCfgType + + handshakeStateChan chan<- HandshakeState + allowedSrcAddresses []net.IP } type aSecCfgType struct { @@ -110,6 +114,14 @@ type aSecCfgType struct { transportPacketMagicHeader uint32 } +type HandshakeState int + +const ( + HandshakeInit HandshakeState = iota + HandshakeSuccess = iota + HandshakeFail = iota +) + // deviceState represents the state of a Device. // There are three states: down, up, closed. // Transitions: @@ -189,6 +201,7 @@ func (device *Device) changeState(want deviceState) (err error) { // upLocked attempts to bring the device up and reports whether it succeeded. // The caller must hold device.state.mu and is responsible for updating device.state.state. func (device *Device) upLocked() error { + device.handshakeStateChan <- HandshakeInit if err := device.BindUpdate(); err != nil { device.log.Errorf("Unable to update bind: %v", err) return err @@ -301,9 +314,18 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { return nil } -func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device { +func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger, handshakeStateChan chan<- HandshakeState /*, allowedSrcAddresses string*/) *Device { device := new(Device) device.state.state.Store(uint32(deviceStateDown)) + device.handshakeStateChan = handshakeStateChan + /*var allowedSources = strings.Split(allowedSrcAddresses, ",") + device.allowedSrcAddresses = make([]net.IP, len(allowedSources)) + for i, source := range allowedSources { + ip := net.ParseIP(source) + if ip != nil { + device.allowedSrcAddresses[i] = ip + } + }*/ device.closed = make(chan struct{}) device.log = logger device.net.bind = bind @@ -417,6 +439,7 @@ func (device *Device) Close() { device.log.Verbosef("Device closed") close(device.closed) + close(device.handshakeStateChan) } func (device *Device) Wait() chan struct{} { diff --git a/device/device_test.go b/device/device_test.go index e6664a681..5f5d6c430 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -237,7 +237,8 @@ func genTestPair( if _, ok := tb.(*testing.B); ok && !testing.Verbose() { level = LogLevelError } - p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i))) + p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i)), + make(chan HandshakeState)) if err := p.dev.IpcSet(cfg[i]); err != nil { tb.Errorf("failed to configure device %d: %v", i, err) p.dev.Close() diff --git a/device/logger.go b/device/logger.go index 22b0df028..9577e2292 100644 --- a/device/logger.go +++ b/device/logger.go @@ -6,6 +6,7 @@ package device import ( + "github.com/amnezia-vpn/amneziawg-go/conn" "log" "os" ) @@ -16,8 +17,7 @@ import ( // They do not require a trailing newline in the format. // If nil, that level of logging will be silent. type Logger struct { - Verbosef func(format string, args ...any) - Errorf func(format string, args ...any) + conn.Logger } // Log levels for use with NewLogger. @@ -34,7 +34,7 @@ func DiscardLogf(format string, args ...any) {} // It logs at the specified log level and above. // It decorates log lines with the log level, date, time, and prepend. func NewLogger(level int, prepend string) *Logger { - logger := &Logger{DiscardLogf, DiscardLogf} + logger := &Logger{conn.Logger{Verbosef: DiscardLogf, Errorf: DiscardLogf}} logf := func(prefix string) func(string, ...any) { return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf } diff --git a/device/noise_test.go b/device/noise_test.go index 075b6d304..ff178f876 100644 --- a/device/noise_test.go +++ b/device/noise_test.go @@ -39,7 +39,7 @@ func randDevice(t *testing.T) *Device { } tun := tuntest.NewChannelTUN() logger := NewLogger(LogLevelError, "") - device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger) + device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger, make(chan HandshakeState)) device.SetPrivateKey(sk) return device } diff --git a/device/receive.go b/device/receive.go index 66c1a32de..9016ebd47 100644 --- a/device/receive.go +++ b/device/receive.go @@ -406,6 +406,7 @@ func (device *Device) RoutineHandshake(id int) { // update endpoint peer.SetEndpointFromPacket(elem.endpoint) + device.handshakeStateChan <- HandshakeSuccess device.log.Verbosef("%v - Received handshake initiation", peer) peer.rxBytes.Add(uint64(len(elem.packet))) @@ -434,6 +435,7 @@ func (device *Device) RoutineHandshake(id int) { // update endpoint peer.SetEndpointFromPacket(elem.endpoint) + device.handshakeStateChan <- HandshakeSuccess device.log.Verbosef("%v - Received handshake response", peer) peer.rxBytes.Add(uint64(len(elem.packet))) diff --git a/device/send.go b/device/send.go index 1b4406d53..4af8028f9 100644 --- a/device/send.go +++ b/device/send.go @@ -175,6 +175,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { err = peer.SendBuffers(sendBuffer) if err != nil { + peer.device.handshakeStateChan <- HandshakeFail peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) } peer.timersHandshakeInitiated() @@ -318,12 +319,14 @@ func (device *Device) RoutineReadFromTUN() { // lookup peer var peer *Peer + var src []byte switch elem.packet[0] >> 4 { case 4: if len(elem.packet) < ipv4.HeaderLen { continue } dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] + src = elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] peer = device.allowedips.Lookup(dst) case 6: @@ -331,6 +334,7 @@ func (device *Device) RoutineReadFromTUN() { continue } dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] + src = elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] peer = device.allowedips.Lookup(dst) default: @@ -340,6 +344,13 @@ func (device *Device) RoutineReadFromTUN() { if peer == nil { continue } + + // Drop packets with unexpected src IP. + if device.allowedSrcAddresses != nil && device.isUnexpectedSrcIP(src) { + //device.log.Verbosef("Dropping packet with unexpected src IP: %v (allowed = %v)", src, device.allowedSrcAddresses) + continue + } + elemsForPeer, ok := elemsByPeer[peer] if !ok { elemsForPeer = device.GetOutboundElementsContainer() @@ -383,6 +394,15 @@ func (device *Device) RoutineReadFromTUN() { } } +func (device *Device) isUnexpectedSrcIP(src []byte) bool { + for _, allowed := range device.allowedSrcAddresses { + if allowed.Equal(src) { + return false + } + } + return true +} + func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) { for { select { diff --git a/device/statemanager.go b/device/statemanager.go new file mode 100644 index 000000000..e6634d0d9 --- /dev/null +++ b/device/statemanager.go @@ -0,0 +1,247 @@ +/* + * Copyright (c) 2022. Proton AG + * + * This file is part of ProtonVPN. + * + * ProtonVPN is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * ProtonVPN is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with ProtonVPN. If not, see . + */ + +package device + +import ( + "strings" + "sync" + "time" +) + +var initialRestartDelay = 4 * time.Second +var maxRestartDelay = 32 * time.Second +var resetRestartDelay = 10 * time.Minute +var timeNow = time.Now + +// WireGuardStateManager handles enabling/disabling WireGuard in response to network availability changes, serves +// connection state to the client and resets WireGuard connection in response to socket and handshake errors. +// +// Client should call SetNetworkAvailable every time network changes - WireGuard will remain inactive until +// SetNetworkAvailable(true) is called. When SetNetworkAvailable(true) is called twice in a row it'll be interpreted +// as network change and trigger reset of the connection (on TCP/TLS socket). +// +// GetState is blocking and therefore should run in dedicated thread in a loop. After Close is called GetState will +// return immediately with WireGuardDisabled. +type WireGuardStateManager struct { + HandshakeStateChan chan HandshakeState + SocketErrChan chan error + networkAvailableChan chan bool + closeChan chan bool + + stateChan chan WireGuardState + isNetAvailable bool + + lastRestart time.Time + transmission string + + log *Logger + mu sync.Mutex + closed bool + startedTimestamp time.Time + nextRestartDelay time.Duration +} + +type WireGuardState int + +const ( + WireGuardDisabled WireGuardState = iota + WireGuardConnecting + WireGuardConnected + WireGuardError + WireGuardWaitingForNetwork +) + +type BaseDevice interface { + Up() error + Down() error +} + +//goland:noinspection GoUnusedExportedFunction +func NewWireGuardStateManager(log *Logger, transmission string) *WireGuardStateManager { + return &WireGuardStateManager{ + networkAvailableChan: make(chan bool, 100), + SocketErrChan: make(chan error, 100), + HandshakeStateChan: make(chan HandshakeState, 100), + closeChan: make(chan bool, 1), + stateChan: make(chan WireGuardState, 1), + transmission: transmission, + log: log, + nextRestartDelay: initialRestartDelay, + lastRestart: timeNow(), + } +} + +func (man *WireGuardStateManager) Start(device BaseDevice) { + go man.handlerLoop(device) +} + +func (man *WireGuardStateManager) GetState() WireGuardState { + state, ok := <-man.stateChan + if !ok { + return -1 + } + return state +} + +func (man *WireGuardStateManager) Close() { + man.log.Verbosef("StateManager: closing") + man.closed = true + go func() { + man.closeChan <- true + man.stateChan <- WireGuardDisabled + close(man.stateChan) + }() +} + +func (man *WireGuardStateManager) SetNetworkAvailable(available bool) { + man.networkAvailableChan <- available +} + +func (man *WireGuardStateManager) handlerLoop(device BaseDevice) { + man.log.Verbosef("StateManager: start loop") + // Ugly way of emulating optional bool type + var wasNetAvailablePtr *bool = nil + for { + select { + case netAvailable := <-man.networkAvailableChan: + man.onNetworkAvailabilityChange(device, wasNetAvailablePtr, netAvailable) + man.isNetAvailable = netAvailable + wasNetAvailablePtr = &man.isNetAvailable + case socketErr := <-man.SocketErrChan: + if man.isNetAvailable { + man.handleSocketErr(device, socketErr) + } + case handshakeState := <-man.HandshakeStateChan: + if man.isNetAvailable { + man.handleHandshakeState(device, handshakeState) + } + case <-man.closeChan: + man.log.Verbosef("StateManager: end loop") + return + } + } +} + +func (man *WireGuardStateManager) onNetworkAvailabilityChange(device BaseDevice, wasAvailable *bool, available bool) { + if !available { + man.postState(WireGuardWaitingForNetwork) + } + if available && wasAvailable == nil { + man.log.Verbosef("StateManager: network on") + man.setActive(device, true) + man.startedTimestamp = timeNow() + } else if available && *wasAvailable && !man.startedTimestamp.IsZero() && + timeNow().After(man.startedTimestamp.Add(5*time.Second)) { + // Ignore network changes at the very beginning of connection as those might be false positive + // (VPN tunnel opening) + man.log.Verbosef("StateManager: network change detected") + man.maybeRestart(device) + } else if available && !*wasAvailable { + man.log.Verbosef("StateManager: network back") + man.setActive(device, true) + } else if !available && wasAvailable != nil && *wasAvailable { + man.log.Verbosef("StateManager: network gone") + man.setActive(device, false) + } +} + +func (man *WireGuardStateManager) setActive(device BaseDevice, activate bool) { + man.mu.Lock() + defer man.mu.Unlock() + + var err error + if activate { + man.postState(WireGuardConnecting) + err = device.Up() + } else { + err = device.Down() + } + if err != nil { + man.log.Errorf("StateManager: setActive(%t) error %v", activate, err) + man.postState(WireGuardError) + } +} + +func (man *WireGuardStateManager) handleSocketErr(device BaseDevice, err error) { + if err != nil { + errStr := err.Error() + if strings.Contains(errStr, "broken pipe") || + strings.Contains(errStr, "connection reset by peer") { + man.log.Errorf("StateManager: %s", errStr) + man.maybeRestart(device) + } + } +} + +func (man *WireGuardStateManager) handleHandshakeState(device BaseDevice, state HandshakeState) { + switch state { + case HandshakeInit: + man.postState(WireGuardConnecting) + case HandshakeSuccess: + man.postState(WireGuardConnected) + case HandshakeFail: + man.postState(WireGuardError) + man.maybeRestart(device) + } +} + +func (man *WireGuardStateManager) maybeRestart(device BaseDevice) { + if man.transmission == "udp" { + return + } + + man.mu.Lock() + defer man.mu.Unlock() + + if man.shouldRestart() { + man.log.Verbosef("StateManager: restarting") + man.postState(WireGuardConnecting) + device.Down() + if !man.closed { + device.Up() + } + } +} + +// Don't restart too often, grow delay exponentially up to a limit and after some time reset to small initial value +func (man *WireGuardStateManager) shouldRestart() bool { + now := timeNow() + restart := now.After(man.lastRestart.Add(man.nextRestartDelay)) + if restart { + if now.After(man.lastRestart.Add(resetRestartDelay)) { + man.nextRestartDelay = initialRestartDelay + } else { + man.nextRestartDelay *= 2 + if man.nextRestartDelay > maxRestartDelay { + man.nextRestartDelay = maxRestartDelay + } + } + man.lastRestart = now + } + return restart +} + +func (man *WireGuardStateManager) postState(state WireGuardState) { + go func() { + if !man.closed && (man.isNetAvailable || state == WireGuardWaitingForNetwork) { + man.stateChan <- state + } + }() +} diff --git a/device/statemanager_test.go b/device/statemanager_test.go new file mode 100644 index 000000000..b2b91bc3b --- /dev/null +++ b/device/statemanager_test.go @@ -0,0 +1,152 @@ +/* + * Copyright (c) 2022. Proton AG + * + * This file is part of ProtonVPN. + * + * ProtonVPN is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * ProtonVPN is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with ProtonVPN. If not, see . + */ + +package device + +import ( + "errors" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +var timeMs int64 = 0 +var mockDevice MockDevice +var manager *WireGuardStateManager +var lastState WireGuardState + +type MockDevice struct { + isUp bool + upCount int +} + +func (dev *MockDevice) Up() error { + dev.isUp = true + dev.upCount++ + return nil +} + +func (dev *MockDevice) Down() error { + dev.isUp = false + return nil +} + +func setup() { + timeMs = 0 + timeNow = func() time.Time { return time.UnixMilli(timeMs) } + mockDevice.isUp = false + + manager = NewWireGuardStateManager(NewLogger(LogLevelVerbose, ""), "tcp") + manager.Start(&mockDevice) + lastState = WireGuardDisabled + go func() { + for lastState != -1 { + lastState = manager.GetState() + } + }() +} + +func setdown() { + manager.Close() +} + +func TestWireGuardStateManager_shouldRestart(t *testing.T) { + assert := assert.New(t) + setup() + defer setdown() + + assert.Equal(initialRestartDelay, manager.nextRestartDelay) + + assert.Equal(false, manager.shouldRestart()) + timeMs += initialRestartDelay.Milliseconds() + assert.Equal(false, manager.shouldRestart()) + timeMs += 1 + assert.Equal(true, manager.shouldRestart()) + + assert.Equal(2*initialRestartDelay, manager.nextRestartDelay) + assert.Equal(false, manager.shouldRestart()) + timeMs += 2 * initialRestartDelay.Milliseconds() + assert.Equal(false, manager.shouldRestart()) + timeMs += 1 + assert.Equal(true, manager.shouldRestart()) + + timeMs += resetRestartDelay.Milliseconds() + 1 + assert.Equal(true, manager.shouldRestart()) + assert.Equal(initialRestartDelay, manager.nextRestartDelay) +} + +func TestWireGuardStateManager_networkStartsAndStopsDevice(t *testing.T) { + assert := assert.New(t) + setup() + defer setdown() + + assert.Equal(false, mockDevice.isUp) + manager.SetNetworkAvailable(true) + time.Sleep(time.Millisecond) // Poor substitute for advanceUntilIdle, make sure goroutines finish before checking + assert.Equal(true, mockDevice.isUp) + assert.Equal(WireGuardConnecting, lastState) + manager.SetNetworkAvailable(false) + time.Sleep(time.Millisecond) + assert.Equal(WireGuardWaitingForNetwork, lastState) + assert.Equal(false, mockDevice.isUp) +} + +func TestWireGuardStateManager_happyConnectionPath(t *testing.T) { + assert := assert.New(t) + setup() + defer setdown() + + manager.SetNetworkAvailable(true) + time.Sleep(time.Millisecond) + manager.HandshakeStateChan <- HandshakeSuccess + time.Sleep(time.Millisecond) + assert.Equal(WireGuardConnected, lastState) + assert.Equal(true, mockDevice.isUp) +} + +func TestWireGuardStateManager_handshakeFailCausesRestart(t *testing.T) { + assert := assert.New(t) + setup() + defer setdown() + + manager.SetNetworkAvailable(true) + time.Sleep(time.Millisecond) + manager.HandshakeStateChan <- HandshakeFail + time.Sleep(time.Millisecond) + assert.Equal(WireGuardError, lastState) + timeMs += initialRestartDelay.Milliseconds() + 1 + manager.HandshakeStateChan <- HandshakeFail + time.Sleep(time.Millisecond) + assert.Equal(WireGuardConnecting, lastState) + assert.Equal(2, mockDevice.upCount) +} + +func TestWireGuardStateManager_brokenPipeCausesRestart(t *testing.T) { + assert := assert.New(t) + setup() + defer setdown() + + manager.SetNetworkAvailable(true) + timeMs += initialRestartDelay.Milliseconds() + 1 + time.Sleep(time.Millisecond) + manager.SocketErrChan <- errors.New("broken pipe") + time.Sleep(time.Millisecond) + assert.Equal(WireGuardConnecting, lastState) + assert.Equal(2, mockDevice.upCount) +} diff --git a/device/timers.go b/device/timers.go index d4a4ed4e5..72b9a2c2b 100644 --- a/device/timers.go +++ b/device/timers.go @@ -78,6 +78,7 @@ func (peer *Peer) timersActive() bool { func expiredRetransmitHandshake(peer *Peer) { if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes { + peer.device.handshakeStateChan <- HandshakeFail peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2) if peer.timersActive() { @@ -97,6 +98,7 @@ func expiredRetransmitHandshake(peer *Peer) { } } else { peer.timers.handshakeAttempts.Add(1) + peer.device.handshakeStateChan <- HandshakeFail peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1) /* We clear the endpoint address src address, in case this is the cause of trouble. */ diff --git a/go.mod b/go.mod index 33182eea1..c018932ec 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,8 @@ module github.com/amnezia-vpn/amneziawg-go go 1.20 require ( + github.com/refraction-networking/utls v1.1.5 + github.com/stretchr/testify v1.8.0 github.com/tevino/abool/v2 v2.1.0 golang.org/x/crypto v0.19.0 golang.org/x/net v0.21.0 @@ -15,3 +17,11 @@ require ( github.com/google/btree v1.0.1 // indirect golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect ) + +require ( + github.com/andybalholm/brotli v1.0.4 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/klauspost/compress v1.15.9 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum index 0e6f733c7..06741f43e 100644 --- a/go.sum +++ b/go.sum @@ -1,16 +1,41 @@ -github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= +github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= +github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= +github.com/klauspost/compress v1.15.9 h1:wKRjX6JRtDdrE9qwa4b/Cip7ACOshUI4smpCQanqjSY= +github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/refraction-networking/utls v1.1.5 h1:JtrojoNhbUQkBqEg05sP3gDgDj6hIEAAVKbI9lx4n6w= +github.com/refraction-networking/utls v1.1.5/go.mod h1:jRQxtYi7nkq1p28HF2lwOH5zQm9aC8rpK0O9lIIzGh8= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c= github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY= +golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220909164309-bea034e7d591/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= -gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY= diff --git a/main.go b/main.go index 775372c90..baff1e33c 100644 --- a/main.go +++ b/main.go @@ -222,11 +222,11 @@ func main() { return } - device := device.NewDevice(tdev, conn.NewDefaultBind(), logger) + errs := make(chan error) + device := device.NewDevice(tdev, conn.NewDefaultBind(&logger.Logger, errs), logger, make(chan device.HandshakeState)) logger.Verbosef("Device started") - errs := make(chan error) term := make(chan os.Signal, 1) uapi, err := ipc.UAPIListen(interfaceName, fileUAPI)