Skip to content

Commit

Permalink
Update onnx app to Adams2019 autoscheduler and new autoscheduler API (h…
Browse files Browse the repository at this point in the history
…alide#7673)

* Update onnx app to Adams2019 autoscheduler and new autoscheduler API

Fixes halide#7670

* Add model test too

* Remove use of tmpnam

* Don't test onnx app in a 32-bit build
  • Loading branch information
abadams authored and ardier committed Mar 3, 2024
1 parent 6358504 commit 38c7884
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 11 deletions.
11 changes: 9 additions & 2 deletions apps/onnx/Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
include ../support/Makefile.inc

ifneq (,$(findstring -m32,$(CXX) $(CC) $(CCFLAGS) $(CXXFLAGS)))
build:
@echo "Not testing this app in a 32-bit build (-m32 found in flags)"
test: build
else

PROTOC := $(shell which protoc)

ifdef PROTOC
Expand Down Expand Up @@ -72,7 +78,7 @@ $(BIN)/%/onnx_converter_generator_test: onnx_converter_generator_test.cc $(BIN)/

build: $(BIN)/$(HL_TARGET)/onnx_converter_test $(BIN)/$(HL_TARGET)/onnx_converter_generator_test

test: build
test: build model_test
LD_LIBRARY_PATH=$(BIN) $(BIN)/$(HL_TARGET)/onnx_converter_test
LD_LIBRARY_PATH=$(BIN) $(BIN)/$(HL_TARGET)/onnx_converter_generator_test

Expand All @@ -88,7 +94,7 @@ PYCXXFLAGS = $(CXXFLAGS) $(PYBIND11_CFLAGS) -Wno-deprecated-register

# Python extension for HalideModel
$(BIN)/%/$(PY_MODEL_EXT): model.cpp $(BIN)/%/oclib.a
$(CXX) $(PYCXXFLAGS) -Wall -shared -fPIC -I$(BIN)/$* $^ $(LIBHALIDE_LDFLAGS) -o $@ $(LDFLAGS)
$(CXX) $(PYCXXFLAGS) -Wall -shared -fPIC -I$(BIN)/$* $^ $(LIBHALIDE_LDFLAGS) -Wl,--no-as-needed -lautoschedule_adams2019 -Wl,--as-needed -o $@ $(LDFLAGS)


model_test: $(BIN)/$(HL_TARGET)/$(PY_MODEL_EXT)
Expand All @@ -104,3 +110,4 @@ build:
test: build

endif
endif
7 changes: 4 additions & 3 deletions apps/onnx/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ HalideModel convert_onnx_model(
std::string auto_schedule(const HalideModel &pipeline) {
// Generate a schedule for the pipeline.
Halide::Target tgt = Halide::get_host_target();
auto schedule = pipeline.rep->auto_schedule(tgt);
Halide::AutoschedulerParams autoscheduler_params = Halide::AutoschedulerParams("Adams2019");
auto schedule = pipeline.rep->apply_autoscheduler(tgt, autoscheduler_params);
return schedule.schedule_source;
}

Expand Down Expand Up @@ -528,15 +529,15 @@ void print_loop_nest(const HalideModel &pipeline) {
}

void print_lowered_statement(const HalideModel &pipeline) {
std::string tmp_file = std::tmpnam(nullptr);
Halide::Internal::TemporaryFile f("model", ".stmt");
std::string tmp_file = f.pathname();
pipeline.rep->compile_to_lowered_stmt(
tmp_file, pipeline.rep->infer_arguments());
std::ifstream is(tmp_file);
std::string line;
while (std::getline(is, line)) {
std::cout << line << "\n";
}
std::remove(tmp_file.c_str());
}

} // namespace
Expand Down
12 changes: 6 additions & 6 deletions apps/onnx/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def test_small_model(self):
model.BuildFromOnnxModel(onnx_model)
schedule = model.OptimizeSchedule()
schedule = schedule.replace('\n', ' ')
expected_schedule = r'// --- BEGIN machine-generated schedule // Target: .+// MachineParams: .+// Delete this line if not using Generator Pipeline pipeline = get_pipeline\(\);.+Func OUT = pipeline.get_func\(1\);.+{.+}.+'
expected_schedule = r'.*Func OUT = pipeline.get_func\(1\);.+'
self.assertRegex(schedule, expected_schedule)

input = np.random.rand(2, 3) - 0.5
input = (np.random.rand(2, 3) - 0.5).astype('float32')
outputs = model.run([input])
self.assertEqual(1, len(outputs))
output = outputs[0]
Expand All @@ -62,12 +62,12 @@ def test_scalars(self):
model = Model()
model.BuildFromOnnxModel(onnx_model)
schedule = model.OptimizeSchedule()
schedule = schedule.replace('\n', ' ')
expected_schedule = r'// --- BEGIN machine-generated schedule // Target: .+// MachineParams: .+// Delete this line if not using Generator Pipeline pipeline = get_pipeline\(\);.+Func C = pipeline.get_func\(2\);.+{.+}.+'
schedule = schedule.replace('\n', ' ')
expected_schedule = r'.*Func C = pipeline.get_func\(2\);.+'
self.assertRegex(schedule, expected_schedule)

input1 = np.random.randint(-10, 10, size=())
input2 = np.random.randint(-10, 10, size=())
input1 = np.random.randint(-10, 10, size=()).astype('int32')
input2 = np.random.randint(-10, 10, size=()).astype('int32')
outputs = model.run([input1, input2])
self.assertEqual(1, len(outputs))
output = outputs[0]
Expand Down

0 comments on commit 38c7884

Please sign in to comment.