Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] [TT-13206] Add support for websocket rate limiting #6630

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
4 changes: 4 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,10 @@ func (pwl *PortsWhiteList) Decode(value string) error {

// StreamingConfig is for configuring tyk streaming
type StreamingConfig struct {
EnableWebSocketRateLimiting bool `json:"enable_websockets_rate_limiting"`
EnableWebSocketDetailedRecording bool `json:"enable_websockets_detailed_recording"`
EnableWebSocketCloseOnRateLimit bool `json:"enable_websocket_close_on_rate_limit"`

Enabled bool `json:"enabled"`
AllowUnsafe []string `json:"allow_unsafe"`
}
Expand Down
126 changes: 126 additions & 0 deletions gateway/analytics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package gateway

import (
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"

"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"

"github.com/TykTechnologies/tyk-pump/analytics"
Expand Down Expand Up @@ -546,6 +548,130 @@ func TestGeoIPLookup(t *testing.T) {
}
}

func TestWebsocketAnalytics(t *testing.T) {
ts := StartTest(nil)
t.Cleanup(ts.Close)

globalConf := ts.Gw.GetConfig()
globalConf.HttpServerOptions.EnableWebSockets = true
globalConf.Streaming.EnableWebSocketDetailedRecording = true
globalConf.AnalyticsConfig.EnableDetailedRecording = true
ts.Gw.SetConfig(globalConf)

// Create a session with a rate limit of 5 requests per second
session := CreateSession(ts.Gw, func(s *user.SessionState) {
s.Rate = 5
s.Per = 1
s.QuotaMax = -1
})

ts.Gw.BuildAndLoadAPI(func(spec *APISpec) {
spec.Proxy.ListenPath = "/"
spec.UseKeylessAccess = false
})

baseURL := strings.Replace(ts.URL, "http://", "ws://", -1)

// Function to create a new WebSocket connection
dialWS := func() (*websocket.Conn, *http.Response, error) {
headers := http.Header{"Authorization": {session}}
return websocket.DefaultDialer.Dial(baseURL+"/ws", headers)
}

// Cleanup before the test
ts.Gw.Analytics.Store.GetAndDeleteSet(analyticsKeyName)

// Connect and send messages
conn, _, err := dialWS()
if err != nil {
t.Fatalf("cannot make websocket connection: %v", err)
}

// Send and receive 3 messages
for i := 0; i < 3; i++ {
err = conn.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf("test message %d", i+1)))
if err != nil {
t.Fatalf("cannot write message: %v", err)
}

_, _, err := conn.ReadMessage()
if err != nil {
t.Fatalf("cannot read message: %v", err)
}
}

conn.Close()

time.Sleep(100 * time.Millisecond)

// Flush analytics
ts.Gw.Analytics.Flush()

time.Sleep(100 * time.Millisecond)

// Retrieve analytics records
analyticsRecords := ts.Gw.Analytics.Store.GetAndDeleteSet(analyticsKeyName)

// We expect 1 record for the initial handshake, and 3 records each for requests and responses
expectedRecords := 7
if len(analyticsRecords) != expectedRecords {
t.Errorf("Expected %d analytics records, got %d", expectedRecords, len(analyticsRecords))
}

// Verify the content of the analytics records
var handshakeFound bool
var requestCount, responseCount int

for _, record := range analyticsRecords {
var analyticRecord analytics.AnalyticsRecord
err := ts.Gw.Analytics.analyticsSerializer.Decode([]byte(record.(string)), &analyticRecord)
if err != nil {
t.Errorf("Error decoding analytics record: %v", err)
continue
}

// Check for handshake record
if analyticRecord.Path == "/ws" && analyticRecord.Method == "GET" {
handshakeFound = true
}

// Check for WebSocket message records
if strings.Contains(analyticRecord.Path, "/ws") {
if strings.HasSuffix(analyticRecord.Path, "/in") {
requestCount++
if analyticRecord.RawRequest == "" {
t.Errorf("Request body is empty for request record: %+v", analyticRecord)
}
rawResponse, _ := base64.StdEncoding.DecodeString(analyticRecord.RawResponse)
if !strings.Contains(string(rawResponse), "Content-Length: 0") {
t.Errorf("Response should contain Content-Length header for request record. Got: %s", rawResponse)
}
} else if strings.HasSuffix(analyticRecord.Path, "/out") {
responseCount++
if analyticRecord.RawResponse == "" {
t.Errorf("Response body is empty for response record: %+v", analyticRecord)
}
rawRequest, _ := base64.StdEncoding.DecodeString(analyticRecord.RawRequest)
if strings.Contains(string(rawRequest), "Content-Length") {
t.Errorf("Request should notcontain Content-Length header for response record. Got: %s", rawRequest)
}
}
}
}

if !handshakeFound {
t.Error("Handshake record not found in analytics")
}

if requestCount != 3 {
t.Errorf("Expected 3 WebSocket request records, got %d", requestCount)
}

if responseCount != 3 {
t.Errorf("Expected 3 WebSocket response records, got %d", responseCount)
}
}

func TestURLReplacer(t *testing.T) {

ts := StartTest(func(globalConf *config.Config) {
Expand Down
86 changes: 86 additions & 0 deletions gateway/gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1573,6 +1573,92 @@ func TestWebsocketsAndHTTPEndpointMatch(t *testing.T) {
})
}

func TestWebsocketsWithRateLimit(t *testing.T) {
ts := StartTest(nil)
t.Cleanup(ts.Close)

globalConf := ts.Gw.GetConfig()
globalConf.HttpServerOptions.EnableWebSockets = true
globalConf.Streaming.EnableWebSocketRateLimiting = true
ts.Gw.SetConfig(globalConf)

// Create a session with a rate limit of 5 requests per 5 seconds
session := CreateSession(ts.Gw, func(s *user.SessionState) {
s.Rate = 5
s.Per = 1
s.QuotaMax = -1
})

ts.Gw.BuildAndLoadAPI(func(spec *APISpec) {
spec.Proxy.ListenPath = "/"
spec.UseKeylessAccess = false
})

baseURL := strings.Replace(ts.URL, "http://", "ws://", -1)

// Function to create a new WebSocket connection
dialWS := func() (*websocket.Conn, *http.Response, error) {
headers := http.Header{"Authorization": {session}}
return websocket.DefaultDialer.Dial(baseURL+"/ws", headers)
}

conn, _, err := dialWS()
if err != nil {
t.Fatalf("cannot make websocket connection: %v", err)
}
defer conn.Close()

// Connect and send messages within rate limit
for i := 0; i < 4; i++ {
err = conn.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf("test message %d", i+1)))
if err != nil {
t.Fatalf("cannot write message: %v", err)
}
_, p, err := conn.ReadMessage()
if err != nil {
t.Fatalf("cannot read message: %v, %d", err, i)
}
if !strings.Contains(string(p), fmt.Sprintf("test message %d", i+1)) {
t.Errorf("Unexpected reply: %s", string(p))
}
}

err = conn.WriteMessage(websocket.TextMessage, []byte("exceeding rate limit"))
if err != nil {
t.Fatalf("cannot write message: %v", err)
}

_, _, err = conn.ReadMessage()
if err == nil {
t.Errorf("Expected rate limit error, got no error")
} else if !websocket.IsCloseError(err, 4000) {
t.Errorf("Expected rate limit error (code 4000), got: %v", err)
}

// Wait for rate limit to reset
time.Sleep(1 * time.Second)

// Should be able to connect and send message again
conn, _, err = dialWS()
if err != nil {
t.Fatalf("cannot make websocket connection after rate limit reset: %v", err)
}
defer conn.Close()

err = conn.WriteMessage(websocket.TextMessage, []byte("after rate limit reset"))
if err != nil {
t.Fatalf("cannot write message after rate limit reset: %v", err)
}

_, p, err := conn.ReadMessage()
if err != nil {
t.Fatalf("cannot read message after rate limit reset: %v", err)
}
if !strings.Contains(string(p), "after rate limit reset") {
t.Errorf("Unexpected reply after rate limit reset: %s", string(p))
}
}

func createTestUptream(t *testing.T, allowedConns int, readsPerConn int) net.Listener {
l, _ := net.Listen("tcp", "127.0.0.1:0")
go func() {
Expand Down
31 changes: 17 additions & 14 deletions gateway/handler_success.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"encoding/base64"
"io"
"io/ioutil"
"net/http"
"strconv"
"strings"
Expand Down Expand Up @@ -163,14 +162,14 @@ func recordGraphDetails(rec *analytics.AnalyticsRecord, r *http.Request, resp *h
}

func (s *SuccessHandler) RecordHit(r *http.Request, timing analytics.Latency, code int, responseCopy *http.Response, cached bool) {
log.Debug("Recording analytics hit")

if s.Spec.DoNotTrack || ctxGetDoNotTrack(r) {
return
}

ip := request.RealIP(r)
if s.Spec.GlobalConfig.StoreAnalytics(ip) {

t := time.Now()

// Track the key ID if it exists
Expand Down Expand Up @@ -224,18 +223,20 @@ func (s *SuccessHandler) RecordHit(r *http.Request, timing analytics.Latency, co
// we need to delete the chunked transfer encoding header to avoid malformed body in our rawResponse
httputil.RemoveResponseTransferEncoding(responseCopy, "chunked")

responseContent, err := io.ReadAll(responseCopy.Body)
if err != nil {
log.Error("Couldn't read response body", err)
if responseCopy.Body != nil {
responseContent, err := io.ReadAll(responseCopy.Body)
if err != nil {
log.Error("Couldn't read response body", err)
} else {
responseCopy.Body = respBodyReader(r, responseCopy)

// Get the wire format representation
var wireFormatRes bytes.Buffer
responseCopy.Body = io.NopCloser(bytes.NewBuffer(responseContent))
responseCopy.Write(&wireFormatRes)
rawResponse = base64.StdEncoding.EncodeToString(wireFormatRes.Bytes())
}
}

responseCopy.Body = respBodyReader(r, responseCopy)

// Get the wire format representation
var wireFormatRes bytes.Buffer
responseCopy.Write(&wireFormatRes)
responseCopy.Body = ioutil.NopCloser(bytes.NewBuffer(responseContent))
rawResponse = base64.StdEncoding.EncodeToString(wireFormatRes.Bytes())
}
}

Expand Down Expand Up @@ -320,6 +321,8 @@ func (s *SuccessHandler) RecordHit(r *http.Request, timing analytics.Latency, co
err := s.Gw.Analytics.RecordHit(&record)
if err != nil {
log.WithError(err).Error("could not store analytic record")
} else {
log.Debug("Succesfully recorded analytics")
}
}

Expand All @@ -329,7 +332,7 @@ func (s *SuccessHandler) RecordHit(r *http.Request, timing analytics.Latency, co

func recordDetail(r *http.Request, spec *APISpec) bool {
// when streaming in grpc, we do not record the request
if httputil.IsStreamingRequest(r) {
if spec.GlobalConfig.Streaming.EnableWebSocketDetailedRecording != true && httputil.IsStreamingRequest(r) {
return false
}

Expand Down
Loading
Loading