1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
18
19#include "glow/Backend/Backend.h"
20#include "glow/Backends/Interpreter/Interpreter.h"
21#include "glow/Graph/Graph.h"
22#include "glow/Graph/Log.h"
23#include "glow/Graph/Node.h"
24#include "glow/Graph/Nodes.h"
25#include "glow/Graph/PlaceholderBindings.h"
26#include "glow/Graph/TensorLayout.h"
27#include "glow/Graph/Utils.h"
28#include "glow/Optimizer/GraphOptimizer/FunctionPasses.h"
29
30#include "llvm/Support/Casting.h"
31
32using namespace glow;
33using llvm::cast;
34using llvm::dyn_cast;
35using llvm::isa;
36
37namespace {
38/// The name of the temporary function to be used to perform constant folding.
39constexpr const char *constEvaluationFunctionName =
40 "__constEvaluationFunction__";
41
42/// \returns true if a node \p N is a constant operation, i.e. it is a trivial
43/// constant like Constant or Splat or all of its inputs are recursively
44/// constant operations, and it has no side-effects and supported by the \p
45/// backend. If \p enableQuantizeConstFolding then QuantizeNodes are considered
46/// a valid constant operation to fold.
47bool isConstantOperation(const Node *N, const Backend &backend,
48 bool enableQuantizeConstFolding) {
49 // An operation with side-effects cannot be computed at compile-time.
50 if (N->hasSideEffects()) {
51 return false;
52 }
53 // Quantize nodes are not handled by ConstantFolding but by a specific
54 // quantization specific optimization.
55 if (!enableQuantizeConstFolding && isa<QuantizeNode>(N)) {
56 return false;
57 }
58 // Constant and splat nodes are trivially constant operations.
59 if (isa<Constant>(N) || isa<SplatNode>(N)) {
60 return true;
61 }
62 // If the node is backend specific, we cannot safely do constant folding using
63 // the interpreter.
64 if (!N->isCanonical()) {
65 return false;
66 }
67 // If the operation is not supported by the backend, it cannot be computed at
68 // compile-time.
69 if (!backend.shouldLower(N) && !backend.isOpSupported(NodeInfo(*N))) {
70 return false;
71 }
72 if (isa<Placeholder>(N)) {
73 return false;
74 }
75 for (size_t idx = 0, e = N->getNumInputs(); idx < e; ++idx) {
76 auto input = N->getNthInput(idx);
77 if (!isConstantOperation(input.getNode(), backend,
78 enableQuantizeConstFolding)) {
79 return false;
80 }
81 }
82 return true;
83}
84
85/// \returns true if node \p N has at least one non-constant operation user.
86/// \p backend and \p enableQuantizeConstFolding are used to determine what is
87/// valid for folding.
88bool hasNonConstantOperationUser(const Node *N, const Backend &backend,
89 bool enableQuantizeConstFolding) {
90 assert(isConstantOperation(N, backend, enableQuantizeConstFolding) &&
91 "Expected constant operation");
92 for (auto &use : N->getUsers()) {
93 auto *user = use.getUser();
94 // Only consider users in the current function.
95 if (user->getParent() != N->getParent()) {
96 continue;
97 }
98 if (!isConstantOperation(user, backend, enableQuantizeConstFolding)) {
99 return true;
100 }
101 }
102 return false;
103}
104
105/// Compile the function \p F for the provided \p backend using the compilation
106/// context \p cctx.
107/// \returns compiled function.
108Expected<std::unique_ptr<CompiledFunction>>
109compile(Backend &backend, Function &F, CompilationContext &cctx) {
110 RETURN_IF_ERR(::glow::optimizeFunction(&F, backend, cctx));
111 return backend.compile(&F, cctx.backendOpts);
112}
113
114/// Runs the compiled function \p compiledF on the \p backend using provided \p
115/// bindings.
116Error run(Backend &backend, CompiledFunction &compiledF,
117 PlaceholderBindings &bindings) {
118 std::unique_ptr<PlaceholderBindings> bindingsPtr(&bindings);
119 ExecutionContext context(std::move(bindingsPtr));
120 // TODO: Add only constants used by F to the compiled function. This should
121 // reduce the amount of data that needs to be copied.
122 auto executeErr = compiledF.execute(&context);
123 RETURN_IF_ERR(std::move(executeErr));
124 // Don't delete bindings.
125 context.movePlaceholderBindings().release();
126 return Error::success();
127}
128
129static bool isCanonicalLayout(const NodeValue &RN, Backend &backend,
130 Node *clonedC, size_t idx) {
131 auto resultLayoutStr =
132 backend.getTensorLayoutRequirements().getNthResultLayoutRequirements(
133 clonedC, idx);
134 auto resultLayout = TensorLayoutDescription(resultLayoutStr);
135 auto &canInstance = CanonicalTensorLayout::getInstance();
136 auto default4DStr = canInstance.getDefaultNDLayout(4);
137 auto default4D = TensorLayoutDescription(default4DStr);
138 if (resultLayout.getDims().size() == 4 &&
139 !canInstance.isSatisfiedBy(RN.getType(), default4D, &resultLayout)) {
140 return false;
141 }
142 return true;
143}
144
145// Bail on constant folding post-lowering for backends that break assumptions.
146static void bailOnNonCanonicalLayout(
147 Function *constEvaluationF, Module &mod,
148 const llvm::SmallVectorImpl<SaveNode *> &savedResults) {
149 // Some results may be in a non-canonical format post-lowering.
150 // For example, if we are trying to constant fold an OpenCL 'Reshape' that
151 // has NCHW layout. We cannot transpose it back to canonical layout for
152 // two reasons: 1) Need to add a solver that supports weird non-NCHW2NHWC
153 // backends. 2) Even if we get a constant tensor as a new "save" of the
154 // transpose, the new constant tensor will have the wrong shape. We'd
155 // actually need to transpose it back to its pre-modification shape. These
156 // issues may be solved in the future (TODO), for now bail on such corner
157 // cases. Clean-up before bailing:
158 for (auto *SN : savedResults) {
159 // Now erase the Placeholder that we created for the SaveNode.
160 auto &vars = mod.getPlaceholders();
161 mod.erasePlaceholder(
162 std::find(vars.begin(), vars.end(), SN->getPlaceholder()));
163 }
164 mod.eraseFunction(constEvaluationF);
165}
166
167/// \returns whether \p N should be folded based on \p cctx's
168/// optimizationOpts.materializeSplatsUsedBySet.count, where the backend may
169/// specify what Splats should be materialized into Constants based on if
170/// they're used by other op kinds.
171static bool isSplatToFold(SplatNode *N, const CompilationContext &cctx) {
172 for (const auto &U : N->getUsers()) {
173 if (cctx.optimizationOpts.materializeSplatsUsedBySet.count(
174 U.getUser()->getKind())) {
175 return true;
176 }
177 }
178 return false;
179}
180
181/// Use to make sure we don't reuse the same name for const fold Functions.
182static uint64_t numFolds = 0;
183
184/// Evaluates a provided constant operation \p C using the provided \p backend
185/// and using the compilation context \p cctx. If \p record is not a nullptr
186/// then the Constant created is added to the map, pointing to the SaveNode that
187/// generated that Constant. Additionally if \p record is not a nullptr then the
188/// constEvaluationF its associated Placeholders for saving results will not be
189/// deleted, and the caller is responsble for deleting them if necessary.
190/// \returns constant results. If \p foldSingleSplats then single splat
191/// subgraphs will be forced to fold.
192bool evaluateConstantOperation(Backend &backend, CompilationContext &cctx,
193 Node *C, std::vector<Constant *> &constResults,
194 ConstantFoldingRecordMap *record,
195 bool foldSingleSplats) {
196 // Allow for quantize folding when we have a const folding record.
197 const bool enableQuantizeConstFolding = record != nullptr;
198 PlaceholderBindings bindings;
199 assert(isConstantOperation(C, backend, enableQuantizeConstFolding) &&
200 "Expected a constant expression");
201 // Constants and splats do not need to be constant evaluated.
202 if (isa<Constant>(C) || (isa<SplatNode>(C) && !foldSingleSplats &&
203 !isSplatToFold(cast<SplatNode>(C), cctx))) {
204 return true;
205 }
206 Module &mod = *C->getParent()->getParent();
207 const std::string funName = std::string(constEvaluationFunctionName) +
208 std::to_string(numFolds++) + "__" +
209 C->getName().data();
210 // Create a temporary function to perform the constant operation.
211 Function *constEvaluationF = mod.createFunction(funName);
212 // Mapping from existing nodes to the new ones.
213 NodeMap currToNew;
214 // Clone the constant operation and some of its inputs if necessary.
215 auto *clonedC = recursiveClone(constEvaluationF, C, currToNew);
216 // Create save nodes for each of the results.
217 llvm::SmallVector<SaveNode *, 16> savedResults;
218
219 // If we're recording constant folding, only lower the const fold subgraph
220 // (i.e. do not run optimizations when compiling the const fold subgraph).
221 // Otherwise some graph optimizations may do folding themselves, meaning that
222 // the the subgraph will not contain all folding that occurs.
223 if (record) {
224 cctx.optimizationOpts.onlyLowerFuns.insert(constEvaluationF);
225 }
226 ScopeGuard cleanupOnlyLowerFuns([&]() {
227 if (record) {
228 cctx.optimizationOpts.onlyLowerFuns.erase(constEvaluationF);
229 }
230 });
231
232 for (size_t idx = 0, e = clonedC->getNumResults(); idx < e; ++idx) {
233 auto RN = clonedC->getNthResult(idx);
234 auto *SN = constEvaluationF->createSave(clonedC->getName(), RN);
235 if (!isCanonicalLayout(RN, backend, clonedC, idx)) {
236 bailOnNonCanonicalLayout(constEvaluationF, mod, savedResults);
237 return false;
238 }
239 savedResults.emplace_back(SN);
240 bindings.allocate(SN->getPlaceholder());
241 }
242 // Run the temporary backend to perform this constant operation
243 // evaluation.
244 if (ERR_TO_BOOL(executeConstantFunction(backend, *constEvaluationF, bindings,
245 cctx, enableQuantizeConstFolding))) {
246 mod.eraseFunction(constEvaluationF);
247 return false;
248 }
249
250 // Get the results of the constant operation compile-time computation and
251 // create new constants from it.
252 constResults.reserve(savedResults.size());
253 for (auto *SN : savedResults) {
254 Tensor *outputTensor = bindings.get(SN->getPlaceholder());
255 auto *constResult = mod.createConstant(
256 SN->getInput().getNode()->getName().str() + ".constfold",
257 std::move(*outputTensor));
258 constResults.emplace_back(constResult);
259
260 if (record) {
261 // Note: we skip erasing the Placeholders during recording (someone else's
262 // responsibility to delete them).
263 (*record)[constResult] = SN;
264 } else {
265 // Now erase the Placeholder that we created for the SaveNode.
266 auto &vars = mod.getPlaceholders();
267 mod.erasePlaceholder(
268 std::find(vars.begin(), vars.end(), SN->getPlaceholder()));
269 }
270 }
271 // Remove the temporary function, unless we're recording the changes (someone
272 // else's responsibility to delete them).
273 if (!record) {
274 mod.eraseFunction(constEvaluationF);
275 }
276 return true;
277}
278
279/// Check if function \p F consists of constant operations only.
280LLVM_ATTRIBUTE_USED
281Error verifyConstantFunction(Backend &backend, Function &F,
282 bool enableQuantizeConstFolding) {
283 // Perform the checks in DEBUG builds only.
284 for (auto &N : F.getNodes()) {
285 // Saving results is fine.
286 if (isa<SaveNode>(&N)) {
287 continue;
288 }
289 // Placeholders can be used just to save results.
290 if (!isa<Placeholder>(&N)) {
291 RETURN_ERR_IF_NOT(
292 isConstantOperation(&N, backend, enableQuantizeConstFolding),
293 "Expected constant operation");
294 continue;
295 }
296 if (!N.hasOneUse()) {
297 return MAKE_ERR("Expected constant operation");
298 }
299 auto *SN = dyn_cast<SaveNode>(N.getUsers().begin()->getUser());
300 if (SN && SN->getPlaceholder() == &N) {
301 continue;
302 }
303 return MAKE_ERR("Expected constant operation");
304 }
305 return Error::success();
306}
307
308/// Perform a compile-time constant folding of the node \p N using the provided
309/// \p backend. If \p record is not a nullptr then the Constant created is added
310/// to the map, pointing to the SaveNode that generated that Constant.
311/// \returns list of constants which are the result of the
312/// constant-folding. These constants correspond to results of the node. If no
313/// constant folding was possible an empty vector will be returned. If
314/// \p foldSingleSplats then single splat subgraphs will be forced to fold.
315bool constantFoldNodeImpl(
316 Backend &backend, Node *N, std::vector<Constant *> &constResults,
317 ConstantFoldingRecordMap *record = nullptr,
318 const CompilationContext &origCctx = CompilationContext(),
319 bool foldSingleSplats = false) {
320 CompilationContext cctx;
321 // Do not recursively call constant folding.
322 cctx.optimizationOpts.enableConstantFolding = false;
323 cctx.optimizationOpts.enableConstantDeduplication = false;
324 cctx.backendOpts.collectConstants = true;
325 // Do not print out compilation errors encountered, as constant folding is a
326 // best effort; simply silently give up and continue with compilation.
327 cctx.verboseCompile = false;
328 // Signal to the graph optimizer that it should not be deleting unused
329 // Constants in the module.
330 cctx.optimizationOpts.delayAndRecordConstantModification = true;
331 // Copy over the splats to materialize from the original cctx.
332 cctx.optimizationOpts.materializeSplatsUsedBySet =
333 origCctx.optimizationOpts.materializeSplatsUsedBySet;
334 assert(!ERR_TO_BOOL(cctx.verify()) && "cctx for const folding must be valid");
335 return evaluateConstantOperation(backend, cctx, N, constResults, record,
336 foldSingleSplats);
337}
338
339} // namespace
340
341Error glow::executeConstantFunction(Backend &backend, Function &F,
342 PlaceholderBindings &bindings,
343 CompilationContext &cctx,
344 bool enableQuantizeConstFolding) {
345// Perform the checks in DEBUG builds only.
346#ifndef NDEBUG
347 RETURN_IF_ERR(verifyConstantFunction(backend, F, enableQuantizeConstFolding));
348#endif
349 std::unique_ptr<CompiledFunction> compiledF;
350 ASSIGN_VALUE_OR_RETURN_ERR(compiledF, compile(backend, F, cctx));
351 return run(backend, *compiledF, bindings);
352}
353
354/// Perform constant folding in the function \p F . Any non-trivial node (i.e.
355/// not a constant or a splat) that can be computed at compile-time is going to
356/// be computed at compile-time. \returns true if any foldings were performed.
357/// If \p record is not a nullptr then the Constants created for any constant
358/// chain of Nodes is added to the map, pointing to the SaveNode that generated
359/// that Constant.
360static bool constantFoldFun(Function *F, const CompilationContext &cctx,
361 ConstantFoldingRecordMap *record = nullptr) {
362 // Skip if specified in the cctx.
363 if (!cctx.optimizationOpts.enableConstantFolding) {
364 return false;
365 }
366
367 // Allow for quantize folding when we have a const folding record.
368 const bool enableQuantizeConstFolding = record != nullptr;
369
370 LOG_SCOPE(F->getLogContext(), "glow::constantFold")
371 bool changed = false;
372 // Backend to be used for compile-time computations.
373 std::unique_ptr<Backend> backend(new Interpreter());
374 // Traverse nodes in post-order, so that children are seen before parents.
375 GraphPostOrderVisitor postOrderVisitor(*F);
376 auto nodes = postOrderVisitor.getPostOrder();
377 // Collect all non-trivial constant operations.
378 for (auto *N : nodes) {
379 // Skip trivial nodes/operations that do not require any constant
380 // computations.
381 if (isa<Storage>(N) || isa<Constant>(N) ||
382 (isa<SplatNode>(N) && !isSplatToFold(cast<SplatNode>(N), cctx)) ||
383 isa<TouchNode>(N)) {
384 continue;
385 }
386
387 // Skip nodes that are not constant operations.
388 if (!isConstantOperation(N, *backend, enableQuantizeConstFolding)) {
389 continue;
390 }
391
392 // Add only a constant operation node whose value is used by at least
393 // one non constant-operation node, because no other bigger constant
394 // operation containing the current node can completely replace the result
395 // of its computation. Doing this check allows for performing a smaller
396 // number of evaluateConstantOperation calls later and thus reduces the
397 // overhead.
398 if (!hasNonConstantOperationUser(N, *backend, enableQuantizeConstFolding)) {
399 continue;
400 }
401
402 // Compute the constant value of the node.
403 std::vector<Constant *> constResults;
404 if (!constantFoldNodeImpl(*backend, N, constResults, record, cctx)) {
405 continue;
406 }
407 // Replace all results of the original operation by the computed
408 // compile-time results of this operation.
409 for (size_t idx = 0, e = constResults.size(); idx < e; ++idx) {
410 auto constResult = constResults[idx];
411 assert(N->getNthResult(idx).getType() ==
412 constResult->getOutput().getType() &&
413 "Constant replacement type must match.");
414 // Replace the old result by the new constant result.
415 N->getNthResult(idx).replaceAllUsesOfWith(constResult);
416 }
417 // Perform Dead Code Elimination.
418 runDCEPass(F, cctx);
419 changed = true;
420 }
421 return changed;
422}
423
424/// Perform constant folding in the function \p F . Any non-trivial node (i.e.
425/// not a constant or a splat) that can be computed at compile-time is going to
426/// be computed at compile-time. \returns true if any foldings were performed.
427bool glow::ConstantFold::run(Function *F, const CompilationContext &cctx) {
428 return constantFoldFun(F, cctx);
429}
430
431ConstantFoldingRecordMap
432glow::constantFoldAndRecord(Function *F, const CompilationContext &cctx) {
433 ConstantFoldingRecordMap record;
434 constantFoldFun(F, cctx, &record);
435 return record;
436}
437
438std::vector<Constant *> glow::constantFold(Node *N, bool foldSingleSplats) {
439 LOG_SCOPE(N->getParent()->getLogContext(), "glow::constantFold")
440 std::unique_ptr<Backend> backend(new Interpreter());
441 if (!isConstantOperation(N, *backend,
442 /* enableQuantizeConstFolding */ false)) {
443 return {};
444 }
445 std::vector<Constant *> constResults;
446 if (!constantFoldNodeImpl(*backend, N, constResults, nullptr,
447 CompilationContext(), foldSingleSplats)) {
448 return {};
449 }
450 return constResults;
451}
452
453void glow::cleanupConstantFolding(Module &mod,
454 const ConstantFoldingRecordMap &record,
455 PlaceholderBindings *bindings) {
456 auto &PHs = mod.getPlaceholders();
457 std::unordered_set<Function *> funsToErase;
458 for (auto &r : record) {
459 SaveNode *SN = r.second;
460 if (bindings && bindings->count(SN->getPlaceholder())) {
461 bindings->erase(SN->getPlaceholder());
462 }
463 mod.erasePlaceholder(
464 std::find(PHs.begin(), PHs.end(), SN->getPlaceholder()));
465 funsToErase.insert(SN->getParent());
466 }
467 for (Function *F : funsToErase) {
468 mod.eraseFunction(F);
469 }
470}
471