diff --git a/conn/bind_std_tcp.go b/conn/bind_std_tcp.go
new file mode 100644
index 000000000..299c242c8
--- /dev/null
+++ b/conn/bind_std_tcp.go
@@ -0,0 +1,331 @@
+/*
+ * Copyright (c) 2022. Proton AG
+ *
+ * This file is part of ProtonVPN.
+ *
+ * ProtonVPN is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * ProtonVPN is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with ProtonVPN. If not, see .
+ */
+
+package conn
+
+import (
+ "bytes"
+ "errors"
+ "io"
+ "net"
+ "net/netip"
+ "sync"
+ "time"
+
+ tls "github.com/refraction-networking/utls"
+)
+
+var lastErrorTimestamp time.Time
+
+type StdNetBindTcp struct {
+ mu sync.Mutex
+
+ useTls bool
+ tcp *net.TCPConn
+ tls *tls.UConn
+ endpoint *StdNetEndpoint
+ currentPacket *bytes.Reader
+ closed bool
+ log *Logger
+ errorChan chan<- error
+ protectSocket func(fd int) int
+
+ tunsafe *TunSafeData
+}
+
+//goland:noinspection GoUnusedExportedFunction
+func CreateStdNetBind(socketType string, log *Logger, errorChan chan<- error) Bind {
+ if socketType == "udp" {
+ return NewStdNetBind()
+ } else {
+ return &StdNetBindTcp{tunsafe: NewTunSafeData(), useTls: socketType == "tls", log: log, errorChan: errorChan}
+ }
+}
+
+func (s *StdNetBindTcp) BatchSize() int {
+ return 1
+}
+
+func (s *StdNetBindTcp) GetOffloadInfo() string {
+ return ""
+}
+
+func (bind *StdNetBindTcp) ParseEndpoint(s string) (Endpoint, error) {
+ e, err := netip.ParseAddrPort(s)
+ if err == nil {
+ bind.endpoint = &StdNetEndpoint{AddrPort: e}
+ }
+ if err != nil {
+ return nil, err
+ }
+ return asEndpoint(e), err
+}
+
+func dialTcp(addr string, protectSocket func(fd int) int) (*net.TCPConn, int, error) {
+ dialer := net.Dialer{Timeout: 5 * time.Second}
+ netConn, err := dialer.Dial("tcp", addr)
+ if err != nil {
+ return nil, 0, err
+ }
+
+ conn := netConn.(*net.TCPConn)
+ conn.SetLinger(0)
+
+ // Retrieve port.
+ laddr := conn.LocalAddr()
+ taddr, err := net.ResolveTCPAddr(
+ laddr.Network(),
+ laddr.String(),
+ )
+ if err != nil {
+ _ = conn.Close()
+ return nil, 0, err
+ }
+ return conn, taddr.Port, nil
+}
+
+func (bind *StdNetBindTcp) upgradeToTls() error {
+ tlsConf := &tls.Config{
+ InsecureSkipVerify: true,
+ ServerName: randomServerName(),
+ }
+
+ conn := tls.UClient(bind.tcp, tlsConf, tls.HelloChrome_Auto)
+ conn.SetDeadline(time.Now().Add(5 * time.Second))
+ bind.log.Verbosef("TLS: Starting handshake")
+ err := conn.Handshake()
+ bind.log.Verbosef("TLS: Handshake result: %v", err)
+ conn.SetDeadline(time.Time{})
+
+ // On some devices (e.g. Samsung S21 FE) we see first WireGuard handshake failing on TLS socket and adding small
+ // delay seems to fix that - issue is likely with timing on the server side, but couldn't find server-side fix.
+ time.Sleep(100 * time.Millisecond)
+
+ if err == nil {
+ bind.tls = conn
+ } else {
+ bind.onSocketError(err)
+ conn.Close()
+ }
+ return err
+}
+
+func (bind *StdNetBindTcp) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
+ bind.mu.Lock()
+ defer bind.mu.Unlock()
+
+ bind.log.Verbosef("TCP/TLS: Open %d", uport)
+ bind.closed = false
+ return []ReceiveFunc{bind.makeReceiveFunc()}, uport, nil
+}
+
+func (bind *StdNetBindTcp) initTcp() error {
+ var err error
+
+ if bind.tcp != nil {
+ return ErrBindAlreadyOpen
+ }
+
+ var tcp *net.TCPConn
+
+ tcp, _, err = dialTcp(bind.endpoint.DstToString(), bind.protectSocket)
+ bind.log.Verbosef("TCP dial result: %v", err)
+ if err != nil {
+ bind.onSocketError(err)
+ return err
+ }
+ bind.tcp = tcp
+ return nil
+}
+
+func (bind *StdNetBindTcp) Close() error {
+ bind.mu.Lock()
+ defer bind.mu.Unlock()
+
+ bind.log.Verbosef("TCP/TLS: Close")
+ bind.closed = true
+ err := bind.closeInternal()
+ return err
+}
+
+func (bind *StdNetBindTcp) closeInternal() error {
+ var err error
+ if bind.tls != nil {
+ err = bind.tls.Close()
+ bind.tls = nil
+ }
+ if bind.tcp != nil {
+ err = bind.tcp.Close()
+ bind.tcp = nil
+ }
+ bind.tunsafe.clear()
+ return err
+}
+
+func (bind *StdNetBindTcp) getConn() (net.Conn, error) {
+ bind.mu.Lock()
+ defer bind.mu.Unlock()
+
+ if bind.closed {
+ return nil, net.ErrClosed
+ }
+
+ conn, err := bind.getConnInternal()
+ if err != nil {
+ bind.closed = true
+ }
+ return conn, err
+}
+
+func (bind *StdNetBindTcp) getConnInternal() (net.Conn, error) {
+ if bind.tcp == nil {
+ err := bind.initTcp()
+ if err != nil {
+ return nil, err
+ }
+ }
+ if !bind.useTls {
+ return bind.tcp, nil
+ }
+ if bind.tls == nil {
+ err := bind.upgradeToTls()
+ if err != nil {
+ bind.closeInternal()
+ return nil, err
+ }
+ }
+ return bind.tls, nil
+}
+
+func (bind *StdNetBindTcp) makeReceiveFunc() ReceiveFunc {
+ return func(packets [][]byte, sizes []int, eps []Endpoint) (int, error) {
+ var err error
+ eps[0] = bind.endpoint
+ if bind.currentPacket == nil || bind.currentPacket.Len() == 0 {
+ var conn net.Conn
+ conn, err = bind.getConn()
+ if err != nil {
+ bind.logError("recv getConn", err)
+ return 0, err
+ }
+ err = bind.readNextPacket(conn)
+ if err != nil {
+ if !errors.Is(err, net.ErrClosed) {
+ bind.onSocketError(err)
+ bind.logError("recv", err)
+ }
+ return 0, err
+ }
+ }
+ n, err := bind.currentPacket.Read(packets[0])
+ if err != nil {
+ bind.logError("read packet", err)
+ return 0, err
+ }
+ sizes[0] = n
+ return 1, err
+ }
+}
+
+func (bind *StdNetBindTcp) readNextPacket(conn net.Conn) error {
+ tunSafeHeader := make([]byte, tunSafeHeaderSize)
+ _, err := io.ReadFull(conn, tunSafeHeader)
+ if err != nil {
+ return err
+ }
+
+ tunSafeType, payloadSize := parseTunSafeHeader(tunSafeHeader)
+ wgPacket, offset, err := bind.tunsafe.prepareWgPacket(tunSafeType, payloadSize)
+ if err != nil {
+ return err
+ }
+
+ _, err = io.ReadFull(conn, wgPacket[offset:])
+ if err != nil {
+ return err
+ }
+
+ bind.tunsafe.onRecvPacket(tunSafeType, wgPacket)
+ bind.currentPacket = bytes.NewReader(wgPacket)
+ return nil
+}
+
+func (bind *StdNetBindTcp) Send(buff [][]byte, endpoint Endpoint) error {
+ conn, err := bind.getConn()
+ if err != nil {
+ bind.logError("send conn", err)
+ return err
+ }
+
+ // As single tcp socket can send only to single destination. We assume endpoint passed to ParseEndpoint will be
+ // the same.
+ boundEndpoint := asEndpoint(bind.endpoint.AddrPort)
+ if endpoint != boundEndpoint {
+ return errors.New("StdNetBindTcp.Send endpoints mismatch")
+ }
+ for i := range buff {
+ tunSafePacket := bind.tunsafe.wgToTunSafe(buff[i])
+ _, err = conn.Write(tunSafePacket)
+ if err != nil {
+ bind.onSocketError(err)
+ bind.logError("send", err)
+ break
+ }
+
+ }
+ return err
+}
+
+func (bind *StdNetBindTcp) SetMark(_ uint32) error {
+ return nil
+}
+
+func (bind *StdNetBindTcp) onSocketError(err error) {
+ if err != nil && !bind.closed {
+ bind.errorChan <- err
+ }
+}
+
+func (bind *StdNetBindTcp) logError(t string, err error) {
+ if time.Now().After(lastErrorTimestamp.Add(5 * time.Second)) {
+ lastErrorTimestamp = time.Now()
+ bind.log.Errorf("TCP/TLS error %s: %v", t, err)
+ }
+}
+
+// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
+// This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates,
+// but Endpoints are immutable, so we can re-use them.
+var endpointPool = sync.Pool{
+ New: func() any {
+ return make(map[netip.AddrPort]Endpoint)
+ },
+}
+
+// asEndpoint returns an Endpoint containing ap.
+func asEndpoint(ap netip.AddrPort) Endpoint {
+ m := endpointPool.Get().(map[netip.AddrPort]Endpoint)
+ defer endpointPool.Put(m)
+ e, ok := m[ap]
+ if !ok {
+ e = Endpoint(&StdNetEndpoint{AddrPort: ap})
+ m[ap] = e
+ }
+ return e
+}
diff --git a/conn/boundif_android.go b/conn/boundif_android.go
index dd3ca5b07..a9f8e126e 100644
--- a/conn/boundif_android.go
+++ b/conn/boundif_android.go
@@ -32,3 +32,11 @@ func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
}
return
}
+
+func (bind *StdNetBindTcp) PeekLookAtSocketFd4() (fd int, err error) {
+ return -1, err
+}
+
+func (bind *StdNetBindTcp) PeekLookAtSocketFd6() (fd int, err error) {
+ return -1, err
+}
diff --git a/conn/conn.go b/conn/conn.go
index 489cb3520..f87b24521 100644
--- a/conn/conn.go
+++ b/conn/conn.go
@@ -59,6 +59,11 @@ type Bind interface {
GetOffloadInfo() string
}
+type Logger struct {
+ Verbosef func(format string, args ...any)
+ Errorf func(format string, args ...any)
+}
+
// BindSocketToInterface is implemented by Bind objects that support being
// tied to a single network interface. Used by wireguard-windows.
type BindSocketToInterface interface {
diff --git a/conn/default.go b/conn/default.go
index b6f761b9e..8f93c172b 100644
--- a/conn/default.go
+++ b/conn/default.go
@@ -7,4 +7,6 @@
package conn
-func NewDefaultBind() Bind { return NewStdNetBind() }
+func NewDefaultBind(logger *Logger, errorChan chan<- error) Bind {
+ return CreateStdNetBind("tls", logger, errorChan)
+}
diff --git a/conn/tcp_tls_utils.go b/conn/tcp_tls_utils.go
new file mode 100644
index 000000000..f1c454255
--- /dev/null
+++ b/conn/tcp_tls_utils.go
@@ -0,0 +1,175 @@
+/*
+ * Copyright (c) 2022. Proton AG
+ *
+ * This file is part of ProtonVPN.
+ *
+ * ProtonVPN is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * ProtonVPN is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with ProtonVPN. If not, see .
+ */
+
+package conn
+
+import (
+ "bytes"
+ cryptoRand "crypto/rand"
+ "encoding/binary"
+ "errors"
+ "math/big"
+ "math/rand"
+ "time"
+)
+
+var wgDataPrefix = []byte{4, 0, 0, 0}
+var wgDataHeaderSize = 16
+var wgDataPrefixSize = 8 // Wireguard data header without counter
+
+var tunSafeHeaderSize = 2
+var tunSafeNormalType = uint8(0b00)
+var tunSafeDataType = uint8(0b10)
+
+type TunSafeData struct {
+ wgSendPrefix []byte
+ wgSendCount uint64
+ wgRecvPrefix []byte
+ wgRecvCount uint64
+}
+
+var topLevelDomains = []string{"com", "net", "org", "it", "fr", "me", "ru", "cn", "es", "tr", "top", "xyz", "info"}
+
+func NewTunSafeData() *TunSafeData {
+ return &TunSafeData{
+ wgRecvPrefix: make([]byte, 8),
+ wgSendPrefix: make([]byte, 8),
+ }
+}
+
+// Returns (type, size)
+func parseTunSafeHeader(header []byte) (byte, int) {
+ tunSafeType := header[0] >> 6
+ size := (int(header[0])&0b00111111)<<8 | int(header[1])
+ return tunSafeType, size
+}
+
+func (tunSafe *TunSafeData) clear() {
+ tunSafe.wgSendCount = 0
+ tunSafe.wgRecvCount = 0
+}
+
+func (tunSafe *TunSafeData) writeWgHeader(wgPacket []byte) {
+ buffer := new(bytes.Buffer)
+ buffer.Grow(len(tunSafe.wgRecvPrefix) + binary.Size(tunSafe.wgRecvCount))
+ buffer.Write(tunSafe.wgRecvPrefix)
+ _ = binary.Write(buffer, binary.LittleEndian, tunSafe.wgRecvCount)
+ copy(wgPacket, buffer.Bytes())
+}
+
+func (tunSafe *TunSafeData) prepareWgPacket(tunSafeType byte, payloadSize int) ([]byte, int, error) {
+ var wgPacket []byte
+ offset := 0
+ switch tunSafeType {
+ case tunSafeNormalType:
+ wgPacket = make([]byte, payloadSize)
+ case tunSafeDataType:
+ offset = wgDataHeaderSize
+ wgPacket = make([]byte, payloadSize+offset)
+ tunSafe.writeWgHeader(wgPacket)
+ default:
+ return nil, 0, errors.New("StdNetBindTcp: unknown TunSafe type")
+ }
+ return wgPacket, offset, nil
+}
+
+func (tunSafe *TunSafeData) onRecvPacket(tunSafeType byte, wgPacket []byte) {
+ if tunSafeType == tunSafeNormalType {
+ isWgDataPacket := bytes.HasPrefix(wgPacket, wgDataPrefix)
+ if isWgDataPacket {
+ copy(tunSafe.wgRecvPrefix, wgPacket[:wgDataPrefixSize])
+ countBuffer := bytes.NewBuffer(wgPacket[wgDataPrefixSize:wgDataHeaderSize])
+ _ = binary.Read(countBuffer, binary.LittleEndian, &tunSafe.wgRecvCount)
+ }
+ }
+ tunSafe.wgRecvCount++
+}
+
+func (tunSafe *TunSafeData) wgToTunSafe(wgPacket []byte) []byte {
+ wgLen := len(wgPacket)
+ if wgLen < wgDataHeaderSize {
+ return wgToTunSafeNormal(wgPacket)
+ }
+ wgPrefix := wgPacket[:wgDataPrefixSize]
+ var wgCount uint64
+ _ = binary.Read(bytes.NewReader(wgPacket[wgDataPrefixSize:wgDataHeaderSize]), binary.LittleEndian, &wgCount)
+ prefixMatch := bytes.Equal(wgPrefix, tunSafe.wgSendPrefix)
+ if prefixMatch && wgCount == tunSafe.wgSendCount+1 {
+ tunSafe.wgSendCount += 1
+ return wgToTunSafeData(wgPacket)
+ } else {
+ isWgDataPacket := bytes.HasPrefix(wgPacket, wgDataPrefix)
+ if isWgDataPacket {
+ tunSafe.wgSendPrefix = wgPrefix
+ tunSafe.wgSendCount = wgCount
+ }
+ return wgToTunSafeNormal(wgPacket)
+ }
+}
+
+func wgToTunSafeNormal(wgPacket []byte) []byte {
+ payloadSize := len(wgPacket)
+ result := make([]byte, payloadSize+tunSafeHeaderSize)
+
+ // Tunsafe normal header
+ result[0] = uint8(payloadSize >> 8)
+ result[1] = uint8(payloadSize & 0xff)
+
+ // Full packet
+ copy(result[tunSafeHeaderSize:], wgPacket)
+
+ return result
+}
+
+func wgToTunSafeData(wgPacket []byte) []byte {
+ payloadSize := len(wgPacket) - wgDataHeaderSize
+ result := make([]byte, payloadSize+tunSafeHeaderSize)
+
+ // TunSafe data header
+ result[0] = uint8(0b10<<6 | payloadSize>>8)
+ result[1] = uint8(payloadSize & 0xff)
+
+ // Packet without header
+ copy(result[tunSafeHeaderSize:], wgPacket[wgDataHeaderSize:])
+
+ return result
+}
+
+func randomServerName() string {
+ charNum := int('z') - int('a') + 1
+ size := 3 + randInt(10)
+ name := make([]byte, size)
+ for i := range name {
+ name[i] = byte(int('a') + randInt(charNum))
+ }
+ return string(name) + "." + randItem(topLevelDomains)
+}
+
+func randItem(list []string) string {
+ return list[randInt(len(list))]
+}
+
+func randInt(n int) int {
+ size, err := cryptoRand.Int(cryptoRand.Reader, big.NewInt(int64(n)))
+ if err == nil {
+ return int(size.Int64())
+ }
+ rand.Seed(time.Now().UnixNano())
+ return rand.Intn(n)
+}
diff --git a/device/device.go b/device/device.go
index 24ae1eab6..d32803620 100644
--- a/device/device.go
+++ b/device/device.go
@@ -6,6 +6,7 @@
package device
import (
+ "net"
"runtime"
"sync"
"sync/atomic"
@@ -95,6 +96,9 @@ type Device struct {
isASecOn abool.AtomicBool
aSecMux sync.RWMutex
aSecCfg aSecCfgType
+
+ handshakeStateChan chan<- HandshakeState
+ allowedSrcAddresses []net.IP
}
type aSecCfgType struct {
@@ -110,6 +114,14 @@ type aSecCfgType struct {
transportPacketMagicHeader uint32
}
+type HandshakeState int
+
+const (
+ HandshakeInit HandshakeState = iota
+ HandshakeSuccess = iota
+ HandshakeFail = iota
+)
+
// deviceState represents the state of a Device.
// There are three states: down, up, closed.
// Transitions:
@@ -189,6 +201,7 @@ func (device *Device) changeState(want deviceState) (err error) {
// upLocked attempts to bring the device up and reports whether it succeeded.
// The caller must hold device.state.mu and is responsible for updating device.state.state.
func (device *Device) upLocked() error {
+ device.handshakeStateChan <- HandshakeInit
if err := device.BindUpdate(); err != nil {
device.log.Errorf("Unable to update bind: %v", err)
return err
@@ -301,9 +314,18 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
return nil
}
-func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
+func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger, handshakeStateChan chan<- HandshakeState /*, allowedSrcAddresses string*/) *Device {
device := new(Device)
device.state.state.Store(uint32(deviceStateDown))
+ device.handshakeStateChan = handshakeStateChan
+ /*var allowedSources = strings.Split(allowedSrcAddresses, ",")
+ device.allowedSrcAddresses = make([]net.IP, len(allowedSources))
+ for i, source := range allowedSources {
+ ip := net.ParseIP(source)
+ if ip != nil {
+ device.allowedSrcAddresses[i] = ip
+ }
+ }*/
device.closed = make(chan struct{})
device.log = logger
device.net.bind = bind
@@ -417,6 +439,7 @@ func (device *Device) Close() {
device.log.Verbosef("Device closed")
close(device.closed)
+ close(device.handshakeStateChan)
}
func (device *Device) Wait() chan struct{} {
diff --git a/device/device_test.go b/device/device_test.go
index e6664a681..5f5d6c430 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -237,7 +237,8 @@ func genTestPair(
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
level = LogLevelError
}
- p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i)))
+ p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i)),
+ make(chan HandshakeState))
if err := p.dev.IpcSet(cfg[i]); err != nil {
tb.Errorf("failed to configure device %d: %v", i, err)
p.dev.Close()
diff --git a/device/logger.go b/device/logger.go
index 22b0df028..9577e2292 100644
--- a/device/logger.go
+++ b/device/logger.go
@@ -6,6 +6,7 @@
package device
import (
+ "github.com/amnezia-vpn/amneziawg-go/conn"
"log"
"os"
)
@@ -16,8 +17,7 @@ import (
// They do not require a trailing newline in the format.
// If nil, that level of logging will be silent.
type Logger struct {
- Verbosef func(format string, args ...any)
- Errorf func(format string, args ...any)
+ conn.Logger
}
// Log levels for use with NewLogger.
@@ -34,7 +34,7 @@ func DiscardLogf(format string, args ...any) {}
// It logs at the specified log level and above.
// It decorates log lines with the log level, date, time, and prepend.
func NewLogger(level int, prepend string) *Logger {
- logger := &Logger{DiscardLogf, DiscardLogf}
+ logger := &Logger{conn.Logger{Verbosef: DiscardLogf, Errorf: DiscardLogf}}
logf := func(prefix string) func(string, ...any) {
return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf
}
diff --git a/device/noise_test.go b/device/noise_test.go
index 075b6d304..ff178f876 100644
--- a/device/noise_test.go
+++ b/device/noise_test.go
@@ -39,7 +39,7 @@ func randDevice(t *testing.T) *Device {
}
tun := tuntest.NewChannelTUN()
logger := NewLogger(LogLevelError, "")
- device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger)
+ device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger, make(chan HandshakeState))
device.SetPrivateKey(sk)
return device
}
diff --git a/device/receive.go b/device/receive.go
index 66c1a32de..9016ebd47 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -406,6 +406,7 @@ func (device *Device) RoutineHandshake(id int) {
// update endpoint
peer.SetEndpointFromPacket(elem.endpoint)
+ device.handshakeStateChan <- HandshakeSuccess
device.log.Verbosef("%v - Received handshake initiation", peer)
peer.rxBytes.Add(uint64(len(elem.packet)))
@@ -434,6 +435,7 @@ func (device *Device) RoutineHandshake(id int) {
// update endpoint
peer.SetEndpointFromPacket(elem.endpoint)
+ device.handshakeStateChan <- HandshakeSuccess
device.log.Verbosef("%v - Received handshake response", peer)
peer.rxBytes.Add(uint64(len(elem.packet)))
diff --git a/device/send.go b/device/send.go
index 1b4406d53..4af8028f9 100644
--- a/device/send.go
+++ b/device/send.go
@@ -175,6 +175,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
err = peer.SendBuffers(sendBuffer)
if err != nil {
+ peer.device.handshakeStateChan <- HandshakeFail
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
}
peer.timersHandshakeInitiated()
@@ -318,12 +319,14 @@ func (device *Device) RoutineReadFromTUN() {
// lookup peer
var peer *Peer
+ var src []byte
switch elem.packet[0] >> 4 {
case 4:
if len(elem.packet) < ipv4.HeaderLen {
continue
}
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
+ src = elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
peer = device.allowedips.Lookup(dst)
case 6:
@@ -331,6 +334,7 @@ func (device *Device) RoutineReadFromTUN() {
continue
}
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
+ src = elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
peer = device.allowedips.Lookup(dst)
default:
@@ -340,6 +344,13 @@ func (device *Device) RoutineReadFromTUN() {
if peer == nil {
continue
}
+
+ // Drop packets with unexpected src IP.
+ if device.allowedSrcAddresses != nil && device.isUnexpectedSrcIP(src) {
+ //device.log.Verbosef("Dropping packet with unexpected src IP: %v (allowed = %v)", src, device.allowedSrcAddresses)
+ continue
+ }
+
elemsForPeer, ok := elemsByPeer[peer]
if !ok {
elemsForPeer = device.GetOutboundElementsContainer()
@@ -383,6 +394,15 @@ func (device *Device) RoutineReadFromTUN() {
}
}
+func (device *Device) isUnexpectedSrcIP(src []byte) bool {
+ for _, allowed := range device.allowedSrcAddresses {
+ if allowed.Equal(src) {
+ return false
+ }
+ }
+ return true
+}
+
func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
for {
select {
diff --git a/device/statemanager.go b/device/statemanager.go
new file mode 100644
index 000000000..e6634d0d9
--- /dev/null
+++ b/device/statemanager.go
@@ -0,0 +1,247 @@
+/*
+ * Copyright (c) 2022. Proton AG
+ *
+ * This file is part of ProtonVPN.
+ *
+ * ProtonVPN is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * ProtonVPN is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with ProtonVPN. If not, see .
+ */
+
+package device
+
+import (
+ "strings"
+ "sync"
+ "time"
+)
+
+var initialRestartDelay = 4 * time.Second
+var maxRestartDelay = 32 * time.Second
+var resetRestartDelay = 10 * time.Minute
+var timeNow = time.Now
+
+// WireGuardStateManager handles enabling/disabling WireGuard in response to network availability changes, serves
+// connection state to the client and resets WireGuard connection in response to socket and handshake errors.
+//
+// Client should call SetNetworkAvailable every time network changes - WireGuard will remain inactive until
+// SetNetworkAvailable(true) is called. When SetNetworkAvailable(true) is called twice in a row it'll be interpreted
+// as network change and trigger reset of the connection (on TCP/TLS socket).
+//
+// GetState is blocking and therefore should run in dedicated thread in a loop. After Close is called GetState will
+// return immediately with WireGuardDisabled.
+type WireGuardStateManager struct {
+ HandshakeStateChan chan HandshakeState
+ SocketErrChan chan error
+ networkAvailableChan chan bool
+ closeChan chan bool
+
+ stateChan chan WireGuardState
+ isNetAvailable bool
+
+ lastRestart time.Time
+ transmission string
+
+ log *Logger
+ mu sync.Mutex
+ closed bool
+ startedTimestamp time.Time
+ nextRestartDelay time.Duration
+}
+
+type WireGuardState int
+
+const (
+ WireGuardDisabled WireGuardState = iota
+ WireGuardConnecting
+ WireGuardConnected
+ WireGuardError
+ WireGuardWaitingForNetwork
+)
+
+type BaseDevice interface {
+ Up() error
+ Down() error
+}
+
+//goland:noinspection GoUnusedExportedFunction
+func NewWireGuardStateManager(log *Logger, transmission string) *WireGuardStateManager {
+ return &WireGuardStateManager{
+ networkAvailableChan: make(chan bool, 100),
+ SocketErrChan: make(chan error, 100),
+ HandshakeStateChan: make(chan HandshakeState, 100),
+ closeChan: make(chan bool, 1),
+ stateChan: make(chan WireGuardState, 1),
+ transmission: transmission,
+ log: log,
+ nextRestartDelay: initialRestartDelay,
+ lastRestart: timeNow(),
+ }
+}
+
+func (man *WireGuardStateManager) Start(device BaseDevice) {
+ go man.handlerLoop(device)
+}
+
+func (man *WireGuardStateManager) GetState() WireGuardState {
+ state, ok := <-man.stateChan
+ if !ok {
+ return -1
+ }
+ return state
+}
+
+func (man *WireGuardStateManager) Close() {
+ man.log.Verbosef("StateManager: closing")
+ man.closed = true
+ go func() {
+ man.closeChan <- true
+ man.stateChan <- WireGuardDisabled
+ close(man.stateChan)
+ }()
+}
+
+func (man *WireGuardStateManager) SetNetworkAvailable(available bool) {
+ man.networkAvailableChan <- available
+}
+
+func (man *WireGuardStateManager) handlerLoop(device BaseDevice) {
+ man.log.Verbosef("StateManager: start loop")
+ // Ugly way of emulating optional bool type
+ var wasNetAvailablePtr *bool = nil
+ for {
+ select {
+ case netAvailable := <-man.networkAvailableChan:
+ man.onNetworkAvailabilityChange(device, wasNetAvailablePtr, netAvailable)
+ man.isNetAvailable = netAvailable
+ wasNetAvailablePtr = &man.isNetAvailable
+ case socketErr := <-man.SocketErrChan:
+ if man.isNetAvailable {
+ man.handleSocketErr(device, socketErr)
+ }
+ case handshakeState := <-man.HandshakeStateChan:
+ if man.isNetAvailable {
+ man.handleHandshakeState(device, handshakeState)
+ }
+ case <-man.closeChan:
+ man.log.Verbosef("StateManager: end loop")
+ return
+ }
+ }
+}
+
+func (man *WireGuardStateManager) onNetworkAvailabilityChange(device BaseDevice, wasAvailable *bool, available bool) {
+ if !available {
+ man.postState(WireGuardWaitingForNetwork)
+ }
+ if available && wasAvailable == nil {
+ man.log.Verbosef("StateManager: network on")
+ man.setActive(device, true)
+ man.startedTimestamp = timeNow()
+ } else if available && *wasAvailable && !man.startedTimestamp.IsZero() &&
+ timeNow().After(man.startedTimestamp.Add(5*time.Second)) {
+ // Ignore network changes at the very beginning of connection as those might be false positive
+ // (VPN tunnel opening)
+ man.log.Verbosef("StateManager: network change detected")
+ man.maybeRestart(device)
+ } else if available && !*wasAvailable {
+ man.log.Verbosef("StateManager: network back")
+ man.setActive(device, true)
+ } else if !available && wasAvailable != nil && *wasAvailable {
+ man.log.Verbosef("StateManager: network gone")
+ man.setActive(device, false)
+ }
+}
+
+func (man *WireGuardStateManager) setActive(device BaseDevice, activate bool) {
+ man.mu.Lock()
+ defer man.mu.Unlock()
+
+ var err error
+ if activate {
+ man.postState(WireGuardConnecting)
+ err = device.Up()
+ } else {
+ err = device.Down()
+ }
+ if err != nil {
+ man.log.Errorf("StateManager: setActive(%t) error %v", activate, err)
+ man.postState(WireGuardError)
+ }
+}
+
+func (man *WireGuardStateManager) handleSocketErr(device BaseDevice, err error) {
+ if err != nil {
+ errStr := err.Error()
+ if strings.Contains(errStr, "broken pipe") ||
+ strings.Contains(errStr, "connection reset by peer") {
+ man.log.Errorf("StateManager: %s", errStr)
+ man.maybeRestart(device)
+ }
+ }
+}
+
+func (man *WireGuardStateManager) handleHandshakeState(device BaseDevice, state HandshakeState) {
+ switch state {
+ case HandshakeInit:
+ man.postState(WireGuardConnecting)
+ case HandshakeSuccess:
+ man.postState(WireGuardConnected)
+ case HandshakeFail:
+ man.postState(WireGuardError)
+ man.maybeRestart(device)
+ }
+}
+
+func (man *WireGuardStateManager) maybeRestart(device BaseDevice) {
+ if man.transmission == "udp" {
+ return
+ }
+
+ man.mu.Lock()
+ defer man.mu.Unlock()
+
+ if man.shouldRestart() {
+ man.log.Verbosef("StateManager: restarting")
+ man.postState(WireGuardConnecting)
+ device.Down()
+ if !man.closed {
+ device.Up()
+ }
+ }
+}
+
+// Don't restart too often, grow delay exponentially up to a limit and after some time reset to small initial value
+func (man *WireGuardStateManager) shouldRestart() bool {
+ now := timeNow()
+ restart := now.After(man.lastRestart.Add(man.nextRestartDelay))
+ if restart {
+ if now.After(man.lastRestart.Add(resetRestartDelay)) {
+ man.nextRestartDelay = initialRestartDelay
+ } else {
+ man.nextRestartDelay *= 2
+ if man.nextRestartDelay > maxRestartDelay {
+ man.nextRestartDelay = maxRestartDelay
+ }
+ }
+ man.lastRestart = now
+ }
+ return restart
+}
+
+func (man *WireGuardStateManager) postState(state WireGuardState) {
+ go func() {
+ if !man.closed && (man.isNetAvailable || state == WireGuardWaitingForNetwork) {
+ man.stateChan <- state
+ }
+ }()
+}
diff --git a/device/statemanager_test.go b/device/statemanager_test.go
new file mode 100644
index 000000000..b2b91bc3b
--- /dev/null
+++ b/device/statemanager_test.go
@@ -0,0 +1,152 @@
+/*
+ * Copyright (c) 2022. Proton AG
+ *
+ * This file is part of ProtonVPN.
+ *
+ * ProtonVPN is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * ProtonVPN is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with ProtonVPN. If not, see .
+ */
+
+package device
+
+import (
+ "errors"
+ "github.com/stretchr/testify/assert"
+ "testing"
+ "time"
+)
+
+var timeMs int64 = 0
+var mockDevice MockDevice
+var manager *WireGuardStateManager
+var lastState WireGuardState
+
+type MockDevice struct {
+ isUp bool
+ upCount int
+}
+
+func (dev *MockDevice) Up() error {
+ dev.isUp = true
+ dev.upCount++
+ return nil
+}
+
+func (dev *MockDevice) Down() error {
+ dev.isUp = false
+ return nil
+}
+
+func setup() {
+ timeMs = 0
+ timeNow = func() time.Time { return time.UnixMilli(timeMs) }
+ mockDevice.isUp = false
+
+ manager = NewWireGuardStateManager(NewLogger(LogLevelVerbose, ""), "tcp")
+ manager.Start(&mockDevice)
+ lastState = WireGuardDisabled
+ go func() {
+ for lastState != -1 {
+ lastState = manager.GetState()
+ }
+ }()
+}
+
+func setdown() {
+ manager.Close()
+}
+
+func TestWireGuardStateManager_shouldRestart(t *testing.T) {
+ assert := assert.New(t)
+ setup()
+ defer setdown()
+
+ assert.Equal(initialRestartDelay, manager.nextRestartDelay)
+
+ assert.Equal(false, manager.shouldRestart())
+ timeMs += initialRestartDelay.Milliseconds()
+ assert.Equal(false, manager.shouldRestart())
+ timeMs += 1
+ assert.Equal(true, manager.shouldRestart())
+
+ assert.Equal(2*initialRestartDelay, manager.nextRestartDelay)
+ assert.Equal(false, manager.shouldRestart())
+ timeMs += 2 * initialRestartDelay.Milliseconds()
+ assert.Equal(false, manager.shouldRestart())
+ timeMs += 1
+ assert.Equal(true, manager.shouldRestart())
+
+ timeMs += resetRestartDelay.Milliseconds() + 1
+ assert.Equal(true, manager.shouldRestart())
+ assert.Equal(initialRestartDelay, manager.nextRestartDelay)
+}
+
+func TestWireGuardStateManager_networkStartsAndStopsDevice(t *testing.T) {
+ assert := assert.New(t)
+ setup()
+ defer setdown()
+
+ assert.Equal(false, mockDevice.isUp)
+ manager.SetNetworkAvailable(true)
+ time.Sleep(time.Millisecond) // Poor substitute for advanceUntilIdle, make sure goroutines finish before checking
+ assert.Equal(true, mockDevice.isUp)
+ assert.Equal(WireGuardConnecting, lastState)
+ manager.SetNetworkAvailable(false)
+ time.Sleep(time.Millisecond)
+ assert.Equal(WireGuardWaitingForNetwork, lastState)
+ assert.Equal(false, mockDevice.isUp)
+}
+
+func TestWireGuardStateManager_happyConnectionPath(t *testing.T) {
+ assert := assert.New(t)
+ setup()
+ defer setdown()
+
+ manager.SetNetworkAvailable(true)
+ time.Sleep(time.Millisecond)
+ manager.HandshakeStateChan <- HandshakeSuccess
+ time.Sleep(time.Millisecond)
+ assert.Equal(WireGuardConnected, lastState)
+ assert.Equal(true, mockDevice.isUp)
+}
+
+func TestWireGuardStateManager_handshakeFailCausesRestart(t *testing.T) {
+ assert := assert.New(t)
+ setup()
+ defer setdown()
+
+ manager.SetNetworkAvailable(true)
+ time.Sleep(time.Millisecond)
+ manager.HandshakeStateChan <- HandshakeFail
+ time.Sleep(time.Millisecond)
+ assert.Equal(WireGuardError, lastState)
+ timeMs += initialRestartDelay.Milliseconds() + 1
+ manager.HandshakeStateChan <- HandshakeFail
+ time.Sleep(time.Millisecond)
+ assert.Equal(WireGuardConnecting, lastState)
+ assert.Equal(2, mockDevice.upCount)
+}
+
+func TestWireGuardStateManager_brokenPipeCausesRestart(t *testing.T) {
+ assert := assert.New(t)
+ setup()
+ defer setdown()
+
+ manager.SetNetworkAvailable(true)
+ timeMs += initialRestartDelay.Milliseconds() + 1
+ time.Sleep(time.Millisecond)
+ manager.SocketErrChan <- errors.New("broken pipe")
+ time.Sleep(time.Millisecond)
+ assert.Equal(WireGuardConnecting, lastState)
+ assert.Equal(2, mockDevice.upCount)
+}
diff --git a/device/timers.go b/device/timers.go
index d4a4ed4e5..72b9a2c2b 100644
--- a/device/timers.go
+++ b/device/timers.go
@@ -78,6 +78,7 @@ func (peer *Peer) timersActive() bool {
func expiredRetransmitHandshake(peer *Peer) {
if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes {
+ peer.device.handshakeStateChan <- HandshakeFail
peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2)
if peer.timersActive() {
@@ -97,6 +98,7 @@ func expiredRetransmitHandshake(peer *Peer) {
}
} else {
peer.timers.handshakeAttempts.Add(1)
+ peer.device.handshakeStateChan <- HandshakeFail
peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
/* We clear the endpoint address src address, in case this is the cause of trouble. */
diff --git a/go.mod b/go.mod
index 33182eea1..c018932ec 100644
--- a/go.mod
+++ b/go.mod
@@ -3,6 +3,8 @@ module github.com/amnezia-vpn/amneziawg-go
go 1.20
require (
+ github.com/refraction-networking/utls v1.1.5
+ github.com/stretchr/testify v1.8.0
github.com/tevino/abool/v2 v2.1.0
golang.org/x/crypto v0.19.0
golang.org/x/net v0.21.0
@@ -15,3 +17,11 @@ require (
github.com/google/btree v1.0.1 // indirect
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
)
+
+require (
+ github.com/andybalholm/brotli v1.0.4 // indirect
+ github.com/davecgh/go-spew v1.1.1 // indirect
+ github.com/klauspost/compress v1.15.9 // indirect
+ github.com/pmezard/go-difflib v1.0.0 // indirect
+ gopkg.in/yaml.v3 v3.0.1 // indirect
+)
diff --git a/go.sum b/go.sum
index 0e6f733c7..06741f43e 100644
--- a/go.sum
+++ b/go.sum
@@ -1,16 +1,41 @@
-github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
+github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
+github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
+github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
+github.com/klauspost/compress v1.15.9 h1:wKRjX6JRtDdrE9qwa4b/Cip7ACOshUI4smpCQanqjSY=
+github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/refraction-networking/utls v1.1.5 h1:JtrojoNhbUQkBqEg05sP3gDgDj6hIEAAVKbI9lx4n6w=
+github.com/refraction-networking/utls v1.1.5/go.mod h1:jRQxtYi7nkq1p28HF2lwOH5zQm9aC8rpK0O9lIIzGh8=
+github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
+github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c=
github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY=
+golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
+golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
+golang.org/x/net v0.0.0-20220909164309-bea034e7d591/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
+golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
-golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44=
+golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
+golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
+golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
-gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY=
diff --git a/main.go b/main.go
index 775372c90..baff1e33c 100644
--- a/main.go
+++ b/main.go
@@ -222,11 +222,11 @@ func main() {
return
}
- device := device.NewDevice(tdev, conn.NewDefaultBind(), logger)
+ errs := make(chan error)
+ device := device.NewDevice(tdev, conn.NewDefaultBind(&logger.Logger, errs), logger, make(chan device.HandshakeState))
logger.Verbosef("Device started")
- errs := make(chan error)
term := make(chan os.Signal, 1)
uapi, err := ipc.UAPIListen(interfaceName, fileUAPI)