Skip to content

Commit

Permalink
Merge shared access key logic.
Browse files Browse the repository at this point in the history
  • Loading branch information
sbruens committed Jun 26, 2024
1 parent 8368480 commit f691d98
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 173 deletions.
13 changes: 13 additions & 0 deletions service/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,16 @@ package service
import logging "github.com/op/go-logging"

var logger = logging.MustGetLogger("shadowsocks")

type DebugLoggerFunc func(tag string, template string, val interface{})

// NewDebugLogger creates a wrapper for logger.Debugf during proxying.
func NewDebugLogger(protocol string) DebugLoggerFunc {
return func(tag string, template string, val interface{}) {
// This is an optimization to reduce unnecessary allocations due to an interaction
// between Go's inlining/escape analysis and varargs functions like logger.Debugf.
if logger.IsEnabledFor(logging.DEBUG) {
logger.Debugf("%s(%s): "+template, protocol, tag, val)
}
}
}
71 changes: 10 additions & 61 deletions service/shadowsocks.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,84 +15,33 @@
package service

import (
"bytes"
"container/list"
"errors"
"fmt"
"io"
"net/netip"
"time"

"github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks"
)

// bytesForKeyFinding is the number of bytes to read for finding the AccessKey.
// Is must satisfy provided >= bytesForKeyFinding >= required for every cipher in the list.
// provided = saltSize + 2 + 2 * cipher.TagSize, the minimum number of bytes we will see in a valid connection
// required = saltSize + 2 + cipher.TagSize, the number of bytes needed to authenticate the connection.
const bytesForKeyFinding = 50

func findAccessKey(clientReader io.Reader, clientIP netip.Addr, bufferSize int, cipherList CipherList) (*CipherEntry, io.Reader, []byte, time.Duration, error) {
firstBytes := make([]byte, bytesForKeyFinding)
if n, err := io.ReadFull(clientReader, firstBytes); err != nil {
return nil, clientReader, nil, 0, fmt.Errorf("reading header failed after %d bytes: %w", n, err)
}

// findAccessKey implements a trial decryption search. This assumes that all ciphers are AEAD.
func findAccessKey(clientIP netip.Addr, bufferSize int, src []byte, cipherList CipherList, logDebug DebugLoggerFunc) (*CipherEntry, []byte, time.Duration, error) {
// We snapshot the list because it may be modified while we use it.
ciphers := cipherList.SnapshotForClientIP(clientIP)

findStartTime := time.Now()
entry, _, elt, err := findEntry(firstBytes, bufferSize, ciphers)
timeToCipher := time.Since(findStartTime)
if err != nil {
// TODO: Ban and log client IPs with too many failures too quick to protect against DoS.
return nil, clientReader, nil, timeToCipher, err
}

// Move the active cipher to the front, so that the search is quicker next time.
cipherList.MarkUsedByClientIP(elt, clientIP)
salt := firstBytes[:entry.CryptoKey.SaltSize()]
return entry, io.MultiReader(bytes.NewReader(firstBytes), clientReader), salt, timeToCipher, nil
}

// findAccessKeyUDP decrypts src. It tries each cipher until it finds one that
// authenticates correctly.
func findAccessKeyUDP(clientIP netip.Addr, bufferSize int, src []byte, cipherList CipherList) (*CipherEntry, []byte, error) {
// We snapshot the list because it may be modified while we use it.
ciphers := cipherList.SnapshotForClientIP(clientIP)

// Try each cipher until we find one that authenticates successfully. This assumes that all ciphers are AEAD.
unpackStart := time.Now()
// To hold the decrypted chunk length.
chunkLenBuf := make([]byte, bufferSize)
for ci, elt := range ciphers {
entry := elt.Value.(*CipherEntry)
cryptoKey := entry.CryptoKey
buf, err := shadowsocks.Unpack(chunkLenBuf, src, cryptoKey)
buf, err := shadowsocks.Unpack(chunkLenBuf, src, entry.CryptoKey)
if err != nil {
debugUDP(entry.ID, "Failed to unpack: %v", err)
logDebug(entry.ID, "Failed to unpack: %v", err)
continue
}
debugUDP(entry.ID, "Found cipher at index %d", ci)
logDebug(entry.ID, "Found cipher at index %d", ci)

// Move the active cipher to the front, so that the search is quicker next time.
cipherList.MarkUsedByClientIP(elt, clientIP)
return entry, buf, nil
}
return nil, nil, errors.New("could not find valid cipher")
}

// Implements a trial decryption search. This assumes that all ciphers are AEAD.
func findEntry(src []byte, bufferSize int, ciphers []*list.Element) (*CipherEntry, []byte, *list.Element, error) {
// To hold the decrypted chunk length.
chunkLenBuf := make([]byte, bufferSize)
for ci, elt := range ciphers {
entry := elt.Value.(*CipherEntry)
cryptoKey := entry.CryptoKey
buf, err := shadowsocks.Unpack(chunkLenBuf, src[:cryptoKey.SaltSize()+2+cryptoKey.TagSize()], cryptoKey)
if err != nil {
debugTCP(entry.ID, "Failed to decrypt length: %v", err)
continue
}
debugTCP(entry.ID, "Found cipher at index %d", ci)
return entry, buf, elt, nil
return entry, buf, time.Since(unpackStart), nil
}
return nil, nil, nil, errors.New("could not find valid cipher")
return nil, nil, time.Since(unpackStart), errors.New("could not find valid cipher")
}
96 changes: 78 additions & 18 deletions service/shadowsocks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,15 @@ func BenchmarkTCPFindCipherFail(b *testing.B) {
b.StopTimer()
b.ResetTimer()

listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
if err != nil {
b.Fatalf("ListenTCP failed: %v", err)
}

clientIP := netip.MustParseAddr("127.0.0.1")
cipherList, err := MakeTestCiphers(makeTestSecrets(100))
if err != nil {
b.Fatal(err)
}
testPayload := makeTestPayload(50)
for n := 0; n < b.N; n++ {
go func() {
conn, err := net.Dial("tcp", listener.Addr().String())
require.NoErrorf(b, err, "Failed to dial %v: %v", listener.Addr(), err)
conn.Write(testPayload)
conn.Close()
}()
clientConn, err := listener.AcceptTCP()
if err != nil {
b.Fatalf("AcceptTCP failed: %v", err)
}
clientIP := clientConn.RemoteAddr().(*net.TCPAddr).AddrPort().Addr()
b.StartTimer()
findAccessKey(clientConn, clientIP, 2, cipherList)
findAccessKey(clientIP, 2, testPayload, cipherList, NewDebugLogger("TCP"))
b.StopTimer()
}
}
Expand All @@ -73,6 +58,7 @@ func BenchmarkTCPFindCipherRepeat(b *testing.B) {
for cipherNumber, element := range snapshot {
cipherEntries[cipherNumber] = element.Value.(*CipherEntry)
}
testPayload := makeTestPayload(50)
for n := 0; n < b.N; n++ {
cipherNumber := byte(n % numCiphers)
reader, writer := io.Pipe()
Expand All @@ -82,11 +68,85 @@ func BenchmarkTCPFindCipherRepeat(b *testing.B) {
cipher := cipherEntries[cipherNumber].CryptoKey
go shadowsocks.NewWriter(writer, cipher).Write(makeTestPayload(50))
b.StartTimer()
_, _, _, _, err := findAccessKey(&c, clientIP, 2, cipherList)
_, _, _, err := findAccessKey(clientIP, 2, testPayload, cipherList, NewDebugLogger("TCP"))
b.StopTimer()
if err != nil {
b.Error(err)
}
c.Close()
}
}

// Simulates receiving invalid UDP packets on a server with 100 ciphers.
func BenchmarkUDPUnpackFail(b *testing.B) {
cipherList, err := MakeTestCiphers(makeTestSecrets(100))
if err != nil {
b.Fatal(err)
}
testPayload := makeTestPayload(50)
testIP := netip.MustParseAddr("192.0.2.1")
b.ResetTimer()
for n := 0; n < b.N; n++ {
findAccessKey(testIP, serverUDPBufferSize, testPayload, cipherList, NewDebugLogger("UDP"))
}
}

// Simulates receiving valid UDP packets from 100 different users, each with
// their own cipher and IP address.
func BenchmarkUDPUnpackRepeat(b *testing.B) {
const numCiphers = 100 // Must be <256
cipherList, err := MakeTestCiphers(makeTestSecrets(numCiphers))
if err != nil {
b.Fatal(err)
}
packets := [numCiphers][]byte{}
ips := [numCiphers]netip.Addr{}
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
for i, element := range snapshot {
packets[i] = make([]byte, 0, serverUDPBufferSize)
plaintext := makeTestPayload(50)
packets[i], err = shadowsocks.Pack(make([]byte, serverUDPBufferSize), plaintext, element.Value.(*CipherEntry).CryptoKey)
if err != nil {
b.Error(err)
}
ips[i] = netip.AddrFrom4([4]byte{192, 0, 2, byte(i)})
}
b.ResetTimer()
for n := 0; n < b.N; n++ {
cipherNumber := n % numCiphers
ip := ips[cipherNumber]
packet := packets[cipherNumber]
_, _, _, err := findAccessKey(ip, serverUDPBufferSize, packet, cipherList, NewDebugLogger("UDP"))
if err != nil {
b.Error(err)
}
}
}

// Simulates receiving valid UDP packets from 100 different IP addresses,
// all using the same cipher.
func BenchmarkUDPUnpackSharedKey(b *testing.B) {
cipherList, err := MakeTestCiphers(makeTestSecrets(1)) // One widely shared key
if err != nil {
b.Fatal(err)
}
plaintext := makeTestPayload(50)
snapshot := cipherList.SnapshotForClientIP(netip.Addr{})
cryptoKey := snapshot[0].Value.(*CipherEntry).CryptoKey
packet, err := shadowsocks.Pack(make([]byte, serverUDPBufferSize), plaintext, cryptoKey)
require.Nil(b, err)

const numIPs = 100 // Must be <256
ips := [numIPs]netip.Addr{}
for i := 0; i < numIPs; i++ {
ips[i] = netip.AddrFrom4([4]byte{192, 0, 2, byte(i)})
}
b.ResetTimer()
for n := 0; n < b.N; n++ {
ip := ips[n%numIPs]
_, _, _, err := findAccessKey(ip, serverUDPBufferSize, packet, cipherList, NewDebugLogger("UDP"))
if err != nil {
b.Error(err)
}
}
}
33 changes: 20 additions & 13 deletions service/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
package service

import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"net/netip"
Expand All @@ -29,7 +31,6 @@ import (
"github.com/Jigsaw-Code/outline-ss-server/ipinfo"
onet "github.com/Jigsaw-Code/outline-ss-server/net"
"github.com/Jigsaw-Code/outline-ss-server/service/metrics"
logging "github.com/op/go-logging"
"github.com/shadowsocks/go-shadowsocks2/socks"
)

Expand Down Expand Up @@ -59,15 +60,6 @@ func remoteIP(conn net.Conn) netip.Addr {
return netip.Addr{}
}

// Wrapper for logger.Debugf during TCP access key searches.
func debugTCP(cipherID, template string, val interface{}) {
// This is an optimization to reduce unnecessary allocations due to an interaction
// between Go's inlining/escape analysis and varargs functions like logger.Debugf.
if logger.IsEnabledFor(logging.DEBUG) {
logger.Debugf("TCP(%s): "+template, cipherID, val)
}
}

type StreamAuthenticateFunc func(clientConn transport.StreamConn) (string, transport.StreamConn, *onet.ConnectionError)

// ShadowsocksTCPMetrics is used to report Shadowsocks metrics on TCP connections.
Expand All @@ -76,23 +68,37 @@ type ShadowsocksTCPMetrics interface {
AddTCPCipherSearch(accessKeyFound bool, timeToCipher time.Duration)
}

// bytesForKeyFinding is the number of bytes to read for finding the AccessKey.
// Is must satisfy provided >= bytesForKeyFinding >= required for every cipher in the list.
// provided = saltSize + 2 + 2 * cipher.TagSize, the minimum number of bytes we will see in a valid connection
// required = saltSize + 2 + cipher.TagSize, the number of bytes needed to authenticate the connection.
const bytesForKeyFinding = 50

// NewShadowsocksStreamAuthenticator creates a stream authenticator that uses Shadowsocks.
// TODO(fortuna): Offer alternative transports.
func NewShadowsocksStreamAuthenticator(ciphers CipherList, replayCache *ReplayCache, metrics ShadowsocksTCPMetrics) StreamAuthenticateFunc {
return func(clientConn transport.StreamConn) (string, transport.StreamConn, *onet.ConnectionError) {
firstBytes := make([]byte, bytesForKeyFinding)
if n, err := io.ReadFull(clientConn, firstBytes); err != nil {
metrics.AddTCPCipherSearch(false, 0)
return "", clientConn, onet.NewConnectionError("ERR_CIPHER", fmt.Sprintf("Reading header failed after %d bytes", n), err)
}

// Find the cipher and acess key id.
cipherEntry, clientReader, clientSalt, timeToCipher, keyErr := findAccessKey(clientConn, remoteIP(clientConn), 2, ciphers)
bufferSize := 2
cipherEntry, _, timeToCipher, keyErr := findAccessKey(remoteIP(clientConn), bufferSize, firstBytes, ciphers, NewDebugLogger("TCP"))
metrics.AddTCPCipherSearch(keyErr == nil, timeToCipher)
if keyErr != nil {
const status = "ERR_CIPHER"
return "", nil, onet.NewConnectionError(status, "Failed to find a valid cipher", keyErr)
// TODO: Ban and log client IPs with too many failures too quick to protect against DoS.
return "", clientConn, onet.NewConnectionError("ERR_CIPHER", "Failed to find a valid cipher", keyErr)
}
var id string
if cipherEntry != nil {
id = cipherEntry.ID
}

// Check if the connection is a replay.
clientSalt := firstBytes[:cipherEntry.CryptoKey.SaltSize()]
isServerSalt := cipherEntry.SaltGenerator.IsServerSalt(clientSalt)
// Only check the cache if findAccessKey succeeded and the salt is unrecognized.
if isServerSalt || !replayCache.Add(cipherEntry.ID, clientSalt) {
Expand All @@ -105,6 +111,7 @@ func NewShadowsocksStreamAuthenticator(ciphers CipherList, replayCache *ReplayCa
return id, nil, onet.NewConnectionError(status, "Replay detected", nil)
}

clientReader := io.MultiReader(bytes.NewReader(firstBytes), clientConn)
ssr := shadowsocks.NewReader(clientReader, cipherEntry.CryptoKey)
ssw := shadowsocks.NewWriter(clientConn, cipherEntry.CryptoKey)
ssw.SetSaltGenerator(cipherEntry.SaltGenerator)
Expand Down
12 changes: 5 additions & 7 deletions service/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,17 +116,15 @@ func (h *packetHandler) authenticate(clientConn net.PacketConn) (*natconn, []byt
return nil, nil, 0, onet.NewConnectionError("ERR_READ", "Failed to read from client", err)
}

if logger.IsEnabledFor(logging.DEBUG) {
defer logger.Debugf("UDP(%v): done", clientAddr)
logger.Debugf("UDP(%v): Outbound packet has %d bytes", clientAddr, clientProxyBytes)
}
logDebug := NewDebugLogger("UDP")
defer logDebug(clientAddr.String(), "done%s", "")
logDebug(clientAddr.String(), "Outbound packet has %d bytes", clientProxyBytes)

targetConn := h.nm.Get(clientAddr.String())
remoteIP := clientAddr.(*net.UDPAddr).AddrPort().Addr()

unpackStart := time.Now()
cipherEntry, textData, keyErr := findAccessKeyUDP(remoteIP, serverUDPBufferSize, cipherBuf[:clientProxyBytes], h.ciphers)
timeToCipher := time.Since(unpackStart)
cipherEntry, textData, timeToCipher, keyErr := findAccessKey(remoteIP, serverUDPBufferSize, cipherBuf[:clientProxyBytes], h.ciphers, logDebug)
logDebug("test", "textdata: %v\n", textData)
h.m.AddUDPCipherSearch(err == nil, timeToCipher)
if keyErr != nil {
return targetConn, nil, 0, onet.NewConnectionError("ERR_CIPHER", "Failed to find a valid cipher", keyErr)
Expand Down
Loading

0 comments on commit f691d98

Please sign in to comment.