-
Notifications
You must be signed in to change notification settings - Fork 2
/
ckpt_transfer.py
132 lines (116 loc) · 5.63 KB
/
ckpt_transfer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#!/usr/bin/env python
# coding=utf-8
import tensorflow as tf
import argparse
import os
parser = argparse.ArgumentParser(description='')
# .../xxx.ckpt
parser.add_argument("--old_ckpt", required=True, help="old ckpt name") # old ckpt path
parser.add_argument("--new_path", required=True, help="path to ckpt") # old ckpt path
parser.add_argument("--prefix", default='teacher', help="prefix for addition") # new prefix
args = parser.parse_args()
def transfer_teacher():
if not os.path.exists(args.old_ckpt+'.meta'):
print("ckpt: {} not exist!".format(args.old_ckpt))
exit()
if not os.path.exists(args.new_path):
print("new_path is not exist!")
os.makedirs(args.new_path)
print("make new_path")
#exit()
with tf.Session() as sess:
new_var_list = []
for var_name, _ in tf.contrib.framework.list_variables(args.old_ckpt):
var = tf.contrib.framework.load_variable(args.old_ckpt, var_name)
new_name = var_name
new_name = new_name.replace("query_triplets", args.prefix+"/query_triplets")
if new_name in ["Variable", "beta1_power", "beta2_power"]:
new_name = args.prefix + '/' + new_name
#new_name = args.prefix + '/' + new_name # add new prefix
print('Renaming %s \n ==> %s.' % (var_name, new_name))
renamed_var = tf.Variable(var, name=new_name)
new_var_list.append(renamed_var)
print('starting to write new checkpoint !')
saver = tf.train.Saver(var_list=new_var_list) # create a new saver
sess.run(tf.global_variables_initializer()) # initial variable, very important!!!
model_name = args.prefix + '_' + args.old_ckpt.split('/')[-1] # new_name
print('model_name: {}'.format(model_name))
#exit()
new_ckpt = os.path.join(args.new_path, model_name) # create new_ckpt path
print('new_ckpt: {}'.format(new_ckpt))
saver.save(sess, new_ckpt) # save new ckpt
print("done !")
def transfer_student():
if not os.path.exists(args.old_ckpt+'.meta'):
print("ckpt: {} not exist!".format(args.old_ckpt))
exit()
if not os.path.exists(args.new_path):
print("new_path is not exist!")
os.makedirs(args.new_path)
print("make new_path")
#exit()
with tf.Session() as sess:
new_var_list = []
for var_name, _ in tf.contrib.framework.list_variables(args.old_ckpt):
var = tf.contrib.framework.load_variable(args.old_ckpt, var_name)
if "VLAD" in var_name: continue
new_name = var_name
new_name = new_name.replace("query_triplets", args.prefix+"/query_triplets")
new_name = new_name.replace("fastdgcnn", "BACKBONE")
if new_name in ["Variable", "beta1_power", "beta2_power"]:
new_name = args.prefix + '/' + new_name
#new_name = args.prefix + '/' + new_name # add new prefix
print('Renaming %s \n ==> %s.' % (var_name, new_name))
renamed_var = tf.Variable(var, name=new_name)
new_var_list.append(renamed_var)
print('starting to write new checkpoint !')
saver = tf.train.Saver(var_list=new_var_list) # create a new saver
sess.run(tf.global_variables_initializer()) # initial variable, very important!!!
model_name = args.prefix + '_' + args.old_ckpt.split('/')[-1] # new_name
print('model_name: {}'.format(model_name))
#exit()
new_ckpt = os.path.join(args.new_path, model_name) # create new_ckpt path
print('new_ckpt: {}'.format(new_ckpt))
saver.save(sess, new_ckpt) # save new ckpt
print("done !")
def transfer_fastdgcnn():
if not os.path.exists(args.old_ckpt+'.meta'):
print("ckpt: {} not exist!".format(args.old_ckpt))
exit()
if not os.path.exists(args.new_path):
print("new_path is not exist!")
os.makedirs(args.new_path)
print("make new_path")
#exit()
with tf.Session() as sess:
new_var_list = []
for var_name, _ in tf.contrib.framework.list_variables(args.old_ckpt):
var = tf.contrib.framework.load_variable(args.old_ckpt, var_name)
if "VLAD" in var_name: continue
new_name = var_name
#new_name = new_name.replace("query_triplets", args.prefix+"/query_triplets")
#if new_name in ["Variable", "beta1_power", "beta2_power"]:
# new_name = args.prefix + '/' + new_name
#new_name = args.prefix + '/' + new_name # add new prefix
print('Renaming %s \n ==> %s' % (var_name, new_name))
renamed_var = tf.Variable(var, name=new_name)
new_var_list.append(renamed_var)
print('starting to write new checkpoint !')
saver = tf.train.Saver(var_list=new_var_list) # create a new saver
sess.run(tf.global_variables_initializer()) # initial variable, very important!!!
model_name = args.prefix + '_' + args.old_ckpt.split('/')[-1] # new_name
print('model_name: {}'.format(model_name))
#exit()
new_ckpt = os.path.join(args.new_path, model_name) # create new_ckpt path
print('new_ckpt: {}'.format(new_ckpt))
saver.save(sess, new_ckpt) # save new ckpt
print("done !")
if __name__ == '__main__':
if args.prefix == "teacher":
transfer_teacher()
elif args.prefix == "student":
transfer_student()
elif args.prefix == "fastdgcnn":
transfer_fastdgcnn()
else:
print("prefix error! prefix only supports ['teacher', 'studnet']")