1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Optimizer/GraphOptimizer/TrainingPreparation.h"
18
19#include "glow/Base/Tensor.h"
20#include "glow/Graph/PlaceholderBindings.h"
21
22namespace glow {
23
24namespace {
25void defaultTensorInitializer(Function *F, Node *node, unsigned inputIdx,
26 Tensor *tensor) {
27 switch (node->getKind()) {
28 case Kinded::Kind::ConvolutionNodeKind: {
29 if (ConvolutionNode::FilterIdx == inputIdx) {
30 ConvolutionNode *CN = llvm::cast<ConvolutionNode>(node);
31 ShapeNHWC idim = ShapeNHWC(CN->getInput().dims());
32 ShapeHW kdim(CN->getKernels());
33 size_t fanIn = kdim.height * kdim.width * idim.c;
34 tensor->init(Tensor::InitKind::Xavier, fanIn, F->getPRNG());
35 } else if (ConvolutionNode::BiasIdx == inputIdx) {
36 tensor->init(Tensor::InitKind::Broadcast, 0.1, F->getPRNG());
37 }
38 break;
39 }
40 case Kinded::Kind::BatchNormalizationNodeKind: {
41 if (BatchNormalizationNode::ScaleIdx == inputIdx) {
42 tensor->init(Tensor::InitKind::Zero, 0, F->getPRNG());
43 } else if (BatchNormalizationNode::BiasIdx == inputIdx) {
44 tensor->init(Tensor::InitKind::Broadcast, 0.1, F->getPRNG());
45 } else if (BatchNormalizationNode::MeanIdx == inputIdx) {
46 tensor->init(Tensor::InitKind::Zero, 0, F->getPRNG());
47 } else if (BatchNormalizationNode::VarIdx == inputIdx) {
48 tensor->init(Tensor::InitKind::Broadcast, 1.0, F->getPRNG());
49 }
50 break;
51 }
52 case Kinded::Kind::FullyConnectedNodeKind: {
53 if (FullyConnectedNode::WeightsIdx == inputIdx) {
54 FullyConnectedNode *FCN = llvm::cast<FullyConnectedNode>(node);
55 auto in = FCN->getInput();
56 tensor->init(Tensor::InitKind::Xavier, in.dims()[1], F->getPRNG());
57 } else if (FullyConnectedNode::BiasIdx == inputIdx) {
58 tensor->init(Tensor::InitKind::Broadcast, 0.1, F->getPRNG());
59 }
60 break;
61 }
62 case Kinded::Kind::SoftMaxNodeKind: {
63 if (SoftMaxNode::SelectedIdx == inputIdx) {
64 tensor->zero();
65 }
66 break;
67 }
68 default:
69 break;
70 }
71}
72} // namespace
73
74TensorInitializer getDefaultTensorInitializer() {
75 return defaultTensorInitializer;
76}
77
78Error prepareFunctionForTraining(Function *F, PlaceholderBindings &bindings,
79 Placeholder *&selected,
80 TensorInitializer &&initializer) {
81
82 auto &nodes = F->getNodes();
83
84 selected = nullptr;
85 // Lookup all nodes, skip Storage types, enumerate inputs,
86 // replace Constant type with trainable Placeholders except special cases,
87 // like BatchNormalization inputs (mean and variance). In special cases
88 // replace Constant type with non-trainable Placeholders.
89 for (auto &node : nodes) {
90 // Skip storages.
91 if (llvm::isa<Storage>(&node)) {
92 continue;
93 }
94
95 const bool isSoftMax = node.getKind() == Kinded::Kind::SoftMaxNodeKind;
96 const bool isBatchNormalization =
97 node.getKind() == Kinded::Kind::BatchNormalizationNodeKind;
98
99 for (unsigned idx = 0, e = node.getNumInputs(); idx < e; ++idx) {
100 auto *IN = node.getNthInput(idx).getNode();
101 Constant *C = llvm::dyn_cast<Constant>(IN);
102 if (!C) {
103 continue;
104 }
105
106 // Condition for NON trainable case
107 // isSoftMax || isBatchNormalization &&
108 // (BatchNormalizationNode::MeanIdx == idx ||
109 // BatchNormalizationNode::VarIdx == idx)
110
111 const bool isTrainable =
112 !isSoftMax &&
113 (!isBatchNormalization || (BatchNormalizationNode::MeanIdx != idx &&
114 BatchNormalizationNode::VarIdx != idx));
115
116 auto *PH = F->getParent()->createPlaceholder(C->getType(), C->getName(),
117 isTrainable);
118
119 if (isSoftMax) {
120 selected = PH;
121 }
122 C->getOutput().replaceAllUsesOfWith(PH, F);
123 auto &tensor = C->getPayloadMutable();
124 initializer(F, &node, idx, &tensor);
125 bindings.insert(PH, std::move(tensor));
126 RETURN_ERR_IF_NOT(!C->hasUsers(), "Constant is still in use.");
127 F->getParent()->eraseConstant(C);
128 }
129 }
130
131 return Error::success();
132}
133} // namespace glow
134