Skip to content

Commit ed9466c

Browse files
authored
read file in one time (#460)
* read whole label file to memory, use string find instead stringstream * format doc
1 parent 720c45c commit ed9466c

File tree

3 files changed

+79
-21
lines changed

3 files changed

+79
-21
lines changed

AnyBuildLogs/latest.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
20231019-111207-d314f8bf

include/pq_flash_index.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
118118
DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, LabelT label_id);
119119
std::unordered_map<std::string, LabelT> load_label_map(std::basic_istream<char> &infile);
120120
DISKANN_DLLEXPORT void parse_label_file(std::basic_istream<char> &infile, size_t &num_pts_labels);
121-
DISKANN_DLLEXPORT void get_label_file_metadata(std::basic_istream<char> &infile, uint32_t &num_pts,
121+
DISKANN_DLLEXPORT void get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts,
122122
uint32_t &num_total_labels);
123123
DISKANN_DLLEXPORT void generate_random_labels(std::vector<LabelT> &labels, const uint32_t num_labels,
124124
const uint32_t nthreads);

src/pq_flash_index.cpp

Lines changed: 77 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -610,28 +610,47 @@ void PQFlashIndex<T, LabelT>::reset_stream_for_reading(std::basic_istream<char>
610610
}
611611

612612
template <typename T, typename LabelT>
613-
void PQFlashIndex<T, LabelT>::get_label_file_metadata(std::basic_istream<char> &infile, uint32_t &num_pts,
613+
void PQFlashIndex<T, LabelT>::get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts,
614614
uint32_t &num_total_labels)
615615
{
616-
std::string line, token;
617616
num_pts = 0;
618617
num_total_labels = 0;
619618

620-
while (std::getline(infile, line))
619+
size_t file_size = fileContent.length();
620+
621+
std::string label_str;
622+
size_t cur_pos = 0;
623+
size_t next_pos = 0;
624+
while (cur_pos < file_size && cur_pos != std::string::npos)
621625
{
622-
std::istringstream iss(line);
623-
while (getline(iss, token, ','))
626+
next_pos = fileContent.find('\n', cur_pos);
627+
if (next_pos == std::string::npos)
628+
{
629+
break;
630+
}
631+
632+
size_t lbl_pos = cur_pos;
633+
size_t next_lbl_pos = 0;
634+
while (lbl_pos < next_pos && lbl_pos != std::string::npos)
624635
{
625-
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
626-
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
636+
next_lbl_pos = fileContent.find(',', lbl_pos);
637+
if (next_lbl_pos == std::string::npos) // the last label
638+
{
639+
next_lbl_pos = next_pos;
640+
}
641+
627642
num_total_labels++;
643+
644+
lbl_pos = next_lbl_pos + 1;
628645
}
646+
647+
cur_pos = next_pos + 1;
648+
629649
num_pts++;
630650
}
631651

632652
diskann::cout << "Labels file metadata: num_points: " << num_pts << ", #total_labels: " << num_total_labels
633653
<< std::endl;
634-
reset_stream_for_reading(infile);
635654
}
636655

637656
template <typename T, typename LabelT>
@@ -654,44 +673,82 @@ inline bool PQFlashIndex<T, LabelT>::point_has_label(uint32_t point_id, LabelT l
654673
template <typename T, typename LabelT>
655674
void PQFlashIndex<T, LabelT>::parse_label_file(std::basic_istream<char> &infile, size_t &num_points_labels)
656675
{
657-
std::string line, token;
676+
infile.seekg(0, std::ios::end);
677+
size_t file_size = infile.tellg();
678+
679+
std::string buffer(file_size, ' ');
680+
681+
infile.seekg(0, std::ios::beg);
682+
infile.read(&buffer[0], file_size);
683+
684+
std::string line;
658685
uint32_t line_cnt = 0;
659686

660687
uint32_t num_pts_in_label_file;
661688
uint32_t num_total_labels;
662-
get_label_file_metadata(infile, num_pts_in_label_file, num_total_labels);
689+
get_label_file_metadata(buffer, num_pts_in_label_file, num_total_labels);
663690

664691
_pts_to_label_offsets = new uint32_t[num_pts_in_label_file];
665692
_pts_to_label_counts = new uint32_t[num_pts_in_label_file];
666693
_pts_to_labels = new LabelT[num_total_labels];
667694
uint32_t labels_seen_so_far = 0;
668695

669-
while (std::getline(infile, line))
696+
std::string label_str;
697+
size_t cur_pos = 0;
698+
size_t next_pos = 0;
699+
while (cur_pos < file_size && cur_pos != std::string::npos)
670700
{
671-
std::istringstream iss(line);
672-
std::vector<uint32_t> lbls(0);
701+
next_pos = buffer.find('\n', cur_pos);
702+
if (next_pos == std::string::npos)
703+
{
704+
break;
705+
}
673706

674707
_pts_to_label_offsets[line_cnt] = labels_seen_so_far;
675708
uint32_t &num_lbls_in_cur_pt = _pts_to_label_counts[line_cnt];
676709
num_lbls_in_cur_pt = 0;
677-
getline(iss, token, '\t');
678-
std::istringstream new_iss(token);
679-
while (getline(new_iss, token, ','))
710+
711+
size_t lbl_pos = cur_pos;
712+
size_t next_lbl_pos = 0;
713+
while (lbl_pos < next_pos && lbl_pos != std::string::npos)
680714
{
681-
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
682-
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
683-
LabelT token_as_num = (LabelT)std::stoul(token);
715+
next_lbl_pos = buffer.find(',', lbl_pos);
716+
if (next_lbl_pos == std::string::npos) // the last label in the whole file
717+
{
718+
next_lbl_pos = next_pos;
719+
}
720+
721+
if (next_lbl_pos > next_pos) // the last label in one line, just read to the end
722+
{
723+
next_lbl_pos = next_pos;
724+
}
725+
726+
label_str.assign(buffer.c_str() + lbl_pos, next_lbl_pos - lbl_pos);
727+
if (label_str[label_str.length() - 1] == '\t') // '\t' won't exist in label file?
728+
{
729+
label_str.erase(label_str.length() - 1);
730+
}
731+
732+
LabelT token_as_num = (LabelT)std::stoul(label_str);
684733
_pts_to_labels[labels_seen_so_far++] = (LabelT)token_as_num;
685734
num_lbls_in_cur_pt++;
735+
736+
// move to next label
737+
lbl_pos = next_lbl_pos + 1;
686738
}
687739

740+
// move to next line
741+
cur_pos = next_pos + 1;
742+
688743
if (num_lbls_in_cur_pt == 0)
689744
{
690745
diskann::cout << "No label found for point " << line_cnt << std::endl;
691746
exit(-1);
692747
}
748+
693749
line_cnt++;
694750
}
751+
695752
num_points_labels = line_cnt;
696753
reset_stream_for_reading(infile);
697754
}
@@ -784,7 +841,7 @@ int PQFlashIndex<T, LabelT>::load_from_separate_paths(uint32_t num_threads, cons
784841
#else
785842
if (file_exists(labels_file))
786843
{
787-
std::ifstream infile(labels_file);
844+
std::ifstream infile(labels_file, std::ios::binary);
788845
if (infile.fail())
789846
{
790847
throw diskann::ANNException(std::string("Failed to open file ") + labels_file, -1);

0 commit comments

Comments
 (0)