forked from wenxinxu/resnet-in-tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 3
/
layers.py
42 lines (34 loc) · 1.48 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def _variable_on_device(name, shape, initializer, trainable, device):
if device==None:
var = tf.get_variable(name, shape, initializer=initializer, trainable=trainable)
else:
with tf.device(device):
var = tf.get_variable(name, shape, initializer=initializer, trainable=trainable)
return var
def base_layer(input_tensor, shape, F, bias, layer_name, device):
with tf.variable_scope(layer_name) as scope:
weight = _variable_on_device('weight',
shape=shape,
initializer=tf.contrib.layers.variance_scaling_initializer(),
trainable=True,
device=device)
if bias:
b = _variable_on_device('bias',
shape=shape[-1],
initializer=tf.zeros_initializer(),
trainable=True,
device=device)
preactivation = F(input_tensor, weight) + b
else:
preactivation = F(input_tensor, weight)
return preactivation
def convolution_layer(input_tensor, shape, strides, bias, layer_name, device):
def conv2d(x, W):
return tf.nn.conv2d(x, W, strides=strides, padding='SAME', use_cudnn_on_gpu=True)
return base_layer(input_tensor, shape, conv2d, bias, layer_name, device)
def full_connection_layer(input_tensor, shape, bias, layer_name, device):
return base_layer(input_tensor, shape, tf.matmul, bias, layer_name, device)