Skip to content

Commit

Permalink
Add mixed stack
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Aug 12, 2023
1 parent aa8760b commit 10d98f2
Show file tree
Hide file tree
Showing 11 changed files with 276 additions and 4 deletions.
1 change: 1 addition & 0 deletions monitor_linux_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package tun

import (
"github.com/sagernet/netlink"

"golang.org/x/sys/unix"
)

Expand Down
3 changes: 2 additions & 1 deletion monitor_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
package tun

import (
"github.com/sagernet/sing/common/logger"
"os"

"github.com/sagernet/sing/common/logger"
)

func NewNetworkUpdateMonitor(logger logger.Logger) (NetworkUpdateMonitor, error) {
Expand Down
8 changes: 7 additions & 1 deletion stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,15 @@ func NewStack(
) (Stack, error) {
switch stack {
case "":
return NewSystem(options)
if WithGVisor {
return NewMixed(options)
} else {
return NewSystem(options)
}
case "gvisor":
return NewGVisor(options)
case "mixed":
return NewMixed(options)
case "system":
return NewSystem(options)
case "lwip":
Expand Down
2 changes: 1 addition & 1 deletion stack_gvisor.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func (t *GVisor) Start() error {
endpoint.Abort()
return
}
gConn := &gUDPConn{udpConn, ipStack, (*gRequest)(unsafe.Pointer(request)).pkt.IncRef()}
gConn := &gUDPConn{UDPConn: udpConn, stack: ipStack, packet: (*gRequest)(unsafe.Pointer(request)).pkt.IncRef()}
go func() {
var metadata M.Metadata
metadata.Source = M.SocksaddrFromNet(lAddr)
Expand Down
6 changes: 6 additions & 0 deletions stack_gvisor_stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,9 @@ func NewGVisor(
) (Stack, error) {
return nil, ErrGVisorNotIncluded
}

func NewMixed(
options StackOptions,
) (Stack, error) {
return nil, ErrGVisorNotIncluded
}
24 changes: 23 additions & 1 deletion stack_gvisor_udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"math"
"net/netip"
"os"
"sync"
"syscall"

"github.com/sagernet/gvisor/pkg/buffer"
Expand Down Expand Up @@ -74,19 +75,34 @@ func (f *UDPForwarder) newUDPConn(natConn N.PacketConn) N.PacketWriter {
source: f.cacheID.RemoteAddress,
sourcePort: f.cacheID.RemotePort,
sourceNetwork: f.cacheProto,
packet: f.cachePacket.IncRef(),
}
}

type UDPBackWriter struct {
access sync.Mutex
stack *stack.Stack
source tcpip.Address
sourcePort uint16
sourceNetwork tcpip.NetworkProtocolNumber
packet stack.PacketBufferPtr
}

func (w *UDPBackWriter) Close() error {
w.access.Lock()
defer w.access.Unlock()
if w.packet == nil {
return os.ErrClosed
}
w.packet.DecRef()
w.packet = nil
return nil
}

func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Socksaddr) error {
if destination.IsIPv4() && w.sourceNetwork == header.IPv6ProtocolNumber {
if !destination.IsIP() {
return E.Cause(os.ErrInvalid, "invalid destination")
} else if destination.IsIPv4() && w.sourceNetwork == header.IPv6ProtocolNumber {
destination = M.SocksaddrFrom(netip.AddrFrom16(destination.Addr.As16()), destination.Port)
} else if destination.IsIPv6() && (w.sourceNetwork == header.IPv4AddressSizeBits) {
return E.New("send IPv6 packet to IPv4 connection")
Expand Down Expand Up @@ -165,6 +181,7 @@ type gRequest struct {

type gUDPConn struct {
*gonet.UDPConn
access sync.Mutex
stack *stack.Stack
packet stack.PacketBufferPtr
}
Expand All @@ -188,6 +205,11 @@ func (c *gUDPConn) Write(b []byte) (n int, err error) {
}

func (c *gUDPConn) Close() error {
c.access.Lock()
defer c.access.Unlock()
if c.packet == nil {
return os.ErrClosed
}
c.packet.DecRef()
c.packet = nil
return c.UDPConn.Close()
Expand Down
202 changes: 202 additions & 0 deletions stack_mixed.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
//go:build with_gvisor

package tun

import (
"time"
"unsafe"

"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/link/channel"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
"github.com/sagernet/gvisor/pkg/waiter"
"github.com/sagernet/sing-tun/internal/clashtcpip"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/canceler"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)

type Mixed struct {
*System
writer N.VectorisedWriter
endpointIndependentNat bool
stack *stack.Stack
endpoint *channel.Endpoint
}

func NewMixed(
options StackOptions,
) (Stack, error) {
system, err := NewSystem(options)
if err != nil {
return nil, err
}
return &Mixed{
System: system.(*System),
writer: options.Tun.CreateVectorisedWriter(),
endpointIndependentNat: options.EndpointIndependentNat,
}, nil
}

func (m *Mixed) Start() error {
err := m.System.start()
if err != nil {
return err
}
endpoint := channel.New(1024, m.mtu, "")
ipStack, err := newGVisorStack(endpoint)
if err != nil {
return err
}
if !m.endpointIndependentNat {
udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) {
var wq waiter.Queue
endpoint, err := request.CreateEndpoint(&wq)
if err != nil {
return
}
udpConn := gonet.NewUDPConn(ipStack, &wq, endpoint)
lAddr := udpConn.RemoteAddr()
rAddr := udpConn.LocalAddr()
if lAddr == nil || rAddr == nil {
endpoint.Abort()
return
}
gConn := &gUDPConn{UDPConn: udpConn, stack: ipStack, packet: (*gRequest)(unsafe.Pointer(request)).pkt.IncRef()}
go func() {
var metadata M.Metadata
metadata.Source = M.SocksaddrFromNet(lAddr)
metadata.Destination = M.SocksaddrFromNet(rAddr)
ctx, conn := canceler.NewPacketConn(m.ctx, bufio.NewPacketConn(&bufio.UnbindPacketConn{ExtendedConn: bufio.NewExtendedConn(gConn), Addr: M.SocksaddrFromNet(rAddr)}), time.Duration(m.udpTimeout)*time.Second)
hErr := m.handler.NewPacketConnection(ctx, conn, metadata)
if hErr != nil {
endpoint.Abort()
}
}()
})
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
} else {
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(m.ctx, ipStack, m.handler, m.udpTimeout).HandlePacket)
}
m.stack = ipStack
m.endpoint = endpoint
go m.tunLoop()
go m.packetLoop()
return nil
}

func (m *Mixed) tunLoop() {
if winTun, isWinTun := m.tun.(WinTun); isWinTun {
m.wintunLoop(winTun)
return
}
packetBuffer := make([]byte, m.mtu+PacketOffset)
for {
n, err := m.tun.Read(packetBuffer)
if err != nil {
return
}
if n < clashtcpip.IPv4PacketMinLength {
continue
}
packet := packetBuffer[PacketOffset:n]
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
err = m.processIPv4(packet)
case 6:
err = m.processIPv6(packet)
default:
err = E.New("ip: unknown version: ", ipVersion)
}
if err != nil {
m.logger.Trace(err)
}
}
}

func (m *Mixed) wintunLoop(winTun WinTun) {
for {
packet, release, err := winTun.ReadPacket()
if err != nil {
return
}
if len(packet) < clashtcpip.IPv4PacketMinLength {
release()
continue
}
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
err = m.processIPv4(packet)
case 6:
err = m.processIPv6(packet)
default:
err = E.New("ip: unknown version: ", ipVersion)
}
if err != nil {
m.logger.Trace(err)
}
release()
}
}

func (m *Mixed) processIPv4(packet clashtcpip.IPv4Packet) error {
switch packet.Protocol() {
case clashtcpip.TCP:
return m.processIPv4TCP(packet, packet.Payload())
case clashtcpip.UDP:
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet),
})
m.endpoint.InjectInbound(header.IPv4ProtocolNumber, pkt)
pkt.DecRef()
return nil
case clashtcpip.ICMP:
return m.processIPv4ICMP(packet, packet.Payload())
default:
return common.Error(m.tun.Write(packet))
}
}

func (m *Mixed) processIPv6(packet clashtcpip.IPv6Packet) error {
switch packet.Protocol() {
case clashtcpip.TCP:
return m.processIPv6TCP(packet, packet.Payload())
case clashtcpip.UDP:
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet),
})
m.endpoint.InjectInbound(header.IPv6ProtocolNumber, pkt)
pkt.DecRef()
return nil
case clashtcpip.ICMPv6:
return m.processIPv6ICMP(packet, packet.Payload())
default:
return common.Error(m.tun.Write(packet))
}
}

func (m *Mixed) packetLoop() {
for {
packet := m.endpoint.ReadContext(m.ctx)
if packet == nil {
break
}
bufio.WriteVectorised(m.writer, packet.AsSlices())
packet.DecRef()
}
}

func (m *Mixed) Close() error {
m.endpoint.Attach(nil)
m.stack.Close()
for _, endpoint := range m.stack.CleanupEndpoints() {
endpoint.Abort()
}
return m.System.Close()
}
1 change: 1 addition & 0 deletions tun.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type Handler interface {

type Tun interface {
io.ReadWriter
CreateVectorisedWriter() N.VectorisedWriter
Close() error
}

Expand Down
15 changes: 15 additions & 0 deletions tun_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"unsafe"

"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
Expand Down Expand Up @@ -101,6 +102,20 @@ func (t *NativeTun) Write(p []byte) (n int, err error) {
return
}

func (t *NativeTun) CreateVectorisedWriter() N.VectorisedWriter {
return t
}

func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
var packetHeader []byte
if buffers[0].Byte(0)>>4 == 4 {
packetHeader = packetHeader4[:]
} else {
packetHeader = packetHeader6[:]
}
return t.tunWriter.WriteVectorised(append([]*buf.Buffer{buf.As(packetHeader)}, buffers...))
}

func (t *NativeTun) Close() error {
flushDNSCache()
return t.tunFile.Close()
Expand Down
6 changes: 6 additions & 0 deletions tun_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ import (

"github.com/sagernet/netlink"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/shell"
"github.com/sagernet/sing/common/x/list"
Expand Down Expand Up @@ -68,6 +70,10 @@ func (t *NativeTun) Write(p []byte) (n int, err error) {
return t.tunFile.Write(p)
}

func (t *NativeTun) CreateVectorisedWriter() N.VectorisedWriter {
return bufio.NewVectorisedWriter(t.tunFile)
}

var controlPath string

func init() {
Expand Down
Loading

0 comments on commit 10d98f2

Please sign in to comment.