Skip to content

Commit 7da7140

Browse files
duxiao1212facebook-github-bot
authored andcommitted
feat: Impl sort key for LocalShuffleReader (#26620)
Summary: Implement sorted shuffle k-way merge for LocalShuffleReader, when it's sortedShuffle. Added k-way merge support using TreeOfLosers to efficiently merge multiple sorted shuffle files. The reader streams data from sorted files and returns merged results in sorted order. Reviewed By: tanjialiang, xiaoxmeng Differential Revision: D86888221
1 parent ea7978b commit 7da7140

File tree

7 files changed

+425
-47
lines changed

7 files changed

+425
-47
lines changed

presto-native-execution/presto_cpp/main/operators/LocalShuffle.cpp

Lines changed: 194 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,96 @@
1515
#include "presto_cpp/external/json/nlohmann/json.hpp"
1616
#include "presto_cpp/main/common/Configs.h"
1717

18-
#include <folly/lang/Bits.h>
18+
#include "velox/common/Casts.h"
19+
#include "velox/common/file/FileInputStream.h"
1920

20-
using namespace facebook::velox::exec;
21-
using namespace facebook::velox;
21+
#include <boost/range/algorithm/sort.hpp>
2222

2323
namespace facebook::presto::operators {
2424

2525
using json = nlohmann::json;
2626

2727
namespace {
2828

29+
using TStreamIdx = uint16_t;
30+
31+
// Default buffer size for SortedFileInputStream
32+
// This buffer is used for streaming reads from shuffle files during k-way
33+
// merge.
34+
constexpr uint64_t kDefaultInputStreamBufferSize = 8 * 1024 * 1024; // 8MB
35+
36+
/// SortedFileInputStream reads sorted (key, data) pairs from a single
37+
/// shuffle file with buffered I/O. It extends FileInputStream for efficient
38+
/// buffered I/O and implements MergeStream interface for k-way merge.
39+
class SortedFileInputStream final : public velox::common::FileInputStream,
40+
public velox::MergeStream {
41+
public:
42+
SortedFileInputStream(
43+
const std::string& filePath,
44+
TStreamIdx streamIdx,
45+
velox::memory::MemoryPool* pool,
46+
size_t bufferSize = kDefaultInputStreamBufferSize)
47+
: velox::common::FileInputStream(
48+
velox::filesystems::getFileSystem(filePath, nullptr)
49+
->openFileForRead(filePath),
50+
bufferSize,
51+
pool),
52+
streamIdx_(streamIdx) {
53+
next();
54+
}
55+
56+
~SortedFileInputStream() override = default;
57+
58+
bool next() {
59+
if (atEnd()) {
60+
currentKey_.clear();
61+
currentValue_.clear();
62+
return false;
63+
}
64+
const TRowSize keySize = folly::Endian::big(read<TRowSize>());
65+
const TRowSize valueSize = folly::Endian::big(read<TRowSize>());
66+
67+
// TODO: Optimize with zero-copy approach when data is contiguous in buffer.
68+
readString(currentKey_, keySize);
69+
readString(currentValue_, valueSize);
70+
return true;
71+
}
72+
73+
std::string_view currentKey() const {
74+
return currentKey_;
75+
}
76+
77+
std::string_view currentValue() const {
78+
return currentValue_;
79+
}
80+
81+
bool hasData() const override {
82+
return !currentValue_.empty() || !atEnd();
83+
}
84+
85+
bool operator<(const velox::MergeStream& other) const override {
86+
const auto* otherReader = static_cast<const SortedFileInputStream*>(&other);
87+
if (currentKey_ != otherReader->currentKey_) {
88+
return compareKeys(currentKey_, otherReader->currentKey_);
89+
}
90+
return streamIdx_ < otherReader->streamIdx_;
91+
}
92+
93+
private:
94+
void readString(std::string& target, TRowSize size) {
95+
if (size > 0) {
96+
target.resize(size);
97+
readBytes(reinterpret_cast<uint8_t*>(target.data()), size);
98+
} else {
99+
target.clear();
100+
}
101+
}
102+
103+
const TStreamIdx streamIdx_;
104+
std::string currentKey_;
105+
std::string currentValue_;
106+
};
107+
29108
std::vector<RowMetadata>
30109
extractRowMetadata(const char* buffer, size_t bufferSize, bool sortedShuffle) {
31110
std::vector<RowMetadata> rows;
@@ -91,13 +170,9 @@ extractRowMetadata(const char* buffer, size_t bufferSize, bool sortedShuffle) {
91170

92171
inline std::string_view
93172
extractRowData(const RowMetadata& row, const char* buffer, bool sortedShuffle) {
94-
if (sortedShuffle) {
95-
const size_t dataOffset = row.rowStart + (kUint32Size * 2) + row.keySize;
96-
return {buffer + dataOffset, row.dataSize};
97-
} else {
98-
const size_t dataOffset = row.rowStart + kUint32Size;
99-
return {buffer + dataOffset, row.dataSize};
100-
}
173+
const auto dataOffset = row.rowStart +
174+
(sortedShuffle ? (kUint32Size * 2) + row.keySize : kUint32Size);
175+
return {buffer + dataOffset, row.dataSize};
101176
}
102177

103178
std::vector<RowMetadata> extractAndSortRowMetadata(
@@ -106,10 +181,8 @@ std::vector<RowMetadata> extractAndSortRowMetadata(
106181
bool sortedShuffle) {
107182
auto rows = extractRowMetadata(buffer, bufferSize, sortedShuffle);
108183
if (!rows.empty() && sortedShuffle) {
109-
std::sort(
110-
rows.begin(),
111-
rows.end(),
112-
[buffer](const RowMetadata& lhs, const RowMetadata& rhs) {
184+
boost::range::sort(
185+
rows, [buffer](const RowMetadata& lhs, const RowMetadata& rhs) {
113186
const char* lhsKey = buffer + lhs.rowStart + (kUint32Size * 2);
114187
const char* rhsKey = buffer + rhs.rowStart + (kUint32Size * 2);
115188
return compareKeys(
@@ -147,6 +220,7 @@ LocalShuffleWriteInfo LocalShuffleWriteInfo::deserialize(
147220
jsonReadInfo.at("queryId").get_to(shuffleInfo.queryId);
148221
jsonReadInfo.at("shuffleId").get_to(shuffleInfo.shuffleId);
149222
jsonReadInfo.at("numPartitions").get_to(shuffleInfo.numPartitions);
223+
shuffleInfo.sortedShuffle = jsonReadInfo.value("sortedShuffle", false);
150224
return shuffleInfo;
151225
}
152226

@@ -157,6 +231,7 @@ LocalShuffleReadInfo LocalShuffleReadInfo::deserialize(
157231
jsonReadInfo.at("rootPath").get_to(shuffleInfo.rootPath);
158232
jsonReadInfo.at("queryId").get_to(shuffleInfo.queryId);
159233
jsonReadInfo.at("partitionIds").get_to(shuffleInfo.partitionIds);
234+
shuffleInfo.sortedShuffle = jsonReadInfo.value("sortedShuffle", false);
160235
return shuffleInfo;
161236
}
162237

@@ -276,10 +351,11 @@ void LocalShuffleWriter::collect(
276351
sortedShuffle_ || key.empty(),
277352
"key '{}' must be empty for non-sorted shuffle",
278353
key);
354+
279355
const auto rowSize = this->rowSize(key.size(), data.size());
280356
auto& buffer = inProgressPartitions_[partition];
281357
if (buffer == nullptr) {
282-
buffer = AlignedBuffer::allocate<char>(
358+
buffer = velox::AlignedBuffer::allocate<char>(
283359
std::max(static_cast<uint64_t>(rowSize), maxBytesPerPartition_),
284360
pool_,
285361
0);
@@ -319,31 +395,107 @@ LocalShuffleReader::LocalShuffleReader(
319395
fileSystem_ = velox::filesystems::getFileSystem(rootPath_, nullptr);
320396
}
321397

322-
folly::SemiFuture<std::vector<std::unique_ptr<ReadBatch>>>
323-
LocalShuffleReader::next(uint64_t maxBytes) {
324-
if (readPartitionFiles_.empty()) {
325-
readPartitionFiles_ = getReadPartitionFiles();
398+
void LocalShuffleReader::initialize() {
399+
VELOX_CHECK(!initialized_, "LocalShuffleReader already initialized");
400+
readPartitionFiles_ = getReadPartitionFiles();
401+
if (sortedShuffle_ && !readPartitionFiles_.empty()) {
402+
initSortedShuffleRead();
326403
}
327404

405+
initialized_ = true;
406+
}
407+
408+
void LocalShuffleReader::initSortedShuffleRead() {
409+
std::vector<std::unique_ptr<velox::MergeStream>> streams;
410+
streams.reserve(readPartitionFiles_.size());
411+
TStreamIdx streamIdx = 0;
412+
for (const auto& filename : readPartitionFiles_) {
413+
VELOX_CHECK(
414+
!filename.empty(),
415+
"Invalid empty shuffle file path for query {}, partitions: [{}]",
416+
queryId_,
417+
folly::join(", ", partitionIds_));
418+
auto reader =
419+
std::make_unique<SortedFileInputStream>(filename, streamIdx, pool_);
420+
if (reader->hasData()) {
421+
streams.push_back(std::move(reader));
422+
++streamIdx;
423+
}
424+
}
425+
if (!streams.empty()) {
426+
merge_ =
427+
std::make_unique<velox::TreeOfLosers<velox::MergeStream, uint16_t>>(
428+
std::move(streams));
429+
}
430+
}
431+
432+
std::vector<std::unique_ptr<ReadBatch>> LocalShuffleReader::nextSorted(
433+
uint64_t maxBytes) {
434+
std::vector<std::unique_ptr<ReadBatch>> batches;
435+
436+
if (merge_ == nullptr) {
437+
return batches;
438+
}
439+
440+
auto batchBuffer = velox::AlignedBuffer::allocate<char>(maxBytes, pool_, 0);
441+
std::vector<std::string_view> rows;
442+
uint64_t bufferUsed = 0;
443+
444+
while (auto* stream = merge_->next()) {
445+
auto* reader = velox::checked_pointer_cast<SortedFileInputStream>(stream);
446+
const auto data = reader->currentValue();
447+
448+
if (bufferUsed + data.size() > maxBytes) {
449+
if (bufferUsed > 0) {
450+
batches.push_back(
451+
std::make_unique<ReadBatch>(
452+
std::move(rows), std::move(batchBuffer)));
453+
return batches;
454+
}
455+
// Single row exceeds buffer - allocate larger buffer
456+
batchBuffer = velox::AlignedBuffer::allocate<char>(data.size(), pool_, 0);
457+
}
458+
459+
char* writePos = batchBuffer->asMutable<char>() + bufferUsed;
460+
if (!data.empty()) {
461+
memcpy(writePos, data.data(), data.size());
462+
}
463+
464+
rows.emplace_back(batchBuffer->as<char>() + bufferUsed, data.size());
465+
bufferUsed += data.size();
466+
reader->next();
467+
}
468+
469+
if (!rows.empty()) {
470+
batches.push_back(
471+
std::make_unique<ReadBatch>(std::move(rows), std::move(batchBuffer)));
472+
}
473+
474+
return batches;
475+
}
476+
477+
std::vector<std::unique_ptr<ReadBatch>> LocalShuffleReader::nextUnsorted(
478+
uint64_t maxBytes) {
328479
std::vector<std::unique_ptr<ReadBatch>> batches;
329480
uint64_t totalBytes{0};
330-
// Read files until we reach maxBytes limit or run out of files.
481+
331482
while (readPartitionFileIndex_ < readPartitionFiles_.size()) {
332483
const auto filename = readPartitionFiles_[readPartitionFileIndex_];
333484
auto file = fileSystem_->openFileForRead(filename);
334485
const auto fileSize = file->size();
335486

336-
// Stop if adding this file would exceed maxBytes (unless we haven't read
337-
// any files yet)
487+
// TODO: Refactor to use streaming I/O with bounded buffer size instead of
488+
// loading entire files into memory at once. A streaming approach would
489+
// reduce peak memory consumption and enable processing arbitrarily large
490+
// shuffle files while maintaining constant memory usage.
338491
if (!batches.empty() && totalBytes + fileSize > maxBytes) {
339492
break;
340493
}
341494

342-
auto buffer = AlignedBuffer::allocate<char>(fileSize, pool_, 0);
495+
auto buffer = velox::AlignedBuffer::allocate<char>(fileSize, pool_, 0);
343496
file->pread(0, fileSize, buffer->asMutable<void>());
344497
++readPartitionFileIndex_;
345498

346-
// Parse the buffer to extract individual rows
347499
const char* data = buffer->as<char>();
348500
const auto parsedRows = extractRowMetadata(data, fileSize, sortedShuffle_);
349501
std::vector<std::string_view> rows;
@@ -357,7 +509,17 @@ LocalShuffleReader::next(uint64_t maxBytes) {
357509
std::make_unique<ReadBatch>(std::move(rows), std::move(buffer)));
358510
}
359511

360-
return folly::makeSemiFuture(std::move(batches));
512+
return batches;
513+
}
514+
515+
folly::SemiFuture<std::vector<std::unique_ptr<ReadBatch>>>
516+
LocalShuffleReader::next(uint64_t maxBytes) {
517+
VELOX_CHECK(
518+
initialized_,
519+
"LocalShuffleReader::initialize() must be called before next()");
520+
521+
return folly::makeSemiFuture(
522+
sortedShuffle_ ? nextSorted(maxBytes) : nextUnsorted(maxBytes));
361523
}
362524

363525
void LocalShuffleReader::noMoreData(bool success) {
@@ -403,12 +565,15 @@ std::shared_ptr<ShuffleReader> LocalPersistentShuffleFactory::createReader(
403565
velox::memory::MemoryPool* pool) {
404566
const operators::LocalShuffleReadInfo readInfo =
405567
operators::LocalShuffleReadInfo::deserialize(serializedStr);
406-
return std::make_shared<LocalShuffleReader>(
568+
569+
auto reader = std::make_shared<LocalShuffleReader>(
407570
readInfo.rootPath,
408571
readInfo.queryId,
409572
readInfo.partitionIds,
410-
/*sortShuffle=*/false, // default to false for now
573+
readInfo.sortedShuffle,
411574
pool);
575+
reader->initialize();
576+
return reader;
412577
}
413578

414579
std::shared_ptr<ShuffleWriter> LocalPersistentShuffleFactory::createWriter(
@@ -418,13 +583,14 @@ std::shared_ptr<ShuffleWriter> LocalPersistentShuffleFactory::createWriter(
418583
SystemConfig::instance()->localShuffleMaxPartitionBytes();
419584
const operators::LocalShuffleWriteInfo writeInfo =
420585
operators::LocalShuffleWriteInfo::deserialize(serializedStr);
586+
421587
return std::make_shared<LocalShuffleWriter>(
422588
writeInfo.rootPath,
423589
writeInfo.queryId,
424590
writeInfo.shuffleId,
425591
writeInfo.numPartitions,
426592
maxBytesPerPartition,
427-
/*sortedShuffle=*/false, // default to false for now
593+
writeInfo.sortedShuffle,
428594
pool);
429595
}
430596

@@ -436,5 +602,4 @@ std::vector<RowMetadata> testingExtractRowMetadata(
436602
bool sortedShuffle) {
437603
return extractRowMetadata(buffer, bufferSize, sortedShuffle);
438604
}
439-
440605
} // namespace facebook::presto::operators

0 commit comments

Comments
 (0)