Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use fibers to guarantee stack size on Windows #5873

Merged
merged 12 commits into from
Apr 2, 2021
2 changes: 1 addition & 1 deletion cmake/HalideTestHelpers.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,4 @@ function(tests)
endforeach ()

set(TEST_NAMES "${TEST_NAMES}" PARENT_SCOPE)
endfunction(tests)
endfunction()
4 changes: 3 additions & 1 deletion src/CodeGen_LLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,9 @@ std::unique_ptr<llvm::Module> CodeGen_LLVM::compile(const Module &input) {
for (const auto &f : input.functions()) {
const auto names = get_mangled_names(f, get_target());

compile_func(f, names.simple_name, names.extern_name);
run_with_large_stack([&]() {
compile_func(f, names.simple_name, names.extern_name);
});

// If the Func is externally visible, also create the argv wrapper and metadata.
// (useful for calling from JIT and other machine interfaces).
Expand Down
39 changes: 24 additions & 15 deletions src/Lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,23 +98,17 @@ class LoweringLogger {
}
};

} // namespace

Module lower(const vector<Function> &output_funcs,
const string &pipeline_name,
const Target &t,
const vector<Argument> &args,
const LinkageType linkage_type,
const vector<Stmt> &requirements,
bool trace_pipeline,
const vector<IRMutator *> &custom_passes) {
void lower_impl(const vector<Function> &output_funcs,
const string &pipeline_name,
const Target &t,
const vector<Argument> &args,
const LinkageType linkage_type,
const vector<Stmt> &requirements,
bool trace_pipeline,
const vector<IRMutator *> &custom_passes,
Module &result_module) {
auto time_start = std::chrono::high_resolution_clock::now();

std::vector<std::string> namespaces;
std::string simple_pipeline_name = extract_namespaces(pipeline_name, namespaces);

Module result_module(simple_pipeline_name, t);

// Compute an environment
map<string, Function> env;
for (const Function &f : output_funcs) {
Expand Down Expand Up @@ -524,7 +518,22 @@ Module lower(const vector<Function> &output_funcs,
std::chrono::duration<double> diff = time_end - time_start;
logger->record_compilation_time(CompilerLogger::Phase::HalideLowering, diff.count());
}
}

} // namespace

Module lower(const vector<Function> &output_funcs,
const string &pipeline_name,
const Target &t,
const vector<Argument> &args,
const LinkageType linkage_type,
const vector<Stmt> &requirements,
bool trace_pipeline,
const vector<IRMutator *> &custom_passes) {
Module result_module{extract_namespaces(pipeline_name), t};
run_with_large_stack([&]() {
lower_impl(output_funcs, pipeline_name, t, args, linkage_type, requirements, trace_pipeline, custom_passes, result_module);
});
return result_module;
}

Expand Down
83 changes: 83 additions & 0 deletions src/Util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,11 @@ std::string extract_namespaces(const std::string &name, std::vector<std::string>
return result;
}

std::string extract_namespaces(const std::string &name) {
std::vector<std::string> unused;
return extract_namespaces(name, unused);
}

bool file_exists(const std::string &name) {
#ifdef _MSC_VER
return _access(name.c_str(), 0) == 0;
Expand Down Expand Up @@ -572,6 +577,84 @@ int get_llvm_version() {
return LLVM_VERSION;
}

#ifdef _WIN32

namespace {

struct GenericFiberArgs {
const std::function<void()> &run;
LPVOID main_fiber;
#ifdef HALIDE_WITH_EXCEPTIONS
std::exception_ptr exception = nullptr; // NOLINT - clang-tidy complains this isn't thrown
#endif
};

void WINAPI generic_fiber_entry_point(LPVOID argument) {
auto *action = reinterpret_cast<GenericFiberArgs *>(argument);
#ifdef HALIDE_WITH_EXCEPTIONS
try {
#endif
action->run();
#ifdef HALIDE_WITH_EXCEPTIONS
} catch (...) {
action->exception = std::current_exception();
}
#endif
SwitchToFiber(action->main_fiber);
}

} // namespace

#endif

void run_with_large_stack(const std::function<void()> &action) {
#if _WIN32
constexpr SIZE_T required_stack = 8 * 1024 * 1024;

// Only exists for its address, which is used to compute remaining stack space.
ULONG_PTR approx_stack_pos;

ULONG_PTR stack_low, stack_high;
GetCurrentThreadStackLimits(&stack_low, &stack_high);
ptrdiff_t stack_remaining = (char *)&approx_stack_pos - (char *)stack_low;

if (stack_remaining < required_stack) {
debug(1) << "Insufficient stack space (" << stack_remaining << " bytes). Switching to fiber with " << required_stack << "-byte stack.\n";

auto was_a_fiber = IsThreadAFiber();

auto *main_fiber = was_a_fiber ? GetCurrentFiber() : ConvertThreadToFiber(nullptr);
internal_assert(main_fiber) << "ConvertThreadToFiber failed with code: " << GetLastError() << "\n";

GenericFiberArgs fiber_args{action, main_fiber};
auto *lower_fiber = CreateFiber(required_stack, generic_fiber_entry_point, &fiber_args);
internal_assert(lower_fiber) << "CreateFiber failed with code: " << GetLastError() << "\n";

SwitchToFiber(lower_fiber);
DeleteFiber(lower_fiber);

debug(1) << "Returned from fiber.\n";

#ifdef HALIDE_WITH_EXCEPTIONS
if (fiber_args.exception) {
debug(1) << "Fiber threw exception. Rethrowing...\n";
std::rethrow_exception(fiber_args.exception);
}
#endif

if (!was_a_fiber) {
BOOL success = ConvertFiberToThread();
internal_assert(success) << "ConvertFiberToThread failed with code: " << GetLastError() << "\n";
}

return;
}

#endif

action();
}

} // namespace Internal

void load_plugin(const std::string &lib_name) {
Expand Down
14 changes: 9 additions & 5 deletions src/Util.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include <cstdint>
#include <cstring>
#include <functional>
#include <limits>
#include <string>
#include <utility>
Expand Down Expand Up @@ -44,11 +45,6 @@
#define HALIDE_NO_USER_CODE_INLINE HALIDE_NEVER_INLINE
#endif

// On windows, Halide needs a larger stack than the default MSVC provides
#ifdef _MSC_VER
#pragma comment(linker, "/STACK:8388608,1048576")
#endif

namespace Halide {

/** Load a plugin in the form of a dynamic library (e.g. for custom autoschedulers).
Expand Down Expand Up @@ -212,6 +208,9 @@ struct all_are_convertible : meta_and<std::is_convertible<Args, To>...> {};
/** Returns base name and fills in namespaces, outermost one first in vector. */
std::string extract_namespaces(const std::string &name, std::vector<std::string> &namespaces);

/** Overload that returns base name only */
std::string extract_namespaces(const std::string &name);

struct FileStat {
uint64_t file_size;
uint32_t mod_time; // Unix epoch time
Expand Down Expand Up @@ -466,6 +465,11 @@ std::string c_print_name(const std::string &name);
* of Halide tests. */
int get_llvm_version();

/** Call the given action in a platform-specific context that provides at least
* 8MB of stack space. Currently only has any effect on Windows where it uses
* a Fiber. */
void run_with_large_stack(const std::function<void()> &action);

} // namespace Internal
} // namespace Halide

Expand Down
1 change: 1 addition & 0 deletions test/error/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ tests(GROUPS error
reuse_var_in_schedule.cpp
reused_args.cpp
rfactor_inner_dim_non_commutative.cpp
run_with_large_stack_throws.cpp
specialize_fail.cpp
split_inner_wrong_tail_strategy.cpp
thread_id_outside_block_id.cpp
Expand Down
16 changes: 16 additions & 0 deletions test/error/run_with_large_stack_throws.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include "Halide.h"
#include <iostream>

int main() {
try {
Halide::Internal::run_with_large_stack([]() {
throw Halide::RuntimeError("Error from run_with_large_stack");
});
} catch (const Halide::RuntimeError &ex) {
std::cerr << ex.what() << "\n";
return 1;
}

std::cout << "Success!\n";
return 0;
}
2 changes: 1 addition & 1 deletion tutorial/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ add_dependencies(lesson_15_targets

##
add_test(NAME tutorial_lesson_15_build_gens
COMMAND ${CMAKE_COMMAND} --build ${CMAKE_BINARY_DIR} --target lesson_15_targets)
COMMAND ${CMAKE_COMMAND} --build ${CMAKE_BINARY_DIR} --target lesson_15_targets --config $<CONFIG>)
set_tests_properties(tutorial_lesson_15_build_gens PROPERTIES
LABELS tutorial
FIXTURES_SETUP tutorial_lesson_15)
Expand Down