A flutter plugin for pytorch model inference, supported both for Android and iOS.
To use this plugin, add pytorch_mobile
as a dependency in your pubspec.yaml file.
Create a assets
folder with your pytorch model and labels if needed. Modify pubspec.yaml
accoringly.
assets:
- assets/models/model.pt
- assets/labels.csv
Run flutter pub get
import 'package:pytorch_mobile/pytorch_mobile.dart';
Either custom model:
Model customModel = await PyTorchMobile
.loadModel('assets/models/custom_model.pt');
Or image model:
Model imageModel = await PyTorchMobile
.loadModel('assets/models/resnet18.pt');
List prediction = await customModel
.getPrediction([1, 2, 3, 4], [1, 2, 2], DType.float32);
String prediction = await _imageModel
.getImagePrediction(image, 224, 224, "assets/labels/labels.csv");
final mean = [0.5, 0.5, 0.5];
final std = [0.5, 0.5, 0.5];
String prediction = await _imageModel
.getImagePrediction(image, 224, 224, "assets/labels/labels.csv", mean: mean, std: std);