forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathworkspace_test.cc
149 lines (127 loc) · 4.33 KB
/
workspace_test.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#include <iostream>
#include "caffe2/core/operator.h"
#include <gtest/gtest.h>
namespace caffe2 {
class WorkspaceTestFoo {};
CAFFE_KNOWN_TYPE(WorkspaceTestFoo);
TEST(WorkspaceTest, BlobAccess) {
Workspace ws;
EXPECT_FALSE(ws.HasBlob("nonexisting"));
EXPECT_EQ(ws.GetBlob("nonexisting"), nullptr);
EXPECT_EQ(ws.GetBlob("newblob"), nullptr);
EXPECT_NE(nullptr, ws.CreateBlob("newblob"));
EXPECT_NE(nullptr, ws.GetBlob("newblob"));
EXPECT_TRUE(ws.HasBlob("newblob"));
// Different names should still be not created.
EXPECT_FALSE(ws.HasBlob("nonexisting"));
EXPECT_EQ(ws.GetBlob("nonexisting"), nullptr);
// Check if the returned Blob is OK for all operations
Blob* blob = ws.GetBlob("newblob");
int* int_unused CAFFE2_UNUSED = blob->GetMutable<int>();
EXPECT_TRUE(blob->IsType<int>());
EXPECT_FALSE(blob->IsType<WorkspaceTestFoo>());
EXPECT_NE(&blob->Get<int>(), nullptr);
// Re-creating the blob does not change the content as long as it already
// exists.
EXPECT_NE(nullptr, ws.CreateBlob("newblob"));
EXPECT_TRUE(blob->IsType<int>());
EXPECT_FALSE(blob->IsType<WorkspaceTestFoo>());
// When not null, we should only call with the right type.
EXPECT_NE(&blob->Get<int>(), nullptr);
// Re-creating the blob through CreateLocalBlob does not change the content
// either.
EXPECT_NE(nullptr, ws.CreateLocalBlob("newblob"));
EXPECT_TRUE(blob->IsType<int>());
EXPECT_NE(&blob->Get<int>(), nullptr);
// test removing blob
EXPECT_FALSE(ws.HasBlob("nonexisting"));
EXPECT_FALSE(ws.RemoveBlob("nonexisting"));
EXPECT_TRUE(ws.HasBlob("newblob"));
EXPECT_TRUE(ws.RemoveBlob("newblob"));
EXPECT_FALSE(ws.HasBlob("newblob"));
}
TEST(WorkspaceTest, RunEmptyPlan) {
PlanDef plan_def;
Workspace ws;
EXPECT_TRUE(ws.RunPlan(plan_def));
}
TEST(WorkspaceTest, Sharing) {
Workspace parent;
EXPECT_FALSE(parent.HasBlob("a"));
EXPECT_TRUE(parent.CreateBlob("a"));
EXPECT_TRUE(parent.GetBlob("a"));
{
Workspace child(&parent);
// Child can access parent blobs
EXPECT_TRUE(child.HasBlob("a"));
EXPECT_TRUE(child.GetBlob("a"));
// Child can create local blobs
EXPECT_FALSE(child.HasBlob("b"));
EXPECT_FALSE(child.GetBlob("b"));
EXPECT_TRUE(child.CreateBlob("b"));
EXPECT_TRUE(child.GetBlob("b"));
// Parent cannot access child blobs
EXPECT_FALSE(parent.GetBlob("b"));
EXPECT_FALSE(parent.HasBlob("b"));
// Parent can create duplicate names
EXPECT_TRUE(parent.CreateBlob("b"));
// But child has local overrides
EXPECT_NE(child.GetBlob("b"), parent.GetBlob("b"));
// Child can create a blob that already exists in the parent
EXPECT_TRUE(child.CreateBlob("a"));
EXPECT_EQ(child.GetBlob("a"), parent.GetBlob("a"));
// Child can create a local blob for the blob already exists in the parent
EXPECT_TRUE(child.CreateLocalBlob("a"));
// But the local blob will be different from the one in parent workspace
EXPECT_NE(child.GetBlob("a"), parent.GetBlob("a"));
}
}
TEST(WorkspaceTest, BlobMapping) {
Workspace parent;
EXPECT_FALSE(parent.HasBlob("a"));
EXPECT_TRUE(parent.CreateBlob("a"));
EXPECT_TRUE(parent.GetBlob("a"));
{
std::unordered_map<string, string> forwarded_blobs;
forwarded_blobs["inner_a"] = "a";
Workspace child(&parent, forwarded_blobs);
EXPECT_FALSE(child.HasBlob("a"));
EXPECT_TRUE(child.HasBlob("inner_a"));
EXPECT_TRUE(child.GetBlob("inner_a"));
Workspace ws;
EXPECT_TRUE(ws.CreateBlob("b"));
forwarded_blobs.clear();
forwarded_blobs["inner_b"] = "b";
child.AddBlobMapping(&ws, forwarded_blobs);
EXPECT_FALSE(child.HasBlob("b"));
EXPECT_TRUE(child.HasBlob("inner_b"));
EXPECT_TRUE(child.GetBlob("inner_b"));
}
}
/**
* Checks that Workspace::ForEach(f) applies f on the specified set of
* workspaces in any order.
*/
static void forEachCheck(std::initializer_list<Workspace*> workspaces) {
std::unordered_set<Workspace*> expected(workspaces);
std::unordered_set<Workspace*> actual;
Workspace::ForEach([&](Workspace* ws) {
auto inserted = actual.insert(ws).second;
EXPECT_TRUE(inserted);
});
EXPECT_EQ(actual, expected);
}
TEST(WorkspaceTest, ForEach) {
forEachCheck({});
{
Workspace ws1;
forEachCheck({&ws1});
{
Workspace ws2;
forEachCheck({&ws1, &ws2});
}
forEachCheck({&ws1});
}
forEachCheck({});
}
} // namespace caffe2