Skip to content

Commit

Permalink
fftw: support bandwidth limit and max traffic per day limit
Browse files Browse the repository at this point in the history
  • Loading branch information
fatedier committed Mar 21, 2019
1 parent 470bd04 commit 7abc256
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 55 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
### v0.2.0

- Support bandwidth limit on fftw.
- Support limit the max traffic fftw can used every day.

### v0.1.0

Expand Down
1 change: 1 addition & 0 deletions CHANGELOG_zh.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
### v0.2.0

- fftw 支持限速。
- fftw 支持限制每天使用的流量,超过限制后会自动从服务器端注销,第二天恢复。

### v0.1.0

Expand Down
3 changes: 2 additions & 1 deletion cmd/fftw/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&options.ServerAddr, "server_addr", "s", version.DefaultServerAddr(), "remote fft server address")
rootCmd.PersistentFlags().StringVarP(&options.BindAddr, "bind_addr", "b", "0.0.0.0:7778", "bind address")
rootCmd.PersistentFlags().StringVarP(&options.AdvicePublicIP, "advice_public_ip", "p", "", "fft worker's advice public ip")
rootCmd.PersistentFlags().IntVarP(&options.RateKB, "rate", "", 2048, "max bandwidth fftw will provide, unit is KB, default is 2048KB and min value is 50KB")
rootCmd.PersistentFlags().IntVarP(&options.RateKB, "rate", "", 4096, "max bandwidth fftw will provide, unit is KB, default is 4096KB and min value is 50KB")
rootCmd.PersistentFlags().IntVarP(&options.MaxTrafficMBPerDay, "max_traffic_per_day", "", 0, "max traffic fftw can use every day, 0 means no limit, unit is MB, default is 0MB and min value is 128MB")

rootCmd.PersistentFlags().StringVarP(&options.LogFile, "log_file", "", "console", "log file path")
rootCmd.PersistentFlags().StringVarP(&options.LogLevel, "log_level", "", "info", "log level")
Expand Down
12 changes: 8 additions & 4 deletions worker/match.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"sync"
"time"

rateio "github.com/fatedier/fft/pkg/io"
fio "github.com/fatedier/fft/pkg/io"
"github.com/fatedier/fft/pkg/log"
"github.com/fatedier/fft/pkg/msg"

Expand Down Expand Up @@ -37,16 +37,18 @@ type MatchController struct {
conns map[string]*TransferConn

rateLimit *rate.Limiter
statFunc func(int)
mu sync.Mutex
}

func NewMatchController(rateByte int) *MatchController {
func NewMatchController(rateByte int, statFunc func(int)) *MatchController {
if rateByte < 50*1024 {
rateByte = 50 * 1024
}
return &MatchController{
conns: make(map[string]*TransferConn),
rateLimit: rate.NewLimiter(rate.Limit(float64(rateByte)), 16*1024),
statFunc: statFunc,
}
}

Expand All @@ -66,12 +68,14 @@ func (mc *MatchController) DealTransferConn(tc *TransferConn, timeout time.Durat
case pairConn := <-tc.pairConnCh:
var sender, receiver io.ReadWriteCloser
if tc.isSender {
sender = gio.WrapReadWriteCloser(rateio.NewRateReader(tc.conn, mc.rateLimit), tc.conn, func() error {
wrapReader := fio.NewCallbackReader(fio.NewRateReader(tc.conn, mc.rateLimit), mc.statFunc)
sender = gio.WrapReadWriteCloser(wrapReader, tc.conn, func() error {
return tc.conn.Close()
})
receiver = pairConn.conn
} else {
sender = gio.WrapReadWriteCloser(rateio.NewRateReader(pairConn.conn, mc.rateLimit), pairConn.conn, func() error {
wrapReader := fio.NewCallbackReader(fio.NewRateReader(pairConn.conn, mc.rateLimit), mc.statFunc)
sender = gio.WrapReadWriteCloser(wrapReader, pairConn.conn, func() error {
return pairConn.conn.Close()
})
receiver = tc.conn
Expand Down
77 changes: 62 additions & 15 deletions worker/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"crypto/tls"
"fmt"
"net"
"sync"
"time"

"github.com/fatedier/fft/pkg/log"
Expand All @@ -15,30 +16,42 @@ type Register struct {
port int64
advicePublicIP string
serverAddr string
conn net.Conn

closed bool
mu sync.Mutex
}

func NewRegister(port int64, advicePublicIP string, serverAddr string) *Register {
func NewRegister(port int64, advicePublicIP string, serverAddr string) (*Register, error) {
conn, err := net.Dial("tcp", serverAddr)
if err != nil {
return nil, err
}
conn = tls.Client(conn, &tls.Config{InsecureSkipVerify: true})

return &Register{
port: port,
advicePublicIP: advicePublicIP,
serverAddr: serverAddr,
}
conn: conn,
closed: false,
}, nil
}

func (r *Register) Register(conn net.Conn) error {
msg.WriteMsg(conn, &msg.RegisterWorker{
func (r *Register) Register() error {
msg.WriteMsg(r.conn, &msg.RegisterWorker{
Version: version.Full(),
PublicIP: r.advicePublicIP,
BindPort: r.port,
})

conn.SetReadDeadline(time.Now().Add(10 * time.Second))
m, err := msg.ReadMsg(conn)
r.conn.SetReadDeadline(time.Now().Add(10 * time.Second))
m, err := msg.ReadMsg(r.conn)
if err != nil {
log.Warn("read RegisterWorkerResp error: %v", err)
return err
}
conn.SetReadDeadline(time.Time{})
r.conn.SetReadDeadline(time.Time{})

resp, ok := m.(*msg.RegisterWorkerResp)
if !ok {
Expand All @@ -51,39 +64,73 @@ func (r *Register) Register(conn net.Conn) error {
return nil
}

func (r *Register) RunKeepAlive(conn net.Conn) error {
func (r *Register) RunKeepAlive() {
var err error
for {
// send ping and read pong
for {
msg.WriteMsg(conn, &msg.Ping{})
// in case it is closed before
if r.conn == nil {
break
}

msg.WriteMsg(r.conn, &msg.Ping{})

_, err = msg.ReadMsg(conn)
_, err = msg.ReadMsg(r.conn)
if err != nil {
conn.Close()
r.conn.Close()
break
}

time.Sleep(10 * time.Second)
}

for {
conn, err = net.Dial("tcp", r.serverAddr)
r.mu.Lock()
closed := r.closed
r.mu.Unlock()
if r.closed {
return
}

conn, err := net.Dial("tcp", r.serverAddr)
if err != nil {
time.Sleep(10 * time.Second)
continue
}
conn = tls.Client(conn, &tls.Config{InsecureSkipVerify: true})

err = r.Register(conn)
if err != nil {
r.mu.Lock()
closed = r.closed
if closed {
conn.Close()
r.mu.Unlock()
return
}
r.conn = conn
r.mu.Unlock()

err = r.Register()
if err != nil {
r.conn.Close()
time.Sleep(10 * time.Second)
continue
}

break
}
}
return nil
}

func (r *Register) Close() {
r.mu.Lock()
defer r.mu.Unlock()
r.closed = true
r.conn.Close()
}

// Reset can be only called after Close
func (r *Register) Reset() {
r.closed = false
r.conn = nil
}
97 changes: 62 additions & 35 deletions worker/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ import (
)

type Options struct {
ServerAddr string
BindAddr string
AdvicePublicIP string
RateKB int // xx KB/s
ServerAddr string
BindAddr string
AdvicePublicIP string
RateKB int // xx KB/s
MaxTrafficMBPerDay int // xx MB, 0 is no limit

LogFile string
LogLevel string
Expand All @@ -32,18 +33,27 @@ func (op *Options) Check() error {
op.LogMaxDays = 3
}
if op.RateKB < 50 {
return fmt.Errorf("rate should greater than 50KB")
return fmt.Errorf("rate should be greater than 50KB")
}
if op.MaxTrafficMBPerDay < 128 && op.MaxTrafficMBPerDay != 0 {
return fmt.Errorf("max_traffic_per_day should be greater than 128MB")
}
return nil
}

type Service struct {
serverAddr string
advicePublicIP string

l net.Listener
matchCtl *MatchController
tlsConfig *tls.Config
serverAddr string
advicePublicIP string
rateKB int
maxTrafficMBPerDay int

l net.Listener
matchCtl *MatchController
register *Register
trafficLimiter *TrafficLimiter
tlsConfig *tls.Config

stopCh chan struct{}
}

func NewService(options Options) (*Service, error) {
Expand All @@ -63,43 +73,60 @@ func NewService(options Options) (*Service, error) {
}
log.Info("fftw listen on: %s", l.Addr().String())

return &Service{
serverAddr: options.ServerAddr,
advicePublicIP: options.AdvicePublicIP,
_, portStr, err := net.SplitHostPort(l.Addr().String())
if err != nil {
return nil, fmt.Errorf("get bind port error, bind address: %v", l.Addr().String())
}
port, err := strconv.Atoi(portStr)
if err != nil {
return nil, fmt.Errorf("get bind port error: %v", err)
}

register, err := NewRegister(int64(port), options.AdvicePublicIP, options.ServerAddr)
if err != nil {
return nil, fmt.Errorf("new register error: %v", err)
}

svc := &Service{
serverAddr: options.ServerAddr,
advicePublicIP: options.AdvicePublicIP,
rateKB: options.RateKB,
maxTrafficMBPerDay: options.MaxTrafficMBPerDay,

l: l,
matchCtl: NewMatchController(options.RateKB * 1024),
register: register,
tlsConfig: generateTLSConfig(),
}, nil

stopCh: make(chan struct{}),
}

svc.trafficLimiter = NewTrafficLimiter(uint64(options.MaxTrafficMBPerDay*1024*1024), func() {
svc.register.Close()
log.Info("reach traffic limit %dMB one day, unregister from server", options.MaxTrafficMBPerDay)
}, func() {
svc.register.Reset()
go svc.register.RunKeepAlive()
log.Info("restore from traffic limit since it's a new day")
})

svc.matchCtl = NewMatchController(options.RateKB*1024, func(n int) {
svc.trafficLimiter.AddCount(uint64(n))
})
return svc, nil
}

func (svc *Service) Run() error {
go svc.worker()
go svc.trafficLimiter.Run()

// connect to server
conn, err := net.Dial("tcp", svc.serverAddr)
if err != nil {
return err
}
conn = tls.Client(conn, &tls.Config{InsecureSkipVerify: true})

_, portStr, err := net.SplitHostPort(svc.l.Addr().String())
if err != nil {
return fmt.Errorf("get bind port error, bind address: %v", svc.l.Addr().String())
}
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("get bind port error")
}

register := NewRegister(int64(port), svc.advicePublicIP, svc.serverAddr)
err = register.Register(conn)
err := svc.register.Register()
if err != nil {
return fmt.Errorf("register worker to server error: %v", err)
}
log.Info("register to server success")

register.RunKeepAlive(conn)
svc.register.RunKeepAlive()
<-svc.stopCh
return nil
}

Expand Down
Loading

0 comments on commit 7abc256

Please sign in to comment.