-
Notifications
You must be signed in to change notification settings - Fork 3
/
mlscript.py.save
55 lines (42 loc) · 1.65 KB
/
mlscript.py.save
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import torchvision
from torchvision import transforms, datasets
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from pathlib import Path
from PIL import Image
classes = ['A', 'B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z','del',
'nothing', 'space']
num_classes=29
class ASLTestDataset(torch.utils.data.Dataset):
def __init__(self, root_path, transforms=None):
super().__init__()
self.transforms = transforms
self.imgs = sorted(list(Path(root_path).glob('*.jpg')))
def __len__(self):
return len(self.imgs)
def __getitem__(self, idx):
img_path = self.imgs[idx]
img = Image.open(img_path).convert('RGB')
label = img_path.parts[-1].split('_')[0]
if self.transforms:
img = self.transforms(img)
return img, label
test_transforms = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor()
])
def predict(demo_data_path):
demo_dataset = ASLTestDataset(demo_data_path, transforms=test_transforms)
model = torchvision.models.resnet50(pretrained=False)
in_features = model.fc.in_features
model.fc = torch.nn.Linear(in_features, num_classes)
model.load_state_dict(torch.load('model!', map_location=torch.device('cpu')))
for img, label in demo_dataset:
img = torch.Tensor(img)
img = img.to(device)
model.eval()
pred = model(img[None])
letter = classes[torch.max(pred, dim=1)[1]]
path = 'data.json'
with open(path, 'r') as f:
data = json.load(f)