1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16#include "glow/Converter/FunctionConverter.h"
17
18#include "glow/Graph/Graph.h" // For Function.
19#include "glow/Graph/Node.h" // For Node.
20#include "glow/Graph/Nodes.h" // For Placeholder and Constant.
21#include "glow/Graph/PlaceholderBindings.h"
22
23#include "llvm/ADT/DenseMap.h"
24
25using namespace glow;
26
27TypeRef
28FunctionConverter::getTargetTypeForOutput(const NodeValue &nodeVal) const {
29 // Default implementation says there is nothing to do.
30 return nullptr;
31}
32
33TypeRef FunctionConverter::getTargetTypeForInput(const Node &use,
34 unsigned idx) const {
35 // Default implementation says there is nothing to do.
36 return nullptr;
37}
38
39bool FunctionConverter::canConvert(const Node &node) const {
40 // By default, we assume everything is convertible.
41 switch (node.getKind()) {
42 default:
43 return true;
44 case Kinded::Kind::PlaceholderKind:
45 case Kinded::Kind::SaveNodeKind:
46 // Save node or placeholder special because
47 // they are or their effects are visible from
48 // the outside of the function being converted.
49 // Thus, we cannot convert them, unless we change
50 // the semantic of this function and the related
51 // placeholder.
52 return false;
53 }
54}
55
56NodeValue FunctionConverter::getConversionOutput(Node &conversion) const {
57 assert(conversion.getNumResults() == 1 && "This method should be overloaded");
58 return NodeValue(&conversion, 0);
59}
60
61Node &FunctionConverter::morphNode(Node &node) { return node; }
62
63void FunctionConverter::postProcessing(Node &node) {}
64
65void FunctionConverter::convertOutputs(Node &node) {
66 using FunctionAndValIdx = std::pair<Function *, unsigned>;
67 llvm::DenseMap<FunctionAndValIdx, NodeValue> functionAndValToConversion;
68 for (unsigned idx = 0, end = node.getNumResults(); idx != end; ++idx) {
69 NodeValue val = node.getNthResult(idx);
70 TypeRef targetTy = getTargetTypeForOutput(val);
71 if (!targetTy || targetTy == val.getType()) {
72 continue;
73 }
74 // convert the node and create a conversion to keep the users happy.
75 assert(targetTy->dims() == val.getType()->dims() &&
76 "Conversion does not preserve shape");
77 TypeRef origTy = val.getType();
78 // Fake the morphing of the node so that the creation
79 // of the conversion works properly.
80 val.setType(targetTy);
81 // Store the users in a temporary object because setOperand
82 // will invalidate the iterator.
83 llvm::SmallVector<NodeUse, 4> users(val.getUsers().begin(),
84 val.getUsers().end());
85 // We cannot use replaceAllUsesWith here because:
86 // 1. At this point, val and conversion don't have the same type
87 // (one is converted the other is the original type), and that
88 // would trigger an assertion.
89 // 2. We would end up replacing the use of val in "conversion" by
90 // "conversion".
91 // 3. Node may be a module level value and we need one conversion per
92 // function.
93 for (auto use : users) {
94 Node *user = use.getUser();
95 Function *parent = user->getParent();
96 assert(parent && "User not in a function?!");
97
98 SaveNode *saveNode = llvm::dyn_cast<SaveNode>(user);
99 // The output of save nodes is special because it doesn't use
100 // the value of the node, but its address.
101 // Thus, if we want to change the value of the output of
102 // a save node, we actually have to convert the input.
103 if (saveNode && saveNode->getOutput() == val) {
104 NodeValue input = saveNode->getInput();
105 Node *conversion = createConversion(*parent, node, input, targetTy,
106 /* isInput */ false);
107 saveNode->setNthInput(SaveNode::InputIdx,
108 getConversionOutput(*conversion));
109 continue;
110 }
111
112 FunctionAndValIdx functionAndVal = std::make_pair(parent, idx);
113 auto conversionValIt = functionAndValToConversion.find(functionAndVal);
114 if (conversionValIt == functionAndValToConversion.end()) {
115 // Create the conversion.
116 Node *conversion =
117 createConversion(*parent, node, val, origTy, /* isInput */ false);
118 // "conversion" uses val so after this call,
119 // we will get a use of conversion inside conversion.
120 NodeValue conversionVal = getConversionOutput(*conversion);
121 auto insertion =
122 functionAndValToConversion.insert({functionAndVal, conversionVal});
123 assert(insertion.second && "Conversion already there?!");
124 conversionValIt = insertion.first;
125 }
126
127 NodeValue conversionVal = conversionValIt->second;
128 if (user == conversionVal.getNode()) {
129 continue;
130 }
131 // Log the change of node input(operand).
132 if (Function *F = node.getParent()) {
133 F->getLogContext()->logNodeInputChange(*user, *(use.get()),
134 conversionVal);
135 }
136
137 use.get()->setOperand(conversionVal.getNode(), conversionVal.getResNo());
138 }
139 }
140}
141
142void FunctionConverter::convertInputs(Node &node) {
143 // We shouldn't have to convert the inputs of something that is not in
144 // function_.
145 assert((node.getNumInputs() == 0 || node.getParent() == &function_) &&
146 "Invalid requested conversion");
147 for (unsigned idx = 0, end = node.getNumInputs(); idx != end; ++idx) {
148 NodeValue val = node.getNthInput(idx);
149 TypeRef targetTy = getTargetTypeForInput(node, idx);
150 if (!targetTy || targetTy == val.getType()) {
151 continue;
152 }
153 // convert the node and create a conversion to keep the users happy.
154 assert(targetTy->dims() == val.getType()->dims() &&
155 "Conversion does not preserve shape");
156 // Create the conversion.
157 Node *conversion =
158 createConversion(function_, node, val, targetTy, /* isInput */ true);
159 node.setNthInput(idx, getConversionOutput(*conversion));
160 }
161}
162
163void FunctionConverter::convert() {
164 assert(function_.verify() && "Input function must be valid");
165
166 // Traverse all nodes.
167 // Check what the conversion should look like, if any.
168 // Convert the node appropriately.
169
170 // For every unprocessed node in the graph we keep the invariant of having
171 // all inputs to be of the uncovered type.
172 // I.e., if we have:
173 // res(outTy) = node arg1(in2Ty), arg2(in2Ty)
174 //
175 // after converting "node", we will have something that looks like:
176 // newArg1(convertedIn1Ty) = conversion arg1
177 // newArg2(convertedIn2Ty) = conversion arg2
178 // newRes(convertedOutTy) = node newArg1, newArg2
179 // res(outTy) = conversion newRes
180 //
181 // In other words, the boundaries (in and out) are unchanged.
182
183 // The iterator looks weird because we only want to iterate through
184 // the original nodes.
185 auto nodeIt = function_.getNodes().end();
186 auto stopIt = function_.getNodes().begin();
187 do {
188 --nodeIt;
189 Node &node = *nodeIt;
190 if (!canConvert(node)) {
191 continue;
192 }
193 // Mutate the output types and insert the conversion to keep our
194 // invariant.
195 convertOutputs(node);
196 // Convert the inputs of the node.
197 convertInputs(node);
198 // All the surrounding code is properly typed, finally the morph node.
199 Node &morphedNode = morphNode(node);
200 // Do some post processing if need be.
201 postProcessing(morphedNode);
202 } while (nodeIt != stopIt);
203
204 // Allow a late clean-up before verifying the conversation produced a valid
205 // function.
206 cleanUp();
207
208 assert(function_.verify() && "Conversion led to invalid function");
209}
210
211void FunctionConverter::convertPlaceholder(Placeholder &placeholder,
212 PlaceholderBindings *bindings) {
213 TypeRef destTy = getTargetTypeForOutput(placeholder.getOutput());
214 if (!destTy || destTy == placeholder.getType()) {
215 return;
216 }
217 convertOutputs(placeholder);
218 if (!bindings) {
219 return;
220 }
221 Tensor *tensor = bindings->get(&placeholder);
222 if (tensor) {
223 convertTensor(*tensor, destTy);
224 }
225}
226