Skip to content

Commit 5182d15

Browse files
author
Stepan Bagritsevich
committed
fix(zset): fix random in ZRANDMEMBER command
fixes dragonflydb#2850 Signed-off-by: Stepan Bagritsevich <[email protected]>
1 parent 5d66e2f commit 5182d15

File tree

2 files changed

+261
-47
lines changed

2 files changed

+261
-47
lines changed

src/server/zset_family.cc

Lines changed: 106 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,65 @@ OpResult<DbSlice::ItAndUpdater> FindZEntry(const ZParams& zparams, const OpArgs&
222222
return DbSlice::ItAndUpdater{add_res.it, add_res.exp_it, std::move(add_res.post_updater)};
223223
}
224224

225+
using RandomPick = std::size_t;
226+
227+
class PicksGenerator {
228+
public:
229+
virtual RandomPick Generate() = 0;
230+
virtual ~PicksGenerator() = default;
231+
};
232+
233+
class NonUniquePicksGenerator : public PicksGenerator {
234+
public:
235+
NonUniquePicksGenerator(std::size_t total_size) : total_size_(total_size) {
236+
CHECK_GT(total_size, std::size_t(0));
237+
}
238+
239+
RandomPick Generate() override {
240+
return absl::Uniform(bitgen_, 0u, total_size_);
241+
}
242+
243+
private:
244+
const std::size_t total_size_;
245+
absl::BitGen bitgen_{};
246+
};
247+
248+
/*
249+
* Generates unique index in O(1).
250+
*
251+
* picks_count specifies the number of random indexes to be generated.
252+
* In other words, this is the number of times the Generate() function is called.
253+
*
254+
* The class uses Robert Floyd's sampling algorithm
255+
* https://dl.acm.org/doi/pdf/10.1145/30401.315746
256+
* */
257+
class UniquePicksGenerator : public PicksGenerator {
258+
public:
259+
UniquePicksGenerator(std::size_t picks_count, std::size_t total_size)
260+
: picked_indexes_(picks_count) {
261+
CHECK_GE(total_size, picks_count);
262+
current_random_limit_ = total_size - picks_count;
263+
}
264+
265+
RandomPick Generate() override {
266+
const std::size_t max_index = current_random_limit_++;
267+
const RandomPick random_index = absl::Uniform(bitgen_, 0u, max_index + 1u);
268+
269+
const bool random_index_is_picked = picked_indexes_.emplace(random_index).second;
270+
if (random_index_is_picked) {
271+
return random_index;
272+
}
273+
274+
picked_indexes_.insert(max_index);
275+
return max_index;
276+
}
277+
278+
private:
279+
std::size_t current_random_limit_;
280+
std::unordered_set<RandomPick> picked_indexes_;
281+
absl::BitGen bitgen_{};
282+
};
283+
225284
bool ScoreToLongLat(const std::optional<double>& val, double* xy) {
226285
if (!val.has_value())
227286
return false;
@@ -1702,6 +1761,48 @@ OpResult<StringVec> OpScan(const OpArgs& op_args, std::string_view key, uint64_t
17021761
return res;
17031762
}
17041763

1764+
OpResult<ScoredArray> OpRandMember(int count, const ZSetFamily::RangeParams& params,
1765+
const OpArgs& op_args, string_view key) {
1766+
auto it = op_args.shard->db_slice().FindReadOnly(op_args.db_cntx, key, OBJ_ZSET);
1767+
if (!it)
1768+
return it.status();
1769+
1770+
// Action::RANGE is a read-only operation, but requires const_cast
1771+
PrimeValue& pv = const_cast<PrimeValue&>(it.value()->second);
1772+
1773+
const std::size_t size = pv.Size();
1774+
const std::size_t picks_count =
1775+
count >= 0 ? std::min(static_cast<std::size_t>(count), size) : std::abs(count);
1776+
1777+
ScoredArray result{picks_count};
1778+
std::unique_ptr<PicksGenerator> generator =
1779+
count >= 0 ? static_cast<std::unique_ptr<PicksGenerator>>(
1780+
std::make_unique<UniquePicksGenerator>(picks_count, size))
1781+
: std::make_unique<NonUniquePicksGenerator>(size);
1782+
1783+
if (picks_count * static_cast<std::uint64_t>(std::log2(size)) < size) {
1784+
for (std::size_t i = 0; i < picks_count; i++) {
1785+
const std::size_t picked_index = generator->Generate();
1786+
1787+
IntervalVisitor iv{Action::RANGE, params, &pv};
1788+
iv(ZSetFamily::IndexInterval{picked_index, picked_index});
1789+
1790+
result[i] = iv.PopResult().front();
1791+
}
1792+
} else {
1793+
IntervalVisitor iv{Action::RANGE, params, &pv};
1794+
iv(ZSetFamily::IndexInterval{0, -1});
1795+
1796+
ScoredArray all_elements = iv.PopResult();
1797+
1798+
for (std::size_t i = 0; i < picks_count; i++) {
1799+
result[i] = all_elements[generator->Generate()];
1800+
}
1801+
}
1802+
1803+
return result;
1804+
}
1805+
17051806
void ZAddGeneric(string_view key, const ZParams& zparams, ScoredMemberSpan memb_sp,
17061807
ConnectionContext* cntx) {
17071808
auto cb = [&](Transaction* t, EngineShard* shard) {
@@ -2323,43 +2424,32 @@ void ZSetFamily::ZRandMember(CmdArgList args, ConnectionContext* cntx) {
23232424
if (args.size() > 3)
23242425
return cntx->SendError(WrongNumArgsError("ZRANDMEMBER"));
23252426

2326-
ZRangeSpec range_spec;
2327-
range_spec.interval = IndexInterval(0, -1);
2328-
23292427
CmdArgParser parser{args};
23302428
string_view key = parser.Next();
23312429

23322430
bool is_count = parser.HasNext();
23332431
int count = is_count ? parser.Next<int>() : 1;
23342432

2335-
range_spec.params.with_scores = static_cast<bool>(parser.Check("WITHSCORES").IgnoreCase());
2433+
ZSetFamily::RangeParams params;
2434+
params.with_scores = static_cast<bool>(parser.Check("WITHSCORES").IgnoreCase());
23362435

23372436
if (parser.HasNext())
23382437
return cntx->SendError(absl::StrCat("Unsupported option:", string_view(parser.Next())));
23392438

23402439
if (auto err = parser.Error(); err)
23412440
return cntx->SendError(err->MakeReply());
23422441

2343-
bool sign = count < 0;
2344-
range_spec.params.limit = std::abs(count);
2345-
23462442
const auto cb = [&](Transaction* t, EngineShard* shard) {
2347-
return OpRange(range_spec, t->GetOpArgs(shard), key);
2443+
return OpRandMember(count, params, t->GetOpArgs(shard), key);
23482444
};
23492445

23502446
OpResult<ScoredArray> result = cntx->transaction->ScheduleSingleHopT(cb);
23512447
auto* rb = static_cast<RedisReplyBuilder*>(cntx->reply_builder());
23522448
if (result) {
2353-
if (sign && !result->empty()) {
2354-
for (auto i = result->size(); i < range_spec.params.limit; ++i) {
2355-
// we can return duplicate elements, so first is OK
2356-
result->push_back(result->front());
2357-
}
2358-
}
2359-
rb->SendScoredArray(result.value(), range_spec.params.with_scores);
2449+
rb->SendScoredArray(result.value(), params.with_scores);
23602450
} else if (result.status() == OpStatus::KEY_NOTFOUND) {
23612451
if (is_count) {
2362-
rb->SendScoredArray(ScoredArray(), range_spec.params.with_scores);
2452+
rb->SendScoredArray(ScoredArray(), params.with_scores);
23632453
} else {
23642454
rb->SendNull();
23652455
}

src/server/zset_family_test.cc

Lines changed: 155 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,88 @@ class ZSetFamilyTest : public BaseFamilyTest {
2020
protected:
2121
};
2222

23+
using ScoredElement = std::pair<std::string, std::string>;
24+
25+
auto ParseToScoredArray(auto vec) {
26+
std::vector<ScoredElement> scored_elements;
27+
for (std::size_t i = 1; i < vec.size(); i += 2) {
28+
scored_elements.emplace_back(vec[i - 1].GetString(), vec[i].GetString());
29+
}
30+
return scored_elements;
31+
}
32+
33+
MATCHER_P(ConsistsOfMatcher, elements, "") {
34+
auto vec = arg.GetVec();
35+
for (const auto& x : vec) {
36+
if (elements.find(x.GetString()) == elements.end()) {
37+
return false;
38+
}
39+
}
40+
return true;
41+
}
42+
43+
MATCHER_P(ConsistsOfScoredElementsMatcher, elements, "") {
44+
auto vec = arg.GetVec();
45+
if (vec.size() % 2) {
46+
return false;
47+
}
48+
49+
auto scored_vec = ParseToScoredArray(vec);
50+
for (const auto& scored_element : scored_vec) {
51+
if (elements.find(scored_element) == elements.end()) {
52+
return false;
53+
}
54+
}
55+
return true;
56+
}
57+
58+
MATCHER_P(IsScoredSubsetOfMatcher, elements_list, "") {
59+
auto vec = arg.GetVec();
60+
if (vec.size() % 2) {
61+
return false;
62+
}
63+
64+
auto scored_vec = ParseToScoredArray(vec);
65+
66+
std::vector<ScoredElement> diff;
67+
std::set_difference(scored_vec.begin(), scored_vec.end(), elements_list.begin(),
68+
elements_list.end(), std::back_inserter(diff));
69+
70+
return diff.empty();
71+
}
72+
73+
MATCHER_P(UnorderedScoredElementsAreMatcher, elements_list, "") {
74+
auto vec = arg.GetVec();
75+
if (vec.size() % 2) {
76+
return false;
77+
}
78+
79+
auto scored_vec = ParseToScoredArray(vec);
80+
81+
std::vector<ScoredElement> diff;
82+
std::set_difference(scored_vec.begin(), scored_vec.end(), elements_list.begin(),
83+
elements_list.end(), std::back_inserter(diff));
84+
85+
return diff.empty() && scored_vec.size() == elements_list.size();
86+
}
87+
88+
auto ConsistsOf(std::initializer_list<std::string> elements) {
89+
return ConsistsOfMatcher(std::unordered_set<std::string>{elements});
90+
}
91+
92+
auto ConsistsOfScoredElements(std::initializer_list<std::pair<std::string, std::string>> elements) {
93+
return ConsistsOfScoredElementsMatcher(std::set<std::pair<std::string, std::string>>{elements});
94+
}
95+
96+
auto IsScoredSubsetOf(std::initializer_list<std::pair<std::string, std::string>> elements) {
97+
return IsScoredSubsetOfMatcher(elements);
98+
}
99+
100+
auto UnorderedScoredElementsAre(
101+
std::initializer_list<std::pair<std::string, std::string>> elements) {
102+
return UnorderedScoredElementsAreMatcher(elements);
103+
}
104+
23105
TEST_F(ZSetFamilyTest, Add) {
24106
auto resp = Run({"zadd", "x", "1.1", "a"});
25107
EXPECT_THAT(resp, IntArg(1));
@@ -77,53 +159,95 @@ TEST_F(ZSetFamilyTest, ZRem) {
77159
}
78160

79161
TEST_F(ZSetFamilyTest, ZRandMember) {
80-
auto resp = Run({
81-
"zadd",
82-
"x",
83-
"1",
84-
"a",
85-
"2",
86-
"b",
87-
"3",
88-
"c",
89-
});
162+
auto resp = Run({"ZAdd", "x", "1", "a", "2", "b", "3", "c"});
163+
EXPECT_THAT(resp, IntArg(3));
164+
165+
// Test if count > 0
90166
resp = Run({"ZRandMember", "x"});
91167
ASSERT_THAT(resp, ArgType(RespExpr::STRING));
92-
EXPECT_THAT(resp, "a");
168+
EXPECT_THAT(resp, AnyOf("a", "b", "c"));
169+
170+
resp = Run({"ZRandMember", "x", "1"});
171+
ASSERT_THAT(resp, ArgType(RespExpr::STRING));
172+
EXPECT_THAT(resp, AnyOf("a", "b", "c"));
93173

94174
resp = Run({"ZRandMember", "x", "2"});
95-
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
96-
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "b"));
175+
ASSERT_THAT(resp, ArrLen(2));
176+
EXPECT_THAT(resp.GetVec(), IsSubsetOf({"a", "b", "c"}));
97177

98-
resp = Run({"ZRandMember", "x", "0"});
99-
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
100-
EXPECT_EQ(resp.GetVec().size(), 0);
178+
resp = Run({"ZRandMember", "x", "3"});
179+
ASSERT_THAT(resp, ArrLen(3));
180+
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "b", "c"));
101181

102-
resp = Run({"ZRandMember", "k"});
103-
ASSERT_THAT(resp, ArgType(RespExpr::NIL));
182+
// Test if count < 0
183+
resp = Run({"ZRandMember", "x", "-1"});
184+
ASSERT_THAT(resp, ArgType(RespExpr::STRING));
185+
EXPECT_THAT(resp, AnyOf("a", "b", "c"));
104186

105-
resp = Run({"ZRandMember", "k", "2"});
106-
ASSERT_THAT(resp, ArgType(RespExpr::ARRAY));
107-
EXPECT_EQ(resp.GetVec().size(), 0);
187+
resp = Run({"ZRandMember", "x", "-2"});
188+
ASSERT_THAT(resp, ArrLen(2));
189+
EXPECT_THAT(resp, ConsistsOf({"a", "b", "c"}));
190+
191+
resp = Run({"ZRandMember", "x", "-3"});
192+
ASSERT_THAT(resp, ArrLen(3));
193+
EXPECT_THAT(resp, ConsistsOf({"a", "b", "c"}));
194+
195+
// Test if count < 0, but the absolute value is larger than the size of the sorted set
196+
resp = Run({"ZRandMember", "x", "-15"});
197+
ASSERT_THAT(resp, ArrLen(15));
198+
EXPECT_THAT(resp, ConsistsOf({"a", "b", "c"}));
108199

109-
resp = Run({"ZRandMember", "x", "-5"});
110-
ASSERT_THAT(resp, ArrLen(5));
111-
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "b", "c", "a", "a"));
200+
// Test if count is 0
201+
ASSERT_THAT(Run({"ZRandMember", "x", "0"}), ArrLen(0));
112202

113-
resp = Run({"ZRandMember", "x", "5"});
203+
// Test if count is larger than the size of the sorted set
204+
resp = Run({"ZRandMember", "x", "15"});
114205
ASSERT_THAT(resp, ArrLen(3));
115206
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "b", "c"));
116207

117-
resp = Run({"ZRandMember", "x", "-5", "WITHSCORES"});
118-
ASSERT_THAT(resp, ArrLen(10));
119-
EXPECT_THAT(resp.GetVec(), ElementsAre("a", "1", "b", "2", "c", "3", "a", "1", "a", "1"));
208+
// Test if sorted set is empty
209+
EXPECT_THAT(Run({"ZAdd", "empty::zset", "1", "one"}), IntArg(1));
210+
EXPECT_THAT(Run({"ZRem", "empty::zset", "one"}), IntArg(1));
211+
ASSERT_THAT(Run({"ZRandMember", "empty::zset", "0"}), ArrLen(0));
212+
ASSERT_THAT(Run({"ZRandMember", "empty::zset", "3"}), ArrLen(0));
213+
ASSERT_THAT(Run({"ZRandMember", "empty::zset", "-4"}), ArrLen(0));
214+
215+
// Test if key does not exist
216+
ASSERT_THAT(Run({"ZRandMember", "y"}), ArgType(RespExpr::NIL));
217+
ASSERT_THAT(Run({"ZRandMember", "y", "0"}), ArrLen(0));
218+
219+
// Test WITHSCORES
220+
resp = Run({"ZRandMember", "x", "1", "WITHSCORES"});
221+
ASSERT_THAT(resp, ArrLen(2));
222+
EXPECT_THAT(resp, IsScoredSubsetOf({{"a", "1"}, {"b", "2"}, {"c", "3"}}));
223+
224+
resp = Run({"ZRandMember", "x", "2", "WITHSCORES"});
225+
ASSERT_THAT(resp, ArrLen(4));
226+
EXPECT_THAT(resp, IsScoredSubsetOf({{"a", "1"}, {"b", "2"}, {"c", "3"}}));
120227

121228
resp = Run({"ZRandMember", "x", "3", "WITHSCORES"});
122229
ASSERT_THAT(resp, ArrLen(6));
123-
EXPECT_THAT(resp.GetVec(), UnorderedElementsAre("a", "1", "b", "2", "c", "3"));
230+
EXPECT_THAT(resp, UnorderedScoredElementsAre({{"a", "1"}, {"b", "2"}, {"c", "3"}}));
124231

125-
resp = Run({"ZRandMember", "x", "3", "WITHSCORES", "test"});
126-
EXPECT_THAT(resp, ErrArg("wrong number of arguments"));
232+
resp = Run({"ZRandMember", "x", "15", "WITHSCORES"});
233+
ASSERT_THAT(resp, ArrLen(6));
234+
EXPECT_THAT(resp, UnorderedScoredElementsAre({{"a", "1"}, {"b", "2"}, {"c", "3"}}));
235+
236+
resp = Run({"ZRandMember", "x", "-1", "WITHSCORES"});
237+
ASSERT_THAT(resp, ArrLen(2));
238+
EXPECT_THAT(resp, ConsistsOfScoredElements({{"a", "1"}, {"b", "2"}, {"c", "3"}}));
239+
240+
resp = Run({"ZRandMember", "x", "-2", "WITHSCORES"});
241+
ASSERT_THAT(resp, ArrLen(4));
242+
EXPECT_THAT(resp, ConsistsOfScoredElements({{"a", "1"}, {"b", "2"}, {"c", "3"}}));
243+
244+
resp = Run({"ZRandMember", "x", "-3", "WITHSCORES"});
245+
ASSERT_THAT(resp, ArrLen(6));
246+
EXPECT_THAT(resp, ConsistsOfScoredElements({{"a", "1"}, {"b", "2"}, {"c", "3"}}));
247+
248+
resp = Run({"ZRandMember", "x", "-15", "WITHSCORES"});
249+
ASSERT_THAT(resp, ArrLen(30));
250+
EXPECT_THAT(resp, ConsistsOfScoredElements({{"a", "1"}, {"b", "2"}, {"c", "3"}}));
127251
}
128252

129253
TEST_F(ZSetFamilyTest, ZMScore) {

0 commit comments

Comments
 (0)