1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/ExecutionEngine/ExecutionEngine.h" |
18 | #include "glow/Graph/Graph.h" |
19 | #include "glow/IR/IR.h" |
20 | #include "glow/IR/IRBuilder.h" |
21 | #include "glow/IR/Instrs.h" |
22 | |
23 | #include "gtest/gtest.h" |
24 | |
25 | #include <cassert> |
26 | #include <string> |
27 | |
28 | using namespace glow; |
29 | |
30 | TEST(GraphAutoGrad, autoGrad) { |
31 | ExecutionEngine EE; |
32 | PlaceholderBindings bindings; |
33 | |
34 | TrainingConfig TC; |
35 | |
36 | // Construct the network: |
37 | TC.learningRate = 0.001; |
38 | TC.momentum = 0.9; |
39 | TC.L2Decay = 0.001; |
40 | TC.L1Decay = 0.001; |
41 | |
42 | auto &mod = EE.getModule(); |
43 | Function *F = mod.createFunction("main" ); |
44 | auto *A = |
45 | mod.createPlaceholder(ElemKind::FloatTy, {10, 28, 28, 1}, "input" , false); |
46 | |
47 | auto *CV0 = F->createConv(bindings, "conv1" , A, 16, 5, 1, 2, 1); |
48 | auto *RL0 = F->createRELU("relu1" , CV0); |
49 | auto *MP0 = F->createMaxPool("pool1" , RL0, 3, 3, 0); |
50 | |
51 | auto *CV1 = |
52 | F->createConv(bindings, "conv2" , MP0->getResult(), 16, 5, 1, 2, 1); |
53 | auto *RL1 = F->createRELU("conv23" , CV1); |
54 | auto *MP1 = F->createMaxPool("pool2" , RL1, 3, 3, 0); |
55 | |
56 | auto *FCL1 = F->createFullyConnected(bindings, "fc3" , MP1->getResult(), 10); |
57 | auto *RL2 = F->createRELU("relu3" , FCL1); |
58 | auto *selected = |
59 | mod.createPlaceholder(ElemKind::Int64ITy, {10, 1}, "selected" , false); |
60 | |
61 | auto *SM = F->createSoftMax("sm" , RL2, selected); |
62 | |
63 | auto *result = F->createSave("return" , SM); |
64 | (void)result; |
65 | |
66 | glow::differentiate(F, TC); |
67 | EE.compile(CompilationMode::Train); |
68 | } |
69 | |
70 | TEST(GraphAutoGrad, checkLRNGen) { |
71 | ExecutionEngine EE; |
72 | TrainingConfig TC; |
73 | PlaceholderBindings bindings; |
74 | |
75 | // Construct the network: |
76 | TC.learningRate = 0.001; |
77 | TC.momentum = 0.9; |
78 | TC.L2Decay = 0.001; |
79 | |
80 | auto &mod = EE.getModule(); |
81 | Function *F = mod.createFunction("main" ); |
82 | |
83 | auto *A = |
84 | mod.createPlaceholder(ElemKind::FloatTy, {10, 28, 28, 1}, "input" , false); |
85 | auto *CV0 = F->createLocalResponseNormalization("LRN" , A); |
86 | auto *FCL1 = F->createFullyConnected(bindings, "fc3" , CV0, 10); |
87 | auto *RL2 = F->createRELU("relu3" , FCL1); |
88 | auto *selected = |
89 | mod.createPlaceholder(ElemKind::Int64ITy, {10, 1}, "selected" , false); |
90 | |
91 | auto *SM = F->createSoftMax("sm" , RL2, selected); |
92 | |
93 | auto *result = F->createSave("return" , SM); |
94 | (void)result; |
95 | glow::differentiate(F, TC); |
96 | EE.compile(CompilationMode::Train); |
97 | } |
98 | |
99 | TEST(GraphAutoGrad, cloneAndDiff) { |
100 | // The test ensures that unused variables are not touched in differentiation. |
101 | ExecutionEngine EE; |
102 | TrainingConfig TC; |
103 | PlaceholderBindings bindings; |
104 | Module M; |
105 | |
106 | auto *F = M.createFunction("main" ); |
107 | Node *A = M.createPlaceholder(ElemKind::FloatTy, {1}, "A" , true); |
108 | Node *B = M.createPlaceholder(ElemKind::FloatTy, {1}, "B" , true); |
109 | Node *AplusB_F = F->createAdd("AplusB" , A, B); |
110 | |
111 | EXPECT_EQ(M.getPlaceholders().size(), 2); |
112 | |
113 | auto *G = F->clone("G" ); |
114 | |
115 | EXPECT_EQ(M.getPlaceholders().size(), 2); |
116 | EXPECT_EQ(G->getNodes().size(), 1); |
117 | |
118 | Node *C = M.createPlaceholder(ElemKind::FloatTy, {1}, "C" , true); |
119 | Node *AplusB_G = &G->getNodes().back(); |
120 | G->createAdd("totalSum" , AplusB_G, C); |
121 | |
122 | EXPECT_EQ(M.getPlaceholders().size(), 3); |
123 | |
124 | Node *label = M.createPlaceholder(ElemKind::FloatTy, {1}, "label" , false); |
125 | Node *reg = F->createRegression("reg" , AplusB_F, label); |
126 | F->createSave("return" , reg); |
127 | |
128 | EXPECT_EQ(M.getPlaceholders().size(), 5); |
129 | |
130 | auto *diffF = differentiate(F, TC); |
131 | |
132 | EXPECT_TRUE(diffF->verify()); |
133 | |
134 | EXPECT_EQ(M.getFunctions().size(), 3); |
135 | EXPECT_EQ(M.getPlaceholders().size(), 5); |
136 | // Check that we have as many SGD node as variables that need to be trained. |
137 | unsigned nbSGDs = 0; |
138 | unsigned nbSGDA = 0; |
139 | unsigned nbSGDB = 0; |
140 | for (auto &node : diffF->getNodes()) { |
141 | SGDNode *SGD = llvm::dyn_cast<SGDNode>(&node); |
142 | if (!SGD) |
143 | continue; |
144 | ++nbSGDs; |
145 | if (A == SGD->getWeight()) |
146 | ++nbSGDA; |
147 | else if (B == SGD->getWeight()) |
148 | ++nbSGDB; |
149 | } |
150 | EXPECT_EQ(nbSGDs, 2); |
151 | EXPECT_EQ(nbSGDA, 1); |
152 | EXPECT_EQ(nbSGDB, 1); |
153 | } |
154 | |
155 | /// Check that we can differentiate functions that update Placeholder graphs. |
156 | TEST(GraphAutoGrad, checkPlaceholderGradTest) { |
157 | ExecutionEngine EE; |
158 | TrainingConfig TC; |
159 | PlaceholderBindings bindings; |
160 | |
161 | // Construct the network: |
162 | TC.learningRate = 0.001; |
163 | |
164 | auto &mod = EE.getModule(); |
165 | Function *F = mod.createFunction("main" ); |
166 | |
167 | Placeholder *A = |
168 | mod.createPlaceholder(ElemKind::FloatTy, {10, 28, 28, 1}, "input" , true); |
169 | auto *RL = F->createRELU("relu" , A); |
170 | F->createSave("return" , RL); |
171 | |
172 | // Expect a single user to the trainable input placeholder. |
173 | EXPECT_EQ(A->getNumUsers(), 1); |
174 | |
175 | glow::differentiate(F, TC); |
176 | EE.compile(CompilationMode::Train); |
177 | |
178 | // Check that the Placeholder has multiple users, because at least one write |
179 | // node will be added. |
180 | EXPECT_GT(A->getNumUsers(), 1); |
181 | } |
182 | |
183 | /// Check that we can differentiate functions that use ConvertToNode. |
184 | TEST(GraphAutoGrad, checkConvertToGradTest) { |
185 | ExecutionEngine EE; |
186 | TrainingConfig TC; |
187 | PlaceholderBindings bindings; |
188 | |
189 | // Construct the network: |
190 | TC.learningRate = 0.001; |
191 | |
192 | auto &mod = EE.getModule(); |
193 | Function *F = mod.createFunction("main" ); |
194 | |
195 | auto *A = mod.createPlaceholder(ElemKind::FloatTy, {20, 13}, "A" , false); |
196 | auto inputHandle = bindings.allocate(A)->getHandle<float>(); |
197 | inputHandle.randomize(-3.0, 3.0, mod.getPRNG()); |
198 | |
199 | auto *convertTo = F->createConvertTo("convertTo" , A, ElemKind::Float16Ty); |
200 | auto *result = F->createSave("save" , convertTo); |
201 | bindings.allocate(result->getPlaceholder()); |
202 | |
203 | glow::differentiate(F, TC); |
204 | EE.compile(CompilationMode::Train); |
205 | } |
206 | |
207 | /// Check that we can differentiate functions that use MatMulNode. |
208 | TEST(GraphAutoGrad, checkMatMulGradTest) { |
209 | ExecutionEngine EE; |
210 | TrainingConfig TC; |
211 | PlaceholderBindings Bindings; |
212 | |
213 | // Construct the network: |
214 | TC.learningRate = 0.001; |
215 | |
216 | auto &Mod = EE.getModule(); |
217 | Function *F = Mod.createFunction("main" ); |
218 | |
219 | auto *A = Mod.createPlaceholder(ElemKind::FloatTy, {20, 13}, "A" , false); |
220 | auto HandleA = Bindings.allocate(A)->getHandle<float>(); |
221 | HandleA.randomize(-3.0, 3.0, Mod.getPRNG()); |
222 | |
223 | auto *B = Mod.createPlaceholder(ElemKind::FloatTy, {13, 30}, "B" , false); |
224 | auto HandleB = Bindings.allocate(B)->getHandle<float>(); |
225 | HandleB.randomize(-3.0, 3.0, Mod.getPRNG()); |
226 | |
227 | auto *MatMul = F->createMatMul("matMul" , A, B); |
228 | auto *R = F->createSave("save" , MatMul); |
229 | Bindings.allocate(R->getPlaceholder()); |
230 | |
231 | glow::differentiate(F, TC); |
232 | EE.compile(CompilationMode::Train); |
233 | } |
234 | |
235 | /// Check that we can differentiate functions that use BatchMatMul. |
236 | TEST(GraphAutoGrad, checkBatchMatMulGradTest) { |
237 | ExecutionEngine EE; |
238 | TrainingConfig TC; |
239 | |
240 | auto &Mod = EE.getModule(); |
241 | Function *F = Mod.createFunction("main" ); |
242 | |
243 | auto *A = Mod.createPlaceholder(ElemKind::FloatTy, {5, 20, 13}, "A" , |
244 | /*isTrainable=*/false); |
245 | auto *B = Mod.createPlaceholder(ElemKind::FloatTy, {13, 30}, "B" , |
246 | /*isTrainable=*/false); |
247 | auto *BatchMatMul = F->createBatchMatMul("batchMatMul" , A, B); |
248 | |
249 | F->createSave("save" , BatchMatMul); |
250 | |
251 | glow::differentiate(F, TC); |
252 | EE.compile(CompilationMode::Train); |
253 | } |
254 | |
255 | // Check that we can differentiate functions that use Tile. |
256 | TEST(GraphAutoGrad, checkTileGradTest) { |
257 | ExecutionEngine EE; |
258 | TrainingConfig TC; |
259 | |
260 | auto &Mod = EE.getModule(); |
261 | Function *F = Mod.createFunction("main" ); |
262 | |
263 | auto *A = Mod.createPlaceholder(ElemKind::FloatTy, {10, 10}, "A" , false); |
264 | auto *Tile = F->createTile("tile" , A, /*tiles=*/5, /*axis=*/1); |
265 | |
266 | F->createSave("save" , Tile); |
267 | |
268 | glow::differentiate(F, TC); |
269 | EE.compile(CompilationMode::Train); |
270 | } |
271 | |
272 | /// Check that we can differentiate functions that use BatchedReduceAddNode. |
273 | TEST(GraphAutoGrad, checkBatchedReduceAddGradTest) { |
274 | ExecutionEngine EE; |
275 | TrainingConfig TC; |
276 | PlaceholderBindings Bindings; |
277 | |
278 | auto &Mod = EE.getModule(); |
279 | Function *F = Mod.createFunction("main" ); |
280 | |
281 | TypeRef Ty = Mod.uniqueType(ElemKind::FloatTy, {1, 10}); |
282 | auto *A = Mod.createPlaceholder(ElemKind::FloatTy, {10, 10}, "A" , false); |
283 | auto HandleA = Bindings.allocate(A)->getHandle<float>(); |
284 | HandleA.randomize(-3.0, 3.0, Mod.getPRNG()); |
285 | |
286 | auto *BRA = F->createBatchedReduceAdd("BRA" , Ty, A, 0 /*axis*/); |
287 | auto *R = F->createSave("save" , BRA); |
288 | Bindings.allocate(R->getPlaceholder()); |
289 | |
290 | glow::differentiate(F, TC); |
291 | EE.compile(CompilationMode::Train); |
292 | } |
293 | |
294 | /// Check that we can differentiate functions that use GatherNode. |
295 | TEST(GraphAutoGrad, checkGatherGrad1DIndexTest) { |
296 | ExecutionEngine EE; |
297 | TrainingConfig TC; |
298 | PlaceholderBindings Bindings; |
299 | |
300 | auto &Mod = EE.getModule(); |
301 | Function *F = Mod.createFunction("main" ); |
302 | |
303 | auto *Data = Mod.createPlaceholder(ElemKind::FloatTy, {3, 4}, "Data" , false); |
304 | auto *Indices = |
305 | Mod.createPlaceholder(ElemKind::Int64ITy, {2}, "Indices" , false); |
306 | |
307 | auto HandleData = Bindings.allocate(Data)->getHandle<float>(); |
308 | HandleData.randomize(-3.0, 3.0, Mod.getPRNG()); |
309 | |
310 | Bindings.allocate(Indices)->getHandle<int64_t>() = {0, 2}; |
311 | |
312 | auto *G = F->createGather("gather" , Data, Indices, 0 /*batchDims*/); |
313 | auto *R = F->createSave("save" , G); |
314 | Bindings.allocate(R->getPlaceholder()); |
315 | |
316 | glow::differentiate(F, TC); |
317 | EE.compile(CompilationMode::Train); |
318 | } |
319 | |
320 | TEST(GraphAutoGrad, checkGatherGrad2DIndexTest) { |
321 | ExecutionEngine EE; |
322 | TrainingConfig TC; |
323 | PlaceholderBindings Bindings; |
324 | |
325 | auto &Mod = EE.getModule(); |
326 | Function *F = Mod.createFunction("main" ); |
327 | |
328 | auto *Data = Mod.createPlaceholder(ElemKind::FloatTy, {8, 4}, "Data" , false); |
329 | auto *Indices = |
330 | Mod.createPlaceholder(ElemKind::Int64ITy, {2, 2}, "Indices" , false); |
331 | |
332 | auto HandleData = Bindings.allocate(Data)->getHandle<float>(); |
333 | HandleData.randomize(-3.0, 3.0, Mod.getPRNG()); |
334 | |
335 | Bindings.allocate(Indices)->getHandle<int64_t>() = {0, 2, 1, 3}; |
336 | |
337 | auto *G = F->createGather("gather" , Data, Indices, 0 /*batchDims*/); |
338 | auto *R = F->createSave("save" , G); |
339 | Bindings.allocate(R->getPlaceholder()); |
340 | |
341 | glow::differentiate(F, TC); |
342 | EE.compile(CompilationMode::Train); |
343 | } |
344 | |
345 | TEST(GraphAutoGrad, checkGatherGrad3DIndexTest) { |
346 | ExecutionEngine EE; |
347 | TrainingConfig TC; |
348 | PlaceholderBindings Bindings; |
349 | |
350 | auto &Mod = EE.getModule(); |
351 | Function *F = Mod.createFunction("main" ); |
352 | |
353 | auto *Data = Mod.createPlaceholder(ElemKind::FloatTy, {8, 4}, "Data" , false); |
354 | auto *Indices = |
355 | Mod.createPlaceholder(ElemKind::Int64ITy, {2, 2, 2}, "Indices" , false); |
356 | |
357 | auto HandleData = Bindings.allocate(Data)->getHandle<float>(); |
358 | HandleData.randomize(-3.0, 3.0, Mod.getPRNG()); |
359 | |
360 | Bindings.allocate(Indices)->getHandle<int64_t>() = {0, 2, 1, 3, 4, 5, 7, 6}; |
361 | |
362 | auto *G = F->createGather("gather" , Data, Indices, 0 /*batchDims*/); |
363 | auto *R = F->createSave("save" , G); |
364 | Bindings.allocate(R->getPlaceholder()); |
365 | |
366 | glow::differentiate(F, TC); |
367 | EE.compile(CompilationMode::Train); |
368 | } |
369 | |
370 | TEST(GraphAutoGrad, checkAdaptiveAvgPoolGradTest) { |
371 | ExecutionEngine EE; |
372 | TrainingConfig TC; |
373 | PlaceholderBindings Bindings; |
374 | |
375 | auto &Mod = EE.getModule(); |
376 | Function *F = Mod.createFunction("main" ); |
377 | |
378 | auto *Data = |
379 | Mod.createPlaceholder(ElemKind::FloatTy, {1, 8, 4, 1}, "Data" , false); |
380 | |
381 | auto HandleData = Bindings.allocate(Data)->getHandle<float>(); |
382 | HandleData.randomize(-3.0, 3.0, Mod.getPRNG()); |
383 | |
384 | auto outTy = Mod.uniqueType(ElemKind::FloatTy, {1, 3, 3, 1}); |
385 | Node *A = F->createAdaptiveAvgPool("pool" , Data, outTy); |
386 | auto *R = F->createSave("save" , A); |
387 | Bindings.allocate(R->getPlaceholder()); |
388 | |
389 | glow::differentiate(F, TC); |
390 | EE.compile(CompilationMode::Train); |
391 | } |
392 | |