diff --git a/net/server/http_server_unittest.cc b/net/server/http_server_unittest.cc index c2a56ef9b3c537..4c59764ee5d2d0 100644 --- a/net/server/http_server_unittest.cc +++ b/net/server/http_server_unittest.cc @@ -322,21 +322,26 @@ class WebSocketAcceptingTest : public WebSocketTest { void OnWebSocketMessage(int connection_id, std::string data) override { message_ = data; - if (message_.length() > 0 && run_loop_) { + got_message_ = true; + if (run_loop_) { run_loop_->Quit(); } } - std::string GetMessage() { - run_loop_ = std::make_unique(); - run_loop_->Run(); - run_loop_.reset(); + const std::string& GetMessage() { + if (!got_message_) { + run_loop_ = std::make_unique(); + run_loop_->Run(); + run_loop_.reset(); + } + got_message_ = false; return message_; } private: std::string message_; std::unique_ptr run_loop_; + bool got_message_ = false; }; std::string EncodeFrame(std::string message, @@ -354,7 +359,7 @@ std::string EncodeFrame(std::string message, WebSocketMaskingKey masking_key = GenerateWebSocketMaskingKey(); WriteWebSocketFrameHeader(header, &masking_key, &frame_header[0], header_size); - MaskWebSocketFramePayload(masking_key, 0, &message[0], message.length()); + MaskWebSocketFramePayload(masking_key, 0, &message[0], message.size()); } else { WriteWebSocketFrameHeader(header, nullptr, &frame_header[0], header_size); } @@ -622,24 +627,276 @@ TEST_F(WebSocketAcceptingTest, SendLongTextFrame) { "Sec-WebSocket-Key: key\r\n\r\n"); RunUntilRequestsReceived(1); ASSERT_TRUE(client.ReadResponse(&response)); - constexpr int kMessageSize = 100000; - const std::string text_message(kMessageSize, 'a'); - const std::string continuation_message(kMessageSize, 'b'); - const std::string text_frame = - EncodeFrame(text_message, WebSocketFrameHeader::OpCodeEnum::kOpCodeText, + constexpr int kFrameSize = 100000; + const std::string text_frame(kFrameSize, 'a'); + const std::string continuation_frame(kFrameSize, 'b'); + const std::string text_encoded_frame = + EncodeFrame(text_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeText, /* mask= */ true, /* finish= */ false); - const std::string continuation_frame = - EncodeFrame(continuation_message, + const std::string continuation_encoded_frame = EncodeFrame( + continuation_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation, + /* mask= */ true, /* finish= */ true); + client.Send(text_encoded_frame); + client.Send(continuation_encoded_frame); + std::string received_message = GetMessage(); + EXPECT_EQ(received_message.size(), + text_frame.size() + continuation_frame.size()); + EXPECT_EQ(received_message, text_frame + continuation_frame); +} + +TEST_F(WebSocketAcceptingTest, SendTwoTextFrame) { + TestHttpClient client; + CreateConnection(&client); + std::string response; + client.Send( + "GET /test HTTP/1.1\r\n" + "Upgrade: WebSocket\r\n" + "Connection: SomethingElse, Upgrade\r\n" + "Sec-WebSocket-Version: 8\r\n" + "Sec-WebSocket-Key: key\r\n\r\n"); + RunUntilRequestsReceived(1); + ASSERT_TRUE(client.ReadResponse(&response)); + const std::string text_frame_first = "foo"; + const std::string continuation_frame_first = "bar"; + const std::string text_encoded_frame_first = EncodeFrame( + text_frame_first, WebSocketFrameHeader::OpCodeEnum::kOpCodeText, + /* mask= */ true, + /* finish= */ false); + const std::string continuation_encoded_frame_first = + EncodeFrame(continuation_frame_first, + WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation, + /* mask= */ true, /* finish= */ true); + + const std::string text_frame_second = "FOO"; + const std::string continuation_frame_second = "BAR"; + const std::string text_encoded_frame_second = EncodeFrame( + text_frame_second, WebSocketFrameHeader::OpCodeEnum::kOpCodeText, + /* mask= */ true, + /* finish= */ false); + const std::string continuation_encoded_frame_second = + EncodeFrame(continuation_frame_second, WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation, /* mask= */ true, /* finish= */ true); - client.Send(text_frame); - client.Send(continuation_frame); + + // text_encoded_frame_first -> text_encoded_frame_second + client.Send(text_encoded_frame_first); + client.Send(continuation_encoded_frame_first); std::string received_message = GetMessage(); - EXPECT_EQ( - static_cast(received_message.length()), - static_cast(text_message.length() + continuation_message.length())); - EXPECT_EQ(received_message, text_message + continuation_message); + EXPECT_EQ(received_message, "foobar"); + client.Send(text_encoded_frame_second); + client.Send(continuation_encoded_frame_second); + received_message = GetMessage(); + EXPECT_EQ(received_message, "FOOBAR"); +} + +TEST_F(WebSocketAcceptingTest, SendPingPongFrame) { + TestHttpClient client; + CreateConnection(&client); + std::string response; + client.Send( + "GET /test HTTP/1.1\r\n" + "Upgrade: WebSocket\r\n" + "Connection: SomethingElse, Upgrade\r\n" + "Sec-WebSocket-Version: 8\r\n" + "Sec-WebSocket-Key: key\r\n\r\n"); + RunUntilRequestsReceived(1); + ASSERT_TRUE(client.ReadResponse(&response)); + + const std::string ping_message_first = ""; + const std::string ping_frame_first = EncodeFrame( + ping_message_first, WebSocketFrameHeader::OpCodeEnum::kOpCodePing, + /* mask= */ true, /* finish= */ true); + const std::string pong_frame_receive_first = EncodeFrame( + ping_message_first, WebSocketFrameHeader::OpCodeEnum::kOpCodePong, + /* mask= */ false, /* finish= */ true); + const std::string pong_frame_send = EncodeFrame( + /* message= */ "", WebSocketFrameHeader::OpCodeEnum::kOpCodePong, + /* mask= */ true, /* finish= */ true); + const std::string ping_message_second = "hello"; + const std::string ping_frame_second = EncodeFrame( + ping_message_second, WebSocketFrameHeader::OpCodeEnum::kOpCodePing, + /* mask= */ true, /* finish= */ true); + const std::string pong_frame_receive_second = EncodeFrame( + ping_message_second, WebSocketFrameHeader::OpCodeEnum::kOpCodePong, + /* mask= */ false, /* finish= */ true); + + // ping_frame_first -> pong_frame_send -> ping_frame_second + client.Send(ping_frame_first); + ASSERT_TRUE(client.Read(&response, pong_frame_receive_first.length())); + EXPECT_EQ(response, pong_frame_receive_first); + client.Send(pong_frame_send); + client.Send(ping_frame_second); + ASSERT_TRUE(client.Read(&response, pong_frame_receive_second.length())); + EXPECT_EQ(response, pong_frame_receive_second); +} + +TEST_F(WebSocketAcceptingTest, SendTextAndPingFrame) { + TestHttpClient client; + CreateConnection(&client); + std::string response; + client.Send( + "GET /test HTTP/1.1\r\n" + "Upgrade: WebSocket\r\n" + "Connection: SomethingElse, Upgrade\r\n" + "Sec-WebSocket-Version: 8\r\n" + "Sec-WebSocket-Key: key\r\n\r\n"); + RunUntilRequestsReceived(1); + ASSERT_TRUE(client.ReadResponse(&response)); + + const std::string text_frame = "foo"; + const std::string continuation_frame = "bar"; + const std::string text_encoded_frame = + EncodeFrame(text_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeText, + /* mask= */ true, + /* finish= */ false); + const std::string continuation_encoded_frame = EncodeFrame( + continuation_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation, + /* mask= */ true, /* finish= */ true); + const std::string ping_message = "ping"; + const std::string ping_frame = + EncodeFrame(ping_message, WebSocketFrameHeader::OpCodeEnum::kOpCodePing, + /* mask= */ true, /* finish= */ true); + const std::string pong_frame = + EncodeFrame(ping_message, WebSocketFrameHeader::OpCodeEnum::kOpCodePong, + /* mask= */ false, /* finish= */ true); + + // text_encoded_frame -> ping_frame -> continuation_encoded_frame + client.Send(text_encoded_frame); + client.Send(ping_frame); + client.Send(continuation_encoded_frame); + ASSERT_TRUE(client.Read(&response, pong_frame.length())); + EXPECT_EQ(response, pong_frame); + std::string received_message = GetMessage(); + EXPECT_EQ(received_message, "foobar"); +} + +TEST_F(WebSocketAcceptingTest, SendTextAndPingFrameWithMessage) { + TestHttpClient client; + CreateConnection(&client); + std::string response; + client.Send( + "GET /test HTTP/1.1\r\n" + "Upgrade: WebSocket\r\n" + "Connection: SomethingElse, Upgrade\r\n" + "Sec-WebSocket-Version: 8\r\n" + "Sec-WebSocket-Key: key\r\n\r\n"); + RunUntilRequestsReceived(1); + ASSERT_TRUE(client.ReadResponse(&response)); + + const std::string text_frame = "foo"; + const std::string continuation_frame = "bar"; + const std::string text_encoded_frame = + EncodeFrame(text_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeText, + /* mask= */ true, + /* finish= */ false); + const std::string continuation_encoded_frame = EncodeFrame( + continuation_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation, + /* mask= */ true, /* finish= */ true); + const std::string ping_message = "hello"; + const std::string ping_frame = + EncodeFrame(ping_message, WebSocketFrameHeader::OpCodeEnum::kOpCodePing, + /* mask= */ true, /* finish= */ true); + const std::string pong_frame = + EncodeFrame(ping_message, WebSocketFrameHeader::OpCodeEnum::kOpCodePong, + /* mask= */ false, /* finish= */ true); + + // text_encoded_frame -> ping_frame -> continuation_frame + client.Send(text_encoded_frame); + client.Send(ping_frame); + client.Send(continuation_encoded_frame); + ASSERT_TRUE(client.Read(&response, pong_frame.length())); + EXPECT_EQ(response, pong_frame); + std::string received_message = GetMessage(); + EXPECT_EQ(received_message, "foobar"); +} + +TEST_F(WebSocketAcceptingTest, SendTextAndPongFrame) { + TestHttpClient client; + CreateConnection(&client); + std::string response; + client.Send( + "GET /test HTTP/1.1\r\n" + "Upgrade: WebSocket\r\n" + "Connection: SomethingElse, Upgrade\r\n" + "Sec-WebSocket-Version: 8\r\n" + "Sec-WebSocket-Key: key\r\n\r\n"); + RunUntilRequestsReceived(1); + ASSERT_TRUE(client.ReadResponse(&response)); + + const std::string text_frame = "foo"; + const std::string continuation_frame = "bar"; + const std::string text_encoded_frame = + EncodeFrame(text_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeText, + /* mask= */ true, + /* finish= */ false); + const std::string continuation_encoded_frame = EncodeFrame( + continuation_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation, + /* mask= */ true, /* finish= */ true); + const std::string pong_message = "pong"; + const std::string pong_frame = + EncodeFrame(pong_message, WebSocketFrameHeader::OpCodeEnum::kOpCodePong, + /* mask= */ true, /* finish= */ true); + + // text_encoded_frame -> pong_frame -> continuation_encoded_frame + client.Send(text_encoded_frame); + client.Send(pong_frame); + client.Send(continuation_encoded_frame); + std::string received_message = GetMessage(); + EXPECT_EQ(received_message, "foobar"); +} + +TEST_F(WebSocketAcceptingTest, SendTextPingPongFrame) { + TestHttpClient client; + CreateConnection(&client); + std::string response; + client.Send( + "GET /test HTTP/1.1\r\n" + "Upgrade: WebSocket\r\n" + "Connection: SomethingElse, Upgrade\r\n" + "Sec-WebSocket-Version: 8\r\n" + "Sec-WebSocket-Key: key\r\n\r\n"); + RunUntilRequestsReceived(1); + ASSERT_TRUE(client.ReadResponse(&response)); + + const std::string text_frame = "foo"; + const std::string continuation_frame = "bar"; + const std::string text_encoded_frame = + EncodeFrame(text_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeText, + /* mask= */ true, + /* finish= */ false); + const std::string continuation_encoded_frame = EncodeFrame( + continuation_frame, WebSocketFrameHeader::OpCodeEnum::kOpCodeContinuation, + /* mask= */ true, /* finish= */ true); + + const std::string ping_message_first = "hello"; + const std::string ping_frame_first = EncodeFrame( + ping_message_first, WebSocketFrameHeader::OpCodeEnum::kOpCodePing, + /* mask= */ true, /* finish= */ true); + const std::string pong_frame_first = EncodeFrame( + ping_message_first, WebSocketFrameHeader::OpCodeEnum::kOpCodePong, + /* mask= */ false, /* finish= */ true); + + const std::string ping_message_second = "HELLO"; + const std::string ping_frame_second = EncodeFrame( + ping_message_second, WebSocketFrameHeader::OpCodeEnum::kOpCodePing, + /* mask= */ true, /* finish= */ true); + const std::string pong_frame_second = EncodeFrame( + ping_message_second, WebSocketFrameHeader::OpCodeEnum::kOpCodePong, + /* mask= */ false, /* finish= */ true); + + // text_encoded_frame -> ping_frame_first -> ping_frame_second -> + // continuation_encoded_frame + client.Send(text_encoded_frame); + client.Send(ping_frame_first); + ASSERT_TRUE(client.Read(&response, pong_frame_first.length())); + EXPECT_EQ(response, pong_frame_first); + client.Send(ping_frame_second); + ASSERT_TRUE(client.Read(&response, pong_frame_second.length())); + EXPECT_EQ(response, pong_frame_second); + client.Send(continuation_encoded_frame); + std::string received_message = GetMessage(); + EXPECT_EQ(received_message, "foobar"); } TEST_F(HttpServerTest, RequestWithTooLargeBody) { diff --git a/net/server/web_socket_encoder.cc b/net/server/web_socket_encoder.cc index 7e5e652786b961..a70e587861a382 100644 --- a/net/server/web_socket_encoder.cc +++ b/net/server/web_socket_encoder.cc @@ -301,22 +301,32 @@ WebSocket::ParseResult WebSocketEncoder::DecodeFrame( std::string current_output; WebSocket::ParseResult result = DecodeFrameHybi17( frame, type_ == FOR_SERVER, bytes_consumed, ¤t_output, &compressed); - if (result == WebSocket::FRAME_OK_FINAL || - result == WebSocket::FRAME_OK_MIDDLE || result == WebSocket::FRAME_PING) { - if (continuation_message_frames_.empty()) - is_current_message_compressed_ = compressed; - continuation_message_frames_.push_back(current_output); - } - if (result == WebSocket::FRAME_OK_FINAL || result == WebSocket::FRAME_PING) { - *output = base::StrCat(continuation_message_frames_); - if (is_current_message_compressed_) { - if (!Inflate(output)) - result = WebSocket::FRAME_ERROR; + switch (result) { + case WebSocket::FRAME_OK_FINAL: + case WebSocket::FRAME_OK_MIDDLE: { + if (continuation_message_frames_.empty()) + is_current_message_compressed_ = compressed; + continuation_message_frames_.push_back(current_output); + + if (result == WebSocket::FRAME_OK_FINAL) { + *output = base::StrCat(continuation_message_frames_); + continuation_message_frames_.clear(); + if (is_current_message_compressed_ && !Inflate(output)) { + return WebSocket::FRAME_ERROR; + } + } + break; } + + case WebSocket::FRAME_PING: + *output = current_output; + break; + + default: + // This function doesn't need special handling for other parse results. + break; } - if (result != WebSocket::FRAME_OK_MIDDLE && - result != WebSocket::FRAME_INCOMPLETE) - continuation_message_frames_.clear(); + return result; }