Skip to content

Commit c617913

Browse files
commoddityadshmhOlshansk
authored
[QoS][Cosmos] add minimal websockets request context to allow WS on Cosmos (#375)
## 🌿 Summary Add minimal WebSocket request context validation to enable WebSocket connections for Cosmos services. ### 🌱 Primary Changes: - Implement `validateWebsocketRequest` to handle WebSocket upgrade requests and RPC type validation - Add specialized `buildWebsocketRequestContext` to create WebSocket-specific request contexts - Introduce error handling for unsupported WebSocket RPC types via `createWebsocketUnsupportedRPCTypeContext` ### 🍃 Secondary changes: - Add new file `request_validator_websocket.go` to encapsulate WebSocket validation logic - Include request observation building for WebSocket connections - Update request context structure to accommodate WebSocket-specific fields ## 🛠️ Type of change Select one or more from the following: - [x] New feature, functionality or library - [x] Bug fix - [ ] Code health or cleanup - [ ] Documentation - [ ] Other (specify) --------- Co-authored-by: Arash Deshmeh <[email protected]> Co-authored-by: Daniel Olshansky <[email protected]>
1 parent 8ed80dc commit c617913

File tree

5 files changed

+195
-7
lines changed

5 files changed

+195
-7
lines changed

qos/cosmos/qos.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,7 @@ func (qos *QoS) ParseHTTPRequest(_ context.Context, req *http.Request) (gateway.
9090
//
9191
// Implements gateway.QoSService interface.
9292
func (qos *QoS) ParseWebsocketRequest(_ context.Context) (gateway.RequestQoSContext, bool) {
93-
return &requestContext{
94-
logger: qos.logger,
95-
serviceState: qos.serviceState,
96-
}, true
93+
return qos.validateWebsocketRequest()
9794
}
9895

9996
// HydrateDisqualifiedEndpointsResponse hydrates the disqualified endpoint response with the QoS-specific data.

qos/cosmos/request_validator_jsonrpc.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package cosmos
33
import (
44
"encoding/json"
55
"errors"
6+
"fmt"
67
"net/http"
78

89
"github.com/pokt-network/poktroll/pkg/polylog"
@@ -48,7 +49,7 @@ func (rv *requestValidator) validateJSONRPCRequest(
4849

4950
// Hydrate the logger with data extracted from the request.
5051
logger = logger.With(
51-
"detected_rpc_type", rpcType.String(),
52+
"rpc_type", rpcType.String(),
5253
"jsonrpc_method", method,
5354
)
5455

@@ -281,7 +282,7 @@ func (rv *requestValidator) createJSONRPCUnsupportedRPCTypeObservation(
281282
},
282283
RequestLevelError: &qosobservations.RequestError{
283284
ErrorKind: qosobservations.RequestErrorKind_REQUEST_ERROR_USER_ERROR_JSONRPC_UNSUPPORTED_RPC_TYPE,
284-
ErrorDetails: "Unsupported RPC type: " + rpcType.String(),
285+
ErrorDetails: fmt.Sprintf("Unsupported RPC type %s for service %s", rpcType.String(), string(rv.serviceID)),
285286
HttpStatusCode: int32(jsonrpcResponse.GetRecommendedHTTPStatusCode()),
286287
},
287288
}

qos/cosmos/request_validator_rest.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ func (rv *requestValidator) validateRESTRequest(
3636
rpcType := determineRESTRPCType(httpRequestPath)
3737

3838
logger = logger.With(
39-
"detected_rpc_type", rpcType.String(),
39+
"rpc_type", rpcType.String(),
4040
"request_path", httpRequestPath,
4141
)
4242

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
package cosmos
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
7+
sharedtypes "github.com/pokt-network/poktroll/x/shared/types"
8+
9+
"github.com/buildwithgrove/path/gateway"
10+
qosobservations "github.com/buildwithgrove/path/observation/qos"
11+
"github.com/buildwithgrove/path/qos"
12+
"github.com/buildwithgrove/path/qos/jsonrpc"
13+
)
14+
15+
// TODO_IMPROVE(@commoddity): Add endpoint-level QoS checks to determine WebSocket support.
16+
// Currently validates WebSocket upgrade requests at the service level only.
17+
18+
// validateWebsocketRequest validates WebSocket upgrade requests for Cosmos SDK services.
19+
// Returns (requestContext, true) if WebSocket is supported.
20+
// Returns (errorContext, false) if WebSocket is not configured for this service.
21+
func (rv *requestValidator) validateWebsocketRequest() (gateway.RequestQoSContext, bool) {
22+
logger := rv.logger.With(
23+
"validator", "WebSocket",
24+
"method", "validateWebsocketRequest",
25+
)
26+
rpcType := sharedtypes.RPCType_WEBSOCKET
27+
logger = logger.With("rpc_type", rpcType.String())
28+
29+
// Verify WebSocket support in service configuration
30+
if _, supported := rv.supportedAPIs[sharedtypes.RPCType_WEBSOCKET]; !supported {
31+
logger.Warn().Msg("Request uses unsupported WebSocket RPC type")
32+
return rv.createWebsocketUnsupportedRPCTypeContext(rpcType), false
33+
}
34+
35+
// Build and return the request context
36+
return rv.buildWebsocketRequestContext(
37+
rpcType,
38+
qosobservations.RequestOrigin_REQUEST_ORIGIN_ORGANIC,
39+
), true
40+
}
41+
42+
// buildWebsocketRequestContext builds a request context for WebSocket upgrade requests.
43+
func (rv *requestValidator) buildWebsocketRequestContext(
44+
rpcType sharedtypes.RPCType,
45+
requestOrigin qosobservations.RequestOrigin,
46+
) gateway.RequestQoSContext {
47+
logger := rv.logger.With(
48+
"method", "buildWebsocketRequestContext",
49+
)
50+
requestObservation := rv.buildWebsocketRequestObservations(
51+
rpcType,
52+
requestOrigin,
53+
)
54+
return &requestContext{
55+
logger: logger,
56+
serviceState: rv.serviceState,
57+
observations: requestObservation,
58+
protocolErrorObservationBuilder: buildProtocolErrorObservation,
59+
}
60+
}
61+
62+
// buildWebsocketRequestObservations builds a request observation for WebSocket upgrade requests.
63+
func (rv *requestValidator) buildWebsocketRequestObservations(
64+
rpcType sharedtypes.RPCType,
65+
requestOrigin qosobservations.RequestOrigin,
66+
) *qosobservations.CosmosRequestObservations {
67+
68+
return &qosobservations.CosmosRequestObservations{
69+
CosmosChainId: rv.cosmosChainID,
70+
ServiceId: string(rv.serviceID),
71+
RequestOrigin: requestOrigin,
72+
RequestProfile: &qosobservations.CosmosRequestProfile{
73+
BackendServiceDetails: &qosobservations.BackendServiceDetails{
74+
BackendServiceType: convertToProtoBackendServiceType(rpcType),
75+
SelectionReason: "WebSocket upgrade request detection",
76+
},
77+
},
78+
}
79+
}
80+
81+
// createWebsocketUnsupportedRPCTypeContext creates error context when WebSocket is not configured
82+
func (rv *requestValidator) createWebsocketUnsupportedRPCTypeContext(
83+
rpcType sharedtypes.RPCType,
84+
) gateway.RequestQoSContext {
85+
err := errors.New("WebSocket not supported for this service")
86+
response := jsonrpc.NewErrResponseInvalidRequest(jsonrpc.ID{}, err)
87+
88+
observations := rv.createWebsocketUnsupportedRPCTypeObservation(rpcType, response)
89+
90+
return &qos.RequestErrorContext{
91+
Logger: rv.logger,
92+
Response: response,
93+
Observations: &qosobservations.Observations{
94+
ServiceObservations: &qosobservations.Observations_Cosmos{
95+
Cosmos: observations,
96+
},
97+
},
98+
}
99+
}
100+
101+
// createWebsocketUnsupportedRPCTypeObservation creates an observation for unsupported WebSocket requests
102+
func (rv *requestValidator) createWebsocketUnsupportedRPCTypeObservation(
103+
rpcType sharedtypes.RPCType,
104+
jsonrpcResponse jsonrpc.Response,
105+
) *qosobservations.CosmosRequestObservations {
106+
return &qosobservations.CosmosRequestObservations{
107+
ServiceId: string(rv.serviceID),
108+
CosmosChainId: rv.cosmosChainID,
109+
RequestOrigin: qosobservations.RequestOrigin_REQUEST_ORIGIN_ORGANIC,
110+
RequestProfile: &qosobservations.CosmosRequestProfile{
111+
BackendServiceDetails: &qosobservations.BackendServiceDetails{
112+
BackendServiceType: convertToProtoBackendServiceType(rpcType),
113+
SelectionReason: "WebSocket upgrade request detection (unsupported)",
114+
},
115+
},
116+
RequestLevelError: &qosobservations.RequestError{
117+
ErrorKind: qosobservations.RequestErrorKind_REQUEST_ERROR_USER_ERROR_JSONRPC_UNSUPPORTED_RPC_TYPE,
118+
ErrorDetails: fmt.Sprintf("Unsupported RPC type %s for service %s", rpcType.String(), string(rv.serviceID)),
119+
HttpStatusCode: int32(jsonrpcResponse.GetRecommendedHTTPStatusCode()),
120+
},
121+
}
122+
}
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package cosmos
2+
3+
import (
4+
"testing"
5+
6+
"github.com/pokt-network/poktroll/pkg/polylog/polyzero"
7+
sharedtypes "github.com/pokt-network/poktroll/x/shared/types"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func Test_validateWebsocketRequest(t *testing.T) {
12+
tests := []struct {
13+
name string
14+
supportedAPIs map[sharedtypes.RPCType]struct{}
15+
expectSuccess bool
16+
expectErrorContextType bool
17+
}{
18+
{
19+
name: "supported websockets service config",
20+
supportedAPIs: map[sharedtypes.RPCType]struct{}{
21+
sharedtypes.RPCType_WEBSOCKET: {},
22+
},
23+
expectSuccess: true,
24+
expectErrorContextType: false,
25+
},
26+
{
27+
name: "unsupported websockets service config",
28+
supportedAPIs: map[sharedtypes.RPCType]struct{}{
29+
sharedtypes.RPCType_REST: {},
30+
},
31+
expectSuccess: false,
32+
expectErrorContextType: true,
33+
},
34+
}
35+
36+
for _, tt := range tests {
37+
t.Run(tt.name, func(t *testing.T) {
38+
// Set up the request validator with test data
39+
validator := &requestValidator{
40+
logger: polyzero.NewLogger(),
41+
cosmosChainID: "test-chain",
42+
serviceID: "test-service",
43+
supportedAPIs: tt.supportedAPIs,
44+
serviceState: &serviceState{}, // minimal setup for context building
45+
}
46+
47+
// Call the function under test
48+
ctx, success := validator.validateWebsocketRequest()
49+
50+
// Verify the success/failure expectation
51+
require.Equal(t, tt.expectSuccess, success, "validateWebsocketRequest success result mismatch")
52+
53+
// Verify context is not nil
54+
require.NotNil(t, ctx, "returned context should not be nil")
55+
56+
// Additional verification based on expected result type
57+
if tt.expectErrorContextType {
58+
// For unsupported case, we expect some kind of error context
59+
// We can't easily type assert the exact error context type without more imports,
60+
// but we can verify it's not nil and success is false
61+
require.False(t, success, "should return false for unsupported WebSocket")
62+
} else {
63+
// For supported case, we expect success
64+
require.True(t, success, "should return true for supported WebSocket")
65+
}
66+
})
67+
}
68+
}

0 commit comments

Comments
 (0)