Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

구현: 라즈베리파이 tflite mobilenet inference 구현 #108

Merged
merged 1 commit into from
May 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions deep_learning/pi/install.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash

# TFLite runtime
echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -
sudo apt-get update
sudo apt-get install python3-tflite-runtime

# OpenCV
sudo apt-get -y install libjpeg-dev libtiff5-dev libjasper-dev libpng12-dev
sudo apt-get -y install libavcodec-dev libavformat-dev libswscale-dev libv4l-dev
sudo apt-get -y install libxvidcore-dev libx264-dev
sudo apt-get -y install qt4-dev-tools libatlas-base-dev
sudo apt-get -y install libhdf5-dev libqtgui4 libqt4-test
pip3 install opencv-python
52 changes: 52 additions & 0 deletions deep_learning/pi/run_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import argparse
import os
import time
from pathlib import Path

import tflite_runtime.interpreter as tflite
import cv2
import numpy as np


def read_input_images(data_path):
img_files = data_path.rglob('*.jpg')
imgs = [cv2.imread(str(img_file)) for img_file in img_files]
imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in imgs]

return imgs


def main():
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model_file', help="path to tflite model file")
parser.add_argument('-d',
'--data_path',
help="path to dataset to run inference")
args = parser.parse_args()

interpreter = tflite.Interpreter(model_path=args.model_file)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_shape = input_details[0]['shape']

images = read_input_images(Path(args.data_path))

for image in images:
input_data = np.expand_dims(image, axis=0)
input_data = np.float32(input_data)

interpreter.set_tensor(input_details[0]['index'], input_data)
start_time = time.time()
interpreter.invoke()
stop_time = time.time()

print('Inference time: {:.3f}'.format((stop_time - start_time) * 1000))

prediction = interpreter.get_tensor(output_details[0]['index'])
print(f'Prediction: {prediction[0]}')


if __name__ == "__main__":
main()