Skip to content

Commit

Permalink
Update onnx app to Adams2019 autoscheduler and new autoscheduler API
Browse files Browse the repository at this point in the history
Fixes #7670
  • Loading branch information
abadams committed Jul 10, 2023
1 parent 9755e3d commit 3157885
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
2 changes: 1 addition & 1 deletion apps/onnx/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ PYCXXFLAGS = $(CXXFLAGS) $(PYBIND11_CFLAGS) -Wno-deprecated-register

# Python extension for HalideModel
$(BIN)/%/$(PY_MODEL_EXT): model.cpp $(BIN)/%/oclib.a
$(CXX) $(PYCXXFLAGS) -Wall -shared -fPIC -I$(BIN)/$* $^ $(LIBHALIDE_LDFLAGS) -o $@ $(LDFLAGS)
$(CXX) $(PYCXXFLAGS) -Wall -shared -fPIC -I$(BIN)/$* $^ $(LIBHALIDE_LDFLAGS) -Wl,--no-as-needed -lautoschedule_adams2019 -Wl,--as-needed -o $@ $(LDFLAGS)


model_test: $(BIN)/$(HL_TARGET)/$(PY_MODEL_EXT)
Expand Down
3 changes: 2 additions & 1 deletion apps/onnx/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ HalideModel convert_onnx_model(
std::string auto_schedule(const HalideModel &pipeline) {
// Generate a schedule for the pipeline.
Halide::Target tgt = Halide::get_host_target();
auto schedule = pipeline.rep->auto_schedule(tgt);
Halide::AutoschedulerParams autoscheduler_params = Halide::AutoschedulerParams("Adams2019");
auto schedule = pipeline.rep->apply_autoscheduler(tgt, autoscheduler_params);
return schedule.schedule_source;
}

Expand Down
12 changes: 6 additions & 6 deletions apps/onnx/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def test_small_model(self):
model.BuildFromOnnxModel(onnx_model)
schedule = model.OptimizeSchedule()
schedule = schedule.replace('\n', ' ')
expected_schedule = r'// --- BEGIN machine-generated schedule // Target: .+// MachineParams: .+// Delete this line if not using Generator Pipeline pipeline = get_pipeline\(\);.+Func OUT = pipeline.get_func\(1\);.+{.+}.+'
expected_schedule = r'.*Func OUT = pipeline.get_func\(1\);.+'
self.assertRegex(schedule, expected_schedule)

input = np.random.rand(2, 3) - 0.5
input = (np.random.rand(2, 3) - 0.5).astype('float32')
outputs = model.run([input])
self.assertEqual(1, len(outputs))
output = outputs[0]
Expand All @@ -62,12 +62,12 @@ def test_scalars(self):
model = Model()
model.BuildFromOnnxModel(onnx_model)
schedule = model.OptimizeSchedule()
schedule = schedule.replace('\n', ' ')
expected_schedule = r'// --- BEGIN machine-generated schedule // Target: .+// MachineParams: .+// Delete this line if not using Generator Pipeline pipeline = get_pipeline\(\);.+Func C = pipeline.get_func\(2\);.+{.+}.+'
schedule = schedule.replace('\n', ' ')
expected_schedule = r'.*Func C = pipeline.get_func\(2\);.+'
self.assertRegex(schedule, expected_schedule)

input1 = np.random.randint(-10, 10, size=())
input2 = np.random.randint(-10, 10, size=())
input1 = np.random.randint(-10, 10, size=()).astype('int32')
input2 = np.random.randint(-10, 10, size=()).astype('int32')
outputs = model.run([input1, input2])
self.assertEqual(1, len(outputs))
output = outputs[0]
Expand Down

0 comments on commit 3157885

Please sign in to comment.