文本分类(三)-构建模型III-LSTM网络参数

假如网络的超参数配置如下:

1
2
3
4
5
6
7
8
9
10
11
12
def set_default_parameters():
return tf.contrib.training.HParams(
embedding_size=16,
encoded_length=50,
num_word_threshold=20,
num_lstm_nodes=[999, 32], # 999
num_lstm_layers=2,
num_fc_nodes=555, # 555
batch_size=100,
learning_rate=0.001,
clip_lstm_grads=1.0,
)

在构建模型图的过程中,需要对LSTM单元中的每一个门制定参数大小:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def _generate_params_for_lstm_cell(x_size, h_size, bias_size):
x_w = tf.get_variable('x_weights', x_size)
h_w = tf.get_variable('h_weights', h_size)
b = tf.get_variable('bias', bias_size, initializer=tf.constant_initializer(0.0))
return x_w, h_w, b
# one LSTM layer
with tf.variable_scope('lstm', initializer=lstm_init):
# all params in the lstm cell:
with tf.variable_scope('inputs'):
ix_w, ih_w, ib = _generate_params_for_lstm_cell(
x_size=[hps.embedding_size, hps.num_lstm_nodes[0]],
h_size=[hps.num_lstm_nodes[0], hps.num_lstm_nodes[0]],
bias_size=[1, hps.num_lstm_nodes[0]]
)
with tf.variable_scope('outputs'):
ox_w, oh_w, ob = _generate_params_for_lstm_cell(
x_size=[hps.embedding_size, hps.num_lstm_nodes[0]],
h_size=[hps.num_lstm_nodes[0], hps.num_lstm_nodes[0]],
bias_size=[1, hps.num_lstm_nodes[0]]
)
with tf.variable_scope('forget'):
fx_w, fh_w, fb = _generate_params_for_lstm_cell(
x_size=[hps.embedding_size, hps.num_lstm_nodes[0]],
h_size=[hps.num_lstm_nodes[0], hps.num_lstm_nodes[0]],
bias_size=[1, hps.num_lstm_nodes[0]]
)
# tanh
with tf.variable_scope('memory'):
cx_w, ch_w, cb = _generate_params_for_lstm_cell(
x_size=[hps.embedding_size, hps.num_lstm_nodes[0]],
h_size=[hps.num_lstm_nodes[0], hps.num_lstm_nodes[0]],
bias_size=[1, hps.num_lstm_nodes[0]]
)
state_C = tf.Variable(
tf.zeros([batch_size, hps.num_lstm_nodes[0]]),
trainable=False
)
h = tf.Variable(
tf.zeros([batch_size, hps.num_lstm_nodes[0]]),
trainable=False
)

查看所有可训练的参数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
<tf.Variable 'embedding/embedding:0' shape=(50513, 16) dtype=float32_ref>
<tf.Variable 'lstm/inputs/x_weights:0' shape=(16, 999) dtype=float32_ref>
<tf.Variable 'lstm/inputs/h_weights:0' shape=(999, 999) dtype=float32_ref>
<tf.Variable 'lstm/inputs/bias:0' shape=(1, 999) dtype=float32_ref>
<tf.Variable 'lstm/outputs/x_weights:0' shape=(16, 999) dtype=float32_ref>
<tf.Variable 'lstm/outputs/h_weights:0' shape=(999, 999) dtype=float32_ref>
<tf.Variable 'lstm/outputs/bias:0' shape=(1, 999) dtype=float32_ref>
<tf.Variable 'lstm/forget/x_weights:0' shape=(16, 999) dtype=float32_ref>
<tf.Variable 'lstm/forget/h_weights:0' shape=(999, 999) dtype=float32_ref>
<tf.Variable 'lstm/forget/bias:0' shape=(1, 999) dtype=float32_ref>
<tf.Variable 'lstm/memory/x_weights:0' shape=(16, 999) dtype=float32_ref>
<tf.Variable 'lstm/memory/h_weights:0' shape=(999, 999) dtype=float32_ref>
<tf.Variable 'lstm/memory/bias:0' shape=(1, 999) dtype=float32_ref>
<tf.Variable 'fc/fc1/kernel:0' shape=(999, 555) dtype=float32_ref>
<tf.Variable 'fc/fc1/bias:0' shape=(555,) dtype=float32_ref>
<tf.Variable 'fc/fc2/kernel:0' shape=(555, 10) dtype=float32_ref>
<tf.Variable 'fc/fc2/bias:0' shape=(10,) dtype=float32_ref>

可以看出LSTM单元中每一个部件的参数大小。模型中只有一层LSTM,所有的LSTM层中的999,均指的是该层有999个LSTM单元。完整实现看这里

结合colah的这篇文章进一步理解。