Skip to content

Commit

Permalink
Move GeneratorContext into a standalone class (#6618)
Browse files Browse the repository at this point in the history
* Move GeneratorContext into a standalone class

* Minor Fixes

* clang-tidy

* Update Generator.cpp

* Update Generator.cpp
  • Loading branch information
steven-johnson authored Feb 17, 2022
1 parent 846592f commit 7373eb9
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 131 deletions.
99 changes: 76 additions & 23 deletions src/Generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,26 @@

namespace Halide {

GeneratorContext::GeneratorContext(const Target &t, bool auto_schedule,
GeneratorContext::GeneratorContext(const Target &target,
bool auto_schedule,
const MachineParams &machine_params,
std::shared_ptr<ExternsMap> externs_map,
std::shared_ptr<Internal::ValueTracker> value_tracker)
: target_(target),
auto_schedule_(auto_schedule),
machine_params_(machine_params),
externs_map_(std::move(externs_map)),
value_tracker_(std::move(value_tracker)) {
}

GeneratorContext::GeneratorContext(const Target &target,
bool auto_schedule,
const MachineParams &machine_params)
: target("target", t),
auto_schedule("auto_schedule", auto_schedule),
machine_params("machine_params", machine_params),
externs_map(std::make_shared<ExternsMap>()),
value_tracker(std::make_shared<Internal::ValueTracker>()) {
}

void GeneratorContext::init_from_context(const Halide::GeneratorContext &context) {
target.set(context.get_target());
auto_schedule.set(context.get_auto_schedule());
machine_params.set(context.get_machine_params());
value_tracker = context.get_value_tracker();
externs_map = context.get_externs_map();
: GeneratorContext(target,
auto_schedule,
machine_params,
std::make_shared<ExternsMap>(),
std::make_shared<Internal::ValueTracker>()) {
}

namespace Internal {
Expand Down Expand Up @@ -135,6 +140,44 @@ std::vector<Type> parse_halide_type_list(const std::string &types) {
return result;
}

/**
* ValueTracker is an internal utility class that attempts to track and flag certain
* obvious Stub-related errors at Halide compile time: it tracks the constraints set
* on any Parameter-based argument (i.e., Input<Buffer> and Output<Buffer>) to
* ensure that incompatible values aren't set.
*
* e.g.: if a Generator A requires stride[0] == 1,
* and Generator B uses Generator A via stub, but requires stride[0] == 4,
* we should be able to detect this at Halide compilation time, and fail immediately,
* rather than producing code that fails at runtime and/or runs slowly due to
* vectorization being unavailable.
*
* We do this by tracking the active values at entrance and exit to all user-provided
* Generator methods (generate()/schedule()); if we ever find more than two unique
* values active, we know we have a potential conflict. ("two" here because the first
* value is the default value for a given constraint.)
*
* Note that this won't catch all cases:
* -- JIT compilation has no way to check for conflicts at the top-level
* -- constraints that match the default value (e.g. if dim(0).set_stride(1) is the
* first value seen by the tracker) will be ignored, so an explicit requirement set
* this way can be missed
*
* Nevertheless, this is likely to be much better than nothing when composing multiple
* layers of Stubs in a single fused result.
*/
class ValueTracker {
private:
std::map<std::string, std::vector<std::vector<Expr>>> values_history;
const size_t max_unique_values;

public:
explicit ValueTracker(size_t max_unique_values = 2)
: max_unique_values(max_unique_values) {
}
void track_values(const std::string &name, const std::vector<Expr> &values);
};

void ValueTracker::track_values(const std::string &name, const std::vector<Expr> &values) {
std::vector<std::vector<Expr>> &history = values_history[name];
if (history.empty()) {
Expand Down Expand Up @@ -597,24 +640,24 @@ void StubEmitter::emit() {
for (const auto &out : out_info) {
stream << get_indent() << "stub." << out.getter << ",\n";
}
stream << get_indent() << "stub.generator->get_target()\n";
stream << get_indent() << "stub.generator->context().get_target()\n";
indent_level--;
stream << get_indent() << "};\n";
indent_level--;
stream << get_indent() << "}\n";
stream << "\n";

stream << get_indent() << "// overload to allow GeneratorContext-pointer\n";
stream << get_indent() << "// overload to allow GeneratorBase-pointer\n";
stream << get_indent() << "inline static Outputs generate(\n";
indent_level++;
stream << get_indent() << "const GeneratorContext* context,\n";
stream << get_indent() << "const Halide::Internal::GeneratorBase* generator,\n";
stream << get_indent() << "const Inputs& inputs,\n";
stream << get_indent() << "const GeneratorParams& generator_params = GeneratorParams()\n";
indent_level--;
stream << get_indent() << ")\n";
stream << get_indent() << "{\n";
indent_level++;
stream << get_indent() << "return generate(*context, inputs, generator_params);\n";
stream << get_indent() << "return generate(generator->context(), inputs, generator_params);\n";
indent_level--;
stream << get_indent() << "}\n";
stream << "\n";
Expand Down Expand Up @@ -1346,10 +1389,20 @@ void GeneratorBase::set_generator_param_values(const GeneratorParamsMap &params)
}
}

GeneratorContext GeneratorBase::context() const {
return GeneratorContext(target, auto_schedule, machine_params, externs_map, value_tracker);
}

void GeneratorBase::init_from_context(const Halide::GeneratorContext &context) {
Halide::GeneratorContext::init_from_context(context);
internal_assert(param_info_ptr == nullptr);
target.set(context.target_);
auto_schedule.set(context.auto_schedule_);
machine_params.set(context.machine_params_);

externs_map = context.externs_map_;
value_tracker = context.value_tracker_;

// pre-emptively build our param_info now
internal_assert(param_info_ptr == nullptr);
param_info_ptr = std::make_unique<GeneratorParamInfo>(this, size);
}

Expand Down Expand Up @@ -1381,7 +1434,7 @@ void GeneratorBase::track_parameter_values(bool include_outputs) {
internal_assert(!input->parameters_.empty());
for (auto &p : input->parameters_) {
// This must use p.name(), *not* input->name()
get_value_tracker()->track_values(p.name(), parameter_constraints(p));
value_tracker->track_values(p.name(), parameter_constraints(p));
}
}
}
Expand All @@ -1395,7 +1448,7 @@ void GeneratorBase::track_parameter_values(bool include_outputs) {
for (auto &o : output_buffers) {
Parameter p = o.parameter();
// This must use p.name(), *not* output->name()
get_value_tracker()->track_values(p.name(), parameter_constraints(p));
value_tracker->track_values(p.name(), parameter_constraints(p));
}
}
}
Expand Down Expand Up @@ -1550,7 +1603,7 @@ Module GeneratorBase::build_module(const std::string &function_name,
}

Module result = pipeline.compile_to_module(filter_arguments, function_name, get_target(), linkage_type);
std::shared_ptr<ExternsMap> externs_map = get_externs_map();
std::shared_ptr<GeneratorContext::ExternsMap> externs_map = get_externs_map();
for (const auto &map_entry : *externs_map) {
result.append(map_entry.second);
}
Expand Down
Loading

0 comments on commit 7373eb9

Please sign in to comment.