-
Notifications
You must be signed in to change notification settings - Fork 1
/
oselm.py
121 lines (100 loc) · 4.14 KB
/
oselm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# -*- coding: utf-8 -*-
"""OSELM.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1eONPMFeE9M7gCrarqnZpVb4nJ7ZslQXZ
"""
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import random
import os
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def calculate_accuracy(outputs, targets):
with torch.no_grad():
batch_size = targets.size(0)
_, pred = outputs.topk(1, 1, largest=True, sorted=True)
pred = pred.t()
correct = pred.eq(targets.view(1, -1))
n_correct_elems = correct.float().sum().item()
return n_correct_elems / batch_size
class OSELM(object):
def __init__(self, n_input_nodes = 784, n_hidden_nodes=1024, n_output_nodes = 10, activation = 'sigmoid'):
self.n_output_nodes = n_output_nodes
self.n_hidden_nodes = n_hidden_nodes
self.W = torch.FloatTensor(n_input_nodes, n_hidden_nodes).uniform_(-1, 1)
self.bias = torch.FloatTensor(n_hidden_nodes).uniform_(-1, 1)
self.activation = torch.sigmoid
self.criterion = nn.CrossEntropyLoss()
self.losses = AverageMeter()
self.accuracies = AverageMeter()
self.is_finished_initialize = False
def initialize(self, inputs, targets):
if self.is_finished_initialize:
raise Exception(
'the initial training phase has already finished. '
'please call \'seq_train\' method for further training.'
)
inputs = inputs.view(inputs.size()[0],-1)
targets = F.one_hot(targets.long(), num_classes = self.n_output_nodes).float()
numsamples = inputs.size()[0]
if numsamples < self.n_hidden_nodes:
raise ValueError(
'in the initial training phase, the number of training samples '
'must be greater than the number of hidden nodes. '
'But this time len(x) = %d, while n_hidden_nodes = %d' % (numsamples, self.n_hidden_nodes)
)
H = self.activation(torch.mm(inputs, self.W) + self.bias)
HT = H.t()
HTH = torch.mm(HT, H)
self.P = torch.inverse(HTH)
PHT = torch.mm(self.P, HT)
self.beta = torch.mm(PHT, targets)
self.is_finished_initialize = True
def seq_train(self, inputs, targets):
inputs = inputs.view(inputs.size()[0],-1)
if not self.is_finished_initialize:
raise Exception(
'you have not gone through the initial training phase yet. '
'please first initialize the model\'s weights by \'init_train\' '
'method before calling \'seq_train\' method.'
)
targets = F.one_hot(targets.long(), num_classes = self.n_output_nodes).float()
I = torch.eye(inputs.size()[0], dtype = torch.int, requires_grad=False)
H = self.activation(torch.mm(inputs, self.W) + self.bias)
#step p
HP = torch.mm(H, self.P)
M = torch.inverse(I + torch.mm(HP, H.t()))
self.P -= torch.mm(torch.mm(torch.mm(self.P, H.t()), M), HP)
#step beta
Z = targets - torch.mm(H, self.beta)
self.beta += torch.mm(torch.mm(self.P, H.t()), Z)
def predict(self, inputs, targets):
inputs = inputs.view(inputs.size()[0],-1)
result = torch.mm(self.activation((torch.mm(inputs, self.W) + self.bias)), self.beta)
loss = self.criterion(result, targets)
accuracy = calculate_accuracy(result, targets)
self.accuracies.update(accuracy, targets.size()[0])
self.losses.update(loss, targets.size()[0])
print("Loss:{loss.val:.4f},({loss.avg:.4f})\t"
"Acc {acc.val:.3f}({acc.avg:.3f})".format(loss=self.losses, acc = self.accuracies))
def oneHotVectorize(targets):
oneHotTarget=torch.zeros(targets.size()[0],targets.max().item()+1, dtype = torch.int, requires_grad=False)
for i in range(targets.size()[0]):
oneHotTarget[i][targets[i].item()] = 1
assert(oneHotTarget.requires_grad == False)
return oneHotTarget