LeetCode-merge应用-求逆序对

分治&归并

涉及到两两比较的操作,都可以考虑分治&归并。并归使得当前的对象两两比较完备,而分治使得两个分别两两比较后的对象之间两两比较,完备。

问题描述,求逆序对个数:[2,7,8,1,3,5],比如其中的(2,7) (2,8) (2,3) (2,5)以2 开始的逆序对有4个。求一共有多少个逆序对。

思路:

如描述中的例子,当分制后,有[2,7,8][1,3,5]两部分,而2>1,所以以1为逆序对中第二个元素的逆序对有3个,(2,1) (7,1) (8,1)。这是因为并归操作的前提是两部分,分别有序。

核心是,当左部分第一个元素2大于右边第一个元素1时,count+=3,(2,7,8 有3个元素)。这是这个问题的关键!

实现:

看Merge操作:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
void mergeOp_(vector<int>& nums, int l, int mid, int r){
int* aux = new int[r-l+1];
for (int i=l;i<=r;i++){
aux[i-l]=nums[i];
}

int i=l;
int j=mid+1;
int b=j;
for (int k=l;k<=r;k++){
if (i>mid) { // 左边排玩
nums[k]=aux[j-l];
j++;
}
else if (j>r) { // 右边排完
nums[k]=aux[i-l];
i++;
}
else if (aux[i-l]<=aux[j-l]){
nums[k] = aux[i-l];
i++;
}
// 左部分元素大于右部分元素时,记录逆序对个数
else if(aux[i-l]>aux[j-l]){
count_ += (b-i); // 这里
nums[k] = aux[j-l];
j++;
}
}
delete aux;
aux=nullptr;
return;
}

count_ += (b-i); 便是记录逆序对个数。

将上述操作放入下面code的完整的求逆序对解法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Sol_mergeSort{
int count_=0;

void split_(vector<int>& nums, int l, int r){
if (r<=l) return;
int mid = (r-l)/2+l;
split_(nums, l, mid);
split_(nums, mid+1, r);
mergeOp_(nums, l, mid, r);
return;
}

// mergeOp_()

public:
void mergeSort(vector<int>& nums){
split_(nums, 0, nums.size()-1);
return;
}

int count(){
return count_;
}
};

测试对象:{1,3,4,0,2,5},逆序对有5对儿。

上述code结果:

1
2
5  
0 1 2 3 4 5 // 排序后的结果

复杂度:O(N*longN)

敲黑板:从这个问题体会分治策略

LeetCode-链表相关

给出链表结点的定义:

1
2
3
4
5
6
7
struct ListNode {
int val;
ListNode *next;
ListNode() : val(0), next(nullptr) {}
ListNode(int x) : val(x), next(nullptr) {}
ListNode(int x, ListNode *next) : val(x), next(next) {}
};

从数组构造一个链表:

1
2
3
4
5
6
7
8
9
10

// create a linkedList from a vector
ListNode* createList(vector<int>& nums) {
ListNode* tail = nullptr;
for (int i = nums.size() - 1; i >= 0; i--) {
ListNode* tmpNode = new ListNode(nums[i], tail);
tail = tmpNode;
}
return tail;
}

补充 递归

从链表到树的递归操作,体会递归从上到下,执行从底向上

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// traverse a list
void printListRecurvely(ListNode* head) {
if (head == nullptr) return;
printListRecurvely(head->next);
cout << head->val << "->";

return;
}

void inOrderTraverse(TreeNode* root) {

if (root == nullptr) return;
cout << root->val << ", ";
inOrderTraverse(root->left);
inOrderTraverse(root->right);

return;
}

206 翻转整个链表

思路一与#92的翻转子操作相似,定义三个指针,并初始化,在移动到链表尾部的过程中改变next指针的指向,从而实现翻转链表。

实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// # 206 翻转整个链表
ListNode* reverseList(ListNode* head) {
if (head == nullptr || head->next == nullptr) return head;
ListNode* left = nullptr;
ListNode* mid = head;
ListNode* right = head->next;

while (right != nullptr) {
// operation
mid->next = left;
// move ptrs
left = mid;
mid = right;
right = right->next;
}
mid->next = left;
return mid;
}

思路二:递归实现,从最后一个结点开始向前指向,最终返回最后一个结点。实现如下:

1
2
3
4
5
6
7
8
ListNode* reverseList(ListNode* head){
if(!head || !head->next) return head;
ListNode* p = reverseList(head->next);

head->next->next = head;
head->next = nullptr;
return p;
}

其过程图示见笔记本。

分析递归返回值:返回值p始终指向原始链表的尾结点,而它又正好是翻转后的头结点。

92 翻转链表两个位置间的结点,部分翻转

例:

1
2
Input: 1->2->3->4->5->NULL, m = 2, n = 4
Output: 1->4->3->2->5->NULL

思路:先实现翻转连个结点间的结点,作为子造作,后移动指针,找到对应的m和n,调用子操作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
//# 92 翻转链表某个部分
class Sol92 {
private:
/// reverse node from head to tail
ListNode* reverseList(ListNode* head, ListNode* tail) {
ListNode* pre = tail;
ListNode* cur = head;
while (cur != tail) {
ListNode* next = cur->next;
cur->next = pre;
pre = cur;
cur = next;
}
return pre;
}

对于对整个链表的翻转,上述函数参数tailnullptr,即表示链表结尾的结点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
public:
ListNode* reverseBetween(ListNode* head, int m, int n) {
assert(m <= n);
// if (!head) return head;
// if (m==n) return head;

ListNode* dhead = new ListNode(-1, head);

ListNode* nodeM = dhead;
for (int i = 0; i < m; i++)
nodeM = nodeM->next;

ListNode* nodeN = dhead;
for (int i = 0; i < n + 1; i++)
nodeN = nodeN->next;

ListNode* subHead = dhead;
for (int i = 0; i < m - 1; i++)
subHead = subHead->next;

subHead->next = reverseList(nodeM, nodeN);

return dhead->next;
}
};

过程示意图见笔记本。

142 判断链表是否有环,若有则返回环的起点

思路一:在纸上比划比划,找规律。当快慢两指针相遇时,慢指针和entry指针同时向后移动,直到两者相遇,此时两指针所指向的就是环的起点。

思路二:首先判断是否有环,然后快慢指针在若干步后相遇(相遇时慢指针所走的步数正好是环中的节点数)。最后让新指针从头开始,慢指针从相遇位置开始,一起移动知道两者相遇,此时的相遇结点就是环的入口。

两思路其实是一样的。

实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Sol142{
public:
ListNode* entryPoint(ListNode* head){
if (head == nullptr || head->next == nullptr)
return nullptr;
ListNode* dhead = new ListNode(-1, head);
ListNode* fast = dhead;
ListNode* slow = dhead;
ListNode* entry = dhead;
while(fast->next && fast->next->next ){
fast = fast->next->next;
slow = slow->next;
if (slow == fast){ // 表示有环
// slow和entry同时移动,直到两者相遇,相遇点就是
// 环的起点
while(slow != entry){
slow = slow->next;
entry = entry->next;
}
return entry;
}
}
return nullptr;
}
};

61 切断链表后再接上

描述:

翻转链表,从右边第k个位置切断,交换两部分位置,后将两部分接上,返回新的链表。

1
2
3
4
5
Input: 1->2->3->4->5->NULL, k = 2
Output: 4->5->1->2->3->NULL
Explanation:
rotate 1 steps to the right: 5->1->2->3->4->NULL
rotate 2 steps to the right: 4->5->1->2->3->NULL

思路:

初始化所有指针,后移到其应该在的位置,最后重新链接两部分。就是指针操作。

实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class Sol61 {
public:
ListNode* rotateRight(ListNode* head, int k) {
if (head==nullptr) return head;
//1. get the length of list
ListNode* curr = head;
int length=1;
while(curr->next){
curr = curr->next;
length++;
}
// 注意这里,
int newK = k%length;
if (newK == 0) return head;

//2. init ptrs
ListNode* dhead = new ListNode(-1, head);
ListNode* end = dhead;
for (int i=0;i<newK;i++){
end=end->next;
}
ListNode* pre = dhead;
ListNode* newHead = head;
while(end->next){
end=end->next;
pre=pre->next;
newHead=newHead->next;
}
//3. move ptrs
pre->next = nullptr;
end->next = head;

//4. clean up
delete dhead;
dhead=nullptr;

return newHead;
}
};

328 奇数号结点链表,偶数号结点链表

描述:

Given a singly linked list, group all odd nodes together followed by the even nodes. Please note here we are talking about the node number and not the value in the nodes.

思路:遍历结点,更改结点next指针,奇数结点相连,偶数结点相连,最后将偶数节点链表的头接在奇数结点链表的尾。示意图见笔记。

实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class Sol328 {
public:
ListNode* oddEvenList(ListNode* head) {
if (!head || !head->next) return head;

ListNode* odd = head;
ListNode* oddCur = head;
ListNode* even = head->next;
ListNode* evenCur = head->next;

while( true ){
// 奇数个结点
if (oddCur->next==nullptr && evenCur==nullptr){
oddCur->next = even;
break;
}
// 偶数个结点
if (oddCur->next->next==nullptr && evenCur->next==nullptr){
oddCur->next = even;
break;
}

oddCur->next = oddCur->next->next;
oddCur = oddCur->next;

evenCur->next = evenCur->next->next;
evenCur = evenCur->next;
}
return odd;
}
};

时间复杂度: O(N)

很直接的解法,但花了好长时间: 原因 1.测试vector马虎了;2.多个if判断的顺序问题

重点在这儿,如果上述code中这两个if调换顺序:

1
2
3
4
5
6
7
8
9
10
// 奇数个结点
if (oddCur->next->next==nullptr && evenCur->next==nullptr){
oddCur->next = even;
break;
}
// 偶数个结点
if (oddCur->next==nullptr && evenCur==nullptr){
oddCur->next = even;
break;
}

那么当结点个数为奇数个,且到链表尾时,第一个if语句中oddCur->next->next==nullptr就会出错,因为oddCur->next->next指向未定义的位置。所以总是出现Segmentation fault (core dumped)。!!!

注意:

  1. 多个if判断的顺序不是随意的。
  2. 如果在中间某处找不到错误,那么从头开始再找,别马虎。

430 扁平化一个多层链表

描述:给出一个结点定义:

1
2
3
4
5
6
7
class Node {
public:
int val;
Node* prev;
Node* next;
Node* child;
};

并构造一个多层链表,目标是将这个多层链表变为一层,如下:

1
2
3
4
5
6
7
8
9
input:
1 - 2 - 3 - 4 - 5 - null
|
6 - 7 - 8 - 9 - null
|
10 - 11 - null

output:
1 - 2 - [6 - 7 - 8 - [10 - 11] - 9] - 3 - 4 - 5 - null

思路:

在循环的过程中改变了循环的对象,所以可以将所有深度的结点遍历到。这是由浅入深的扁平化过程。

1
2
3
4
5
6
7
遇到第一个含有child的结点后:
1 - 2 - 6 - 7 - 8 - 9 - 3 - 4 - 5 - null
|
10 - 11 - null

遇到第二个含有child的结点后:
1 - 2 - 6 - 7 - 8 - 10 - 11 - 9 - 3 - 4 - 5 - null

实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
Node* flatten(Node* head) {
for (Node* h = head; h; h = h->next){
if (h->child){
Node* next = h->next;
h->next = h->child;
h->next->prev = h;
h->child = NULL;
Node* p = h->next;
while (p->next) p = p->next;
p->next = next;
if (next) next->prev = p;
}
}
return head;
}

时间复杂度: O(N)

这个思路很neat!

caffe-在新样本上使用训练好的模型

用训练好的模型给测试样本打分

经过训练有了模型:my_lr_iter_10000.caffemodel,将新的样本数据的LMDB文件 mnist_newTEST_lmdb 的路径在TEST网络的DataLayer层中指明,将含有标签的新样本文件dataLayer的source:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
layer{
name: "mnist"
type: "Data"
top: "data"
top: "label"
include {
phase: TEST # 该层应用与TEST
}
transform_param{
scale: 0.0039063
}
data_param{
source: "/media/junhui/DATA/caffe_workspace/my_linearReggresion/mnist_newTEST_lmdb"
batch_size: 100 # 每一批 100个样本
backend: LMDB
}
}

使用一下命令执行:

1
2
3
4
~/caffe-master/build/tools/caffe test \
-model /media/junhui/DATA/caffe_workspace/my_linearReggresion/mylr.prototxt \
-weights /media/junhui/DATA/caffe_workspace/my_linearReggresion/my_lr_iter_10000.caffemodel \
-iterations 2

test 表示使用TEST网络【只进行前行传播】打分,-model 指明模型结构参数,-weights 指明训练好的模型权值,-iterations 指明对多少样本进行inference,2*100 个样本。

其中caffe实际是可执行文件caffe.bin

这里是日志文件:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
I0617 13:28:09.083691  4697 caffe.cpp:275] Use CPU.
I0617 13:28:09.308647 4697 net.cpp:296] The NetState phase (1) differed from the phase (0) specified by a rule in layer mnist
I0617 13:28:09.308768 4697 net.cpp:53] Initializing net from parameters:
name: "lrNet"
state {
phase: TEST
level: 0
stage: ""
}
layer {
name: "mnist"
type: "Data"
top: "data"
top: "label"
include {
phase: TEST
}
transform_param {
scale: 0.0039063
}
data_param {
source: "/media/junhui/DATA/caffe_workspace/my_linearReggresion/mnist_newtest_lmdb"
batch_size: 100
backend: LMDB
}
}
layer {
name: "ip"
type: "InnerProduct"
bottom: "data"
top: "ip"
param {
lr_mult: 1
}
param {
lr_mult: 2
}
inner_product_param {
num_output: 10
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
}
}
}
layer {
name: "accuracy"
type: "Accuracy"
bottom: "ip"
bottom: "label"
top: "accuracy"
include {
phase: TEST
}
}
layer {
name: "loss"
type: "SoftmaxWithLoss"
bottom: "ip"
bottom: "label"
top: "loss"
}
I0617 13:28:09.308953 4697 layer_factory.hpp:77] Creating layer mnist
I0617 13:28:09.309533 4697 db_lmdb.cpp:35] Opened lmdb /media/junhui/DATA/caffe_workspace/my_linearReggresion/mnist_newtest_lmdb
# 略
I0617 13:14:44.887392 3080 net.cpp:200] loss needs backward computation.
I0617 13:14:44.887401 3080 net.cpp:202] accuracy does not need backward computation.
I0617 13:14:44.887406 3080 net.cpp:200] ip_ip_0_split needs backward computation.
I0617 13:14:44.887411 3080 net.cpp:200] ip needs backward computation.
I0617 13:14:44.887415 3080 net.cpp:202] label_mnist_1_split does not need backward computation.
I0617 13:14:44.887420 3080 net.cpp:202] mnist does not need backward computation.
I0617 13:14:44.887424 3080 net.cpp:244] This network produces output accuracy
I0617 13:14:44.887430 3080 net.cpp:244] This network produces output loss
I0617 13:14:44.887439 3080 net.cpp:257] Network initialization done.
I0617 13:14:44.888084 3080 caffe.cpp:281] Running for 2 iterations.
# 结果:
I0617 13:14:44.888623 3080 caffe.cpp:304] Batch 0, accuracy = 0.94
I0617 13:14:44.888634 3080 caffe.cpp:304] Batch 0, loss = 0.21204
I0617 13:14:44.889169 3080 caffe.cpp:304] Batch 1, accuracy = 0.96
I0617 13:14:44.889178 3080 caffe.cpp:304] Batch 1, loss = 0.199202
I0617 13:14:44.889200 3080 caffe.cpp:309] Loss: 0.205621
I0617 13:14:44.889210 3080 caffe.cpp:321] accuracy = 0.95
I0617 13:14:44.889219 3080 caffe.cpp:321] loss = 0.205621 (* 1 = 0.205621 loss)

由Log发现,在新样本上使用TEST网络,依然进行了反向传播?因为在TEST网络中要求loss,与accuracy一样是衡量模型在样本上的性能的指标,accuracy越大,loss越小,新能相对越好。而这层需要反向传播计算获得。从尾反向传播到头,所以Log就理所当然了

使用不带标签的新样本

测试猜想

猜测这种情况下的TEST网络结构只有数据层和全连接层(因为上述网络只有这两个层与label数据无关),最终结果是全连接层的输出。只有两层,进行前向计算(对一批测试,每一批只有一个元素):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
...
I0617 13:48:08.206188 5203 caffe.cpp:304] Batch 0, ip = 0.483029
I0617 13:48:08.206202 5203 caffe.cpp:304] Batch 0, ip = -4.20101
I0617 13:48:08.206208 5203 caffe.cpp:304] Batch 0, ip = 0.95857
I0617 13:48:08.206213 5203 caffe.cpp:304] Batch 0, ip = 4.77686
I0617 13:48:08.206219 5203 caffe.cpp:304] Batch 0, ip = -5.86976
I0617 13:48:08.206223 5203 caffe.cpp:304] Batch 0, ip = 5.99296
I0617 13:48:08.206228 5203 caffe.cpp:304] Batch 0, ip = -1.1105
I0617 13:48:08.206233 5203 caffe.cpp:304] Batch 0, ip = 0.394778
I0617 13:48:08.206238 5203 caffe.cpp:304] Batch 0, ip = -0.198426
I0617 13:48:08.206243 5203 caffe.cpp:304] Batch 0, ip = -0.884502
I0617 13:48:08.206248 5203 caffe.cpp:309] Loss: 0
I0617 13:48:08.206256 5203 caffe.cpp:321] ip = 0.483029
I0617 13:48:08.206264 5203 caffe.cpp:321] ip = -4.20101
I0617 13:48:08.206269 5203 caffe.cpp:321] ip = 0.95857
I0617 13:48:08.206272 5203 caffe.cpp:321] ip = 4.77686
I0617 13:48:08.206277 5203 caffe.cpp:321] ip = -5.86976
I0617 13:48:08.206282 5203 caffe.cpp:321] ip = 5.99296
I0617 13:48:08.206287 5203 caffe.cpp:321] ip = -1.1105
I0617 13:48:08.206292 5203 caffe.cpp:321] ip = 0.394778
I0617 13:48:08.206297 5203 caffe.cpp:321] ip = -0.198426
I0617 13:48:08.206301 5203 caffe.cpp:321] ip = -0.884502

发现比没有反向计算,因为不需要求loss,且输出是10个值,与网络结构对应 num_output: 10

转化自己的raw数据为LMDB数据

对于自己的MNIST数据,使用 /home/junhui/caffe-master/build/examples/mnist/convert_mnist_data.bin 命令,用法如下:

1
2
# 命令 图片文件 对应标签文件 保存位置
convert_mnist_data.bin input_image_file input_label_file output_db_file

那么caffe是如何将所下载的mnist数据集转化为LMDB的,见examples/mnist/create_mnist.sh,分析一下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# 只要这个bash脚本中任何一条命令的返回值不是0,
# 就退出整个脚本,因为默认是有错继续执行。
set -e

# 4个变量,
EXAMPLE=examples/mnist
DATA=data/mnist
BUILD=build/examples/mnist

BACKEND="lmdb"

# 打印内容:
echo "Creating ${BACKEND}..."

# 将之前的DB文件删除
rm -rf $EXAMPLE/mnist_train_${BACKEND}
rm -rf $EXAMPLE/mnist_test_${BACKEND}

# 调用 convert_mnist_data.bin 命令,
$BUILD/convert_mnist_data.bin $DATA/train-images-idx3-ubyte \
$DATA/train-labels-idx1-ubyte $EXAMPLE/mnist_train_${BACKEND} --backend=${BACKEND}
$BUILD/convert_mnist_data.bin $DATA/t10k-images-idx3-ubyte \
$DATA/t10k-labels-idx1-ubyte $EXAMPLE/mnist_test_${BACKEND} --backend=${BACKEND}
# 最终在制定的路径中生成 mnist_train_ 和 mnist_test_ 。

echo "Done."

convert_mnist_data.bin 命令对应的源文件在 examples/mnist/convert_mnist_data.cpp,阅读发现其解析过程是根据 train-images-idx3-ubyte 文件的格式的。

下一步:

  1. 如何从LMDB中读取每一条数据?
  2. 将不含标签的数据保存为LMDB?
  3. 写自己的 convert_XXX_data.bin 工具,和bash脚本。

LeetCode-BST相关

BST与一般二叉树的不同之处是,BST的中序遍历得到有序序列,具体说是left->root->right遍历结果是递增序列,right->root->left 的遍历结果是递减序列。根据BST的这个特点,与BST相关的问题,先遍历得到有序序列,后就是对一个有序数组的操作了。

先给出BST结点的定义:

1
2
3
4
5
6
7
8
9
struct TreeNode {
int val;
TreeNode *left;
TreeNode *right;
TreeNode() : val(0), left(nullptr), right(nullptr) {}
TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
TreeNode(int x, TreeNode *left, TreeNode *right)
: val(x), left(left), right(right) {}
};

下面记录LeetCode 中的BST相关问题

530,783,230

思路一样,先遍历,后对一个有序数组操作。不细展开了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
// #530 #783 # 230 BST inOrder to get a 递增序列
class Sol530_783_230{

public:
// time: O(2*N)
int getMinimumDifference(TreeNode* root) {
inOrder(root);
int minAD = INT_MAX;
for (int i=1; i<vec.size(); i++){
minAD = min(minAD, vec[i]-vec[i-1]);
cout<<minAD<<endl;
}
return minAD;
}

private:
vector<int> vec;
void inOrder(TreeNode* root){
if (root!=NULL){
inOrder(root->left); // 只是进去,没有其他操作,直到进到最里层,最里层root==nullptr,所以直接return
vec.push_back(root->val); // 上一句return后才开始执行这一句,及之后的语句。
inOrder(root->right);
}
}
};

653

描述:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
Given a Binary Search Tree and a target number, return true if there exist two elements in the BST such that their sum is equal to the given target.

Example 1:

Input:
5
/ \
3 6
/ \ \
2 4 7

Target = 9

Output: True

思路:仍是对有序数组的操作。

先序遍历后得到有序数组;后使用对撞指针找到一对值和为target。

实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Sol653 {
public:
// time: O(2×N)
bool findTarget(TreeNode* root, int k) {
inOrder(root);
int i=0;
int j=vec.size()-1;
while(i<j){
if (vec[i]+vec[j]==k) return true;
else if(vec[i]+vec[j] < k)
i++;
else
j--;
}
return false;
}
private:
vector<int> vec;
void inOrder(TreeNode* root){
if (root!=NULL){
inOrder(root->left);
vec.push_back(root->val);
inOrder(root->right);
}
}
};

938

描述:

1
2
3
4
5
6
7
8
Given the root node of a binary search tree, return the sum of values of all nodes with value between L and R (inclusive).

The binary search tree is guaranteed to have unique values.

Example

Input: root = [10,5,15,3,7,null,18], L = 7, R = 15
Output: 32

思路:

方法一:中序遍历后得有序序列;后从头和尾设置两个指针对撞,找到两个位置,后将两位置间的元素累加得结果。

方法二:选择一种树的遍历方法,在遍历的过程中,判断每一个结点值是否在指定区间内,如果是,就累加到一个变量中。

实现方法二:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class Sol938 {

public:
int sum=0;
int rangeSumBST(TreeNode* root, const int L, cosnt int R) {
dfs(root, L,R);
return sum;
}
void dfs(TreeNode* root, int L, int R){
if (root==nullptr) return;
else{
if (root->val>=L && root->val<=R)
sum += root->val;
if (root->val>L)
dfs(root->left, L,R);
if (root->val<R)
dfs(root->right, L,R);
}
}
};

注意-有返回值的递归:由于rangeSumBST()函数有返回值,若将递归实现在其中,会有问题:递归函数与其返回值不一致。什么意思?具体说是,函数是递归的,自己会调用自己,而返回值是一个最终结果,一个是过程一个是结果。可能没法处理!所以考虑将递归函数和返回值分离,其中递归函数没有返回值,如上。

701

描述:

将一个结点掺入到BST中(新节点不与原BST中结点重复),要保持BST的性质不变。BST的基本操作

思路:

分析题意,一定可以在BST的叶节点的左右子节点找到新节点的位置,将其作为这个叶节点的子节点所以递归地找到带插入结点的位置即可。

实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

// #701 BST insert operation
class Sol701 {
public:
// 在现有的 BST 的叶节点添加新的结点
// time: O(logN)
TreeNode* insertIntoBST(TreeNode* root, const int val) {
// 向叶节点添加子结点
if (root==nullptr){
return new TreeNode(val);
}
if (root->val == val)
cout<<"Not going to happen"<<endl;
else if(root->val > val){
root->left = insertIntoBST(root->left, val);
}else
root->right = insertIntoBST(root->right, val);
return root;
}
};

注意-有返回值的递归:这个问题的返回值是节点指针TreeNode*,而递归函数每次递归的判断操作对象也是一个TreeNode*。所以这里的返回值不会影响到递归函数。所以是递归函数与其返回值一致。与上一个问题对比。

501

描述:

找到BST中出现频数最多的节点值。

思路:

第一步遍历,过程中将每个节点值放入哈希表。第二步是对哈希表的操作,找到找到频数最多的key值,可能有重复。

实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
// #501 BST
class Sol501 {
public:
vector<int> findMode(TreeNode* root) {
inOrder(root)
int secondMax = 0;
vector<int> res;
for (auto item: mp){
if (item.second>secondMax){
res.clear();
secondMax = item.second;
}
if (item.second == secondMax)
res.push_back(item.first);
}
return res;

}
private:
unordered_map<int, int> mp;
void inOrder(TreeNode* root){
if (root!=NULL){
inOrder(root->left);
mp[root->val]++;
inOrder(root->right);
}
}
};

173

描述:

1
2
3
4
Implement an iterator over a binary search tree (BST). Your iterator will be initialized with the root node of a BST.

Calling next() will return the next smallest number in the BST.
hasNext() return whether we have a next smallest number

思路:

right->root->left 遍历,得到递减序列。next()返回序列最有一个值,后删除最后一个值。hasNext()返回这个序列是否为空。

实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
// # 173
class BSTIterator {
private:
vector<int> nodeVal;
void inOrder(TreeNode* root){
if(root!=nullptr){
inOrder(root->right);
nodeVal.push_back(root->val);
inOrder(root->left);
}
}
public:
// time: O(N)
// right->root->left
BSTIterator(TreeNode* root) {
inOrder(root);
}

/** @return the next smallest number */
// time: O(1)
int next() {
int res = nodeVal[nodeVal.size()-1];
if (nodeVal.size()!=0)
nodeVal.pop_back();
return res;
}

/** @return whether we have a next smallest number */
// time: O(1)
bool hasNext() {
if(nodeVal.size()!=0) return true;
return false;
}
};

538

描述:

BST中每个结点的值,类加上所有比这个结点大的值,得到新的二叉树。

1
2
3
4
5
6
7
8
9
Input: The root of a Binary Search Tree like this:
5
/ \
2 13

Output: The root of a Greater Tree like this:
18
/ \
20 13

思路:

先定义一个变量value=0,后进行right->root->left遍历,每遍历到一个结点,就将当前节点值累加到value,之后将更新了的value赋值给当前结点。

实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// #538
class Sol538 {
private:
int value=0;
public:
// time: O(N)
TreeNode* convertBST(TreeNode* root) {
if (root!=nullptr){
root->right = convertBST(root->right);
value+=root->val;
root->val = value;
root->left = convertBST(root->left);
}
return root;
}
};

注意-有返回值的递归

98 合法BST

描述:

判断一个BST是否合法

方法:

left->root->right 顺序递归地判断这个树的心虚遍历是否单调递增。只要中途没有返回false,就一直递归下去,所以需要在code中所有可能返回false的情况都写出。

实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class Solution {
public:
bool isValidBST(TreeNode* root) {
TreeNode* pre = NULL;
return dfs(root, pre);
}
private:
bool dfs(TreeNode* root, TreeNode* &preValue){
// 递归终止
if (!root) return true;
// 情况1. 返回false
if (!dfs(root->left, preValue))
return false;
// 情况2. 返回false
if (preValue && preValue->val>= root->val )
return false;
preValue = root;

return dfs(root->right, preValue);
}
};

如果上述逻辑不能马上写出,可以先实现一个判断单调递增链表的递归实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class Solution_list{
public:
bool func(ListNode* head){
ListNode* pre = nullptr;
return isGreaterThanPre(head, pre);
}
// pre是要被更改的
bool isGreaterThanPre(ListNode* head, ListNode* &pre){
// 终止条件
if (head == nullptr) return true;

// 返回false的情况
if (pre && head->val<=pre->val)
return false;
pre = head;
return isGreaterThanPre(head->next, pre);
}
};

注意参数的 这种表达: TreeNode* &preListNode* &pre 它表示,pre这个指针所存地址中的内容是会在函数中被更改的。

449

描述:

将内存中数据序列化到磁盘,后将磁盘中的数据反序列化到内存。
应用时,反序列化后的树要与序列化前的树是同一颗树,才是这个问题的目的。

1
2
Codec codec;
codec.deserialize(codec.serialize(root));

思路:

困难的是

实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class Codec449 {
public:
// Encodes a tree to a single string.
string serialize(TreeNode* root) {
string res;
preOrder(root, res);
return res;
}

// Decodes your encoded data to tree.
TreeNode* deserialize(string data) {
istringstream str(data);
queue<int> q;
string s;
while(str>>s)
q.push(stoi(s)); // string to int
return des(q, INT_MIN, INT_MAX);
}
private:
void preOrder(TreeNode* root, string& res){
if (root==nullptr) return;
res += to_string(root->val); // anything to string
res += ' ';
preOrder(root->left, res);
preOrder(root->right, res);
}

TreeNode* des(queue<int>& q, int low, int hig){
if (q.empty()) return nullptr;
int val = q.front();
if (val<low || val>hig) return nullptr;
q.pop();
// 构建一个结点和其左右子节点
TreeNode* node = new TreeNode(val);
node->left = des(q, low, val);
node->right = des(q, val, hig);
return node;
}
};

构建一颗树,肯定是从root开始,所以先要有root后才可以构建其左右子节点。

敲黑板 注意-递归函数的 返回值

LeetCode-top-k-Kth-lagest

Top-k 问题有多种解法。LeetCode #215 #347

k次冒泡法

就是冒泡排序了,执行看次冒泡操作:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
vector<int> kBubble(vector<int>& nums, int k){
bool swapped;
int n = nums.size();
int topK = k;
// 每次do{} 都是一次冒泡,即将最大值移到数组最右端
do {
if (topK==0) break;
swapped= false;
for (int i=1;i<n;i++){
if (nums[i-1]>nums[i]){
swap(nums[i-1], nums[i]);
swapped = true;
}
}
n--;
topK--;
}while(swapped);
// 保存top k
vector<int> res(nums.rbegin(), nums.rbegin()+k);
return res;
}

测试:

1
2
3
4
5
6
7
8
9
10
int main(int argc, char** argv){

vector<int> nums = {2,5,8,3,4,5,9,12,45,78,340,5,2};
vector<int> res = kBubble(nums, 3);
for (auto item:res)
cout<<item<<" ";
cout<<endl;

return 0;
}

返回:

1
2
3
4
// 当k = 3
340 78 45
// 当k = 数组大小
340 78 45 12 9 8 5 5 5 4 3 2 2

时间复杂度:O(N*K)

快排中的Partition操作法

先看,快排中的partition操作返回第k大的元素。其思路是,每次partition返回的值是一个索引值index,它对应的元素是在排好的序列中的正确位置。对于下面的例子,元素个数为9,k为2。一次partition操作后判断这个索引值index=4,对应元素5,5被放到它应该在的位置。而且9-4==5 != k,所以5不是第k个元素。

1
2
3
4
k=2
5 4 3 2 1 8 9 6 7
// 一次partition操作后
4 3 2 1 [5] 8 9 6 7

则此时第2大的元素在以5为界的右半边,所以对右半边进行partition操作。假如返回index为7,而9-7==k,所以找到第2大元素8.

1
4 3 2 1 [5] 6 7 [8] 9

上述是这个问题的关键

实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
// partition 操作
int partitionOpe(vector<int>& nums, int l, int r){
swap(nums[l], nums[rand()%(r-l+1)+l]);
int p = nums[l];
int j = l;
for (int i=j+1;i<=r;i++){
if (nums[i]<p){
swap(nums[i], nums[++j]);
}
}
swap(nums[l], nums[j]);
return j;
}

// 查找
int kthItem(vector<int>& nums, int k){
int n = nums.size();
int l=0, r=n-1, index=0;
index = partitionOpe(nums,l,r);
while(true){
// 对右边进行partition
if (n-index>k){
l = index+1;
index = partitionOpe(nums,l,r);
}
// 对左边进行partition
else if(n-index<k){
r = index-1;
index = partitionOpe(nums,l,r);
}
// 直到找到
else if(n-index==k){
return nums[index];
}
}
}

测试对象[5,7,3,8,2,9,1,6,4], 返回第3大元素,结果:

1
2
3
4
// 结果
7
// 打印数组,7前元素比7小,7后元素比7大。
4 5 3 2 1 6 7 9 8

进而Top-k问题,就是将此时的数组返回7之后的所有元素(包括7)即可。

时间复杂度:

O(N(1+1/2+1/4+1/8+…)), 根据数列和的极限,1+1/2+1/4+1/8+… 趋近于2,所以partition操作的top-k问题时间复杂度为O(2N).

注意:

  1. 此法可以处理含有重复元素的数组。
  2. 不要忘了快排递归终止条件。

附件 快排

注意递归终止条件。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
int partitionOpe(vector<int>& nums, int l, int r){
swap(nums[l], nums[rand()%(r-l+1)+l]);
int p = nums[l];
int j = l;
for (int i=j+1;i<=r;i++){
if (nums[i]<p){
swap(nums[i], nums[++j]);
}
}
swap(nums[l], nums[j]);
return j;
}
// 找到p值后分左右继续找。
void findPSplit(vector<int>& nums, int l, int r){
// 递归终止条件!
if (l>=r) return;
int p = partitionOpe(nums, l, r);
findPSplit(nums, l, p-1);
findPSplit(nums, p+1, r);
return;
}
// 入口函数
void quickSort(vector<int>& nums){
findPSplit(nums, 0, nums.size()-1);
return;
}

优先队列法

找到第k的元素:

1
2
3
4
5
6
7
8
9
10
11
int findKthLargest(vector<int>& nums, int k) {
// 1) 构造优先队列
priority_queue<int> pq;
for(auto itr=nums.begin(); itr!=nums.end(); itr++)
pq.push(*itr);

/// 求kth 元素
for(int i=1;i<k;i++)
pq.pop();
return pq.top();
}

如果是求Top-k:

1
2
3
4
5
6
7
8
9
10
11
12
13
vector<int> topKQueue(vector<int>& nums, int k){
// 1) build a map:
priority_queue<int> pq;
for(auto item: nums)
pq.push(item);

vector<int> res;
for(int i=0; i<k; i++){
res.push_back(pq.top());
pq.pop();
}
return res;
}

储存pop的结果得到top-k问题结果。

时间复杂度,与数据结构有关。

LeetCode-153-找到最小值

二分查找。

  • 描述:

    一个递增序列,中间某个位置切一刀,之后前后两部分交换位置。找到此时的最小值。如[4,5,6,7,0,1,2]返回0。

  • 思路:

    两种情况:

    第一,数组中没有重复元素时,此情况容易,当这一刀没有切,数组递增,第一个元素为最小值。如果切下去了,那么最小值应该在数组切面处,如上述例子,切面在7和0,而最小值就在切面处。所以接下来的步骤就是二分查找法找切面。

    第二,数组中含有重复元素,此时情况较多,需要将所有情况都找到,较困难,这里是别人的如何找到所有情况的方法。

  • 实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class Solution {
public:
// 不含重复元素
int findMin(vector<int>& nums) {
if(nums.size()==1) return nums[0];
int head = 0;
int tail = nums.size()-1;
// 如果这一刀没有切
if (nums[head] < nums[tail])
return nums[head];

while(head <= tail){
int mid = head+(tail-head)/2;
// 找到切面
if(nums[mid] > nums[mid+1])
return nums[mid+1];
if(nums[mid-1] > nums[mid])
return nums[mid];

// 没有找到切面
if(nums[mid] > nums[head])
head = mid+1;
else if(nums[mid] < nums[head])
tail = mid-1;
}
return -1;
}
// 含重复元素
int findMin(vector<int>& nums){
int head=0;
int tail=nums.size()-1;
int mid;
while(head<=tail ){
mid = head+(tail-head)/2;
if(head==mid || mid==tail)
return min(min(nums[head], nums[mid]), nums[tail]);
if(nums[head]<=nums[mid] && nums[mid]<=nums[tail])
tail--;
else if(nums[mid]<=nums[tail] && nums[tail]<=nums[head])
tail = mid;
else if(nums[tail]<=nums[head] && nums[head]<=nums[mid])
head = mid;
}
return -1;
}
};

LeetCode-34-find-positions-of-elements

二分查找。

  • 描述

    在一个有序的数组中找到指定元素的起始和结束位置,如下:

    1
    2
    Input: nums = [5,7,7,8,8,10], target = 8
    Output: [3,4]

    如果没有找到target,返回[-1, -1]。

  • 思路

    分两步,首先用二分查找法找到数组中的一个target,然后分别从这个target开始先前和向后移动,找到其开始和结束的位置。实现如下,时间复杂度O(m*logN),m为target元素个数:

  • 实现

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    class Solution {
    public:
    // 二分查找
    vector<int> searchRange(vector<int>& nums, int target) {
    if(nums.size()==0) return vector<int>{-1,-1};

    int head = 0;
    int tail = nums.size()-1;
    while(head<=tail){
    int m = (tail+head)/2;
    if(nums[m] == target) {
    return mySearch(m, nums);
    }
    else if(nums[m] < target)
    head = m + 1;
    else
    tail = m - 1;
    }
    // 当没有找到target
    return vector<int> {-1, -1};
    }
    private:
    // 搜索起始和结束位置
    vector<int> mySearch(int m, vector<int>& nums){
    int mUp = m;
    int mDown = m;
    // 向后搜索
    while(mUp+1<nums.size() && nums[mUp] == nums[mUp+1])
    mUp++;
    // 向前搜索
    while(mDown-1>=0 && nums[mDown] == nums[mDown-1])
    mDown--;
    return vector<int>{mDown, mUp};
    }
    };

LeetCode-求一个数的平方根sqrt(x)

  • 描述

    如题

  • 思路

    使用基础算法,二分查找容易求出sqrt(x)。如果对精度有要求,即要求精确到小数点后n位:sqrt(x, n),只不过此时二分查找的headtail的偏移不再是1,而是与精度有关。

  • 实现

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    // 没有精度的要求
    class Solution {
    private:
    //int compute_count=0;
    public:
    // my sqrt without accuracy
    int mySqrt(int x) {
    int head = 0, tail=x;
    while(head<=tail){
    long long int middle = head + (tail-head)/2;
    if(middle * middle == x)
    return middle;
    else if(middle * middle < x)
    head = middle+1;
    else
    tail = middle-1;
    }
    return tail;
    }

    // my sqrt with accuracy
    double mySqrt(double x, int acc){
    double head = 0.0;
    double offset = pow(10, -acc);
    double tail = x - offset;
    cout<<"Accuracy: "<<offset<<" result: ";

    while(tail>=head){
    //compute_count++;
    double mid = head + (tail-head)/2;
    if(mid * mid == x)
    return mid;
    else if(mid * mid < x)
    head = mid + offset;
    else
    tail = mid - offset;
    }
    return return offset!=1?min(tail,head):tail;
    }
    };

    int main(){
    Solution sol;
    double in;
    cin>>in;
    // 调用无精度要求的函数
    cout<<sol.mySqrt(static_cast<int>(in))<<endl;
    // 不同的精度要求
    cout<<sol.mySqrt(in,0)<<endl;
    cout<<sol.mySqrt(in,1)<<endl;
    cout<<sol.mySqrt(in,2)<<endl;
    cout<<sol.mySqrt(in,3)<<endl;
    cout<<sol.mySqrt(in,4)<<endl;
    cout<<sol.mySqrt(in,5)<<endl;
    cout<<sol.mySqrt(in,6)<<" | true: "<<sqrt(in)<<endl;

    return 0;
    }

    返回如下,可以看出,结果的精度在随要求精度提高:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    87 # 输入87
    9
    Accuracy: 1 result: 9
    Accuracy: 0.1 result: 9.32656
    Accuracy: 0.01 result: 9.32678
    Accuracy: 0.001 result: 9.32655
    Accuracy: 0.0001 result: 9.32743
    Accuracy: 1e-05 result: 9.32738
    Accuracy: 1e-06 result: 9.32738 | true: 9.32738

敲黑板:注意边界条件,注意特殊情况。

LeetCode-求一个数的n次幂

递归

  • 问题描述:

    求x的n次方power(x, n),如power(2,4) = 16

  • 思路:

    很直接的思路,x与自己相乘n次的结果,事假复杂度为O(N)。这个方法有个问题:很多的重复计算。以power(2,4)为例,power(2,4) = (2x2)x(2x2)。(2x2)被计算了两次,当n很大时,就会由更多的重复计算。其实power(2,4) = power(2,2)xpower(2,2)。所以可以很容易写出递归的主体:

    1
    2
    double half = power(x, n/2);
    double res = half*half;

    时间复杂度为O(logN)。而递归终止条件和其他细节可从下图中自然得到:

    【过程图】

  • 实现:

    由过程图可以实现如下

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    double myPow3(double x, int n){ 
    // 递归终止条件:
    if (n==0) return 1.0;
    // 特殊情况:
    if (x==0.0) return 1.0;
    long long int nn=n;
    // 当x为负数
    if(nn<0){
    nn=-nn;
    x=1/x;
    }
    // 递归主体:
    double half = myPow3(x, nn/2);
    double ans = half*half;
    // 这个细节在过程图中很容易看出:
    return n%2==0 ? ans : ans*x;
    }

敲黑板:递归直接写很容易出错,把过程图画出来,一些细节就会很容易发现。

caffe-阅读Log运行日志

接着笔记caffe-数据&模型-模型输出log,继续阅读Log。是关于内存磁盘间的通信。

从prototxt描述转换城内存表示方式

要构建一个Net,就需要解析两个超参数文件。对于模型结构参数文件:

1
2
3
4
5
6
# 解析
I0610 04:53:36.096751 13919 solver.cpp:102] Creating training net from net file: /media/junhui/DATA/caffe_workspace/my_linearReggresion/mylr.prototxt
# 这两行前面解释过了,区分哪些层用于TRAIN,哪些层用于TEST
I0610 04:53:36.097002 13919 net.cpp:296] The NetState phase (0) differed from the phase (1) specified by a rule in layer mnist
I0610 04:53:36.097012 13919 net.cpp:296] The NetState phase (0) differed from the phase (1) specified by a rule in layer accuracy
I0610 04:53:36.097085 13919 net.cpp:53] Initializing net from parameters:

找到对应位置:

1
2
3
4
5
6
7
8
9
10
void Solver<Dtype>::InitTrainNet() {
...
// param_是一个SolverParameter对象,如果mylr.prototxt文件中定义了Net结构,则如下
if (param_.has_net()) {
LOG_IF(INFO, Caffe::root_solver()) // 打印Log
<< "Creating training net from net file: " << param_.net();
// 解析mylr.prototxt中内容,将其内容存入 NetParameter 对象 net_param 中。
// 转换过程由ProtoBuffer工具完成的。
ReadNetParamsFromTextFileOrDie(param_.net(), &net_param);
}

如此就将磁盘中的prototxt描述转换到内存。

将内存中模型存储到磁盘

当需要保存当前模型和快照时:

1
2
I0610 04:53:38.909629 13919 solver.cpp:464] Snapshotting to binary proto file my_lr_iter_5000.caffemodel
I0610 04:53:38.910568 13919 sgd_solver.cpp:284] Snapshotting solver state to binary proto file my_lr_iter_5000.solverstate

solver.cpp:464 将模型存入磁盘:

1
2
3
4
5
6
7
8
9
10
11
12
13
template <typename Dtype>
string Solver<Dtype>::SnapshotToBinaryProto() {
string model_filename = SnapshotFilename(".caffemodel");
// 文件名,只用提供后缀
LOG(INFO) << "Snapshotting to binary proto file " << model_filename;
// NetParameter 定义在 caffe.pb.h 中
NetParameter net_param;
// 序列化到 net_param,一个 ProtoBuffer 对象
net_->ToProto(&net_param, param_.snapshot_diff());
// 将这个 ProtoBuffer 对象写入磁盘
WriteProtoToBinaryFile(net_param, model_filename);
return model_filename;
}

上述写入磁盘过程在caffe-Blob-(1)解释过。

sgd_solver.cpp:284 将快照存入磁盘:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
template <typename Dtype>
void SGDSolver<Dtype>::SnapshotSolverStateToBinaryProto(
const string& model_filename) {
// SolverState 定义在 caffe.pb.h 中
// 创建一个序列化对象
SolverState state;
// 获得当前网络迭代参数
state.set_iter(this->iter_);
state.set_learned_net(model_filename);
state.set_current_step(this->current_step_);
state.clear_history();
for (int i = 0; i < history_.size(); ++i) {
// Add history
BlobProto* history_blob = state.add_history();
// 写入 ProtoBuffer 对象
history_[i]->ToProto(history_blob);
}
string snapshot_filename = Solver<Dtype>::SnapshotFilename(".solverstate");
LOG(INFO)
<< "Snapshotting solver state to binary proto file " << snapshot_filename;
// 将 ProtoBuffer 对象 写入磁盘
WriteProtoToBinaryFile(state, snapshot_filename.c_str());
}

NetParameter和类SolverState都是在caffe.pb.h文件中定义的。这个文件是编译时由 ProtoBuffer 的编译器自动生成的。其中定义了很多类,包括NetParameterSolverState。为什么会自动成成的文件中会有特定的类名? 和可能是ProtoBuffer根据caffe.proto协议生成的。

对比发现,caffe.pb.h 中所有的类名在 caffe.proto 中都是一个message 对象: message NAME {}

上述属于内存磁盘通信的内容。