From 639d48aacff7a058c334de862995c3569c01723f Mon Sep 17 00:00:00 2001 From: Kian O'Hara Date: Tue, 22 Aug 2023 10:50:41 +0200 Subject: [PATCH 1/5] [lang] Clean up dtypes to be alphabetical We make sure that structs come before tuples, in order to keep cleanliness for later additions. --- include/occa/dtype/dtype.hpp | 107 ++++++++------- src/dtype/dtype.cpp | 246 ++++++++++++++++------------------- 2 files changed, 178 insertions(+), 175 deletions(-) diff --git a/include/occa/dtype/dtype.hpp b/include/occa/dtype/dtype.hpp index 6fce760ac..30676b6e8 100644 --- a/include/occa/dtype/dtype.hpp +++ b/include/occa/dtype/dtype.hpp @@ -10,8 +10,8 @@ namespace occa { class dtype_t; - class dtypeTuple_t; class dtypeStruct_t; + class dtypeTuple_t; class json; typedef std::map dtypeGlobalMap_t; @@ -40,8 +40,8 @@ namespace occa { int bytes_; bool registered; - dtypeTuple_t *tuple_; dtypeStruct_t *struct_; + dtypeTuple_t *tuple_; mutable dtypeVector_t flatDtype; public: @@ -99,59 +99,60 @@ namespace occa { bool isRegistered() const; - // Tuple methods + + // Struct methods /** - * @startDoc{isTuple} + * @startDoc{isStruct} * * Description: - * Returns `true` if the data type holds a tuple type. - * For example: `occa::dtype::int2` is a tuple of two `int`s + * Returns `true` if the data type represents a struct. + * It's different that a tuple since it can keep distinct data types in its fields. * * @endDoc */ - bool isTuple() const; + bool isStruct() const; /** - * @startDoc{tupleSize} + * @startDoc{structFieldCount} * * Description: - * Return how big the tuple is, for example `int2` would return `2` + * Returns how many fields are defined in the struct * * @endDoc */ - int tupleSize() const; + int structFieldCount() const; - // Struct methods /** - * @startDoc{isStruct} + * @startDoc{structFieldNames} * * Description: - * Returns `true` if the data type represents a struct. - * It's different that a tuple since it can keep distinct data types in its fields. + * Return the list of field names for the struct * * @endDoc */ - bool isStruct() const; + const strVector& structFieldNames() const; + // Tuple methods /** - * @startDoc{structFieldCount} + * @startDoc{isTuple} * * Description: - * Returns how many fields are defined in the struct + * Returns `true` if the data type holds a tuple type. + * For example: `occa::dtype::int2` is a tuple of two `int`s * * @endDoc */ - int structFieldCount() const; + bool isTuple() const; /** - * @startDoc{structFieldNames} + * @startDoc{tupleSize} * * Description: - * Return the list of field names for the struct + * Return how big the tuple is, for example `int2` would return `2` * * @endDoc */ - const strVector& structFieldNames() const; + int tupleSize() const; /** * @startDoc{operator_bracket[0]} @@ -186,7 +187,21 @@ namespace occa { const int tupleSize_ = 1); // Dtype methods + /** + * @startDoc{setFlattenedDtype} + * + * Description: + * Add flatten dtypes of each field. + * @endDoc + */ void setFlattenedDtype() const; + /** + * @startDoc{addFlatDtypes} + * + * Description: + * Add dtypes of each field. + * @endDoc + */ void addFlatDtypes(dtypeVector_t &vec) const; /** @@ -260,29 +275,6 @@ namespace occa { const dtype_t &dtype); - //---[ Tuple ]------------------------ - class dtypeTuple_t { - friend class dtype_t; - - private: - const dtype_t dtype; - int size; - - dtypeTuple_t(const dtype_t &dtype_, - const int size_); - - dtypeTuple_t* clone() const; - - bool matches(const dtypeTuple_t &other) const; - - void addFlatDtypes(dtypeVector_t &vec) const; - - void toJson(json &j, const std::string &name = "") const; - static dtypeTuple_t fromJson(const json &j); - - std::string toString(const std::string &varName = "") const; - }; - //==================================== //---[ Struct ]----------------------- @@ -315,6 +307,33 @@ namespace occa { std::string toString(const std::string &varName = "") const; }; //==================================== + + + //---[ Tuple ]------------------------ + class dtypeTuple_t { + friend class dtype_t; + + private: + const dtype_t dtype; + int size; + + dtypeTuple_t(const dtype_t &dtype_, + const int size_); + + dtypeTuple_t* clone() const; + + bool matches(const dtypeTuple_t &other) const; + + void addFlatDtypes(dtypeVector_t &vec) const; + + void toJson(json &j, const std::string &name = "") const; + static dtypeTuple_t fromJson(const json &j); + + std::string toString(const std::string &varName = "") const; + }; + //==================================== + + } #endif diff --git a/src/dtype/dtype.cpp b/src/dtype/dtype.cpp index 1e61889b5..5eb51bcc8 100644 --- a/src/dtype/dtype.cpp +++ b/src/dtype/dtype.cpp @@ -1,3 +1,5 @@ +#include + #include #include #include @@ -12,8 +14,8 @@ namespace occa { name_(), bytes_(0), registered(false), - tuple_(NULL), - struct_(NULL) {} + struct_(NULL), + tuple_(NULL) {} dtype_t::dtype_t(const std::string &name__, const int bytes__, @@ -22,8 +24,8 @@ namespace occa { name_(name__), bytes_(bytes__), registered(registered_), - tuple_(NULL), - struct_(NULL) {} + struct_(NULL), + tuple_(NULL) {} dtype_t::dtype_t(const std::string &name__, const dtype_t &other, @@ -32,8 +34,8 @@ namespace occa { name_(), bytes_(0), registered(false), - tuple_(NULL), - struct_(NULL) { + struct_(NULL), + tuple_(NULL) { *this = other; @@ -46,8 +48,8 @@ namespace occa { name_(), bytes_(0), registered(false), - tuple_(NULL), - struct_(NULL) { + struct_(NULL), + tuple_(NULL) { *this = other; } @@ -59,30 +61,30 @@ namespace occa { const dtype_t &other = other_.self(); if (!ref || ref != &other) { - delete tuple_; delete struct_; + delete tuple_; if (other.registered) { // Clear values - ref = &other; - name_ = ""; - bytes_ = 0; - tuple_ = NULL; - struct_ = NULL; + ref = &other; + name_ = ""; + bytes_ = 0; + struct_ = NULL; + tuple_ = NULL; } else { - ref = NULL; - name_ = other.name_; - bytes_ = other.bytes_; - tuple_ = other.tuple_ ? other.tuple_->clone() : NULL; - struct_ = other.struct_ ? other.struct_->clone() : NULL; + ref = NULL; + name_ = other.name_; + bytes_ = other.bytes_; + struct_ = other.struct_ ? other.struct_->clone() : NULL; + tuple_ = other.tuple_ ? other.tuple_->clone() : NULL; } } return *this; } dtype_t::~dtype_t() { - delete tuple_; delete struct_; + delete tuple_; } const std::string& dtype_t::name() const { @@ -94,8 +96,7 @@ namespace occa { } void dtype_t::registerType() { - OCCA_ERROR("Unable to register dtype references", - ref == NULL); + OCCA_ERROR("Unable to register dtype references", ref == NULL); registered = true; } @@ -103,19 +104,6 @@ namespace occa { return self().registered; } - // Tuple methods - bool dtype_t::isTuple() const { - return self().tuple_; - } - - int dtype_t::tupleSize() const { - const dtypeTuple_t *tuplePtr = self().tuple_; - if (tuplePtr) { - return tuplePtr->size; - } - return 0; - } - // Struct methods bool dtype_t::isStruct() const { return self().struct_; @@ -131,34 +119,28 @@ namespace occa { const strVector& dtype_t::structFieldNames() const { const dtypeStruct_t *structPtr = self().struct_; - OCCA_ERROR("Cannot get fields from a non-struct dtype_t", - structPtr != NULL); + OCCA_ERROR("Cannot get fields from a non-struct dtype_t", structPtr != NULL); return structPtr->fieldNames; } const dtype_t& dtype_t::operator [] (const int field) const { const dtypeStruct_t *structPtr = self().struct_; - OCCA_ERROR("Cannot access fields from a non-struct dtype_t", - structPtr != NULL); + OCCA_ERROR("Cannot access fields from a non-struct dtype_t", structPtr != NULL); return (*structPtr)[field]; } const dtype_t& dtype_t::operator [] (const std::string &field) const { const dtypeStruct_t *structPtr = self().struct_; - OCCA_ERROR("Cannot access fields from a non-struct dtype_t", - structPtr != NULL); + OCCA_ERROR("Cannot access fields from a non-struct dtype_t", structPtr != NULL); return (*structPtr)[field]; } dtype_t& dtype_t::addField(const std::string &field, const dtype_t &dtype, const int tupleSize_) { - OCCA_ERROR("Cannot add a field to a dtype_t reference", - ref == NULL); - OCCA_ERROR("Cannot add a field to an tuple dtype_t", - tuple_ == NULL); - OCCA_ERROR("Tuple size must be a positive integer", - tupleSize_ > 0); + OCCA_ERROR("Cannot add a field to a dtype_t reference", ref == NULL); + OCCA_ERROR("Cannot add a field to an tuple dtype_t", tuple_ == NULL); + OCCA_ERROR("Tuple size must be a positive integer", tupleSize_ > 0); if (!struct_) { struct_ = new dtypeStruct_t(); @@ -169,8 +151,7 @@ namespace occa { if (tupleSize_ == 1) { struct_->addField(field, dtype); } else { - struct_->addField(field, - tuple(dtype, tupleSize_)); + struct_->addField(field, tuple(dtype, tupleSize_)); } return *this; @@ -227,17 +208,17 @@ namespace occa { } // Check type differences - if (((bool) a.tuple_ != (bool) b.tuple_) || - ((bool) a.struct_ != (bool) b.struct_)) { + if (((bool) a.struct_ != (bool) b.struct_) || + ((bool) a.tuple_ != (bool) b.tuple_)) { return false; } // Check from the dtype type - if (a.tuple_) { - return a.tuple_->matches(*(b.tuple_)); - } if (a.struct_) { return a.struct_->matches(*(b.struct_)); } + if (a.tuple_) { + return a.tuple_->matches(*(b.tuple_)); + } // Shouldn't get here return false; @@ -402,10 +383,10 @@ namespace occa { return ref->toJson(j, name); } - if (tuple_) { - return tuple_->toJson(j, name); - } else if (struct_) { + if (struct_) { return struct_->toJson(j, name); + } else if (tuple_) { + return tuple_->toJson(j, name); } j.clear(); @@ -441,10 +422,10 @@ namespace occa { OCCA_ERROR("Unknown dtype builtin [" << dtype.name_ << "]", &builtin != &dtype::none); dtype = builtin; - } else if (type == "tuple") { - dtype.tuple_ = dtypeTuple_t::fromJson(j).clone(); } else if (type == "struct") { dtype.struct_ = dtypeStruct_t::fromJson(j).clone(); + } else if (type == "tuple") { + dtype.tuple_ = dtypeTuple_t::fromJson(j).clone(); } else if (type == "custom") { dtype.bytes_ = (int) j["bytes"]; } else { @@ -465,10 +446,10 @@ namespace occa { name = self_.name_; } - if (self_.tuple_) { - ss << self_.tuple_->toString(name); - } else if (self_.struct_) { + if (self_.struct_) { ss << self_.struct_->toString(name); + } else if (self_.tuple_) { + ss << self_.tuple_->toString(name); } else { ss << name; } @@ -484,75 +465,6 @@ namespace occa { //==================================== - //---[ Tuple ]---------------------- - dtypeTuple_t::dtypeTuple_t(const dtype_t &dtype_, - const int size_) : - dtype(dtype_), - size(size_) {} - - dtypeTuple_t* dtypeTuple_t::clone() const { - return new dtypeTuple_t(dtype, size); - } - - bool dtypeTuple_t::matches(const dtypeTuple_t &other) const { - if (size != other.size) { - return false; - } - return dtype.matches(other.dtype); - } - - void dtypeTuple_t::addFlatDtypes(dtypeVector_t &vec) const { - for (int i = 0; i < size; ++i) { - dtype.addFlatDtypes(vec); - } - } - - void dtypeTuple_t::toJson(json &j, const std::string &name) const { - j.clear(); - j.asObject(); - - j["type"] = "tuple"; - if (name.size()) { - j["name"] = name; - } - j["dtype"] = dtype::toJson(dtype); - j["size"] = size; - } - - dtypeTuple_t dtypeTuple_t::fromJson(const json &j) { - OCCA_ERROR("JSON field [dtype] missing from tuple", - j.has("dtype")); - OCCA_ERROR("JSON field [size] missing from tuple", - j.has("size")); - OCCA_ERROR("JSON field [size] must be an integer", - j["size"].isNumber()); - - return dtypeTuple_t(dtype_t::fromJson(j["dtype"]), - (int) j["size"]); - } - - std::string dtypeTuple_t::toString(const std::string &varName) const { - std::stringstream ss; - - ss << dtype; - - if (varName.size()) { - ss << ' ' << varName; - } - - ss << '['; - if (size >= 0) { - ss << size; - } else { - ss << '?'; - } - ss << ']'; - - return ss.str(); - } - //==================================== - - //---[ Struct ]----------------------- dtypeStruct_t::dtypeStruct_t() {} @@ -719,4 +631,76 @@ namespace occa { return ss.str(); } //==================================== + + + + //---[ Tuple ]---------------------- + dtypeTuple_t::dtypeTuple_t(const dtype_t &dtype_, + const int size_) : + dtype(dtype_), + size(size_) {} + + dtypeTuple_t* dtypeTuple_t::clone() const { + return new dtypeTuple_t(dtype, size); + } + + bool dtypeTuple_t::matches(const dtypeTuple_t &other) const { + if (size != other.size) { + return false; + } + return dtype.matches(other.dtype); + } + + void dtypeTuple_t::addFlatDtypes(dtypeVector_t &vec) const { + for (int i = 0; i < size; ++i) { + dtype.addFlatDtypes(vec); + } + } + + void dtypeTuple_t::toJson(json &j, const std::string &name) const { + j.clear(); + j.asObject(); + + j["type"] = "tuple"; + if (name.size()) { + j["name"] = name; + } + j["dtype"] = dtype::toJson(dtype); + j["size"] = size; + } + + dtypeTuple_t dtypeTuple_t::fromJson(const json &j) { + OCCA_ERROR("JSON field [dtype] missing from tuple", + j.has("dtype")); + OCCA_ERROR("JSON field [size] missing from tuple", + j.has("size")); + OCCA_ERROR("JSON field [size] must be an integer", + j["size"].isNumber()); + + return dtypeTuple_t(dtype_t::fromJson(j["dtype"]), + (int) j["size"]); + } + + std::string dtypeTuple_t::toString(const std::string &varName) const { + std::stringstream ss; + + ss << dtype; + + if (varName.size()) { + ss << ' ' << varName; + } + + ss << '['; + if (size >= 0) { + ss << size; + } else { + ss << '?'; + } + ss << ']'; + + return ss.str(); + } + //==================================== + + } From 9705f7edc727825d14443f0049c9c65c54e1d28d Mon Sep 17 00:00:00 2001 From: Kian O'Hara Date: Tue, 22 Aug 2023 10:51:17 +0200 Subject: [PATCH 2/5] [lang] Added initial support for enums We add support for (typedef) enums by extending lang/type/enum.(cpp|hpp) adding lang/loaders/enumLoader.(cpp|hpp) and lang/enumerator.(cpp|hpp) --- include/occa/dtype/dtype.hpp | 69 +++++++ src/dtype/dtype.cpp | 169 +++++++++++++++++- src/occa/internal/lang/builtins/types.hpp | 2 +- src/occa/internal/lang/enumerator.cpp | 49 +++++ src/occa/internal/lang/enumerator.hpp | 36 ++++ src/occa/internal/lang/keyword.cpp | 2 +- src/occa/internal/lang/loaders.hpp | 1 + .../internal/lang/loaders/attributeLoader.cpp | 3 +- src/occa/internal/lang/loaders/enumLoader.cpp | 84 +++++++++ src/occa/internal/lang/loaders/enumLoader.hpp | 38 ++++ src/occa/internal/lang/loaders/typeLoader.cpp | 71 +++++++- src/occa/internal/lang/loaders/typeLoader.hpp | 10 ++ .../lang/statement/declarationStatement.cpp | 37 ++-- src/occa/internal/lang/type/enum.cpp | 80 ++++++++- src/occa/internal/lang/type/enum.hpp | 17 +- src/occa/internal/lang/type/type.hpp | 4 +- src/occa/internal/lang/type/typedef.cpp | 2 +- src/occa/internal/lang/type/vartype.cpp | 15 ++ src/occa/internal/lang/type/vartype.hpp | 2 + 19 files changed, 655 insertions(+), 36 deletions(-) create mode 100644 src/occa/internal/lang/enumerator.cpp create mode 100644 src/occa/internal/lang/enumerator.hpp create mode 100644 src/occa/internal/lang/loaders/enumLoader.cpp create mode 100644 src/occa/internal/lang/loaders/enumLoader.hpp diff --git a/include/occa/dtype/dtype.hpp b/include/occa/dtype/dtype.hpp index 30676b6e8..5e11126ab 100644 --- a/include/occa/dtype/dtype.hpp +++ b/include/occa/dtype/dtype.hpp @@ -10,6 +10,7 @@ namespace occa { class dtype_t; + class dtypeEnum_t; class dtypeStruct_t; class dtypeTuple_t; class json; @@ -40,6 +41,7 @@ namespace occa { int bytes_; bool registered; + dtypeEnum_t *enum_; dtypeStruct_t *struct_; dtypeTuple_t *tuple_; mutable dtypeVector_t flatDtype; @@ -99,6 +101,47 @@ namespace occa { bool isRegistered() const; + // Enum methods + /** + * @startDoc{isEnum} + * + * Description: + * Returns `true` if the data type represents a enum. + * It's different that a tuple since it can keep distinct data types in its fields. + * + * @endDoc + */ + bool isEnum() const; + + /** + * @startDoc{enumEnumeratorCount} + * + * Description: + * Returns how many enumerator are defined in the enum + * + * @endDoc + */ + int enumEnumeratorCount() const; + + /** + * @startDoc{enumEnumeratorNames} + * + * Description: + * Return the list of enumerator names for the enum + * + * @endDoc + */ + const strVector& enumEnumeratorNames() const; + + /** + * @startDoc{addEnumerator} + * + * Description: + * Add a enumerator to the enum type + * + * @endDoc + */ + dtype_t& addEnumerator(const std::string &enumerator); // Struct methods /** @@ -275,6 +318,32 @@ namespace occa { const dtype_t &dtype); + //---[ Enum ]----------------------- + class dtypeEnum_t { + friend class dtype_t; + + private: + strVector enumeratorNames; + + dtypeEnum_t(); + + dtypeEnum_t* clone() const; + + bool matches(const dtypeEnum_t &other) const; + + int enumeratorCount() const; + + const dtype_t& operator [] (const int enumerator) const; + const dtype_t& operator [] (const std::string &enumerator) const; + + void addEnumerator(const std::string &enumerator); + + void toJson(json &j, const std::string &name = "") const; + static dtypeEnum_t fromJson(const json &j); + + std::string toString(const std::string &varName = "") const; + }; + //==================================== //---[ Struct ]----------------------- diff --git a/src/dtype/dtype.cpp b/src/dtype/dtype.cpp index 5eb51bcc8..87b4102a3 100644 --- a/src/dtype/dtype.cpp +++ b/src/dtype/dtype.cpp @@ -14,6 +14,7 @@ namespace occa { name_(), bytes_(0), registered(false), + enum_(NULL), struct_(NULL), tuple_(NULL) {} @@ -24,6 +25,7 @@ namespace occa { name_(name__), bytes_(bytes__), registered(registered_), + enum_(NULL), struct_(NULL), tuple_(NULL) {} @@ -34,6 +36,7 @@ namespace occa { name_(), bytes_(0), registered(false), + enum_(NULL), struct_(NULL), tuple_(NULL) { @@ -48,6 +51,7 @@ namespace occa { name_(), bytes_(0), registered(false), + enum_(NULL), struct_(NULL), tuple_(NULL) { @@ -61,6 +65,7 @@ namespace occa { const dtype_t &other = other_.self(); if (!ref || ref != &other) { + delete enum_; delete struct_; delete tuple_; @@ -69,12 +74,14 @@ namespace occa { ref = &other; name_ = ""; bytes_ = 0; + enum_ = NULL; struct_ = NULL; tuple_ = NULL; } else { ref = NULL; name_ = other.name_; bytes_ = other.bytes_; + enum_ = other.enum_ ? other.enum_->clone() : NULL; struct_ = other.struct_ ? other.struct_->clone() : NULL; tuple_ = other.tuple_ ? other.tuple_->clone() : NULL; } @@ -83,6 +90,7 @@ namespace occa { } dtype_t::~dtype_t() { + delete enum_; delete struct_; delete tuple_; } @@ -104,6 +112,36 @@ namespace occa { return self().registered; } + // Enum methods + bool dtype_t::isEnum() const { + return self().enum_; + } + + int dtype_t::enumEnumeratorCount() const { + const dtypeEnum_t *enumPtr = self().enum_; + if (enumPtr) { + return enumPtr->enumeratorCount(); + } + return 0; + } + + const strVector& dtype_t::enumEnumeratorNames() const { + const dtypeEnum_t *enumPtr = self().enum_; + OCCA_ERROR("Cannot get enumerators from a non-enum dtype_t", enumPtr != NULL); + return enumPtr->enumeratorNames; + } + + dtype_t& dtype_t::addEnumerator(const std::string &enumerator) { + + if (!enum_) { + enum_ = new dtypeEnum_t(); + } + + enum_->addEnumerator(enumerator); + + return *this; + } + // Struct methods bool dtype_t::isStruct() const { return self().struct_; @@ -208,11 +246,15 @@ namespace occa { } // Check type differences - if (((bool) a.struct_ != (bool) b.struct_) || + if (((bool) a.enum_ != (bool) b.enum_) || + ((bool) a.struct_ != (bool) b.struct_) || ((bool) a.tuple_ != (bool) b.tuple_)) { return false; } // Check from the dtype type + if (a.enum_) { + return a.enum_->matches(*(b.enum_)); + } if (a.struct_) { return a.struct_->matches(*(b.struct_)); } @@ -383,7 +425,9 @@ namespace occa { return ref->toJson(j, name); } - if (struct_) { + if (enum_) { + return enum_->toJson(j, name); + } else if (struct_) { return struct_->toJson(j, name); } else if (tuple_) { return tuple_->toJson(j, name); @@ -422,6 +466,8 @@ namespace occa { OCCA_ERROR("Unknown dtype builtin [" << dtype.name_ << "]", &builtin != &dtype::none); dtype = builtin; + } else if (type == "enum") { + dtype.enum_ = dtypeEnum_t::fromJson(j).clone(); } else if (type == "struct") { dtype.struct_ = dtypeStruct_t::fromJson(j).clone(); } else if (type == "tuple") { @@ -446,7 +492,9 @@ namespace occa { name = self_.name_; } - if (self_.struct_) { + if (self_.enum_) { + ss << self_.enum_->toString(name); + } else if (self_.struct_) { ss << self_.struct_->toString(name); } else if (self_.tuple_) { ss << self_.tuple_->toString(name); @@ -465,6 +513,121 @@ namespace occa { //==================================== + //---[ Enum ]----------------------- + dtypeEnum_t::dtypeEnum_t() {} + + dtypeEnum_t* dtypeEnum_t::clone() const { + dtypeEnum_t *s = new dtypeEnum_t(); + s->enumeratorNames = enumeratorNames; + return s; + } + + bool dtypeEnum_t::matches(const dtypeEnum_t &other) const { + const int enumeratorCount = (int) enumeratorNames.size(); + if (enumeratorCount != (int) other.enumeratorNames.size()) { + return false; + } + + // Compare enumerators + const std::string *names1 = &(enumeratorNames[0]); + const std::string *names2 = &(other.enumeratorNames[0]); + for (int i = 0; i < enumeratorCount; ++i) { + const std::string &name1 = names1[i]; + const std::string &name2 = names2[i]; + if (name1 != name2) { + return false; + } + } + + return true; + } + + int dtypeEnum_t::enumeratorCount() const { + return (int) enumeratorNames.size(); + } + + void dtypeEnum_t::addEnumerator(const std::string &enumerator) { + const bool enumeratorExists = std::find(enumeratorNames.begin(), enumeratorNames.end(), enumerator) != enumeratorNames.end(); + OCCA_ERROR("Enumerator [" << enumerator << "] is already in dtype_t", !enumeratorExists); + + if (!enumeratorExists) { + enumeratorNames.push_back(enumerator); + } + } + + void dtypeEnum_t::toJson(json &j, const std::string &name) const { + j.clear(); + j.asObject(); + + j["type"] = "enum"; + if (name.size()) { + j["name"] = name; + } + + json &enumeratorsJson = j["enumerators"].asArray(); + const int enumeratorCount = (int) enumeratorNames.size(); + + const std::string *names = &(enumeratorNames[0]); + for (int i = 0; i < enumeratorCount; ++i) { + const std::string &enumeratorName = names[i]; + + json enumeratorJson; + enumeratorJson["name"] = enumeratorName; + enumeratorsJson += enumeratorJson; + } + } + + dtypeEnum_t dtypeEnum_t::fromJson(const json &j) { + OCCA_ERROR("JSON enumerator [enumerators] missing from enum", j.has("enumerators")); + OCCA_ERROR("JSON enumerator [enumerators] must be an array of dtypes", j["enumerators"].isArray()); + + const jsonArray &enumerators = j["enumerators"].array(); + const int enumeratorCount = (int) enumerators.size(); + + dtypeEnum_t enum_; + for (int i = 0; i < enumeratorCount; ++i) { + const json &enumeratorJson = enumerators[i]; + OCCA_ERROR("JSON enumerator [name] missing from enum enumerator", enumeratorJson.has("name")); + OCCA_ERROR("JSON enumerator [name] must be a string for enum enumerators", enumeratorJson["name"].isString()); + + enum_.addEnumerator(enumeratorJson["name"].string()); + } + + return enum_; + } + + std::string dtypeEnum_t::toString(const std::string &enumName) const { + std::stringstream ss; + const int enumeratorCount = (int) enumeratorNames.size(); + + ss << "enum "; + if (enumName.size()) { + ss << enumName << ' '; + } + ss << '{'; + + if (!enumeratorCount) { + ss << '}'; + return ss.str(); + } + + ss << '\n'; + + const std::string *names = &(enumeratorNames[0]); + dtype_t prevDtype = dtype::none; + for (int i = 0; i < enumeratorCount; ++i) { + const std::string &name = names[i]; + if (i) { + ss << ", "; + } + ss << name; + } + ss << "\n}"; + + return ss.str(); + } + //==================================== + //---[ Struct ]----------------------- dtypeStruct_t::dtypeStruct_t() {} diff --git a/src/occa/internal/lang/builtins/types.hpp b/src/occa/internal/lang/builtins/types.hpp index 70821999e..ce8998221 100644 --- a/src/occa/internal/lang/builtins/types.hpp +++ b/src/occa/internal/lang/builtins/types.hpp @@ -28,8 +28,8 @@ namespace occa { extern const qualifier_t virtual_; extern const qualifier_t class_; - extern const qualifier_t struct_; extern const qualifier_t enum_; + extern const qualifier_t struct_; extern const qualifier_t union_; // Windows types diff --git a/src/occa/internal/lang/enumerator.cpp b/src/occa/internal/lang/enumerator.cpp new file mode 100644 index 000000000..7801d4399 --- /dev/null +++ b/src/occa/internal/lang/enumerator.cpp @@ -0,0 +1,49 @@ +#include +#include + +namespace occa { + namespace lang { + + enumerator_t::enumerator_t(identifierToken *source_, exprNode *expr_) : + source((identifierToken*) token_t::clone(source_)), + expr(exprNode::clone(expr_)) {}; + + enumerator_t::enumerator_t(const std::string &name_, exprNode *expr_) : + source(new identifierToken(fileOrigin(), name_)), + expr(exprNode::clone(expr_)) {}; + + enumerator_t::enumerator_t(const enumerator_t &other) : + source((identifierToken*) token_t::clone(other.source)), + expr(other.expr) {}; + + enumerator_t& enumerator_t::operator = (const enumerator_t &other) { + if (this == &other) { + return *this; + } + expr = other.expr; + if (source != other.source) { + delete source; + source = (identifierToken*) token_t::clone(other.source); + } + return *this; + } + + enumerator_t& enumerator_t::clone() const { + return *(new enumerator_t(*this)); + } + + enumerator_t::~enumerator_t() {} + + void enumerator_t::clear() { + delete expr; + expr = NULL; + delete source; + source = NULL; + } + + bool enumerator_t::exists() const { + return expr; + } + + } +} \ No newline at end of file diff --git a/src/occa/internal/lang/enumerator.hpp b/src/occa/internal/lang/enumerator.hpp new file mode 100644 index 000000000..ddfc39ae8 --- /dev/null +++ b/src/occa/internal/lang/enumerator.hpp @@ -0,0 +1,36 @@ +#ifndef OCCA_INTERNAL_LANG_ENUMERATOR_HEADER +#define OCCA_INTERNAL_LANG_ENUMERATOR_HEADER +#include +#include +#include +#include + +namespace occa { + namespace lang { + class exprNode; + + //---[ Enumerator ]-------------- + class enumerator_t { + public: + identifierToken *source; + exprNode *expr; + + enumerator_t(); + enumerator_t(identifierToken *source, exprNode *expr_); + enumerator_t(const std::string &name_, exprNode *expr_); + enumerator_t(const enumerator_t &other); + + enumerator_t& operator = (const enumerator_t &other); + enumerator_t& clone() const; + + ~enumerator_t(); + + void clear(); + bool exists() const; + + }; + //================================== + } +} + +#endif diff --git a/src/occa/internal/lang/keyword.cpp b/src/occa/internal/lang/keyword.cpp index 27bb4940e..16d929362 100644 --- a/src/occa/internal/lang/keyword.cpp +++ b/src/occa/internal/lang/keyword.cpp @@ -305,8 +305,8 @@ namespace occa { keywords.add(*(new qualifierKeyword(virtual_))); keywords.add(*(new qualifierKeyword(class_))); - keywords.add(*(new qualifierKeyword(struct_))); keywords.add(*(new qualifierKeyword(enum_))); + keywords.add(*(new qualifierKeyword(struct_))); keywords.add(*(new qualifierKeyword(union_))); // Types diff --git a/src/occa/internal/lang/loaders.hpp b/src/occa/internal/lang/loaders.hpp index 28a792c07..d650b47dc 100644 --- a/src/occa/internal/lang/loaders.hpp +++ b/src/occa/internal/lang/loaders.hpp @@ -2,6 +2,7 @@ #define OCCA_INTERNAL_LANG_LOADERS_HEADER #include +#include #include #include #include diff --git a/src/occa/internal/lang/loaders/attributeLoader.cpp b/src/occa/internal/lang/loaders/attributeLoader.cpp index 31cb81840..f22546444 100644 --- a/src/occa/internal/lang/loaders/attributeLoader.cpp +++ b/src/occa/internal/lang/loaders/attributeLoader.cpp @@ -20,8 +20,7 @@ namespace occa { success(true) {} bool attributeLoader_t::loadAttributes(attributeTokenMap &attrs) { - while (success && - (token_t::safeOperatorType(tokenContext[0]) & operatorType::attribute)) { + while (success && (token_t::safeOperatorType(tokenContext[0]) & operatorType::attribute)) { loadAttribute(attrs); if (!success) { break; diff --git a/src/occa/internal/lang/loaders/enumLoader.cpp b/src/occa/internal/lang/loaders/enumLoader.cpp new file mode 100644 index 000000000..f210f0131 --- /dev/null +++ b/src/occa/internal/lang/loaders/enumLoader.cpp @@ -0,0 +1,84 @@ +#include +#include +#include +#include +#include + +namespace occa { + namespace lang { + enumLoader_t::enumLoader_t(tokenContext_t &tokenContext_, + statementContext_t &smntContext_, + parser_t &parser_ + ) : + tokenContext(tokenContext_), + smntContext(smntContext_), + parser(parser_) {} + + bool enumLoader_t::loadEnum(enum_t *&type) { + type = NULL; + + // Store type declarations in temporary block statement + blockStatement *blockSmnt = new blockStatement(smntContext.up, + tokenContext[0]); + smntContext.pushUp(*blockSmnt); + + identifierToken *nameToken = NULL; + const bool hasName = token_t::safeType(tokenContext[0]) & tokenType::identifier; + if (hasName) { + nameToken = (identifierToken*) tokenContext[0]; + ++tokenContext; + } + + opType_t opType = token_t::safeOperatorType(tokenContext[0]); + if (!(opType & (operatorType::braceStart | + operatorType::scope))) { + tokenContext.printError("Expected enum body {}"); + delete blockSmnt; + smntContext.popUp(); + return false; + } + tokenContext.pushPairRange(); + // Load type expression statements + enumeratorVector enumerators; + if (tokenContext.size()) { + identifierToken &source = (tokenContext[0]->clone()->to()); + while (source.value != "") { + exprNode *expr_ = NULL; + ++tokenContext; + if ((token_t::safeOperatorType(tokenContext[0]) & operatorType::assign)) { + ++tokenContext; + const int end = tokenContext.getNextOperator(operatorType::comma); + if (end>0) { + expr_ = parser.parseTokenContextExpression(0, end); + tokenContext += end; + } else { + expr_ = parser.parseTokenContextExpression(); + } + } + enumerators.push_back(*(new enumerator_t(&source, expr_))); + if (!(token_t::safeOperatorType(tokenContext[0]) & operatorType::comma)) { + break; + } + ++tokenContext; + source = (tokenContext[0]->clone()->to()); + } + } + delete blockSmnt; + smntContext.popUp(); + tokenContext.popAndSkip(); + + type = nameToken ? new enum_t(*nameToken) : new enum_t(); + type->addEnumerators(enumerators); + + return true; + } + + bool loadEnum(tokenContext_t &tokenContext, + statementContext_t &smntContext, + parser_t &parser, + enum_t *&type) { + enumLoader_t loader(tokenContext, smntContext, parser); + return loader.loadEnum(type); + } + } +} diff --git a/src/occa/internal/lang/loaders/enumLoader.hpp b/src/occa/internal/lang/loaders/enumLoader.hpp new file mode 100644 index 000000000..b80eb8a8a --- /dev/null +++ b/src/occa/internal/lang/loaders/enumLoader.hpp @@ -0,0 +1,38 @@ +#ifndef OCCA_INTERNAL_LANG_PARSER_ENUMLOADER_HEADER +#define OCCA_INTERNAL_LANG_PARSER_ENUMLOADER_HEADER + +#include +#include +#include + +namespace occa { + namespace lang { + class enum_t; + class parser_t; + + class enumLoader_t { + public: + tokenContext_t &tokenContext; + statementContext_t &smntContext; + parser_t &parser; + + enumLoader_t(tokenContext_t &tokenContext_, + statementContext_t &smntContext_, + parser_t &parser_); + + bool loadEnum(enum_t *&type); + + friend bool loadEnum(tokenContext_t &tokenContext, + statementContext_t &smntContext, + parser_t &parser, + enum_t *&type); + }; + + bool loadEnum(tokenContext_t &tokenContext, + statementContext_t &smntContext, + parser_t &parser, + enum_t *&type); + } +} + +#endif diff --git a/src/occa/internal/lang/loaders/typeLoader.cpp b/src/occa/internal/lang/loaders/typeLoader.cpp index c19a02b46..fc2529ce5 100644 --- a/src/occa/internal/lang/loaders/typeLoader.cpp +++ b/src/occa/internal/lang/loaders/typeLoader.cpp @@ -1,10 +1,12 @@ #include +#include #include #include #include #include #include #include +#include #include #include @@ -63,11 +65,7 @@ namespace occa { if (kType & keywordType::qualifier) { const qualifier_t &qualifier = keyword.to().qualifier; type_t *type = NULL; - if (qualifier == enum_) { - // TODO: type = loadEnum(); - token->printError("Enums are not supported yet"); - success = false; - } else if (qualifier == union_) { + if (qualifier == union_) { // TODO: type = loadUnion(); token->printError("Unions are not supported yet"); success = false; @@ -115,7 +113,10 @@ namespace occa { vartype.type = &int_; return true; } - + if (vartype.has(enum_)) { + loadEnum(vartype); + return success; + } if (vartype.has(struct_)) { loadStruct(vartype); return success; @@ -205,6 +206,45 @@ namespace occa { ++tokenContext; } + void typeLoader_t::loadEnum(vartype_t &vartype) { + enumLoader_t enumLoader(tokenContext, smntContext, parser); + + // Load enum + enum_t *enumType = NULL; + success &= enumLoader.loadEnum(enumType); + if (!success) { + return; + } + + if (!vartype.has(typedef_)) { + vartype.setType(*((identifierToken*) enumType->source), + *enumType); + return; + } + + // Load typedef name + if (!(token_t::safeType(tokenContext[0]) & tokenType::identifier)) { + tokenContext.printError("Expected typedef name"); + success = false; + return; + } + + identifierToken *nameToken = (identifierToken*) tokenContext[0]; + ++tokenContext; + + // Move the enum qualifier over + vartype_t enumVartype(*((identifierToken*) enumType->source), + *enumType); + enumVartype += enum_; + vartype -= enum_; + + typedef_t *typedefType = new typedef_t(enumVartype, *nameToken); + typedefType->declaredBaseType = true; + + vartype.setType(*nameToken, + *typedefType); + } + void typeLoader_t::loadStruct(vartype_t &vartype) { structLoader_t structLoader(tokenContext, smntContext, parser); @@ -260,6 +300,23 @@ namespace occa { return loader.loadBaseType(vartype); } + bool isLoadingEnum(tokenContext_t &tokenContext, + statementContext_t &smntContext, + parser_t &parser) { + tokenContext.push(); + tokenContext.supressErrors = true; + + vartype_t vartype; + loadType(tokenContext, smntContext, parser, vartype); + + tokenContext.supressErrors = false; + tokenContext.pop(); + + return (!vartype.isValid() && // Should not have a base type since we're defining it + vartype.has(enum_) && // Should have enum_ + !vartype.has(typedef_)); // typedef enum is not loaded as a enum + } + bool isLoadingStruct(tokenContext_t &tokenContext, statementContext_t &smntContext, parser_t &parser) { @@ -272,7 +329,7 @@ namespace occa { tokenContext.supressErrors = false; tokenContext.pop(); - return (!vartype.isValid() && // Should not have a base type since we're defining it + return (!vartype.isValid() && // Should not have a base type since we're defining it vartype.has(struct_) && // Should have struct_ !vartype.has(typedef_)); // typedef struct is not loaded as a struct } diff --git a/src/occa/internal/lang/loaders/typeLoader.hpp b/src/occa/internal/lang/loaders/typeLoader.hpp index fcbf7c1cf..260d411f7 100644 --- a/src/occa/internal/lang/loaders/typeLoader.hpp +++ b/src/occa/internal/lang/loaders/typeLoader.hpp @@ -32,6 +32,8 @@ namespace occa { void setVartypeReference(vartype_t &vartype); + void loadEnum(vartype_t &vartype); + void loadStruct(vartype_t &vartype); friend bool loadType(tokenContext_t &tokenContext, @@ -44,6 +46,10 @@ namespace occa { parser_t &parser, vartype_t &vartype); + friend bool isLoadingEnum(tokenContext_t &tokenContext, + statementContext_t &smntContext, + parser_t &parser); + friend bool isLoadingStruct(tokenContext_t &tokenContext, statementContext_t &smntContext, parser_t &parser); @@ -59,6 +65,10 @@ namespace occa { parser_t &parser, vartype_t &vartype); + bool isLoadingEnum(tokenContext_t &tokenContext, + statementContext_t &smntContext, + parser_t &parser); + bool isLoadingStruct(tokenContext_t &tokenContext, statementContext_t &smntContext, parser_t &parser); diff --git a/src/occa/internal/lang/statement/declarationStatement.cpp b/src/occa/internal/lang/statement/declarationStatement.cpp index 2d1320025..fb124cf55 100644 --- a/src/occa/internal/lang/statement/declarationStatement.cpp +++ b/src/occa/internal/lang/statement/declarationStatement.cpp @@ -90,16 +90,21 @@ namespace occa { const typedef_t *originalTypedef = dynamic_cast(var.vartype.type); + const bool typedefingEnum = ( + originalTypedef != NULL + && originalTypedef->baseType.has(enum_) + && originalTypedef->declaredBaseType + ); + const bool typedefingStruct = ( originalTypedef != NULL && originalTypedef->baseType.has(struct_) && originalTypedef->declaredBaseType ); - typedef_t *type = NULL; - if (typedefingStruct) { - // Struct typedefs already allocate a new type + if (typedefingStruct || typedefingEnum) { + // Struct & Enum typedefs already allocate a new type type = (typedef_t*) originalTypedef; } else { type = new typedef_t(var.vartype); @@ -117,19 +122,23 @@ namespace occa { success = up->addToScope(*type, force); - // This type typedef's a struct so we need to add that - // type to the current scope - if (success && typedefingStruct) { - struct_t &structType = *((struct_t*) type->baseType.type); - success = up->addToScope(structType, - force); + if (success) { + // We need to add to the current scope + if (typedefingStruct) { + // that this type typedef's a struct + struct_t &structType = *((struct_t*) type->baseType.type); + success = up->addToScope(structType, force); + } else if (typedefingEnum) { + // that this type typedef's an enum + enum_t &enumType = *((enum_t*) type->baseType.type); + success = up->addToScope(enumType, force); + } } - if (!success) { delete type; } - } else if (var.vartype.definesStruct()) { - // Struct + } else if (var.vartype.definesStruct() || var.vartype.definesEnum()) { + // Struct or Enum declaredType = true; success = up->addToScope(var.vartype.type->clone(), @@ -205,10 +214,10 @@ namespace occa { const variableDeclaration &firstDecl = declarations[0]; - // Pretty print newlines around the struct definition + // Pretty print newlines around the struct or enum definition const bool printNewlines = ( declaredType - && firstDecl.variable().vartype.definesStruct() + && (firstDecl.variable().vartype.definesStruct() || firstDecl.variable().vartype.definesEnum()) ); if (printNewlines) { diff --git a/src/occa/internal/lang/type/enum.cpp b/src/occa/internal/lang/type/enum.cpp index 97e1394e2..7957edece 100644 --- a/src/occa/internal/lang/type/enum.cpp +++ b/src/occa/internal/lang/type/enum.cpp @@ -1,23 +1,97 @@ #include +#include +#include +#include namespace occa { namespace lang { + enum_t::enum_t() : - structure_t("") {} + type_t() {} + + enum_t::enum_t(identifierToken &nameToken) : + type_t(nameToken) {} + + enum_t::enum_t(const enum_t &other) : + type_t(other) { + + const int count = (int) other.enumerators.size(); + for (int i = 0; i < count; ++i) { + enumerators.push_back( + other.enumerators[i].clone() + ); + } + } int enum_t::type() const { return typeType::enum_; } type_t& enum_t::clone() const { - return *(new enum_t()); + return *(new enum_t(*this)); } dtype_t enum_t::dtype() const { - return dtype::byte; + dtype_t dtype_; + const int enumeratorsCount = (int) enumerators.size(); + for (int i =0; i < enumeratorsCount; ++i) { + const enumerator_t &enumerator_ = enumerators[i]; + dtype_.addEnumerator(enumerator_.source->value); + } + return dtype_; + } + + void enum_t::addEnumerator(enumerator_t &enumerator_) { + enumerators.push_back( + enumerator_t( + (identifierToken*) enumerator_.source->clone(), + exprNode::clone(enumerator_.expr) + ) + ); + } + + void enum_t::addEnumerators(enumeratorVector &enumerators_) { + const int enumeratorCount = (int) enumerators_.size(); + for (int i = 0; i < enumeratorCount; ++i) { + addEnumerator(enumerators_[i]); + } + } + + void enum_t::debugPrint() const { + printer pout(io::stderr); + printDeclaration(pout); } void enum_t::printDeclaration(printer &pout) const { + const std::string name_ = name(); + if (name_.size()) { + pout << name_ << ' '; + } + + const int enumeratorsCount = (int) enumerators.size(); + if (!enumeratorsCount) { + pout << "{}"; + } else { + pout << "{\n"; + pout.addIndentation(); + pout.printIndentation(); + for (int i = 0; i < enumeratorsCount; ++i) { + const enumerator_t &enumerator_ = enumerators[i]; + if (i) { + pout << ", \n"; + pout.printIndentation(); + } + pout << enumerator_.source->value; + if (enumerator_.expr) { + pout << "="; + pout << enumerator_.expr; + } + } + pout << "\n"; + pout.removeIndentation(); + pout.printIndentation(); + pout << "}"; + } } } } diff --git a/src/occa/internal/lang/type/enum.hpp b/src/occa/internal/lang/type/enum.hpp index 6ac081e85..6b318d7cf 100644 --- a/src/occa/internal/lang/type/enum.hpp +++ b/src/occa/internal/lang/type/enum.hpp @@ -1,19 +1,30 @@ #ifndef OCCA_INTERNAL_LANG_TYPE_ENUM_HEADER #define OCCA_INTERNAL_LANG_TYPE_ENUM_HEADER -#include +#include namespace occa { namespace lang { - class enum_t : public structure_t { + class enum_t; + + class enum_t : public type_t { public: + enumeratorVector enumerators; + enum_t(); + enum_t(identifierToken &nameToken); + enum_t(const enum_t &other); + virtual int type() const; virtual type_t& clone() const; - virtual dtype_t dtype() const; + void addEnumerator(enumerator_t &enumerator_); + void addEnumerators(enumeratorVector &enumerators_); + + void debugPrint() const; + virtual void printDeclaration(printer &pout) const; }; } diff --git a/src/occa/internal/lang/type/type.hpp b/src/occa/internal/lang/type/type.hpp index 48288f20e..c1a8934b0 100644 --- a/src/occa/internal/lang/type/type.hpp +++ b/src/occa/internal/lang/type/type.hpp @@ -19,11 +19,13 @@ namespace occa { class pointer_t; class array_t; class variable_t; + class enumerator_t; typedef std::vector arrayVector; typedef std::vector pointerVector; typedef std::vector variableVector; typedef std::vector variablePtrVector; + typedef std::vectorenumeratorVector; namespace typeType { extern const int none; @@ -35,9 +37,9 @@ namespace occa { extern const int lambda; extern const int class_; + extern const int enum_; extern const int struct_; extern const int union_; - extern const int enum_; extern const int structure; } diff --git a/src/occa/internal/lang/type/typedef.cpp b/src/occa/internal/lang/type/typedef.cpp index 0be9021d2..f2a7df27a 100644 --- a/src/occa/internal/lang/type/typedef.cpp +++ b/src/occa/internal/lang/type/typedef.cpp @@ -20,7 +20,7 @@ namespace occa { declaredBaseType(other.declaredBaseType) {} typedef_t::~typedef_t() { - if (baseType.isNamed() || !baseType.has(struct_)) { + if (baseType.isNamed() || !baseType.has(struct_) || !baseType.has(enum_)) { return; } // The typedef owns the nameless struct diff --git a/src/occa/internal/lang/type/vartype.cpp b/src/occa/internal/lang/type/vartype.cpp index 9d10a2e4d..7eff190dd 100644 --- a/src/occa/internal/lang/type/vartype.cpp +++ b/src/occa/internal/lang/type/vartype.cpp @@ -355,6 +355,21 @@ namespace occa { return flat; } + bool vartype_t::definesEnum() const { + if (typeToken && type && (type->type() & typeType::enum_)) { + return (typeToken->origin == type->source->origin); + } + if (!has(typedef_)) { + return false; + } + + typedef_t &typedefType = *((typedef_t*) type); + return ( + typedefType.declaredBaseType + && typedefType.baseType.has(enum_) + ); + } + bool vartype_t::definesStruct() const { if (typeToken && type && (type->type() & typeType::struct_)) { return (typeToken->origin == type->source->origin); diff --git a/src/occa/internal/lang/type/vartype.hpp b/src/occa/internal/lang/type/vartype.hpp index 75235d76c..310900a00 100644 --- a/src/occa/internal/lang/type/vartype.hpp +++ b/src/occa/internal/lang/type/vartype.hpp @@ -96,6 +96,8 @@ namespace occa { vartype_t flatten() const; + bool definesEnum() const; + bool definesStruct() const; void printDeclaration(printer &pout, From 4d3bc1cc150b9b0ded1b61845c8a2f9f20ce7caa Mon Sep 17 00:00:00 2001 From: Kian O'Hara Date: Tue, 22 Aug 2023 09:53:59 +0200 Subject: [PATCH 3/5] [tests]Add enum tests for statement & type loading We add the necessary tests to ensure that (typedef) enums can be parsed and loaded by OCCA. --- .../internal/lang/parser/statementLoading.cpp | 110 +++++++++++++++++- .../src/internal/lang/parser/typeLoading.cpp | 27 +++++ 2 files changed, 132 insertions(+), 5 deletions(-) diff --git a/tests/src/internal/lang/parser/statementLoading.cpp b/tests/src/internal/lang/parser/statementLoading.cpp index 69988bffd..0261b6ff2 100644 --- a/tests/src/internal/lang/parser/statementLoading.cpp +++ b/tests/src/internal/lang/parser/statementLoading.cpp @@ -31,7 +31,7 @@ int main(const int argc, const char **argv) { testStructLoading(); // testClassLoading(); // testUnionLoading(); - // testEnumLoading(); + testEnumLoading(); testFunctionLoading(); testIfLoading(); testForLoading(); @@ -287,8 +287,109 @@ void testUnionLoading() { } void testEnumLoading() { - // TODO: Add enum tests + statement_t *statement = NULL; + enum_t *enumType = NULL; + typedef_t *typedefType = NULL; + +#define declSmnt statement->to() +#define getDeclType declSmnt.declarations[0].variable().vartype.type +#define setEnumType() enumType = (enum_t*) getDeclType +#define setTypedefType() typedefType = (typedef_t*) getDeclType + + // Test default enum + setStatement( + "enum Foo {\n" + " a,\n" + " b, c = 10,\n" + " d, e = 1, f,\n" + "g = 1 + 2,\n" + "h = c + g" + "};", + statementType::declaration + ); + + setEnumType(); + + ASSERT_EQ("Foo", + enumType->name()); + + ASSERT_EQ(8, + (int) enumType->enumerators.size()); + + ASSERT_EQ("a", + enumType->enumerators[0].source->value); + + ASSERT_EQ("b", + enumType->enumerators[1].source->value); + + ASSERT_EQ("c", + enumType->enumerators[2].source->value); + ASSERT_TRUE(enumType->enumerators[2].expr->canEvaluate()); + ASSERT_EQ(10, + (int) enumType->enumerators[2].expr->evaluate()); + + ASSERT_EQ("d", + enumType->enumerators[3].source->value); + + ASSERT_EQ("e", + enumType->enumerators[4].source->value); + ASSERT_TRUE(enumType->enumerators[4].expr->canEvaluate()); + ASSERT_EQ(1, + (int) enumType->enumerators[4].expr->evaluate()); + + ASSERT_EQ("f", + enumType->enumerators[5].source->value); + + ASSERT_EQ("g", + enumType->enumerators[6].source->value); + ASSERT_TRUE(enumType->enumerators[6].expr->canEvaluate()); + ASSERT_EQ(3, + (int) enumType->enumerators[6].expr->evaluate()); + + // How use Assert to check evaluate() "h"? + ASSERT_EQ("h", + enumType->enumerators[7].source->value); + ASSERT_TRUE((strcmp(enumType->enumerators[7].expr->toString().c_str(), "c + g") == 0)); + // Test default typedef enum + setStatement( + "typedef enum Bar_t {\n" + " a, b,\n" + " c = 0,\n" + "d = c + 2\n" + "} Bar;", + statementType::declaration + ); + setTypedefType(); + + ASSERT_EQ("Bar", + typedefType->name()); + + ASSERT_EQ("Bar_t", + typedefType->baseType.name()); + + // Test typedef anonymous enum + setStatement( + "typedef enum {\n" + " a, b,\n" + " c = 0,\n" + "d = c + 2\n" + "} Bar;", + statementType::declaration + ); + + setTypedefType(); + + ASSERT_EQ("Bar", + typedefType->name()); + + ASSERT_EQ(0, + (int) typedefType->baseType.name().size()); + +#undef declSmnt +#undef getDeclType +#undef getEnumType +#undef getTypedefType } void testFunctionLoading() { @@ -329,9 +430,8 @@ void testFunctionLoading() { setStatement("void foo3(int a, int b) { int x; int y; }", statementType::functionDecl); ASSERT_EQ("foo3", - - funcDecl.name()) ; - + funcDecl.name()) ; + ASSERT_EQ(&void_, funcDecl.returnType.type); ASSERT_EQ(2, diff --git a/tests/src/internal/lang/parser/typeLoading.cpp b/tests/src/internal/lang/parser/typeLoading.cpp index 8efe05a67..51b886503 100644 --- a/tests/src/internal/lang/parser/typeLoading.cpp +++ b/tests/src/internal/lang/parser/typeLoading.cpp @@ -8,6 +8,7 @@ void testVariableLoading(); void testArgumentLoading(); void testFunctionPointerLoading(); void testStructLoading(); +void testEnumLoading(); void testBaseTypeErrors(); void testPointerTypeErrors(); @@ -25,6 +26,8 @@ int main(const int argc, const char **argv) { testArgumentLoading(); testFunctionPointerLoading(); testStructLoading(); + testEnumLoading(); + std::cerr << "\n---[ Testing type errors ]----------------------\n\n"; testBaseTypeErrors(); @@ -334,6 +337,30 @@ void testStructLoading() { ASSERT_TRUE(foo4.has(struct_)); } +void testEnumLoading() { + vartype_t type; + + type = loadType("enum foo1 {}"); + ASSERT_EQ("foo1", type.name()); + ASSERT_TRUE(type.has(enum_)); + + type = loadType("enum foo2 {} bar2"); + ASSERT_EQ("foo2", type.name()); + ASSERT_TRUE(type.has(enum_)); + + type = loadType("enum {} bar3"); + ASSERT_EQ(0, (int) type.name().size()); + ASSERT_TRUE(type.has(enum_)); + + type = loadType("typedef enum foo4 {} bar4"); + ASSERT_EQ("bar4", type.name()); + ASSERT_TRUE(type.has(typedef_)); + + vartype_t foo4 = ((typedef_t*) type.type)->baseType; + ASSERT_EQ("foo4", foo4.name()); + ASSERT_TRUE(foo4.has(enum_)); +} + void testBaseTypeErrors() { vartype_t type; type = loadType("const"); From 008ed5872cdae8c76801e3732e0112d96111e624 Mon Sep 17 00:00:00 2001 From: Kian O'Hara Date: Tue, 22 Aug 2023 10:53:34 +0200 Subject: [PATCH 4/5] [lang] Support for Basic Unions Only the most basic tests with int,float,etc, work. --- include/occa/dtype/dtype.hpp | 66 +++++ src/dtype/dtype.cpp | 269 ++++++++++++++++-- src/occa/internal/lang/loaders.hpp | 1 + src/occa/internal/lang/loaders/typeLoader.cpp | 69 ++++- src/occa/internal/lang/loaders/typeLoader.hpp | 10 + .../internal/lang/loaders/unionLoader.cpp | 100 +++++++ .../internal/lang/loaders/unionLoader.hpp | 37 +++ src/occa/internal/lang/type/typedef.cpp | 4 +- src/occa/internal/lang/type/union.cpp | 77 ++++- src/occa/internal/lang/type/union.hpp | 12 +- src/occa/internal/lang/type/vartype.cpp | 15 + src/occa/internal/lang/type/vartype.hpp | 2 + 12 files changed, 631 insertions(+), 31 deletions(-) create mode 100644 src/occa/internal/lang/loaders/unionLoader.cpp create mode 100644 src/occa/internal/lang/loaders/unionLoader.hpp diff --git a/include/occa/dtype/dtype.hpp b/include/occa/dtype/dtype.hpp index 5e11126ab..1ee4ebc98 100644 --- a/include/occa/dtype/dtype.hpp +++ b/include/occa/dtype/dtype.hpp @@ -13,6 +13,7 @@ namespace occa { class dtypeEnum_t; class dtypeStruct_t; class dtypeTuple_t; + class dtypeUnion_t; class json; typedef std::map dtypeGlobalMap_t; @@ -44,6 +45,8 @@ namespace occa { dtypeEnum_t *enum_; dtypeStruct_t *struct_; dtypeTuple_t *tuple_; + dtypeUnion_t *union_; + mutable dtypeVector_t flatDtype; public: @@ -197,6 +200,39 @@ namespace occa { */ int tupleSize() const; + // Union methods + /** + * @startDoc{isUnion} + * + * Description: + * Returns `true` if the data type represents a union. + * It's different that a tuple since it can keep distinct data types in its fields. + * + * @endDoc + */ + bool isUnion() const; + + /** + * @startDoc{unionFieldCount} + * + * Description: + * Returns how many fields are defined in the union + * + * @endDoc + */ + int unionFieldCount() const; + + /** + * @startDoc{unionFieldNames} + * + * Description: + * Return the list of field names for the union + * + * @endDoc + */ + const strVector& unionFieldNames() const; + + /** * @startDoc{operator_bracket[0]} * @@ -403,6 +439,36 @@ namespace occa { //==================================== + //---[ Union ]----------------------- + class dtypeUnion_t { + friend class dtype_t; + + private: + strVector fieldNames; + dtypeNameMap_t fieldTypes; + + dtypeUnion_t(); + + dtypeUnion_t* clone() const; + + bool matches(const dtypeUnion_t &other) const; + + int fieldCount() const; + + const dtype_t& operator [] (const int field) const; + const dtype_t& operator [] (const std::string &field) const; + + void addField(const std::string &field, + const dtype_t &dtype); + + void addFlatDtypes(dtypeVector_t &vec) const; + + void toJson(json &j, const std::string &name = "") const; + static dtypeUnion_t fromJson(const json &j); + + std::string toString(const std::string &varName = "") const; + }; + //==================================== } #endif diff --git a/src/dtype/dtype.cpp b/src/dtype/dtype.cpp index 87b4102a3..b742c4c00 100644 --- a/src/dtype/dtype.cpp +++ b/src/dtype/dtype.cpp @@ -16,7 +16,8 @@ namespace occa { registered(false), enum_(NULL), struct_(NULL), - tuple_(NULL) {} + tuple_(NULL), + union_(NULL) {} dtype_t::dtype_t(const std::string &name__, const int bytes__, @@ -27,7 +28,8 @@ namespace occa { registered(registered_), enum_(NULL), struct_(NULL), - tuple_(NULL) {} + tuple_(NULL), + union_(NULL) {} dtype_t::dtype_t(const std::string &name__, const dtype_t &other, @@ -38,7 +40,8 @@ namespace occa { registered(false), enum_(NULL), struct_(NULL), - tuple_(NULL) { + tuple_(NULL), + union_(NULL) { *this = other; @@ -53,7 +56,8 @@ namespace occa { registered(false), enum_(NULL), struct_(NULL), - tuple_(NULL) { + tuple_(NULL), + union_(NULL) { *this = other; } @@ -68,6 +72,7 @@ namespace occa { delete enum_; delete struct_; delete tuple_; + delete union_; if (other.registered) { // Clear values @@ -77,6 +82,7 @@ namespace occa { enum_ = NULL; struct_ = NULL; tuple_ = NULL; + union_ = NULL; } else { ref = NULL; name_ = other.name_; @@ -84,6 +90,7 @@ namespace occa { enum_ = other.enum_ ? other.enum_->clone() : NULL; struct_ = other.struct_ ? other.struct_->clone() : NULL; tuple_ = other.tuple_ ? other.tuple_->clone() : NULL; + union_ = other.union_ ? other.union_->clone() : NULL; } } return *this; @@ -93,6 +100,7 @@ namespace occa { delete enum_; delete struct_; delete tuple_; + delete union_; } const std::string& dtype_t::name() const { @@ -162,15 +170,45 @@ namespace occa { } const dtype_t& dtype_t::operator [] (const int field) const { - const dtypeStruct_t *structPtr = self().struct_; - OCCA_ERROR("Cannot access fields from a non-struct dtype_t", structPtr != NULL); - return (*structPtr)[field]; + if (self().union_) { + const dtypeUnion_t *unionPtr = self().union_; + OCCA_ERROR("Cannot access fields from a non-union dtype_t", unionPtr != NULL); + return (*unionPtr)[field]; + } else { + const dtypeStruct_t *structPtr = self().struct_; + OCCA_ERROR("Cannot access fields from a non-struct dtype_t", structPtr != NULL); + return (*structPtr)[field]; + } } const dtype_t& dtype_t::operator [] (const std::string &field) const { - const dtypeStruct_t *structPtr = self().struct_; - OCCA_ERROR("Cannot access fields from a non-struct dtype_t", structPtr != NULL); - return (*structPtr)[field]; + if (self().union_) { + const dtypeUnion_t *unionPtr = self().union_; + OCCA_ERROR("Cannot access fields from a non-union dtype_t", unionPtr != NULL); + return (*unionPtr)[field]; + } else { + const dtypeStruct_t *structPtr = self().struct_; + OCCA_ERROR("Cannot access fields from a non-struct dtype_t", structPtr != NULL); + return (*structPtr)[field]; + } + } + // Union methods + bool dtype_t::isUnion() const { + return self().union_; + } + + int dtype_t::unionFieldCount() const { + const dtypeUnion_t *unionPtr = self().union_; + if (unionPtr) { + return unionPtr->fieldCount(); + } + return 0; + } + + const strVector& dtype_t::unionFieldNames() const { + const dtypeUnion_t *unionPtr = self().union_; + OCCA_ERROR("Cannot get fields from a non-union dtype_t", unionPtr != NULL); + return unionPtr->fieldNames; } dtype_t& dtype_t::addField(const std::string &field, @@ -180,16 +218,30 @@ namespace occa { OCCA_ERROR("Cannot add a field to an tuple dtype_t", tuple_ == NULL); OCCA_ERROR("Tuple size must be a positive integer", tupleSize_ > 0); - if (!struct_) { - struct_ = new dtypeStruct_t(); - } + if (self().union_) { + if (!union_) { + union_ = new dtypeUnion_t(); + } - bytes_ += (dtype.bytes_ * tupleSize_); + bytes_ += (dtype.bytes_ * tupleSize_); - if (tupleSize_ == 1) { - struct_->addField(field, dtype); + if (tupleSize_ == 1) { + union_->addField(field, dtype); + } else { + union_->addField(field, tuple(dtype, tupleSize_)); + } } else { - struct_->addField(field, tuple(dtype, tupleSize_)); + if (!struct_) { + struct_ = new dtypeStruct_t(); + } + + bytes_ += (dtype.bytes_ * tupleSize_); + + if (tupleSize_ == 1) { + struct_->addField(field, dtype); + } else { + struct_->addField(field, tuple(dtype, tupleSize_)); + } } return *this; @@ -208,6 +260,8 @@ namespace occa { self_.struct_->addFlatDtypes(vec); } else if (self_.tuple_) { self_.tuple_->addFlatDtypes(vec); + } else if (self_.union_) { + self_.union_->addFlatDtypes(vec); } else { vec.push_back(&self_); } @@ -240,6 +294,7 @@ namespace occa { if (a.registered != b.registered) { return false; } + // Refs didn't match and both a and b are registered if (a.registered) { return false; @@ -248,7 +303,8 @@ namespace occa { // Check type differences if (((bool) a.enum_ != (bool) b.enum_) || ((bool) a.struct_ != (bool) b.struct_) || - ((bool) a.tuple_ != (bool) b.tuple_)) { + ((bool) a.tuple_ != (bool) b.tuple_) || + ((bool) a.union_ != (bool) b.union_)) { return false; } // Check from the dtype type @@ -261,6 +317,9 @@ namespace occa { if (a.tuple_) { return a.tuple_->matches(*(b.tuple_)); } + if (a.union_) { + return a.union_->matches(*(b.union_)); + } // Shouldn't get here return false; @@ -431,6 +490,8 @@ namespace occa { return struct_->toJson(j, name); } else if (tuple_) { return tuple_->toJson(j, name); + } else if (union_) { + return union_->toJson(j, name); } j.clear(); @@ -472,6 +533,8 @@ namespace occa { dtype.struct_ = dtypeStruct_t::fromJson(j).clone(); } else if (type == "tuple") { dtype.tuple_ = dtypeTuple_t::fromJson(j).clone(); + } else if (type == "union") { + dtype.union_ = dtypeUnion_t::fromJson(j).clone(); } else if (type == "custom") { dtype.bytes_ = (int) j["bytes"]; } else { @@ -498,6 +561,8 @@ namespace occa { ss << self_.struct_->toString(name); } else if (self_.tuple_) { ss << self_.tuple_->toString(name); + } else if (self_.union_) { + ss << self_.union_->toString(name); } else { ss << name; } @@ -865,5 +930,173 @@ namespace occa { } //==================================== + //---[ Union ]----------------------- + dtypeUnion_t::dtypeUnion_t() {} + + dtypeUnion_t* dtypeUnion_t::clone() const { + dtypeUnion_t *s = new dtypeUnion_t(); + s->fieldNames = fieldNames; + s->fieldTypes = fieldTypes; + return s; + } + + bool dtypeUnion_t::matches(const dtypeUnion_t &other) const { + const int fieldCount = (int) fieldNames.size(); + if (fieldCount != (int) other.fieldNames.size()) { + return false; + } + + // Compare fields + const std::string *names1 = &(fieldNames[0]); + const std::string *names2 = &(other.fieldNames[0]); + for (int i = 0; i < fieldCount; ++i) { + const std::string &name1 = names1[i]; + const std::string &name2 = names2[i]; + if (name1 != name2) { + return false; + } + const dtype_t &dtype1 = fieldTypes.find(name1)->second; + const dtype_t &dtype2 = fieldTypes.find(name2)->second; + if (!dtype1.matches(dtype2)) { + return false; + } + } + + return true; + } + + int dtypeUnion_t::fieldCount() const { + return (int) fieldNames.size(); + } + + const dtype_t& dtypeUnion_t::operator [] (const int field) const { + OCCA_ERROR("Field index is out of bounds", + (0 <= field) && (field < (int) fieldNames.size())); + dtypeNameMap_t::const_iterator it = fieldTypes.find(fieldNames[field]); + return it->second; + } + + const dtype_t& dtypeUnion_t::operator [] (const std::string &field) const { + dtypeNameMap_t::const_iterator it = fieldTypes.find(field); + OCCA_ERROR("Field [" << field << "] is not in dtype_t", + it != fieldTypes.end()); + return it->second; + } + + void dtypeUnion_t::addField(const std::string &field, + const dtype_t &dtype) { + const bool fieldExists = (fieldTypes.find(field) != fieldTypes.end()); + OCCA_ERROR("Field [" << field << "] is already in dtype_t", + !fieldExists); + + if (!fieldExists) { + fieldNames.push_back(field); + fieldTypes[field] = dtype; + } + } + + void dtypeUnion_t::addFlatDtypes(dtypeVector_t &vec) const { + const int fieldCount = (int) fieldNames.size(); + const std::string *names = &(fieldNames[0]); + for (int i = 0; i < fieldCount; ++i) { + const std::string &name = names[i]; + const dtype_t &dtype = fieldTypes.find(name)->second; + dtype.addFlatDtypes(vec); + } + } + + void dtypeUnion_t::toJson(json &j, const std::string &name) const { + j.clear(); + j.asObject(); + + j["type"] = "union"; + if (name.size()) { + j["name"] = name; + } + + json &fieldsJson = j["fields"].asArray(); + const int fieldCount = (int) fieldNames.size(); + + const std::string *names = &(fieldNames[0]); + for (int i = 0; i < fieldCount; ++i) { + const std::string &fieldName = names[i]; + const dtype_t &dtype = fieldTypes.find(fieldName)->second; + + json fieldJson; + fieldJson["dtype"] = dtype::toJson(dtype); + fieldJson["name"] = fieldName; + fieldsJson += fieldJson; + } + } + + dtypeUnion_t dtypeUnion_t::fromJson(const json &j) { + OCCA_ERROR("JSON field [fields] missing from union", + j.has("fields")); + OCCA_ERROR("JSON field [fields] must be an array of dtypes", + j["fields"].isArray()); + + const jsonArray &fields = j["fields"].array(); + const int fieldCount = (int) fields.size(); + + dtypeUnion_t union_; + for (int i = 0; i < fieldCount; ++i) { + const json &fieldJson = fields[i]; + OCCA_ERROR("JSON field [dtype] missing from union field", + fieldJson.has("dtype")); + OCCA_ERROR("JSON field [name] missing from union field", + fieldJson.has("name")); + OCCA_ERROR("JSON field [name] must be a string for union fields", + fieldJson["name"].isString()); + + union_.addField(fieldJson["name"].string(), + dtype_t::fromJson(fieldJson["dtype"])); + } + + return union_; + } + + std::string dtypeUnion_t::toString(const std::string &varName) const { + std::stringstream ss; + const int fieldCount = (int) fieldNames.size(); + + ss << "union "; + if (varName.size()) { + ss << varName << ' '; + } + ss << '{'; + + if (!fieldCount) { + ss << '}'; + return ss.str(); + } + + ss << '\n'; + + const std::string *names = &(fieldNames[0]); + dtype_t prevDtype = dtype::none; + for (int i = 0; i < fieldCount; ++i) { + const std::string &name = names[i]; + const dtype_t &dtype = fieldTypes.find(name)->second; + + if (prevDtype != dtype) { + prevDtype = dtype; + if (i) { + ss << ";\n"; + } + ss << " " << dtype.toString(name); + } else { + if (!i) { + prevDtype = dtype; + } + ss << ", " << name; + } + } + ss << ";\n}"; + + return ss.str(); + } + //==================================== + + } diff --git a/src/occa/internal/lang/loaders.hpp b/src/occa/internal/lang/loaders.hpp index d650b47dc..febd91bd9 100644 --- a/src/occa/internal/lang/loaders.hpp +++ b/src/occa/internal/lang/loaders.hpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #endif diff --git a/src/occa/internal/lang/loaders/typeLoader.cpp b/src/occa/internal/lang/loaders/typeLoader.cpp index fc2529ce5..f56379201 100644 --- a/src/occa/internal/lang/loaders/typeLoader.cpp +++ b/src/occa/internal/lang/loaders/typeLoader.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -8,6 +9,7 @@ #include #include #include +#include #include namespace occa { @@ -65,11 +67,7 @@ namespace occa { if (kType & keywordType::qualifier) { const qualifier_t &qualifier = keyword.to().qualifier; type_t *type = NULL; - if (qualifier == union_) { - // TODO: type = loadUnion(); - token->printError("Unions are not supported yet"); - success = false; - } else if (qualifier == class_) { + if (qualifier == class_) { // TODO: type = loadClass(); token->printError("Classes are not supported yet"); success = false; @@ -121,7 +119,10 @@ namespace occa { loadStruct(vartype); return success; } - + if (vartype.has(union_)) { + loadUnion(vartype); + return success; + } tokenContext.printError("Expected a type"); return false; } @@ -284,6 +285,45 @@ namespace occa { *typedefType); } + void typeLoader_t::loadUnion(vartype_t &vartype) { + unionLoader_t unionLoader(tokenContext, smntContext, parser); + + // Load union + union_t *unionType = NULL; + success &= unionLoader.loadUnion(unionType); + if (!success) { + return; + } + + if (!vartype.has(typedef_)) { + vartype.setType(*((identifierToken*) unionType->source), + *unionType); + return; + } + + // Load typedef name + if (!(token_t::safeType(tokenContext[0]) & tokenType::identifier)) { + tokenContext.printError("Expected typedef name"); + success = false; + return; + } + + identifierToken *nameToken = (identifierToken*) tokenContext[0]; + ++tokenContext; + + // Move the union qualifier over + vartype_t unionVartype(*((identifierToken*) unionType->source), + *unionType); + unionVartype += union_; + vartype -= union_; + + typedef_t *typedefType = new typedef_t(unionVartype, *nameToken); + typedefType->declaredBaseType = true; + + vartype.setType(*nameToken, + *typedefType); + } + bool loadType(tokenContext_t &tokenContext, statementContext_t &smntContext, parser_t &parser, @@ -333,5 +373,22 @@ namespace occa { vartype.has(struct_) && // Should have struct_ !vartype.has(typedef_)); // typedef struct is not loaded as a struct } + + bool isLoadingUnion(tokenContext_t &tokenContext, + statementContext_t &smntContext, + parser_t &parser) { + tokenContext.push(); + tokenContext.supressErrors = true; + + vartype_t vartype; + loadType(tokenContext, smntContext, parser, vartype); + + tokenContext.supressErrors = false; + tokenContext.pop(); + + return (!vartype.isValid() && // Should not have a base type since we're defining it + vartype.has(union_) && // Should have union_ + !vartype.has(typedef_)); // typedef union is not loaded as a union + } } } diff --git a/src/occa/internal/lang/loaders/typeLoader.hpp b/src/occa/internal/lang/loaders/typeLoader.hpp index 260d411f7..c293b0baa 100644 --- a/src/occa/internal/lang/loaders/typeLoader.hpp +++ b/src/occa/internal/lang/loaders/typeLoader.hpp @@ -36,6 +36,8 @@ namespace occa { void loadStruct(vartype_t &vartype); + void loadUnion(vartype_t &vartype); + friend bool loadType(tokenContext_t &tokenContext, statementContext_t &smntContext, parser_t &parser, @@ -53,6 +55,10 @@ namespace occa { friend bool isLoadingStruct(tokenContext_t &tokenContext, statementContext_t &smntContext, parser_t &parser); + + friend bool isLoadingUnion(tokenContext_t &tokenContext, + statementContext_t &smntContext, + parser_t &parser); }; bool loadType(tokenContext_t &tokenContext, @@ -72,6 +78,10 @@ namespace occa { bool isLoadingStruct(tokenContext_t &tokenContext, statementContext_t &smntContext, parser_t &parser); + + bool isLoadingUnion(tokenContext_t &tokenContext, + statementContext_t &smntContext, + parser_t &parser); } } diff --git a/src/occa/internal/lang/loaders/unionLoader.cpp b/src/occa/internal/lang/loaders/unionLoader.cpp new file mode 100644 index 000000000..72bfa9117 --- /dev/null +++ b/src/occa/internal/lang/loaders/unionLoader.cpp @@ -0,0 +1,100 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace occa { + namespace lang { + unionLoader_t::unionLoader_t(tokenContext_t &tokenContext_, + statementContext_t &smntContext_, + parser_t &parser_) : + tokenContext(tokenContext_), + smntContext(smntContext_), + parser(parser_) {} + + bool unionLoader_t::loadUnion(union_t *&type) { + type = NULL; + + // Store type declarations in temporary block statement + blockStatement *blockSmnt = new blockStatement(smntContext.up, + tokenContext[0]); + smntContext.pushUp(*blockSmnt); + + identifierToken *nameToken = NULL; + const bool hasName = token_t::safeType(tokenContext[0]) & tokenType::identifier; + if (hasName) { + nameToken = (identifierToken*) tokenContext[0]; + ++tokenContext; + } + + opType_t opType = token_t::safeOperatorType(tokenContext[0]); + if (!(opType & (operatorType::braceStart | + operatorType::scope))) { + tokenContext.printError("Expected union body {}"); + delete blockSmnt; + smntContext.popUp(); + return false; + } + + tokenContext.pushPairRange(); + + // Load type declaration statements + statement_t *smnt = parser.getNextStatement(); + variableVector fields; + while (smnt) { + const int sType = smnt->type(); + if (!(sType & statementType::declaration)) { + if (sType & (statementType::function | + statementType::functionDecl)) { + smnt->printError("Union functions are not supported yet"); + } else if (sType & statementType::classAccess) { + smnt->printError("Access modifiers are not supported yet"); + } else { + smnt->printError("Expected variable declaration statements"); + } + delete blockSmnt; + smntContext.popUp(); + return false; + } + + variableDeclarationVector &declarations = (smnt + ->to() + .declarations); + const int varCount = (int) declarations.size(); + for (int i = 0; i < varCount; ++i) { + variableDeclaration &decl = declarations[i]; + if (decl.value) { + decl.value->printError("Union fields cannot have default values"); + delete blockSmnt; + smntContext.popUp(); + return false; + } + fields.push_back(decl.variable()); + } + delete smnt; + smnt = parser.getNextStatement(); + } + + delete blockSmnt; + smntContext.popUp(); + + tokenContext.popAndSkip(); + + type = nameToken ? new union_t(*nameToken) : new union_t(); + type->addFields(fields); + + return true; + } + + bool loadUnion(tokenContext_t &tokenContext, + statementContext_t &smntContext, + parser_t &parser, + union_t *&type) { + unionLoader_t loader(tokenContext, smntContext, parser); + return loader.loadUnion(type); + } + } +} diff --git a/src/occa/internal/lang/loaders/unionLoader.hpp b/src/occa/internal/lang/loaders/unionLoader.hpp new file mode 100644 index 000000000..f508aca13 --- /dev/null +++ b/src/occa/internal/lang/loaders/unionLoader.hpp @@ -0,0 +1,37 @@ +#ifndef OCCA_INTERNAL_LANG_PARSER_UNIONLOADER_HEADER +#define OCCA_INTERNAL_LANG_PARSER_UNIONLOADER_HEADER + +#include +#include + +namespace occa { + namespace lang { + class union_t; + class parser_t; + + class unionLoader_t { + public: + tokenContext_t &tokenContext; + statementContext_t &smntContext; + parser_t &parser; + + unionLoader_t(tokenContext_t &tokenContext_, + statementContext_t &smntContext_, + parser_t &parser_); + + bool loadUnion(union_t *&type); + + friend bool loadUnion(tokenContext_t &tokenContext, + statementContext_t &smntContext, + parser_t &parser, + union_t *&type); + }; + + bool loadUnion(tokenContext_t &tokenContext, + statementContext_t &smntContext, + parser_t &parser, + union_t *&type); + } +} + +#endif diff --git a/src/occa/internal/lang/type/typedef.cpp b/src/occa/internal/lang/type/typedef.cpp index f2a7df27a..bd1f35208 100644 --- a/src/occa/internal/lang/type/typedef.cpp +++ b/src/occa/internal/lang/type/typedef.cpp @@ -20,10 +20,10 @@ namespace occa { declaredBaseType(other.declaredBaseType) {} typedef_t::~typedef_t() { - if (baseType.isNamed() || !baseType.has(struct_) || !baseType.has(enum_)) { + if (baseType.isNamed() || !baseType.has(enum_) || !baseType.has(struct_) || !baseType.has(union_)) { return; } - // The typedef owns the nameless struct + // The typedef owns the nameless object delete baseType.type; } diff --git a/src/occa/internal/lang/type/union.cpp b/src/occa/internal/lang/type/union.cpp index 32a01d9b0..729dce3f2 100644 --- a/src/occa/internal/lang/type/union.cpp +++ b/src/occa/internal/lang/type/union.cpp @@ -1,23 +1,94 @@ #include +#include +#include namespace occa { namespace lang { union_t::union_t() : - structure_t("") {} + type_t() {} + + union_t::union_t(identifierToken &nameToken) : + type_t(nameToken) {} + + union_t::union_t(const union_t &other) : + type_t(other) { + + const int count = (int) other.fields.size(); + for (int i = 0; i < count; ++i) { + fields.push_back( + other.fields[i].clone() + ); + } + } int union_t::type() const { return typeType::union_; } type_t& union_t::clone() const { - return *(new union_t()); + return *(new union_t(*this)); } dtype_t union_t::dtype() const { - return dtype::byte; + dtype_t dtype_; + + const int fieldCount = (int) fields.size(); + for (int i = 0; i < fieldCount; ++i) { + const variable_t &var = fields[i]; + dtype_.addField(var.name(), + var.dtype()); + } + + return dtype_; + } + + void union_t::addField(variable_t &field) { + fields.push_back(field.clone()); + } + + void union_t::addFields(variableVector &fields_) { + const int fieldCount = (int) fields_.size(); + for (int i = 0; i < fieldCount; ++i) { + fields.push_back(fields_[i].clone()); + } } void union_t::printDeclaration(printer &pout) const { + const std::string name_ = name(); + if (name_.size()) { + pout << name_ << ' '; + } + + const int fieldCount = (int) fields.size(); + if (!fieldCount) { + pout << "{}"; + } else { + vartype_t prevVartype; + + pout << "{\n"; + pout.addIndentation(); + pout.printIndentation(); + + for (int i = 0; i < fieldCount; ++i) { + const variable_t &var = fields[i]; + if (prevVartype != var.vartype) { + if (i) { + pout << ";\n"; + pout.printIndentation(); + } + prevVartype = var.vartype; + var.printDeclaration(pout); + } else { + pout << ", "; + var.printExtraDeclaration(pout); + } + } + pout << ";\n"; + + pout.removeIndentation(); + pout.printIndentation(); + pout << "}"; + } } } } diff --git a/src/occa/internal/lang/type/union.hpp b/src/occa/internal/lang/type/union.hpp index 51dc7ed59..ec48c2b4a 100644 --- a/src/occa/internal/lang/type/union.hpp +++ b/src/occa/internal/lang/type/union.hpp @@ -1,19 +1,27 @@ #ifndef OCCA_INTERNAL_LANG_TYPE_UNION_HEADER #define OCCA_INTERNAL_LANG_TYPE_UNION_HEADER -#include +#include namespace occa { namespace lang { - class union_t : public structure_t { + class union_t : public type_t { public: + variableVector fields; + union_t(); + union_t(identifierToken &nameToken); + + union_t(const union_t &other); virtual int type() const; virtual type_t& clone() const; virtual dtype_t dtype() const; + void addField(variable_t &var); + void addFields(variableVector &fields_); + virtual void printDeclaration(printer &pout) const; }; } diff --git a/src/occa/internal/lang/type/vartype.cpp b/src/occa/internal/lang/type/vartype.cpp index 7eff190dd..589472fae 100644 --- a/src/occa/internal/lang/type/vartype.cpp +++ b/src/occa/internal/lang/type/vartype.cpp @@ -385,6 +385,21 @@ namespace occa { ); } + bool vartype_t::definesUnion() const { + if (typeToken && type && (type->type() & typeType::union_)) { + return (typeToken->origin == type->source->origin); + } + if (!has(typedef_)) { + return false; + } + + typedef_t &typedefType = *((typedef_t*) type); + return ( + typedefType.declaredBaseType + && typedefType.baseType.has(union_) + ); + } + void vartype_t::printDeclaration(printer &pout, const std::string &varName, const vartypePrintType_t printType) const { diff --git a/src/occa/internal/lang/type/vartype.hpp b/src/occa/internal/lang/type/vartype.hpp index 310900a00..6515e8ac8 100644 --- a/src/occa/internal/lang/type/vartype.hpp +++ b/src/occa/internal/lang/type/vartype.hpp @@ -100,6 +100,8 @@ namespace occa { bool definesStruct() const; + bool definesUnion() const; + void printDeclaration(printer &pout, const std::string &varName, const vartypePrintType_t printType = vartypePrintType_t::type) const; From 84f843e30aa52fd3eb0017e9d9bccdf30ad90a3e Mon Sep 17 00:00:00 2001 From: Kian O'Hara Date: Tue, 22 Aug 2023 09:56:24 +0200 Subject: [PATCH 5/5] [tests]Add union tests, statement & type loading We add the necessary tests to ensure that (typedef) unions can be parsed and loaded by OCCA. --- .../internal/lang/parser/statementLoading.cpp | 79 ++++++++++++++++++- .../src/internal/lang/parser/typeLoading.cpp | 26 ++++++ 2 files changed, 103 insertions(+), 2 deletions(-) diff --git a/tests/src/internal/lang/parser/statementLoading.cpp b/tests/src/internal/lang/parser/statementLoading.cpp index 0261b6ff2..e78dff690 100644 --- a/tests/src/internal/lang/parser/statementLoading.cpp +++ b/tests/src/internal/lang/parser/statementLoading.cpp @@ -30,8 +30,8 @@ int main(const int argc, const char **argv) { testNamespaceLoading(); testStructLoading(); // testClassLoading(); - // testUnionLoading(); testEnumLoading(); + testUnionLoading(); testFunctionLoading(); testIfLoading(); testForLoading(); @@ -283,7 +283,82 @@ void testClassLoading() { } void testUnionLoading() { - // TODO: Add union tests + statement_t *statement = NULL; + union_t *unionType = NULL; + typedef_t *typedefType = NULL; + +#define declSmnt statement->to() +#define getDeclType declSmnt.declarations[0].variable().vartype.type +#define setUnionType() unionType = (union_t*) getDeclType +#define setTypedefType() typedefType = (typedef_t*) getDeclType + + // Test default union + setStatement( + "union idx3 {\n" + " int i, *j, &k;\n" + "};", + statementType::declaration + ); + + setUnionType(); + + ASSERT_EQ("idx3", + unionType->name()); + + ASSERT_EQ(3, + (int) unionType->fields.size()); + + ASSERT_EQ("i", + unionType->fields[0].name()); + ASSERT_EQ(&int_, + unionType->fields[0].vartype.type); + + ASSERT_EQ("j", + unionType->fields[1].name()); + ASSERT_EQ(&int_, + unionType->fields[1].vartype.type); + + ASSERT_EQ("k", + unionType->fields[2].name()); + ASSERT_EQ(&int_, + unionType->fields[2].vartype.type); + + // Test default typedef union + setStatement( + "typedef union idx3_t {\n" + " int i, *j, &k;\n" + "} idx3;", + statementType::declaration + ); + + setTypedefType(); + + ASSERT_EQ("idx3", + typedefType->name()); + + ASSERT_EQ("idx3_t", + typedefType->baseType.name()); + + // Test typedef anonymous union + setStatement( + "typedef union {\n" + " int i, *j, &k;\n" + "} idx3;", + statementType::declaration + ); + + setTypedefType(); + + ASSERT_EQ("idx3", + typedefType->name()); + + ASSERT_EQ(0, + (int) typedefType->baseType.name().size()); + +#undef declSmnt +#undef getDeclType +#undef getUnionType +#undef getTypedefType } void testEnumLoading() { diff --git a/tests/src/internal/lang/parser/typeLoading.cpp b/tests/src/internal/lang/parser/typeLoading.cpp index 51b886503..9a421d398 100644 --- a/tests/src/internal/lang/parser/typeLoading.cpp +++ b/tests/src/internal/lang/parser/typeLoading.cpp @@ -9,6 +9,7 @@ void testArgumentLoading(); void testFunctionPointerLoading(); void testStructLoading(); void testEnumLoading(); +void testUnionLoading(); void testBaseTypeErrors(); void testPointerTypeErrors(); @@ -28,6 +29,7 @@ int main(const int argc, const char **argv) { testStructLoading(); testEnumLoading(); + testUnionLoading(); std::cerr << "\n---[ Testing type errors ]----------------------\n\n"; testBaseTypeErrors(); @@ -361,6 +363,30 @@ void testEnumLoading() { ASSERT_TRUE(foo4.has(enum_)); } +void testUnionLoading() { + vartype_t type; + + type = loadType("union foo1 {}"); + ASSERT_EQ("foo1", type.name()); + ASSERT_TRUE(type.has(union_)); + + type = loadType("union foo2 {} bar2"); + ASSERT_EQ("foo2", type.name()); + ASSERT_TRUE(type.has(union_)); + + type = loadType("union {} bar3"); + ASSERT_EQ(0, (int) type.name().size()); + ASSERT_TRUE(type.has(union_)); + + type = loadType("typedef union foo4 {} bar4"); + ASSERT_EQ("bar4", type.name()); + ASSERT_TRUE(type.has(typedef_)); + + vartype_t foo4 = ((typedef_t*) type.type)->baseType; + ASSERT_EQ("foo4", foo4.name()); + ASSERT_TRUE(foo4.has(union_)); +} + void testBaseTypeErrors() { vartype_t type; type = loadType("const");