Skip to content

Commit

Permalink
Fix internal WebSocket server's handling of interleaved control frames
Browse files Browse the repository at this point in the history
In the previous implementation, if ping or pong frames were sent between
text frame and continuation frame, the WebSocket encoder didn't handle
the fragmented message correctly (e.g. text -> ping -> continuation,
      text -> continuation -> pong -> continuation)
In this case, the parts of the message before the ping or pong frames
were lost, so only the following parts were passed to the
HttpServer::Delegate.

To solve this problem, this CL switches the method of processing
messages in WebSocketEncoder::DecodeFrame() based on the type of
frame. For text and continuation frames, the message is buffered until
the end of the message, and no longer cleared if ping or pong frames
are received. On the other hand, for ping frames, the contents are
passed back directly and not buffered.

To check this implementation works correctly, add tests which send text
or ping frames in various orders.

This is a copy of
https://chromium-review.googlesource.com/c/chromium/src/+/3209721 by
Shiho Noda.

Bug: 1226710
Change-Id: I4ec52e72866e2ce534d654233babc0e07d886622
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3220895
Reviewed-by: Yoichi Osato <[email protected]>
Commit-Queue: Adam Rice <[email protected]>
Cr-Commit-Position: refs/heads/main@{#935349}
  • Loading branch information
ricea authored and Chromium LUCI CQ committed Oct 27, 2021
1 parent 8d3417e commit 16be071
Show file tree
Hide file tree
Showing 2 changed files with 300 additions and 33 deletions.
295 changes: 276 additions & 19 deletions net/server/http_server_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,21 +322,26 @@ class WebSocketAcceptingTest : public WebSocketTest {

void OnWebSocketMessage(int connection_id, std::string data) override {
message_ = data;
if (message_.length() > 0 && run_loop_) {
got_message_ = true;
if (run_loop_) {
run_loop_->Quit();
}
}

std::string GetMessage() {
run_loop_ = std::make_unique<base::RunLoop>();
run_loop_->Run();
run_loop_.reset();
const std::string& GetMessage() {
if (!got_message_) {
run_loop_ = std::make_unique<base::RunLoop>();
run_loop_->Run();
run_loop_.reset();
}
got_message_ = false;
return message_;
}

private:
std::string message_;
std::unique_ptr<base::RunLoop> run_loop_;
bool got_message_ = false;
};

std::string EncodeFrame(std::string message,
Expand All @@ -354,7 +359,7 @@ std::string EncodeFrame(std::string message,
WebSocketMaskingKey masking_key = GenerateWebSocketMaskingKey();
WriteWebSocketFrameHeader(header, &masking_key, &frame_header[0],
header_size);
MaskWebSocketFramePayload(masking_key, 0, &message[0], message.length());
MaskWebSocketFramePayload(masking_key, 0, &message[0], message.size());
} else {
WriteWebSocketFrameHeader(header, nullptr, &frame_header[0], header_size);
}
Expand Down Expand Up @@ -622,24 +627,276 @@ TEST_F(WebSocketAcceptingTest, SendLongTextFrame) {
"Sec-WebSocket-Key: key\r\n\r\n");
RunUntilRequestsReceived(1);
ASSERT_TRUE(client.ReadResponse(&response));
constexpr int kMessageSize = 100000;
const std::string text_message(kMessageSize, 'a');
const std::string continuation_message(kMessageSize, 'b');
const std::string text_frame =
EncodeFrame(text_message, WebSocketFrameHeader::OpCodeEnum::kOpCodeText,
constexpr int kFrameSize = 100000;
const std::string text_frame(kFrameSize, 'a');
const std::string continuation_frame(kFrameSize, 'b');
const std::string text_encoded_frame =
EncodeFrame(text_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeText,
/* mask= */ true,
/* finish= */ false);
const std::string continuation_frame =
EncodeFrame(continuation_message,
const std::string continuation_encoded_frame = EncodeFrame(
continuation_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation,
/* mask= */ true, /* finish= */ true);
client.Send(text_encoded_frame);
client.Send(continuation_encoded_frame);
std::string received_message = GetMessage();
EXPECT_EQ(received_message.size(),
text_frame.size() + continuation_frame.size());
EXPECT_EQ(received_message, text_frame + continuation_frame);
}

TEST_F(WebSocketAcceptingTest, SendTwoTextFrame) {
TestHttpClient client;
CreateConnection(&client);
std::string response;
client.Send(
"GET /test HTTP/1.1\r\n"
"Upgrade: WebSocket\r\n"
"Connection: SomethingElse, Upgrade\r\n"
"Sec-WebSocket-Version: 8\r\n"
"Sec-WebSocket-Key: key\r\n\r\n");
RunUntilRequestsReceived(1);
ASSERT_TRUE(client.ReadResponse(&response));
const std::string text_frame_first = "foo";
const std::string continuation_frame_first = "bar";
const std::string text_encoded_frame_first = EncodeFrame(
text_frame_first, WebSocketFrameHeader::OpCodeEnum::kOpCodeText,
/* mask= */ true,
/* finish= */ false);
const std::string continuation_encoded_frame_first =
EncodeFrame(continuation_frame_first,
WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation,
/* mask= */ true, /* finish= */ true);

const std::string text_frame_second = "FOO";
const std::string continuation_frame_second = "BAR";
const std::string text_encoded_frame_second = EncodeFrame(
text_frame_second, WebSocketFrameHeader::OpCodeEnum::kOpCodeText,
/* mask= */ true,
/* finish= */ false);
const std::string continuation_encoded_frame_second =
EncodeFrame(continuation_frame_second,
WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation,
/* mask= */ true, /* finish= */ true);
client.Send(text_frame);
client.Send(continuation_frame);

// text_encoded_frame_first -> text_encoded_frame_second
client.Send(text_encoded_frame_first);
client.Send(continuation_encoded_frame_first);
std::string received_message = GetMessage();
EXPECT_EQ(
static_cast<int>(received_message.length()),
static_cast<int>(text_message.length() + continuation_message.length()));
EXPECT_EQ(received_message, text_message + continuation_message);
EXPECT_EQ(received_message, "foobar");
client.Send(text_encoded_frame_second);
client.Send(continuation_encoded_frame_second);
received_message = GetMessage();
EXPECT_EQ(received_message, "FOOBAR");
}

TEST_F(WebSocketAcceptingTest, SendPingPongFrame) {
TestHttpClient client;
CreateConnection(&client);
std::string response;
client.Send(
"GET /test HTTP/1.1\r\n"
"Upgrade: WebSocket\r\n"
"Connection: SomethingElse, Upgrade\r\n"
"Sec-WebSocket-Version: 8\r\n"
"Sec-WebSocket-Key: key\r\n\r\n");
RunUntilRequestsReceived(1);
ASSERT_TRUE(client.ReadResponse(&response));

const std::string ping_message_first = "";
const std::string ping_frame_first = EncodeFrame(
ping_message_first, WebSocketFrameHeader::OpCodeEnum::kOpCodePing,
/* mask= */ true, /* finish= */ true);
const std::string pong_frame_receive_first = EncodeFrame(
ping_message_first, WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
/* mask= */ false, /* finish= */ true);
const std::string pong_frame_send = EncodeFrame(
/* message= */ "", WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
/* mask= */ true, /* finish= */ true);
const std::string ping_message_second = "hello";
const std::string ping_frame_second = EncodeFrame(
ping_message_second, WebSocketFrameHeader::OpCodeEnum::kOpCodePing,
/* mask= */ true, /* finish= */ true);
const std::string pong_frame_receive_second = EncodeFrame(
ping_message_second, WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
/* mask= */ false, /* finish= */ true);

// ping_frame_first -> pong_frame_send -> ping_frame_second
client.Send(ping_frame_first);
ASSERT_TRUE(client.Read(&response, pong_frame_receive_first.length()));
EXPECT_EQ(response, pong_frame_receive_first);
client.Send(pong_frame_send);
client.Send(ping_frame_second);
ASSERT_TRUE(client.Read(&response, pong_frame_receive_second.length()));
EXPECT_EQ(response, pong_frame_receive_second);
}

TEST_F(WebSocketAcceptingTest, SendTextAndPingFrame) {
TestHttpClient client;
CreateConnection(&client);
std::string response;
client.Send(
"GET /test HTTP/1.1\r\n"
"Upgrade: WebSocket\r\n"
"Connection: SomethingElse, Upgrade\r\n"
"Sec-WebSocket-Version: 8\r\n"
"Sec-WebSocket-Key: key\r\n\r\n");
RunUntilRequestsReceived(1);
ASSERT_TRUE(client.ReadResponse(&response));

const std::string text_frame = "foo";
const std::string continuation_frame = "bar";
const std::string text_encoded_frame =
EncodeFrame(text_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeText,
/* mask= */ true,
/* finish= */ false);
const std::string continuation_encoded_frame = EncodeFrame(
continuation_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation,
/* mask= */ true, /* finish= */ true);
const std::string ping_message = "ping";
const std::string ping_frame =
EncodeFrame(ping_message, WebSocketFrameHeader::OpCodeEnum::kOpCodePing,
/* mask= */ true, /* finish= */ true);
const std::string pong_frame =
EncodeFrame(ping_message, WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
/* mask= */ false, /* finish= */ true);

// text_encoded_frame -> ping_frame -> continuation_encoded_frame
client.Send(text_encoded_frame);
client.Send(ping_frame);
client.Send(continuation_encoded_frame);
ASSERT_TRUE(client.Read(&response, pong_frame.length()));
EXPECT_EQ(response, pong_frame);
std::string received_message = GetMessage();
EXPECT_EQ(received_message, "foobar");
}

TEST_F(WebSocketAcceptingTest, SendTextAndPingFrameWithMessage) {
TestHttpClient client;
CreateConnection(&client);
std::string response;
client.Send(
"GET /test HTTP/1.1\r\n"
"Upgrade: WebSocket\r\n"
"Connection: SomethingElse, Upgrade\r\n"
"Sec-WebSocket-Version: 8\r\n"
"Sec-WebSocket-Key: key\r\n\r\n");
RunUntilRequestsReceived(1);
ASSERT_TRUE(client.ReadResponse(&response));

const std::string text_frame = "foo";
const std::string continuation_frame = "bar";
const std::string text_encoded_frame =
EncodeFrame(text_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeText,
/* mask= */ true,
/* finish= */ false);
const std::string continuation_encoded_frame = EncodeFrame(
continuation_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation,
/* mask= */ true, /* finish= */ true);
const std::string ping_message = "hello";
const std::string ping_frame =
EncodeFrame(ping_message, WebSocketFrameHeader::OpCodeEnum::kOpCodePing,
/* mask= */ true, /* finish= */ true);
const std::string pong_frame =
EncodeFrame(ping_message, WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
/* mask= */ false, /* finish= */ true);

// text_encoded_frame -> ping_frame -> continuation_frame
client.Send(text_encoded_frame);
client.Send(ping_frame);
client.Send(continuation_encoded_frame);
ASSERT_TRUE(client.Read(&response, pong_frame.length()));
EXPECT_EQ(response, pong_frame);
std::string received_message = GetMessage();
EXPECT_EQ(received_message, "foobar");
}

TEST_F(WebSocketAcceptingTest, SendTextAndPongFrame) {
TestHttpClient client;
CreateConnection(&client);
std::string response;
client.Send(
"GET /test HTTP/1.1\r\n"
"Upgrade: WebSocket\r\n"
"Connection: SomethingElse, Upgrade\r\n"
"Sec-WebSocket-Version: 8\r\n"
"Sec-WebSocket-Key: key\r\n\r\n");
RunUntilRequestsReceived(1);
ASSERT_TRUE(client.ReadResponse(&response));

const std::string text_frame = "foo";
const std::string continuation_frame = "bar";
const std::string text_encoded_frame =
EncodeFrame(text_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeText,
/* mask= */ true,
/* finish= */ false);
const std::string continuation_encoded_frame = EncodeFrame(
continuation_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation,
/* mask= */ true, /* finish= */ true);
const std::string pong_message = "pong";
const std::string pong_frame =
EncodeFrame(pong_message, WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
/* mask= */ true, /* finish= */ true);

// text_encoded_frame -> pong_frame -> continuation_encoded_frame
client.Send(text_encoded_frame);
client.Send(pong_frame);
client.Send(continuation_encoded_frame);
std::string received_message = GetMessage();
EXPECT_EQ(received_message, "foobar");
}

TEST_F(WebSocketAcceptingTest, SendTextPingPongFrame) {
TestHttpClient client;
CreateConnection(&client);
std::string response;
client.Send(
"GET /test HTTP/1.1\r\n"
"Upgrade: WebSocket\r\n"
"Connection: SomethingElse, Upgrade\r\n"
"Sec-WebSocket-Version: 8\r\n"
"Sec-WebSocket-Key: key\r\n\r\n");
RunUntilRequestsReceived(1);
ASSERT_TRUE(client.ReadResponse(&response));

const std::string text_frame = "foo";
const std::string continuation_frame = "bar";
const std::string text_encoded_frame =
EncodeFrame(text_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeText,
/* mask= */ true,
/* finish= */ false);
const std::string continuation_encoded_frame = EncodeFrame(
continuation_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation,
/* mask= */ true, /* finish= */ true);

const std::string ping_message_first = "hello";
const std::string ping_frame_first = EncodeFrame(
ping_message_first, WebSocketFrameHeader::OpCodeEnum::kOpCodePing,
/* mask= */ true, /* finish= */ true);
const std::string pong_frame_first = EncodeFrame(
ping_message_first, WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
/* mask= */ false, /* finish= */ true);

const std::string ping_message_second = "HELLO";
const std::string ping_frame_second = EncodeFrame(
ping_message_second, WebSocketFrameHeader::OpCodeEnum::kOpCodePing,
/* mask= */ true, /* finish= */ true);
const std::string pong_frame_second = EncodeFrame(
ping_message_second, WebSocketFrameHeader::OpCodeEnum::kOpCodePong,
/* mask= */ false, /* finish= */ true);

// text_encoded_frame -> ping_frame_first -> ping_frame_second ->
// continuation_encoded_frame
client.Send(text_encoded_frame);
client.Send(ping_frame_first);
ASSERT_TRUE(client.Read(&response, pong_frame_first.length()));
EXPECT_EQ(response, pong_frame_first);
client.Send(ping_frame_second);
ASSERT_TRUE(client.Read(&response, pong_frame_second.length()));
EXPECT_EQ(response, pong_frame_second);
client.Send(continuation_encoded_frame);
std::string received_message = GetMessage();
EXPECT_EQ(received_message, "foobar");
}

TEST_F(HttpServerTest, RequestWithTooLargeBody) {
Expand Down
38 changes: 24 additions & 14 deletions net/server/web_socket_encoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -301,22 +301,32 @@ WebSocket::ParseResult WebSocketEncoder::DecodeFrame(
std::string current_output;
WebSocket::ParseResult result = DecodeFrameHybi17(
frame, type_ == FOR_SERVER, bytes_consumed, &current_output, &compressed);
if (result == WebSocket::FRAME_OK_FINAL ||
result == WebSocket::FRAME_OK_MIDDLE || result == WebSocket::FRAME_PING) {
if (continuation_message_frames_.empty())
is_current_message_compressed_ = compressed;
continuation_message_frames_.push_back(current_output);
}
if (result == WebSocket::FRAME_OK_FINAL || result == WebSocket::FRAME_PING) {
*output = base::StrCat(continuation_message_frames_);
if (is_current_message_compressed_) {
if (!Inflate(output))
result = WebSocket::FRAME_ERROR;
switch (result) {
case WebSocket::FRAME_OK_FINAL:
case WebSocket::FRAME_OK_MIDDLE: {
if (continuation_message_frames_.empty())
is_current_message_compressed_ = compressed;
continuation_message_frames_.push_back(current_output);

if (result == WebSocket::FRAME_OK_FINAL) {
*output = base::StrCat(continuation_message_frames_);
continuation_message_frames_.clear();
if (is_current_message_compressed_ && !Inflate(output)) {
return WebSocket::FRAME_ERROR;
}
}
break;
}

case WebSocket::FRAME_PING:
*output = current_output;
break;

default:
// This function doesn't need special handling for other parse results.
break;
}
if (result != WebSocket::FRAME_OK_MIDDLE &&
result != WebSocket::FRAME_INCOMPLETE)
continuation_message_frames_.clear();

return result;
}

Expand Down

0 comments on commit 16be071

Please sign in to comment.