Skip to content

Commit

Permalink
Hijacked connection will not be handled by the graceful shutdown
Browse files Browse the repository at this point in the history
Need to fix this
  • Loading branch information
henrybear327 committed May 19, 2024
1 parent ce9ac27 commit ecb34d5
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 37 deletions.
126 changes: 92 additions & 34 deletions pkg/proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,9 @@ type server struct {
donec chan struct{}
errc chan error

closeOnce sync.Once
closeWg sync.WaitGroup
closeOnce sync.Once
closeWg sync.WaitGroup
closeHijackedConn sync.WaitGroup

listenerMu sync.RWMutex
listener *customListener
Expand Down Expand Up @@ -279,7 +280,6 @@ func NewServer(cfg ServerConfig) Server {
}

s.closeWg.Add(1)

var ln net.Listener
var err error
if !s.tlsInfo.Empty() {
Expand Down Expand Up @@ -332,7 +332,8 @@ func (c customListener) Accept() (net.Conn, error) {
c.s.pauseAcceptMu.Unlock()
select {
case <-pausec:
default:
case <-c.s.donec:
return nil, nil
}

c.s.latencyAcceptMu.RLock()
Expand All @@ -341,24 +342,44 @@ func (c customListener) Accept() (net.Conn, error) {
if lat > 0 {
select {
case <-time.After(lat):
default:
case <-c.s.donec:
return nil, nil
}
}

conn, err := c.l.Accept()
if err != nil {
c.s.errc <- err
select {
case c.s.errc <- err:
select {
case <-c.s.donec:
return nil, nil
default:
}
case <-c.s.donec:
return nil, nil
}
c.s.lg.Debug("listener accept error", zap.Error(err))

if strings.HasSuffix(err.Error(), "use of closed network connection") {
select {
case <-time.After(c.s.retryInterval):
default:
case <-c.s.donec:
return nil, nil
}
c.s.lg.Debug("listener is closed; retry listening on", zap.String("from", c.s.From()))

if err = c.s.ResetListener(); err != nil {
c.s.errc <- err
select {
case c.s.errc <- err:
select {
case <-c.s.donec:
return nil, nil
default:
}
case <-c.s.donec:
return nil, nil
}
c.s.lg.Warn("failed to reset listener", zap.Error(err))
}
}
Expand Down Expand Up @@ -386,7 +407,16 @@ func (s *serverHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
hijacker, _ := resp.(http.Hijacker)
in, _, err := hijacker.Hijack()
if err != nil {
s.s.errc <- err
select {
case s.s.errc <- err:
select {
case <-s.s.donec:
return
default:
}
case <-s.s.donec:
return
}
s.s.lg.Debug("ServeHTTP hijack error", zap.Error(err))
panic(err)
}
Expand All @@ -407,7 +437,16 @@ func (s *serverHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
var tp *http.Transport
tp, err = transport.NewTransport(s.s.tlsInfo, s.s.dialTimeout)
if err != nil {
s.s.errc <- err
select {
case s.s.errc <- err:
select {
case <-s.s.donec:
return
default:
}
case <-s.s.donec:
return
}
s.s.lg.Debug("failed to get new Transport", zap.Error(err))
return
}
Expand All @@ -416,29 +455,47 @@ func (s *serverHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
out, err = net.Dial(targetScheme, targetHost)
}
if err != nil {
s.s.errc <- err
select {
case s.s.errc <- err:
select {
case <-s.s.donec:
return
default:
}
case <-s.s.donec:
return
}
s.s.lg.Debug("failed to dial", zap.Error(err))
return
}

var dstPort int
dstPort, err = getPort(out.RemoteAddr())
if err != nil {
s.s.errc <- err
select {
case s.s.errc <- err:
select {
case <-s.s.donec:
return
default:
}
case <-s.s.donec:
return
}
s.s.lg.Debug("failed to parse port in transmit", zap.Error(err))
return
}

s.s.closeWg.Add(2)
s.s.closeHijackedConn.Add(2)
go func() {
defer s.s.closeWg.Done()
defer s.s.closeHijackedConn.Done()
// read incoming bytes from listener, dispatch to outgoing connection
s.s.transmit(out, in, dstPort)
out.Close()
in.Close()
}()
go func() {
defer s.s.closeWg.Done()
defer s.s.closeHijackedConn.Done()
// read response from outgoing connection, write back to listener
s.s.receive(in, out, dstPort)
in.Close()
Expand Down Expand Up @@ -710,27 +767,28 @@ func (s *server) Close() (err error) {
s.closeOnce.Do(func() {
close(s.donec)

if s.httpServer != nil {
if err = s.httpServer.Shutdown(context.TODO()); err != nil {
return
}
s.httpServer = nil
} else {
s.listenerMu.Lock()

if s.listener != nil {
err = s.listener.Close()
s.lg.Info(
"closed proxy listener",
zap.String("from", s.From()),
zap.String("to", s.To()),
)
}
s.lg.Sync()
s.listenerMu.Unlock()
s.closeHijackedConn.Wait()

if err = s.httpServer.Shutdown(context.TODO()); err != nil {
return
}
s.httpServer = nil

// s.listenerMu.Lock()
// if s.listener != nil {
// err = s.listener.Close()
// s.lg.Info(
// "closed proxy listener",
// zap.String("from", s.From()),
// zap.String("to", s.To()),
// )
// }
// s.lg.Sync()
// s.listenerMu.Unlock()
})
s.closeWg.Wait()

// s.closeWg.Wait()

return err
}

Expand Down
6 changes: 3 additions & 3 deletions tests/e2e/blackhole_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ func blackholeTestByMockingPartition(t *testing.T, clusterSize int, partitionLea
t.Logf("Blackholing traffic from and to member %q", partitionedMember.Config().Name)
epc.BlackholePeer(partitionedMember)

t.Logf("Wait 5s for any open connections to expire")
time.Sleep(5 * time.Second)
t.Logf("Wait 1s for any open connections to expire")
time.Sleep(1 * time.Second)

t.Logf("Wait for new leader election with remaining members")
leaderEPC := epc.Procs[waitLeader(t, epc, mockPartitionNodeIndex)]
Expand All @@ -80,7 +80,7 @@ func blackholeTestByMockingPartition(t *testing.T, clusterSize int, partitionLea
epc.UnblackholePeer(partitionedMember)

leaderEPC = epc.Procs[epc.WaitLeader(t)]
time.Sleep(5 * time.Second)
time.Sleep(1 * time.Second)
assertRevision(t, leaderEPC, 21)
assertRevision(t, partitionedMember, 21)
}
Expand Down

0 comments on commit ecb34d5

Please sign in to comment.