Skip to content

Commit

Permalink
Clearly separate L4 and L7 connection handling logic
Browse files Browse the repository at this point in the history
Due to forward proxy's need to parse the CONNECT header, which is a
L7 layer feature, thus we are splitting the proxy into 2 types, for
better maintainability.

Reference:
- #17985 (comment)

Signed-off-by: Chun-Hung Tseng <[email protected]>
  • Loading branch information
henrybear327 committed May 16, 2024
1 parent 37b09a7 commit 954c141
Showing 1 changed file with 143 additions and 82 deletions.
225 changes: 143 additions & 82 deletions pkg/proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
package proxy

import (
"bufio"
"bytes"
"context"
"fmt"
"io"
Expand Down Expand Up @@ -211,6 +209,8 @@ type server struct {

blackholePeerMap map[int]uint8 // port number, blackhole type
blackholePeerMapMu sync.RWMutex

httpServer *http.Server
}

// NewServer returns a proxy implementation with no iptables/tc dependencies.
Expand Down Expand Up @@ -278,25 +278,132 @@ func NewServer(cfg ServerConfig) Server {
addr = s.from.Host
}

var ln net.Listener
var err error
if !s.tlsInfo.Empty() {
ln, err = transport.NewListener(addr, s.from.Scheme, &s.tlsInfo)
// We do not have an unified implementation for the proxy because we are dealing with the connection in different layers
// L7 (serverHandler) can't deal with unix socket, as it's L4 (transport layer's feature)
s.closeWg.Add(1)
if s.isForwardProxy {
// L7 proxy
//
// the main goal is to parse the CONNECT header for the destination host first (at L7 application layer),
// then continuing on to forward the traffic like we do in L4
//
// this implementation won't have features such as delayed connection accept, as it's a L7 proxy
if !(s.tlsInfo.Empty() && s.from.Scheme == "tcp") {
panic("Unsupported configuration")
}

handler := &serverHandler{
closeWg: &s.closeWg,
s: s,
}

s.httpServer = startHTTPServer(&s.closeWg, s.readyc, addr, handler)
} else {
ln, err = net.Listen(s.from.Scheme, addr)
// L4 proxy
//
// the destination host is known, thus, we can directly forward the traffic (at L4 transport layer)
var ln net.Listener
var err error
if !s.tlsInfo.Empty() {
ln, err = transport.NewListener(addr, s.from.Scheme, &s.tlsInfo)
} else {
ln, err = net.Listen(s.from.Scheme, addr)
}
if err != nil {
s.errc <- err
s.Close()
return s
}
s.listener = ln

go s.listenAndServe()
}

s.lg.Info("started proxying", zap.String("from", s.From()), zap.String("to", s.To()))
return s
}

func startHTTPServer(closeWg *sync.WaitGroup, readyc chan struct{}, addr string, handler *serverHandler) *http.Server {
srv := &http.Server{
Addr: addr,
}
srv.Handler = handler

go func() {
defer closeWg.Done() // let main know we are done cleaning up

close(readyc)
// always returns error. ErrServerClosed on graceful close
if err := srv.ListenAndServe(); err != http.ErrServerClosed {
panic(fmt.Sprintf("ListenAndServe(): %v", err))
}
}()

// returning reference so caller can call Shutdown()
return srv
}

type serverHandler struct {
closeWg *sync.WaitGroup

s *server
}

func (s *serverHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
hijacker, _ := resp.(http.Hijacker)
conn, _, err := hijacker.Hijack()
if err != nil {
s.errc <- err
s.Close()
return s
// write error back to conn
return
}
s.listener = ln

s.closeWg.Add(1)
go s.listenAndServe()
// dial to target host
targetConn, err := net.Dial("tcp", req.URL.Host)
if err != nil {
// write error back to conn
return
}

s.lg.Info("started proxying", zap.String("from", s.From()), zap.String("to", s.To()))
return s
// for CONNECT, we need to send 200 response back first
if req.Method == "CONNECT" {
conn.Write([]byte("HTTP/1.0 200 Connection established\r\n\r\n"))
}

var dstPort int
dstPort, err = getPort(targetConn.RemoteAddr())
if err != nil {
select {
case s.s.errc <- err:
select {
case <-s.s.donec:
return
default:
}
case <-s.s.donec:
return
}
s.s.lg.Debug("failed to parse port in transmit", zap.Error(err))
return
}

out := targetConn
in := conn

s.closeWg.Add(2)
go func() {
defer s.closeWg.Done()
// read incoming bytes from listener, dispatch to outgoing connection
s.s.transmit(out, in, dstPort)
out.Close()
in.Close()
}()
go func() {
defer s.closeWg.Done()
// read response from outgoing connection, write back to listener
s.s.receive(in, out, dstPort)
in.Close()
out.Close()
}()
}

func (s *server) From() string {
Expand All @@ -314,7 +421,6 @@ func (s *server) To() string {
// buffer packets per connection for awhile, reorder before transmit
// - https://github.com/etcd-io/etcd/issues/5614
// - https://github.com/etcd-io/etcd/pull/6918#issuecomment-264093034

func (s *server) listenAndServe() {
defer s.closeWg.Done()

Expand Down Expand Up @@ -387,44 +493,6 @@ func (s *server) listenAndServe() {
continue
}

parseHeaderForDestination := func() *string {
// the first request should always contain a CONNECT header field
// since we set the transport to forward the traffic to the proxy
buf := make([]byte, s.bufferSize)
var data []byte
var nr1 int
if nr1, err = in.Read(buf); err != nil {
if err == io.EOF {
return nil
// why??
// panic("No data available for forward proxy to work on")
}
panic(err)
} else {
data = buf[:nr1]
}

// attempt to parse for the HOST from the CONNECT request
var req *http.Request
if req, err = http.ReadRequest(bufio.NewReader(bytes.NewReader(data))); err != nil {
panic("Failed to parse header in forward proxy")
}

if req.Method == http.MethodConnect {
// make sure a reply is sent back to the client
connectResponse := &http.Response{
StatusCode: 200,
ProtoMajor: 1,
ProtoMinor: 1,
}
connectResponse.Write(in)

return &req.URL.Host
}

panic("Wrong header type to start the connection")
}

var out net.Conn
if !s.tlsInfo.Empty() {
var tp *http.Transport
Expand All @@ -442,25 +510,9 @@ func (s *server) listenAndServe() {
}
continue
}
if s.isForwardProxy {
if dest := parseHeaderForDestination(); dest == nil {
continue
} else {
out, err = tp.DialContext(ctx, "tcp", *dest)
}
} else {
out, err = tp.DialContext(ctx, s.to.Scheme, s.to.Host)
}
out, err = tp.DialContext(ctx, s.to.Scheme, s.to.Host)
} else {
if s.isForwardProxy {
if dest := parseHeaderForDestination(); dest == nil {
continue
} else {
out, err = net.Dial("tcp", *dest)
}
} else {
out, err = net.Dial(s.to.Scheme, s.to.Host)
}
out, err = net.Dial(s.to.Scheme, s.to.Host)
}
if err != nil {
select {
Expand Down Expand Up @@ -764,17 +816,26 @@ func (s *server) Error() <-chan error { return s.errc }
func (s *server) Close() (err error) {
s.closeOnce.Do(func() {
close(s.donec)
s.listenerMu.Lock()
if s.listener != nil {
err = s.listener.Close()
s.lg.Info(
"closed proxy listener",
zap.String("from", s.From()),
zap.String("to", s.To()),
)

if s.httpServer != nil {
if err = s.httpServer.Shutdown(context.TODO()); err != nil {
return
}
s.httpServer = nil
} else {
s.listenerMu.Lock()

if s.listener != nil {
err = s.listener.Close()
s.lg.Info(
"closed proxy listener",
zap.String("from", s.From()),
zap.String("to", s.To()),
)
}
s.lg.Sync()
s.listenerMu.Unlock()
}
s.lg.Sync()
s.listenerMu.Unlock()
})
s.closeWg.Wait()
return err
Expand Down

0 comments on commit 954c141

Please sign in to comment.