diff --git a/redirect_iptables.go b/redirect_iptables.go index a9e4c86..c197dd9 100644 --- a/redirect_iptables.go +++ b/redirect_iptables.go @@ -10,23 +10,28 @@ import ( "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" F "github.com/sagernet/sing/common/format" - - "golang.org/x/sys/unix" ) -func (r *autoRedirect) iptablesPathForFamily(family int) string { - if family == unix.AF_INET { - return r.iptablesPath - } else { - return r.ip6tablesPath +func (r *autoRedirect) setupIPTables() error { + if r.enableIPv4 { + err := r.setupIPTablesForFamily(r.iptablesPath) + if err != nil { + return err + } + } + if r.enableIPv6 { + err := r.setupIPTablesForFamily(r.ip6tablesPath) + if err != nil { + return err + } } + return nil } -func (r *autoRedirect) setupIPTables(family int) error { +func (r *autoRedirect) setupIPTablesForFamily(iptablesPath string) error { tableNameOutput := r.tableName + "-output" tableNameForward := r.tableName + "-forward" tableNamePreRouteing := r.tableName + "-prerouting" - iptablesPath := r.iptablesPathForFamily(family) redirectPort := r.redirectPort() // OUTPUT err := r.runShell(iptablesPath, "-t nat -N", tableNameOutput) @@ -74,7 +79,7 @@ func (r *autoRedirect) setupIPTables(family int) error { routeAddress []netip.Prefix routeExcludeAddress []netip.Prefix ) - if family == unix.AF_INET { + if iptablesPath == r.iptablesPath { routeAddress = r.tunOptions.Inet4RouteAddress routeExcludeAddress = r.tunOptions.Inet4RouteExcludeAddress } else { @@ -112,10 +117,10 @@ func (r *autoRedirect) setupIPTables(family int) error { } if !r.tunOptions.EXP_DisableDNSHijack { dnsServer := common.Find(r.tunOptions.DNSServers, func(it netip.Addr) bool { - return it.Is4() == (family == unix.AF_INET) + return it.Is4() == (iptablesPath == r.iptablesPath) }) if !dnsServer.IsValid() { - if family == unix.AF_INET { + if iptablesPath == r.iptablesPath { dnsServer = r.tunOptions.Inet4Address[0].Addr().Next() } else { dnsServer = r.tunOptions.Inet6Address[0].Addr().Next() @@ -199,11 +204,19 @@ func (r *autoRedirect) setupIPTables(family int) error { return nil } -func (r *autoRedirect) cleanupIPTables(family int) { +func (r *autoRedirect) cleanupIPTables() { + if r.enableIPv4 { + r.cleanupIPTablesForFamily(r.iptablesPath) + } + if r.enableIPv6 { + r.cleanupIPTablesForFamily(r.ip6tablesPath) + } +} + +func (r *autoRedirect) cleanupIPTablesForFamily(iptablesPath string) { tableNameOutput := r.tableName + "-output" tableNameForward := r.tableName + "-forward" tableNamePreRouteing := r.tableName + "-prerouting" - iptablesPath := r.iptablesPathForFamily(family) _ = r.runShell(iptablesPath, "-t nat -D OUTPUT -j", tableNameOutput) _ = r.runShell(iptablesPath, "-t nat -F", tableNameOutput) _ = r.runShell(iptablesPath, "-t nat -X", tableNameOutput) diff --git a/redirect_linux.go b/redirect_linux.go index 815091b..bcad482 100644 --- a/redirect_linux.go +++ b/redirect_linux.go @@ -12,8 +12,6 @@ import ( E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" - - "golang.org/x/sys/unix" ) type autoRedirect struct { @@ -118,11 +116,19 @@ func (r *autoRedirect) Start() error { } r.redirectServer = server } - return r.setupTables() + if r.useNFTables { + return r.setupNFTables() + } else { + return r.setupIPTables() + } } func (r *autoRedirect) Close() error { - r.cleanupTables() + if r.useNFTables { + r.cleanupNFTables() + } else { + r.cleanupIPTables() + } return common.Close( common.PtrOrNil(r.redirectServer), ) @@ -134,7 +140,7 @@ func (r *autoRedirect) initializeNFTables() error { return err } defer nft.CloseLasting() - _, err = nft.ListTablesOfFamily(unix.AF_INET) + _, err = nft.ListTablesOfFamily(nftables.TableFamilyIPv4) if err != nil { return err } @@ -148,40 +154,3 @@ func (r *autoRedirect) redirectPort() uint16 { } return M.AddrPortFromNet(r.redirectServer.listener.Addr()).Port() } - -func (r *autoRedirect) setupTables() error { - var setupTables func(int) error - if r.useNFTables { - setupTables = r.setupNFTables - } else { - setupTables = r.setupIPTables - } - if r.enableIPv4 { - err := setupTables(unix.AF_INET) - if err != nil { - return err - } - } - if r.enableIPv6 { - err := setupTables(unix.AF_INET6) - if err != nil { - return err - } - } - return nil -} - -func (r *autoRedirect) cleanupTables() { - var cleanupTables func(int) - if r.useNFTables { - cleanupTables = r.cleanupNFTables - } else { - cleanupTables = r.cleanupIPTables - } - if r.enableIPv4 { - cleanupTables(unix.AF_INET) - } - if r.enableIPv6 { - cleanupTables(unix.AF_INET6) - } -} diff --git a/redirect_nftables.go b/redirect_nftables.go index ae1e689..3b3e04c 100644 --- a/redirect_nftables.go +++ b/redirect_nftables.go @@ -9,57 +9,24 @@ import ( "github.com/sagernet/nftables/binaryutil" "github.com/sagernet/nftables/expr" "github.com/sagernet/sing/common" - F "github.com/sagernet/sing/common/format" "golang.org/x/sys/unix" ) -const ( - nftablesChainOutput = "output" - nftablesChainForward = "forward" - nftablesChainPreRouting = "prerouting" -) - -func nftablesFamily(family int) nftables.TableFamily { - switch family { - case unix.AF_INET: - return nftables.TableFamilyIPv4 - case unix.AF_INET6: - return nftables.TableFamilyIPv6 - default: - panic(F.ToString("unknown family ", family)) - } -} - -func (r *autoRedirect) setupNFTables(family int) error { +func (r *autoRedirect) setupNFTables() error { nft, err := nftables.New() if err != nil { return err } defer nft.CloseLasting() - redirectPort := r.redirectPort() - table := nft.AddTable(&nftables.Table{ Name: r.tableName, - Family: nftablesFamily(family), - }) - - chainOutput := nft.AddChain(&nftables.Chain{ - Name: nftablesChainOutput, - Table: table, - Hooknum: nftables.ChainHookOutput, - Priority: nftables.ChainPriorityMangle, - Type: nftables.ChainTypeNAT, - }) - nft.AddRule(&nftables.Rule{ - Table: table, - Chain: chainOutput, - Exprs: nftablesRuleIfName(expr.MetaKeyOIFNAME, r.tunOptions.Name, nftablesRuleRedirectToPorts(redirectPort)...), + Family: nftables.TableFamilyINet, }) chainForward := nft.AddChain(&nftables.Chain{ - Name: nftablesChainForward, + Name: "forward", Table: table, Hooknum: nftables.ChainHookForward, Priority: nftables.ChainPriorityMangle, @@ -79,8 +46,22 @@ func (r *autoRedirect) setupNFTables(family int) error { }), }) + redirectPort := r.redirectPort() + chainOutput := nft.AddChain(&nftables.Chain{ + Name: "output", + Table: table, + Hooknum: nftables.ChainHookOutput, + Priority: nftables.ChainPriorityMangle, + Type: nftables.ChainTypeNAT, + }) + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chainOutput, + Exprs: nftablesRuleIfName(expr.MetaKeyOIFNAME, r.tunOptions.Name, nftablesRuleRedirectToPorts(redirectPort)...), + }) + chainPreRouting := nft.AddChain(&nftables.Chain{ - Name: nftablesChainPreRouting, + Name: "prerouting", Table: table, Hooknum: nftables.ChainHookPrerouting, Priority: nftables.ChainPriorityMangle, @@ -97,12 +78,13 @@ func (r *autoRedirect) setupNFTables(family int) error { routeAddress []netip.Prefix routeExcludeAddress []netip.Prefix ) - if table.Family == nftables.TableFamilyIPv4 { - routeAddress = r.tunOptions.Inet4RouteAddress - routeExcludeAddress = r.tunOptions.Inet4RouteExcludeAddress - } else { - routeAddress = r.tunOptions.Inet6RouteAddress - routeExcludeAddress = r.tunOptions.Inet6RouteExcludeAddress + if r.enableIPv4 { + routeAddress = append(routeAddress, r.tunOptions.Inet4RouteAddress...) + routeExcludeAddress = append(routeExcludeAddress, r.tunOptions.Inet4RouteExcludeAddress...) + } + if r.enableIPv6 { + routeAddress = append(routeAddress, r.tunOptions.Inet6RouteAddress...) + routeExcludeAddress = append(routeExcludeAddress, r.tunOptions.Inet6RouteExcludeAddress...) } for _, address := range routeExcludeAddress { nft.AddRule(&nftables.Rule{ @@ -140,37 +122,66 @@ func (r *autoRedirect) setupNFTables(family int) error { } if !r.tunOptions.EXP_DisableDNSHijack { - dnsServer := common.Find(r.tunOptions.DNSServers, func(it netip.Addr) bool { - return it.Is4() == (family == unix.AF_INET) + dnsServer4 := common.Find(r.tunOptions.DNSServers, func(it netip.Addr) bool { + return it.Is4() }) - if !dnsServer.IsValid() { - if family == unix.AF_INET { - dnsServer = r.tunOptions.Inet4Address[0].Addr().Next() - } else { - dnsServer = r.tunOptions.Inet6Address[0].Addr().Next() - } + dnsServer6 := common.Find(r.tunOptions.DNSServers, func(it netip.Addr) bool { + return it.Is6() + }) + if r.enableIPv4 && !dnsServer4.IsValid() { + dnsServer4 = r.tunOptions.Inet4Address[0].Addr().Next() + } + if r.enableIPv6 && !dnsServer6.IsValid() { + dnsServer6 = r.tunOptions.Inet6Address[0].Addr().Next() } if len(r.tunOptions.IncludeInterface) > 0 || len(r.tunOptions.IncludeUID) > 0 { for _, name := range r.tunOptions.IncludeInterface { + if r.enableIPv4 { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chainPreRouting, + Exprs: nftablesRuleIfName(expr.MetaKeyIIFNAME, name, append(routeExprs, nftablesRuleHijackDNS(nftables.TableFamilyIPv4, dnsServer4)...)...), + }) + } + if r.enableIPv6 { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chainPreRouting, + Exprs: nftablesRuleIfName(expr.MetaKeyIIFNAME, name, append(routeExprs, nftablesRuleHijackDNS(nftables.TableFamilyIPv6, dnsServer6)...)...), + }) + } + } + for _, uidRange := range r.tunOptions.IncludeUID { + if r.enableIPv4 { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chainPreRouting, + Exprs: nftablesRuleMetaUInt32Range(expr.MetaKeySKUID, uidRange, append(routeExprs, nftablesRuleHijackDNS(nftables.TableFamilyIPv4, dnsServer4)...)...), + }) + } + if r.enableIPv6 { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chainPreRouting, + Exprs: nftablesRuleMetaUInt32Range(expr.MetaKeySKUID, uidRange, append(routeExprs, nftablesRuleHijackDNS(nftables.TableFamilyIPv6, dnsServer6)...)...), + }) + } + } + } else { + if r.enableIPv4 { nft.AddRule(&nftables.Rule{ Table: table, Chain: chainPreRouting, - Exprs: nftablesRuleIfName(expr.MetaKeyIIFNAME, name, append(routeExprs, nftablesRuleHijackDNS(table.Family, dnsServer)...)...), + Exprs: append(routeExprs, nftablesRuleHijackDNS(nftables.TableFamilyIPv4, dnsServer4)...), }) } - for _, uidRange := range r.tunOptions.IncludeUID { + if r.enableIPv6 { nft.AddRule(&nftables.Rule{ Table: table, Chain: chainPreRouting, - Exprs: nftablesRuleMetaUInt32Range(expr.MetaKeySKUID, uidRange, append(routeExprs, nftablesRuleHijackDNS(table.Family, dnsServer)...)...), + Exprs: append(routeExprs, nftablesRuleHijackDNS(nftables.TableFamilyIPv6, dnsServer6)...), }) } - } else { - nft.AddRule(&nftables.Rule{ - Table: table, - Chain: chainPreRouting, - Exprs: append(routeExprs, nftablesRuleHijackDNS(table.Family, dnsServer)...), - }) } } @@ -219,18 +230,14 @@ func (r *autoRedirect) setupNFTables(family int) error { return nft.Flush() } -func (r *autoRedirect) cleanupNFTables(family int) { +func (r *autoRedirect) cleanupNFTables() { conn, err := nftables.New() if err != nil { return } - conn.FlushTable(&nftables.Table{ - Name: r.tableName, - Family: nftablesFamily(family), - }) conn.DelTable(&nftables.Table{ Name: r.tableName, - Family: nftablesFamily(family), + Family: nftables.TableFamilyINet, }) _ = conn.Flush() _ = conn.CloseLasting() diff --git a/redirect_nftables_expr.go b/redirect_nftables_expr.go index bf2d46f..9692d8a 100644 --- a/redirect_nftables_expr.go +++ b/redirect_nftables_expr.go @@ -48,37 +48,54 @@ func nftablesRuleMetaUInt32Range(key expr.MetaKey, uidRange ranges.Range[uint32] } func nftablesRuleDestinationAddress(address netip.Prefix, exprs ...expr.Any) []expr.Any { - var newExprs []expr.Any + newExprs := []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyNFPROTO, + Register: 1, + }, + } if address.Addr().Is4() { - newExprs = append(newExprs, &expr.Payload{ - OperationType: expr.PayloadLoad, - DestRegister: 1, - SourceRegister: 0, - Base: expr.PayloadBaseNetworkHeader, - Offset: 16, - Len: 4, - }, &expr.Bitwise{ - SourceRegister: 1, - DestRegister: 1, - Len: 4, - Xor: make([]byte, 4), - Mask: net.CIDRMask(address.Bits(), 32), - }) + newExprs = append(newExprs, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{unix.NFPROTO_IPV4}, + }, + &expr.Payload{ + OperationType: expr.PayloadLoad, + DestRegister: 1, + SourceRegister: 0, + Base: expr.PayloadBaseNetworkHeader, + Offset: 16, + Len: 4, + }, &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Xor: make([]byte, 4), + Mask: net.CIDRMask(address.Bits(), 32), + }) } else { - newExprs = append(newExprs, &expr.Payload{ - OperationType: expr.PayloadLoad, - DestRegister: 1, - SourceRegister: 0, - Base: expr.PayloadBaseNetworkHeader, - Offset: 24, - Len: 16, - }, &expr.Bitwise{ - SourceRegister: 1, - DestRegister: 1, - Len: 16, - Xor: make([]byte, 16), - Mask: net.CIDRMask(address.Bits(), 128), - }) + newExprs = append(newExprs, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{unix.NFPROTO_IPV6}, + }, + &expr.Payload{ + OperationType: expr.PayloadLoad, + DestRegister: 1, + SourceRegister: 0, + Base: expr.PayloadBaseNetworkHeader, + Offset: 24, + Len: 16, + }, &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 16, + Xor: make([]byte, 16), + Mask: net.CIDRMask(address.Bits(), 128), + }) } newExprs = append(newExprs, &expr.Cmp{ Op: expr.CmpOpEq, @@ -91,6 +108,15 @@ func nftablesRuleDestinationAddress(address netip.Prefix, exprs ...expr.Any) []e func nftablesRuleHijackDNS(family nftables.TableFamily, dnsServerAddress netip.Addr) []expr.Any { return []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyNFPROTO, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{uint8(family)}, + }, &expr.Meta{ Key: expr.MetaKeyL4PROTO, Register: 1,