-
Notifications
You must be signed in to change notification settings - Fork 4
/
neural_encryption.py
91 lines (68 loc) · 3.18 KB
/
neural_encryption.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# -*- coding: utf-8 -*-
import time
import datetime
import tensorflow as tf
from datagen import get_random_block
from session_manager import save_session
from net import build_input_layers, build_network
def reconstruction_loss(msg, output):
"""Autoencoder error."""
return tf.reduce_mean(tf.abs(tf.subtract(msg, output))) / 2
def bits_loss(msg, output, message_length):
"""Autoencoder error in number of different bits."""
return reconstruction_loss(msg, output) * message_length
message_length = 16 # in bits
key_length = message_length # in bits
batch = 512 # Number of messages to train on at once
adv_iter = 100 # Adversarial iterations
max_iter = 20 # Individual agent iterations
learning_rate = 0.0008
if __name__ == "__main__":
msg, key = build_input_layers(message_length, key_length)
alice_output, bob_output, eve_output = build_network(msg, key)
eve_loss = reconstruction_loss(msg, eve_output)
bob_reconst_loss = reconstruction_loss(msg, bob_output)
bob_loss = bob_reconst_loss + (0.5 - eve_loss) ** 2
eve_bit_loss = bits_loss(msg, eve_output, message_length)
bob_bit_loss = bits_loss(msg, bob_output, message_length)
AB_vars = (
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "alice") +
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "bob")
)
E_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "eve")
trainAB = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
bob_loss, var_list=AB_vars)
trainE = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(
eve_loss, var_list=E_vars)
writer = tf.summary.FileWriter("logs/{}".format(datetime.datetime.now()))
tf.summary.scalar("eve_error", eve_loss)
tf.summary.scalar("bob_reconst_error", bob_reconst_loss)
tf.summary.scalar("bob_error", bob_loss)
tf.summary.scalar("eve_bit_error", eve_bit_loss)
tf.summary.scalar("bob_bit_error", bob_bit_loss)
merged_summary = tf.summary.merge_all()
with tf.Session() as sess:
tf.global_variables_initializer().run()
writer.add_graph(sess.graph)
for i in range(adv_iter):
print("\nIteration:", i)
start_time = time.time()
feed_dict = {
msg: get_random_block(message_length, batch),
key: get_random_block(key_length, batch)
}
print("\tTraining Alice and Bob for {} iterations..."
.format(max_iter))
for j in range(max_iter):
sess.run(trainAB, feed_dict=feed_dict)
print("\tTraining Eve for {} iterations...".format(2 * max_iter))
for j in range(2 * max_iter):
sess.run(trainAB, feed_dict=feed_dict)
results = [eve_loss, bob_loss, merged_summary]
eve_error, bob_error, summary = sess.run(results,
feed_dict=feed_dict)
writer.add_summary(summary, global_step=i)
writer.flush()
print("\tEve error: {:.4f} | Bob error: {:.4f} | Time: {:.2f}s"
.format(eve_error, bob_error, time.time() - start_time))
save_session(sess, "alice_bob")