From 5d3d6db9a241a874244303982f455d4c691b6471 Mon Sep 17 00:00:00 2001 From: Sander Bruens Date: Mon, 23 Sep 2024 14:33:45 -0400 Subject: [PATCH] refactor: pass in logger to service so caller can control logs (#207) * refactor: create re-usable service that can be re-used by Caddy * Remove need to return errors in opt functions. * Move the service into `shadowsocks.go`. * refactor: pass in logger to service so caller can control logs * Move initialization of handlers to the constructor. * Pass a `list.List` instead of a `CipherList`. * Rename `SSServer` to `OutlineServer`. * refactor: make connection metrics optional * Make setting the logger a setter function. * Revert "Pass a `list.List` instead of a `CipherList`." This reverts commit 1259af8d312fe0676856301c6961b848e96cc967. * Create noop metrics if nil. * Revert some more changes. * Use a noop metrics struct if no metrics provided. * Add noop implementation for `ShadowsocksConnMetrics`. * Move logger arg. * Resolve nil metrics. * Set logger explicitly to `noopLogger` in service creation. * Set `noopLogger` in `NewShadowsocksStreamAuthenticator()` if nil. * Fix logger reference. * Use a `noopLogger` if `SetLogger()` is called with `nil`. * Update tests. * Use concrete `slog.Logger` instead of `Logger` interface now that we don't need a zap adapter for Caddy. * Move `WithLogger()` down. * Remove `nil` check. * Use `math.MaxInt` to make sure no error log records are created. --- cmd/outline-ss-server/main.go | 2 + internal/integration_test/integration_test.go | 8 ++-- service/logger.go | 11 ++++- service/shadowsocks.go | 22 ++++++++- service/tcp.go | 46 +++++++++++------- service/tcp_test.go | 20 ++++---- service/udp.go | 48 ++++++++++++------- service/udp_test.go | 10 ++-- 8 files changed, 110 insertions(+), 57 deletions(-) diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index bd2aa177..c2704c1f 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -222,6 +222,7 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) { service.WithNatTimeout(s.natTimeout), service.WithMetrics(s.serviceMetrics), service.WithReplayCache(&s.replayCache), + service.WithLogger(slog.Default()), ) ln, err := lnSet.ListenStream(addr) if err != nil { @@ -248,6 +249,7 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) { service.WithNatTimeout(s.natTimeout), service.WithMetrics(s.serviceMetrics), service.WithReplayCache(&s.replayCache), + service.WithLogger(slog.Default()), ) if err != nil { return err diff --git a/internal/integration_test/integration_test.go b/internal/integration_test/integration_test.go index f847f761..0994b90f 100644 --- a/internal/integration_test/integration_test.go +++ b/internal/integration_test/integration_test.go @@ -131,7 +131,7 @@ func TestTCPEcho(t *testing.T) { replayCache := service.NewReplayCache(5) const testTimeout = 200 * time.Millisecond testMetrics := &statusMetrics{} - authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &replayCache, &fakeShadowsocksMetrics{}) + authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &replayCache, &fakeShadowsocksMetrics{}, nil) handler := service.NewStreamHandler(authFunc, testTimeout) handler.SetTargetDialer(&transport.TCPDialer{}) done := make(chan struct{}) @@ -211,7 +211,7 @@ func TestRestrictedAddresses(t *testing.T) { require.NoError(t, err) const testTimeout = 200 * time.Millisecond testMetrics := &statusMetrics{} - authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}) + authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}, nil) handler := service.NewStreamHandler(authFunc, testTimeout) done := make(chan struct{}) go func() { @@ -400,7 +400,7 @@ func BenchmarkTCPThroughput(b *testing.B) { } const testTimeout = 200 * time.Millisecond testMetrics := &service.NoOpTCPConnMetrics{} - authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}) + authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}, nil) handler := service.NewStreamHandler(authFunc, testTimeout) handler.SetTargetDialer(&transport.TCPDialer{}) done := make(chan struct{}) @@ -467,7 +467,7 @@ func BenchmarkTCPMultiplexing(b *testing.B) { replayCache := service.NewReplayCache(service.MaxCapacity) const testTimeout = 200 * time.Millisecond testMetrics := &service.NoOpTCPConnMetrics{} - authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &replayCache, &fakeShadowsocksMetrics{}) + authFunc := service.NewShadowsocksStreamAuthenticator(cipherList, &replayCache, &fakeShadowsocksMetrics{}, nil) handler := service.NewStreamHandler(authFunc, testTimeout) handler.SetTargetDialer(&transport.TCPDialer{}) done := make(chan struct{}) diff --git a/service/logger.go b/service/logger.go index 79b8ee3a..343dd5df 100644 --- a/service/logger.go +++ b/service/logger.go @@ -14,6 +14,13 @@ package service -import logging "github.com/op/go-logging" +import ( + "io" + "log/slog" + "math" +) -var logger = logging.MustGetLogger("shadowsocks") +func noopLogger() *slog.Logger { + // TODO: Use built-in no-op log level when available: https://go.dev/issue/62005 + return slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.Level(math.MaxInt)})) +} diff --git a/service/shadowsocks.go b/service/shadowsocks.go index f979bcce..636fa94e 100644 --- a/service/shadowsocks.go +++ b/service/shadowsocks.go @@ -16,6 +16,7 @@ package service import ( "context" + "log/slog" "net" "time" @@ -50,6 +51,7 @@ type Service interface { type Option func(s *ssService) type ssService struct { + logger *slog.Logger metrics ServiceMetrics ciphers CipherList natTimeout time.Duration @@ -59,7 +61,7 @@ type ssService struct { ph PacketHandler } -// NewShadowsocksService creates a new service +// NewShadowsocksService creates a new Shadowsocks service. func NewShadowsocksService(opts ...Option) (Service, error) { s := &ssService{} @@ -67,20 +69,36 @@ func NewShadowsocksService(opts ...Option) (Service, error) { opt(s) } + // If no NAT timeout is provided via options, use the recommended default. if s.natTimeout == 0 { s.natTimeout = defaultNatTimeout } + // If no logger is provided via options, use a noop logger. + if s.logger == nil { + s.logger = noopLogger() + } // TODO: Register initial data metrics at zero. s.sh = NewStreamHandler( - NewShadowsocksStreamAuthenticator(s.ciphers, s.replayCache, &ssConnMetrics{ServiceMetrics: s.metrics, proto: "tcp"}), + NewShadowsocksStreamAuthenticator(s.ciphers, s.replayCache, &ssConnMetrics{ServiceMetrics: s.metrics, proto: "tcp"}, s.logger), tcpReadTimeout, ) + s.sh.SetLogger(s.logger) + s.ph = NewPacketHandler(s.natTimeout, s.ciphers, s.metrics, &ssConnMetrics{ServiceMetrics: s.metrics, proto: "udp"}) + s.ph.SetLogger(s.logger) return s, nil } +// WithLogger can be used to provide a custom log target. If not provided, +// the service uses a noop logger (i.e., no logging). +func WithLogger(l *slog.Logger) Option { + return func(s *ssService) { + s.logger = l + } +} + // WithCiphers option function. func WithCiphers(ciphers CipherList) Option { return func(s *ssService) { diff --git a/service/tcp.go b/service/tcp.go index 6a12afc0..775509e9 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -58,11 +58,11 @@ func remoteIP(conn net.Conn) netip.Addr { } // Wrapper for slog.Debug during TCP access key searches. -func debugTCP(template string, cipherID string, attr slog.Attr) { +func debugTCP(l *slog.Logger, template string, cipherID string, attr slog.Attr) { // This is an optimization to reduce unnecessary allocations due to an interaction // between Go's inlining/escape analysis and varargs functions like slog.Debug. - if slog.Default().Enabled(nil, slog.LevelDebug) { - slog.LogAttrs(nil, slog.LevelDebug, fmt.Sprintf("TCP: %s", template), slog.String("ID", cipherID), attr) + if l.Enabled(nil, slog.LevelDebug) { + l.LogAttrs(nil, slog.LevelDebug, fmt.Sprintf("TCP: %s", template), slog.String("ID", cipherID), attr) } } @@ -72,7 +72,7 @@ func debugTCP(template string, cipherID string, attr slog.Attr) { // required = saltSize + 2 + cipher.TagSize, the number of bytes needed to authenticate the connection. const bytesForKeyFinding = 50 -func findAccessKey(clientReader io.Reader, clientIP netip.Addr, cipherList CipherList) (*CipherEntry, io.Reader, []byte, time.Duration, error) { +func findAccessKey(clientReader io.Reader, clientIP netip.Addr, cipherList CipherList, l *slog.Logger) (*CipherEntry, io.Reader, []byte, time.Duration, error) { // We snapshot the list because it may be modified while we use it. ciphers := cipherList.SnapshotForClientIP(clientIP) firstBytes := make([]byte, bytesForKeyFinding) @@ -81,7 +81,7 @@ func findAccessKey(clientReader io.Reader, clientIP netip.Addr, cipherList Ciphe } findStartTime := time.Now() - entry, elt := findEntry(firstBytes, ciphers) + entry, elt := findEntry(firstBytes, ciphers, l) timeToCipher := time.Since(findStartTime) if entry == nil { // TODO: Ban and log client IPs with too many failures too quick to protect against DoS. @@ -95,7 +95,7 @@ func findAccessKey(clientReader io.Reader, clientIP netip.Addr, cipherList Ciphe } // Implements a trial decryption search. This assumes that all ciphers are AEAD. -func findEntry(firstBytes []byte, ciphers []*list.Element) (*CipherEntry, *list.Element) { +func findEntry(firstBytes []byte, ciphers []*list.Element, l *slog.Logger) (*CipherEntry, *list.Element) { // To hold the decrypted chunk length. chunkLenBuf := [2]byte{} for ci, elt := range ciphers { @@ -103,10 +103,10 @@ func findEntry(firstBytes []byte, ciphers []*list.Element) (*CipherEntry, *list. cryptoKey := entry.CryptoKey _, err := shadowsocks.Unpack(chunkLenBuf[:0], firstBytes[:cryptoKey.SaltSize()+2+cryptoKey.TagSize()], cryptoKey) if err != nil { - debugTCP("Failed to decrypt length.", entry.ID, slog.Any("err", err)) + debugTCP(l, "Failed to decrypt length.", entry.ID, slog.Any("err", err)) continue } - debugTCP("Found cipher.", entry.ID, slog.Int("index", ci)) + debugTCP(l, "Found cipher.", entry.ID, slog.Int("index", ci)) return entry, elt } return nil, nil @@ -116,13 +116,16 @@ type StreamAuthenticateFunc func(clientConn transport.StreamConn) (string, trans // NewShadowsocksStreamAuthenticator creates a stream authenticator that uses Shadowsocks. // TODO(fortuna): Offer alternative transports. -func NewShadowsocksStreamAuthenticator(ciphers CipherList, replayCache *ReplayCache, metrics ShadowsocksConnMetrics) StreamAuthenticateFunc { +func NewShadowsocksStreamAuthenticator(ciphers CipherList, replayCache *ReplayCache, metrics ShadowsocksConnMetrics, l *slog.Logger) StreamAuthenticateFunc { if metrics == nil { metrics = &NoOpShadowsocksConnMetrics{} } + if l == nil { + l = noopLogger() + } return func(clientConn transport.StreamConn) (string, transport.StreamConn, *onet.ConnectionError) { // Find the cipher and acess key id. - cipherEntry, clientReader, clientSalt, timeToCipher, keyErr := findAccessKey(clientConn, remoteIP(clientConn), ciphers) + cipherEntry, clientReader, clientSalt, timeToCipher, keyErr := findAccessKey(clientConn, remoteIP(clientConn), ciphers, l) metrics.AddCipherSearch(keyErr == nil, timeToCipher) if keyErr != nil { const status = "ERR_CIPHER" @@ -154,6 +157,7 @@ func NewShadowsocksStreamAuthenticator(ciphers CipherList, replayCache *ReplayCa } type streamHandler struct { + logger *slog.Logger listenerId string readTimeout time.Duration authenticate StreamAuthenticateFunc @@ -163,6 +167,7 @@ type streamHandler struct { // NewStreamHandler creates a StreamHandler func NewStreamHandler(authenticate StreamAuthenticateFunc, timeout time.Duration) StreamHandler { return &streamHandler{ + logger: noopLogger(), readTimeout: timeout, authenticate: authenticate, dialer: defaultDialer, @@ -181,10 +186,19 @@ func makeValidatingTCPStreamDialer(targetIPValidator onet.TargetIPValidator) tra // StreamHandler is a handler that handles stream connections. type StreamHandler interface { Handle(ctx context.Context, conn transport.StreamConn, connMetrics TCPConnMetrics) + // SetLogger sets the logger used to log messages. Uses a no-op logger if nil. + SetLogger(l *slog.Logger) // SetTargetDialer sets the [transport.StreamDialer] to be used to connect to target addresses. SetTargetDialer(dialer transport.StreamDialer) } +func (s *streamHandler) SetLogger(l *slog.Logger) { + if l == nil { + l = noopLogger() + } + s.logger = l +} + func (s *streamHandler) SetTargetDialer(dialer transport.StreamDialer) { s.dialer = dialer } @@ -257,11 +271,11 @@ func (h *streamHandler) Handle(ctx context.Context, clientConn transport.StreamC status := "OK" if connError != nil { status = connError.Status - slog.LogAttrs(nil, slog.LevelDebug, "TCP: Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause)) + h.logger.LogAttrs(nil, slog.LevelDebug, "TCP: Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause)) } connMetrics.AddClosed(status, proxyMetrics, connDuration) measuredClientConn.Close() // Closing after the metrics are added aids integration testing. - slog.LogAttrs(nil, slog.LevelDebug, "TCP: Done.", slog.String("status", status), slog.Duration("duration", connDuration)) + h.logger.LogAttrs(nil, slog.LevelDebug, "TCP: Done.", slog.String("status", status), slog.Duration("duration", connDuration)) } func getProxyRequest(clientConn transport.StreamConn) (string, error) { @@ -276,14 +290,14 @@ func getProxyRequest(clientConn transport.StreamConn) (string, error) { return tgtAddr.String(), nil } -func proxyConnection(ctx context.Context, dialer transport.StreamDialer, tgtAddr string, clientConn transport.StreamConn) *onet.ConnectionError { +func proxyConnection(l *slog.Logger, ctx context.Context, dialer transport.StreamDialer, tgtAddr string, clientConn transport.StreamConn) *onet.ConnectionError { tgtConn, dialErr := dialer.DialStream(ctx, tgtAddr) if dialErr != nil { // We don't drain so dial errors and invalid addresses are communicated quickly. return ensureConnectionError(dialErr, "ERR_CONNECT", "Failed to connect to target") } defer tgtConn.Close() - slog.LogAttrs(nil, slog.LevelDebug, "Proxy connection.", slog.String("client", clientConn.RemoteAddr().String()), slog.String("target", tgtConn.RemoteAddr().String())) + l.LogAttrs(nil, slog.LevelDebug, "Proxy connection.", slog.String("client", clientConn.RemoteAddr().String()), slog.String("target", tgtConn.RemoteAddr().String())) fromClientErrCh := make(chan error) go func() { @@ -351,7 +365,7 @@ func (h *streamHandler) handleConnection(ctx context.Context, outerConn transpor tgtConn = metrics.MeasureConn(tgtConn, &proxyMetrics.ProxyTarget, &proxyMetrics.TargetProxy) return tgtConn, nil }) - return proxyConnection(ctx, dialer, tgtAddr, innerConn) + return proxyConnection(h.logger, ctx, dialer, tgtAddr, innerConn) } // Keep the connection open until we hit the authentication deadline to protect against probing attacks @@ -360,7 +374,7 @@ func (h *streamHandler) absorbProbe(clientConn io.ReadCloser, connMetrics TCPCon // This line updates proxyMetrics.ClientProxy before it's used in AddTCPProbe. _, drainErr := io.Copy(io.Discard, clientConn) // drain socket drainResult := drainErrToString(drainErr) - slog.LogAttrs(nil, slog.LevelDebug, "Drain error.", slog.Any("err", drainErr), slog.String("result", drainResult)) + h.logger.LogAttrs(nil, slog.LevelDebug, "Drain error.", slog.Any("err", drainErr), slog.String("result", drainResult)) connMetrics.AddProbe(status, drainResult, proxyMetrics.ClientProxy) } diff --git a/service/tcp_test.go b/service/tcp_test.go index a4efcb2e..e69d1d1e 100644 --- a/service/tcp_test.go +++ b/service/tcp_test.go @@ -102,7 +102,7 @@ func BenchmarkTCPFindCipherFail(b *testing.B) { } clientIP := clientConn.RemoteAddr().(*net.TCPAddr).AddrPort().Addr() b.StartTimer() - findAccessKey(clientConn, clientIP, cipherList) + findAccessKey(clientConn, clientIP, cipherList, noopLogger()) b.StopTimer() } } @@ -205,7 +205,7 @@ func BenchmarkTCPFindCipherRepeat(b *testing.B) { cipher := cipherEntries[cipherNumber].CryptoKey go shadowsocks.NewWriter(writer, cipher).Write(makeTestPayload(50)) b.StartTimer() - _, _, _, _, err := findAccessKey(&c, clientIP, cipherList) + _, _, _, _, err := findAccessKey(&c, clientIP, cipherList, noopLogger()) b.StopTimer() if err != nil { b.Error(err) @@ -285,7 +285,7 @@ func TestProbeRandom(t *testing.T) { cipherList, err := MakeTestCiphers(makeTestSecrets(1)) require.NoError(t, err, "MakeTestCiphers failed: %v", err) testMetrics := &probeTestMetrics{} - authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}) + authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}, nil) handler := NewStreamHandler(authFunc, 200*time.Millisecond) done := make(chan struct{}) go func() { @@ -365,7 +365,7 @@ func TestProbeClientBytesBasicTruncated(t *testing.T) { require.NoError(t, err, "MakeTestCiphers failed: %v", err) cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} - authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}) + authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}, nil) handler := NewStreamHandler(authFunc, 200*time.Millisecond) handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) @@ -403,7 +403,7 @@ func TestProbeClientBytesBasicModified(t *testing.T) { require.NoError(t, err, "MakeTestCiphers failed: %v", err) cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} - authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}) + authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}, nil) handler := NewStreamHandler(authFunc, 200*time.Millisecond) handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) @@ -442,7 +442,7 @@ func TestProbeClientBytesCoalescedModified(t *testing.T) { require.NoError(t, err, "MakeTestCiphers failed: %v", err) cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} - authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}) + authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}, nil) handler := NewStreamHandler(authFunc, 200*time.Millisecond) handler.SetTargetDialer(makeValidatingTCPStreamDialer(allowAll)) done := make(chan struct{}) @@ -488,7 +488,7 @@ func TestProbeServerBytesModified(t *testing.T) { require.NoError(t, err, "MakeTestCiphers failed: %v", err) cipher := firstCipher(cipherList) testMetrics := &probeTestMetrics{} - authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}) + authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}, nil) handler := NewStreamHandler(authFunc, 200*time.Millisecond) done := make(chan struct{}) go func() { @@ -522,7 +522,7 @@ func TestReplayDefense(t *testing.T) { replayCache := NewReplayCache(5) testMetrics := &probeTestMetrics{} const testTimeout = 200 * time.Millisecond - authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics) + authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics, nil) handler := NewStreamHandler(authFunc, testTimeout) snapshot := cipherList.SnapshotForClientIP(netip.Addr{}) cipherEntry := snapshot[0].Value.(*CipherEntry) @@ -604,7 +604,7 @@ func TestReverseReplayDefense(t *testing.T) { replayCache := NewReplayCache(5) testMetrics := &probeTestMetrics{} const testTimeout = 200 * time.Millisecond - authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics) + authFunc := NewShadowsocksStreamAuthenticator(cipherList, &replayCache, testMetrics, nil) handler := NewStreamHandler(authFunc, testTimeout) snapshot := cipherList.SnapshotForClientIP(netip.Addr{}) cipherEntry := snapshot[0].Value.(*CipherEntry) @@ -678,7 +678,7 @@ func probeExpectTimeout(t *testing.T, payloadSize int) { cipherList, err := MakeTestCiphers(makeTestSecrets(5)) require.NoError(t, err, "MakeTestCiphers failed: %v", err) testMetrics := &probeTestMetrics{} - authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}) + authFunc := NewShadowsocksStreamAuthenticator(cipherList, nil, &fakeShadowsocksMetrics{}, nil) handler := NewStreamHandler(authFunc, testTimeout) done := make(chan struct{}) diff --git a/service/udp.go b/service/udp.go index 39091239..8ff5352f 100644 --- a/service/udp.go +++ b/service/udp.go @@ -44,23 +44,23 @@ type UDPMetrics interface { const serverUDPBufferSize = 64 * 1024 // Wrapper for slog.Debug during UDP proxying. -func debugUDP(template string, cipherID string, attr slog.Attr) { +func debugUDP(l *slog.Logger, template string, cipherID string, attr slog.Attr) { // This is an optimization to reduce unnecessary allocations due to an interaction // between Go's inlining/escape analysis and varargs functions like slog.Debug. - if slog.Default().Enabled(nil, slog.LevelDebug) { - slog.LogAttrs(nil, slog.LevelDebug, fmt.Sprintf("UDP: %s", template), slog.String("ID", cipherID), attr) + if l.Enabled(nil, slog.LevelDebug) { + l.LogAttrs(nil, slog.LevelDebug, fmt.Sprintf("UDP: %s", template), slog.String("ID", cipherID), attr) } } -func debugUDPAddr(template string, addr net.Addr, attr slog.Attr) { - if slog.Default().Enabled(nil, slog.LevelDebug) { - slog.LogAttrs(nil, slog.LevelDebug, fmt.Sprintf("UDP: %s", template), slog.String("address", addr.String()), attr) +func debugUDPAddr(l *slog.Logger, template string, addr net.Addr, attr slog.Attr) { + if l.Enabled(nil, slog.LevelDebug) { + l.LogAttrs(nil, slog.LevelDebug, fmt.Sprintf("UDP: %s", template), slog.String("address", addr.String()), attr) } } // Decrypts src into dst. It tries each cipher until it finds one that authenticates // correctly. dst and src must not overlap. -func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherList) ([]byte, string, *shadowsocks.EncryptionKey, error) { +func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherList, l *slog.Logger) ([]byte, string, *shadowsocks.EncryptionKey, error) { // Try each cipher until we find one that authenticates successfully. This assumes that all ciphers are AEAD. // We snapshot the list because it may be modified while we use it. snapshot := cipherList.SnapshotForClientIP(clientIP) @@ -68,10 +68,10 @@ func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherLis id, cryptoKey := entry.Value.(*CipherEntry).ID, entry.Value.(*CipherEntry).CryptoKey buf, err := shadowsocks.Unpack(dst, src, cryptoKey) if err != nil { - debugUDP("Failed to unpack.", id, slog.Any("err", err)) + debugUDP(l, "Failed to unpack.", id, slog.Any("err", err)) continue } - debugUDP("Found cipher.", id, slog.Int("index", ci)) + debugUDP(l, "Found cipher.", id, slog.Int("index", ci)) // Move the active cipher to the front, so that the search is quicker next time. cipherList.MarkUsedByClientIP(entry, clientIP) return buf, id, cryptoKey, nil @@ -80,6 +80,7 @@ func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherLis } type packetHandler struct { + logger *slog.Logger natTimeout time.Duration ciphers CipherList m UDPMetrics @@ -96,6 +97,7 @@ func NewPacketHandler(natTimeout time.Duration, cipherList CipherList, m UDPMetr ssMetrics = &NoOpShadowsocksConnMetrics{} } return &packetHandler{ + logger: noopLogger(), natTimeout: natTimeout, ciphers: cipherList, m: m, @@ -106,12 +108,21 @@ func NewPacketHandler(natTimeout time.Duration, cipherList CipherList, m UDPMetr // PacketHandler is a running UDP shadowsocks proxy that can be stopped. type PacketHandler interface { + // SetLogger sets the logger used to log messages. Uses a no-op logger if nil. + SetLogger(l *slog.Logger) // SetTargetIPValidator sets the function to be used to validate the target IP addresses. SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) // Handle returns after clientConn closes and all the sub goroutines return. Handle(clientConn net.PacketConn) } +func (h *packetHandler) SetLogger(l *slog.Logger) { + if l == nil { + l = noopLogger() + } + h.logger = l +} + func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) { h.targetIPValidator = targetIPValidator } @@ -121,7 +132,7 @@ func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPVali func (h *packetHandler) Handle(clientConn net.PacketConn) { var running sync.WaitGroup - nm := newNATmap(h.natTimeout, h.m, &running) + nm := newNATmap(h.natTimeout, h.m, &running, h.logger) defer nm.Close() cipherBuf := make([]byte, serverUDPBufferSize) textBuf := make([]byte, serverUDPBufferSize) @@ -149,7 +160,7 @@ func (h *packetHandler) Handle(clientConn net.PacketConn) { return onet.NewConnectionError("ERR_READ", "Failed to read from client", err) } defer slog.LogAttrs(nil, slog.LevelDebug, "UDP: Done", slog.String("address", clientAddr.String())) - debugUDPAddr("Outbound packet.", clientAddr, slog.Int("bytes", clientProxyBytes)) + debugUDPAddr(h.logger, "Outbound packet.", clientAddr, slog.Int("bytes", clientProxyBytes)) cipherData := cipherBuf[:clientProxyBytes] var payload []byte @@ -160,7 +171,7 @@ func (h *packetHandler) Handle(clientConn net.PacketConn) { var textData []byte var cryptoKey *shadowsocks.EncryptionKey unpackStart := time.Now() - textData, keyID, cryptoKey, err = findAccessKeyUDP(ip, textBuf, cipherData, h.ciphers) + textData, keyID, cryptoKey, err = findAccessKeyUDP(ip, textBuf, cipherData, h.ciphers, h.logger) timeToCipher := time.Since(unpackStart) h.ssm.AddCipherSearch(err == nil, timeToCipher) @@ -197,7 +208,7 @@ func (h *packetHandler) Handle(clientConn net.PacketConn) { } } - debugUDPAddr("Proxy exit.", clientAddr, slog.Any("target", targetConn.LocalAddr())) + debugUDPAddr(h.logger, "Proxy exit.", clientAddr, slog.Any("target", targetConn.LocalAddr())) proxyTargetBytes, err = targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature if err != nil { return onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err) @@ -306,13 +317,14 @@ func (c *natconn) ReadFrom(buf []byte) (int, net.Addr, error) { type natmap struct { sync.RWMutex keyConn map[string]*natconn + logger *slog.Logger timeout time.Duration metrics UDPMetrics running *sync.WaitGroup } -func newNATmap(timeout time.Duration, sm UDPMetrics, running *sync.WaitGroup) *natmap { - m := &natmap{metrics: sm, running: running} +func newNATmap(timeout time.Duration, sm UDPMetrics, running *sync.WaitGroup, l *slog.Logger) *natmap { + m := &natmap{logger: l, metrics: sm, running: running} m.keyConn = make(map[string]*natconn) m.timeout = timeout return m @@ -358,7 +370,7 @@ func (m *natmap) Add(clientAddr net.Addr, clientConn net.PacketConn, cryptoKey * m.running.Add(1) go func() { - timedCopy(clientAddr, clientConn, entry, keyID) + timedCopy(clientAddr, clientConn, entry, keyID, m.logger) connMetrics.RemoveNatEntry() if pc := m.del(clientAddr.String()); pc != nil { pc.Close() @@ -387,7 +399,7 @@ func (m *natmap) Close() error { var maxAddrLen int = len(socks.ParseAddr("[2001:db8::1]:12345")) // copy from target to client until read timeout -func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natconn, keyID string) { +func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natconn, keyID string, l *slog.Logger) { // pkt is used for in-place encryption of downstream UDP packets, with the layout // [padding?][salt][address][body][tag][extra] // Padding is only used if the address is IPv4. @@ -420,7 +432,7 @@ func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natco return onet.NewConnectionError("ERR_READ", "Failed to read from target", err) } - debugUDPAddr("Got response.", clientAddr, slog.Any("target", raddr)) + debugUDPAddr(l, "Got response.", clientAddr, slog.Any("target", raddr)) srcAddr := socks.ParseAddr(raddr.String()) addrStart := bodyStart - len(srcAddr) // `plainTextBuf` concatenates the SOCKS address and body: diff --git a/service/udp_test.go b/service/udp_test.go index 8ba00af3..ae792363 100644 --- a/service/udp_test.go +++ b/service/udp_test.go @@ -207,14 +207,14 @@ func assertAlmostEqual(t *testing.T, a, b time.Time) { } func TestNATEmpty(t *testing.T) { - nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{}) + nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{}, noopLogger()) if nat.Get("foo") != nil { t.Error("Expected nil value from empty NAT map") } } func setupNAT() (*fakePacketConn, *fakePacketConn, *natconn) { - nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{}) + nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{}, noopLogger()) clientConn := makePacketConn() targetConn := makePacketConn() nat.Add(&clientAddr, clientConn, natCryptoKey, targetConn, "key id") @@ -409,7 +409,7 @@ func BenchmarkUDPUnpackFail(b *testing.B) { testIP := netip.MustParseAddr("192.0.2.1") b.ResetTimer() for n := 0; n < b.N; n++ { - findAccessKeyUDP(testIP, textBuf, testPayload, cipherList) + findAccessKeyUDP(testIP, textBuf, testPayload, cipherList, noopLogger()) } } @@ -439,7 +439,7 @@ func BenchmarkUDPUnpackRepeat(b *testing.B) { cipherNumber := n % numCiphers ip := ips[cipherNumber] packet := packets[cipherNumber] - _, _, _, err := findAccessKeyUDP(ip, testBuf, packet, cipherList) + _, _, _, err := findAccessKeyUDP(ip, testBuf, packet, cipherList, noopLogger()) if err != nil { b.Error(err) } @@ -468,7 +468,7 @@ func BenchmarkUDPUnpackSharedKey(b *testing.B) { b.ResetTimer() for n := 0; n < b.N; n++ { ip := ips[n%numIPs] - _, _, _, err := findAccessKeyUDP(ip, testBuf, packet, cipherList) + _, _, _, err := findAccessKeyUDP(ip, testBuf, packet, cipherList, noopLogger()) if err != nil { b.Error(err) }