caffe-使用已有的lenet模型

关于如何使用,在网上找了半天,没有一篇博客说清楚。之后才发现,下载好的caffe的examples目录中就是我想找的。通过examples,caffe的工作流程基本明了。一般由4步:

  1. 数据准备: 将原始数据转化为caffe-format
  2. 定义网络结构和参数,
  3. 定义训练超参数solver,
  4. 训练

给个caffe自带的mnist例子

1.数据准备

回到caffe的根目录后执行下面脚本:

1
2
$./data/mnist/get_mnist.sh
$./examples/mnist/create_mnist.sh

上述代码所做的事情是:先下载解压mnist数据集,后将源数据通过build/examples/mnist/convert_mnist_data.bin写入将原始mnist数据写成LMDB格式,所以会有两个文件生成:mnist_train_lmdbmnist_test_lmdb

2.定义模型结构和参数

我使用caffe为我们定义好的LeNet网络,它的定义在文件examples/mnist/lenet_train_test.prototxt中。定义的格式是Google Protobuff,解析.prorotxt文件的 统一规则由文件/src/caffe/proto/caffe.proto.提供。

打开文件lenet_train_test.prototxt,是lenet的网络结构:

1
2
3
4
5
6
7
8
9
name: "LeNet"
layer {
// ...layer definition...
include: { phase: TRAIN }
}
layer {
}
layer {
}
  • 若干层(Layer)堆叠在一起,构成了一个网络(Net)。lenet网络每层定义的细节间附录。
  • 这个文件被称作network definition protobuf file。其中top表示这层输出,bottom表示这层出入。caffe中模型的逻辑图结构与其.prototxt文件的对应。
  • 网络结构可以通过可视化工具绘制(详见后)。

3.指明训练超参数:solver.prototxt

从这里查看训练超参数examples/mnist/lenet_solver.prototxt。在caffe中,设定训练超参数(非训练参数)被称作solver,这个文件被称作solver protobuf file,也是Google protobuff 文件格式。如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# 网络结构定义
net: "examples/mnist/lenet_train_test.prototxt"
# Test阶段的迭代次数。batch_size也是100,所以由10,000测试数据
test_iter: 100
# 每500次训练后执行1次测试
test_interval: 500
# The base learning rate, momentum and the weight decay of the network.
base_lr: 0.01
momentum: 0.9
weight_decay: 0.0005
# 设定学习率的策略
lr_policy: "inv"
gamma: 0.0001
power: 0.75
# 每100次迭代后,打印
display: 100
# 最大迭代次数
max_iter: 10000
# 每5000次记录一次快照,中间结果
snapshot: 5000
snapshot_prefix: "examples/mnist/lenet"
# 用什么计算
solver_mode: GPU

这个文件在程序中按照caffe.proto的协议解析到一个Caffe::SolverParameter 对象中,进而将解析后的参数使用到网络中。用于解析的文件是 caffe.pb.hcaffe.pb.cc,其位于build/src/caffe/proto/

默认使用GPU。可以使用同目录的其他的参数配置文件(不同优化方法):

lenet_multistep_solver.prototxt
lenet_solver_adam.prototxt
lenet_solver_rmsprop.prototxt

4.训练与测试

有了caffe-format的数据;有了网络结构;有了确定的超参数,就可以执行下面脚本开始训练:

1
./examples/mnist/train_lenet.sh

train_lenet.sh文件包含一行代码:

1
2
3
4
#!/usr/bin/env sh
set -e

./build/tools/caffe train --solver=examples/mnist/lenet_solver.prototxt $@

使用caffe命令(这个caffe是tools中的命令,这是caffe提供的命令行接口cmdcaffe), 指明是trian,指明训练超参数配置,即使用哪个*solver.prototxt文件。(而网络结构超参数的地址,在*solver.prototxt中指明)

训练结束后,在examples/mnist/路径下生成两个文件lenet_iter_10000.caffemodellenet_iter_10000.solverstate

  1. .caffemodel文件是最终模型,二进制文件。
  2. .solverstate是保存的训练状态(snapshot),可以从此继续开始训练。

训练结束

lenet_iter_10000.caffemodel 是binary protobuf 文件,就是训练10000次的最终模型,用于对真实样本的推理。

MNIST数据集是小数据集,所以使用GPU的效果并不好,原因是GPU的计算核心与存储的通讯开销。所以对于复杂模型和数据集GPU的速度提升是显著的。

使用预训练的model对新样本预测

训练结束后生成的.caffemodel是可以直接拿来使用的模型(其实就是所有训练好的参数),用作推理。如何使用?

如果要使用预训练的模型,从这里下载预训练的模型caffe model index,如bvlc_reference_caffenet.caffemodel, 并且将其放到caffe-root/models/bvlc_reference_caffenet/中。

使用预训练模型:

1
2
3
4
5
# 指明 -weights 关键字,提供预训练模型
./build/tools/caffe \
train \
--solver examples/finetuning_on_flickr_style/solver.prototxt \
--weights models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel

上面有问题

总结

生成制定格式数据{数据&模型之数据}:

1
2
$./data/mnist/get_mnist.sh
$./examples/mnist/create_mnist.sh

在GPU上训练{数据&模型之模型}:

1
2
./build/tools/caffe train \
--solver=examples/mnist/lenet_solver.prototxt -gpu 0

关键字为train,需要提供*solver.prototxt文件,和网络结构文件,后者的路径在*solver.prototxt文件中指明。有了结构参数和超参数,就可以进行学习。

提供快照文件.solverstate可以接着训练:

1
2
3
./build/tools/caffe train \
-solver examples/mnist/lenet_solver.prototxt \
-snapshot examples/mnist/lenet_iter_5000.solverstate

附录 lenet模型定义

说明,以下是模型结构参数文件的内容,其包括了TRAIN网络和TEST网络,所需的每个Layer。其实在运行Log中打印的两个网络才更直观。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
name: "LeNet"
layer {
name: "mnist"
type: "Data"
top: "data"
top: "label"
include {
phase: TRAIN # 本层只用于训练过程中
}
transform_param { # 数据变换使用的方所因子
scale: 0.00390625
}
data_param { # 数据层参数,包括二进制数据的路径
source: "examples/mnist/mnist_train_lmdb"
batch_size: 64
backend: LMDB
}
}
layer {
name: "mnist"
type: "Data"
top: "data"
top: "label"
include {
phase: TEST # 同样的数据层,只用于测试阶段
}
transform_param {
scale: 0.00390625
}
data_param {
source: "examples/mnist/mnist_test_lmdb"
batch_size: 100
backend: LMDB
}
}
layer {
name: "conv1"
type: "Convolution"
bottom: "data"
top: "conv1"
param {
lr_mult: 1 # 权值学习速率因子,1表示保持与全局参数一致
}
param {
lr_mult: 2 # bias学习因子,2表示是全局的2倍
}
convolution_param { # 卷基层参数
num_output: 20
kernel_size: 5
stride: 1
weight_filler {
type: "xavier" # 权值使用xavier方法初始化?
}
bias_filler {
type: "constant" # bias初始化为常量?
}
}
}
layer {
name: "pool1"
type: "Pooling"
bottom: "conv1"
top: "pool1"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
layer {
name: "conv2"
type: "Convolution"
bottom: "pool1"
top: "conv2"
param {
lr_mult: 1
}
param {
lr_mult: 2
}
convolution_param {
num_output: 50
kernel_size: 5
stride: 1
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
}
}
}
layer {
name: "pool2"
type: "Pooling"
bottom: "conv2"
top: "pool2"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
layer {
name: "ip1"
type: "InnerProduct"
bottom: "pool2"
top: "ip1"
param {
lr_mult: 1
}
param {
lr_mult: 2
}
inner_product_param { # 全连接层InnerProduct
num_output: 500
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
}
}
}
layer { # 非现行变换层
name: "relu1"
type: "ReLU"
bottom: "ip1"
top: "ip1"
}
layer {
name: "ip2"
type: "InnerProduct"
bottom: "ip1"
top: "ip2"
param {
lr_mult: 1
}
param {
lr_mult: 2
}
inner_product_param {
num_output: 10
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
}
}
}
layer {
name: "accuracy" # 计算accuracy 只用于Test阶段
type: "Accuracy"
bottom: "ip2"
bottom: "label"
top: "accuracy"
include {
phase: TEST
}
}
layer {
name: "loss" # loss层
type: "SoftmaxWithLoss"
bottom: "ip2"
bottom: "label"
top: "loss"
}