1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h" |
18 | |
19 | #include "glow/Backend/Backend.h" |
20 | #include "glow/Backends/Interpreter/Interpreter.h" |
21 | #include "glow/Graph/Graph.h" |
22 | #include "glow/Graph/Log.h" |
23 | #include "glow/Graph/Node.h" |
24 | #include "glow/Graph/Nodes.h" |
25 | #include "glow/Graph/PlaceholderBindings.h" |
26 | #include "glow/Graph/TensorLayout.h" |
27 | #include "glow/Graph/Utils.h" |
28 | #include "glow/Optimizer/GraphOptimizer/FunctionPasses.h" |
29 | |
30 | #include "llvm/Support/Casting.h" |
31 | |
32 | using namespace glow; |
33 | using llvm::cast; |
34 | using llvm::dyn_cast; |
35 | using llvm::isa; |
36 | |
37 | namespace { |
38 | /// The name of the temporary function to be used to perform constant folding. |
39 | constexpr const char *constEvaluationFunctionName = |
40 | "__constEvaluationFunction__" ; |
41 | |
42 | /// \returns true if a node \p N is a constant operation, i.e. it is a trivial |
43 | /// constant like Constant or Splat or all of its inputs are recursively |
44 | /// constant operations, and it has no side-effects and supported by the \p |
45 | /// backend. If \p enableQuantizeConstFolding then QuantizeNodes are considered |
46 | /// a valid constant operation to fold. |
47 | bool isConstantOperation(const Node *N, const Backend &backend, |
48 | bool enableQuantizeConstFolding) { |
49 | // An operation with side-effects cannot be computed at compile-time. |
50 | if (N->hasSideEffects()) { |
51 | return false; |
52 | } |
53 | // Quantize nodes are not handled by ConstantFolding but by a specific |
54 | // quantization specific optimization. |
55 | if (!enableQuantizeConstFolding && isa<QuantizeNode>(N)) { |
56 | return false; |
57 | } |
58 | // Constant and splat nodes are trivially constant operations. |
59 | if (isa<Constant>(N) || isa<SplatNode>(N)) { |
60 | return true; |
61 | } |
62 | // If the node is backend specific, we cannot safely do constant folding using |
63 | // the interpreter. |
64 | if (!N->isCanonical()) { |
65 | return false; |
66 | } |
67 | // If the operation is not supported by the backend, it cannot be computed at |
68 | // compile-time. |
69 | if (!backend.shouldLower(N) && !backend.isOpSupported(NodeInfo(*N))) { |
70 | return false; |
71 | } |
72 | if (isa<Placeholder>(N)) { |
73 | return false; |
74 | } |
75 | for (size_t idx = 0, e = N->getNumInputs(); idx < e; ++idx) { |
76 | auto input = N->getNthInput(idx); |
77 | if (!isConstantOperation(input.getNode(), backend, |
78 | enableQuantizeConstFolding)) { |
79 | return false; |
80 | } |
81 | } |
82 | return true; |
83 | } |
84 | |
85 | /// \returns true if node \p N has at least one non-constant operation user. |
86 | /// \p backend and \p enableQuantizeConstFolding are used to determine what is |
87 | /// valid for folding. |
88 | bool hasNonConstantOperationUser(const Node *N, const Backend &backend, |
89 | bool enableQuantizeConstFolding) { |
90 | assert(isConstantOperation(N, backend, enableQuantizeConstFolding) && |
91 | "Expected constant operation" ); |
92 | for (auto &use : N->getUsers()) { |
93 | auto *user = use.getUser(); |
94 | // Only consider users in the current function. |
95 | if (user->getParent() != N->getParent()) { |
96 | continue; |
97 | } |
98 | if (!isConstantOperation(user, backend, enableQuantizeConstFolding)) { |
99 | return true; |
100 | } |
101 | } |
102 | return false; |
103 | } |
104 | |
105 | /// Compile the function \p F for the provided \p backend using the compilation |
106 | /// context \p cctx. |
107 | /// \returns compiled function. |
108 | Expected<std::unique_ptr<CompiledFunction>> |
109 | compile(Backend &backend, Function &F, CompilationContext &cctx) { |
110 | RETURN_IF_ERR(::glow::optimizeFunction(&F, backend, cctx)); |
111 | return backend.compile(&F, cctx.backendOpts); |
112 | } |
113 | |
114 | /// Runs the compiled function \p compiledF on the \p backend using provided \p |
115 | /// bindings. |
116 | Error run(Backend &backend, CompiledFunction &compiledF, |
117 | PlaceholderBindings &bindings) { |
118 | std::unique_ptr<PlaceholderBindings> bindingsPtr(&bindings); |
119 | ExecutionContext context(std::move(bindingsPtr)); |
120 | // TODO: Add only constants used by F to the compiled function. This should |
121 | // reduce the amount of data that needs to be copied. |
122 | auto executeErr = compiledF.execute(&context); |
123 | RETURN_IF_ERR(std::move(executeErr)); |
124 | // Don't delete bindings. |
125 | context.movePlaceholderBindings().release(); |
126 | return Error::success(); |
127 | } |
128 | |
129 | static bool isCanonicalLayout(const NodeValue &RN, Backend &backend, |
130 | Node *clonedC, size_t idx) { |
131 | auto resultLayoutStr = |
132 | backend.getTensorLayoutRequirements().getNthResultLayoutRequirements( |
133 | clonedC, idx); |
134 | auto resultLayout = TensorLayoutDescription(resultLayoutStr); |
135 | auto &canInstance = CanonicalTensorLayout::getInstance(); |
136 | auto default4DStr = canInstance.getDefaultNDLayout(4); |
137 | auto default4D = TensorLayoutDescription(default4DStr); |
138 | if (resultLayout.getDims().size() == 4 && |
139 | !canInstance.isSatisfiedBy(RN.getType(), default4D, &resultLayout)) { |
140 | return false; |
141 | } |
142 | return true; |
143 | } |
144 | |
145 | // Bail on constant folding post-lowering for backends that break assumptions. |
146 | static void bailOnNonCanonicalLayout( |
147 | Function *constEvaluationF, Module &mod, |
148 | const llvm::SmallVectorImpl<SaveNode *> &savedResults) { |
149 | // Some results may be in a non-canonical format post-lowering. |
150 | // For example, if we are trying to constant fold an OpenCL 'Reshape' that |
151 | // has NCHW layout. We cannot transpose it back to canonical layout for |
152 | // two reasons: 1) Need to add a solver that supports weird non-NCHW2NHWC |
153 | // backends. 2) Even if we get a constant tensor as a new "save" of the |
154 | // transpose, the new constant tensor will have the wrong shape. We'd |
155 | // actually need to transpose it back to its pre-modification shape. These |
156 | // issues may be solved in the future (TODO), for now bail on such corner |
157 | // cases. Clean-up before bailing: |
158 | for (auto *SN : savedResults) { |
159 | // Now erase the Placeholder that we created for the SaveNode. |
160 | auto &vars = mod.getPlaceholders(); |
161 | mod.erasePlaceholder( |
162 | std::find(vars.begin(), vars.end(), SN->getPlaceholder())); |
163 | } |
164 | mod.eraseFunction(constEvaluationF); |
165 | } |
166 | |
167 | /// \returns whether \p N should be folded based on \p cctx's |
168 | /// optimizationOpts.materializeSplatsUsedBySet.count, where the backend may |
169 | /// specify what Splats should be materialized into Constants based on if |
170 | /// they're used by other op kinds. |
171 | static bool isSplatToFold(SplatNode *N, const CompilationContext &cctx) { |
172 | for (const auto &U : N->getUsers()) { |
173 | if (cctx.optimizationOpts.materializeSplatsUsedBySet.count( |
174 | U.getUser()->getKind())) { |
175 | return true; |
176 | } |
177 | } |
178 | return false; |
179 | } |
180 | |
181 | /// Use to make sure we don't reuse the same name for const fold Functions. |
182 | static uint64_t numFolds = 0; |
183 | |
184 | /// Evaluates a provided constant operation \p C using the provided \p backend |
185 | /// and using the compilation context \p cctx. If \p record is not a nullptr |
186 | /// then the Constant created is added to the map, pointing to the SaveNode that |
187 | /// generated that Constant. Additionally if \p record is not a nullptr then the |
188 | /// constEvaluationF its associated Placeholders for saving results will not be |
189 | /// deleted, and the caller is responsble for deleting them if necessary. |
190 | /// \returns constant results. If \p foldSingleSplats then single splat |
191 | /// subgraphs will be forced to fold. |
192 | bool evaluateConstantOperation(Backend &backend, CompilationContext &cctx, |
193 | Node *C, std::vector<Constant *> &constResults, |
194 | ConstantFoldingRecordMap *record, |
195 | bool foldSingleSplats) { |
196 | // Allow for quantize folding when we have a const folding record. |
197 | const bool enableQuantizeConstFolding = record != nullptr; |
198 | PlaceholderBindings bindings; |
199 | assert(isConstantOperation(C, backend, enableQuantizeConstFolding) && |
200 | "Expected a constant expression" ); |
201 | // Constants and splats do not need to be constant evaluated. |
202 | if (isa<Constant>(C) || (isa<SplatNode>(C) && !foldSingleSplats && |
203 | !isSplatToFold(cast<SplatNode>(C), cctx))) { |
204 | return true; |
205 | } |
206 | Module &mod = *C->getParent()->getParent(); |
207 | const std::string funName = std::string(constEvaluationFunctionName) + |
208 | std::to_string(numFolds++) + "__" + |
209 | C->getName().data(); |
210 | // Create a temporary function to perform the constant operation. |
211 | Function *constEvaluationF = mod.createFunction(funName); |
212 | // Mapping from existing nodes to the new ones. |
213 | NodeMap currToNew; |
214 | // Clone the constant operation and some of its inputs if necessary. |
215 | auto *clonedC = recursiveClone(constEvaluationF, C, currToNew); |
216 | // Create save nodes for each of the results. |
217 | llvm::SmallVector<SaveNode *, 16> savedResults; |
218 | |
219 | // If we're recording constant folding, only lower the const fold subgraph |
220 | // (i.e. do not run optimizations when compiling the const fold subgraph). |
221 | // Otherwise some graph optimizations may do folding themselves, meaning that |
222 | // the the subgraph will not contain all folding that occurs. |
223 | if (record) { |
224 | cctx.optimizationOpts.onlyLowerFuns.insert(constEvaluationF); |
225 | } |
226 | ScopeGuard cleanupOnlyLowerFuns([&]() { |
227 | if (record) { |
228 | cctx.optimizationOpts.onlyLowerFuns.erase(constEvaluationF); |
229 | } |
230 | }); |
231 | |
232 | for (size_t idx = 0, e = clonedC->getNumResults(); idx < e; ++idx) { |
233 | auto RN = clonedC->getNthResult(idx); |
234 | auto *SN = constEvaluationF->createSave(clonedC->getName(), RN); |
235 | if (!isCanonicalLayout(RN, backend, clonedC, idx)) { |
236 | bailOnNonCanonicalLayout(constEvaluationF, mod, savedResults); |
237 | return false; |
238 | } |
239 | savedResults.emplace_back(SN); |
240 | bindings.allocate(SN->getPlaceholder()); |
241 | } |
242 | // Run the temporary backend to perform this constant operation |
243 | // evaluation. |
244 | if (ERR_TO_BOOL(executeConstantFunction(backend, *constEvaluationF, bindings, |
245 | cctx, enableQuantizeConstFolding))) { |
246 | mod.eraseFunction(constEvaluationF); |
247 | return false; |
248 | } |
249 | |
250 | // Get the results of the constant operation compile-time computation and |
251 | // create new constants from it. |
252 | constResults.reserve(savedResults.size()); |
253 | for (auto *SN : savedResults) { |
254 | Tensor *outputTensor = bindings.get(SN->getPlaceholder()); |
255 | auto *constResult = mod.createConstant( |
256 | SN->getInput().getNode()->getName().str() + ".constfold" , |
257 | std::move(*outputTensor)); |
258 | constResults.emplace_back(constResult); |
259 | |
260 | if (record) { |
261 | // Note: we skip erasing the Placeholders during recording (someone else's |
262 | // responsibility to delete them). |
263 | (*record)[constResult] = SN; |
264 | } else { |
265 | // Now erase the Placeholder that we created for the SaveNode. |
266 | auto &vars = mod.getPlaceholders(); |
267 | mod.erasePlaceholder( |
268 | std::find(vars.begin(), vars.end(), SN->getPlaceholder())); |
269 | } |
270 | } |
271 | // Remove the temporary function, unless we're recording the changes (someone |
272 | // else's responsibility to delete them). |
273 | if (!record) { |
274 | mod.eraseFunction(constEvaluationF); |
275 | } |
276 | return true; |
277 | } |
278 | |
279 | /// Check if function \p F consists of constant operations only. |
280 | LLVM_ATTRIBUTE_USED |
281 | Error verifyConstantFunction(Backend &backend, Function &F, |
282 | bool enableQuantizeConstFolding) { |
283 | // Perform the checks in DEBUG builds only. |
284 | for (auto &N : F.getNodes()) { |
285 | // Saving results is fine. |
286 | if (isa<SaveNode>(&N)) { |
287 | continue; |
288 | } |
289 | // Placeholders can be used just to save results. |
290 | if (!isa<Placeholder>(&N)) { |
291 | RETURN_ERR_IF_NOT( |
292 | isConstantOperation(&N, backend, enableQuantizeConstFolding), |
293 | "Expected constant operation" ); |
294 | continue; |
295 | } |
296 | if (!N.hasOneUse()) { |
297 | return MAKE_ERR("Expected constant operation" ); |
298 | } |
299 | auto *SN = dyn_cast<SaveNode>(N.getUsers().begin()->getUser()); |
300 | if (SN && SN->getPlaceholder() == &N) { |
301 | continue; |
302 | } |
303 | return MAKE_ERR("Expected constant operation" ); |
304 | } |
305 | return Error::success(); |
306 | } |
307 | |
308 | /// Perform a compile-time constant folding of the node \p N using the provided |
309 | /// \p backend. If \p record is not a nullptr then the Constant created is added |
310 | /// to the map, pointing to the SaveNode that generated that Constant. |
311 | /// \returns list of constants which are the result of the |
312 | /// constant-folding. These constants correspond to results of the node. If no |
313 | /// constant folding was possible an empty vector will be returned. If |
314 | /// \p foldSingleSplats then single splat subgraphs will be forced to fold. |
315 | bool constantFoldNodeImpl( |
316 | Backend &backend, Node *N, std::vector<Constant *> &constResults, |
317 | ConstantFoldingRecordMap *record = nullptr, |
318 | const CompilationContext &origCctx = CompilationContext(), |
319 | bool foldSingleSplats = false) { |
320 | CompilationContext cctx; |
321 | // Do not recursively call constant folding. |
322 | cctx.optimizationOpts.enableConstantFolding = false; |
323 | cctx.optimizationOpts.enableConstantDeduplication = false; |
324 | cctx.backendOpts.collectConstants = true; |
325 | // Do not print out compilation errors encountered, as constant folding is a |
326 | // best effort; simply silently give up and continue with compilation. |
327 | cctx.verboseCompile = false; |
328 | // Signal to the graph optimizer that it should not be deleting unused |
329 | // Constants in the module. |
330 | cctx.optimizationOpts.delayAndRecordConstantModification = true; |
331 | // Copy over the splats to materialize from the original cctx. |
332 | cctx.optimizationOpts.materializeSplatsUsedBySet = |
333 | origCctx.optimizationOpts.materializeSplatsUsedBySet; |
334 | assert(!ERR_TO_BOOL(cctx.verify()) && "cctx for const folding must be valid" ); |
335 | return evaluateConstantOperation(backend, cctx, N, constResults, record, |
336 | foldSingleSplats); |
337 | } |
338 | |
339 | } // namespace |
340 | |
341 | Error glow::executeConstantFunction(Backend &backend, Function &F, |
342 | PlaceholderBindings &bindings, |
343 | CompilationContext &cctx, |
344 | bool enableQuantizeConstFolding) { |
345 | // Perform the checks in DEBUG builds only. |
346 | #ifndef NDEBUG |
347 | RETURN_IF_ERR(verifyConstantFunction(backend, F, enableQuantizeConstFolding)); |
348 | #endif |
349 | std::unique_ptr<CompiledFunction> compiledF; |
350 | ASSIGN_VALUE_OR_RETURN_ERR(compiledF, compile(backend, F, cctx)); |
351 | return run(backend, *compiledF, bindings); |
352 | } |
353 | |
354 | /// Perform constant folding in the function \p F . Any non-trivial node (i.e. |
355 | /// not a constant or a splat) that can be computed at compile-time is going to |
356 | /// be computed at compile-time. \returns true if any foldings were performed. |
357 | /// If \p record is not a nullptr then the Constants created for any constant |
358 | /// chain of Nodes is added to the map, pointing to the SaveNode that generated |
359 | /// that Constant. |
360 | static bool constantFoldFun(Function *F, const CompilationContext &cctx, |
361 | ConstantFoldingRecordMap *record = nullptr) { |
362 | // Skip if specified in the cctx. |
363 | if (!cctx.optimizationOpts.enableConstantFolding) { |
364 | return false; |
365 | } |
366 | |
367 | // Allow for quantize folding when we have a const folding record. |
368 | const bool enableQuantizeConstFolding = record != nullptr; |
369 | |
370 | LOG_SCOPE(F->getLogContext(), "glow::constantFold" ) |
371 | bool changed = false; |
372 | // Backend to be used for compile-time computations. |
373 | std::unique_ptr<Backend> backend(new Interpreter()); |
374 | // Traverse nodes in post-order, so that children are seen before parents. |
375 | GraphPostOrderVisitor postOrderVisitor(*F); |
376 | auto nodes = postOrderVisitor.getPostOrder(); |
377 | // Collect all non-trivial constant operations. |
378 | for (auto *N : nodes) { |
379 | // Skip trivial nodes/operations that do not require any constant |
380 | // computations. |
381 | if (isa<Storage>(N) || isa<Constant>(N) || |
382 | (isa<SplatNode>(N) && !isSplatToFold(cast<SplatNode>(N), cctx)) || |
383 | isa<TouchNode>(N)) { |
384 | continue; |
385 | } |
386 | |
387 | // Skip nodes that are not constant operations. |
388 | if (!isConstantOperation(N, *backend, enableQuantizeConstFolding)) { |
389 | continue; |
390 | } |
391 | |
392 | // Add only a constant operation node whose value is used by at least |
393 | // one non constant-operation node, because no other bigger constant |
394 | // operation containing the current node can completely replace the result |
395 | // of its computation. Doing this check allows for performing a smaller |
396 | // number of evaluateConstantOperation calls later and thus reduces the |
397 | // overhead. |
398 | if (!hasNonConstantOperationUser(N, *backend, enableQuantizeConstFolding)) { |
399 | continue; |
400 | } |
401 | |
402 | // Compute the constant value of the node. |
403 | std::vector<Constant *> constResults; |
404 | if (!constantFoldNodeImpl(*backend, N, constResults, record, cctx)) { |
405 | continue; |
406 | } |
407 | // Replace all results of the original operation by the computed |
408 | // compile-time results of this operation. |
409 | for (size_t idx = 0, e = constResults.size(); idx < e; ++idx) { |
410 | auto constResult = constResults[idx]; |
411 | assert(N->getNthResult(idx).getType() == |
412 | constResult->getOutput().getType() && |
413 | "Constant replacement type must match." ); |
414 | // Replace the old result by the new constant result. |
415 | N->getNthResult(idx).replaceAllUsesOfWith(constResult); |
416 | } |
417 | // Perform Dead Code Elimination. |
418 | runDCEPass(F, cctx); |
419 | changed = true; |
420 | } |
421 | return changed; |
422 | } |
423 | |
424 | /// Perform constant folding in the function \p F . Any non-trivial node (i.e. |
425 | /// not a constant or a splat) that can be computed at compile-time is going to |
426 | /// be computed at compile-time. \returns true if any foldings were performed. |
427 | bool glow::ConstantFold::run(Function *F, const CompilationContext &cctx) { |
428 | return constantFoldFun(F, cctx); |
429 | } |
430 | |
431 | ConstantFoldingRecordMap |
432 | glow::constantFoldAndRecord(Function *F, const CompilationContext &cctx) { |
433 | ConstantFoldingRecordMap record; |
434 | constantFoldFun(F, cctx, &record); |
435 | return record; |
436 | } |
437 | |
438 | std::vector<Constant *> glow::constantFold(Node *N, bool foldSingleSplats) { |
439 | LOG_SCOPE(N->getParent()->getLogContext(), "glow::constantFold" ) |
440 | std::unique_ptr<Backend> backend(new Interpreter()); |
441 | if (!isConstantOperation(N, *backend, |
442 | /* enableQuantizeConstFolding */ false)) { |
443 | return {}; |
444 | } |
445 | std::vector<Constant *> constResults; |
446 | if (!constantFoldNodeImpl(*backend, N, constResults, nullptr, |
447 | CompilationContext(), foldSingleSplats)) { |
448 | return {}; |
449 | } |
450 | return constResults; |
451 | } |
452 | |
453 | void glow::cleanupConstantFolding(Module &mod, |
454 | const ConstantFoldingRecordMap &record, |
455 | PlaceholderBindings *bindings) { |
456 | auto &PHs = mod.getPlaceholders(); |
457 | std::unordered_set<Function *> funsToErase; |
458 | for (auto &r : record) { |
459 | SaveNode *SN = r.second; |
460 | if (bindings && bindings->count(SN->getPlaceholder())) { |
461 | bindings->erase(SN->getPlaceholder()); |
462 | } |
463 | mod.erasePlaceholder( |
464 | std::find(PHs.begin(), PHs.end(), SN->getPlaceholder())); |
465 | funsToErase.insert(SN->getParent()); |
466 | } |
467 | for (Function *F : funsToErase) { |
468 | mod.eraseFunction(F); |
469 | } |
470 | } |
471 | |