Skip to content
This repository was archived by the owner on Aug 11, 2020. It is now read-only.

Commit 21377c1

Browse files
committed
quic: ensure callbacks of QuicSocket.connect() get called
1. The callbacks of QuicSocket.connect() won't get called after QuicSocket binding. To fix this issue, This PR calls QuicSession[kReady]() directly when QuicSocket bound. 2. This PR also modify SocketAddress::Hash and SocketAddress::Compare to accept struct values instead of pointers because SocketAddress might get freed firstly which would cause the values aren't safe to use. For example, the test added in this PR is likely to abort before this PR.
1 parent 1c316ac commit 21377c1

File tree

8 files changed

+124
-27
lines changed

8 files changed

+124
-27
lines changed

lib/internal/quic/core.js

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,6 +1063,9 @@ class QuicSocket extends EventEmitter {
10631063
if (typeof callback === 'function')
10641064
session.on('ready', callback);
10651065

1066+
if (this.bound)
1067+
session[kReady]();
1068+
10661069
this[kMaybeBind](connectAfterBind.bind(
10671070
this,
10681071
session,

src/node_crypto.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030
#include "env.h"
3131
#include "base_object.h"
32+
// TODO do not included
33+
#include "base_object-inl.h"
3234
#include "util.h"
3335

3436
#include "v8.h"

src/node_quic_session.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2050,7 +2050,7 @@ void QuicSession::RemoveFromSocket() {
20502050
socket_->DisassociateCID(QuicCID(&cid));
20512051

20522052
Debug(this, "Removed from the QuicSocket.");
2053-
socket_->RemoveSession(QuicCID(scid_), **GetRemoteAddress());
2053+
socket_->RemoveSession(QuicCID(scid_), GetRemoteAddress()->GetSockaddrStorage());
20542054
socket_.reset();
20552055
}
20562056

src/node_quic_socket.cc

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ void QuicSocket::AddSession(
247247
const QuicCID& cid,
248248
BaseObjectPtr<QuicSession> session) {
249249
sessions_[cid.ToStr()] = session;
250-
IncrementSocketAddressCounter(**session->GetRemoteAddress());
250+
IncrementSocketAddressCounter(session->GetRemoteAddress()->GetSockaddrStorage());
251251
IncrementSocketStat(
252252
1, &socket_stats_,
253253
session->IsServer() ?
@@ -485,7 +485,7 @@ int QuicSocket::ReceiveStop() {
485485
return udp_->RecvStop();
486486
}
487487

488-
void QuicSocket::RemoveSession(const QuicCID& cid, const sockaddr* addr) {
488+
void QuicSocket::RemoveSession(const QuicCID& cid, const sockaddr_storage* addr) {
489489
sessions_.erase(cid.ToStr());
490490
DecrementSocketAddressCounter(addr);
491491
}
@@ -659,7 +659,7 @@ namespace {
659659
void QuicSocket::SetValidatedAddress(const sockaddr* addr) {
660660
if (IsOptionSet(QUICSOCKET_OPTIONS_VALIDATE_ADDRESS_LRU)) {
661661
// Remove the oldest item if we've hit the LRU limit
662-
validated_addrs_.push_back(addr_hash(addr));
662+
validated_addrs_.push_back(addr_hash(*addr));
663663
if (validated_addrs_.size() > MAX_VALIDATE_ADDRESS_LRU)
664664
validated_addrs_.pop_front();
665665
}
@@ -669,7 +669,7 @@ bool QuicSocket::IsValidatedAddress(const sockaddr* addr) const {
669669
if (IsOptionSet(QUICSOCKET_OPTIONS_VALIDATE_ADDRESS_LRU)) {
670670
auto res = std::find(std::begin(validated_addrs_),
671671
std::end(validated_addrs_),
672-
addr_hash(addr));
672+
addr_hash(*addr));
673673
return res != std::end(validated_addrs_);
674674
}
675675
return false;
@@ -721,9 +721,13 @@ BaseObjectPtr<QuicSession> QuicSocket::AcceptInitialPacket(
721721
// Check to see if the number of connections for this peer has been exceeded.
722722
// If the count has been exceeded, shutdown the connection immediately
723723
// after the initial keys are installed.
724-
if (GetCurrentSocketAddressCounter(addr) >= max_connections_per_host_) {
725-
Debug(this, "Connection count for address exceeded");
726-
initial_connection_close = NGTCP2_SERVER_BUSY;
724+
{
725+
sockaddr_storage storage;
726+
memcpy(&storage, addr, SocketAddress::GetLength(addr));
727+
if (GetCurrentSocketAddressCounter(&storage) >= max_connections_per_host_) {
728+
Debug(this, "Connection count for address exceeded");
729+
initial_connection_close = NGTCP2_SERVER_BUSY;
730+
}
727731
}
728732

729733
// QUIC has address validation built in to the handshake but allows for
@@ -782,22 +786,22 @@ BaseObjectPtr<QuicSession> QuicSocket::AcceptInitialPacket(
782786
return session;
783787
}
784788

785-
void QuicSocket::IncrementSocketAddressCounter(const sockaddr* addr) {
786-
addr_counts_[addr]++;
789+
void QuicSocket::IncrementSocketAddressCounter(const sockaddr_storage* addr) {
790+
addr_counts_[*addr]++;
787791
}
788792

789-
void QuicSocket::DecrementSocketAddressCounter(const sockaddr* addr) {
790-
auto it = addr_counts_.find(addr);
793+
void QuicSocket::DecrementSocketAddressCounter(const sockaddr_storage* addr) {
794+
auto it = addr_counts_.find(*addr);
791795
if (it == std::end(addr_counts_))
792796
return;
793797
it->second--;
794798
// Remove the address if the counter reaches zero again.
795799
if (it->second == 0)
796-
addr_counts_.erase(addr);
800+
addr_counts_.erase(*addr);
797801
}
798802

799-
size_t QuicSocket::GetCurrentSocketAddressCounter(const sockaddr* addr) {
800-
auto it = addr_counts_.find(addr);
803+
size_t QuicSocket::GetCurrentSocketAddressCounter(const sockaddr_storage* addr) {
804+
auto it = addr_counts_.find(*addr);
801805
if (it == std::end(addr_counts_))
802806
return 0;
803807
return it->second;

src/node_quic_socket.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ class QuicSocket : public AsyncWrap,
133133
int ReceiveStop();
134134
void RemoveSession(
135135
const QuicCID& cid,
136-
const sockaddr* addr);
136+
const sockaddr_storage* addr);
137137
void ReportSendError(
138138
int error);
139139
int SendPacket(
@@ -233,9 +233,9 @@ class QuicSocket : public AsyncWrap,
233233
const struct sockaddr* addr,
234234
unsigned int flags);
235235

236-
void IncrementSocketAddressCounter(const sockaddr* addr);
237-
void DecrementSocketAddressCounter(const sockaddr* addr);
238-
size_t GetCurrentSocketAddressCounter(const sockaddr* addr);
236+
void IncrementSocketAddressCounter(const sockaddr_storage* addr);
237+
void DecrementSocketAddressCounter(const sockaddr_storage* addr);
238+
size_t GetCurrentSocketAddressCounter(const sockaddr_storage* addr);
239239

240240
void IncrementPendingCallbacks() { pending_callbacks_++; }
241241
void DecrementPendingCallbacks() { pending_callbacks_--; }
@@ -315,7 +315,7 @@ class QuicSocket : public AsyncWrap,
315315
// value reaches the value of max_connections_per_host_,
316316
// attempts to create new connections will be ignored
317317
// until the value falls back below the limit.
318-
std::unordered_map<const sockaddr*, size_t, SocketAddress::Hash,
318+
std::unordered_map<const sockaddr_storage, size_t, SocketAddress::Hash,
319319
SocketAddress::Compare> addr_counts_;
320320

321321
// The validated_addrs_ vector is used as an LRU cache for

src/node_sockaddr-inl.h

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ inline void hash_combine(size_t* seed, const T& value, Args... rest) {
2525
}
2626
} // namespace
2727

28-
size_t SocketAddress::Hash::operator()(const sockaddr* addr) const {
28+
static size_t GetHash(const sockaddr* addr) {
2929
size_t hash = 0;
3030
switch (addr->sa_family) {
3131
case AF_INET: {
@@ -48,11 +48,20 @@ size_t SocketAddress::Hash::operator()(const sockaddr* addr) const {
4848
return hash;
4949
}
5050

51+
size_t SocketAddress::Hash::operator()(const sockaddr& addr) const {
52+
return GetHash(&addr);
53+
}
54+
55+
size_t SocketAddress::Hash::operator()(const sockaddr_storage& addr_storage) const {
56+
const sockaddr* addr = reinterpret_cast<const sockaddr*>(&addr_storage);
57+
return GetHash(addr);
58+
}
59+
5160
bool SocketAddress::Compare::operator()(
52-
const sockaddr* laddr,
53-
const sockaddr* raddr) const {
54-
CHECK(laddr->sa_family == AF_INET || laddr->sa_family == AF_INET6);
55-
return memcmp(laddr, raddr, GetLength(laddr)) == 0;
61+
const sockaddr_storage& laddr,
62+
const sockaddr_storage& raddr) const {
63+
CHECK(laddr.ss_family == AF_INET || laddr.ss_family == AF_INET6);
64+
return memcmp(&laddr, &raddr, GetLength(&laddr)) == 0;
5665
}
5766

5867
bool SocketAddress::is_numeric_host(const char* hostname) {
@@ -146,6 +155,10 @@ const sockaddr* SocketAddress::operator*() const {
146155
return reinterpret_cast<const sockaddr*>(&address_);
147156
}
148157

158+
const sockaddr_storage* SocketAddress::GetSockaddrStorage() const {
159+
return &address_;
160+
}
161+
149162
size_t SocketAddress::GetLength() const {
150163
return GetLength(&address_);
151164
}

src/node_sockaddr.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@ namespace node {
1515
class SocketAddress {
1616
public:
1717
struct Hash {
18-
inline size_t operator()(const sockaddr* addr) const;
18+
inline size_t operator()(const sockaddr& addr) const;
19+
inline size_t operator()(const sockaddr_storage& addr_storage) const;
1920
};
2021

2122
struct Compare {
22-
inline bool operator()(const sockaddr* laddr, const sockaddr* raddr) const;
23+
inline bool operator()(const sockaddr_storage& laddr, const sockaddr_storage& raddr) const;
2324
};
2425

2526
inline static bool is_numeric_host(const char* hostname);
@@ -56,6 +57,8 @@ class SocketAddress {
5657

5758
inline const sockaddr* operator*() const;
5859

60+
inline const sockaddr_storage* GetSockaddrStorage() const;
61+
5962
inline size_t GetLength() const;
6063

6164
inline int GetFamily() const;
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
'use strict';
2+
3+
const common = require('../common');
4+
if (!common.hasQuic)
5+
common.skip('missing quic');
6+
7+
const { createSocket } = require('quic');
8+
const fixtures = require('../common/fixtures');
9+
const Countdown = require('../common/countdown');
10+
const key = fixtures.readKey('agent1-key.pem', 'binary');
11+
const cert = fixtures.readKey('agent1-cert.pem', 'binary');
12+
const ca = fixtures.readKey('ca1-cert.pem', 'binary');
13+
14+
const kServerName = 'agent2';
15+
const kALPN = 'zzz';
16+
const kIdleTimeout = 0;
17+
const kConnections = 5;
18+
19+
// After QuicSocket bound, the callback of QuicSocket.connect()
20+
// should still get called.
21+
{
22+
let client;
23+
const server = createSocket({
24+
port: 0,
25+
});
26+
27+
server.listen({
28+
key,
29+
cert,
30+
ca,
31+
alpn: kALPN,
32+
idleTimeout: kIdleTimeout,
33+
});
34+
35+
const countdown = new Countdown(kConnections, () => {
36+
client.close();
37+
server.close();
38+
});
39+
40+
server.on('ready', common.mustCall(() => {
41+
const options = {
42+
key,
43+
cert,
44+
ca,
45+
address: common.localhostIPv4,
46+
port: server.address.port,
47+
servername: kServerName,
48+
alpn: kALPN,
49+
idleTimeout: kIdleTimeout,
50+
};
51+
52+
client = createSocket({
53+
port: 0,
54+
});
55+
56+
const session = client.connect(options, common.mustCall(() => {
57+
session.close(common.mustCall(() => {
58+
// After a session being ready, the socket should have bound
59+
// and we could start the test.
60+
testConnections();
61+
}));
62+
}));
63+
64+
const testConnections = common.mustCall(() => {
65+
for (let i = 0; i < kConnections; i += 1) {
66+
client.connect(options, common.mustCall(() => {
67+
countdown.dec();
68+
}));
69+
}
70+
});
71+
}));
72+
}

0 commit comments

Comments
 (0)