Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code for custom datasets and visualization notebooks added #34

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 286 additions & 0 deletions .ipynb_checkpoints/DEC-viz-checkpoint.ipynb

Large diffs are not rendered by default.

286 changes: 286 additions & 0 deletions DEC-viz.ipynb

Large diffs are not rendered by default.

49 changes: 33 additions & 16 deletions DEC.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@

from time import time
import numpy as np
import keras.backend as K
from keras.engine.topology import Layer, InputSpec
from keras.layers import Dense, Input
from keras.models import Model
from keras.optimizers import SGD
from keras import callbacks
from keras.initializers import VarianceScaling
import tensorflow.keras.backend as K
from tensorflow.keras.layers import Layer, InputSpec, Dense, Input
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import SGD
from tensorflow.keras import callbacks
from tensorflow.keras.initializers import VarianceScaling
from sklearn.cluster import KMeans
import metrics
import pdb


def autoencoder(dims, act='relu', init='glorot_uniform'):
Expand Down Expand Up @@ -93,6 +93,7 @@ def build(self, input_shape):
self.set_weights(self.initial_weights)
del self.initial_weights
self.built = True
#pdb.set_trace()

def call(self, inputs, **kwargs):
""" student t-distribution, as same as used in t-SNE algorithm.
Expand Down Expand Up @@ -120,7 +121,7 @@ def get_config(self):
class DEC(object):
def __init__(self,
dims,
n_clusters=10,
n_clusters=5,
alpha=1.0,
init='glorot_uniform'):

Expand Down Expand Up @@ -158,16 +159,18 @@ def on_epoch_end(self, epoch, logs=None):
self.model.get_layer(
'encoder_%d' % (int(len(self.model.layers) / 2) - 1)).output)
features = feature_model.predict(self.x)
km = KMeans(n_clusters=len(np.unique(self.y)), n_init=20, n_jobs=4)
km = KMeans(n_clusters=len(np.unique(self.y)), n_init=20)
y_pred = km.fit_predict(features)
# print()
print(' '*8 + '|==> acc: %.4f, nmi: %.4f <==|'
% (metrics.acc(self.y, y_pred), metrics.nmi(self.y, y_pred)))
#pdb.set_trace()

cb.append(PrintACC(x, y))

# begin pretraining
t0 = time()
#pdb.set_trace()
self.autoencoder.fit(x, x, batch_size=batch_size, epochs=epochs, callbacks=cb)
print('Pretraining time: %ds' % round(time() - t0))
self.autoencoder.save_weights(save_dir + '/ae_weights.h5')
Expand All @@ -192,7 +195,7 @@ def target_distribution(q):
def compile(self, optimizer='sgd', loss='kld'):
self.model.compile(optimizer=optimizer, loss=loss)

def fit(self, x, y=None, maxiter=2e4, batch_size=256, tol=1e-3,
def fit(self, x, y=None, maxiter=2e5, batch_size=256, tol=1e-3,
update_interval=140, save_dir='./results/temp'):

print('Update interval', update_interval)
Expand All @@ -203,6 +206,7 @@ def fit(self, x, y=None, maxiter=2e4, batch_size=256, tol=1e-3,
t1 = time()
print('Initializing cluster centers with k-means.')
kmeans = KMeans(n_clusters=self.n_clusters, n_init=20)
#pdb.set_trace()
y_pred = kmeans.fit_predict(self.encoder.predict(x))
y_pred_last = np.copy(y_pred)
self.model.get_layer(name='clustering').set_weights([kmeans.cluster_centers_])
Expand All @@ -217,6 +221,7 @@ def fit(self, x, y=None, maxiter=2e4, batch_size=256, tol=1e-3,
loss = 0
index = 0
index_array = np.arange(x.shape[0])
#pdb.set_trace()
for ite in range(int(maxiter)):
if ite % update_interval == 0:
q = self.model.predict(x, verbose=0)
Expand All @@ -233,6 +238,8 @@ def fit(self, x, y=None, maxiter=2e4, batch_size=256, tol=1e-3,
logwriter.writerow(logdict)
print('Iter %d: acc = %.5f, nmi = %.5f, ari = %.5f' % (ite, acc, nmi, ari), ' ; loss=', loss)

#print(q, p)
#pdb.set_trace()
# check stop criterion
delta_label = np.sum(y_pred != y_pred_last).astype(np.float32) / y_pred.shape[0]
y_pred_last = np.copy(y_pred)
Expand All @@ -249,7 +256,9 @@ def fit(self, x, y=None, maxiter=2e4, batch_size=256, tol=1e-3,
loss = self.model.train_on_batch(x=x[idx], y=p[idx])
index = index + 1 if (index + 1) * batch_size <= x.shape[0] else 0


# save intermediate model
#pdb.set_trace()
if ite % save_interval == 0:
print('saving model to:', save_dir + '/DEC_model_' + str(ite) + '.h5')
self.model.save_weights(save_dir + '/DEC_model_' + str(ite) + '.h5')
Expand All @@ -258,6 +267,7 @@ def fit(self, x, y=None, maxiter=2e4, batch_size=256, tol=1e-3,

# save the trained model
logfile.close()
#pdb.set_trace()
print('saving model to:', save_dir + '/DEC_model_final.h5')
self.model.save_weights(save_dir + '/DEC_model_final.h5')

Expand All @@ -271,9 +281,9 @@ def fit(self, x, y=None, maxiter=2e4, batch_size=256, tol=1e-3,
parser = argparse.ArgumentParser(description='train',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--dataset', default='mnist',
choices=['mnist', 'fmnist', 'usps', 'reuters10k', 'stl'])
choices=['mnist', 'fmnist', 'usps', 'reuters10k', 'stl', 'custom'])
parser.add_argument('--batch_size', default=256, type=int)
parser.add_argument('--maxiter', default=2e4, type=int)
parser.add_argument('--maxiter', default=2e5, type=int)
parser.add_argument('--pretrain_epochs', default=None, type=int)
parser.add_argument('--update_interval', default=None, type=int)
parser.add_argument('--tol', default=0.001, type=float)
Expand All @@ -288,7 +298,7 @@ def fit(self, x, y=None, maxiter=2e4, batch_size=256, tol=1e-3,
# load dataset
from datasets import load_data
x, y = load_data(args.dataset)
n_clusters = len(np.unique(y))
n_clusters = 5

init = 'glorot_uniform'
pretrain_optimizer = 'adam'
Expand All @@ -298,13 +308,19 @@ def fit(self, x, y=None, maxiter=2e4, batch_size=256, tol=1e-3,
pretrain_epochs = 300
init = VarianceScaling(scale=1. / 3., mode='fan_in',
distribution='uniform') # [-limit, limit], limit=sqrt(1./fan_in)
pretrain_optimizer = SGD(lr=1, momentum=0.9)
pretrain_optimizer = SGD(learning_rate=1, momentum=0.9)
elif args.dataset == 'custom':
update_interval = 140
pretrain_epochs = 500
init = VarianceScaling(scale=1. / 3., mode='fan_in',
distribution='uniform') # [-limit, limit], limit=sqrt(1./fan_in)
#pretrain_optimizer = SGD(learning_rate=1, momentum=0.9)
elif args.dataset == 'reuters10k':
update_interval = 30
pretrain_epochs = 50
init = VarianceScaling(scale=1. / 3., mode='fan_in',
distribution='uniform') # [-limit, limit], limit=sqrt(1./fan_in)
pretrain_optimizer = SGD(lr=1, momentum=0.9)
pretrain_optimizer = SGD(learning_rate=1, momentum=0.9)
elif args.dataset == 'usps':
update_interval = 30
pretrain_epochs = 50
Expand All @@ -318,7 +334,7 @@ def fit(self, x, y=None, maxiter=2e4, batch_size=256, tol=1e-3,
pretrain_epochs = args.pretrain_epochs

# prepare the DEC model
dec = DEC(dims=[x.shape[-1], 500, 500, 2000, 10], n_clusters=n_clusters, init=init)
dec = DEC(dims=[x.shape[-1], 500, 500, 2000, 5], n_clusters=n_clusters, init=init)

if args.ae_weights is None:
dec.pretrain(x=x, y=y, optimizer=pretrain_optimizer,
Expand All @@ -332,5 +348,6 @@ def fit(self, x, y=None, maxiter=2e4, batch_size=256, tol=1e-3,
dec.compile(optimizer=SGD(0.01, 0.9), loss='kld')
y_pred = dec.fit(x, y=y, tol=args.tol, maxiter=args.maxiter, batch_size=args.batch_size,
update_interval=update_interval, save_dir=args.save_dir)
#pdb.set_trace()
print('acc:', metrics.acc(y, y_pred))
print('clustering time: ', (time() - t0))
Loading