diff --git a/apps/onnx/Makefile b/apps/onnx/Makefile index 738fb8fe7af6..940a4e66e7ff 100644 --- a/apps/onnx/Makefile +++ b/apps/onnx/Makefile @@ -1,5 +1,11 @@ include ../support/Makefile.inc +ifneq (,$(findstring -m32,$(CXX) $(CC) $(CCFLAGS) $(CXXFLAGS))) +build: + @echo "Not testing this app in a 32-bit build (-m32 found in flags)" +test: build +else + PROTOC := $(shell which protoc) ifdef PROTOC @@ -72,7 +78,7 @@ $(BIN)/%/onnx_converter_generator_test: onnx_converter_generator_test.cc $(BIN)/ build: $(BIN)/$(HL_TARGET)/onnx_converter_test $(BIN)/$(HL_TARGET)/onnx_converter_generator_test -test: build +test: build model_test LD_LIBRARY_PATH=$(BIN) $(BIN)/$(HL_TARGET)/onnx_converter_test LD_LIBRARY_PATH=$(BIN) $(BIN)/$(HL_TARGET)/onnx_converter_generator_test @@ -88,7 +94,7 @@ PYCXXFLAGS = $(CXXFLAGS) $(PYBIND11_CFLAGS) -Wno-deprecated-register # Python extension for HalideModel $(BIN)/%/$(PY_MODEL_EXT): model.cpp $(BIN)/%/oclib.a - $(CXX) $(PYCXXFLAGS) -Wall -shared -fPIC -I$(BIN)/$* $^ $(LIBHALIDE_LDFLAGS) -o $@ $(LDFLAGS) + $(CXX) $(PYCXXFLAGS) -Wall -shared -fPIC -I$(BIN)/$* $^ $(LIBHALIDE_LDFLAGS) -Wl,--no-as-needed -lautoschedule_adams2019 -Wl,--as-needed -o $@ $(LDFLAGS) model_test: $(BIN)/$(HL_TARGET)/$(PY_MODEL_EXT) @@ -104,3 +110,4 @@ build: test: build endif +endif diff --git a/apps/onnx/model.cpp b/apps/onnx/model.cpp index eb7327974612..2d8676ed32bc 100644 --- a/apps/onnx/model.cpp +++ b/apps/onnx/model.cpp @@ -65,7 +65,8 @@ HalideModel convert_onnx_model( std::string auto_schedule(const HalideModel &pipeline) { // Generate a schedule for the pipeline. Halide::Target tgt = Halide::get_host_target(); - auto schedule = pipeline.rep->auto_schedule(tgt); + Halide::AutoschedulerParams autoscheduler_params = Halide::AutoschedulerParams("Adams2019"); + auto schedule = pipeline.rep->apply_autoscheduler(tgt, autoscheduler_params); return schedule.schedule_source; } @@ -528,7 +529,8 @@ void print_loop_nest(const HalideModel &pipeline) { } void print_lowered_statement(const HalideModel &pipeline) { - std::string tmp_file = std::tmpnam(nullptr); + Halide::Internal::TemporaryFile f("model", ".stmt"); + std::string tmp_file = f.pathname(); pipeline.rep->compile_to_lowered_stmt( tmp_file, pipeline.rep->infer_arguments()); std::ifstream is(tmp_file); @@ -536,7 +538,6 @@ void print_lowered_statement(const HalideModel &pipeline) { while (std::getline(is, line)) { std::cout << line << "\n"; } - std::remove(tmp_file.c_str()); } } // namespace diff --git a/apps/onnx/model_test.py b/apps/onnx/model_test.py index a8e2d67cee14..bfd47d30cc98 100644 --- a/apps/onnx/model_test.py +++ b/apps/onnx/model_test.py @@ -35,10 +35,10 @@ def test_small_model(self): model.BuildFromOnnxModel(onnx_model) schedule = model.OptimizeSchedule() schedule = schedule.replace('\n', ' ') - expected_schedule = r'// --- BEGIN machine-generated schedule // Target: .+// MachineParams: .+// Delete this line if not using Generator Pipeline pipeline = get_pipeline\(\);.+Func OUT = pipeline.get_func\(1\);.+{.+}.+' + expected_schedule = r'.*Func OUT = pipeline.get_func\(1\);.+' self.assertRegex(schedule, expected_schedule) - input = np.random.rand(2, 3) - 0.5 + input = (np.random.rand(2, 3) - 0.5).astype('float32') outputs = model.run([input]) self.assertEqual(1, len(outputs)) output = outputs[0] @@ -62,12 +62,12 @@ def test_scalars(self): model = Model() model.BuildFromOnnxModel(onnx_model) schedule = model.OptimizeSchedule() - schedule = schedule.replace('\n', ' ') - expected_schedule = r'// --- BEGIN machine-generated schedule // Target: .+// MachineParams: .+// Delete this line if not using Generator Pipeline pipeline = get_pipeline\(\);.+Func C = pipeline.get_func\(2\);.+{.+}.+' + schedule = schedule.replace('\n', ' ') + expected_schedule = r'.*Func C = pipeline.get_func\(2\);.+' self.assertRegex(schedule, expected_schedule) - input1 = np.random.randint(-10, 10, size=()) - input2 = np.random.randint(-10, 10, size=()) + input1 = np.random.randint(-10, 10, size=()).astype('int32') + input2 = np.random.randint(-10, 10, size=()).astype('int32') outputs = model.run([input1, input2]) self.assertEqual(1, len(outputs)) output = outputs[0]