From 672053941b7b8e7191060dd19786e8e1635e1007 Mon Sep 17 00:00:00 2001 From: Samin <69201742+saminbassiri@users.noreply.github.com> Date: Fri, 18 Oct 2024 01:11:13 +0200 Subject: [PATCH] [DAPHNE-#629] Efficient processing of string data in DenseMatrix (#797) This commit addresses issue #629 by enhancing the string support in the DAPHNE runtime making it practical to process string data sets. The main addition is generalizing or specializing current template structures for FixedStr16 class and std::string class. Generalizations include, e.g., using std::fill()/std::copy() instead of memset()/memcpy() and using the newly introduced ValueTypeUtils::defaultValue instead of 0. Key Features Implemented - Focus on the runtime part, i.e., data structures and kernels; no DaphneDSL/compiler integration yet. - FixedStr16 Class: A fixed-size string class with a 16-character buffer. - DenseMatrix for string value type: Generalize the current DenseMatrix class to support string data. - I/O Operations: Reading CSV files containing string columns. - Kernels on string-valued matrices. - Convert string matrices to numeric: - oneHot: Applies one-hot-encoding to the given (n x m) matrix of strings. - recode: Applies dictionary encoding to the given (n x 1) matrix. - cast: String value and matrix objects can be cast to a particular numeric type. - Comparison operations: - Element-wise binary operators for comparing DenseMatrix and/or DenseMatrix. - Elementwise unary/binary string operations: - Element-wise unary operators for converting all strings in a matrix to lower/upper case. - Element-wise binary operators for concatenating all corresponding strings in two matrices. - Source operations: - fill: Creates a matrix and sets all elements to a particular value. - Reorganization: - transpose: Transposes a given matrix. Testing - Initial unit tests for the DenseMatrix and DenseMatrix have been implemented, verifying functionality for newly added features and data types. --- src/runtime/local/datagen/GenGivenVals.h | 2 +- .../local/datastructures/DenseMatrix.cpp | 5 +- .../local/datastructures/DenseMatrix.h | 30 ++-- .../datastructures/FixedSizeStringValueType.h | 146 ++++++++++++++++++ .../local/datastructures/ValueTypeCode.h | 6 +- .../local/datastructures/ValueTypeUtils.cpp | 16 ++ .../local/datastructures/ValueTypeUtils.h | 19 +++ src/runtime/local/io/ReadCsvFile.h | 65 +++++++- src/runtime/local/io/utils.h | 57 +++++++ src/runtime/local/kernels/BinaryOpCode.h | 37 ++++- src/runtime/local/kernels/CastObj.h | 5 +- src/runtime/local/kernels/CastSca.h | 54 ++++++- src/runtime/local/kernels/EwBinarySca.h | 1 + src/runtime/local/kernels/EwUnarySca.h | 19 +++ src/runtime/local/kernels/Fill.h | 6 +- src/runtime/local/kernels/OneHot.h | 50 ++++++ src/runtime/local/kernels/UnaryOpCode.h | 19 ++- .../local/datastructures/DenseMatrixTest.cpp | 120 ++++++++++++++ test/runtime/local/io/ReadCsvStr.csv | 10 ++ test/runtime/local/io/ReadCsvStr.csv.meta | 5 + test/runtime/local/io/ReadCsvTest.cpp | 48 ++++++ test/runtime/local/kernels/CastObjTest.cpp | 76 +++++++++ test/runtime/local/kernels/CastScaTest.cpp | 29 +++- .../runtime/local/kernels/EwBinaryMatTest.cpp | 92 +++++++++++ .../runtime/local/kernels/EwBinaryScaTest.cpp | 84 ++++++++++ test/runtime/local/kernels/EwUnaryMatTest.cpp | 30 ++++ test/runtime/local/kernels/FillTest.cpp | 37 ++++- test/runtime/local/kernels/OneHotTest.cpp | 62 ++++++++ test/runtime/local/kernels/RecodeTest.cpp | 42 ++++- test/runtime/local/kernels/TransposeTest.cpp | 45 +++++- 30 files changed, 1184 insertions(+), 33 deletions(-) create mode 100644 src/runtime/local/datastructures/FixedSizeStringValueType.h create mode 100644 test/runtime/local/io/ReadCsvStr.csv create mode 100644 test/runtime/local/io/ReadCsvStr.csv.meta diff --git a/src/runtime/local/datagen/GenGivenVals.h b/src/runtime/local/datagen/GenGivenVals.h index e7398f9d5..0d62d1889 100644 --- a/src/runtime/local/datagen/GenGivenVals.h +++ b/src/runtime/local/datagen/GenGivenVals.h @@ -97,7 +97,7 @@ template struct GenGivenVals> { "divisible by given number of rows"); const size_t numCols = numCells / numRows; auto res = DataObjectFactory::create>(numRows, numCols, false); - memcpy(res->getValues(), elements.data(), numCells * sizeof(VT)); + std::copy(elements.begin(), elements.end(), res->getValues()); return res; } }; diff --git a/src/runtime/local/datastructures/DenseMatrix.cpp b/src/runtime/local/datastructures/DenseMatrix.cpp index 2d343ef15..463a5ca0c 100644 --- a/src/runtime/local/datastructures/DenseMatrix.cpp +++ b/src/runtime/local/datastructures/DenseMatrix.cpp @@ -78,8 +78,9 @@ DenseMatrix::DenseMatrix(size_t maxNumRows, size_t numCols, bool zero } else { AllocationDescriptorHost myHostAllocInfo; alloc_shared_values(); + if (zero) - memset(values.get(), 0, maxNumRows * numCols * sizeof(ValueType)); + std::fill(values.get(), values.get() + maxNumRows * numCols, ValueTypeUtils::defaultValue); new_data_placement = this->mdo->addDataPlacement(&myHostAllocInfo); } this->mdo->addLatest(new_data_placement->dp_id); @@ -341,3 +342,5 @@ template class DenseMatrix; template class DenseMatrix; template class DenseMatrix; template class DenseMatrix; +template class DenseMatrix; +template class DenseMatrix; diff --git a/src/runtime/local/datastructures/DenseMatrix.h b/src/runtime/local/datastructures/DenseMatrix.h index afd7f6b3f..ec4c9bd05 100644 --- a/src/runtime/local/datastructures/DenseMatrix.h +++ b/src/runtime/local/datastructures/DenseMatrix.h @@ -124,18 +124,22 @@ template class DenseMatrix : public Matrix { if (rowSkip == numCols || lastAppendedRowIdx == rowIdx) { const size_t startPosIncl = pos(lastAppendedRowIdx, lastAppendedColIdx) + 1; const size_t endPosExcl = pos(rowIdx, colIdx); + if (startPosIncl < endPosExcl) - memset(values.get() + startPosIncl, 0, (endPosExcl - startPosIncl) * sizeof(ValueType)); + std::fill(values.get() + startPosIncl, values.get() + endPosExcl, + ValueTypeUtils::defaultValue); } else { auto v = values.get() + lastAppendedRowIdx * rowSkip; - memset(v + lastAppendedColIdx + 1, 0, (numCols - lastAppendedColIdx - 1) * sizeof(ValueType)); + std::fill(v + lastAppendedColIdx + 1, v + numCols, ValueTypeUtils::defaultValue); + v += rowSkip; + for (size_t r = lastAppendedRowIdx + 1; r < rowIdx; r++) { - memset(v, 0, numCols * sizeof(ValueType)); + std::fill(v, v + numCols, ValueTypeUtils::defaultValue); v += rowSkip; } if (colIdx) - memset(v, 0, (colIdx - 1) * sizeof(ValueType)); + std::fill(v, v + colIdx - 1, ValueTypeUtils::defaultValue); } } @@ -258,7 +262,7 @@ template class DenseMatrix : public Matrix { void prepareAppend() override { // The matrix might be empty. if (numRows != 0 && numCols != 0) - values.get()[0] = ValueType(0); + values.get()[0] = ValueType(ValueTypeUtils::defaultValue); lastAppendedRowIdx = 0; lastAppendedColIdx = 0; } @@ -277,7 +281,7 @@ template class DenseMatrix : public Matrix { // The matrix might be empty. if ((numRows != 0 && numCols != 0) && ((lastAppendedRowIdx + 1 < numRows) || (lastAppendedColIdx + 1 < numCols))) - append(numRows - 1, numCols - 1, ValueType(0)); + append(numRows - 1, numCols - 1, ValueType(ValueTypeUtils::defaultValue)); } void print(std::ostream &os) const override { @@ -327,17 +331,15 @@ template class DenseMatrix : public Matrix { if (valuesLhs == valuesRhs && rowSkipLhs == rowSkipRhs) return true; - if (rowSkipLhs == numCols && rowSkipRhs == numCols) - return !memcmp(valuesLhs, valuesRhs, numRows * numCols * sizeof(ValueType)); - else { - for (size_t r = 0; r < numRows; r++) { - if (memcmp(valuesLhs, valuesRhs, numCols * sizeof(ValueType))) + for (size_t r = 0; r < numRows; ++r) { + for (size_t c = 0; c < numCols; ++c) { + if (*(valuesLhs + c) != *(valuesRhs + c)) return false; - valuesLhs += rowSkipLhs; - valuesRhs += rowSkipRhs; } - return true; + valuesLhs += rowSkipLhs; + valuesRhs += rowSkipRhs; } + return true; } size_t serialize(std::vector &buf) const override; diff --git a/src/runtime/local/datastructures/FixedSizeStringValueType.h b/src/runtime/local/datastructures/FixedSizeStringValueType.h new file mode 100644 index 000000000..f39cbda6a --- /dev/null +++ b/src/runtime/local/datastructures/FixedSizeStringValueType.h @@ -0,0 +1,146 @@ +/* + * Copyright 2024 The DAPHNE Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +/** + * @brief A string value type with a maximum length of 15 characters. + * + * Each instance is backed by a 16-character buffer, whereby at least the last character must always be a null + * character. The null-termination is required for some operations to work correctly (e.g., casting to a number). + */ +struct FixedStr16 { + static const std::size_t N = 16; + char buffer[N]; + + // Default constructor + FixedStr16() { std::fill(buffer, buffer + N, '\0'); } + + // Constructor from a C-style string + FixedStr16(const char *str) { + size_t len = std::strlen(str); + if (len >= N) { + throw std::length_error("string exceeds fixed buffer size"); + } + std::copy(str, str + len, buffer); + std::fill(buffer + len, buffer + N, '\0'); + } + + // Copy constructor + FixedStr16(const FixedStr16 &other) { std::copy(other.buffer, other.buffer + N, buffer); } + + // Constructor from a std::string + FixedStr16(const std::string &other) { + size_t len = other.size(); + if (len >= N) { + throw std::length_error("string exceeds fixed buffer size"); + } + std::copy(other.begin(), other.end(), buffer); + std::fill(buffer + len, buffer + N, '\0'); + } + + // Assignment operator + FixedStr16 &operator=(const FixedStr16 &other) { + if (this != &other) { + std::copy(other.buffer, other.buffer + N, buffer); + } + return *this; + } + + // Overriding the equality operator + bool operator==(const FixedStr16 &other) const { return std::equal(buffer, buffer + N, other.buffer); } + + bool operator==(const char *str) const { return std::strncmp(buffer, str, sizeof(buffer)) == 0; } + + // Overriding the inequality operator + bool operator!=(const FixedStr16 &other) const { return !(std::equal(buffer, buffer + N, other.buffer)); } + + bool operator!=(const char *str) const { return !(std::strncmp(buffer, str, sizeof(buffer)) == 0); } + + // Overriding the Less than operator + bool operator<(const FixedStr16 &other) const { return std::strncmp(buffer, other.buffer, N) < 0; } + + // Overriding the Greater than operator + bool operator>(const FixedStr16 &other) const { return std::strncmp(buffer, other.buffer, N) > 0; } + + // Concatenation operator + friend std::string operator+(const FixedStr16 &lhs, const FixedStr16 &rhs) { + std::string result(lhs.buffer); + result.append(rhs.buffer); + return result; + } + + // Serialization function + void serialize(std::vector &outBuffer) const { outBuffer.insert(outBuffer.end(), buffer, buffer + N); } + + // Overload the output stream operator + friend std::ostream &operator<<(std::ostream &os, const FixedStr16 &fs) { + os.write(fs.buffer, N); + return os; + } + + // Size method + size_t size() const { return std::strlen(buffer); } + + // Method to set the string + void set(const char *str) { + size_t len = std::strlen(str); + if (len >= N) { + throw std::length_error("string exceeds fixed buffer size"); + } + std::transform(str, str + len, buffer, [](char c) { return c; }); + std::fill(buffer + len, buffer + N, '\0'); + } + + // C-string method for compatibility + std::string to_string() const { return std::string(buffer, size()); } + + // Compare method similar to std::string::compare + int compare(const FixedStr16 &other) const { return std::strncmp(buffer, other.buffer, N); } + + // Convert to lowercase + FixedStr16 lower() const { + FixedStr16 result; + std::transform(buffer, buffer + N, result.buffer, [](unsigned char c) { return std::tolower(c); }); + return result; + } + + // Convert to uppercase + FixedStr16 upper() const { + FixedStr16 result; + std::transform(buffer, buffer + N, result.buffer, [](unsigned char c) { return std::toupper(c); }); + return result; + } +}; + +// Specialize std::hash for FixedStr16 this is nessary to use FixedStr16 as a key in std::unordered_map +namespace std { +template <> struct hash { + std::size_t operator()(const FixedStr16 &key) const { + // Compute the hash of the fixed-size buffer + return std::hash()(std::string(key.buffer, key.N)); + } +}; +} // namespace std diff --git a/src/runtime/local/datastructures/ValueTypeCode.h b/src/runtime/local/datastructures/ValueTypeCode.h index cfcb9eba5..f1f7e5f64 100644 --- a/src/runtime/local/datastructures/ValueTypeCode.h +++ b/src/runtime/local/datastructures/ValueTypeCode.h @@ -34,8 +34,10 @@ enum class ValueTypeCode : uint8_t { UI32, UI64, // unsigned integers (uintx_t) F32, - F64, // floating point (float, double) - INVALID, // only for JSON enum conversion + F64, // floating point (float, double) + STR, // std::string + FIXEDSTR16, // fixed-size string (length 16) + INVALID, // only for JSON enum conversion // TODO Support bool as well, but poses some challenges (e.g. sizeof). // UI1 // boolean (bool) }; diff --git a/src/runtime/local/datastructures/ValueTypeUtils.cpp b/src/runtime/local/datastructures/ValueTypeUtils.cpp index ee26023b6..0739b29b2 100644 --- a/src/runtime/local/datastructures/ValueTypeUtils.cpp +++ b/src/runtime/local/datastructures/ValueTypeUtils.cpp @@ -88,6 +88,8 @@ template <> const ValueTypeCode ValueTypeUtils::codeFor = ValueTypeCod template <> const ValueTypeCode ValueTypeUtils::codeFor = ValueTypeCode::UI64; template <> const ValueTypeCode ValueTypeUtils::codeFor = ValueTypeCode::F32; template <> const ValueTypeCode ValueTypeUtils::codeFor = ValueTypeCode::F64; +template <> const ValueTypeCode ValueTypeUtils::codeFor = ValueTypeCode::STR; +template <> const ValueTypeCode ValueTypeUtils::codeFor = ValueTypeCode::FIXEDSTR16; template <> const std::string ValueTypeUtils::cppNameFor = "int8_t"; template <> const std::string ValueTypeUtils::cppNameFor = "int32_t"; @@ -99,6 +101,8 @@ template <> const std::string ValueTypeUtils::cppNameFor = "float"; template <> const std::string ValueTypeUtils::cppNameFor = "double"; template <> const std::string ValueTypeUtils::cppNameFor = "bool"; template <> const std::string ValueTypeUtils::cppNameFor = "const char*"; +template <> const std::string ValueTypeUtils::cppNameFor = "std::string"; +template <> const std::string ValueTypeUtils::cppNameFor = "FixedStr"; template <> const std::string ValueTypeUtils::irNameFor = "si8"; template <> const std::string ValueTypeUtils::irNameFor = "si32"; @@ -109,6 +113,18 @@ template <> const std::string ValueTypeUtils::irNameFor = "ui64"; template <> const std::string ValueTypeUtils::irNameFor = "f32"; template <> const std::string ValueTypeUtils::irNameFor = "f64"; +template <> const int8_t ValueTypeUtils::defaultValue = 0; +template <> const int32_t ValueTypeUtils::defaultValue = 0; +template <> const int64_t ValueTypeUtils::defaultValue = 0; +template <> const uint8_t ValueTypeUtils::defaultValue = 0; +template <> const uint32_t ValueTypeUtils::defaultValue = 0; +template <> const uint64_t ValueTypeUtils::defaultValue = 0; +template <> const float ValueTypeUtils::defaultValue = 0; +template <> const double ValueTypeUtils::defaultValue = 0; +template <> const bool ValueTypeUtils::defaultValue = false; +template <> const std::string ValueTypeUtils::defaultValue = std::string(""); +template <> const FixedStr16 ValueTypeUtils::defaultValue = FixedStr16(); + const std::string ValueTypeUtils::cppNameForCode(ValueTypeCode type) { switch (type) { case ValueTypeCode::SI8: diff --git a/src/runtime/local/datastructures/ValueTypeUtils.h b/src/runtime/local/datastructures/ValueTypeUtils.h index a043cc4b0..d8297cc2e 100644 --- a/src/runtime/local/datastructures/ValueTypeUtils.h +++ b/src/runtime/local/datastructures/ValueTypeUtils.h @@ -16,6 +16,7 @@ #pragma once +#include #include #include @@ -29,6 +30,8 @@ // changes to the list of supported data types local. #define ALL_VALUE_TYPES int8_t, int32_t, int64_t, uint8_t, uint32_t, uint64_t, float, double +#define ALL_STRING_VALUE_TYPES std::string, FixedStr16 + struct ValueTypeUtils { static size_t sizeOf(ValueTypeCode type); @@ -37,6 +40,8 @@ struct ValueTypeUtils { template static const ValueTypeCode codeFor; + template static const ValueType defaultValue; + template static const std::string cppNameFor; template static const std::string irNameFor; @@ -54,6 +59,8 @@ template <> const ValueTypeCode ValueTypeUtils::codeFor; template <> const ValueTypeCode ValueTypeUtils::codeFor; template <> const ValueTypeCode ValueTypeUtils::codeFor; template <> const ValueTypeCode ValueTypeUtils::codeFor; +template <> const ValueTypeCode ValueTypeUtils::codeFor; +template <> const ValueTypeCode ValueTypeUtils::codeFor; template <> const std::string ValueTypeUtils::cppNameFor; template <> const std::string ValueTypeUtils::cppNameFor; @@ -74,3 +81,15 @@ template <> const std::string ValueTypeUtils::irNameFor; template <> const std::string ValueTypeUtils::irNameFor; template <> const std::string ValueTypeUtils::irNameFor; template <> const std::string ValueTypeUtils::irNameFor; + +template <> const int8_t ValueTypeUtils::defaultValue; +template <> const int32_t ValueTypeUtils::defaultValue; +template <> const int64_t ValueTypeUtils::defaultValue; +template <> const uint8_t ValueTypeUtils::defaultValue; +template <> const uint32_t ValueTypeUtils::defaultValue; +template <> const uint64_t ValueTypeUtils::defaultValue; +template <> const float ValueTypeUtils::defaultValue; +template <> const double ValueTypeUtils::defaultValue; +template <> const std::string ValueTypeUtils::defaultValue; +template <> const FixedStr16 ValueTypeUtils::defaultValue; +template <> const char *ValueTypeUtils::defaultValue; diff --git a/src/runtime/local/io/ReadCsvFile.h b/src/runtime/local/io/ReadCsvFile.h index 66ccdfe30..db2fd0f92 100644 --- a/src/runtime/local/io/ReadCsvFile.h +++ b/src/runtime/local/io/ReadCsvFile.h @@ -21,7 +21,6 @@ #include #include -#include #include #include @@ -124,6 +123,70 @@ template struct ReadCsvFile> { } }; +template <> struct ReadCsvFile> { + static void apply(DenseMatrix *&res, struct File *file, size_t numRows, size_t numCols, char delim) { + if (file == nullptr) + throw std::runtime_error("ReadCsvFile: requires a file to be specified (must not be nullptr)"); + if (numRows <= 0) + throw std::runtime_error("ReadCsvFile: numRows must be > 0"); + if (numCols <= 0) + throw std::runtime_error("ReadCsvFile: numCols must be > 0"); + + if (res == nullptr) { + res = DataObjectFactory::create>(numRows, numCols, false); + } + + size_t cell = 0; + std::string *valuesRes = res->getValues(); + + for (size_t r = 0; r < numRows; r++) { + if (getFileLine(file) == -1) + throw std::runtime_error("ReadCsvFile::apply: getFileLine failed"); + + size_t pos = 0; + for (size_t c = 0; c < numCols; c++) { + std::string val(""); + int next_column_pos = setCString(file, pos, &val, delim); + // TODO This assumes that rowSkip == numCols. + valuesRes[cell++] = val; + pos += next_column_pos + 1; + } + } + } +}; + +template <> struct ReadCsvFile> { + static void apply(DenseMatrix *&res, struct File *file, size_t numRows, size_t numCols, char delim) { + if (file == nullptr) + throw std::runtime_error("ReadCsvFile: requires a file to be specified (must not be nullptr)"); + if (numRows <= 0) + throw std::runtime_error("ReadCsvFile: numRows must be > 0"); + if (numCols <= 0) + throw std::runtime_error("ReadCsvFile: numCols must be > 0"); + + if (res == nullptr) { + res = DataObjectFactory::create>(numRows, numCols, false); + } + + size_t cell = 0; + FixedStr16 *valuesRes = res->getValues(); + + for (size_t r = 0; r < numRows; r++) { + if (getFileLine(file) == -1) + throw std::runtime_error("ReadCsvFile::apply: getFileLine failed"); + + size_t pos = 0; + for (size_t c = 0; c < numCols; c++) { + std::string val(""); + int next_column_pos = setCString(file, pos, &val, delim); + // TODO This assumes that rowSkip == numCols. + valuesRes[cell++].set(val.c_str()); + pos += next_column_pos + 1; + } + } + } +}; + // ---------------------------------------------------------------------------- // CSRMatrix // ---------------------------------------------------------------------------- diff --git a/src/runtime/local/io/utils.h b/src/runtime/local/io/utils.h index 8b8d2d5a2..eb8224405 100644 --- a/src/runtime/local/io/utils.h +++ b/src/runtime/local/io/utils.h @@ -20,6 +20,7 @@ #include #include +#include #include // Conversion of std::string. @@ -71,3 +72,59 @@ inline void convertCstr(const char *x, int64_t *v) { *v = atoi(x); } inline void convertCstr(const char *x, uint8_t *v) { *v = atoi(x); } inline void convertCstr(const char *x, uint32_t *v) { *v = atoi(x); } inline void convertCstr(const char *x, uint64_t *v) { *v = atoi(x); } + +/** + * @brief This function reads a CSV column that contains strings. + * + * This function processes a column from a CSV file starting at the given position in the current line. + * It reads and appends characters to the result string (`res`) until it encounters the column delimiter + * or the end of the line. If the column contains multiline strings (enclosed in double quotes), it + * continues reading until the closing quote is found, handling embedded quotes and newline characters + * as necessary. + * + * @param file Pointer to the file object from which the CSV data is being read. The file's `line` + * attribute is expected to contain the current line being processed. + * @param start_pos The starting position within the current line to begin reading the column. This + * function may move beyond the current line if the field contains a multiline string. + * @param res A pointer to the result string that will store the contents of the current column. + * @param delim The delimiter character separating columns (e.g., a comma `,`). + * @return The position pointing to the character immediately before the next column in the line. + */ +inline size_t setCString(struct File *file, size_t start_pos, std::string *res, const char delim) { + size_t pos = 0; + const char *str = file->line + start_pos; + bool is_multiLine = (str[0] == '"'); + if (is_multiLine) + pos++; + + int is_not_end = 1; + while (is_not_end && str[pos]) { + is_not_end -= (!is_multiLine && str[pos] == delim); + is_not_end -= (!is_multiLine && str[pos] == '\n'); + is_not_end -= (!is_multiLine && str[pos] == '\r'); + + is_not_end -= (is_multiLine && str[pos] == '"' && str[pos + 1] != '"'); + if (!is_not_end) + break; + if (is_multiLine && str[pos] == '"' && str[pos + 1] == '"') { + res->append("\"\""); + pos += 2; + } else if (is_multiLine && str[pos] == '\\' && str[pos + 1] == '"') { + res->append("\\\""); + pos += 2; + } else if (is_multiLine && (str[pos] == '\n' || str[pos] == '\r')) { + res->push_back('\n'); + getFileLine(file); + str = file->line; + pos = 0; + } else { + res->push_back(str[pos]); + pos++; + } + } + + if (is_multiLine) + pos++; + + return pos; +} diff --git a/src/runtime/local/kernels/BinaryOpCode.h b/src/runtime/local/kernels/BinaryOpCode.h index ad2db6642..626e08cbb 100644 --- a/src/runtime/local/kernels/BinaryOpCode.h +++ b/src/runtime/local/kernels/BinaryOpCode.h @@ -93,6 +93,11 @@ static constexpr bool supportsBinaryOp = false; // simplicity). #define SUPPORT(Op, VT) template <> constexpr bool supportsBinaryOp = true; +// Generates code specifying that the binary operation `Op` should be supported on +// the value types `VTLhs` and `VTRhs` with result `VTRes`. +#define SUPPORT_RLR(Op, VTRes, VTLhs, VTRhs) \ + template <> constexpr bool supportsBinaryOp = true; + // Generates code specifying that all binary operations of a certain category // should be supported on the given value type `VT` (for the result and the two // arguments, for simplicity). @@ -126,6 +131,23 @@ static constexpr bool supportsBinaryOp = false; /* Bitwise. */ \ SUPPORT(BITWISE_AND, VT) +// Generates code specifying that all binary operations of a certain category should be +// supported on the given argument value type `VTArg` (for the left and right-hand-side +// arguments, for simplicity) and the given result value type `VTRes`. +#define SUPPORT_COMPARISONS_RA(VTRes, VTArg) \ + /* string Comparisons operations. */ \ + SUPPORT_RLR(LT, VTRes, VTArg, VTArg) \ + SUPPORT_RLR(GT, VTRes, VTArg, VTArg) +#define SUPPORT_EQUALITY_RA(VTRes, VTArg) \ + /* string Comparisons operations. */ \ + SUPPORT_RLR(EQ, VTRes, VTArg, VTArg) \ + SUPPORT_RLR(NEQ, VTRes, VTArg, VTArg) +#define SUPPORT_STRING_RA(VTRes, VTArg) \ + /* string concatenation operations. */ \ + /* Since the result may not fit in FixedStr16,*/ \ + /* it always return std::string*/ \ + SUPPORT_RLR(CONCAT, VTRes, VTArg, VTArg) + // Generates code specifying that all binary operations typically supported on a // certain category of value types should be supported on the given value type // `VT` (for the result and the two arguments, for simplicity). @@ -151,11 +173,19 @@ SUPPORT_NUMERIC_INT(int8_t) SUPPORT_NUMERIC_INT(uint64_t) SUPPORT_NUMERIC_INT(uint32_t) SUPPORT_NUMERIC_INT(uint8_t) -template <> constexpr bool supportsBinaryOp = true; -template <> constexpr bool supportsBinaryOp = true; +// Strings binary operations. +SUPPORT_EQUALITY_RA(int64_t, std::string) +SUPPORT_EQUALITY_RA(int64_t, FixedStr16) +SUPPORT_EQUALITY_RA(int64_t, const char *) +SUPPORT_COMPARISONS_RA(int64_t, std::string) +SUPPORT_COMPARISONS_RA(int64_t, FixedStr16) +SUPPORT_STRING_RA(std::string, std::string) +SUPPORT_STRING_RA(std::string, FixedStr16) +SUPPORT_STRING_RA(const char *, const char *) // Undefine helper macros. #undef SUPPORT +#undef SUPPORT_RLR #undef SUPPORT_ARITHMETIC #undef SUPPORT_EQUALITY #undef SUPPORT_COMPARISONS @@ -163,3 +193,6 @@ template <> constexpr bool supportsBinaryOp #include #include +#include // **************************************************************************** // Struct for partial template specialization @@ -187,12 +188,12 @@ template class CastObj, Dens // a single dense array of values, we can simply // perform cast in one loop over that array. for (size_t idx = 0; idx < numCols * numRows; idx++) - resVals[idx] = static_cast(argVals[idx]); + resVals[idx] = castSca(argVals[idx], ctx); else // res and arg might be views into a larger DenseMatrix. for (size_t r = 0; r < numRows; r++) { for (size_t c = 0; c < numCols; c++) - resVals[c] = static_cast(argVals[c]); + resVals[c] = castSca(argVals[c], ctx); resVals += res->getRowSkip(); argVals += arg->getRowSkip(); } diff --git a/src/runtime/local/kernels/CastSca.h b/src/runtime/local/kernels/CastSca.h index 60c4dc438..c443e2b85 100644 --- a/src/runtime/local/kernels/CastSca.h +++ b/src/runtime/local/kernels/CastSca.h @@ -18,6 +18,7 @@ #define SRC_RUNTIME_LOCAL_KERNELS_CASTSCA_H #include +#include #include @@ -67,4 +68,55 @@ template struct CastSca { } }; -#endif // SRC_RUNTIME_LOCAL_KERNELS_CASTSCA_H \ No newline at end of file +// ---------------------------------------------------------------------------- +// any type <- string +// ---------------------------------------------------------------------------- + +template struct CastSca { + static VTRes apply(std::string arg, DCTX(ctx)) { + if constexpr (std::is_integral::value) { + if constexpr (std::is_unsigned::value) + return static_cast(std::stoull(arg)); + else + return static_cast(std::stoll(arg)); + } else if constexpr (std::is_same::value) + return static_cast(std::stold(arg)); + + else if constexpr (std::is_same::value) + return static_cast(std::stof(arg)); + else { + // Trigger a compiler warning using deprecated attribute. + return throwUnsupportedType(arg); + } + } + + [[deprecated("CastSca: Warning! Unsupported result type in casting string values.")]] + static VTRes throwUnsupportedType(std::string arg) { + throw std::runtime_error("CastSca: Unsupported result type in casting string values"); + } +}; + +template struct CastSca { + static VTRes apply(FixedStr16 arg, DCTX(ctx)) { + if constexpr (std::is_integral::value) { + if constexpr (std::is_unsigned::value) + return static_cast(std::stoull(arg.buffer)); + else + return static_cast(std::stoll(arg.buffer)); + } else if constexpr (std::is_same::value) + return static_cast(std::stold(arg.buffer)); + else if constexpr (std::is_same::value) + return static_cast(std::stof(arg.buffer)); + else { + // Trigger a compiler warning using deprecated attribute. + return throwUnsupportedType(arg); + } + } + + [[deprecated("CastSca: Warning! Unsupported result type in casting string values.")]] + static VTRes throwUnsupportedType(std::string arg) { + throw std::runtime_error("CastSca: Unsupported result type in casting string values"); + } +}; + +#endif // SRC_RUNTIME_LOCAL_KERNELS_CASTSCA_H diff --git a/src/runtime/local/kernels/EwBinarySca.h b/src/runtime/local/kernels/EwBinarySca.h index 0fa1768db..2864fa446 100644 --- a/src/runtime/local/kernels/EwBinarySca.h +++ b/src/runtime/local/kernels/EwBinarySca.h @@ -166,6 +166,7 @@ MAKE_EW_BINARY_SCA(BinaryOpCode::MAX, std::max(lhs, rhs)) MAKE_EW_BINARY_SCA(BinaryOpCode::AND, lhs &&rhs) MAKE_EW_BINARY_SCA(BinaryOpCode::OR, lhs || rhs) // Strings. +MAKE_EW_BINARY_SCA(BinaryOpCode::CONCAT, lhs + rhs) template <> struct EwBinarySca { inline static const char *apply(const char *lhs, const char *rhs, DCTX(ctx)) { const auto lenLhs = std::string_view(lhs).size(); diff --git a/src/runtime/local/kernels/EwUnarySca.h b/src/runtime/local/kernels/EwUnarySca.h index b185bd00a..b16639f40 100644 --- a/src/runtime/local/kernels/EwUnarySca.h +++ b/src/runtime/local/kernels/EwUnarySca.h @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -92,6 +93,9 @@ template EwUnaryScaFuncPtr getEwU MAKE_CASE(UnaryOpCode::ROUND) // Comparison. MAKE_CASE(UnaryOpCode::ISNAN) + // String. + MAKE_CASE(UnaryOpCode::LOWER) + MAKE_CASE(UnaryOpCode::UPPER) #undef MAKE_CASE default: throw std::runtime_error("unknown UnaryOpCode: " + std::to_string(static_cast(opCode))); @@ -151,6 +155,15 @@ template TRes ewUnarySca(UnaryOpCode opCode, TArg } \ }; +#define MAKE_EW_UNARY_STRING_TRANSFORM(opCode, expr) \ + template <> struct EwUnarySca { \ + inline static std::string apply(std::string arg, DCTX(ctx)) { \ + std::string new_string = arg; \ + std::transform(new_string.begin(), new_string.end(), new_string.begin(), static_cast(expr)); \ + return new_string; \ + } \ + }; + // One such line for each unary function to support. // Arithmetic/general math. MAKE_EW_UNARY_SCA(UnaryOpCode::MINUS, -arg); @@ -178,9 +191,15 @@ MAKE_EW_UNARY_SCA(UnaryOpCode::CEIL, std::ceil(arg)); MAKE_EW_UNARY_SCA(UnaryOpCode::ROUND, round(arg)); // Comparison. MAKE_EW_UNARY_SCA(UnaryOpCode::ISNAN, std::isnan(arg)); +// String. +MAKE_EW_UNARY_SCA(UnaryOpCode::LOWER, arg.lower()) +MAKE_EW_UNARY_SCA(UnaryOpCode::UPPER, arg.upper()) +MAKE_EW_UNARY_STRING_TRANSFORM(UnaryOpCode::LOWER, std::tolower) +MAKE_EW_UNARY_STRING_TRANSFORM(UnaryOpCode::UPPER, std::toupper) #undef MAKE_EW_UNARY_SCA_CLOSED_DOMAIN_ERROR #undef MAKE_EW_UNARY_SCA_OPEN_DOMAIN_ERROR #undef MAKE_EW_UNARY_SCA +#undef MAKE_EW_UNARY_STRING_TRANSFORM #endif // SRC_RUNTIME_LOCAL_KERNELS_EWUNARYSCA_H diff --git a/src/runtime/local/kernels/Fill.h b/src/runtime/local/kernels/Fill.h index 2e98546a5..920a39826 100644 --- a/src/runtime/local/kernels/Fill.h +++ b/src/runtime/local/kernels/Fill.h @@ -20,6 +20,7 @@ #include #include #include +#include // **************************************************************************** // Struct for partial template specialization @@ -47,11 +48,10 @@ template void fill(DTRes *&res, VTArg arg, size_t template struct Fill, VT> { static void apply(DenseMatrix *&res, VT arg, size_t numRows, size_t numCols, DCTX(ctx)) { - if (res == nullptr) - res = DataObjectFactory::create>(numRows, numCols, arg == 0); + res = DataObjectFactory::create>(numRows, numCols, arg == ValueTypeUtils::defaultValue); - if (arg != 0) { + if (arg != ValueTypeUtils::defaultValue) { VT *valuesRes = res->getValues(); for (auto i = 0ul; i < res->getNumItems(); ++i) valuesRes[i] = arg; diff --git a/src/runtime/local/kernels/OneHot.h b/src/runtime/local/kernels/OneHot.h index 865ae2f7e..1ebe430bb 100644 --- a/src/runtime/local/kernels/OneHot.h +++ b/src/runtime/local/kernels/OneHot.h @@ -21,8 +21,11 @@ #include #include #include +#include #include +#include +#include #include #include @@ -120,6 +123,53 @@ template struct OneHot, DenseMatrix> { } }; +// ---------------------------------------------------------------------------- +// DenseMatrix <- DenseMatrix +// ---------------------------------------------------------------------------- + +template +void oneHotString(DenseMatrix *&res, const DenseMatrix *arg, const DenseMatrix *info, + DCTX(ctx)) { + const size_t numRows = arg->getNumRows(); + const size_t numColsArg = arg->getNumCols(); + const size_t rowSkipArg = arg->getRowSkip(); + auto recode_result = DataObjectFactory::create>(numRows, numColsArg, false); + + VTRes *valuesRes = recode_result->getValues(); + const VTArg *valuesArg = arg->getValues(); + + // Recode arg with string elements to a dense matrix based on indices without ordering + for (size_t cArg = 0; cArg < numColsArg; cArg++) { + std::unordered_map firstIndexMap; + for (size_t r = 0; r < numRows; r++) { + size_t value_index = (rowSkipArg * r) + cArg; + if (firstIndexMap.find(valuesArg[value_index]) == firstIndexMap.end()) { + firstIndexMap[valuesArg[value_index]] = firstIndexMap.size(); + } + valuesRes[value_index] = VTRes(firstIndexMap[valuesArg[value_index]]); + } + } + + // call oneHot with recoded matrix as arg + oneHot(res, recode_result, info, ctx); + + DataObjectFactory::destroy(recode_result); +} + +template struct OneHot, DenseMatrix> { + static void apply(DenseMatrix *&res, const DenseMatrix *arg, const DenseMatrix *info, + DCTX(ctx)) { + oneHotString(res, arg, info, ctx); + } +}; + +template struct OneHot, DenseMatrix> { + static void apply(DenseMatrix *&res, const DenseMatrix *arg, const DenseMatrix *info, + DCTX(ctx)) { + oneHotString(res, arg, info, ctx); + } +}; + // ---------------------------------------------------------------------------- // Matrix <- Matrix // ---------------------------------------------------------------------------- diff --git a/src/runtime/local/kernels/UnaryOpCode.h b/src/runtime/local/kernels/UnaryOpCode.h index 9a327b94b..3bc79bb92 100644 --- a/src/runtime/local/kernels/UnaryOpCode.h +++ b/src/runtime/local/kernels/UnaryOpCode.h @@ -19,6 +19,8 @@ #pragma once +#include + // **************************************************************************** // Enum for unary op codes and their names // **************************************************************************** @@ -46,7 +48,10 @@ enum class UnaryOpCode { CEIL, ROUND, // Comparison. - ISNAN + ISNAN, + // String. + UPPER, + LOWER }; /** @@ -64,7 +69,9 @@ static std::string_view unary_op_codes[] = { // Rounding. "FLOOR", "CEIL", "ROUND", // Comparison. - "ISNAN"}; + "ISNAN", + // String. + "UPPER", "LOWER"}; // **************************************************************************** // Specification which unary ops should be supported on which value types @@ -116,6 +123,10 @@ template static constexpr bool /* Comparison */ \ SUPPORT(ISNAN, VT) +#define SUPPORT_STRING(VT) \ + /* String */ \ + SUPPORT(UPPER, VT) \ + SUPPORT(LOWER, VT) // Concise specification of which unary operations should be supported on // which value types. SUPPORT_NUMERIC(double) @@ -126,9 +137,13 @@ SUPPORT_NUMERIC(int8_t) SUPPORT_NUMERIC(uint64_t) SUPPORT_NUMERIC(uint32_t) SUPPORT_NUMERIC(uint8_t) +// String operations +SUPPORT_STRING(std::string) +SUPPORT_STRING(FixedStr16) // Undefine helper macros. #undef SUPPORT #undef SUPPORT_NUMERIC +#undef SUPPORT_STRING #endif // SRC_RUNTIME_LOCAL_KERNELS_UNARYOPCODE_H diff --git a/test/runtime/local/datastructures/DenseMatrixTest.cpp b/test/runtime/local/datastructures/DenseMatrixTest.cpp index 6f6df1094..8314439f4 100644 --- a/test/runtime/local/datastructures/DenseMatrixTest.cpp +++ b/test/runtime/local/datastructures/DenseMatrixTest.cpp @@ -191,4 +191,124 @@ TEST_CASE("DenseMatrix sub-matrix works properly", TAG_DATASTRUCTURES) { DataObjectFactory::destroy(mSub); DataObjectFactory::destroy(mOrig); } +} + +TEMPLATE_TEST_CASE("DenseMatrix with string value type", TAG_DATASTRUCTURES, ALL_STRING_VALUE_TYPES) { + using ValueType = TestType; + + using expectedStrings = const std::vector; + + // We do not use operator == to compare to a matrix created by genGivenVals() + // here, since this would rely on the functionality we want to test. + auto compareMatToArr = [](const DenseMatrix *mat, const expectedStrings &exp) { + for (size_t r = 0; r < mat->getNumRows(); r++) + for (size_t c = 0; c < mat->getNumCols(); c++) + if ((mat->get(r, c) != exp[r * mat->getNumCols() + c])) + return false; + return true; + }; + + SECTION("Create") { + const size_t numRows = 10000; + const size_t numCols = 2000; + + DenseMatrix *m = DataObjectFactory::create>(numRows, numCols, false); + + ValueType *values = m->getValues(); + const size_t numCells = numRows * numCols; + + for (size_t i = 0; i < numCells; i++) + values[i] = ValueType(); + + DataObjectFactory::destroy(m); + } + + SECTION("Append") { + const size_t numRows = 3; + const size_t numCols = 4; + + DenseMatrix *m = DataObjectFactory::create>(numRows, numCols, false); + expectedStrings exp = {ValueType("0"), ValueType(""), ValueType(""), ValueType("3"), + ValueType("10"), ValueType(""), ValueType(""), ValueType("13"), + ValueType("20"), ValueType(""), ValueType(""), ValueType("23")}; + m->prepareAppend(); + + for (size_t r = 0; r < numRows; r++) + for (size_t c = 0; c < numCols; c++) + if (c % 3 == 0) + m->append(r, c, ValueType(std::to_string(r * 10 + c).c_str())); + + m->finishAppend(); + + CHECK(compareMatToArr(m, exp)); + + DataObjectFactory::destroy>(m); + } + + SECTION("Set") { + const size_t numRows = 3; + const size_t numCols = 4; + expectedStrings exp1 = {ValueType(""), ValueType("1"), ValueType(""), ValueType("3"), + ValueType(""), ValueType("11"), ValueType(""), ValueType("13"), + ValueType(""), ValueType("21"), ValueType(""), ValueType("23")}; + + expectedStrings exp2 = {ValueType("0"), ValueType("1"), ValueType("2"), ValueType("3"), + ValueType("10"), ValueType("11"), ValueType("12"), ValueType("13"), + ValueType("20"), ValueType("21"), ValueType("22"), ValueType("23")}; + DenseMatrix *m = DataObjectFactory::create>(numRows, numCols, false); + + for (size_t r = 0; r < numRows; r++) + for (size_t c = 0; c < numCols; c++) { + size_t num = r * 10 + c; + if (num % 2) { + m->set(r, c, ValueType(std::to_string(num).c_str())); + } + } + CHECK(compareMatToArr(m, exp1)); + + for (size_t r = 0; r < numRows; r++) + for (size_t c = 0; c < numCols; c++) { + size_t num = r * 10 + c; + if (!(num % 2)) { + m->set(r, c, ValueType(std::to_string(num).c_str())); + } + } + CHECK(compareMatToArr(m, exp2)); + + DataObjectFactory::destroy(m); + } + + SECTION("View") { + const size_t numRows = 3; + const size_t numCols = 4; + expectedStrings exp1 = {ValueType("1"), ValueType("2"), ValueType("11"), "12"}; + expectedStrings exp2 = {ValueType("1"), ValueType("2"), ValueType("11"), + ValueType(std::string(5, 'X').c_str())}; + expectedStrings exp3 = {ValueType("0"), + ValueType("1"), + ValueType("2"), + ValueType("3"), + ValueType("10"), + ValueType("11"), + ValueType(std::string(5, 'X').c_str()), + ValueType("13"), + ValueType("20"), + ValueType("21"), + ValueType("22"), + ValueType("23")}; + DenseMatrix *m = DataObjectFactory::create>(numRows, numCols, false); + for (size_t r = 0; r < numRows; r++) + for (size_t c = 0; c < numCols; c++) + m->set(r, c, ValueType(std::to_string(r * 10 + c).c_str())); + auto mView = DataObjectFactory::create>(m, 0, 2, 1, 3); + CHECK(compareMatToArr(mView, exp1)); + + mView->set(1, 1, ValueType(std::string(5, 'X').c_str())); + CHECK(compareMatToArr(mView, exp2)); + + CHECK(compareMatToArr(m, exp3)); + + DataObjectFactory::destroy(m); + DataObjectFactory::destroy(mView); + } } \ No newline at end of file diff --git a/test/runtime/local/io/ReadCsvStr.csv b/test/runtime/local/io/ReadCsvStr.csv new file mode 100644 index 000000000..50eea33eb --- /dev/null +++ b/test/runtime/local/io/ReadCsvStr.csv @@ -0,0 +1,10 @@ +"apple, orange",35,Fruit Basket +"dog, cat",30,Pets +table,27,Furniture Set +"""",22,Unknown Item +"abc""def",33,"No Category\"" +"red, blue\n",50, +"\n\"abc""def\"",28,"Mixed string" +"line1 +line2",27,"with newline" +"\"red, \"\"",41,"" diff --git a/test/runtime/local/io/ReadCsvStr.csv.meta b/test/runtime/local/io/ReadCsvStr.csv.meta new file mode 100644 index 000000000..0fe44b980 --- /dev/null +++ b/test/runtime/local/io/ReadCsvStr.csv.meta @@ -0,0 +1,5 @@ +{ + "numRows": 9, + "numCols": 3, + "valueType": "str" +} \ No newline at end of file diff --git a/test/runtime/local/io/ReadCsvTest.cpp b/test/runtime/local/io/ReadCsvTest.cpp index 6af843135..32134d81d 100644 --- a/test/runtime/local/io/ReadCsvTest.cpp +++ b/test/runtime/local/io/ReadCsvTest.cpp @@ -264,3 +264,51 @@ TEST_CASE("ReadCsv, varying columns", TAG_IO) { DataObjectFactory::destroy(m); } + +TEMPLATE_PRODUCT_TEST_CASE("ReadCsv", TAG_IO, (DenseMatrix), (ALL_STRING_VALUE_TYPES)) { + using DT = TestType; + DT *m = nullptr; + + size_t numRows = 9; + size_t numCols = 3; + + char filename[] = "./test/runtime/local/io/ReadCsvStr.csv"; + char delim = ','; + + readCsv(m, filename, numRows, numCols, delim); + + REQUIRE(m->getNumRows() == numRows); + REQUIRE(m->getNumCols() == numCols); + + CHECK(m->get(0, 0) == "apple, orange"); + CHECK(m->get(1, 0) == "dog, cat"); + CHECK(m->get(2, 0) == "table"); + CHECK(m->get(3, 0) == "\"\""); + CHECK(m->get(4, 0) == "abc\"\"def"); + CHECK(m->get(5, 0) == "red, blue\\n"); + CHECK(m->get(6, 0) == "\\n\\\"abc\"\"def\\\""); + CHECK(m->get(7, 0) == "line1\nline2"); + CHECK(m->get(8, 0) == "\\\"red, \\\"\\\""); + + CHECK(m->get(0, 1) == "35"); + CHECK(m->get(1, 1) == "30"); + CHECK(m->get(2, 1) == "27"); + CHECK(m->get(3, 1) == "22"); + CHECK(m->get(4, 1) == "33"); + CHECK(m->get(5, 1) == "50"); + CHECK(m->get(6, 1) == "28"); + CHECK(m->get(7, 1) == "27"); + CHECK(m->get(8, 1) == "41"); + + CHECK(m->get(0, 2) == "Fruit Basket"); + CHECK(m->get(1, 2) == "Pets"); + CHECK(m->get(2, 2) == "Furniture Set"); + CHECK(m->get(3, 2) == "Unknown Item"); + CHECK(m->get(4, 2) == "No Category\\\""); + CHECK(m->get(5, 2) == ""); + CHECK(m->get(6, 2) == "Mixed string"); + CHECK(m->get(7, 2) == "with newline"); + CHECK(m->get(8, 2) == ""); + + DataObjectFactory::destroy(m); +} diff --git a/test/runtime/local/kernels/CastObjTest.cpp b/test/runtime/local/kernels/CastObjTest.cpp index 3916e7eb2..d66d748e0 100644 --- a/test/runtime/local/kernels/CastObjTest.cpp +++ b/test/runtime/local/kernels/CastObjTest.cpp @@ -212,6 +212,82 @@ TEMPLATE_PRODUCT_TEST_CASE("castObj, matrix to matrix, multi-column", TAG_KERNEL DataObjectFactory::destroy(check123); } +TEMPLATE_PRODUCT_TEST_CASE("castObj, DenseMatrix to DenseMatrix, multi-column", TAG_KERNELS, + (DenseMatrix), (double, int64_t, uint32_t)) { + using DTRes = TestType; + using VTRes = typename DTRes::VT; + + const size_t numRows = 2; + + auto arg_stdSTR = + genGivenVals>(numRows, {std::string("3.1"), std::string("1.1"), std::string("4.1"), + std::string("1.1"), std::string("5.1"), std::string("9.1")}); + DTRes *res_stdSTR = nullptr; + + auto arg_FixedStr16 = + genGivenVals>(numRows, {FixedStr16("3.1"), FixedStr16("1.1"), FixedStr16("4.1"), + FixedStr16("1.1"), FixedStr16("5.1"), FixedStr16("9.1")}); + DTRes *res_FixedStr16 = nullptr; + + auto check = genGivenVals>( + numRows, {VTRes(3.1), VTRes(1.1), VTRes(4.1), VTRes(1.1), VTRes(5.1), VTRes(9.1)}); + + castObj, DenseMatrix>(res_stdSTR, arg_stdSTR, nullptr); + CHECK(*res_stdSTR == *check); + + castObj, DenseMatrix>(res_FixedStr16, arg_FixedStr16, nullptr); + CHECK(*res_FixedStr16 == *check); + + DataObjectFactory::destroy(check); + DataObjectFactory::destroy(res_stdSTR, res_FixedStr16); + DataObjectFactory::destroy(arg_stdSTR, arg_FixedStr16); +} + +TEMPLATE_PRODUCT_TEST_CASE("castObj, DenseMatrix to DenseMatrix, int64_t", TAG_KERNELS, (DenseMatrix), + (int64_t)) { + using DTRes = TestType; + using VTRes = typename DTRes::VT; + + const size_t numRows = 2; + + SECTION("std::string") { + auto arg_string = genGivenVals>( + numRows, {std::string("9223372036854775807"), std::string("9223372036854775806"), + std::string("9223372036854775805"), std::string("9223372036854775804"), + std::string("9223372036854775803"), std::string("9223372036854775802")}); + DTRes *res_string = nullptr; + auto check_string = + genGivenVals>(numRows, {9223372036854775807, 9223372036854775806, 9223372036854775805, + 9223372036854775804, 9223372036854775803, 9223372036854775802}); + + castObj, DenseMatrix>(res_string, arg_string, nullptr); + + CHECK(*res_string == *check_string); + + DataObjectFactory::destroy(check_string); + DataObjectFactory::destroy(res_string); + DataObjectFactory::destroy(arg_string); + } + + SECTION("FixedStr16") { + auto arg_FixedStr16 = genGivenVals>( + numRows, {FixedStr16("123456789012345"), FixedStr16("123456789012344"), FixedStr16("123456789012343"), + FixedStr16("123456789012342"), FixedStr16("123456789012341"), FixedStr16("123456789012340")}); + DTRes *res_FixedStr16 = nullptr; + auto check_FixedStr16 = + genGivenVals>(numRows, {123456789012345, 123456789012344, 123456789012343, + 123456789012342, 123456789012341, 123456789012340}); + + castObj, DenseMatrix>(res_FixedStr16, arg_FixedStr16, nullptr); + + CHECK(*res_FixedStr16 == *check_FixedStr16); + + DataObjectFactory::destroy(check_FixedStr16); + DataObjectFactory::destroy(res_FixedStr16); + DataObjectFactory::destroy(arg_FixedStr16); + } +} + TEMPLATE_PRODUCT_TEST_CASE("castObj, matrix to matrix, single dim", TAG_KERNELS, (DenseMatrix), (double, int64_t, uint32_t)) { using DTRes = TestType; diff --git a/test/runtime/local/kernels/CastScaTest.cpp b/test/runtime/local/kernels/CastScaTest.cpp index bdccf6919..d1143db38 100644 --- a/test/runtime/local/kernels/CastScaTest.cpp +++ b/test/runtime/local/kernels/CastScaTest.cpp @@ -20,6 +20,8 @@ #include +#include + #include TEST_CASE("castSca, no-op casts", TAG_KERNELS) { @@ -46,4 +48,29 @@ TEST_CASE("castSca, actual casts", TAG_KERNELS) { CHECK(castSca(123.4, nullptr) == true); CHECK(castSca(-123.4, nullptr) == true); CHECK(castSca(0.0, nullptr) == false); -} \ No newline at end of file +} + +TEST_CASE("castSca, actual casts strings to numbers", TAG_KERNELS) { + + CHECK(castSca("123", nullptr) == 123); + CHECK(castSca("-123", nullptr) == -123); + CHECK(castSca("0", nullptr) == 0); + CHECK(castSca("123.4", nullptr) == 123.4); + CHECK(castSca("-123.4", nullptr) == -123.4); + CHECK(castSca("0.0", nullptr) == 0.0); + CHECK(castSca("9223372036854775807", nullptr) == std::numeric_limits::max()); + CHECK(castSca("-9223372036854775808", nullptr) == std::numeric_limits::min()); + CHECK(castSca("18446744073709551615", nullptr) == std::numeric_limits::max()); + CHECK(castSca("0", nullptr) == std::numeric_limits::min()); + + CHECK(castSca("123", nullptr) == 123); + CHECK(castSca("-123", nullptr) == -123); + CHECK(castSca("0", nullptr) == 0); + CHECK(castSca("123.4", nullptr) == 123.4); + CHECK(castSca("-123.4", nullptr) == -123.4); + CHECK(castSca("0.0", nullptr) == 0.0); + CHECK(castSca("123456789012345", nullptr) == 123456789012345ll); + CHECK(castSca("-12345678901234", nullptr) == -12345678901234ll); + CHECK(castSca("123456789012345", nullptr) == 123456789012345ull); + CHECK(castSca("0", nullptr) == std::numeric_limits::min()); +} diff --git a/test/runtime/local/kernels/EwBinaryMatTest.cpp b/test/runtime/local/kernels/EwBinaryMatTest.cpp index 041b0179f..648a1b232 100644 --- a/test/runtime/local/kernels/EwBinaryMatTest.cpp +++ b/test/runtime/local/kernels/EwBinaryMatTest.cpp @@ -35,6 +35,13 @@ // CSRMatrix currently only supports ADD and MUL opCodes #define DATA_TYPES_NO_CSR DenseMatrix, Matrix +template +void checkEwBinaryMat(BinaryOpCode opCode, const DTArg *lhs, const DTArg *rhs, const DTRes *exp) { + DTRes *res = nullptr; + ewBinaryMat(opCode, res, lhs, rhs, nullptr); + CHECK(*res == *exp); +} + template void checkEwBinaryMat(BinaryOpCode opCode, const DT *lhs, const DT *rhs, const DT *exp) { DT *res = nullptr; ewBinaryMat(opCode, res, lhs, rhs, nullptr); @@ -189,6 +196,36 @@ TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("div"), TAG_KERNELS, (DATA_TYPES_NO_CSR), ( // Comparisons // **************************************************************************** +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("eq"), TAG_KERNELS, (DenseMatrix), (ALL_STRING_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + auto m1 = genGivenVals
(2, {VT("1"), VT("2"), VT("abc"), VT("abcd"), VT("ABCD"), VT("34ab")}); + auto m2 = genGivenVals
(2, {VT("1"), VT("0"), VT("3"), VT("abcd"), VT("abcd"), VT("34ab")}); + auto m3 = genGivenVals>(2, {1, 0, 0, 1, 0, 1}); + + SECTION("matrix") { checkEwBinaryMat(BinaryOpCode::EQ, m1, m2, m3); } + + DataObjectFactory::destroy(m1); + DataObjectFactory::destroy(m2); + DataObjectFactory::destroy(m3); +} + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("neq"), TAG_KERNELS, (DenseMatrix), (ALL_STRING_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + auto m1 = genGivenVals
(2, {VT("1"), VT("2"), VT("abc"), VT("abcd"), VT("ABCD"), VT("34ab")}); + auto m2 = genGivenVals
(2, {VT("1"), VT("0"), VT("3"), VT("abcd"), VT("abcd"), VT("34ab")}); + auto m3 = genGivenVals>(2, {0, 1, 1, 0, 1, 0}); + + SECTION("matrix") { checkEwBinaryMat(BinaryOpCode::NEQ, m1, m2, m3); } + + DataObjectFactory::destroy(m1); + DataObjectFactory::destroy(m2); + DataObjectFactory::destroy(m3); +} + TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("eq"), TAG_KERNELS, (DATA_TYPES_NO_CSR), (VALUE_TYPES)) { using DT = TestType; @@ -288,6 +325,23 @@ TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("lt"), TAG_KERNELS, (DATA_TYPES_NO_CSR), (V DataObjectFactory::destroy(m1, m2, m3); } +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("lt"), TAG_KERNELS, (DenseMatrix), (ALL_STRING_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + auto m1 = genGivenVals
( + 3, {VT("1"), VT("2"), VT("1"), VT("abc"), VT("abcd"), VT("abcd"), VT("abcd"), VT("ABC"), VT("35abcd")}); + auto m2 = genGivenVals
( + 3, {VT("1"), VT("0"), VT("3"), VT("abcd"), VT("abce"), VT("abcd"), VT("abc"), VT("abc"), VT("30abcd")}); + auto m3 = genGivenVals>(3, {0, 0, 1, 1, 1, 0, 0, 1, 0}); + + SECTION("matrix") { checkEwBinaryMat(BinaryOpCode::LT, m1, m2, m3); } + + DataObjectFactory::destroy(m1); + DataObjectFactory::destroy(m2); + DataObjectFactory::destroy(m3); +} + TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("le"), TAG_KERNELS, (DATA_TYPES_NO_CSR), (VALUE_TYPES)) { using DT = TestType; @@ -354,6 +408,23 @@ TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("gt"), TAG_KERNELS, (DATA_TYPES_NO_CSR), (V DataObjectFactory::destroy(m1, m2, m3); } +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("gt"), TAG_KERNELS, (DenseMatrix), (ALL_STRING_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + auto m1 = genGivenVals
( + 3, {VT("1"), VT("2"), VT("1"), VT("abc"), VT("abcd"), VT("abcd"), VT("abcd"), VT("ABC"), VT("35abcd")}); + auto m2 = genGivenVals
( + 3, {VT("1"), VT("0"), VT("3"), VT("abcd"), VT("abce"), VT("abcd"), VT("abc"), VT("abc"), VT("30abcd")}); + auto m3 = genGivenVals>(3, {0, 1, 0, 0, 0, 0, 1, 0, 1}); + + SECTION("matrix") { checkEwBinaryMat(BinaryOpCode::GT, m1, m2, m3); } + + DataObjectFactory::destroy(m1); + DataObjectFactory::destroy(m2); + DataObjectFactory::destroy(m3); +} + TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("ge"), TAG_KERNELS, (DATA_TYPES_NO_CSR), (VALUE_TYPES)) { using DT = TestType; @@ -487,6 +558,27 @@ TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("or"), TAG_KERNELS, (DATA_TYPES_NO_CSR), (V DataObjectFactory::destroy(m1, m2, m3); } +// **************************************************************************** +// string. +// **************************************************************************** + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("concat"), TAG_KERNELS, (DenseMatrix), (ALL_STRING_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + using VTr = std::string; + + auto m1 = genGivenVals
(2, {VT("1"), VT("2"), VT(""), VT(""), VT("ab"), VT("abcd")}); + auto m2 = genGivenVals
(2, {VT(""), VT("0"), VT(""), VT("abc"), VT("ce"), VT("abcd")}); + auto m3 = + genGivenVals>(2, {VTr("1"), VTr("20"), VTr(""), VTr("abc"), VTr("abce"), VTr("abcdabcd")}); + + SECTION("matrix") { checkEwBinaryMat(BinaryOpCode::CONCAT, m1, m2, m3); } + + DataObjectFactory::destroy(m1); + DataObjectFactory::destroy(m2); + DataObjectFactory::destroy(m3); +} + // **************************************************************************** // Invalid op-code // **************************************************************************** diff --git a/test/runtime/local/kernels/EwBinaryScaTest.cpp b/test/runtime/local/kernels/EwBinaryScaTest.cpp index 02a0042b3..1504071d8 100644 --- a/test/runtime/local/kernels/EwBinaryScaTest.cpp +++ b/test/runtime/local/kernels/EwBinaryScaTest.cpp @@ -30,6 +30,21 @@ template void checkEwBinarySca(VT lhs, VT rhs CHECK(ewBinarySca(opCode, lhs, rhs, nullptr) == exp); } +template void checkEwBinarySca(std::string lhs, std::string rhs, int64_t exp) { + CHECK(EwBinarySca::apply(lhs, rhs, nullptr) == exp); + CHECK(ewBinarySca(opCode, lhs, rhs, nullptr) == exp); +} + +template void checkEwBinarySca(FixedStr16 lhs, FixedStr16 rhs, int64_t exp) { + CHECK(EwBinarySca::apply(lhs, rhs, nullptr) == exp); + CHECK(ewBinarySca(opCode, lhs, rhs, nullptr) == exp); +} + +template void checkEwBinarySca(VT lhs, VT rhs, std::string exp) { + CHECK(EwBinarySca::apply(lhs, rhs, nullptr) == exp); + CHECK(ewBinarySca(BinaryOpCode::CONCAT, lhs, rhs, nullptr) == exp); +} + // **************************************************************************** // Arithmetic // **************************************************************************** @@ -65,6 +80,19 @@ TEMPLATE_TEST_CASE(TEST_NAME("eq"), TAG_KERNELS, VALUE_TYPES) { checkEwBinarySca(3, 5, 0); } +TEMPLATE_TEST_CASE(TEST_NAME("eq"), TAG_KERNELS, ALL_STRING_VALUE_TYPES) { + using VT = TestType; + checkEwBinarySca(VT("abcd"), VT("abcd"), 1); + checkEwBinarySca(VT("abce"), VT("abcd"), 0); + checkEwBinarySca(VT("abcda"), VT("abcd"), 0); + checkEwBinarySca(VT("abc"), VT("abcd"), 0); + checkEwBinarySca(VT("ABCD"), VT("abcd"), 0); + checkEwBinarySca(VT("36abcd"), VT("30abcd"), 0); + checkEwBinarySca(VT("3"), VT("4"), 0); + checkEwBinarySca(VT(""), VT("abc"), 0); + checkEwBinarySca(VT(""), VT(""), 1); +} + TEMPLATE_TEST_CASE(TEST_NAME("neq"), TAG_KERNELS, VALUE_TYPES) { using VT = TestType; checkEwBinarySca(0, 0, 0); @@ -72,6 +100,19 @@ TEMPLATE_TEST_CASE(TEST_NAME("neq"), TAG_KERNELS, VALUE_TYPES) { checkEwBinarySca(3, 5, 1); } +TEMPLATE_TEST_CASE(TEST_NAME("neq"), TAG_KERNELS, ALL_STRING_VALUE_TYPES) { + using VT = TestType; + checkEwBinarySca(VT("abcd"), VT("abcd"), 0); + checkEwBinarySca(VT("abce"), VT("abcd"), 1); + checkEwBinarySca(VT("abcda"), VT("abcd"), 1); + checkEwBinarySca(VT("abc"), VT("abcd"), 1); + checkEwBinarySca(VT("ABCD"), VT("abcd"), 1); + checkEwBinarySca(VT("36abcd"), VT("30abcd"), 1); + checkEwBinarySca(VT("3"), VT("4"), 1); + checkEwBinarySca(VT(""), VT("abc"), 1); + checkEwBinarySca(VT(""), VT(""), 0); +} + TEMPLATE_TEST_CASE(TEST_NAME("lt"), TAG_KERNELS, VALUE_TYPES) { using VT = TestType; checkEwBinarySca(1, 1, 0); @@ -79,6 +120,21 @@ TEMPLATE_TEST_CASE(TEST_NAME("lt"), TAG_KERNELS, VALUE_TYPES) { checkEwBinarySca(4, 2, 0); } +TEMPLATE_TEST_CASE(TEST_NAME("lt"), TAG_KERNELS, ALL_STRING_VALUE_TYPES) { + using VT = TestType; + checkEwBinarySca(VT("abcd"), VT("abcd"), 0); + checkEwBinarySca(VT("abce"), VT("abcd"), 0); + checkEwBinarySca(VT("abcb"), VT("abcd"), 1); + checkEwBinarySca(VT("abcda"), VT("abcd"), 0); + checkEwBinarySca(VT("abc"), VT("abcd"), 1); + checkEwBinarySca(VT("ABCD"), VT("abcd"), 1); + checkEwBinarySca(VT("abcD"), VT("abcd"), 1); + checkEwBinarySca(VT("36abcd"), VT("30abcd"), 0); + checkEwBinarySca(VT("3"), VT("4"), 1); + checkEwBinarySca(VT(""), VT("abc"), 1); + checkEwBinarySca(VT(""), VT(""), 0); +} + TEMPLATE_TEST_CASE(TEST_NAME("le"), TAG_KERNELS, VALUE_TYPES) { using VT = TestType; checkEwBinarySca(1, 1, 1); @@ -93,6 +149,21 @@ TEMPLATE_TEST_CASE(TEST_NAME("gt"), TAG_KERNELS, VALUE_TYPES) { checkEwBinarySca(4, 2, 1); } +TEMPLATE_TEST_CASE(TEST_NAME("gt"), TAG_KERNELS, ALL_STRING_VALUE_TYPES) { + using VT = TestType; + checkEwBinarySca(VT("abcd"), VT("abcd"), 0); + checkEwBinarySca(VT("abce"), VT("abcd"), 1); + checkEwBinarySca(VT("abcb"), VT("abcd"), 0); + checkEwBinarySca(VT("abcda"), VT("abcd"), 1); + checkEwBinarySca(VT("abc"), VT("abcd"), 0); + checkEwBinarySca(VT("ABCD"), VT("abcd"), 0); + checkEwBinarySca(VT("abcD"), VT("abcd"), 0); + checkEwBinarySca(VT("36abcd"), VT("30abcd"), 1); + checkEwBinarySca(VT("3"), VT("4"), 0); + checkEwBinarySca(VT(""), VT("abc"), 0); + checkEwBinarySca(VT(""), VT(""), 0); +} + TEMPLATE_TEST_CASE(TEST_NAME("ge"), TAG_KERNELS, VALUE_TYPES) { using VT = TestType; checkEwBinarySca(1, 1, 1); @@ -150,6 +221,19 @@ TEMPLATE_TEST_CASE(TEST_NAME("or"), TAG_KERNELS, VALUE_TYPES) { checkEwBinarySca(-2, -2, 1); } +// **************************************************************************** +// String ops +// **************************************************************************** + +TEMPLATE_TEST_CASE(TEST_NAME("concat"), TAG_KERNELS, ALL_STRING_VALUE_TYPES) { + using VT = TestType; + checkEwBinarySca(VT("abcd"), VT("abcd"), std::string("abcdabcd")); + checkEwBinarySca(VT("ABCD"), VT("abcd"), std::string("ABCDabcd")); + checkEwBinarySca(VT("3"), VT("4"), std::string("34")); + checkEwBinarySca(VT(""), VT("abc"), std::string("abc")); + checkEwBinarySca(VT(""), VT(""), std::string("")); +} + // **************************************************************************** // Invalid op-code // **************************************************************************** diff --git a/test/runtime/local/kernels/EwUnaryMatTest.cpp b/test/runtime/local/kernels/EwUnaryMatTest.cpp index 09bea0b0d..9a923db98 100644 --- a/test/runtime/local/kernels/EwUnaryMatTest.cpp +++ b/test/runtime/local/kernels/EwUnaryMatTest.cpp @@ -602,6 +602,36 @@ TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("isNan, floating-point specific"), TAG_KERN DataObjectFactory::destroy(arg, exp); } +// **************************************************************************** +// String Upper and Lower +// **************************************************************************** + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("Upper, string data"), TAG_KERNELS, (DenseMatrix), (ALL_STRING_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + auto arg = genGivenVals
(2, {VT(""), VT("Ab"), VT("123 abc"), VT("ABc"), VT("ab"), VT("12")}); + + auto exp = genGivenVals>(2, {VT(""), VT("AB"), VT("123 ABC"), VT("ABC"), VT("AB"), VT("12")}); + + checkEwUnaryMat(UnaryOpCode::UPPER, arg, exp); + + DataObjectFactory::destroy(arg, exp); +} + +TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("Lower, string data"), TAG_KERNELS, (DenseMatrix), (ALL_STRING_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + auto arg = genGivenVals
(2, {VT(""), VT("Ab"), VT("123 ABC"), VT("ABc"), VT("ab"), VT("14")}); + + auto exp = genGivenVals>(2, {VT(""), VT("ab"), VT("123 abc"), VT("abc"), VT("ab"), VT("14")}); + + checkEwUnaryMat(UnaryOpCode::LOWER, arg, exp); + + DataObjectFactory::destroy(arg, exp); +} + // **************************************************************************** // Invalid op-code // **************************************************************************** diff --git a/test/runtime/local/kernels/FillTest.cpp b/test/runtime/local/kernels/FillTest.cpp index b4db0801f..3c8714bea 100644 --- a/test/runtime/local/kernels/FillTest.cpp +++ b/test/runtime/local/kernels/FillTest.cpp @@ -58,4 +58,39 @@ TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("Matrix"), TAG_KERNELS, (DATA_TYPES), (VALU CHECK(*res == *exp); DataObjectFactory::destroy(exp, res); -} \ No newline at end of file +} + +TEMPLATE_PRODUCT_TEST_CASE("FillString", TAG_KERNELS, (DenseMatrix), (ALL_STRING_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + + size_t numRows = 3; + size_t numCols = 4; + + SECTION("empty_string") { + DenseMatrix *res = nullptr; + VT arg = VT(""); + + auto *exp = genGivenVals>( + 3, {VT(""), VT(""), VT(""), VT(""), VT(""), VT(""), VT(""), VT(""), VT(""), VT(""), VT(""), VT("")}); + + fill(res, arg, numRows, numCols, nullptr); + CHECK(*exp == *res); + + DataObjectFactory::destroy(res, exp); + } + + SECTION("not_empty_string") { + DenseMatrix *res = nullptr; + VT arg = VT("abc"); + + auto *exp = + genGivenVals>(3, {VT("abc"), VT("abc"), VT("abc"), VT("abc"), VT("abc"), VT("abc"), + VT("abc"), VT("abc"), VT("abc"), VT("abc"), VT("abc"), VT("abc")}); + + fill(res, arg, numRows, numCols, nullptr); + CHECK(*exp == *res); + + DataObjectFactory::destroy(res, exp); + } +} diff --git a/test/runtime/local/kernels/OneHotTest.cpp b/test/runtime/local/kernels/OneHotTest.cpp index 777904e09..48f1d8d02 100644 --- a/test/runtime/local/kernels/OneHotTest.cpp +++ b/test/runtime/local/kernels/OneHotTest.cpp @@ -93,5 +93,67 @@ TEMPLATE_PRODUCT_TEST_CASE("OneHot", TAG_KERNELS, (DATA_TYPES), (VALUE_TYPES)) { REQUIRE_THROWS_AS(oneHot(res, arg, info, nullptr), std::out_of_range); } + DataObjectFactory::destroy(arg, info); +} + +TEMPLATE_PRODUCT_TEST_CASE("OneHot", TAG_KERNELS, (DenseMatrix), (ALL_STRING_VALUE_TYPES)) { + using DTArg = TestType; + using VT = typename DTArg::VT; + using DTRes = DenseMatrix; + + auto *arg = genGivenVals>(4, {VT("a"), VT("blue"), VT("a"), VT("5"), VT("b"), VT("green"), VT("a"), + VT("20"), VT("c"), VT("red"), VT("b"), VT("10"), VT("d"), VT("blue"), + VT("b"), VT("20")}); + + DTRes *res = nullptr; + DenseMatrix *info = nullptr; + + /* + recoded_matrix = { + 0, 0, 0, 0, + 1, 1, 0, 1, + 2, 2, 1, 2, + 3, 0, 1, 1 + } + */ + SECTION("normal encoding") { + info = genGivenVals>(1, {-1, -1, 2, 3}); + auto *exp = genGivenVals( + 4, { + 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 2, 2, 0, 1, 0, 0, 1, 3, 0, 0, 1, 0, 1, 0, + }); + + oneHot(res, arg, info, nullptr); + CHECK(*res == *exp); + + DataObjectFactory::destroy(exp, res); + } + SECTION("normal encoding - skip columns") { + info = genGivenVals>(1, {4, 0, 0, 0}); + auto *exp = genGivenVals(4, { + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + }); + + oneHot(res, arg, info, nullptr); + CHECK(*res == *exp); + + DataObjectFactory::destroy(exp, res); + } + DataObjectFactory::destroy(arg, info); } \ No newline at end of file diff --git a/test/runtime/local/kernels/RecodeTest.cpp b/test/runtime/local/kernels/RecodeTest.cpp index df6c27c37..8193229ea 100644 --- a/test/runtime/local/kernels/RecodeTest.cpp +++ b/test/runtime/local/kernels/RecodeTest.cpp @@ -78,4 +78,44 @@ TEMPLATE_PRODUCT_TEST_CASE("Recode", TAG_KERNELS, (DenseMatrix, Matrix), (double } DataObjectFactory::destroy(arg, expRes, expDict); -} \ No newline at end of file +} + +TEMPLATE_PRODUCT_TEST_CASE("Recode", TAG_KERNELS, (DenseMatrix), (ALL_STRING_VALUE_TYPES)) { + using DTArg = TestType; + using VTArg = typename DTArg::VT; + using DTRes = DenseMatrix; + using DTDict = DenseMatrix; + + DTArg *arg = nullptr; + DTRes *expRes = nullptr; + DTArg *expDict = nullptr; + + SECTION("empty arg, non-order-preserving recoding") { + arg = DataObjectFactory::create(0, 1, false); + expRes = DataObjectFactory::create(0, 1, false); + expDict = DataObjectFactory::create(0, 1, false); + checkRecode(arg, false, expRes, expDict); + } + SECTION("empty arg, order-preserving recoding") { + arg = DataObjectFactory::create(0, 1, false); + expRes = DataObjectFactory::create(0, 1, false); + expDict = DataObjectFactory::create(0, 1, false); + checkRecode(arg, true, expRes, expDict); + } + SECTION("non-empty arg, non-order-preserving recoding") { + arg = genGivenVals(8, {VTArg("abc"), VTArg("ab"), VTArg("abcde"), VTArg("ab"), VTArg("ab"), VTArg("a"), + VTArg("abcd"), VTArg("abcde")}); + expRes = genGivenVals(8, {0, 1, 2, 1, 1, 3, 4, 2}); + expDict = genGivenVals(5, {VTArg("abc"), VTArg("ab"), VTArg("abcde"), VTArg("a"), VTArg("abcd")}); + checkRecode(arg, false, expRes, expDict); + } + SECTION("non-empty arg, order-preserving recoding") { + arg = genGivenVals(8, {VTArg("abc"), VTArg("ab"), VTArg("abcde"), VTArg("ab"), VTArg("ab"), VTArg("a"), + VTArg("abcd"), VTArg("abcde")}); + expRes = genGivenVals(8, {2, 1, 4, 1, 1, 0, 3, 4}); + expDict = genGivenVals(5, {VTArg("a"), VTArg("ab"), VTArg("abc"), VTArg("abcd"), VTArg("abcde")}); + checkRecode(arg, true, expRes, expDict); + } + + DataObjectFactory::destroy(arg, expRes, expDict); +} diff --git a/test/runtime/local/kernels/TransposeTest.cpp b/test/runtime/local/kernels/TransposeTest.cpp index 1157f01bb..cd4162fed 100644 --- a/test/runtime/local/kernels/TransposeTest.cpp +++ b/test/runtime/local/kernels/TransposeTest.cpp @@ -115,4 +115,47 @@ TEMPLATE_PRODUCT_TEST_CASE("Transpose", TAG_KERNELS, (DATA_TYPES), (VALUE_TYPES) checkTranspose(m, mt); DataObjectFactory::destroy(m, mt); -} \ No newline at end of file +} + +TEMPLATE_PRODUCT_TEST_CASE("Transpose", TAG_KERNELS, (DenseMatrix), (ALL_STRING_VALUE_TYPES)) { + using DT = TestType; + using VT = typename DT::VT; + ; + + DT *m = nullptr; + DT *mt = nullptr; + + m = genGivenVals
(3, { + VT("1"), + VT("a"), + VT("3"), + VT("4"), + VT("5"), + VT("ab"), + VT("7"), + VT("8"), + VT("9"), + VT("abc"), + VT("11"), + VT("12"), + }); + mt = genGivenVals
(4, { + VT("1"), + VT("5"), + VT("9"), + VT("a"), + VT("ab"), + VT("abc"), + VT("3"), + VT("7"), + VT("11"), + VT("4"), + VT("8"), + VT("12"), + }); + + checkTranspose(m, mt); + + DataObjectFactory::destroy(m); + DataObjectFactory::destroy(mt); +}