diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index 53d7754be9469..a40fd04959f88 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -93,8 +93,12 @@ DataType TfLiteDType2TVMDType(TfLiteType dtype) { void TFLiteRuntime::Init(const std::string& tflite_model_bytes, TVMContext ctx) { const char* buffer = tflite_model_bytes.c_str(); size_t buffer_size = tflite_model_bytes.size(); + // The buffer used to construct the model must be kept alive for + // dependent interpreters to be used. + flatBuffersBuffer_ = std::unique_ptr(new char[buffer_size]); + std::memcpy(flatBuffersBuffer_.get(), buffer, buffer_size); std::unique_ptr model = - tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size); + tflite::FlatBufferModel::BuildFromBuffer(flatBuffersBuffer_.get(), buffer_size); tflite::ops::builtin::BuiltinOpResolver resolver; // Build interpreter TfLiteStatus status = tflite::InterpreterBuilder(*model, resolver)(&interpreter_); diff --git a/src/runtime/contrib/tflite/tflite_runtime.h b/src/runtime/contrib/tflite/tflite_runtime.h index f61f6ee37e0b9..e4d231c03fb14 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.h +++ b/src/runtime/contrib/tflite/tflite_runtime.h @@ -93,6 +93,8 @@ class TFLiteRuntime : public ModuleNode { */ NDArray GetOutput(int index) const; + // Buffer backing the interpreter's model + std::unique_ptr flatBuffersBuffer_; // TFLite interpreter std::unique_ptr interpreter_; // TVM context