Skip to content

Commit

Permalink
xtcp: when connection timeout occurs, support fallback to STCP (#3460)
Browse files Browse the repository at this point in the history
  • Loading branch information
fatedier authored May 30, 2023
1 parent 555db9d commit c7a0cfc
Show file tree
Hide file tree
Showing 16 changed files with 230 additions and 68 deletions.
6 changes: 4 additions & 2 deletions client/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,8 @@ func (cm *ConnectionManager) OpenConnection() error {
}
tlsConfig.NextProtos = []string{"frp"}

conn, err := quic.DialAddr(
conn, err := quic.DialAddrContext(
cm.ctx,
net.JoinHostPort(cm.cfg.ServerAddr, strconv.Itoa(cm.cfg.ServerPort)),
tlsConfig, &quic.Config{
MaxIdleTimeout: time.Duration(cm.cfg.QUICMaxIdleTimeout) * time.Second,
Expand Down Expand Up @@ -467,7 +468,8 @@ func (cm *ConnectionManager) realConnect() (net.Conn, error) {
Hook: utilnet.DialHookCustomTLSHeadByte(tlsConfig != nil, cm.cfg.DisableCustomTLSFirstByte),
}),
)
conn, err := libdial.Dial(
conn, err := libdial.DialContext(
cm.ctx,
net.JoinHostPort(cm.cfg.ServerAddr, strconv.Itoa(cm.cfg.ServerPort)),
dialOptions...,
)
Expand Down
24 changes: 19 additions & 5 deletions client/visitor/stcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,20 @@ type STCPVisitor struct {
}

func (sv *STCPVisitor) Run() (err error) {
sv.l, err = net.Listen("tcp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort)))
if err != nil {
return
if sv.cfg.BindPort > 0 {
sv.l, err = net.Listen("tcp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort)))
if err != nil {
return
}
go sv.worker()
}

go sv.worker()
go sv.internalConnWorker()
return
}

func (sv *STCPVisitor) Close() {
sv.l.Close()
sv.BaseVisitor.Close()
}

func (sv *STCPVisitor) worker() {
Expand All @@ -56,7 +59,18 @@ func (sv *STCPVisitor) worker() {
xl.Warn("stcp local listener closed")
return
}
go sv.handleConn(conn)
}
}

func (sv *STCPVisitor) internalConnWorker() {
xl := xlog.FromContextSafe(sv.ctx)
for {
conn, err := sv.internalLn.Accept()
if err != nil {
xl.Warn("stcp internal listener closed")
return
}
go sv.handleConn(conn)
}
}
Expand Down
1 change: 1 addition & 0 deletions client/visitor/sudp.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ func (sv *SUDPVisitor) Close() {
default:
close(sv.checkCloseCh)
}
sv.BaseVisitor.Close()
if sv.udpConn != nil {
sv.udpConn.Close()
}
Expand Down
20 changes: 20 additions & 0 deletions client/visitor/visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ import (

"github.com/fatedier/frp/pkg/config"
"github.com/fatedier/frp/pkg/transport"
utilnet "github.com/fatedier/frp/pkg/util/net"
"github.com/fatedier/frp/pkg/util/xlog"
)

// Visitor is used for forward traffics from local port tot remote service.
type Visitor interface {
Run() error
AcceptConn(conn net.Conn) error
Close()
}

Expand All @@ -35,14 +37,17 @@ func NewVisitor(
cfg config.VisitorConf,
clientCfg config.ClientCommonConf,
connectServer func() (net.Conn, error),
transferConn func(string, net.Conn) error,
msgTransporter transport.MessageTransporter,
) (visitor Visitor) {
xl := xlog.FromContextSafe(ctx).Spawn().AppendPrefix(cfg.GetBaseInfo().ProxyName)
baseVisitor := BaseVisitor{
clientCfg: clientCfg,
connectServer: connectServer,
transferConn: transferConn,
msgTransporter: msgTransporter,
ctx: xlog.NewContext(ctx, xl),
internalLn: utilnet.NewInternalListener(),
}
switch cfg := cfg.(type) {
case *config.STCPVisitorConf:
Expand All @@ -69,9 +74,24 @@ func NewVisitor(
type BaseVisitor struct {
clientCfg config.ClientCommonConf
connectServer func() (net.Conn, error)
transferConn func(string, net.Conn) error
msgTransporter transport.MessageTransporter
l net.Listener
internalLn *utilnet.InternalListener

mu sync.RWMutex
ctx context.Context
}

func (v *BaseVisitor) AcceptConn(conn net.Conn) error {
return v.internalLn.PutConn(conn)
}

func (v *BaseVisitor) Close() {
if v.l != nil {
v.l.Close()
}
if v.internalLn != nil {
v.internalLn.Close()
}
}
36 changes: 24 additions & 12 deletions client/visitor/visitor_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package visitor

import (
"context"
"fmt"
"net"
"sync"
"time"
Expand All @@ -34,7 +35,7 @@ type Manager struct {

checkInterval time.Duration

mu sync.Mutex
mu sync.RWMutex
ctx context.Context

stopCh chan struct{}
Expand Down Expand Up @@ -83,11 +84,24 @@ func (vm *Manager) Run() {
}
}

func (vm *Manager) Close() {
vm.mu.Lock()
defer vm.mu.Unlock()
for _, v := range vm.visitors {
v.Close()
}
select {
case <-vm.stopCh:
default:
close(vm.stopCh)
}
}

// Hold lock before calling this function.
func (vm *Manager) startVisitor(cfg config.VisitorConf) (err error) {
xl := xlog.FromContextSafe(vm.ctx)
name := cfg.GetBaseInfo().ProxyName
visitor := NewVisitor(vm.ctx, cfg, vm.clientCfg, vm.connectServer, vm.msgTransporter)
visitor := NewVisitor(vm.ctx, cfg, vm.clientCfg, vm.connectServer, vm.TransferConn, vm.msgTransporter)
err = visitor.Run()
if err != nil {
xl.Warn("start error: %v", err)
Expand Down Expand Up @@ -139,15 +153,13 @@ func (vm *Manager) Reload(cfgs map[string]config.VisitorConf) {
}
}

func (vm *Manager) Close() {
vm.mu.Lock()
defer vm.mu.Unlock()
for _, v := range vm.visitors {
v.Close()
}
select {
case <-vm.stopCh:
default:
close(vm.stopCh)
// TransferConn transfers a connection to a visitor.
func (vm *Manager) TransferConn(name string, conn net.Conn) error {
vm.mu.RLock()
defer vm.mu.RUnlock()
v, ok := vm.visitors[name]
if !ok {
return fmt.Errorf("visitor [%s] not found", name)
}
return v.AcceptConn(conn)
}
60 changes: 51 additions & 9 deletions client/visitor/xtcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,15 @@ func (sv *XTCPVisitor) Run() (err error) {
sv.session = NewQUICTunnelSession(&sv.clientCfg)
}

sv.l, err = net.Listen("tcp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort)))
if err != nil {
return
if sv.cfg.BindPort > 0 {
sv.l, err = net.Listen("tcp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort)))
if err != nil {
return
}
go sv.worker()
}

go sv.worker()
go sv.internalConnWorker()
go sv.processTunnelStartEvents()
if sv.cfg.KeepTunnelOpen {
sv.retryLimiter = rate.NewLimiter(rate.Every(time.Hour/time.Duration(sv.cfg.MaxRetriesAnHour)), sv.cfg.MaxRetriesAnHour)
Expand All @@ -74,8 +77,12 @@ func (sv *XTCPVisitor) Run() (err error) {
}

func (sv *XTCPVisitor) Close() {
sv.l.Close()
sv.cancel()
sv.mu.Lock()
defer sv.mu.Unlock()
sv.BaseVisitor.Close()
if sv.cancel != nil {
sv.cancel()
}
if sv.session != nil {
sv.session.Close()
}
Expand All @@ -89,7 +96,18 @@ func (sv *XTCPVisitor) worker() {
xl.Warn("xtcp local listener closed")
return
}
go sv.handleConn(conn)
}
}

func (sv *XTCPVisitor) internalConnWorker() {
xl := xlog.FromContextSafe(sv.ctx)
for {
conn, err := sv.internalLn.Accept()
if err != nil {
xl.Warn("xtcp internal listener closed")
return
}
go sv.handleConn(conn)
}
}
Expand Down Expand Up @@ -139,15 +157,37 @@ func (sv *XTCPVisitor) keepTunnelOpenWorker() {

func (sv *XTCPVisitor) handleConn(userConn net.Conn) {
xl := xlog.FromContextSafe(sv.ctx)
defer userConn.Close()
isConnTrasfered := false
defer func() {
if !isConnTrasfered {
userConn.Close()
}
}()

xl.Debug("get a new xtcp user connection")

// Open a tunnel connection to the server. If there is already a successful hole-punching connection,
// it will be reused. Otherwise, it will block and wait for a successful hole-punching connection until timeout.
tunnelConn, err := sv.openTunnel()
ctx := context.Background()
if sv.cfg.FallbackTo != "" {
timeoutCtx, cancel := context.WithTimeout(ctx, time.Duration(sv.cfg.FallbackTimeoutMs)*time.Millisecond)
defer cancel()
ctx = timeoutCtx
}
tunnelConn, err := sv.openTunnel(ctx)
if err != nil {
xl.Error("open tunnel error: %v", err)
// no fallback, just return
if sv.cfg.FallbackTo == "" {
return
}

xl.Debug("try to transfer connection to visitor: %s", sv.cfg.FallbackTo)
if err := sv.transferConn(sv.cfg.FallbackTo, userConn); err != nil {
xl.Error("transfer connection to visitor %s error: %v", sv.cfg.FallbackTo, err)
return
}
isConnTrasfered = true
return
}

Expand All @@ -171,7 +211,7 @@ func (sv *XTCPVisitor) handleConn(userConn net.Conn) {
}

// openTunnel will open a tunnel connection to the target server.
func (sv *XTCPVisitor) openTunnel() (conn net.Conn, err error) {
func (sv *XTCPVisitor) openTunnel(ctx context.Context) (conn net.Conn, err error) {
xl := xlog.FromContextSafe(sv.ctx)
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
Expand All @@ -185,6 +225,8 @@ func (sv *XTCPVisitor) openTunnel() (conn net.Conn, err error) {
select {
case <-sv.ctx.Done():
return nil, sv.ctx.Err()
case <-ctx.Done():
return nil, ctx.Err()
case <-immediateTrigger:
conn, err = sv.getTunnelConn()
case <-ticker.C:
Expand Down
6 changes: 3 additions & 3 deletions cmd/frpc/sub/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ var rootCmd = &cobra.Command{
// Do not show command usage here.
err := runClient(cfgFile)
if err != nil {
fmt.Println(err)
os.Exit(1)
}
return nil
Expand Down Expand Up @@ -199,6 +198,7 @@ func parseClientCommonCfgFromCmd() (cfg config.ClientCommonConf, err error) {
func runClient(cfgFilePath string) error {
cfg, pxyCfgs, visitorCfgs, err := config.ParseClientConfig(cfgFilePath)
if err != nil {
fmt.Println(err)
return err
}
return startService(cfg, pxyCfgs, visitorCfgs, cfgFilePath)
Expand All @@ -214,8 +214,8 @@ func startService(
cfg.LogMaxDays, cfg.DisableLogColor)

if cfgFile != "" {
log.Trace("start frpc service for config file [%s]", cfgFile)
defer log.Trace("frpc service for config file [%s] stopped", cfgFile)
log.Info("start frpc service for config file [%s]", cfgFile)
defer log.Info("frpc service for config file [%s] stopped", cfgFile)
}
svr, errRet := client.NewService(cfg, pxyCfgs, visitorCfgs, cfgFile)
if errRet != nil {
Expand Down
6 changes: 6 additions & 0 deletions conf/frpc_full.ini
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,8 @@ server_name = secret_tcp
sk = abcdefg
# connect this address to visitor stcp server
bind_addr = 127.0.0.1
# bind_port can be less than 0, it means don't bind to the port and only receive connections redirected from
# other visitors. (This is not supported for SUDP now)
bind_port = 9000
use_encryption = false
use_compression = false
Expand All @@ -355,6 +357,8 @@ type = xtcp
server_name = p2p_tcp
sk = abcdefg
bind_addr = 127.0.0.1
# bind_port can be less than 0, it means don't bind to the port and only receive connections redirected from
# other visitors. (This is not supported for SUDP now)
bind_port = 9001
use_encryption = false
use_compression = false
Expand All @@ -363,6 +367,8 @@ keep_tunnel_open = false
# effective when keep_tunnel_open is set to true, the number of attempts to punch through per hour
max_retries_an_hour = 8
min_retry_interval = 90
# fallback_to = stcp_visitor
# fallback_timeout_ms = 500

[tcpmuxhttpconnect]
type = tcpmux
Expand Down
7 changes: 4 additions & 3 deletions pkg/config/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -661,9 +661,10 @@ func Test_LoadClientBasicConf(t *testing.T) {
BindAddr: "127.0.0.1",
BindPort: 9001,
},
Protocol: "quic",
MaxRetriesAnHour: 8,
MinRetryInterval: 90,
Protocol: "quic",
MaxRetriesAnHour: 8,
MinRetryInterval: 90,
FallbackTimeoutMs: 1000,
},
}

Expand Down
Loading

0 comments on commit c7a0cfc

Please sign in to comment.