diff --git a/monitor_linux_default.go b/monitor_linux_default.go index a8446a5..3a3290c 100644 --- a/monitor_linux_default.go +++ b/monitor_linux_default.go @@ -4,6 +4,7 @@ package tun import ( "github.com/sagernet/netlink" + "golang.org/x/sys/unix" ) diff --git a/monitor_other.go b/monitor_other.go index 76f4c29..c6b447c 100644 --- a/monitor_other.go +++ b/monitor_other.go @@ -3,8 +3,9 @@ package tun import ( - "github.com/sagernet/sing/common/logger" "os" + + "github.com/sagernet/sing/common/logger" ) func NewNetworkUpdateMonitor(logger logger.Logger) (NetworkUpdateMonitor, error) { diff --git a/stack.go b/stack.go index 2191187..2e96e9d 100644 --- a/stack.go +++ b/stack.go @@ -35,9 +35,15 @@ func NewStack( ) (Stack, error) { switch stack { case "": - return NewSystem(options) + if WithGVisor { + return NewMixed(options) + } else { + return NewSystem(options) + } case "gvisor": return NewGVisor(options) + case "mixed": + return NewMixed(options) case "system": return NewSystem(options) case "lwip": diff --git a/stack_gvisor_stub.go b/stack_gvisor_stub.go index bd380f4..64c8a65 100644 --- a/stack_gvisor_stub.go +++ b/stack_gvisor_stub.go @@ -13,3 +13,9 @@ func NewGVisor( ) (Stack, error) { return nil, ErrGVisorNotIncluded } + +func NewMixed( + options StackOptions, +) (Stack, error) { + return nil, ErrGVisorNotIncluded +} diff --git a/stack_gvisor_udp.go b/stack_gvisor_udp.go index 0846e5d..d29fa46 100644 --- a/stack_gvisor_udp.go +++ b/stack_gvisor_udp.go @@ -8,6 +8,7 @@ import ( "math" "net/netip" "os" + "sync" "syscall" "github.com/sagernet/gvisor/pkg/buffer" @@ -74,10 +75,12 @@ func (f *UDPForwarder) newUDPConn(natConn N.PacketConn) N.PacketWriter { source: f.cacheID.RemoteAddress, sourcePort: f.cacheID.RemotePort, sourceNetwork: f.cacheProto, + packet: f.cachePacket.IncRef(), } } type UDPBackWriter struct { + access sync.Mutex stack *stack.Stack source tcpip.Address sourcePort uint16 @@ -85,8 +88,21 @@ type UDPBackWriter struct { packet stack.PacketBufferPtr } +func (w *UDPBackWriter) Close() error { + w.access.Lock() + defer w.access.Unlock() + if w.packet == nil { + return os.ErrClosed + } + w.packet.DecRef() + w.packet = nil + return nil +} + func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Socksaddr) error { - if destination.IsIPv4() && w.sourceNetwork == header.IPv6ProtocolNumber { + if !destination.IsIP() { + return E.Cause(os.ErrInvalid, "invalid destination") + } else if destination.IsIPv4() && w.sourceNetwork == header.IPv6ProtocolNumber { destination = M.SocksaddrFrom(netip.AddrFrom16(destination.Addr.As16()), destination.Port) } else if destination.IsIPv6() && (w.sourceNetwork == header.IPv4AddressSizeBits) { return E.New("send IPv6 packet to IPv4 connection") @@ -165,6 +181,7 @@ type gRequest struct { type gUDPConn struct { *gonet.UDPConn + access sync.Mutex stack *stack.Stack packet stack.PacketBufferPtr } @@ -188,6 +205,11 @@ func (c *gUDPConn) Write(b []byte) (n int, err error) { } func (c *gUDPConn) Close() error { + c.access.Lock() + defer c.access.Unlock() + if c.packet == nil { + return os.ErrClosed + } c.packet.DecRef() c.packet = nil return c.UDPConn.Close() diff --git a/stack_mixed.go b/stack_mixed.go new file mode 100644 index 0000000..2d55814 --- /dev/null +++ b/stack_mixed.go @@ -0,0 +1,202 @@ +//go:build with_gvisor + +package tun + +import ( + "time" + "unsafe" + + "github.com/sagernet/gvisor/pkg/buffer" + "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet" + "github.com/sagernet/gvisor/pkg/tcpip/header" + "github.com/sagernet/gvisor/pkg/tcpip/link/channel" + "github.com/sagernet/gvisor/pkg/tcpip/stack" + "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" + "github.com/sagernet/gvisor/pkg/waiter" + "github.com/sagernet/sing-tun/internal/clashtcpip" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/bufio" + "github.com/sagernet/sing/common/canceler" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +type Mixed struct { + *System + writer N.VectorisedWriter + endpointIndependentNat bool + stack *stack.Stack + endpoint *channel.Endpoint +} + +func NewMixed( + options StackOptions, +) (Stack, error) { + system, err := NewSystem(options) + if err != nil { + return nil, err + } + return &Mixed{ + System: system.(*System), + writer: options.Tun.CreateVectorisedWriter(), + endpointIndependentNat: options.EndpointIndependentNat, + }, nil +} + +func (m *Mixed) Start() error { + err := m.System.start() + if err != nil { + return err + } + endpoint := channel.New(1024, m.mtu, "") + ipStack, err := newGVisorStack(endpoint) + if err != nil { + return err + } + if !m.endpointIndependentNat { + udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) { + var wq waiter.Queue + endpoint, err := request.CreateEndpoint(&wq) + if err != nil { + return + } + udpConn := gonet.NewUDPConn(ipStack, &wq, endpoint) + lAddr := udpConn.RemoteAddr() + rAddr := udpConn.LocalAddr() + if lAddr == nil || rAddr == nil { + endpoint.Abort() + return + } + gConn := &gUDPConn{udpConn, ipStack, (*gRequest)(unsafe.Pointer(request)).pkt.IncRef()} + go func() { + var metadata M.Metadata + metadata.Source = M.SocksaddrFromNet(lAddr) + metadata.Destination = M.SocksaddrFromNet(rAddr) + ctx, conn := canceler.NewPacketConn(m.ctx, bufio.NewPacketConn(&bufio.UnbindPacketConn{ExtendedConn: bufio.NewExtendedConn(gConn), Addr: M.SocksaddrFromNet(rAddr)}), time.Duration(m.udpTimeout)*time.Second) + hErr := m.handler.NewPacketConnection(ctx, conn, metadata) + if hErr != nil { + endpoint.Abort() + } + }() + }) + ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) + } else { + ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(m.ctx, ipStack, m.handler, m.udpTimeout).HandlePacket) + } + m.stack = ipStack + m.endpoint = endpoint + go m.tunLoop() + go m.packetLoop() + return nil +} + +func (m *Mixed) tunLoop() { + if winTun, isWinTun := m.tun.(WinTun); isWinTun { + m.wintunLoop(winTun) + return + } + packetBuffer := make([]byte, m.mtu+PacketOffset) + for { + n, err := m.tun.Read(packetBuffer) + if err != nil { + return + } + if n < clashtcpip.IPv4PacketMinLength { + continue + } + packet := packetBuffer[PacketOffset:n] + switch ipVersion := packet[0] >> 4; ipVersion { + case 4: + err = m.processIPv4(packet) + case 6: + err = m.processIPv6(packet) + default: + err = E.New("ip: unknown version: ", ipVersion) + } + if err != nil { + m.logger.Trace(err) + } + } +} + +func (m *Mixed) wintunLoop(winTun WinTun) { + for { + packet, release, err := winTun.ReadPacket() + if err != nil { + return + } + if len(packet) < clashtcpip.IPv4PacketMinLength { + release() + continue + } + switch ipVersion := packet[0] >> 4; ipVersion { + case 4: + err = m.processIPv4(packet) + case 6: + err = m.processIPv6(packet) + default: + err = E.New("ip: unknown version: ", ipVersion) + } + if err != nil { + m.logger.Trace(err) + } + release() + } +} + +func (m *Mixed) processIPv4(packet clashtcpip.IPv4Packet) error { + switch packet.Protocol() { + case clashtcpip.TCP: + return m.processIPv4TCP(packet, packet.Payload()) + case clashtcpip.UDP: + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(packet), + }) + m.endpoint.InjectInbound(header.IPv4ProtocolNumber, pkt) + pkt.DecRef() + return nil + case clashtcpip.ICMP: + return m.processIPv4ICMP(packet, packet.Payload()) + default: + return common.Error(m.tun.Write(packet)) + } +} + +func (m *Mixed) processIPv6(packet clashtcpip.IPv6Packet) error { + switch packet.Protocol() { + case clashtcpip.TCP: + return m.processIPv6TCP(packet, packet.Payload()) + case clashtcpip.UDP: + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(packet), + }) + m.endpoint.InjectInbound(header.IPv6ProtocolNumber, pkt) + pkt.DecRef() + return nil + case clashtcpip.ICMPv6: + return m.processIPv6ICMP(packet, packet.Payload()) + default: + return common.Error(m.tun.Write(packet)) + } +} + +func (m *Mixed) packetLoop() { + for { + packet := m.endpoint.ReadContext(m.ctx) + if packet == nil { + break + } + bufio.WriteVectorised(m.writer, packet.AsSlices()) + packet.DecRef() + } +} + +func (m *Mixed) Close() error { + m.endpoint.Attach(nil) + m.stack.Close() + for _, endpoint := range m.stack.CleanupEndpoints() { + endpoint.Abort() + } + return m.System.Close() +} diff --git a/tun.go b/tun.go index 2784339..52aa6a5 100644 --- a/tun.go +++ b/tun.go @@ -23,6 +23,7 @@ type Handler interface { type Tun interface { io.ReadWriter + CreateVectorisedWriter() N.VectorisedWriter Close() error } diff --git a/tun_darwin.go b/tun_darwin.go index 2013e21..a9bba89 100644 --- a/tun_darwin.go +++ b/tun_darwin.go @@ -10,6 +10,7 @@ import ( "unsafe" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" N "github.com/sagernet/sing/common/network" @@ -101,6 +102,20 @@ func (t *NativeTun) Write(p []byte) (n int, err error) { return } +func (t *NativeTun) CreateVectorisedWriter() N.VectorisedWriter { + return t +} + +func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error { + var packetHeader []byte + if buffers[0].Byte(0)>>4 == 4 { + packetHeader = packetHeader4[:] + } else { + packetHeader = packetHeader6[:] + } + return t.tunWriter.WriteVectorised(append([]*buf.Buffer{buf.As(packetHeader)}, buffers...)) +} + func (t *NativeTun) Close() error { flushDNSCache() return t.tunFile.Close() diff --git a/tun_linux.go b/tun_linux.go index 597881f..465fc5c 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -12,7 +12,9 @@ import ( "github.com/sagernet/netlink" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/common/shell" "github.com/sagernet/sing/common/x/list" @@ -68,6 +70,10 @@ func (t *NativeTun) Write(p []byte) (n int, err error) { return t.tunFile.Write(p) } +func (t *NativeTun) CreateVectorisedWriter() N.VectorisedWriter { + return bufio.NewVectorisedWriter(t.tunFile) +} + var controlPath string func init() { diff --git a/tun_windows.go b/tun_windows.go index 488f8a7..656251f 100644 --- a/tun_windows.go +++ b/tun_windows.go @@ -16,7 +16,10 @@ import ( "github.com/sagernet/sing-tun/internal/winipcfg" "github.com/sagernet/sing-tun/internal/winsys" "github.com/sagernet/sing-tun/internal/wintun" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/windnsapi" "golang.org/x/sys/windows" @@ -467,6 +470,15 @@ func (t *NativeTun) write(packetElementList [][]byte) (n int, err error) { return 0, fmt.Errorf("write failed: %w", err) } +func (t *NativeTun) CreateVectorisedWriter() N.VectorisedWriter { + return t +} + +func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error { + defer buf.ReleaseMulti(buffers) + return common.Error(t.write(buf.ToSliceMulti(buffers))) +} + func (t *NativeTun) Close() error { var err error t.closeOnce.Do(func() {