1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/ExecutionEngine/ExecutionEngine.h"
18#include "glow/Graph/Graph.h"
19#include "glow/IR/IR.h"
20#include "glow/IR/IRBuilder.h"
21#include "glow/IR/Instrs.h"
22
23#include "gtest/gtest.h"
24
25#include <cassert>
26#include <string>
27
28using namespace glow;
29
30TEST(GraphAutoGrad, autoGrad) {
31 ExecutionEngine EE;
32 PlaceholderBindings bindings;
33
34 TrainingConfig TC;
35
36 // Construct the network:
37 TC.learningRate = 0.001;
38 TC.momentum = 0.9;
39 TC.L2Decay = 0.001;
40 TC.L1Decay = 0.001;
41
42 auto &mod = EE.getModule();
43 Function *F = mod.createFunction("main");
44 auto *A =
45 mod.createPlaceholder(ElemKind::FloatTy, {10, 28, 28, 1}, "input", false);
46
47 auto *CV0 = F->createConv(bindings, "conv1", A, 16, 5, 1, 2, 1);
48 auto *RL0 = F->createRELU("relu1", CV0);
49 auto *MP0 = F->createMaxPool("pool1", RL0, 3, 3, 0);
50
51 auto *CV1 =
52 F->createConv(bindings, "conv2", MP0->getResult(), 16, 5, 1, 2, 1);
53 auto *RL1 = F->createRELU("conv23", CV1);
54 auto *MP1 = F->createMaxPool("pool2", RL1, 3, 3, 0);
55
56 auto *FCL1 = F->createFullyConnected(bindings, "fc3", MP1->getResult(), 10);
57 auto *RL2 = F->createRELU("relu3", FCL1);
58 auto *selected =
59 mod.createPlaceholder(ElemKind::Int64ITy, {10, 1}, "selected", false);
60
61 auto *SM = F->createSoftMax("sm", RL2, selected);
62
63 auto *result = F->createSave("return", SM);
64 (void)result;
65
66 glow::differentiate(F, TC);
67 EE.compile(CompilationMode::Train);
68}
69
70TEST(GraphAutoGrad, checkLRNGen) {
71 ExecutionEngine EE;
72 TrainingConfig TC;
73 PlaceholderBindings bindings;
74
75 // Construct the network:
76 TC.learningRate = 0.001;
77 TC.momentum = 0.9;
78 TC.L2Decay = 0.001;
79
80 auto &mod = EE.getModule();
81 Function *F = mod.createFunction("main");
82
83 auto *A =
84 mod.createPlaceholder(ElemKind::FloatTy, {10, 28, 28, 1}, "input", false);
85 auto *CV0 = F->createLocalResponseNormalization("LRN", A);
86 auto *FCL1 = F->createFullyConnected(bindings, "fc3", CV0, 10);
87 auto *RL2 = F->createRELU("relu3", FCL1);
88 auto *selected =
89 mod.createPlaceholder(ElemKind::Int64ITy, {10, 1}, "selected", false);
90
91 auto *SM = F->createSoftMax("sm", RL2, selected);
92
93 auto *result = F->createSave("return", SM);
94 (void)result;
95 glow::differentiate(F, TC);
96 EE.compile(CompilationMode::Train);
97}
98
99TEST(GraphAutoGrad, cloneAndDiff) {
100 // The test ensures that unused variables are not touched in differentiation.
101 ExecutionEngine EE;
102 TrainingConfig TC;
103 PlaceholderBindings bindings;
104 Module M;
105
106 auto *F = M.createFunction("main");
107 Node *A = M.createPlaceholder(ElemKind::FloatTy, {1}, "A", true);
108 Node *B = M.createPlaceholder(ElemKind::FloatTy, {1}, "B", true);
109 Node *AplusB_F = F->createAdd("AplusB", A, B);
110
111 EXPECT_EQ(M.getPlaceholders().size(), 2);
112
113 auto *G = F->clone("G");
114
115 EXPECT_EQ(M.getPlaceholders().size(), 2);
116 EXPECT_EQ(G->getNodes().size(), 1);
117
118 Node *C = M.createPlaceholder(ElemKind::FloatTy, {1}, "C", true);
119 Node *AplusB_G = &G->getNodes().back();
120 G->createAdd("totalSum", AplusB_G, C);
121
122 EXPECT_EQ(M.getPlaceholders().size(), 3);
123
124 Node *label = M.createPlaceholder(ElemKind::FloatTy, {1}, "label", false);
125 Node *reg = F->createRegression("reg", AplusB_F, label);
126 F->createSave("return", reg);
127
128 EXPECT_EQ(M.getPlaceholders().size(), 5);
129
130 auto *diffF = differentiate(F, TC);
131
132 EXPECT_TRUE(diffF->verify());
133
134 EXPECT_EQ(M.getFunctions().size(), 3);
135 EXPECT_EQ(M.getPlaceholders().size(), 5);
136 // Check that we have as many SGD node as variables that need to be trained.
137 unsigned nbSGDs = 0;
138 unsigned nbSGDA = 0;
139 unsigned nbSGDB = 0;
140 for (auto &node : diffF->getNodes()) {
141 SGDNode *SGD = llvm::dyn_cast<SGDNode>(&node);
142 if (!SGD)
143 continue;
144 ++nbSGDs;
145 if (A == SGD->getWeight())
146 ++nbSGDA;
147 else if (B == SGD->getWeight())
148 ++nbSGDB;
149 }
150 EXPECT_EQ(nbSGDs, 2);
151 EXPECT_EQ(nbSGDA, 1);
152 EXPECT_EQ(nbSGDB, 1);
153}
154
155/// Check that we can differentiate functions that update Placeholder graphs.
156TEST(GraphAutoGrad, checkPlaceholderGradTest) {
157 ExecutionEngine EE;
158 TrainingConfig TC;
159 PlaceholderBindings bindings;
160
161 // Construct the network:
162 TC.learningRate = 0.001;
163
164 auto &mod = EE.getModule();
165 Function *F = mod.createFunction("main");
166
167 Placeholder *A =
168 mod.createPlaceholder(ElemKind::FloatTy, {10, 28, 28, 1}, "input", true);
169 auto *RL = F->createRELU("relu", A);
170 F->createSave("return", RL);
171
172 // Expect a single user to the trainable input placeholder.
173 EXPECT_EQ(A->getNumUsers(), 1);
174
175 glow::differentiate(F, TC);
176 EE.compile(CompilationMode::Train);
177
178 // Check that the Placeholder has multiple users, because at least one write
179 // node will be added.
180 EXPECT_GT(A->getNumUsers(), 1);
181}
182
183/// Check that we can differentiate functions that use ConvertToNode.
184TEST(GraphAutoGrad, checkConvertToGradTest) {
185 ExecutionEngine EE;
186 TrainingConfig TC;
187 PlaceholderBindings bindings;
188
189 // Construct the network:
190 TC.learningRate = 0.001;
191
192 auto &mod = EE.getModule();
193 Function *F = mod.createFunction("main");
194
195 auto *A = mod.createPlaceholder(ElemKind::FloatTy, {20, 13}, "A", false);
196 auto inputHandle = bindings.allocate(A)->getHandle<float>();
197 inputHandle.randomize(-3.0, 3.0, mod.getPRNG());
198
199 auto *convertTo = F->createConvertTo("convertTo", A, ElemKind::Float16Ty);
200 auto *result = F->createSave("save", convertTo);
201 bindings.allocate(result->getPlaceholder());
202
203 glow::differentiate(F, TC);
204 EE.compile(CompilationMode::Train);
205}
206
207/// Check that we can differentiate functions that use MatMulNode.
208TEST(GraphAutoGrad, checkMatMulGradTest) {
209 ExecutionEngine EE;
210 TrainingConfig TC;
211 PlaceholderBindings Bindings;
212
213 // Construct the network:
214 TC.learningRate = 0.001;
215
216 auto &Mod = EE.getModule();
217 Function *F = Mod.createFunction("main");
218
219 auto *A = Mod.createPlaceholder(ElemKind::FloatTy, {20, 13}, "A", false);
220 auto HandleA = Bindings.allocate(A)->getHandle<float>();
221 HandleA.randomize(-3.0, 3.0, Mod.getPRNG());
222
223 auto *B = Mod.createPlaceholder(ElemKind::FloatTy, {13, 30}, "B", false);
224 auto HandleB = Bindings.allocate(B)->getHandle<float>();
225 HandleB.randomize(-3.0, 3.0, Mod.getPRNG());
226
227 auto *MatMul = F->createMatMul("matMul", A, B);
228 auto *R = F->createSave("save", MatMul);
229 Bindings.allocate(R->getPlaceholder());
230
231 glow::differentiate(F, TC);
232 EE.compile(CompilationMode::Train);
233}
234
235/// Check that we can differentiate functions that use BatchMatMul.
236TEST(GraphAutoGrad, checkBatchMatMulGradTest) {
237 ExecutionEngine EE;
238 TrainingConfig TC;
239
240 auto &Mod = EE.getModule();
241 Function *F = Mod.createFunction("main");
242
243 auto *A = Mod.createPlaceholder(ElemKind::FloatTy, {5, 20, 13}, "A",
244 /*isTrainable=*/false);
245 auto *B = Mod.createPlaceholder(ElemKind::FloatTy, {13, 30}, "B",
246 /*isTrainable=*/false);
247 auto *BatchMatMul = F->createBatchMatMul("batchMatMul", A, B);
248
249 F->createSave("save", BatchMatMul);
250
251 glow::differentiate(F, TC);
252 EE.compile(CompilationMode::Train);
253}
254
255// Check that we can differentiate functions that use Tile.
256TEST(GraphAutoGrad, checkTileGradTest) {
257 ExecutionEngine EE;
258 TrainingConfig TC;
259
260 auto &Mod = EE.getModule();
261 Function *F = Mod.createFunction("main");
262
263 auto *A = Mod.createPlaceholder(ElemKind::FloatTy, {10, 10}, "A", false);
264 auto *Tile = F->createTile("tile", A, /*tiles=*/5, /*axis=*/1);
265
266 F->createSave("save", Tile);
267
268 glow::differentiate(F, TC);
269 EE.compile(CompilationMode::Train);
270}
271
272/// Check that we can differentiate functions that use BatchedReduceAddNode.
273TEST(GraphAutoGrad, checkBatchedReduceAddGradTest) {
274 ExecutionEngine EE;
275 TrainingConfig TC;
276 PlaceholderBindings Bindings;
277
278 auto &Mod = EE.getModule();
279 Function *F = Mod.createFunction("main");
280
281 TypeRef Ty = Mod.uniqueType(ElemKind::FloatTy, {1, 10});
282 auto *A = Mod.createPlaceholder(ElemKind::FloatTy, {10, 10}, "A", false);
283 auto HandleA = Bindings.allocate(A)->getHandle<float>();
284 HandleA.randomize(-3.0, 3.0, Mod.getPRNG());
285
286 auto *BRA = F->createBatchedReduceAdd("BRA", Ty, A, 0 /*axis*/);
287 auto *R = F->createSave("save", BRA);
288 Bindings.allocate(R->getPlaceholder());
289
290 glow::differentiate(F, TC);
291 EE.compile(CompilationMode::Train);
292}
293
294/// Check that we can differentiate functions that use GatherNode.
295TEST(GraphAutoGrad, checkGatherGrad1DIndexTest) {
296 ExecutionEngine EE;
297 TrainingConfig TC;
298 PlaceholderBindings Bindings;
299
300 auto &Mod = EE.getModule();
301 Function *F = Mod.createFunction("main");
302
303 auto *Data = Mod.createPlaceholder(ElemKind::FloatTy, {3, 4}, "Data", false);
304 auto *Indices =
305 Mod.createPlaceholder(ElemKind::Int64ITy, {2}, "Indices", false);
306
307 auto HandleData = Bindings.allocate(Data)->getHandle<float>();
308 HandleData.randomize(-3.0, 3.0, Mod.getPRNG());
309
310 Bindings.allocate(Indices)->getHandle<int64_t>() = {0, 2};
311
312 auto *G = F->createGather("gather", Data, Indices, 0 /*batchDims*/);
313 auto *R = F->createSave("save", G);
314 Bindings.allocate(R->getPlaceholder());
315
316 glow::differentiate(F, TC);
317 EE.compile(CompilationMode::Train);
318}
319
320TEST(GraphAutoGrad, checkGatherGrad2DIndexTest) {
321 ExecutionEngine EE;
322 TrainingConfig TC;
323 PlaceholderBindings Bindings;
324
325 auto &Mod = EE.getModule();
326 Function *F = Mod.createFunction("main");
327
328 auto *Data = Mod.createPlaceholder(ElemKind::FloatTy, {8, 4}, "Data", false);
329 auto *Indices =
330 Mod.createPlaceholder(ElemKind::Int64ITy, {2, 2}, "Indices", false);
331
332 auto HandleData = Bindings.allocate(Data)->getHandle<float>();
333 HandleData.randomize(-3.0, 3.0, Mod.getPRNG());
334
335 Bindings.allocate(Indices)->getHandle<int64_t>() = {0, 2, 1, 3};
336
337 auto *G = F->createGather("gather", Data, Indices, 0 /*batchDims*/);
338 auto *R = F->createSave("save", G);
339 Bindings.allocate(R->getPlaceholder());
340
341 glow::differentiate(F, TC);
342 EE.compile(CompilationMode::Train);
343}
344
345TEST(GraphAutoGrad, checkGatherGrad3DIndexTest) {
346 ExecutionEngine EE;
347 TrainingConfig TC;
348 PlaceholderBindings Bindings;
349
350 auto &Mod = EE.getModule();
351 Function *F = Mod.createFunction("main");
352
353 auto *Data = Mod.createPlaceholder(ElemKind::FloatTy, {8, 4}, "Data", false);
354 auto *Indices =
355 Mod.createPlaceholder(ElemKind::Int64ITy, {2, 2, 2}, "Indices", false);
356
357 auto HandleData = Bindings.allocate(Data)->getHandle<float>();
358 HandleData.randomize(-3.0, 3.0, Mod.getPRNG());
359
360 Bindings.allocate(Indices)->getHandle<int64_t>() = {0, 2, 1, 3, 4, 5, 7, 6};
361
362 auto *G = F->createGather("gather", Data, Indices, 0 /*batchDims*/);
363 auto *R = F->createSave("save", G);
364 Bindings.allocate(R->getPlaceholder());
365
366 glow::differentiate(F, TC);
367 EE.compile(CompilationMode::Train);
368}
369
370TEST(GraphAutoGrad, checkAdaptiveAvgPoolGradTest) {
371 ExecutionEngine EE;
372 TrainingConfig TC;
373 PlaceholderBindings Bindings;
374
375 auto &Mod = EE.getModule();
376 Function *F = Mod.createFunction("main");
377
378 auto *Data =
379 Mod.createPlaceholder(ElemKind::FloatTy, {1, 8, 4, 1}, "Data", false);
380
381 auto HandleData = Bindings.allocate(Data)->getHandle<float>();
382 HandleData.randomize(-3.0, 3.0, Mod.getPRNG());
383
384 auto outTy = Mod.uniqueType(ElemKind::FloatTy, {1, 3, 3, 1});
385 Node *A = F->createAdaptiveAvgPool("pool", Data, outTy);
386 auto *R = F->createSave("save", A);
387 Bindings.allocate(R->getPlaceholder());
388
389 glow::differentiate(F, TC);
390 EE.compile(CompilationMode::Train);
391}
392