-
Notifications
You must be signed in to change notification settings - Fork 202
/
dataset.py
44 lines (42 loc) · 1.66 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import glob
import os
def loader(path, batch_size=32, num_workers=4, pin_memory=True):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
return data.DataLoader(
datasets.ImageFolder(path,
transforms.Compose([
transforms.Scale(256),
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])),
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory)
def test_loader(path, batch_size=32, num_workers=4, pin_memory=True):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
return data.DataLoader(
datasets.ImageFolder(path,
transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory)