From 16be071d29d7d6f526319c281b2d22fc2bba0392 Mon Sep 17 00:00:00 2001 From: Adam Rice Date: Wed, 27 Oct 2021 09:13:20 +0000 Subject: [PATCH] Fix internal WebSocket server's handling of interleaved control frames In the previous implementation, if ping or pong frames were sent between text frame and continuation frame, the WebSocket encoder didn't handle the fragmented message correctly (e.g. text -> ping -> continuation, text -> continuation -> pong -> continuation) In this case, the parts of the message before the ping or pong frames were lost, so only the following parts were passed to the HttpServer::Delegate. To solve this problem, this CL switches the method of processing messages in WebSocketEncoder::DecodeFrame() based on the type of frame. For text and continuation frames, the message is buffered until the end of the message, and no longer cleared if ping or pong frames are received. On the other hand, for ping frames, the contents are passed back directly and not buffered. To check this implementation works correctly, add tests which send text or ping frames in various orders. This is a copy of https://chromium-review.googlesource.com/c/chromium/src/+/3209721 by Shiho Noda. Bug: 1226710 Change-Id: I4ec52e72866e2ce534d654233babc0e07d886622 Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/3220895 Reviewed-by: Yoichi Osato Commit-Queue: Adam Rice Cr-Commit-Position: refs/heads/main@{#935349} --- net/server/http_server_unittest.cc | 295 +++++++++++++++++++++++++++-- net/server/web_socket_encoder.cc | 38 ++-- 2 files changed, 300 insertions(+), 33 deletions(-) diff --git a/net/server/http_server_unittest.cc b/net/server/http_server_unittest.cc index c2a56ef9b3c537..4c59764ee5d2d0 100644 --- a/net/server/http_server_unittest.cc +++ b/net/server/http_server_unittest.cc @@ -322,21 +322,26 @@ class WebSocketAcceptingTest : public WebSocketTest { void OnWebSocketMessage(int connection_id, std::string data) override { message_ = data; - if (message_.length() > 0 && run_loop_) { + got_message_ = true; + if (run_loop_) { run_loop_->Quit(); } } - std::string GetMessage() { - run_loop_ = std::make_unique(); - run_loop_->Run(); - run_loop_.reset(); + const std::string& GetMessage() { + if (!got_message_) { + run_loop_ = std::make_unique(); + run_loop_->Run(); + run_loop_.reset(); + } + got_message_ = false; return message_; } private: std::string message_; std::unique_ptr run_loop_; + bool got_message_ = false; }; std::string EncodeFrame(std::string message, @@ -354,7 +359,7 @@ std::string EncodeFrame(std::string message, WebSocketMaskingKey masking_key = GenerateWebSocketMaskingKey(); WriteWebSocketFrameHeader(header, &masking_key, &frame_header[0], header_size); - MaskWebSocketFramePayload(masking_key, 0, &message[0], message.length()); + MaskWebSocketFramePayload(masking_key, 0, &message[0], message.size()); } else { WriteWebSocketFrameHeader(header, nullptr, &frame_header[0], header_size); } @@ -622,24 +627,276 @@ TEST_F(WebSocketAcceptingTest, SendLongTextFrame) { "Sec-WebSocket-Key: key\r\n\r\n"); RunUntilRequestsReceived(1); ASSERT_TRUE(client.ReadResponse(&response)); - constexpr int kMessageSize = 100000; - const std::string text_message(kMessageSize, 'a'); - const std::string continuation_message(kMessageSize, 'b'); - const std::string text_frame = - EncodeFrame(text_message, WebSocketFrameHeader::OpCodeEnum::kOpCodeText, + constexpr int kFrameSize = 100000; + const std::string text_frame(kFrameSize, 'a'); + const std::string continuation_frame(kFrameSize, 'b'); + const std::string text_encoded_frame = + EncodeFrame(text_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeText, /* mask= */ true, /* finish= */ false); - const std::string continuation_frame = - EncodeFrame(continuation_message, + const std::string continuation_encoded_frame = EncodeFrame( + continuation_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation, + /* mask= */ true, /* finish= */ true); + client.Send(text_encoded_frame); + client.Send(continuation_encoded_frame); + std::string received_message = GetMessage(); + EXPECT_EQ(received_message.size(), + text_frame.size() + continuation_frame.size()); + EXPECT_EQ(received_message, text_frame + continuation_frame); +} + +TEST_F(WebSocketAcceptingTest, SendTwoTextFrame) { + TestHttpClient client; + CreateConnection(&client); + std::string response; + client.Send( + "GET /test HTTP/1.1\r\n" + "Upgrade: WebSocket\r\n" + "Connection: SomethingElse, Upgrade\r\n" + "Sec-WebSocket-Version: 8\r\n" + "Sec-WebSocket-Key: key\r\n\r\n"); + RunUntilRequestsReceived(1); + ASSERT_TRUE(client.ReadResponse(&response)); + const std::string text_frame_first = "foo"; + const std::string continuation_frame_first = "bar"; + const std::string text_encoded_frame_first = EncodeFrame( + text_frame_first, WebSocketFrameHeader::OpCodeEnum::kOpCodeText, + /* mask= */ true, + /* finish= */ false); + const std::string continuation_encoded_frame_first = + EncodeFrame(continuation_frame_first, + WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation, + /* mask= */ true, /* finish= */ true); + + const std::string text_frame_second = "FOO"; + const std::string continuation_frame_second = "BAR"; + const std::string text_encoded_frame_second = EncodeFrame( + text_frame_second, WebSocketFrameHeader::OpCodeEnum::kOpCodeText, + /* mask= */ true, + /* finish= */ false); + const std::string continuation_encoded_frame_second = + EncodeFrame(continuation_frame_second, WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation, /* mask= */ true, /* finish= */ true); - client.Send(text_frame); - client.Send(continuation_frame); + + // text_encoded_frame_first -> text_encoded_frame_second + client.Send(text_encoded_frame_first); + client.Send(continuation_encoded_frame_first); std::string received_message = GetMessage(); - EXPECT_EQ( - static_cast(received_message.length()), - static_cast(text_message.length() + continuation_message.length())); - EXPECT_EQ(received_message, text_message + continuation_message); + EXPECT_EQ(received_message, "foobar"); + client.Send(text_encoded_frame_second); + client.Send(continuation_encoded_frame_second); + received_message = GetMessage(); + EXPECT_EQ(received_message, "FOOBAR"); +} + +TEST_F(WebSocketAcceptingTest, SendPingPongFrame) { + TestHttpClient client; + CreateConnection(&client); + std::string response; + client.Send( + "GET /test HTTP/1.1\r\n" + "Upgrade: WebSocket\r\n" + "Connection: SomethingElse, Upgrade\r\n" + "Sec-WebSocket-Version: 8\r\n" + "Sec-WebSocket-Key: key\r\n\r\n"); + RunUntilRequestsReceived(1); + ASSERT_TRUE(client.ReadResponse(&response)); + + const std::string ping_message_first = ""; + const std::string ping_frame_first = EncodeFrame( + ping_message_first, WebSocketFrameHeader::OpCodeEnum::kOpCodePing, + /* mask= */ true, /* finish= */ true); + const std::string pong_frame_receive_first = EncodeFrame( + ping_message_first, WebSocketFrameHeader::OpCodeEnum::kOpCodePong, + /* mask= */ false, /* finish= */ true); + const std::string pong_frame_send = EncodeFrame( + /* message= */ "", WebSocketFrameHeader::OpCodeEnum::kOpCodePong, + /* mask= */ true, /* finish= */ true); + const std::string ping_message_second = "hello"; + const std::string ping_frame_second = EncodeFrame( + ping_message_second, WebSocketFrameHeader::OpCodeEnum::kOpCodePing, + /* mask= */ true, /* finish= */ true); + const std::string pong_frame_receive_second = EncodeFrame( + ping_message_second, WebSocketFrameHeader::OpCodeEnum::kOpCodePong, + /* mask= */ false, /* finish= */ true); + + // ping_frame_first -> pong_frame_send -> ping_frame_second + client.Send(ping_frame_first); + ASSERT_TRUE(client.Read(&response, pong_frame_receive_first.length())); + EXPECT_EQ(response, pong_frame_receive_first); + client.Send(pong_frame_send); + client.Send(ping_frame_second); + ASSERT_TRUE(client.Read(&response, pong_frame_receive_second.length())); + EXPECT_EQ(response, pong_frame_receive_second); +} + +TEST_F(WebSocketAcceptingTest, SendTextAndPingFrame) { + TestHttpClient client; + CreateConnection(&client); + std::string response; + client.Send( + "GET /test HTTP/1.1\r\n" + "Upgrade: WebSocket\r\n" + "Connection: SomethingElse, Upgrade\r\n" + "Sec-WebSocket-Version: 8\r\n" + "Sec-WebSocket-Key: key\r\n\r\n"); + RunUntilRequestsReceived(1); + ASSERT_TRUE(client.ReadResponse(&response)); + + const std::string text_frame = "foo"; + const std::string continuation_frame = "bar"; + const std::string text_encoded_frame = + EncodeFrame(text_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeText, + /* mask= */ true, + /* finish= */ false); + const std::string continuation_encoded_frame = EncodeFrame( + continuation_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation, + /* mask= */ true, /* finish= */ true); + const std::string ping_message = "ping"; + const std::string ping_frame = + EncodeFrame(ping_message, WebSocketFrameHeader::OpCodeEnum::kOpCodePing, + /* mask= */ true, /* finish= */ true); + const std::string pong_frame = + EncodeFrame(ping_message, WebSocketFrameHeader::OpCodeEnum::kOpCodePong, + /* mask= */ false, /* finish= */ true); + + // text_encoded_frame -> ping_frame -> continuation_encoded_frame + client.Send(text_encoded_frame); + client.Send(ping_frame); + client.Send(continuation_encoded_frame); + ASSERT_TRUE(client.Read(&response, pong_frame.length())); + EXPECT_EQ(response, pong_frame); + std::string received_message = GetMessage(); + EXPECT_EQ(received_message, "foobar"); +} + +TEST_F(WebSocketAcceptingTest, SendTextAndPingFrameWithMessage) { + TestHttpClient client; + CreateConnection(&client); + std::string response; + client.Send( + "GET /test HTTP/1.1\r\n" + "Upgrade: WebSocket\r\n" + "Connection: SomethingElse, Upgrade\r\n" + "Sec-WebSocket-Version: 8\r\n" + "Sec-WebSocket-Key: key\r\n\r\n"); + RunUntilRequestsReceived(1); + ASSERT_TRUE(client.ReadResponse(&response)); + + const std::string text_frame = "foo"; + const std::string continuation_frame = "bar"; + const std::string text_encoded_frame = + EncodeFrame(text_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeText, + /* mask= */ true, + /* finish= */ false); + const std::string continuation_encoded_frame = EncodeFrame( + continuation_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation, + /* mask= */ true, /* finish= */ true); + const std::string ping_message = "hello"; + const std::string ping_frame = + EncodeFrame(ping_message, WebSocketFrameHeader::OpCodeEnum::kOpCodePing, + /* mask= */ true, /* finish= */ true); + const std::string pong_frame = + EncodeFrame(ping_message, WebSocketFrameHeader::OpCodeEnum::kOpCodePong, + /* mask= */ false, /* finish= */ true); + + // text_encoded_frame -> ping_frame -> continuation_frame + client.Send(text_encoded_frame); + client.Send(ping_frame); + client.Send(continuation_encoded_frame); + ASSERT_TRUE(client.Read(&response, pong_frame.length())); + EXPECT_EQ(response, pong_frame); + std::string received_message = GetMessage(); + EXPECT_EQ(received_message, "foobar"); +} + +TEST_F(WebSocketAcceptingTest, SendTextAndPongFrame) { + TestHttpClient client; + CreateConnection(&client); + std::string response; + client.Send( + "GET /test HTTP/1.1\r\n" + "Upgrade: WebSocket\r\n" + "Connection: SomethingElse, Upgrade\r\n" + "Sec-WebSocket-Version: 8\r\n" + "Sec-WebSocket-Key: key\r\n\r\n"); + RunUntilRequestsReceived(1); + ASSERT_TRUE(client.ReadResponse(&response)); + + const std::string text_frame = "foo"; + const std::string continuation_frame = "bar"; + const std::string text_encoded_frame = + EncodeFrame(text_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeText, + /* mask= */ true, + /* finish= */ false); + const std::string continuation_encoded_frame = EncodeFrame( + continuation_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation, + /* mask= */ true, /* finish= */ true); + const std::string pong_message = "pong"; + const std::string pong_frame = + EncodeFrame(pong_message, WebSocketFrameHeader::OpCodeEnum::kOpCodePong, + /* mask= */ true, /* finish= */ true); + + // text_encoded_frame -> pong_frame -> continuation_encoded_frame + client.Send(text_encoded_frame); + client.Send(pong_frame); + client.Send(continuation_encoded_frame); + std::string received_message = GetMessage(); + EXPECT_EQ(received_message, "foobar"); +} + +TEST_F(WebSocketAcceptingTest, SendTextPingPongFrame) { + TestHttpClient client; + CreateConnection(&client); + std::string response; + client.Send( + "GET /test HTTP/1.1\r\n" + "Upgrade: WebSocket\r\n" + "Connection: SomethingElse, Upgrade\r\n" + "Sec-WebSocket-Version: 8\r\n" + "Sec-WebSocket-Key: key\r\n\r\n"); + RunUntilRequestsReceived(1); + ASSERT_TRUE(client.ReadResponse(&response)); + + const std::string text_frame = "foo"; + const std::string continuation_frame = "bar"; + const std::string text_encoded_frame = + EncodeFrame(text_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeText, + /* mask= */ true, + /* finish= */ false); + const std::string continuation_encoded_frame = EncodeFrame( + continuation_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation, + /* mask= */ true, /* finish= */ true); + + const std::string ping_message_first = "hello"; + const std::string ping_frame_first = EncodeFrame( + ping_message_first, WebSocketFrameHeader::OpCodeEnum::kOpCodePing, + /* mask= */ true, /* finish= */ true); + const std::string pong_frame_first = EncodeFrame( + ping_message_first, WebSocketFrameHeader::OpCodeEnum::kOpCodePong, + /* mask= */ false, /* finish= */ true); + + const std::string ping_message_second = "HELLO"; + const std::string ping_frame_second = EncodeFrame( + ping_message_second, WebSocketFrameHeader::OpCodeEnum::kOpCodePing, + /* mask= */ true, /* finish= */ true); + const std::string pong_frame_second = EncodeFrame( + ping_message_second, WebSocketFrameHeader::OpCodeEnum::kOpCodePong, + /* mask= */ false, /* finish= */ true); + + // text_encoded_frame -> ping_frame_first -> ping_frame_second -> + // continuation_encoded_frame + client.Send(text_encoded_frame); + client.Send(ping_frame_first); + ASSERT_TRUE(client.Read(&response, pong_frame_first.length())); + EXPECT_EQ(response, pong_frame_first); + client.Send(ping_frame_second); + ASSERT_TRUE(client.Read(&response, pong_frame_second.length())); + EXPECT_EQ(response, pong_frame_second); + client.Send(continuation_encoded_frame); + std::string received_message = GetMessage(); + EXPECT_EQ(received_message, "foobar"); } TEST_F(HttpServerTest, RequestWithTooLargeBody) { diff --git a/net/server/web_socket_encoder.cc b/net/server/web_socket_encoder.cc index 7e5e652786b961..a70e587861a382 100644 --- a/net/server/web_socket_encoder.cc +++ b/net/server/web_socket_encoder.cc @@ -301,22 +301,32 @@ WebSocket::ParseResult WebSocketEncoder::DecodeFrame( std::string current_output; WebSocket::ParseResult result = DecodeFrameHybi17( frame, type_ == FOR_SERVER, bytes_consumed, ¤t_output, &compressed); - if (result == WebSocket::FRAME_OK_FINAL || - result == WebSocket::FRAME_OK_MIDDLE || result == WebSocket::FRAME_PING) { - if (continuation_message_frames_.empty()) - is_current_message_compressed_ = compressed; - continuation_message_frames_.push_back(current_output); - } - if (result == WebSocket::FRAME_OK_FINAL || result == WebSocket::FRAME_PING) { - *output = base::StrCat(continuation_message_frames_); - if (is_current_message_compressed_) { - if (!Inflate(output)) - result = WebSocket::FRAME_ERROR; + switch (result) { + case WebSocket::FRAME_OK_FINAL: + case WebSocket::FRAME_OK_MIDDLE: { + if (continuation_message_frames_.empty()) + is_current_message_compressed_ = compressed; + continuation_message_frames_.push_back(current_output); + + if (result == WebSocket::FRAME_OK_FINAL) { + *output = base::StrCat(continuation_message_frames_); + continuation_message_frames_.clear(); + if (is_current_message_compressed_ && !Inflate(output)) { + return WebSocket::FRAME_ERROR; + } + } + break; } + + case WebSocket::FRAME_PING: + *output = current_output; + break; + + default: + // This function doesn't need special handling for other parse results. + break; } - if (result != WebSocket::FRAME_OK_MIDDLE && - result != WebSocket::FRAME_INCOMPLETE) - continuation_message_frames_.clear(); + return result; }