1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Optimizer/GraphOptimizer/TrainingPreparation.h" |
18 | |
19 | #include "glow/Base/Tensor.h" |
20 | #include "glow/Graph/PlaceholderBindings.h" |
21 | |
22 | namespace glow { |
23 | |
24 | namespace { |
25 | void defaultTensorInitializer(Function *F, Node *node, unsigned inputIdx, |
26 | Tensor *tensor) { |
27 | switch (node->getKind()) { |
28 | case Kinded::Kind::ConvolutionNodeKind: { |
29 | if (ConvolutionNode::FilterIdx == inputIdx) { |
30 | ConvolutionNode *CN = llvm::cast<ConvolutionNode>(node); |
31 | ShapeNHWC idim = ShapeNHWC(CN->getInput().dims()); |
32 | ShapeHW kdim(CN->getKernels()); |
33 | size_t fanIn = kdim.height * kdim.width * idim.c; |
34 | tensor->init(Tensor::InitKind::Xavier, fanIn, F->getPRNG()); |
35 | } else if (ConvolutionNode::BiasIdx == inputIdx) { |
36 | tensor->init(Tensor::InitKind::Broadcast, 0.1, F->getPRNG()); |
37 | } |
38 | break; |
39 | } |
40 | case Kinded::Kind::BatchNormalizationNodeKind: { |
41 | if (BatchNormalizationNode::ScaleIdx == inputIdx) { |
42 | tensor->init(Tensor::InitKind::Zero, 0, F->getPRNG()); |
43 | } else if (BatchNormalizationNode::BiasIdx == inputIdx) { |
44 | tensor->init(Tensor::InitKind::Broadcast, 0.1, F->getPRNG()); |
45 | } else if (BatchNormalizationNode::MeanIdx == inputIdx) { |
46 | tensor->init(Tensor::InitKind::Zero, 0, F->getPRNG()); |
47 | } else if (BatchNormalizationNode::VarIdx == inputIdx) { |
48 | tensor->init(Tensor::InitKind::Broadcast, 1.0, F->getPRNG()); |
49 | } |
50 | break; |
51 | } |
52 | case Kinded::Kind::FullyConnectedNodeKind: { |
53 | if (FullyConnectedNode::WeightsIdx == inputIdx) { |
54 | FullyConnectedNode *FCN = llvm::cast<FullyConnectedNode>(node); |
55 | auto in = FCN->getInput(); |
56 | tensor->init(Tensor::InitKind::Xavier, in.dims()[1], F->getPRNG()); |
57 | } else if (FullyConnectedNode::BiasIdx == inputIdx) { |
58 | tensor->init(Tensor::InitKind::Broadcast, 0.1, F->getPRNG()); |
59 | } |
60 | break; |
61 | } |
62 | case Kinded::Kind::SoftMaxNodeKind: { |
63 | if (SoftMaxNode::SelectedIdx == inputIdx) { |
64 | tensor->zero(); |
65 | } |
66 | break; |
67 | } |
68 | default: |
69 | break; |
70 | } |
71 | } |
72 | } // namespace |
73 | |
74 | TensorInitializer getDefaultTensorInitializer() { |
75 | return defaultTensorInitializer; |
76 | } |
77 | |
78 | Error prepareFunctionForTraining(Function *F, PlaceholderBindings &bindings, |
79 | Placeholder *&selected, |
80 | TensorInitializer &&initializer) { |
81 | |
82 | auto &nodes = F->getNodes(); |
83 | |
84 | selected = nullptr; |
85 | // Lookup all nodes, skip Storage types, enumerate inputs, |
86 | // replace Constant type with trainable Placeholders except special cases, |
87 | // like BatchNormalization inputs (mean and variance). In special cases |
88 | // replace Constant type with non-trainable Placeholders. |
89 | for (auto &node : nodes) { |
90 | // Skip storages. |
91 | if (llvm::isa<Storage>(&node)) { |
92 | continue; |
93 | } |
94 | |
95 | const bool isSoftMax = node.getKind() == Kinded::Kind::SoftMaxNodeKind; |
96 | const bool isBatchNormalization = |
97 | node.getKind() == Kinded::Kind::BatchNormalizationNodeKind; |
98 | |
99 | for (unsigned idx = 0, e = node.getNumInputs(); idx < e; ++idx) { |
100 | auto *IN = node.getNthInput(idx).getNode(); |
101 | Constant *C = llvm::dyn_cast<Constant>(IN); |
102 | if (!C) { |
103 | continue; |
104 | } |
105 | |
106 | // Condition for NON trainable case |
107 | // isSoftMax || isBatchNormalization && |
108 | // (BatchNormalizationNode::MeanIdx == idx || |
109 | // BatchNormalizationNode::VarIdx == idx) |
110 | |
111 | const bool isTrainable = |
112 | !isSoftMax && |
113 | (!isBatchNormalization || (BatchNormalizationNode::MeanIdx != idx && |
114 | BatchNormalizationNode::VarIdx != idx)); |
115 | |
116 | auto *PH = F->getParent()->createPlaceholder(C->getType(), C->getName(), |
117 | isTrainable); |
118 | |
119 | if (isSoftMax) { |
120 | selected = PH; |
121 | } |
122 | C->getOutput().replaceAllUsesOfWith(PH, F); |
123 | auto &tensor = C->getPayloadMutable(); |
124 | initializer(F, &node, idx, &tensor); |
125 | bindings.insert(PH, std::move(tensor)); |
126 | RETURN_ERR_IF_NOT(!C->hasUsers(), "Constant is still in use." ); |
127 | F->getParent()->eraseConstant(C); |
128 | } |
129 | } |
130 | |
131 | return Error::success(); |
132 | } |
133 | } // namespace glow |
134 | |