diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index 92d3e7149463..68cab83d13f2 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -28,7 +28,37 @@ #include #include +#include #include +#include +// We use c++14 std::experimental::string_view for optimizing hash computation +// only right now, its usage is limited in this file. Any broader usage of +// std::experiment in our core codebase is discouraged and needs community +// discussion for each use case. Reference for feature test macros of +// string_view: +// https://isocpp.org/std/standing-documents/sd-6-sg10-feature-test-recommendations +// https://en.cppreference.com/w/User:D41D8CD98F/feature_testing_macros +#if defined(__cpp_lib_experimental_string_view) && \ + __cpp_lib_experimental_string_view >= 201411 +#define TVM_USE_CXX14_STRING_VIEW_HASH 1 +#else +#define TVM_USE_CXX14_STRING_VIEW_HASH 0 +#endif + +// Tested with clang version 9.0.1 and c++17. It will detect string_view support +// correctly. +#if defined(__cpp_lib_string_view) && __cpp_lib_string_view >= 201606 +#define TVM_USE_CXX17_STRING_VIEW_HASH 1 +#else +#define TVM_USE_CXX17_STRING_VIEW_HASH 0 +#endif + +#if TVM_USE_CXX17_STRING_VIEW_HASH +#include +#elif TVM_USE_CXX14_STRING_VIEW_HASH +#include +#endif + #include #include #include @@ -274,7 +304,285 @@ class ADT : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(ADT, ObjectRef, ADTObj); }; +/*! \brief An object representing string. It's POD type. */ +class StringObj : public Object { + public: + /*! \brief The pointer to string data. */ + const char* data; + + /*! \brief The length of the string object. */ + uint64_t size; + + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "runtime.String"; + TVM_DECLARE_FINAL_OBJECT_INFO(StringObj, Object); + + private: + /*! \brief String object which is moved from std::string container. */ + class FromStd; + + friend class String; +}; + +/*! + * \brief Reference to string objects. + * + * \code + * + * // Example to create runtime String reference object from std::string + * std::string s = "hello world"; + * + * // You can create the reference from existing std::string + * String ref{std::move(s)}; + * + * // You can rebind the reference to another string. + * ref = std::string{"hello world2"}; + * + * // You can use the reference as hash map key + * std::unordered_map m; + * m[ref] = 1; + * + * // You can compare the reference object with other string objects + * assert(ref == "hello world", true); + * + * // You can convert the reference to std::string again + * string s2 = (string)ref; + * + * \endcode + */ +class String : public ObjectRef { + public: + /*! + * \brief Construct a new String object + * + * \param other The moved/copied std::string object + * + * \note If user passes const reference, it will trigger copy. If it's rvalue, + * it will be moved into other. + */ + explicit String(std::string other); + + /*! + * \brief Change the value the reference object points to. + * + * \param other The value for the new String + * + */ + inline String operator=(std::string other); + + /*! + * \brief Compare is equal to other std::string + * + * \param other The other string + * + * \return the comparison result + */ + bool operator==(const std::string& other) const { + return this->compare(other) == 0; + } + + /*! + * \brief Compare is not equal to other std::string + * + * \param other The other string + * + * \return the comparison result + */ + bool operator!=(const std::string& other) const { return !operator==(other); } + + /*! + * \brief Compare is equal to other char string + * + * \param other The other char string + * + * \return the comparison result + */ + bool operator==(const char* other) const { return compare(other) == 0; } + + /*! + * \brief Compare is not equal to other char string + * + * \param other The other char string + * + * \return the comparison result + */ + bool operator!=(const char* other) const { return !operator==(other); } + + /*! + * \brief Compares this String object to other + * + * \param other The String to compare with. + * + * \return zero if both char sequences compare equal. negative if this appear + * before other, positive otherwise. + */ + int compare(const String& other) const { + return memncmp(data(), other.data(), size(), other.size()); + } + + /*! + * \brief Compares this String object to other + * + * \param other The string to compare with. + * + * \return zero if both char sequences compare equal. negative if this appear + * before other, positive otherwise. + */ + int compare(const std::string& other) const { + return memncmp(data(), other.data(), size(), other.size()); + } + + /*! + * \brief Compares this to other + * + * \param other The character array to compare with. + * + * \return zero if both char sequences compare equal. negative if this appear + * before other, positive otherwise. + */ + int compare(const char* other) const { + return memncmp(data(), other, size(), std::strlen(other)); + } + + /*! + * \brief Returns a pointer to the char array in the string. + * + * \return const char* + */ + const char* c_str() const { return get()->data; } + + /*! + * \brief Return the length of the string + * + * \return size_t string length + */ + size_t size() const { + const auto* ptr = get(); + if (ptr == nullptr) { + return 0; + } + return ptr->size; + } + + /*! + * \brief Return the length of the string + * + * \return size_t string length + */ + size_t length() const { return size(); } + + /*! + * \brief Retun if the string is empty + * + * \return true if empty, false otherwise. + */ + bool empty() const { return size() == 0; } + + /*! + * \brief Return the data pointer + * + * \return const char* data pointer + */ + const char* data() const { return get()->data; } + + /*! + * \brief Convert String to an std::sting object + * + * \return std::string + */ + operator std::string() const { return std::string{get()->data, size()}; } + + TVM_DEFINE_OBJECT_REF_METHODS(String, ObjectRef, StringObj); + + private: + /*! \return the internal StringObj pointer */ + const StringObj* get() const { return operator->(); } + + /*! + * \brief Compare two char sequence + * + * \param lhs Pointers to the char array to compare + * \param rhs Pointers to the char array to compare + * \param lhs_count Length of the char array to compare + * \param rhs_count Length of the char array to compare + * \return int zero if both char sequences compare equal. negative if this + * appear before other, positive otherwise. + */ + static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, + size_t rhs_count); +}; + +/*! \brief An object representing string moved from std::string. */ +class StringObj::FromStd : public StringObj { + public: + /*! + * \brief Construct a new FromStd object + * + * \param other The moved/copied std::string object + * + * \note If user passes const reference, it will trigger copy. If it's rvalue, + * it will be moved into other. + */ + explicit FromStd(std::string other) : data_container{other} {} + + private: + /*! \brief Container that holds the memory. */ + std::string data_container; + + friend class String; +}; + +inline String::String(std::string other) { + auto ptr = make_object(std::move(other)); + ptr->size = ptr->data_container.size(); + ptr->data = ptr->data_container.data(); + data_ = std::move(ptr); +} + +inline String String::operator=(std::string other) { + String replace{std::move(other)}; + data_.swap(replace.data_); + return Downcast(*this); +} + +inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, + size_t rhs_count) { + if (lhs == rhs && lhs_count == rhs_count) return 0; + + for (size_t i = 0; i < lhs_count && i < rhs_count; ++i) { + if (lhs[i] < rhs[i]) return -1; + if (lhs[i] > rhs[i]) return 1; + } + if (lhs_count < rhs_count) { + return -1; + } else if (lhs_count > rhs_count) { + return 1; + } else { + return 0; + } +} + } // namespace runtime } // namespace tvm +namespace std { + +template <> +struct hash<::tvm::runtime::String> { + std::size_t operator()(const ::tvm::runtime::String& str) const { + // This function falls back to string copy with c++11 compiler and is + // recommended to be compiled with c++14 +#if TVM_USE_CXX17_STRING_VIEW_HASH + return std::hash{}( + std::string_view{str.data(), str.size()}); +#elif TVM_USE_CXX14_STRING_VIEW_HASH + return std::hash{}( + std::experimental::string_view{str.data(), str.size()}); +#else + return std::hash()(str.operator std::string()); +#endif + } +}; +} // namespace std + #endif // TVM_RUNTIME_CONTAINER_H_ diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index 3e6ef2138625..f1198e727401 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -19,8 +19,9 @@ #include #include -#include #include +#include + #include #include #include @@ -221,11 +222,185 @@ TEST(Map, Iterator) { using namespace tvm; PrimExpr a = 1, b = 2; Map map1{{a, b}}; - std::unordered_map - map2(map1.begin(), map1.end()); + std::unordered_map map2( + map1.begin(), map1.end()); CHECK(map2[a].as()->value == 2); } +TEST(String, MoveFromStd) { + using namespace std; + string source = "this is a string"; + string expect = source; + String s(std::move(source)); + string copy = (string)s; + CHECK_EQ(copy, expect); + CHECK_EQ(source.size(), 0); +} + +TEST(String, CopyFromStd) { + using namespace std; + string source = "this is a string"; + string expect = source; + String s{source}; + string copy = (string)s; + CHECK_EQ(copy, expect); + CHECK_EQ(source.size(), expect.size()); +} + +TEST(String, Assignment) { + using namespace std; + String s{string{"hello"}}; + s = string{"world"}; + CHECK_EQ(s == "world", true); + string s2{"world2"}; + s = std::move(s2); + CHECK_EQ(s == "world2", true); +} + +TEST(String, empty) { + using namespace std; + String s{"hello"}; + CHECK_EQ(s.empty(), false); + s = ""; + CHECK_EQ(s.empty(), true); +} + +TEST(String, Comparisons) { + using namespace std; + string source = "a string"; + string mismatch = "a string but longer"; + String s{source}; + + CHECK_EQ(s == source, true); + CHECK_EQ(s == mismatch, false); + CHECK_EQ(s == source.data(), true); + CHECK_EQ(s == mismatch.data(), false); +} + +// Check '\0' handling +TEST(String, null_byte_handling) { + using namespace std; + // Ensure string still compares equal if it contains '\0'. + string v1 = "hello world"; + size_t v1_size = v1.size(); + v1[5] = '\0'; + CHECK_EQ(v1[5], '\0'); + CHECK_EQ(v1.size(), v1_size); + String str_v1{v1}; + CHECK_EQ(str_v1.compare(v1), 0); + CHECK_EQ(str_v1.size(), v1_size); + + // Ensure bytes after '\0' are taken into account for mismatches. + string v2 = "aaa one"; + string v3 = "aaa two"; + v2[3] = '\0'; + v3[3] = '\0'; + String str_v2{v2}; + String str_v3{v3}; + CHECK_EQ(str_v2.compare(str_v3), -1); + CHECK_EQ(str_v2.size(), 7); + // strcmp won't be able to detect the mismatch + CHECK_EQ(strcmp(v2.data(), v3.data()), 0); + // string::compare can handle \0 since it knows size + CHECK_LT(v2.compare(v3), 0); + + // If there is mismatch before '\0', should still handle it. + string v4 = "acc one"; + string v5 = "abb two"; + v4[3] = '\0'; + v5[3] = '\0'; + String str_v4{v4}; + String str_v5{v5}; + CHECK_GT(str_v4.compare(str_v5), 0); + CHECK_EQ(str_v4.size(), 7); + // strcmp is able to detect the mismatch + CHECK_GT(strcmp(v4.data(), v5.data()), 0); + // string::compare can handle \0 since it knows size + CHECK_GT(v4.compare(v5), 0); +} + +TEST(String, compare_same_memory_region_different_size) { + using namespace std; + string source = "a string"; + String str_source{source}; + char* memory = const_cast(str_source.data()); + CHECK_EQ(str_source.compare(memory), 0); + // This changes the string size + memory[2] = '\0'; + // memory is logically shorter now + CHECK_GT(str_source.compare(memory), 0); +} + +TEST(String, compare) { + using namespace std; + string source = "a string"; + string mismatch1 = "a string but longer"; + string mismatch2 = "a strin"; + string mismatch3 = "a b"; + string mismatch4 = "a t"; + String str_source{source}; + String str_mismatch1{mismatch1}; + String str_mismatch2{mismatch2}; + String str_mismatch3{mismatch3}; + String str_mismatch4{mismatch4}; + + // compare with string + CHECK_EQ(str_source.compare(source), 0); + CHECK_LT(str_source.compare(mismatch1), 0); + CHECK_GT(str_source.compare(mismatch2), 0); + CHECK_GT(str_source.compare(mismatch3), 0); + CHECK_LT(str_source.compare(mismatch4), 0); + + // compare with char* + CHECK_EQ(str_source.compare(source.data()), 0); + CHECK_LT(str_source.compare(mismatch1.data()), 0); + CHECK_GT(str_source.compare(mismatch2.data()), 0); + CHECK_GT(str_source.compare(mismatch3.data()), 0); + CHECK_LT(str_source.compare(mismatch4.data()), 0); + + // compare with String + CHECK_LT(str_source.compare(str_mismatch1), 0); + CHECK_GT(str_source.compare(str_mismatch2), 0); + CHECK_GT(str_source.compare(str_mismatch3), 0); + CHECK_LT(str_source.compare(str_mismatch4), 0); +} + +TEST(String, c_str) { + using namespace std; + string source = "this is a string"; + string mismatch = "mismatch"; + String s{source}; + + CHECK_EQ(std::strcmp(s.c_str(), source.data()), 0); + CHECK_NE(std::strcmp(s.c_str(), mismatch.data()), 0); +} + +TEST(String, hash) { + using namespace std; + string source = "this is a string"; + String s{source}; + std::hash()(s); + + std::unordered_map map; + String k1{string{"k1"}}; + string v1{"v1"}; + String k2{string{"k2"}}; + string v2{"v2"}; + map[k1] = v1; + map[k2] = v2; + + CHECK_EQ(map[k1], v1); + CHECK_EQ(map[k2], v2); +} + +TEST(String, Cast) { + using namespace std; + string source = "this is a string"; + String s{source}; + ObjectRef r = s; + String s2 = Downcast(r); +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe";