Skip to content

Commit

Permalink
Fix autoscheduling trivial lut wrappers (#6905)
Browse files Browse the repository at this point in the history
* Fix autoscheduling trivial lut wrappers

Fixes #6899

* trigger buildbots

Co-authored-by: Steven Johnson <[email protected]>
  • Loading branch information
abadams and steven-johnson authored Aug 2, 2022
1 parent 8871404 commit 2239119
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/autoschedulers/adams2019/DefaultCostModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ void DefaultCostModel::enqueue(const Internal::Autoscheduler::FunctionDAG &dag,
void DefaultCostModel::enqueue(int ns, Runtime::Buffer<float> *schedule_feats, double *cost_ptr) {
num_stages = ns;

// We know the most stages that will ever be enqueued from the schedule features
internal_assert(pipeline_feat_queue.data() && "Call set_schedule_features before calling enqueue\n");
// We know the most stages that will ever be enqueued from the pipeline features
internal_assert(pipeline_feat_queue.data() && "Call set_pipeline_features before calling enqueue\n");
const int max_num_stages = pipeline_feat_queue.dim(2).extent();
internal_assert(num_stages <= max_num_stages)
<< "schedule features has more stages (" << num_stages
Expand Down
2 changes: 1 addition & 1 deletion src/autoschedulers/adams2019/FunctionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,7 @@ FunctionDAG::FunctionDAG(const vector<Function> &outputs, const Target &target)
}

node.is_wrapper = node.func.is_wrapper();
node.is_input = !node.func.has_update_definition() && node.is_wrapper && !any_incoming_edges;
node.is_input = !node.is_output && !node.func.has_update_definition() && node.is_wrapper && !any_incoming_edges;
node.dimensions = node.func.dimensions();
}
}
Expand Down
24 changes: 24 additions & 0 deletions src/autoschedulers/adams2019/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,30 @@ int main(int argc, char **argv) {
}
}

// A trivial pipeline that just loads from a LUT
if (true) {
Pipeline p1;
Pipeline p2;
for (int test_condition = 0; test_condition < 2; test_condition++) {
Buffer<uint8_t> lut(256);
Func f;
f(x) = lut(x);

f.set_estimate(x, 0, 256);

if (test_condition) {
p2 = Pipeline(f);
} else {
p1 = Pipeline(f);
}
}

if (!test_caching(p1, p2, target)) {
std::cerr << "Caching check failed on stencil chain" << std::endl;
return 1;
}
}

#ifdef HALIDE_ALLOW_LEGACY_AUTOSCHEDULER_API
// Reset environment variables.
set_env_variable("HL_DISABLE_MEMOIZED_FEATURES", cache_features, /* overwrite */ 1);
Expand Down

0 comments on commit 2239119

Please sign in to comment.