diff --git a/include/tvm/support/parallel_for.h b/include/tvm/support/parallel_for.h new file mode 100644 index 000000000000..49a9d4889e33 --- /dev/null +++ b/include/tvm/support/parallel_for.h @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file parallel_for.h + * \brief An implementation to run loop in parallel. + */ +#ifndef TVM_SUPPORT_PARALLEL_FOR_H_ +#define TVM_SUPPORT_PARALLEL_FOR_H_ + +#include + +#include +#include + +namespace tvm { +namespace support { + +using PartitionerFuncType = std::function>(int, int, int, int)>; + +/*! + * \brief A partitioner to split the task to each thread in Round-robin manner. + * \param begin The start index of this parallel loop(inclusive). + * \param end The end index of this parallel loop(exclusive). + * \param step The traversal step to the index. + * \param num_threads The number of threads(the number of tasks to be partitioned to). + * \return A list with `num_threads` elements, and each is a list of integers indicating the loop + * indexes for the corresponding thread to process. + */ +TVM_DLL std::vector> rr_partitioner(int begin, int end, int step, int num_threads); + +/*! + * \brief A runtime api provided to run the task function in parallel. + * e.g. A for loop: + * for (int i = 0; i < 10; i++) { + * a[i] = i; + * } + * should work the same as: + * parallel_for(0, 10, [&a](int index) { + * a[i] = i; + * }); + * \param begin The start index of this parallel loop(inclusive). + * \param end The end index of this parallel loop(exclusive). + * \param f The task function to be excuted. Assert to take an int index as input with no output. + * \param step The traversal step to the index. + * \param partitioner A partition function to split tasks to different threads. Use Round-robin + * partitioner by default. + * \note 1. Currently do not support nested parallel_for; 2. The order of execution in each thread + * is not guaranteed, the for loop task should be thread independent and thread safe. + */ +TVM_DLL void parallel_for(int begin, int end, const std::function& f, int step = 1, + const PartitionerFuncType partitioner = rr_partitioner); + +} // namespace support +} // namespace tvm + +#endif // TVM_SUPPORT_PARALLEL_FOR_H_ diff --git a/src/support/parallel_for.cc b/src/support/parallel_for.cc new file mode 100644 index 000000000000..30f39fbee6f9 --- /dev/null +++ b/src/support/parallel_for.cc @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file parallel_for.cc + * \brief An implementation to run loop in parallel. + */ +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace support { + +std::vector> rr_partitioner(int begin, int end, int step, int num_threads) { + int total_task_count = (end - begin) / step; + CHECK_GT(total_task_count, 0) << "Infinite loop condition, check the input value of " + << "`begin`, `end`, `step`."; + std::vector> ret; + ret.reserve(num_threads); + for (size_t thread = 0; begin < end; begin += step, thread = (thread + 1) % num_threads) { + if (thread >= ret.size()) { + ret.push_back(std::vector()); + } + ret[thread].push_back(begin); + } + return ret; +} + +void parallel_for(int begin, int end, const std::function& f, int step, + const PartitionerFuncType partitioner) { + int default_num_threads = std::thread::hardware_concurrency(); + const auto& run_partitions = partitioner(begin, end, step, default_num_threads); + + std::vector threads; + threads.reserve(run_partitions.size()); + std::vector> res_vec; + res_vec.reserve(run_partitions.size()); + for (const auto& run_partition : run_partitions) { + std::packaged_task&, const std::function&)> task( + [](const std::vector& run_pattition, const std::function& f) { + for (const auto& i : run_pattition) { + f(i); + } + }); + res_vec.emplace_back(task.get_future()); + threads.emplace_back(std::move(task), run_partition, f); + } + + for (auto&& thread : threads) { + thread.join(); + } + try { + for (auto&& i : res_vec) { + i.get(); + } + } catch (const std::exception& e) { + LOG(FATAL) << "Parallel_for error with " << e.what(); + } +} + +} // namespace support +} // namespace tvm diff --git a/tests/cpp/parallel_for_test.cc b/tests/cpp/parallel_for_test.cc new file mode 100644 index 000000000000..3d586fc1aa15 --- /dev/null +++ b/tests/cpp/parallel_for_test.cc @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +#include + +TEST(ParallelFor, Basic) { + using tvm::support::parallel_for; + + int a[1000], b[1000]; + + // Check for a small size of parallel + for (int i = 0; i < 10; i++) { + a[i] = i; + } + parallel_for(0, 10, [&b](int i) { b[i] = i; }); + for (int i = 0; i < 10; i++) { + CHECK_EQ(a[i], b[i]); + } + + // Check for a large size of parallel + for (int i = 0; i < 1000; i++) { + a[i] = i; + } + parallel_for(0, 1000, [&b](int i) { b[i] = i; }); + for (int i = 0; i < 1000; i++) { + CHECK_EQ(a[i], b[i]); + } + + // Check for step != 1 + for (int i = 0; i < 1000; i += 2) { + a[i] *= 2; + } + parallel_for( + 0, 1000, [&b](int i) { b[i] *= 2; }, 2); + for (int i = 0; i < 1000; i++) { + CHECK_EQ(a[i], b[i]); + } +} + +TEST(ParallelFor, NestedWithNormalForLoop) { + using tvm::support::parallel_for; + + int a[500][500], b[500][500], c[500][500]; + + for (int i = 0; i < 500; i++) { + for (int j = 0; j < 500; j++) { + a[i][j] = i * j; + } + } + + parallel_for(0, 500, [&b](int i) { + for (int j = 0; j < 500; j++) { + b[i][j] = i * j; + } + }); + for (int i = 0; i < 500; i++) { + for (int j = 0; j < 500; j++) { + CHECK_EQ(a[i][j], b[i][j]); + } + } + + for (int i = 0; i < 500; i++) { + parallel_for(0, 500, [&c, &i](int j) { c[i][j] = i * j; }); + } + for (int i = 0; i < 500; i++) { + for (int j = 0; j < 500; j++) { + CHECK_EQ(a[i][j], c[i][j]); + } + } +} + +TEST(ParallelFor, Exception) { + using tvm::support::parallel_for; + + bool exception = false; + try { + parallel_for(0, 100, [](int i) { LOG(FATAL) << "error"; }); + } catch (const std::exception& e) { + exception = true; + } + CHECK(exception); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +}