1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Graph/Grad.h" |
18 | #include "glow/Base/Train.h" |
19 | #include "glow/Graph/Graph.h" |
20 | #include "glow/Graph/Nodes.h" |
21 | #include "glow/Graph/Utils.h" |
22 | #include "glow/Support/Support.h" |
23 | |
24 | #include "llvm/ADT/SmallVector.h" |
25 | #include "llvm/ADT/StringExtras.h" |
26 | #include "llvm/Support/Casting.h" |
27 | #include "llvm/Support/raw_ostream.h" |
28 | |
29 | using namespace glow; |
30 | |
31 | using llvm::cast; |
32 | using llvm::isa; |
33 | |
34 | #define DECORATE_NODE_NAME(Node, ...) \ |
35 | llvm::join_items("_", Node->getName(), __VA_ARGS__) |
36 | |
37 | void GraphGradMapper::addGradient(NodeValue activation, NodeValue grad) { |
38 | auto p = map_.insert({activation, grad}); |
39 | if (!p.second) { |
40 | auto curr = p.first->second; |
41 | auto *sum = F_->createAdd("updateGrad" , curr, grad); |
42 | p.first->second = sum; |
43 | } |
44 | } |
45 | |
46 | bool GraphGradMapper::hasGradient(NodeValue activation) { |
47 | return map_.count(activation); |
48 | } |
49 | |
50 | NodeValue GraphGradMapper::getGradient(NodeValue activation) { |
51 | return map_[activation]; |
52 | } |
53 | |
54 | //===----------------------------------------------------------------------===// |
55 | // Code for automatically generating the back propagation code. |
56 | //===----------------------------------------------------------------------===// |
57 | |
58 | Function *glow::differentiate(Function *F, const TrainingConfig &conf, |
59 | llvm::StringRef newFuncName, |
60 | VariableGradientsList *varGrads) { |
61 | // Create a new name for the differentiated function, if none is given. |
62 | std::string tmpName; |
63 | if (newFuncName.empty()) { |
64 | tmpName = std::string(F->getName()) + "_grad" ; |
65 | newFuncName = tmpName; |
66 | } |
67 | |
68 | // Clone the function. |
69 | Function *G = F->clone(newFuncName); |
70 | |
71 | using Kind = glow::Kinded::Kind; |
72 | GraphGradMapper map(G); |
73 | |
74 | // A list of nodes to add to the Function. |
75 | std::vector<Node *> toAppend; |
76 | |
77 | // Generate the gradient nodes for each one of the nodes in the function. |
78 | |
79 | PostOrderVisitor pov; |
80 | for (auto &N : G->getNodes()) { |
81 | N.visit(nullptr, &pov); |
82 | } |
83 | |
84 | auto nodes = pov.getPostOrder(); |
85 | |
86 | for (auto it = nodes.rbegin(), e = nodes.rend(); it != e; it++) { |
87 | Node *N = *it; |
88 | if (isa<Storage>(N)) { |
89 | continue; |
90 | } |
91 | |
92 | #define CONVERT_TO_GRAD_NODE(NodeKind) \ |
93 | if (N->getKind() == Kind::NodeKind##Kind) { \ |
94 | toAppend.push_back(cast<NodeKind>(N)->getGrad(map)); \ |
95 | continue; \ |
96 | } |
97 | |
98 | CONVERT_TO_GRAD_NODE(ConvolutionNode) |
99 | CONVERT_TO_GRAD_NODE(AvgPoolNode) |
100 | CONVERT_TO_GRAD_NODE(AdaptiveAvgPoolNode) |
101 | CONVERT_TO_GRAD_NODE(FullyConnectedNode) |
102 | CONVERT_TO_GRAD_NODE(LocalResponseNormalizationNode) |
103 | CONVERT_TO_GRAD_NODE(SoftMaxNode) |
104 | CONVERT_TO_GRAD_NODE(CrossEntropyLossNode) |
105 | CONVERT_TO_GRAD_NODE(RegressionNode) |
106 | CONVERT_TO_GRAD_NODE(AddNode) |
107 | CONVERT_TO_GRAD_NODE(MulNode) |
108 | CONVERT_TO_GRAD_NODE(SubNode) |
109 | CONVERT_TO_GRAD_NODE(DivNode) |
110 | CONVERT_TO_GRAD_NODE(ReluNode) |
111 | CONVERT_TO_GRAD_NODE(SigmoidNode) |
112 | CONVERT_TO_GRAD_NODE(TanhNode) |
113 | CONVERT_TO_GRAD_NODE(SparseLengthsWeightedSumNode) |
114 | CONVERT_TO_GRAD_NODE(SparseLengthsSumNode) |
115 | |
116 | if (N->getKind() == Kind::BatchedPairwiseDotProductNodeKind) { |
117 | BatchedPairwiseDotProductNode *BPDPN = |
118 | cast<BatchedPairwiseDotProductNode>(N); |
119 | auto outputGrad = map.getGradient(BPDPN->getResult()); |
120 | |
121 | auto *X = new BatchedPairwiseDotProductGradNode( |
122 | DECORATE_NODE_NAME(N, "grad" ), outputGrad, BPDPN->getInputs()); |
123 | |
124 | size_t i = 0; |
125 | for (auto &in : BPDPN->getInputs()) { |
126 | X->addExtraResult(in.getType()); |
127 | map.addGradient(in, X->getNthResult(i)); |
128 | ++i; |
129 | } |
130 | |
131 | toAppend.push_back(X); |
132 | continue; |
133 | } |
134 | |
135 | if (N->getKind() == Kind::SaveNodeKind) { |
136 | // Swap the src and dest. Send the Zero value as gradient for both sides. |
137 | auto *X = new SplatNode(DECORATE_NODE_NAME(N, "grad" ), |
138 | cast<SaveNode>(N)->getInput().getType(), 0); |
139 | toAppend.push_back(X); |
140 | map.addGradient(cast<SaveNode>(N)->getInput(), X); |
141 | map.addGradient(cast<SaveNode>(N)->getOutput(), X); |
142 | continue; |
143 | } |
144 | |
145 | if (N->getKind() == Kind::MaxPoolNodeKind) { |
146 | auto *MPN = llvm::cast<MaxPoolNode>(N); |
147 | // Argmax cannot be differentiated. Assert it has no users, and use a zero |
148 | // Splat for its grad input so it doesn't have a null input. |
149 | assert(MPN->getArgmax().getNumUsers() == 0 && |
150 | "Argmax cannot be differentiated; must go unused." ); |
151 | auto *ZSN = new SplatNode(DECORATE_NODE_NAME(N, "grad" ), |
152 | MPN->getArgmax().getType(), 0); |
153 | toAppend.push_back(ZSN); |
154 | map.addGradient(MPN->getArgmax(), ZSN); |
155 | toAppend.push_back(MPN->getGrad(map)); |
156 | continue; |
157 | } |
158 | |
159 | if (N->getKind() == Kind::ReshapeNodeKind) { |
160 | ReshapeNode *RN = cast<ReshapeNode>(N); |
161 | NodeValue outputG = map.getGradient(RN->getResult()); |
162 | NodeValue inputW = RN->getInput(); |
163 | |
164 | // Swap the src and dest. |
165 | auto *X = new ReshapeNode(DECORATE_NODE_NAME(RN, "grad" , "reshape" ), |
166 | inputW.getType(), outputG, |
167 | inputW.getType()->dims(), RN->getLayout()); |
168 | toAppend.push_back(X); |
169 | map.addGradient(RN->getInput(), X); |
170 | continue; |
171 | } |
172 | |
173 | if (N->getKind() == Kind::TileNodeKind) { |
174 | TileNode *TN = cast<TileNode>(N); |
175 | NodeValue outputG = map.getGradient(TN->getResult()); |
176 | |
177 | // To compute the gradient with respect to the input of the TileNode, all |
178 | // of the slices in outputG corresponding to the tiled slices in the |
179 | // forward pass need to be added together. This is achieved by reshaping |
180 | // outputG to replace the tiling axis with {numTiles, tileDim}, and then |
181 | // performing a BatchedReduceAdd on the axis with numTiles elements. For |
182 | // example, if the tile creates a {n,x,h,w} output with a {n,c,h,w} |
183 | // input where x = c * numTiles, then the {n,x,h,w} gradient with respect |
184 | // to the output is reshaped to {n, numTiles, c, h, w} so that |
185 | // BatchedReduceAddNode eliminates the numTiles axis and produces a |
186 | // {n,c,h,w} output. |
187 | auto *TNInputType = TN->getInput().getType(); |
188 | std::vector<dim_t> BRAInputDims{TNInputType->dims()}; |
189 | BRAInputDims.insert(BRAInputDims.begin() + TN->getAxis(), TN->getCount()); |
190 | auto *BRAInputType = |
191 | F->getParent()->uniqueTypeWithNewShape(TNInputType, BRAInputDims); |
192 | |
193 | auto *RN = |
194 | new ReshapeNode(DECORATE_NODE_NAME(TN, "grad" , "reshape" ), |
195 | BRAInputType, outputG, BRAInputType->dims(), "*" ); |
196 | auto *BRA = |
197 | new BatchedReduceAddNode(DECORATE_NODE_NAME(TN, "grad" , "bra" ), |
198 | TN->getInput().getType(), RN, TN->getAxis()); |
199 | |
200 | toAppend.push_back(RN); |
201 | toAppend.push_back(BRA); |
202 | map.addGradient(TN->getInput(), BRA); |
203 | continue; |
204 | } |
205 | |
206 | if (N->getKind() == Kind::ConvertToNodeKind) { |
207 | auto *RN = cast<ConvertToNode>(N); |
208 | NodeValue outputG = map.getGradient(RN->getResult()); |
209 | NodeValue inputW = RN->getInput(); |
210 | |
211 | // Swap the src and dest. |
212 | auto *X = new ConvertToNode(DECORATE_NODE_NAME(N, "grad" ), |
213 | inputW.getType(), outputG); |
214 | toAppend.push_back(X); |
215 | map.addGradient(RN->getInput(), X); |
216 | continue; |
217 | } |
218 | |
219 | if (N->getKind() == Kind::TransposeNodeKind) { |
220 | TransposeNode *TN = cast<TransposeNode>(N); |
221 | NodeValue outputG = map.getGradient(TN->getResult()); |
222 | NodeValue inputW = TN->getInput(); |
223 | |
224 | // Generate the reverse shuffle. |
225 | auto shuffle = TN->getShuffle(); |
226 | auto layout = TN->getLayout(); |
227 | std::string reverseLayout; |
228 | reverseLayout.resize(TN->getLayout().size()); |
229 | std::vector<unsigned_t> reverseShuffle(shuffle.begin(), shuffle.end()); |
230 | for (unsigned int i = 0; i < shuffle.size(); i++) { |
231 | reverseShuffle[shuffle[i]] = i; |
232 | reverseLayout[shuffle[i]] = layout[i]; |
233 | } |
234 | |
235 | // Swap the src and dest. |
236 | auto *X = |
237 | new TransposeNode(DECORATE_NODE_NAME(N, "grad" ), inputW.getType(), |
238 | outputG, reverseShuffle, reverseLayout); |
239 | toAppend.push_back(X); |
240 | map.addGradient(TN->getInput(), X); |
241 | continue; |
242 | } |
243 | |
244 | if (N->getKind() == Kind::SliceNodeKind) { |
245 | SliceNode *SN = cast<SliceNode>(N); |
246 | auto *zero = new SplatNode(DECORATE_NODE_NAME(SN, "expand" ), |
247 | SN->getInput().getType(), 0); |
248 | auto *insert = |
249 | new InsertTensorNode(DECORATE_NODE_NAME(SN, "grad" ), zero, |
250 | map.getGradient(SN->getResult()), SN->getStart(), |
251 | /* count */ 1, /* axis */ 0); |
252 | |
253 | toAppend.push_back(zero); |
254 | toAppend.push_back(insert); |
255 | map.addGradient(SN->getInput(), insert); |
256 | continue; |
257 | } |
258 | |
259 | if (N->getKind() == Kind::ConcatNodeKind) { |
260 | auto *CC = cast<ConcatNode>(N); |
261 | auto inputs = CC->getInputs(); |
262 | NodeValue outputG = map.getGradient(CC->getResult()); |
263 | |
264 | // We start extracting the shape at (0,0, ... ). |
265 | std::vector<dim_t> offsets(CC->getResult().dims().size(), 0); |
266 | unsigned_t dim = CC->getDim(); |
267 | for (auto &N : inputs) { |
268 | // SliceNode's name will be auto incremented due to name uniqueness. |
269 | auto *X = new SliceNode(DECORATE_NODE_NAME(CC, "extract" ), N.getType(), |
270 | outputG, offsets); |
271 | toAppend.push_back(X); |
272 | // We are stacking the tensors along a specific dimension. This means |
273 | // that we increase the size of the tensor along this dimension. |
274 | offsets[dim] += N.dims()[dim]; |
275 | map.addGradient(N, X); |
276 | } |
277 | continue; |
278 | } |
279 | |
280 | if (N->getKind() == Kind::SplatNodeKind) |
281 | // Constant nodes don't have inputs therefore don't need grad |
282 | // calculations. |
283 | continue; |
284 | |
285 | if (N->getKind() == Kind::BatchNormalizationNodeKind) { |
286 | auto *BN = cast<BatchNormalizationNode>(N); |
287 | auto in = BN->getInput(); |
288 | auto mean = BN->getMean(); |
289 | auto var = BN->getVar(); |
290 | auto channelIdx = BN->getChannelIdx(); |
291 | auto momentum = BN->getMomentum(); |
292 | |
293 | // Update the mean and variance via the MeanVarNormalizationNode. |
294 | auto *MVN = new MeanVarNormalizationNode( |
295 | DECORATE_NODE_NAME(BN, "grad" ), in, mean, var, channelIdx, momentum); |
296 | toAppend.push_back(MVN); |
297 | |
298 | // Save the newly calculated mean and variance to the mean and variance |
299 | // variables. These will be used during the next iteration of training. |
300 | G->createSave(DECORATE_NODE_NAME(MVN, "mean" ), MVN->getNewMean(), |
301 | llvm::cast<Placeholder>(mean.getNode())); |
302 | G->createSave(DECORATE_NODE_NAME(MVN, "var" ), MVN->getNewVar(), |
303 | llvm::cast<Placeholder>(var.getNode())); |
304 | |
305 | // Replace the BN's mean and variance with the new mean and variance |
306 | // calculated from MVN. |
307 | BN->getParent()->getLogContext()->logNodeInputChange( |
308 | *BN, BN->getNthInput(BatchNormalizationNode::MeanIdx), mean); |
309 | BN->setNthInput(BatchNormalizationNode::MeanIdx, mean); |
310 | BN->getParent()->getLogContext()->logNodeInputChange( |
311 | *BN, BN->getNthInput(BatchNormalizationNode::VarIdx), var); |
312 | BN->setNthInput(BatchNormalizationNode::VarIdx, var); |
313 | |
314 | toAppend.push_back(BN->getGrad(map)); |
315 | |
316 | continue; |
317 | } |
318 | |
319 | if (N->getKind() == Kind::MatMulNodeKind) { |
320 | MatMulNode *MMN = cast<MatMulNode>(N); |
321 | // Get gradient. |
322 | NodeValue OutputG = map.getGradient(MMN->getResult()); |
323 | |
324 | // Get LHS/RHS inputs and their transpose presentations. |
325 | NodeValue InputLHS = MMN->getLHS(); |
326 | NodeValue InputRHS = MMN->getRHS(); |
327 | auto *LT = G->createTranspose( |
328 | DECORATE_NODE_NAME(MMN, "grad" , "lhs" , "transpose" ), InputLHS, |
329 | {1, 0}); |
330 | auto *RT = G->createTranspose( |
331 | DECORATE_NODE_NAME(MMN, "grad" , "rhs" , "transpose" ), InputRHS, |
332 | {1, 0}); |
333 | |
334 | // Grad for LHS = outputG x transpose(RHS). |
335 | auto *GradLHS = new MatMulNode(DECORATE_NODE_NAME(MMN, "grad" , "lhs" ), |
336 | InputLHS.getType(), OutputG, RT); |
337 | // Grad for RHS = transpose(LHS) x outputG. |
338 | auto *GradRHS = new MatMulNode(DECORATE_NODE_NAME(MMN, "grad" , "rhs" ), |
339 | InputRHS.getType(), LT, OutputG); |
340 | |
341 | toAppend.push_back(GradLHS); |
342 | map.addGradient(InputLHS, GradLHS); |
343 | toAppend.push_back(GradRHS); |
344 | map.addGradient(InputRHS, GradRHS); |
345 | continue; |
346 | } |
347 | |
348 | if (N->getKind() == Kind::BatchMatMulNodeKind) { |
349 | BatchMatMulNode *BMMN = cast<BatchMatMulNode>(N); |
350 | // Get gradient. |
351 | NodeValue OutputG = map.getGradient(BMMN->getResult()); |
352 | |
353 | // The implementation below is a batched version of the gradient |
354 | // computation for MatMul. |
355 | NodeValue InputLHS = BMMN->getLHS(); |
356 | NodeValue InputRHS = BMMN->getRHS(); |
357 | auto *LT = G->createTranspose( |
358 | DECORATE_NODE_NAME(BMMN, "grad" , "lhs" , "transpose" ), InputLHS, |
359 | {0, 2, 1}); |
360 | auto *RT = G->createTranspose( |
361 | DECORATE_NODE_NAME(BMMN, "grad" , "lhs" , "transpose" ), InputRHS, |
362 | {0, 2, 1}); |
363 | |
364 | // Grad for LHS = outputG x transpose(RHS). |
365 | auto *GradLHS = |
366 | new BatchMatMulNode(DECORATE_NODE_NAME(BMMN, "grad" , "lhs" ), |
367 | InputLHS.getType(), OutputG, RT); |
368 | // Grad for RHS = transpose(LHS) x outputG. |
369 | auto *GradRHS = |
370 | new BatchMatMulNode(DECORATE_NODE_NAME(BMMN, "grad" , "lhs" ), |
371 | InputRHS.getType(), LT, OutputG); |
372 | |
373 | toAppend.push_back(GradLHS); |
374 | map.addGradient(InputLHS, GradLHS); |
375 | toAppend.push_back(GradRHS); |
376 | map.addGradient(InputRHS, GradRHS); |
377 | continue; |
378 | } |
379 | |
380 | if (N->getKind() == Kind::BatchedReduceAddNodeKind) { |
381 | BatchedReduceAddNode *BRA = cast<BatchedReduceAddNode>(N); |
382 | // Get gradient. |
383 | NodeValue OutputG = map.getGradient(BRA->getResult()); |
384 | // Get input value. |
385 | NodeValue Input = BRA->getBatch(); |
386 | |
387 | // Gradient for BatchedReduceAddNode is TileNode, |
388 | // repeating OutputG batch times. |
389 | auto Axis = BRA->getAxis(); |
390 | // Copy input dimensions first. |
391 | std::vector<dim_t> Dims{Input.dims()}; |
392 | // Then set to 1 dimension size on axis. |
393 | Dims[Axis] = 1; |
394 | auto *RSN = G->createReshape(DECORATE_NODE_NAME(BRA, "grad" , "reshape" ), |
395 | OutputG, Dims); |
396 | auto *TN = |
397 | new TileNode(DECORATE_NODE_NAME(BRA, "grad" , "tile" ), Input.getType(), |
398 | RSN->getResult(), Input.dims()[Axis], Axis); |
399 | |
400 | toAppend.push_back(TN); |
401 | map.addGradient(Input, TN); |
402 | continue; |
403 | } |
404 | |
405 | if (N->getKind() == Kind::GatherNodeKind) { |
406 | GatherNode *GN = cast<GatherNode>(N); |
407 | // Get gradient. |
408 | NodeValue Result = GN->getResult(); |
409 | NodeValue OutputG = map.getGradient(Result); |
410 | // Get Data & Indices. |
411 | NodeValue Data = GN->getData(); |
412 | NodeValue Indices = GN->getIndices(); |
413 | |
414 | // Reshape indices into a two-dimensional Tensor (Vector). |
415 | std::vector<dim_t> IndicesDims{Indices.getType()->size(), 1}; |
416 | auto *RI = |
417 | G->createReshape(DECORATE_NODE_NAME(GN, "grad" , "reshape" , "indices" ), |
418 | Indices, IndicesDims); |
419 | |
420 | // Reshape Gradient into N-k dimension, where k is Index dimensions, |
421 | // except the case when Indices is one-dimensional. |
422 | ReshapeNode *RG = nullptr; |
423 | auto K = Indices.dims().size(); |
424 | if (K != 1) { |
425 | const auto &OrgDims = OutputG.dims(); |
426 | std::vector<dim_t> GDims{OrgDims.begin() + K - 1, OrgDims.end()}; |
427 | for (dim_t k = 0; k < K - 1; ++k) { |
428 | GDims[0] *= OrgDims[k]; |
429 | } |
430 | RG = G->createReshape( |
431 | DECORATE_NODE_NAME(GN, "grad" , "reshape" , "output" ), OutputG, |
432 | GDims); |
433 | } |
434 | // Reshaped Indices Vector maps Reshaped Gradient Tensors |
435 | // to the correspondent Data Tensors, where Vector value |
436 | // points to Data Tensor. |
437 | auto *SN = |
438 | G->createSplat(DECORATE_NODE_NAME(GN, "splat" ), Data.getType(), 0); |
439 | auto *SA = new ScatterDataNode(DECORATE_NODE_NAME(GN, "scatter_assign" ), |
440 | SN->getResult(), RI->getResult(), |
441 | RG ? RG->getResult() : OutputG, |
442 | /*cumulative*/ false); |
443 | toAppend.push_back(SA); |
444 | map.addGradient(Data, SA); |
445 | continue; |
446 | } |
447 | |
448 | llvm_unreachable("Invalid instruction type." ); |
449 | } // End of the for-each instr loop. |
450 | |
451 | for (auto N : nodes) { |
452 | // Iterate only through Placeholders used by the Function. These are |
453 | // inserted during the post-order walk. |
454 | Placeholder *PH = llvm::dyn_cast<Placeholder>(N); |
455 | if (!PH) |
456 | continue; |
457 | |
458 | // In this special differentiation mode we record the last gradient value |
459 | // without performing the SGD update. This mode is used by the unit tests. |
460 | if (varGrads) { |
461 | if (map.hasGradient(PH)) { |
462 | std::string nodeName = "_grad_" + PH->getName().str(); |
463 | // Save the gradient and return the destination variable. |
464 | auto *saveNode = G->createSave(nodeName, map.getGradient(PH)); |
465 | Placeholder *GradV = saveNode->getPlaceholder(); |
466 | varGrads->push_back({PH, GradV}); |
467 | } |
468 | continue; |
469 | } |
470 | |
471 | // Don't update nodes that are not marked as trainable. |
472 | if (!PH->isTraining()) { |
473 | continue; |
474 | } |
475 | |
476 | auto X = new SGDNode(PH->getName(), map.getGradient(PH), PH, conf.L1Decay, |
477 | conf.L2Decay, conf.learningRate, conf.momentum, |
478 | conf.batchSize); |
479 | toAppend.push_back(X); |
480 | // Now update the weight with the value computed by SGD. |
481 | auto *save = |
482 | new SaveNode(DECORATE_NODE_NAME(PH, "save" , "grad" ), {X, 0}, PH); |
483 | toAppend.push_back(save); |
484 | } |
485 | |
486 | // Add all of the new variables and instructions. |
487 | for (auto &I : toAppend) { |
488 | G->addNode(I); |
489 | } |
490 | |
491 | return G; |
492 | } |
493 | |
494 | #undef DECORATE_NODE_NAME |
495 | |