Skip to content

Commit 1f7a0e8

Browse files
authored
[Feature](mlu-ops): 添加真值相关算子的test_list. (Cambricon#1222)
1 parent 3a226bd commit 1f7a0e8

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
[rely_real_data]

test/mlu_op_gtest/pb_gtest/src/parser.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ namespace mluoptest {
7272

7373
// env for test negative_scale
7474
__attribute__((__unused__)) bool negative_scale_ =
75-
getEnv("MLUOP_GTEST_NEGATIVE_SCALE", false);
75+
getEnv("MLUOP_GTEST_NEGATIVE_SCALE", false);
7676

7777
Parser::~Parser() {
7878
if (proto_node_ != nullptr) {
@@ -844,7 +844,7 @@ void Parser::getTensorValueByFile(Tensor *pt, void *data, size_t count) {
844844
// random data(for cpu compute) value is fp32 definitely
845845
// valueh valuef valuei dtype is according dtype in proto
846846
void Parser::getTensorValue(Tensor *pt, void *data, ValueType value_type,
847-
size_t count) {
847+
size_t count) {
848848
switch (value_type) {
849849
case VALUE_H:
850850
getTensorValueH(pt, data, count);
@@ -933,10 +933,9 @@ bool Parser::common_threshold() {
933933
return res;
934934
}
935935

936-
std::set<Evaluator::Criterion> Parser::criterions(
937-
int index, const std::set<Evaluator::Formula> &criterions_use) {
936+
std::set<Evaluator::Criterion> Parser::criterions(int index,
937+
const std::set<Evaluator::Formula> &criterions_use) {
938938
std::set<Evaluator::Criterion> res;
939-
940939
// check if there exists complex output tensor
941940
bool has_complex_output = false;
942941
for (auto i = 0; i < proto_node_->output_size(); ++i) {
@@ -1040,7 +1039,7 @@ std::set<Evaluator::Criterion> Parser::criterions(
10401039
}
10411040

10421041
static inline bool strEndsWith(const std::string &self,
1043-
const std::string &pattern) {
1042+
const std::string &pattern) {
10441043
if (self.size() < pattern.size()) return false;
10451044
return (self.compare(self.size() - pattern.size(), pattern.size(), pattern) ==
10461045
0);
@@ -1100,9 +1099,17 @@ bool Parser::readMessageFromFile(const std::string &filename, Node *proto) {
11001099

11011100
void Parser::getTestInfo() {
11021101
std::unordered_map<std::string, std::vector<std::string>> test_info;
1102+
std::unordered_map<std::string, std::vector<std::string>> internal_test_info;
11031103
test_info =
11041104
readFileByLine("../../test/mlu_op_gtest/pb_gtest/gtest_config/test_list");
1105+
internal_test_info = readFileByLine(
1106+
"../../test/mlu_op_gtest/pb_gtest/gtest_config/internal_test_list");
1107+
11051108
list_rely_real_data_ = test_info["rely_real_data"];
1109+
std::vector<std::string> temp_list = internal_test_info["rely_real_data"];
1110+
for (auto &internal_op_name : temp_list) {
1111+
list_rely_real_data_.push_back(internal_op_name);
1112+
}
11061113
}
11071114

11081115
Evaluator::Formula Parser::cvtProtoEvaluationCriterion(int f) {

0 commit comments

Comments
 (0)