1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | #include "glow/Converter/FunctionConverter.h" |
17 | |
18 | #include "glow/Graph/Graph.h" // For Function. |
19 | #include "glow/Graph/Node.h" // For Node. |
20 | #include "glow/Graph/Nodes.h" // For Placeholder and Constant. |
21 | #include "glow/Graph/PlaceholderBindings.h" |
22 | |
23 | #include "llvm/ADT/DenseMap.h" |
24 | |
25 | using namespace glow; |
26 | |
27 | TypeRef |
28 | FunctionConverter::getTargetTypeForOutput(const NodeValue &nodeVal) const { |
29 | // Default implementation says there is nothing to do. |
30 | return nullptr; |
31 | } |
32 | |
33 | TypeRef FunctionConverter::getTargetTypeForInput(const Node &use, |
34 | unsigned idx) const { |
35 | // Default implementation says there is nothing to do. |
36 | return nullptr; |
37 | } |
38 | |
39 | bool FunctionConverter::canConvert(const Node &node) const { |
40 | // By default, we assume everything is convertible. |
41 | switch (node.getKind()) { |
42 | default: |
43 | return true; |
44 | case Kinded::Kind::PlaceholderKind: |
45 | case Kinded::Kind::SaveNodeKind: |
46 | // Save node or placeholder special because |
47 | // they are or their effects are visible from |
48 | // the outside of the function being converted. |
49 | // Thus, we cannot convert them, unless we change |
50 | // the semantic of this function and the related |
51 | // placeholder. |
52 | return false; |
53 | } |
54 | } |
55 | |
56 | NodeValue FunctionConverter::getConversionOutput(Node &conversion) const { |
57 | assert(conversion.getNumResults() == 1 && "This method should be overloaded" ); |
58 | return NodeValue(&conversion, 0); |
59 | } |
60 | |
61 | Node &FunctionConverter::morphNode(Node &node) { return node; } |
62 | |
63 | void FunctionConverter::postProcessing(Node &node) {} |
64 | |
65 | void FunctionConverter::convertOutputs(Node &node) { |
66 | using FunctionAndValIdx = std::pair<Function *, unsigned>; |
67 | llvm::DenseMap<FunctionAndValIdx, NodeValue> functionAndValToConversion; |
68 | for (unsigned idx = 0, end = node.getNumResults(); idx != end; ++idx) { |
69 | NodeValue val = node.getNthResult(idx); |
70 | TypeRef targetTy = getTargetTypeForOutput(val); |
71 | if (!targetTy || targetTy == val.getType()) { |
72 | continue; |
73 | } |
74 | // convert the node and create a conversion to keep the users happy. |
75 | assert(targetTy->dims() == val.getType()->dims() && |
76 | "Conversion does not preserve shape" ); |
77 | TypeRef origTy = val.getType(); |
78 | // Fake the morphing of the node so that the creation |
79 | // of the conversion works properly. |
80 | val.setType(targetTy); |
81 | // Store the users in a temporary object because setOperand |
82 | // will invalidate the iterator. |
83 | llvm::SmallVector<NodeUse, 4> users(val.getUsers().begin(), |
84 | val.getUsers().end()); |
85 | // We cannot use replaceAllUsesWith here because: |
86 | // 1. At this point, val and conversion don't have the same type |
87 | // (one is converted the other is the original type), and that |
88 | // would trigger an assertion. |
89 | // 2. We would end up replacing the use of val in "conversion" by |
90 | // "conversion". |
91 | // 3. Node may be a module level value and we need one conversion per |
92 | // function. |
93 | for (auto use : users) { |
94 | Node *user = use.getUser(); |
95 | Function *parent = user->getParent(); |
96 | assert(parent && "User not in a function?!" ); |
97 | |
98 | SaveNode *saveNode = llvm::dyn_cast<SaveNode>(user); |
99 | // The output of save nodes is special because it doesn't use |
100 | // the value of the node, but its address. |
101 | // Thus, if we want to change the value of the output of |
102 | // a save node, we actually have to convert the input. |
103 | if (saveNode && saveNode->getOutput() == val) { |
104 | NodeValue input = saveNode->getInput(); |
105 | Node *conversion = createConversion(*parent, node, input, targetTy, |
106 | /* isInput */ false); |
107 | saveNode->setNthInput(SaveNode::InputIdx, |
108 | getConversionOutput(*conversion)); |
109 | continue; |
110 | } |
111 | |
112 | FunctionAndValIdx functionAndVal = std::make_pair(parent, idx); |
113 | auto conversionValIt = functionAndValToConversion.find(functionAndVal); |
114 | if (conversionValIt == functionAndValToConversion.end()) { |
115 | // Create the conversion. |
116 | Node *conversion = |
117 | createConversion(*parent, node, val, origTy, /* isInput */ false); |
118 | // "conversion" uses val so after this call, |
119 | // we will get a use of conversion inside conversion. |
120 | NodeValue conversionVal = getConversionOutput(*conversion); |
121 | auto insertion = |
122 | functionAndValToConversion.insert({functionAndVal, conversionVal}); |
123 | assert(insertion.second && "Conversion already there?!" ); |
124 | conversionValIt = insertion.first; |
125 | } |
126 | |
127 | NodeValue conversionVal = conversionValIt->second; |
128 | if (user == conversionVal.getNode()) { |
129 | continue; |
130 | } |
131 | // Log the change of node input(operand). |
132 | if (Function *F = node.getParent()) { |
133 | F->getLogContext()->logNodeInputChange(*user, *(use.get()), |
134 | conversionVal); |
135 | } |
136 | |
137 | use.get()->setOperand(conversionVal.getNode(), conversionVal.getResNo()); |
138 | } |
139 | } |
140 | } |
141 | |
142 | void FunctionConverter::convertInputs(Node &node) { |
143 | // We shouldn't have to convert the inputs of something that is not in |
144 | // function_. |
145 | assert((node.getNumInputs() == 0 || node.getParent() == &function_) && |
146 | "Invalid requested conversion" ); |
147 | for (unsigned idx = 0, end = node.getNumInputs(); idx != end; ++idx) { |
148 | NodeValue val = node.getNthInput(idx); |
149 | TypeRef targetTy = getTargetTypeForInput(node, idx); |
150 | if (!targetTy || targetTy == val.getType()) { |
151 | continue; |
152 | } |
153 | // convert the node and create a conversion to keep the users happy. |
154 | assert(targetTy->dims() == val.getType()->dims() && |
155 | "Conversion does not preserve shape" ); |
156 | // Create the conversion. |
157 | Node *conversion = |
158 | createConversion(function_, node, val, targetTy, /* isInput */ true); |
159 | node.setNthInput(idx, getConversionOutput(*conversion)); |
160 | } |
161 | } |
162 | |
163 | void FunctionConverter::convert() { |
164 | assert(function_.verify() && "Input function must be valid" ); |
165 | |
166 | // Traverse all nodes. |
167 | // Check what the conversion should look like, if any. |
168 | // Convert the node appropriately. |
169 | |
170 | // For every unprocessed node in the graph we keep the invariant of having |
171 | // all inputs to be of the uncovered type. |
172 | // I.e., if we have: |
173 | // res(outTy) = node arg1(in2Ty), arg2(in2Ty) |
174 | // |
175 | // after converting "node", we will have something that looks like: |
176 | // newArg1(convertedIn1Ty) = conversion arg1 |
177 | // newArg2(convertedIn2Ty) = conversion arg2 |
178 | // newRes(convertedOutTy) = node newArg1, newArg2 |
179 | // res(outTy) = conversion newRes |
180 | // |
181 | // In other words, the boundaries (in and out) are unchanged. |
182 | |
183 | // The iterator looks weird because we only want to iterate through |
184 | // the original nodes. |
185 | auto nodeIt = function_.getNodes().end(); |
186 | auto stopIt = function_.getNodes().begin(); |
187 | do { |
188 | --nodeIt; |
189 | Node &node = *nodeIt; |
190 | if (!canConvert(node)) { |
191 | continue; |
192 | } |
193 | // Mutate the output types and insert the conversion to keep our |
194 | // invariant. |
195 | convertOutputs(node); |
196 | // Convert the inputs of the node. |
197 | convertInputs(node); |
198 | // All the surrounding code is properly typed, finally the morph node. |
199 | Node &morphedNode = morphNode(node); |
200 | // Do some post processing if need be. |
201 | postProcessing(morphedNode); |
202 | } while (nodeIt != stopIt); |
203 | |
204 | // Allow a late clean-up before verifying the conversation produced a valid |
205 | // function. |
206 | cleanUp(); |
207 | |
208 | assert(function_.verify() && "Conversion led to invalid function" ); |
209 | } |
210 | |
211 | void FunctionConverter::convertPlaceholder(Placeholder &placeholder, |
212 | PlaceholderBindings *bindings) { |
213 | TypeRef destTy = getTargetTypeForOutput(placeholder.getOutput()); |
214 | if (!destTy || destTy == placeholder.getType()) { |
215 | return; |
216 | } |
217 | convertOutputs(placeholder); |
218 | if (!bindings) { |
219 | return; |
220 | } |
221 | Tensor *tensor = bindings->get(&placeholder); |
222 | if (tensor) { |
223 | convertTensor(*tensor, destTy); |
224 | } |
225 | } |
226 | |