1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Graph/Grad.h"
18#include "glow/Base/Train.h"
19#include "glow/Graph/Graph.h"
20#include "glow/Graph/Nodes.h"
21#include "glow/Graph/Utils.h"
22#include "glow/Support/Support.h"
23
24#include "llvm/ADT/SmallVector.h"
25#include "llvm/ADT/StringExtras.h"
26#include "llvm/Support/Casting.h"
27#include "llvm/Support/raw_ostream.h"
28
29using namespace glow;
30
31using llvm::cast;
32using llvm::isa;
33
34#define DECORATE_NODE_NAME(Node, ...) \
35 llvm::join_items("_", Node->getName(), __VA_ARGS__)
36
37void GraphGradMapper::addGradient(NodeValue activation, NodeValue grad) {
38 auto p = map_.insert({activation, grad});
39 if (!p.second) {
40 auto curr = p.first->second;
41 auto *sum = F_->createAdd("updateGrad", curr, grad);
42 p.first->second = sum;
43 }
44}
45
46bool GraphGradMapper::hasGradient(NodeValue activation) {
47 return map_.count(activation);
48}
49
50NodeValue GraphGradMapper::getGradient(NodeValue activation) {
51 return map_[activation];
52}
53
54//===----------------------------------------------------------------------===//
55// Code for automatically generating the back propagation code.
56//===----------------------------------------------------------------------===//
57
58Function *glow::differentiate(Function *F, const TrainingConfig &conf,
59 llvm::StringRef newFuncName,
60 VariableGradientsList *varGrads) {
61 // Create a new name for the differentiated function, if none is given.
62 std::string tmpName;
63 if (newFuncName.empty()) {
64 tmpName = std::string(F->getName()) + "_grad";
65 newFuncName = tmpName;
66 }
67
68 // Clone the function.
69 Function *G = F->clone(newFuncName);
70
71 using Kind = glow::Kinded::Kind;
72 GraphGradMapper map(G);
73
74 // A list of nodes to add to the Function.
75 std::vector<Node *> toAppend;
76
77 // Generate the gradient nodes for each one of the nodes in the function.
78
79 PostOrderVisitor pov;
80 for (auto &N : G->getNodes()) {
81 N.visit(nullptr, &pov);
82 }
83
84 auto nodes = pov.getPostOrder();
85
86 for (auto it = nodes.rbegin(), e = nodes.rend(); it != e; it++) {
87 Node *N = *it;
88 if (isa<Storage>(N)) {
89 continue;
90 }
91
92#define CONVERT_TO_GRAD_NODE(NodeKind) \
93 if (N->getKind() == Kind::NodeKind##Kind) { \
94 toAppend.push_back(cast<NodeKind>(N)->getGrad(map)); \
95 continue; \
96 }
97
98 CONVERT_TO_GRAD_NODE(ConvolutionNode)
99 CONVERT_TO_GRAD_NODE(AvgPoolNode)
100 CONVERT_TO_GRAD_NODE(AdaptiveAvgPoolNode)
101 CONVERT_TO_GRAD_NODE(FullyConnectedNode)
102 CONVERT_TO_GRAD_NODE(LocalResponseNormalizationNode)
103 CONVERT_TO_GRAD_NODE(SoftMaxNode)
104 CONVERT_TO_GRAD_NODE(CrossEntropyLossNode)
105 CONVERT_TO_GRAD_NODE(RegressionNode)
106 CONVERT_TO_GRAD_NODE(AddNode)
107 CONVERT_TO_GRAD_NODE(MulNode)
108 CONVERT_TO_GRAD_NODE(SubNode)
109 CONVERT_TO_GRAD_NODE(DivNode)
110 CONVERT_TO_GRAD_NODE(ReluNode)
111 CONVERT_TO_GRAD_NODE(SigmoidNode)
112 CONVERT_TO_GRAD_NODE(TanhNode)
113 CONVERT_TO_GRAD_NODE(SparseLengthsWeightedSumNode)
114 CONVERT_TO_GRAD_NODE(SparseLengthsSumNode)
115
116 if (N->getKind() == Kind::BatchedPairwiseDotProductNodeKind) {
117 BatchedPairwiseDotProductNode *BPDPN =
118 cast<BatchedPairwiseDotProductNode>(N);
119 auto outputGrad = map.getGradient(BPDPN->getResult());
120
121 auto *X = new BatchedPairwiseDotProductGradNode(
122 DECORATE_NODE_NAME(N, "grad"), outputGrad, BPDPN->getInputs());
123
124 size_t i = 0;
125 for (auto &in : BPDPN->getInputs()) {
126 X->addExtraResult(in.getType());
127 map.addGradient(in, X->getNthResult(i));
128 ++i;
129 }
130
131 toAppend.push_back(X);
132 continue;
133 }
134
135 if (N->getKind() == Kind::SaveNodeKind) {
136 // Swap the src and dest. Send the Zero value as gradient for both sides.
137 auto *X = new SplatNode(DECORATE_NODE_NAME(N, "grad"),
138 cast<SaveNode>(N)->getInput().getType(), 0);
139 toAppend.push_back(X);
140 map.addGradient(cast<SaveNode>(N)->getInput(), X);
141 map.addGradient(cast<SaveNode>(N)->getOutput(), X);
142 continue;
143 }
144
145 if (N->getKind() == Kind::MaxPoolNodeKind) {
146 auto *MPN = llvm::cast<MaxPoolNode>(N);
147 // Argmax cannot be differentiated. Assert it has no users, and use a zero
148 // Splat for its grad input so it doesn't have a null input.
149 assert(MPN->getArgmax().getNumUsers() == 0 &&
150 "Argmax cannot be differentiated; must go unused.");
151 auto *ZSN = new SplatNode(DECORATE_NODE_NAME(N, "grad"),
152 MPN->getArgmax().getType(), 0);
153 toAppend.push_back(ZSN);
154 map.addGradient(MPN->getArgmax(), ZSN);
155 toAppend.push_back(MPN->getGrad(map));
156 continue;
157 }
158
159 if (N->getKind() == Kind::ReshapeNodeKind) {
160 ReshapeNode *RN = cast<ReshapeNode>(N);
161 NodeValue outputG = map.getGradient(RN->getResult());
162 NodeValue inputW = RN->getInput();
163
164 // Swap the src and dest.
165 auto *X = new ReshapeNode(DECORATE_NODE_NAME(RN, "grad", "reshape"),
166 inputW.getType(), outputG,
167 inputW.getType()->dims(), RN->getLayout());
168 toAppend.push_back(X);
169 map.addGradient(RN->getInput(), X);
170 continue;
171 }
172
173 if (N->getKind() == Kind::TileNodeKind) {
174 TileNode *TN = cast<TileNode>(N);
175 NodeValue outputG = map.getGradient(TN->getResult());
176
177 // To compute the gradient with respect to the input of the TileNode, all
178 // of the slices in outputG corresponding to the tiled slices in the
179 // forward pass need to be added together. This is achieved by reshaping
180 // outputG to replace the tiling axis with {numTiles, tileDim}, and then
181 // performing a BatchedReduceAdd on the axis with numTiles elements. For
182 // example, if the tile creates a {n,x,h,w} output with a {n,c,h,w}
183 // input where x = c * numTiles, then the {n,x,h,w} gradient with respect
184 // to the output is reshaped to {n, numTiles, c, h, w} so that
185 // BatchedReduceAddNode eliminates the numTiles axis and produces a
186 // {n,c,h,w} output.
187 auto *TNInputType = TN->getInput().getType();
188 std::vector<dim_t> BRAInputDims{TNInputType->dims()};
189 BRAInputDims.insert(BRAInputDims.begin() + TN->getAxis(), TN->getCount());
190 auto *BRAInputType =
191 F->getParent()->uniqueTypeWithNewShape(TNInputType, BRAInputDims);
192
193 auto *RN =
194 new ReshapeNode(DECORATE_NODE_NAME(TN, "grad", "reshape"),
195 BRAInputType, outputG, BRAInputType->dims(), "*");
196 auto *BRA =
197 new BatchedReduceAddNode(DECORATE_NODE_NAME(TN, "grad", "bra"),
198 TN->getInput().getType(), RN, TN->getAxis());
199
200 toAppend.push_back(RN);
201 toAppend.push_back(BRA);
202 map.addGradient(TN->getInput(), BRA);
203 continue;
204 }
205
206 if (N->getKind() == Kind::ConvertToNodeKind) {
207 auto *RN = cast<ConvertToNode>(N);
208 NodeValue outputG = map.getGradient(RN->getResult());
209 NodeValue inputW = RN->getInput();
210
211 // Swap the src and dest.
212 auto *X = new ConvertToNode(DECORATE_NODE_NAME(N, "grad"),
213 inputW.getType(), outputG);
214 toAppend.push_back(X);
215 map.addGradient(RN->getInput(), X);
216 continue;
217 }
218
219 if (N->getKind() == Kind::TransposeNodeKind) {
220 TransposeNode *TN = cast<TransposeNode>(N);
221 NodeValue outputG = map.getGradient(TN->getResult());
222 NodeValue inputW = TN->getInput();
223
224 // Generate the reverse shuffle.
225 auto shuffle = TN->getShuffle();
226 auto layout = TN->getLayout();
227 std::string reverseLayout;
228 reverseLayout.resize(TN->getLayout().size());
229 std::vector<unsigned_t> reverseShuffle(shuffle.begin(), shuffle.end());
230 for (unsigned int i = 0; i < shuffle.size(); i++) {
231 reverseShuffle[shuffle[i]] = i;
232 reverseLayout[shuffle[i]] = layout[i];
233 }
234
235 // Swap the src and dest.
236 auto *X =
237 new TransposeNode(DECORATE_NODE_NAME(N, "grad"), inputW.getType(),
238 outputG, reverseShuffle, reverseLayout);
239 toAppend.push_back(X);
240 map.addGradient(TN->getInput(), X);
241 continue;
242 }
243
244 if (N->getKind() == Kind::SliceNodeKind) {
245 SliceNode *SN = cast<SliceNode>(N);
246 auto *zero = new SplatNode(DECORATE_NODE_NAME(SN, "expand"),
247 SN->getInput().getType(), 0);
248 auto *insert =
249 new InsertTensorNode(DECORATE_NODE_NAME(SN, "grad"), zero,
250 map.getGradient(SN->getResult()), SN->getStart(),
251 /* count */ 1, /* axis */ 0);
252
253 toAppend.push_back(zero);
254 toAppend.push_back(insert);
255 map.addGradient(SN->getInput(), insert);
256 continue;
257 }
258
259 if (N->getKind() == Kind::ConcatNodeKind) {
260 auto *CC = cast<ConcatNode>(N);
261 auto inputs = CC->getInputs();
262 NodeValue outputG = map.getGradient(CC->getResult());
263
264 // We start extracting the shape at (0,0, ... ).
265 std::vector<dim_t> offsets(CC->getResult().dims().size(), 0);
266 unsigned_t dim = CC->getDim();
267 for (auto &N : inputs) {
268 // SliceNode's name will be auto incremented due to name uniqueness.
269 auto *X = new SliceNode(DECORATE_NODE_NAME(CC, "extract"), N.getType(),
270 outputG, offsets);
271 toAppend.push_back(X);
272 // We are stacking the tensors along a specific dimension. This means
273 // that we increase the size of the tensor along this dimension.
274 offsets[dim] += N.dims()[dim];
275 map.addGradient(N, X);
276 }
277 continue;
278 }
279
280 if (N->getKind() == Kind::SplatNodeKind)
281 // Constant nodes don't have inputs therefore don't need grad
282 // calculations.
283 continue;
284
285 if (N->getKind() == Kind::BatchNormalizationNodeKind) {
286 auto *BN = cast<BatchNormalizationNode>(N);
287 auto in = BN->getInput();
288 auto mean = BN->getMean();
289 auto var = BN->getVar();
290 auto channelIdx = BN->getChannelIdx();
291 auto momentum = BN->getMomentum();
292
293 // Update the mean and variance via the MeanVarNormalizationNode.
294 auto *MVN = new MeanVarNormalizationNode(
295 DECORATE_NODE_NAME(BN, "grad"), in, mean, var, channelIdx, momentum);
296 toAppend.push_back(MVN);
297
298 // Save the newly calculated mean and variance to the mean and variance
299 // variables. These will be used during the next iteration of training.
300 G->createSave(DECORATE_NODE_NAME(MVN, "mean"), MVN->getNewMean(),
301 llvm::cast<Placeholder>(mean.getNode()));
302 G->createSave(DECORATE_NODE_NAME(MVN, "var"), MVN->getNewVar(),
303 llvm::cast<Placeholder>(var.getNode()));
304
305 // Replace the BN's mean and variance with the new mean and variance
306 // calculated from MVN.
307 BN->getParent()->getLogContext()->logNodeInputChange(
308 *BN, BN->getNthInput(BatchNormalizationNode::MeanIdx), mean);
309 BN->setNthInput(BatchNormalizationNode::MeanIdx, mean);
310 BN->getParent()->getLogContext()->logNodeInputChange(
311 *BN, BN->getNthInput(BatchNormalizationNode::VarIdx), var);
312 BN->setNthInput(BatchNormalizationNode::VarIdx, var);
313
314 toAppend.push_back(BN->getGrad(map));
315
316 continue;
317 }
318
319 if (N->getKind() == Kind::MatMulNodeKind) {
320 MatMulNode *MMN = cast<MatMulNode>(N);
321 // Get gradient.
322 NodeValue OutputG = map.getGradient(MMN->getResult());
323
324 // Get LHS/RHS inputs and their transpose presentations.
325 NodeValue InputLHS = MMN->getLHS();
326 NodeValue InputRHS = MMN->getRHS();
327 auto *LT = G->createTranspose(
328 DECORATE_NODE_NAME(MMN, "grad", "lhs", "transpose"), InputLHS,
329 {1, 0});
330 auto *RT = G->createTranspose(
331 DECORATE_NODE_NAME(MMN, "grad", "rhs", "transpose"), InputRHS,
332 {1, 0});
333
334 // Grad for LHS = outputG x transpose(RHS).
335 auto *GradLHS = new MatMulNode(DECORATE_NODE_NAME(MMN, "grad", "lhs"),
336 InputLHS.getType(), OutputG, RT);
337 // Grad for RHS = transpose(LHS) x outputG.
338 auto *GradRHS = new MatMulNode(DECORATE_NODE_NAME(MMN, "grad", "rhs"),
339 InputRHS.getType(), LT, OutputG);
340
341 toAppend.push_back(GradLHS);
342 map.addGradient(InputLHS, GradLHS);
343 toAppend.push_back(GradRHS);
344 map.addGradient(InputRHS, GradRHS);
345 continue;
346 }
347
348 if (N->getKind() == Kind::BatchMatMulNodeKind) {
349 BatchMatMulNode *BMMN = cast<BatchMatMulNode>(N);
350 // Get gradient.
351 NodeValue OutputG = map.getGradient(BMMN->getResult());
352
353 // The implementation below is a batched version of the gradient
354 // computation for MatMul.
355 NodeValue InputLHS = BMMN->getLHS();
356 NodeValue InputRHS = BMMN->getRHS();
357 auto *LT = G->createTranspose(
358 DECORATE_NODE_NAME(BMMN, "grad", "lhs", "transpose"), InputLHS,
359 {0, 2, 1});
360 auto *RT = G->createTranspose(
361 DECORATE_NODE_NAME(BMMN, "grad", "lhs", "transpose"), InputRHS,
362 {0, 2, 1});
363
364 // Grad for LHS = outputG x transpose(RHS).
365 auto *GradLHS =
366 new BatchMatMulNode(DECORATE_NODE_NAME(BMMN, "grad", "lhs"),
367 InputLHS.getType(), OutputG, RT);
368 // Grad for RHS = transpose(LHS) x outputG.
369 auto *GradRHS =
370 new BatchMatMulNode(DECORATE_NODE_NAME(BMMN, "grad", "lhs"),
371 InputRHS.getType(), LT, OutputG);
372
373 toAppend.push_back(GradLHS);
374 map.addGradient(InputLHS, GradLHS);
375 toAppend.push_back(GradRHS);
376 map.addGradient(InputRHS, GradRHS);
377 continue;
378 }
379
380 if (N->getKind() == Kind::BatchedReduceAddNodeKind) {
381 BatchedReduceAddNode *BRA = cast<BatchedReduceAddNode>(N);
382 // Get gradient.
383 NodeValue OutputG = map.getGradient(BRA->getResult());
384 // Get input value.
385 NodeValue Input = BRA->getBatch();
386
387 // Gradient for BatchedReduceAddNode is TileNode,
388 // repeating OutputG batch times.
389 auto Axis = BRA->getAxis();
390 // Copy input dimensions first.
391 std::vector<dim_t> Dims{Input.dims()};
392 // Then set to 1 dimension size on axis.
393 Dims[Axis] = 1;
394 auto *RSN = G->createReshape(DECORATE_NODE_NAME(BRA, "grad", "reshape"),
395 OutputG, Dims);
396 auto *TN =
397 new TileNode(DECORATE_NODE_NAME(BRA, "grad", "tile"), Input.getType(),
398 RSN->getResult(), Input.dims()[Axis], Axis);
399
400 toAppend.push_back(TN);
401 map.addGradient(Input, TN);
402 continue;
403 }
404
405 if (N->getKind() == Kind::GatherNodeKind) {
406 GatherNode *GN = cast<GatherNode>(N);
407 // Get gradient.
408 NodeValue Result = GN->getResult();
409 NodeValue OutputG = map.getGradient(Result);
410 // Get Data & Indices.
411 NodeValue Data = GN->getData();
412 NodeValue Indices = GN->getIndices();
413
414 // Reshape indices into a two-dimensional Tensor (Vector).
415 std::vector<dim_t> IndicesDims{Indices.getType()->size(), 1};
416 auto *RI =
417 G->createReshape(DECORATE_NODE_NAME(GN, "grad", "reshape", "indices"),
418 Indices, IndicesDims);
419
420 // Reshape Gradient into N-k dimension, where k is Index dimensions,
421 // except the case when Indices is one-dimensional.
422 ReshapeNode *RG = nullptr;
423 auto K = Indices.dims().size();
424 if (K != 1) {
425 const auto &OrgDims = OutputG.dims();
426 std::vector<dim_t> GDims{OrgDims.begin() + K - 1, OrgDims.end()};
427 for (dim_t k = 0; k < K - 1; ++k) {
428 GDims[0] *= OrgDims[k];
429 }
430 RG = G->createReshape(
431 DECORATE_NODE_NAME(GN, "grad", "reshape", "output"), OutputG,
432 GDims);
433 }
434 // Reshaped Indices Vector maps Reshaped Gradient Tensors
435 // to the correspondent Data Tensors, where Vector value
436 // points to Data Tensor.
437 auto *SN =
438 G->createSplat(DECORATE_NODE_NAME(GN, "splat"), Data.getType(), 0);
439 auto *SA = new ScatterDataNode(DECORATE_NODE_NAME(GN, "scatter_assign"),
440 SN->getResult(), RI->getResult(),
441 RG ? RG->getResult() : OutputG,
442 /*cumulative*/ false);
443 toAppend.push_back(SA);
444 map.addGradient(Data, SA);
445 continue;
446 }
447
448 llvm_unreachable("Invalid instruction type.");
449 } // End of the for-each instr loop.
450
451 for (auto N : nodes) {
452 // Iterate only through Placeholders used by the Function. These are
453 // inserted during the post-order walk.
454 Placeholder *PH = llvm::dyn_cast<Placeholder>(N);
455 if (!PH)
456 continue;
457
458 // In this special differentiation mode we record the last gradient value
459 // without performing the SGD update. This mode is used by the unit tests.
460 if (varGrads) {
461 if (map.hasGradient(PH)) {
462 std::string nodeName = "_grad_" + PH->getName().str();
463 // Save the gradient and return the destination variable.
464 auto *saveNode = G->createSave(nodeName, map.getGradient(PH));
465 Placeholder *GradV = saveNode->getPlaceholder();
466 varGrads->push_back({PH, GradV});
467 }
468 continue;
469 }
470
471 // Don't update nodes that are not marked as trainable.
472 if (!PH->isTraining()) {
473 continue;
474 }
475
476 auto X = new SGDNode(PH->getName(), map.getGradient(PH), PH, conf.L1Decay,
477 conf.L2Decay, conf.learningRate, conf.momentum,
478 conf.batchSize);
479 toAppend.push_back(X);
480 // Now update the weight with the value computed by SGD.
481 auto *save =
482 new SaveNode(DECORATE_NODE_NAME(PH, "save", "grad"), {X, 0}, PH);
483 toAppend.push_back(save);
484 }
485
486 // Add all of the new variables and instructions.
487 for (auto &I : toAppend) {
488 G->addNode(I);
489 }
490
491 return G;
492}
493
494#undef DECORATE_NODE_NAME
495