1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Graph/Graph.h" |
18 | #include "glow/Graph/Node.h" |
19 | #include "glow/Graph/Nodes.h" |
20 | #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h" |
21 | |
22 | #include <unordered_set> |
23 | |
24 | using namespace glow; |
25 | |
26 | void glow::profileQuantization( |
27 | PlaceholderBindings &bindings, Function *F, |
28 | const quantization::ProfilingConfiguration &profConfig) { |
29 | // Iterate over all nodes in the graph and insert QuantizationProfile nodes |
30 | // to observe tensor values from every node's output. |
31 | std::unordered_set<NodeValue> nodesToInstrument; |
32 | |
33 | // Add Quantization Profile node to all of the floating point outputs. |
34 | for (auto &node : F->getNodes()) { |
35 | for (unsigned i = 0, e = node.getNumResults(); i < e; ++i) { |
36 | if (node.getElementType(i) != ElemKind::FloatTy) { |
37 | continue; |
38 | } |
39 | nodesToInstrument.insert(node.getNthResult(i)); |
40 | } |
41 | } |
42 | |
43 | // Add Quantization Profile node to all floating point vars. |
44 | for (const auto &var : F->getParent()->getConstants()) { |
45 | if (var->getOutput().getElementType() != ElemKind::FloatTy) { |
46 | continue; |
47 | } |
48 | nodesToInstrument.insert(var->getOutput()); |
49 | } |
50 | |
51 | // Add Quantization Profile node to all floating point placeholders. |
52 | for (const auto &PH : F->getParent()->getPlaceholders()) { |
53 | if (PH->getOutput().getElementType() != ElemKind::FloatTy) { |
54 | continue; |
55 | } |
56 | |
57 | /// Don't profile output nodes. |
58 | if (!PH->getUsers().empty()) { |
59 | auto *SN = llvm::dyn_cast<SaveNode>(PH->getUsers().begin()->getUser()); |
60 | if (SN) { |
61 | continue; |
62 | } |
63 | } |
64 | nodesToInstrument.insert(PH->getOutput()); |
65 | } |
66 | |
67 | for (const auto &NV : nodesToInstrument) { |
68 | F->createQuantizationProfile(bindings, |
69 | "QP_" + NV.getNode()->getName().str(), NV, |
70 | profConfig.numHistogramBins); |
71 | } |
72 | } |
73 | |