-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_cnn.m
49 lines (43 loc) · 1.58 KB
/
train_cnn.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
function train_cnn()
X = importdata("FeatVectSelT.mat");
T = importdata("T_categorical.mat");
X = num2cell(X,1);
input_number = length(X);
[idxTrain,idxVal, idxTest] = trainingPartitions(input_number, [0.7 0.15 0.15]);
XTrain = X(idxTrain);
TTrain = T(idxTrain);
XVal = X(idxVal);
TVal = T(idxVal);
XTest = X(idxTest);
TTest = T(idxTest);
numClasses = 3;
numFeatures = 29;
minibatchsize = 2048;
classes = unique(T', "rows")';
[XTrainCNN, TTrainCNN, ~] = prepareCNNInputs(XTrain, TTrain, numFeatures);
[XValCNN, TValCNN, ~] = prepareCNNInputs(XVal, TVal, numFeatures);
layers = [ ...
imageInputLayer([numFeatures, numFeatures, 1])
convolution2dLayer(5, 20)
reluLayer
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer(Classes=classes)];
options = trainingOptions("adam", ...
InitialLearnRate=0.002,...
ExecutionEnvironment="gpu",...
MaxEpochs=100, ...
ValidationData={XValCNN, TValCNN}, ...
MiniBatchSize=minibatchsize, ...
Shuffle="never", ...
GradientThreshold=1, ...
Verbose=false, ...
Plots="training-progress");
cnn_net = trainNetwork(XTrainCNN,TTrainCNN,layers,options);
[XTestCNN, TTestCNN, ~] = prepareCNNInputs(XTest, TTest, numFeatures);
YTestCNN = classify(cnn_net,XTestCNN,MiniBatchSize=minibatchsize, ExecutionEnvironment="gpu");
acc = mean(mean(YTestCNN == TTestCNN))
save cnn_net.mat cnn_net
save XTest.mat XTestCNN
save TTest.mat TTestCNN
end