1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Quantization/Quantization.h" |
18 | |
19 | #include "glow/Backend/Backend.h" |
20 | #include "glow/Converter/FunctionConverter.h" |
21 | |
22 | #include <cmath> |
23 | #include <unordered_set> |
24 | #include <vector> |
25 | |
26 | using llvm::cast; |
27 | |
28 | namespace { |
29 | |
30 | using namespace glow; |
31 | using namespace glow::quantization; |
32 | |
33 | /// \returns whether BatchedAddNode \p baN was originally lowered from a |
34 | /// FullyConnectedNode based on the given \p loweredMap. |
35 | static bool isBAFromLoweredFC(const BatchedAddNode *baN, |
36 | const LoweredInfoMap &loweredMap) { |
37 | // Look for the set of NodeNameAndKinds corresponding to the |
38 | // BatchedAdd. If one exists, this means it was lowered. |
39 | auto it = loweredMap.find(baN->getResult().generateNodeOutputName()); |
40 | if (it == loweredMap.end()) { |
41 | return false; |
42 | } |
43 | |
44 | // Look through the set looking to see if the BatchedAdd was lowered |
45 | // from a FullyConnectedNode. |
46 | auto &set = it->getValue(); |
47 | for (auto i = set.begin(), e = set.end(); i != e; ++i) { |
48 | if (i->getKind() == glow::Kinded::Kind::FullyConnectedNodeKind) { |
49 | return true; |
50 | } |
51 | } |
52 | return false; |
53 | } |
54 | |
55 | /// This class produces a quantized function based on a provided profile. |
56 | class FunctionQuantizer : public FunctionConverter { |
57 | protected: |
58 | /// Get the type that \p out should have at the end of the conversion |
59 | /// process regardless of what is its current type. |
60 | /// This is similar to ::getTargetTypeForOutput except that \p out |
61 | /// doesn't need to be a floating point type for this method to |
62 | /// return the target type. |
63 | /// The reason we need this method is because we may morph the type |
64 | /// of \p out to match some IR constraints, but the final type still |
65 | /// needs to be known to insert rescale nodes. |
66 | TypeRef getTargetTypeForOutputImpl(const NodeValue &out) const { |
67 | auto outTQPIt = nodeToTQP_.find(out.generateNodeOutputName()); |
68 | assert(outTQPIt != nodeToTQP_.end() && |
69 | "Missing quantization params for a node" ); |
70 | |
71 | const TensorQuantizationParams &TQP = outTQPIt->second; |
72 | return mod_.uniqueType(quantizationPrecision_, out.dims(), TQP.scale, |
73 | TQP.offset); |
74 | } |
75 | |
76 | /// \see FunctionConverter::getTargetTypeForOutput. |
77 | /// \returns quantized type for \p out if any; if not quantizable, then |
78 | /// \returns the original type. |
79 | TypeRef getTargetTypeForOutput(const NodeValue &out) const override { |
80 | if (out.getElementType() != ElemKind::FloatTy) { |
81 | return out.getType(); |
82 | } |
83 | return getTargetTypeForOutputImpl(out); |
84 | } |
85 | |
86 | /// \see FunctionConverter::getTargetTypeForInput. |
87 | /// \returns the quantized type for the \p idx-th argument of \p use, if any; |
88 | /// if not quantizable, then \returns the original type. |
89 | TypeRef getTargetTypeForInput(const Node &use, unsigned idx) const override { |
90 | NodeValue val = use.getNthInput(idx); |
91 | |
92 | // Do not quantize non floating point type, e.g., Index type. |
93 | if (val.getElementType() != ElemKind::FloatTy) { |
94 | return val.getType(); |
95 | } |
96 | |
97 | auto valTQPIt = nodeToTQP_.find(val.generateNodeOutputName()); |
98 | assert(valTQPIt != nodeToTQP_.end() && |
99 | "Missing quantization params for a node" ); |
100 | |
101 | const TensorQuantizationParams &TQP = valTQPIt->second; |
102 | |
103 | // Local lambda to specialize the bias quantization parameters. |
104 | auto getBiasType = [&](TypeRef inputTy, TypeRef weightsTy) -> TypeRef { |
105 | TensorQuantizationParams inputTQP = {inputTy->getScale(), |
106 | inputTy->getOffset()}; |
107 | TensorQuantizationParams weightsTQP = {weightsTy->getScale(), |
108 | weightsTy->getOffset()}; |
109 | auto biasTQP = specializeBiasQuantizationParams( |
110 | TQP, inputTQP, weightsTQP, schema_, quantizationPrecisionBias_); |
111 | return mod_.uniqueType(quantizationPrecisionBias_, val.dims(), |
112 | biasTQP.scale, biasTQP.offset); |
113 | }; |
114 | |
115 | // NOTE: For every node for which the bias is specialized add the similar |
116 | // logic in the 'generateNodeQuantizationInfos' function. |
117 | if (use.getKind() == glow::Kinded::Kind::ConvolutionNodeKind && |
118 | idx == ConvolutionNode::BiasIdx) { |
119 | // Get the input and weights types. This ensures the types will be |
120 | // quantized. This is often the case when calling into this function from |
121 | // canConvert(), as we have not yet converted the inputs. |
122 | return getBiasType( |
123 | getTargetTypeForInput(use, ConvolutionNode::InputIdx), |
124 | getTargetTypeForInput(use, ConvolutionNode::FilterIdx)); |
125 | } else if (use.getKind() == glow::Kinded::Kind::Convolution3DNodeKind && |
126 | idx == Convolution3DNode::BiasIdx) { |
127 | // Get the input and weights types. This ensures the types will be |
128 | // quantized. This is often the case when calling into this function from |
129 | // canConvert(), as we have not yet converted the inputs. |
130 | return getBiasType( |
131 | getTargetTypeForInput(use, Convolution3DNode::InputIdx), |
132 | getTargetTypeForInput(use, Convolution3DNode::FilterIdx)); |
133 | } else if (use.getKind() == glow::Kinded::Kind::ConvTransposeNodeKind && |
134 | idx == ConvTransposeNode::BiasIdx) { |
135 | // Get the input and weights types. This ensures the types will be |
136 | // quantized. This is often the case when calling into this function from |
137 | // canConvert(), as we have not yet converted the inputs. |
138 | return getBiasType( |
139 | getTargetTypeForInput(use, ConvTransposeNode::InputIdx), |
140 | getTargetTypeForInput(use, ConvTransposeNode::FilterIdx)); |
141 | } else if (use.getKind() == glow::Kinded::Kind::FullyConnectedNodeKind && |
142 | idx == FullyConnectedNode::BiasIdx) { |
143 | // Return the original type if we don't want to convert FC biases. |
144 | if (skipQuantizeFCBias_) { |
145 | return val.getType(); |
146 | } |
147 | // Get the input and weights types. This ensures the types will be |
148 | // quantized. This is often the case when calling into this function from |
149 | // canConvert(), as we have not yet converted the inputs. |
150 | return getBiasType( |
151 | getTargetTypeForInput(use, FullyConnectedNode::InputIdx), |
152 | getTargetTypeForInput(use, FullyConnectedNode::WeightsIdx)); |
153 | } else if (use.getKind() == glow::Kinded::Kind::BatchedAddNodeKind && |
154 | idx == BatchedAddNode::SliceIdx) { |
155 | // Check if this BatchedAdd was lowered from a FullyConnectedNode. |
156 | const auto *baN = llvm::cast<BatchedAddNode>(&use); |
157 | if (isBAFromLoweredFC(baN, loweredMap_)) { |
158 | // If it came from a FullyConnected node then we need to backtrack to |
159 | // the matrix multiplication to calculate the new scale for the batched |
160 | // add slice. Slice must be a MatMul if this was lowered from a |
161 | // FullyConnected. Batch may have already been quantized. |
162 | NodeValue batch = baN->getBatch(); |
163 | assert( |
164 | (llvm::isa<MatMulNode>(batch) || llvm::isa<QuantizeNode>(batch)) && |
165 | "Batch must be either a MatMul or a Quantize." ); |
166 | MatMulNode *MM = llvm::dyn_cast<MatMulNode>(batch); |
167 | if (!MM) { |
168 | QuantizeNode *QN = llvm::cast<QuantizeNode>(batch); |
169 | assert(llvm::isa<MatMulNode>(QN->getInput()) && |
170 | "MM must be input of BA if lowered from FC." ); |
171 | MM = llvm::cast<MatMulNode>(QN->getInput()); |
172 | } |
173 | return getBiasType(getTargetTypeForOutput(MM->getLHS()), |
174 | getTargetTypeForOutput(MM->getRHS())); |
175 | } |
176 | } |
177 | return mod_.uniqueType(quantizationPrecision_, val.dims(), TQP.scale, |
178 | TQP.offset); |
179 | } |
180 | |
181 | /// Macro to be put in a switch for all nodes that may need to be replaced by |
182 | /// a LookupTable if the backend doesn't support the quantized node directly. |
183 | #define CASES_FOR_INT_LOOKUP_TABLE_REPLACEMENT \ |
184 | case Kinded::Kind::LogNodeKind: \ |
185 | case Kinded::Kind::TanhNodeKind: \ |
186 | case Kinded::Kind::SigmoidNodeKind |
187 | |
188 | /// \see FunctionConverter::canConvert. |
189 | /// Only convert nodes that use floating point types and that |
190 | /// weren't specifically marked as to-ignore with doNotQuantizeKinds_. |
191 | bool canConvert(const Node &node) const override { |
192 | // Check if the node is one that we never want to convert, e.g. SaveNode. |
193 | if (!FunctionConverter::canConvert(node)) { |
194 | return false; |
195 | } |
196 | |
197 | // Check if the node kind should not be converted based on supplied kinds |
198 | // informed to the converter. |
199 | if (doNotQuantizeKinds_.count(node.getKind())) { |
200 | return false; |
201 | } |
202 | |
203 | // Gather the input and output types that we will have once we quantize the |
204 | // node, and check if the backend supports such a node. Note that if a node |
205 | // has float inputs or outputs then we must have quantization parameters for |
206 | // them. For inputs and outputs without quantization parameters, we keep |
207 | // their original element type. |
208 | bool needsQuantization = false; |
209 | std::vector<TypeRef> inputTypes, outputTypes; |
210 | for (unsigned idx = 0, end = node.getNumInputs(); idx != end; ++idx) { |
211 | NodeValue val = node.getNthInput(idx); |
212 | if (val.getElementType() == ElemKind::FloatTy) { |
213 | needsQuantization = true; |
214 | if (!quantizationParamsExist(val)) { |
215 | CHECK(!assertAllNodesQuantized_) |
216 | << "Quantization parameters did not exist for an input NodeValue " |
217 | "that should have been quantized; input number " |
218 | << idx << " of node:\n" |
219 | << node.getDebugDesc(); |
220 | return false; |
221 | } |
222 | } |
223 | TypeRef targetTy = getTargetTypeForInput(node, idx); |
224 | inputTypes.push_back(targetTy); |
225 | } |
226 | for (unsigned idx = 0, end = node.getNumResults(); idx != end; ++idx) { |
227 | NodeValue val = node.getNthResult(idx); |
228 | if (val.getElementType() == ElemKind::FloatTy) { |
229 | needsQuantization = true; |
230 | if (!quantizationParamsExist(val)) { |
231 | CHECK(!assertAllNodesQuantized_) |
232 | << "Quantization parameters did not exist for a result of a Node " |
233 | "that should have been quantized; result number " |
234 | << idx << " of node:\n" |
235 | << node.getDebugDesc(); |
236 | return false; |
237 | } |
238 | } |
239 | TypeRef targetTy = getTargetTypeForOutput(val); |
240 | outputTypes.push_back(targetTy); |
241 | } |
242 | |
243 | // If none of the inputs were FPType then there's no quantization to |
244 | // perform, so we return that we cannot convert this node. |
245 | if (!needsQuantization) { |
246 | return false; |
247 | } |
248 | |
249 | // Only convert the node if the backend supports the newly converted node. |
250 | bool isOpSupported = |
251 | B_.isOpSupported(NodeInfo(node.getKind(), inputTypes, outputTypes)); |
252 | |
253 | // Some nodes are only supported as quantized via lookup tables. Here we |
254 | // check if such nodes are supported without lookup tables; if so, then we |
255 | // convert them. Otherwise, return whether we can support them as lookup |
256 | // tables instead, and they will be quantized as lookup tables. |
257 | switch (node.getKind()) { |
258 | CASES_FOR_INT_LOOKUP_TABLE_REPLACEMENT: |
259 | if (!isOpSupported) { |
260 | isOpSupported = B_.isOpSupported(NodeInfo( |
261 | Kinded::Kind::IntLookupTableNodeKind, inputTypes, outputTypes)); |
262 | } |
263 | break; |
264 | default: |
265 | break; |
266 | } |
267 | |
268 | // Quantizer may be set up to die if a node is only skipped during |
269 | // quantization because the backend does not support it as quantized. |
270 | if (assertAllNodesQuantized_) { |
271 | CHECK(isOpSupported) << B_.getBackendName() |
272 | << " Backend does not support node as quantized in " |
273 | << Type::getElementName(quantizationPrecision_).str() |
274 | << ":\n" |
275 | << node.getDebugDesc(); |
276 | } |
277 | |
278 | return isOpSupported; |
279 | } |
280 | |
281 | /// Helper that \returns whether quantization parameters exist |
282 | /// in \ref nodeToTQP_ given the name and result number of \p val. |
283 | bool quantizationParamsExist(const NodeValue &val) const { |
284 | auto valTQPIt = nodeToTQP_.find(val.generateNodeOutputName()); |
285 | return valTQPIt != nodeToTQP_.end(); |
286 | } |
287 | |
288 | /// Create either a QuantizeNode or DequantizeNode in \p function based on the |
289 | /// \p destTy and the type of \p val. |
290 | /// Basically, if \p val's type is floating point, this creates a |
291 | /// QuantizeNode of \p val. |
292 | /// If \p val's type is a quantized type, this creates a |
293 | /// DequantizeNode of \p val. |
294 | /// |
295 | /// \pre One of t\p val's type and \p destTy must be FloatTy and |
296 | /// the other must be a quantized type. |
297 | Node *createConversion(Function &function, const Node &node, NodeValue &val, |
298 | TypeRef destTy, bool /* isInput */) override { |
299 | assert((&function == &function_) && |
300 | "Trying to add quantize/dequantize conversion to a function other " |
301 | "than the function being quantized." ); |
302 | std::string nodeName = node.getName().str(); |
303 | if (destTy->isQuantizedType()) { |
304 | return function.createQuantize(nodeName + "_quantize" , val, destTy); |
305 | } |
306 | return function.createDequantize(nodeName + "_dequantize" , val, destTy); |
307 | } |
308 | |
309 | /// All IRConstraint cases below assume that the input and output index that |
310 | /// they are looking for the type is at idx 0. We statically assert that here |
311 | /// along with the case. |
312 | static constexpr unsigned SingleMatchingInOutTypeInputIdx = 0; |
313 | static constexpr unsigned SingleMatchingInOutTypeResultIdx = 0; |
314 | #define CASE_SINGLE_MATCHING_INOUT_TYPE(NODE_NAME_, INPUT_NAME_, OUTPUT_NAME_) \ |
315 | static_assert((NODE_NAME_##Node::INPUT_NAME_##Idx == \ |
316 | SingleMatchingInOutTypeInputIdx && \ |
317 | NODE_NAME_##Node::OUTPUT_NAME_##Idx == \ |
318 | SingleMatchingInOutTypeResultIdx), \ |
319 | #NODE_NAME_ "Node format is unexpected."); \ |
320 | case Kinded::Kind::NODE_NAME_##NodeKind |
321 | |
322 | /// Macro to be put in a switch for all the nodes that have a constraint |
323 | /// where the input and output type must be equals. |
324 | /// Note: The last case of the macro doesn't have ':' so we can put it |
325 | /// where the macro is inserted to keep the nice code formatting. |
326 | // clang-format off |
327 | #define CASES_FOR_SINGLE_MATCHING_IN_OUT_TYPE \ |
328 | CASE_SINGLE_MATCHING_INOUT_TYPE(LocalResponseNormalization, Input, Result): \ |
329 | CASE_SINGLE_MATCHING_INOUT_TYPE(Slice, Input, Result): \ |
330 | CASE_SINGLE_MATCHING_INOUT_TYPE(Reshape, Input, Result): \ |
331 | CASE_SINGLE_MATCHING_INOUT_TYPE(Transpose, Input, Result): \ |
332 | CASE_SINGLE_MATCHING_INOUT_TYPE(TopK, Input, Values): \ |
333 | CASE_SINGLE_MATCHING_INOUT_TYPE(Gather, Data, Result): \ |
334 | CASE_SINGLE_MATCHING_INOUT_TYPE(MaxPool, Input, Result): \ |
335 | CASE_SINGLE_MATCHING_INOUT_TYPE(ResizeNearest, Input, Result): \ |
336 | CASE_SINGLE_MATCHING_INOUT_TYPE(ResizeBilinear, Input, Result): \ |
337 | CASE_SINGLE_MATCHING_INOUT_TYPE(SpaceToDepth, Input, Result) |
338 | // clang-format on |
339 | |
340 | /// \see FunctionConverter::morphNode. |
341 | /// This method does the final adjustment to the output types |
342 | /// when the profile and the IR constraints do not agree. |
343 | /// E.g., the profile of LocalResponseNormalizationNode may |
344 | /// give a range that is different from the range its input. |
345 | /// However, the IR expects that both the input and output of |
346 | /// this node have the same type. |
347 | Node &morphNode(Node &node) override { |
348 | // FIXME: Right now, the TensorQuantizationParams only tracks one |
349 | // type per NodeValue, whereas we would need one type for the output and |
350 | // one for each user of that value. E.g., for |
351 | // val = node |
352 | // = use1 val |
353 | // = use2 val |
354 | // |
355 | |
356 | // We would want to track independently the types for the result of node, |
357 | // the use of val in use1, and the use of val in use2, in respectively |
358 | // outTy, inTy1, and inTy2, so that we can generate: outTy = node inTy1 = |
359 | // cast outTy = use1 inTy1 inTy2 = cast outTy = use2 inTy2 |
360 | // |
361 | // |
362 | // But instead what we track only one type like this: |
363 | // outTy = node |
364 | // = use1 outTy |
365 | // = use2 outTy |
366 | // |
367 | // However, sometimes outTy is not suitable for the input (we fix those in |
368 | // postProcessing) and sometimes, the outTy is not suitable for the output |
369 | // itself! |
370 | // What this means basically is outTy encodes the inTy constraints whereas |
371 | // they may disagree. |
372 | // The following switch fixes outTy for the few nodes where inTy and outTy |
373 | // can disagree. |
374 | // This happens for cases where for instance, the quantized parameters for |
375 | // the output have a different scale than the input, whereas the operation |
376 | // itself doesn't allow that. |
377 | // |
378 | // E.g., the constraints for `outTy = op inTy` are `outTy == inTy`, but the |
379 | // quantization profiling gave different types to outTy and inTy. |
380 | switch (node.getKind()) { |
381 | // Those cases need to be in sync with postProcessing, so we generate them |
382 | // using macros. |
383 | CASES_FOR_SINGLE_MATCHING_IN_OUT_TYPE : { |
384 | // The constraints on the IR says that the input type must |
385 | // be the same as the output type. |
386 | TypeRef inTy = |
387 | node.getNthInput(SingleMatchingInOutTypeInputIdx).getType(); |
388 | TypeRef fixedTy = mod_.uniqueType( |
389 | quantizationPrecision_, |
390 | node.getNthResult(SingleMatchingInOutTypeResultIdx).dims(), |
391 | inTy->getScale(), inTy->getOffset()); |
392 | |
393 | node.setType(SingleMatchingInOutTypeResultIdx, fixedTy); |
394 | assert(!lastMorphedNodeWithTypeChanges && |
395 | "Missed one node to rescale in postprocessing" ); |
396 | lastMorphedNodeWithTypeChanges = &node; |
397 | return node; |
398 | } |
399 | default: |
400 | return node; |
401 | } |
402 | } |
403 | |
404 | /// Perform post processing for \p node. Handles special cases, e.g. |
405 | /// requirements for input/output quantization parameters, converting to |
406 | /// lookup tables, etc. Also updates nodeToTQP_ with the added dequantization |
407 | /// nodes added for \p node. |
408 | void postProcessing(Node &node) override { |
409 | Node *quantizedNode = &node; |
410 | switch (node.getKind()) { |
411 | default: |
412 | break; |
413 | |
414 | // Cases for nodes where all inputs should use the same scale/offset as |
415 | // the output. |
416 | #define CASE_ALL_INS_MATCH_SINGLE_OUT(NODE_KIND_) \ |
417 | case Kinded::Kind::NODE_KIND_##NodeKind: { \ |
418 | auto *N = cast<NODE_KIND_##Node>(&node); \ |
419 | TypeRef outputTy = N->getResult().getType(); \ |
420 | assert(outputTy->isQuantizedType() && "Node hasn't been quantized yet?!"); \ |
421 | unsigned idx = 0; \ |
422 | for (size_t i = 0, e = N->getNumInputs(); i < e; ++i) { \ |
423 | NodeValue input = N->getNthInput(i); \ |
424 | auto argOutTy = \ |
425 | mod_.uniqueType(quantizationPrecision_, input.dims(), \ |
426 | outputTy->getScale(), outputTy->getOffset()); \ |
427 | auto *rescale = function_.createRescaleQuantized( \ |
428 | input.getNode()->getName(), input, argOutTy); \ |
429 | function_.getLogContext()->logNodeInputChange(*N, N->getNthInput(idx), \ |
430 | rescale); \ |
431 | N->setNthInput(idx++, rescale); \ |
432 | } \ |
433 | break; \ |
434 | } |
435 | CASE_ALL_INS_MATCH_SINGLE_OUT(Concat); |
436 | CASE_ALL_INS_MATCH_SINGLE_OUT(InsertTensor); |
437 | #undef CASE_ALL_INS_MATCH_SINGLE_OUT |
438 | |
439 | CASES_FOR_SINGLE_MATCHING_IN_OUT_TYPE : { |
440 | // Check that the main loop hands us the node in the order we expect: |
441 | // morph then postprocessing. |
442 | // If the assert breaks, that means that morphNode and postprocessing |
443 | // are out of sync (we probably miss a case in the switch). |
444 | assert(lastMorphedNodeWithTypeChanges == &node && |
445 | "Mismatching last node changed" ); |
446 | lastMorphedNodeWithTypeChanges = nullptr; |
447 | // These nodes do not change {S,O} of the output, they use the same |
448 | // {S,O} as the input. Make sure that rescale is applied to comply with |
449 | // the taken profile from the node. |
450 | TypeRef outputTy = getTargetTypeForOutputImpl( |
451 | NodeValue(&node, SingleMatchingInOutTypeResultIdx)); |
452 | assert(outputTy->isQuantizedType() && "Node hasn't been quantized yet?!" ); |
453 | auto outTy = mod_.uniqueType(quantizationPrecision_, outputTy->dims(), |
454 | outputTy->getScale(), outputTy->getOffset()); |
455 | NodeValue val = node.getNthResult(SingleMatchingInOutTypeResultIdx); |
456 | // "val" may not have any users if the output goes unused, e.g. if we are |
457 | // quantizing a TopKNode and only indices is used. |
458 | if (val.getNumUsers() == 0) { |
459 | break; |
460 | } |
461 | // "node" should have only one use, the dequantize node. |
462 | // Update this use. |
463 | assert( |
464 | val.hasOneUse() && |
465 | llvm::dyn_cast<DequantizeNode>((*val.getUsers().begin()).getUser()) && |
466 | "This node should only be used by the dequantize node" ); |
467 | auto *dequantize = |
468 | llvm::dyn_cast<DequantizeNode>((*val.getUsers().begin()).getUser()); |
469 | auto *rescale = |
470 | function_.createRescaleQuantized(node.getName(), val, outTy); |
471 | quantizedNode = rescale; |
472 | |
473 | function_.getLogContext()->logNodeInputChange( |
474 | *dequantize, dequantize->getNthInput(DequantizeNode::InputIdx), |
475 | rescale); |
476 | dequantize->setNthInput(DequantizeNode::InputIdx, rescale); |
477 | break; |
478 | } |
479 | |
480 | CASES_FOR_INT_LOOKUP_TABLE_REPLACEMENT : { |
481 | // If these nodes aren't supported then we convert them to a lookup table. |
482 | NodeInfo NI(node); |
483 | if (B_.isOpSupported(NI)) { |
484 | break; |
485 | } |
486 | assert(B_.isOpSupported(NodeInfo(Kinded::Kind::IntLookupTableNodeKind, |
487 | NI.getInTypes(), NI.getOutTypes())) && |
488 | "Backend should support IntLookupTable at this point." ); |
489 | switch (node.getKind()) { |
490 | case Kinded::Kind::LogNodeKind: |
491 | quantizedNode = replaceQuantizedLogWithLookupTable( |
492 | function_, llvm::cast<LogNode>(node), schema_); |
493 | break; |
494 | case Kinded::Kind::ExpNodeKind: |
495 | quantizedNode = replaceQuantizedExpWithLookupTable( |
496 | function_, llvm::cast<ExpNode>(node), schema_); |
497 | break; |
498 | case Kinded::Kind::TanhNodeKind: |
499 | quantizedNode = replaceQuantizedTanhWithLookupTable( |
500 | function_, llvm::cast<TanhNode>(node), schema_); |
501 | break; |
502 | case Kinded::Kind::SigmoidNodeKind: |
503 | quantizedNode = replaceQuantizedSigmoidWithLookupTable( |
504 | function_, llvm::cast<SigmoidNode>(node), schema_); |
505 | break; |
506 | default: |
507 | llvm_unreachable("Unsupported case for converting to lookup table." ); |
508 | } |
509 | } |
510 | } |
511 | assert(!lastMorphedNodeWithTypeChanges && "Type not fixed" ); |
512 | |
513 | // Update nodeToTQP_ since we've added in dequantized nodes to the output of |
514 | // now-quantized nodes. This is necessary because later we try to quantize |
515 | // nodes only if we have a quantized type for its operands (i.e. a profile |
516 | // in nodeToTQP_). However its inputs may already have been quantized, which |
517 | // means its inputs are replaced by a dequantize node, and no profile would |
518 | // be found in nodeToTQP_ for the dequantize node_. Thus we add TQPs for |
519 | // the dequantize node given the scale/offset it is dequantizing from. |
520 | for (unsigned outNum = 0, e = quantizedNode->getNumResults(); outNum != e; |
521 | ++outNum) { |
522 | NodeValue val = quantizedNode->getNthResult(outNum); |
523 | if (!val.getType()->isQuantizedType()) { |
524 | continue; |
525 | } |
526 | // Not all float outputs will have a dequantize added to its quantized |
527 | // output, as we may just be using some outputs of a quantized node |
528 | // (e.g. when quantizing a TopK but only using the Indices output, no |
529 | // dequantize node is added to Values). |
530 | if (val.getNumUsers() == 0) { |
531 | continue; |
532 | } |
533 | assert( |
534 | val.hasOneUse() && |
535 | llvm::dyn_cast<DequantizeNode>((*val.getUsers().begin()).getUser()) && |
536 | "This node should only be used by the dequantize node" ); |
537 | auto *dequantize = |
538 | llvm::dyn_cast<DequantizeNode>((*val.getUsers().begin()).getUser()); |
539 | TypeRef outTy = val.getType(); |
540 | auto name = dequantize->getResult().generateNodeOutputName(); |
541 | nodeToTQP_[name] = {outTy->getScale(), outTy->getOffset()}; |
542 | } |
543 | } // namespace |
544 | |
545 | void convertTensor(Tensor &tensor, TypeRef destTy) override { |
546 | assert(tensor.getElementType() == ElemKind::FloatTy && |
547 | (destTy->getElementType() == ElemKind::Int8QTy || |
548 | destTy->getElementType() == ElemKind::Int16QTy) && |
549 | "Dequantization not implemented" ); |
550 | |
551 | tensor = quantizeTensor(tensor, {destTy->getScale(), destTy->getOffset()}, |
552 | destTy->getElementType()); |
553 | } |
554 | |
555 | private: |
556 | /// Shortcut to the module of function_. |
557 | Module &mod_; |
558 | /// Backend used to check is a quantized operator is supported. |
559 | const Backend &B_; |
560 | /// Quantization schema. |
561 | quantization::Schema schema_; |
562 | /// Quantization precision. |
563 | const ElemKind quantizationPrecision_; |
564 | /// Set of node kinds that should not be quantized. |
565 | const KindSet &doNotQuantizeKinds_; |
566 | /// Map the (name of a node, idx) to its quantization parameters. |
567 | std::unordered_map<std::string, TensorQuantizationParams> nodeToTQP_; |
568 | /// For debug, keep track of the last node that we changed because of IR |
569 | /// constraints. |
570 | Node *lastMorphedNodeWithTypeChanges; |
571 | /// A map between quantization profiling names of NodeValues that were lowered |
572 | /// from each other. Maps to a set of names of NodeValues and their NodeKinds |
573 | /// that were replaced by the NodeValue (whose output name is the key) that |
574 | /// replaced them. |
575 | const LoweredInfoMap &loweredMap_; |
576 | /// Used for debugging if we expect all nodes to be quantized by the |
577 | /// quantizer. |
578 | bool assertAllNodesQuantized_; |
579 | /// Precision used for bias quantization for Convolution and FullyConnected. |
580 | /// This allows specializing the bias quantization. |
581 | const ElemKind quantizationPrecisionBias_; |
582 | |
583 | // If true, don't apply quantization to FC bias inputs. |
584 | const bool skipQuantizeFCBias_; |
585 | |
586 | public: |
587 | /// Creates a function quantizer for function \p F using the quantization |
588 | /// configuration \p quantConfig. This method quantizes as many nodes as |
589 | /// permitted by the backend \p B. The map \p loweredMap contains info about |
590 | /// what nodes were lowered from what, to be used during quantization. |
591 | /// \p doNotQuantizeKinds lists kinds to not quantize, even if a profile was |
592 | /// gathered for them and the backend supports the quantized operation. |
593 | FunctionQuantizer(Function &F, const Backend &B, |
594 | const QuantizationConfiguration &quantConfig, |
595 | const KindSet &doNotQuantizeKinds, |
596 | const LoweredInfoMap &loweredMap) |
597 | : FunctionConverter(F), mod_(*F.getParent()), B_(B), |
598 | schema_(quantConfig.schema), |
599 | quantizationPrecision_(quantConfig.precision), |
600 | doNotQuantizeKinds_(doNotQuantizeKinds), loweredMap_(loweredMap), |
601 | assertAllNodesQuantized_(quantConfig.assertAllNodesQuantized), |
602 | quantizationPrecisionBias_(quantConfig.precisionBias), |
603 | skipQuantizeFCBias_(quantConfig.skipQuantizeFCBias) { |
604 | |
605 | // Compute the TensorQuantizationParams using the profiling infos. |
606 | auto quantizationInfos = |
607 | generateNodeQuantizationInfos(&F, quantConfig, loweredMap); |
608 | |
609 | // Build a mapping between node name and TensorQuantizatonParams. |
610 | for (const auto &quantizationInfo : quantizationInfos) { |
611 | nodeToTQP_.emplace(quantizationInfo.nodeOutputName_, |
612 | quantizationInfo.tensorQuantizationParams_); |
613 | } |
614 | |
615 | // Use for debug purposes. |
616 | lastMorphedNodeWithTypeChanges = nullptr; |
617 | (void)assertAllNodesQuantized_; |
618 | } |
619 | |
620 | /// Traverse all nodes to find applicable quantized nodes, and convert them |
621 | /// to RowwiseQuantized versions if required inputs are Constant. |
622 | void enableRowwise() { |
623 | auto nodeIt = function_.getNodes().end(); |
624 | auto stopIt = function_.getNodes().begin(); |
625 | do { |
626 | --nodeIt; |
627 | Node &node = *nodeIt; |
628 | auto *Q = llvm::dyn_cast<DequantizeNode>(&node); |
629 | if (!Q) { |
630 | continue; |
631 | } |
632 | |
633 | // ---------------------------------------------------------------------- |
634 | // After function "convert()" is called, one FullyConnectedNode is |
635 | // converted into: |
636 | // [fp32 input] [fp32 weights] [fp32 bias] |
637 | // | | | |
638 | // [QuantizeNode] [QuantizeNode] [QuantizeNode (optional)] |
639 | // \ | / |
640 | // [ FullyConnectedNode ] |
641 | // | |
642 | // [DequantizeNode] |
643 | // We need to find the above pattern and convert it to: |
644 | // [fp32 input] [fp32 weights] [fp32 bias] |
645 | // | / | \ | |
646 | // | [int8 weights] [scales] [offsets] | |
647 | // [QuantizeNode] | | | [QuantizeNode] |
648 | // \ | | | / |
649 | // [ RowwiseQuantizedFullyConnectedNode ] |
650 | // | |
651 | // [DequantizeNode] |
652 | // ---------------------------------------------------------------------- |
653 | bool foundFC = false; |
654 | NodeValue input, weights, bias, result; |
655 | if (auto *fcN = llvm::dyn_cast<FullyConnectedNode>(Q->getInput())) { |
656 | foundFC = true; |
657 | input = fcN->getInput(); |
658 | weights = fcN->getWeights(); |
659 | bias = fcN->getBias(); |
660 | result = fcN->getResult(); |
661 | } else if (const auto *baN = |
662 | llvm::dyn_cast<BatchedAddNode>(Q->getInput())) { |
663 | if (isBAFromLoweredFC(baN, loweredMap_)) { |
664 | foundFC = true; |
665 | NodeValue batch = baN->getBatch(); |
666 | |
667 | // All quantization has occurred at this point, but optimizations |
668 | // haven't eliminated extra quantize/dequantize nodes. Look |
669 | // backwards through them to find the MatMul of the FC. |
670 | assert(llvm::isa<QuantizeNode>(batch)); |
671 | QuantizeNode *QN = llvm::cast<QuantizeNode>(batch); |
672 | assert(llvm::isa<DequantizeNode>(QN->getInput())); |
673 | DequantizeNode *DQN = llvm::cast<DequantizeNode>(QN->getInput()); |
674 | assert(llvm::isa<MatMulNode>(DQN->getInput())); |
675 | MatMulNode *MM = llvm::cast<MatMulNode>(DQN->getInput()); |
676 | |
677 | input = MM->getLHS(); |
678 | weights = MM->getRHS(); |
679 | bias = baN->getSlice(); |
680 | result = baN->getResult(); |
681 | } |
682 | } |
683 | if (foundFC) { |
684 | // Only convert quantized FullyConnected Node (or its equivalent lowered |
685 | // representation in MatMul + BatchedAdd form). |
686 | if (input.getType()->isQuantizedType() && |
687 | llvm::isa<QuantizeNode>(weights.getNode()) && |
688 | result.getType()->isQuantizedType()) { |
689 | auto *wq = llvm::dyn_cast<QuantizeNode>(weights.getNode()); |
690 | // For RowwiseQuantizedFullyConnected, the weights need to be |
691 | // constant. |
692 | if (Constant *wc = llvm::dyn_cast<Constant>(wq->getInput())) { |
693 | auto *fcq = function_.createRowwiseQuantizedFullyConnected( |
694 | "rowwiseqfc" , input, wc, bias, result.getType(), schema_, |
695 | /* transposeWeight */ true); |
696 | // Replace usages of quantized FC node (or its equivalent lowered |
697 | // representation MM + BA) to RowwiseQuantizedFullyConnectedNode. |
698 | result.replaceAllUsesOfWith(fcq->getResult()); |
699 | } |
700 | } |
701 | } |
702 | |
703 | // Convert SLWS from normal version to fused rowwise-quantized version if |
704 | // applicable. Data must be Constant for this to occur. We also will not |
705 | // quantize the weights as we do for the default normal quantized SLWS, as |
706 | // the rowwise version uses float weights. |
707 | if (auto *SLWS = |
708 | llvm::dyn_cast<SparseLengthsWeightedSumNode>(Q->getInput())) { |
709 | NodeValue data = SLWS->getData(); |
710 | |
711 | // It's possible we skipped quantizing this node due to |
712 | // doNotQuantizeKinds, and so may not need to process it. |
713 | auto *dataQN = llvm::dyn_cast<QuantizeNode>(data.getNode()); |
714 | if (!dataQN) { |
715 | continue; |
716 | } |
717 | |
718 | // Can only convert to rowwise-quantized version if the data input is |
719 | // Constant. |
720 | auto *dataC = llvm::dyn_cast<Constant>(dataQN->getInput()); |
721 | if (!dataC) { |
722 | continue; |
723 | } |
724 | |
725 | // Right now we quantize the weights input for SLWS. However, the |
726 | // rowwise-quantized version does not, so we will skip the QN. At this |
727 | // point we know the SLWS was quantized, so the weights input must be a |
728 | // quantize node. |
729 | auto *weightsQN = llvm::dyn_cast<QuantizeNode>(SLWS->getWeights()); |
730 | assert(weightsQN && "Weights should have been quantized" ); |
731 | NodeValue weightsF = weightsQN->getInput(); |
732 | |
733 | auto *FRWQSLWS = |
734 | function_.createFusedRowwiseQuantizedSparseLengthsWeightedSum( |
735 | SLWS->getName(), dataC->getPayloadMutable(), weightsF, |
736 | SLWS->getIndices(), SLWS->getLengths(), |
737 | /* fusedElemKind */ ElemKind::UInt8FusedQTy, |
738 | /* useFP16Accumulation */ false, SLWS->getLengthsMode(), |
739 | SLWS->getAvgLength()); |
740 | |
741 | // Fused RWQSLWS stores the fused scales and offsets in trailing |
742 | // columns. If the input was single dimensional then it adds extra |
743 | // dimensions to both input and output. Therefore reshape back to the |
744 | // expected output shape in case the input to the SLWS did not have a |
745 | // second dimension but the fused version added one to insert columns. |
746 | auto *RN = function_.createReshape("reshape" , FRWQSLWS, |
747 | SLWS->getResult().dims()); |
748 | |
749 | // Replace the dequantize node of the original SLWS with the FRWQSLWS, |
750 | // as its output is already in float. |
751 | Q->getResult().replaceAllUsesOfWith(RN->getResult()); |
752 | } |
753 | |
754 | } while (nodeIt != stopIt); |
755 | |
756 | cleanUp(); |
757 | assert(function_.verify() && "Conversion led to invalid function" ); |
758 | } |
759 | |
760 | /// Traverse all nodes to find applicable quantized nodes, and convert them |
761 | /// to ChannelwiseQuantized versions if required inputs are Constant. |
762 | void enableChannelwise() { |
763 | auto nodeIt = function_.getNodes().end(); |
764 | auto stopIt = function_.getNodes().begin(); |
765 | do { |
766 | --nodeIt; |
767 | Node &node = *nodeIt; |
768 | auto *Q = llvm::dyn_cast<DequantizeNode>(&node); |
769 | if (!Q) { |
770 | continue; |
771 | } |
772 | |
773 | // ---------------------------------------------------------------------- |
774 | // After function "convert()" is called, one ConvolutionNode is |
775 | // converted into: |
776 | // [fp32 input] [fp32 filter] [fp32 bias] |
777 | // | | | |
778 | // [QuantizeNode] [QuantizeNode] [QuantizeNode] |
779 | // \ | / |
780 | // [ ConvolutionNode ] |
781 | // | |
782 | // [DequantizeNode] |
783 | // We need to find the above pattern and convert it to: |
784 | // [fp32 input] [fp32 filter] [fp32 bias] |
785 | // | / | \ / | \ |
786 | // | [int8 filter|scales|offsets] [int8 bias|scales|offsets] |
787 | // [QuantizeNode] | | | | | | |
788 | // \ | | | / / / |
789 | // [ ChannelwiseQuantizedConvolutionNode ] |
790 | // | |
791 | // [DequantizeNode] |
792 | // ---------------------------------------------------------------------- |
793 | |
794 | // Replace ConvolutionNode with ChannelwiseQuantizedConvolutionNode |
795 | // if the filter and bias operands are constant. The node creation |
796 | // function will be provided with the floating-point filter and |
797 | // bias constants and will perform channel wise quantization. |
798 | if (auto *convNode = llvm::dyn_cast<ConvolutionNode>(Q->getInput())) { |
799 | |
800 | NodeValue input = convNode->getInput(); |
801 | NodeValue filter = convNode->getFilter(); |
802 | NodeValue bias = convNode->getBias(); |
803 | NodeValue result = convNode->getResult(); |
804 | |
805 | if (input.getType()->isQuantizedType() && |
806 | llvm::isa<QuantizeNode>(filter.getNode()) && |
807 | llvm::isa<QuantizeNode>(bias.getNode()) && |
808 | result.getType()->isQuantizedType()) { |
809 | |
810 | auto *filterQ = llvm::dyn_cast<QuantizeNode>(filter.getNode()); |
811 | Constant *filterC = llvm::dyn_cast<Constant>(filterQ->getInput()); |
812 | auto *biasQ = llvm::dyn_cast<QuantizeNode>(bias.getNode()); |
813 | Constant *biasC = llvm::dyn_cast<Constant>(biasQ->getInput()); |
814 | |
815 | if (filterC && biasC) { |
816 | // When the overall requested quantization schema is asymmetric |
817 | // we use symmetric quantization schema for the channelwise filter |
818 | // and bias in order to be closer to the TFLite quantization specs: |
819 | // https://www.tensorflow.org/lite/performance/quantization_spec |
820 | quantization::Schema quantSchema = schema_; |
821 | if (quantSchema == quantization::Schema::Asymmetric) { |
822 | quantSchema = quantization::Schema::Symmetric; |
823 | } |
824 | // Create per channel quantized Convolution. |
825 | auto *convNodeCWQ = function_.createChannelwiseQuantizedConv( |
826 | "ChannelwiseQuantizedConv" , input, filterC, biasC, |
827 | /* filterScales */ nullptr, /* filterOffsets */ nullptr, |
828 | /* biasScales */ nullptr, /* biasOffsets */ nullptr, |
829 | result.getType(), convNode->getKernels(), |
830 | convNode->getStrides(), convNode->getPads(), |
831 | convNode->getGroup(), convNode->getDilation(), |
832 | /* quantizeFilter */ true, /* quantizeBias */ true, quantSchema, |
833 | quantizationPrecision_, quantizationPrecisionBias_); |
834 | convNodeCWQ->setFusedActivation(convNode->getFusedActivation()); |
835 | convNodeCWQ->setFusedActivationArgs( |
836 | convNode->getFusedActivationArgs()); |
837 | result.replaceAllUsesOfWith(convNodeCWQ->getResult()); |
838 | } |
839 | } |
840 | } |
841 | } while (nodeIt != stopIt); |
842 | cleanUp(); |
843 | assert(function_.verify() && "Conversion led to invalid function" ); |
844 | } |
845 | }; // namespace |
846 | |
847 | } // namespace |
848 | |
849 | namespace glow { |
850 | namespace quantization { |
851 | |
852 | Node *replaceQuantizedLogWithLookupTable(Function &F, const LogNode &LN, |
853 | Schema schema) { |
854 | IntLookupTableNode *ILT = F.createIntLog( |
855 | LN.getName().str() + ".log" , LN.getInput(), LN.getResult().getType()); |
856 | LN.getResult().replaceAllUsesOfWith(ILT); |
857 | return ILT; |
858 | } |
859 | |
860 | Node *replaceQuantizedExpWithLookupTable(Function &F, const ExpNode &EN, |
861 | Schema schema) { |
862 | IntLookupTableNode *ELT = F.createIntExp( |
863 | EN.getName().str() + ".exp" , EN.getInput(), EN.getResult().getType()); |
864 | EN.getResult().replaceAllUsesOfWith(ELT); |
865 | return ELT; |
866 | } |
867 | |
868 | Node *replaceQuantizedTanhWithLookupTable(Function &F, const TanhNode &TN, |
869 | Schema schema) { |
870 | // Quantized tanh operator expects input to be in a certain floating point |
871 | // range. This operator works based on the precomputed table and has to |
872 | // process input in a range of [-3.0, 3.0]. Tanh asymptotically approaches |
873 | // +/-1.0 and is already +/-.995 at +/-3.0. |
874 | // The output quantization parameters are chosen to represent the floating |
875 | // point range of [-1.0, 1.0]. |
876 | TypeRef inpTy = TN.getInput().getType(); |
877 | TypeRef outTy = TN.getResult().getType(); |
878 | auto inputQuantizationParams = glow::quantization::chooseQuantizationParams( |
879 | {-3.0, 3.0}, schema, inpTy->getElementType()); |
880 | auto tanhInTy = F.getParent()->uniqueType( |
881 | inpTy->getElementType(), TN.getResult().dims(), |
882 | inputQuantizationParams.scale, inputQuantizationParams.offset); |
883 | |
884 | // Make sure input is clipped in [-3.0, 3.0] floating point range. |
885 | auto *rescaleInputNode = |
886 | F.createRescaleQuantized(TN.getName(), TN.getInput(), tanhInTy); |
887 | |
888 | // Make sure output is clipped in [-1.0, 1.0] floating point range. |
889 | auto outputQuantizationParams = glow::quantization::chooseQuantizationParams( |
890 | {-1.0, 1.0}, schema, outTy->getElementType()); |
891 | auto resultOutTy = F.getParent()->uniqueType( |
892 | outTy->getElementType(), rescaleInputNode->getResult().dims(), |
893 | outputQuantizationParams.scale, outputQuantizationParams.offset); |
894 | |
895 | // Note: The actual lookup table is created inside this call. |
896 | auto *quantizedNode = |
897 | F.createIntTanh(TN.getName(), rescaleInputNode, resultOutTy); |
898 | |
899 | auto *rescaleOutputNode = F.createRescaleQuantized( |
900 | TN.getName(), quantizedNode, TN.getResult().getType()); |
901 | |
902 | TN.getResult().replaceAllUsesOfWith(rescaleOutputNode); |
903 | return rescaleOutputNode; |
904 | } |
905 | |
906 | Node *replaceQuantizedSigmoidWithLookupTable(Function &F, const SigmoidNode &SN, |
907 | Schema schema) { |
908 | // Quantized sigmoid operator expects input to be in a certain floating |
909 | // point range. This operator works based on the precomputed table and has |
910 | // to process input in a range of [-6.0, 6.0]. Sigmoid asymptotically |
911 | // approaches 0 at -inf and 1 at +inf. It has values of 0.00247262 and |
912 | // 0.997527 at -6.0 and 6.0 correspondingly. The output quantization |
913 | // parameters are chosen to represent the floating point range of [0, 1.0]. |
914 | TypeRef inpTy = SN.getInput().getType(); |
915 | TypeRef outTy = SN.getResult().getType(); |
916 | auto inputQuantizationParams = glow::quantization::chooseQuantizationParams( |
917 | {-6.0, 6.0}, schema, inpTy->getElementType()); |
918 | auto sigmoidInTy = F.getParent()->uniqueType( |
919 | inpTy->getElementType(), SN.getResult().dims(), |
920 | inputQuantizationParams.scale, inputQuantizationParams.offset); |
921 | |
922 | // Make sure input is clipped in [-6.0, 6.0] floating point range. |
923 | auto *rescaleInputNode = |
924 | F.createRescaleQuantized(SN.getName(), SN.getInput(), sigmoidInTy); |
925 | |
926 | // Make sure output is clipped in [0.0, 1.0] floating point range. |
927 | auto outputQuantizationParams = glow::quantization::chooseQuantizationParams( |
928 | {0.0, 1.0}, schema, outTy->getElementType()); |
929 | auto resultOutTy = F.getParent()->uniqueType( |
930 | outTy->getElementType(), rescaleInputNode->getResult().dims(), |
931 | outputQuantizationParams.scale, outputQuantizationParams.offset); |
932 | |
933 | // Note: The actual lookup table is created inside this call. |
934 | auto *quantizedNode = |
935 | F.createIntSigmoid(SN.getName(), rescaleInputNode, resultOutTy); |
936 | |
937 | auto *rescaleOutputNode = F.createRescaleQuantized( |
938 | SN.getName(), quantizedNode, SN.getResult().getType()); |
939 | |
940 | SN.getResult().replaceAllUsesOfWith(rescaleOutputNode); |
941 | |
942 | return rescaleOutputNode->getResult(); |
943 | } |
944 | |
945 | /// Helper which, given the output name \p currName of some node, looks for |
946 | /// corresponding names in \p loweredMap which represent any names that this |
947 | /// node was lowered from. If any are found then they are inserted into \p |
948 | /// profilingInfos along with \p TPP. |
949 | static void |
950 | findAndInsertLoweredInfos(llvm::StringRef currName, |
951 | const LoweredInfoMap &loweredMap, |
952 | std::vector<NodeProfilingInfo> &profilingInfos, |
953 | const TensorProfilingParams &TPP) { |
954 | auto currSetIt = loweredMap.find(currName); |
955 | if (currSetIt == loweredMap.end()) { |
956 | return; |
957 | } |
958 | |
959 | // Get the set of names corresponding to currName. All names in the set are |
960 | // names that were originally lowered into currName. |
961 | auto &currSet = currSetIt->getValue(); |
962 | |
963 | // For each of the names (currOrigName), insert them into profilingInfos, |
964 | // and then recursively find and insert other names in case currOrigName was |
965 | // also lowered from a previous node. |
966 | for (auto i = currSet.begin(), e = currSet.end(); i != e; ++i) { |
967 | llvm::StringRef currOrigName = i->getName(); |
968 | profilingInfos.emplace_back(currOrigName.str(), TPP); |
969 | findAndInsertLoweredInfos(currOrigName, loweredMap, profilingInfos, TPP); |
970 | } |
971 | } |
972 | |
973 | std::vector<NodeProfilingInfo> |
974 | generateNodeProfilingInfos(PlaceholderBindings &bindings, const Function *F, |
975 | const LoweredInfoMap &loweredMap) { |
976 | std::vector<NodeProfilingInfo> profilingInfos; |
977 | for (auto &node : F->getNodes()) { |
978 | auto *QPN = llvm::dyn_cast<QuantizationProfileNode>(&node); |
979 | if (QPN) { |
980 | |
981 | // Extract the profiling information from the placeholders after running |
982 | // the network in profiling mode. |
983 | auto compInfoH = bindings.get(QPN->getComputationInfoPlaceholder()) |
984 | ->getHandle<float>(); |
985 | auto *histogramT = bindings.get(QPN->getHistogramPlaceholder()); |
986 | float min = compInfoH.raw(0); |
987 | float max = compInfoH.raw(1); |
988 | |
989 | // Generate a name to be used as profiling information identifier. |
990 | std::string fullOutputName = NodeValue::generateNodeOutputName( |
991 | QPN->getProfiledNodeName(), QPN->getProfiledOutputNumber()); |
992 | |
993 | // Set TensorProfilingParams for this node output. |
994 | TensorProfilingParams TPP(min, max, *histogramT); |
995 | profilingInfos.emplace_back(fullOutputName, TPP); |
996 | |
997 | // If the NodeValue represented by fullOutputName was created via lowering |
998 | // of another original NodeValue, then generate node profiling info for |
999 | // the original NodeValue using the same profiling parameters. |
1000 | findAndInsertLoweredInfos(fullOutputName, loweredMap, profilingInfos, |
1001 | TPP); |
1002 | } |
1003 | } |
1004 | return profilingInfos; |
1005 | } |
1006 | |
1007 | std::vector<NodeQuantizationInfo> |
1008 | generateNodeQuantizationInfos(Function *F, |
1009 | const QuantizationConfiguration &quantConfig, |
1010 | const LoweredInfoMap &loweredMap) { |
1011 | std::vector<NodeQuantizationInfo> quantizationInfos; |
1012 | for (const auto &profilingInfo : quantConfig.infos) { |
1013 | // Get node value from node output name. |
1014 | std::string nodeOutputName = profilingInfo.nodeOutputName_; |
1015 | NodeValue nodeOutput = F->getNodeValueByName(nodeOutputName); |
1016 | |
1017 | // Skip if the node is not part of the graph. |
1018 | if (!nodeOutput.getNode()) { |
1019 | continue; |
1020 | } |
1021 | |
1022 | // Default quantization schema. |
1023 | Schema schema = quantConfig.schema; |
1024 | |
1025 | // Default target precision. |
1026 | ElemKind precision = quantConfig.precision; |
1027 | |
1028 | // Default calibration mode. |
1029 | Calibration calibration = quantConfig.calibration; |
1030 | |
1031 | // The TensorQuantizationParams must be computed using the target |
1032 | // precision used during the actual quantization. The code below |
1033 | // reflects the same logic as the one used in the function |
1034 | // FunctionQuantizer::getTargetTypeForInput for specializing the bias |
1035 | // quantization precision. Since bias quantization is sensitive we will |
1036 | // choose to use no calibration. |
1037 | for (const auto &use : nodeOutput.getUsers()) { |
1038 | const auto *user = use.getUser(); |
1039 | if ((user->getKind() == glow::Kinded::Kind::ConvolutionNodeKind) && |
1040 | (user->getNthInput(ConvolutionNode::BiasIdx) == nodeOutput)) { |
1041 | // Found bias for ConvolutionNode. |
1042 | precision = quantConfig.precisionBias; |
1043 | calibration = Calibration::None; |
1044 | continue; |
1045 | } |
1046 | if ((user->getKind() == glow::Kinded::Kind::Convolution3DNodeKind) && |
1047 | (user->getNthInput(Convolution3DNode::BiasIdx) == nodeOutput)) { |
1048 | // Found bias for Convolution3DNode. |
1049 | precision = quantConfig.precisionBias; |
1050 | calibration = Calibration::None; |
1051 | continue; |
1052 | } |
1053 | if ((user->getKind() == glow::Kinded::Kind::ConvTransposeNodeKind) && |
1054 | (user->getNthInput(ConvTransposeNode::BiasIdx) == nodeOutput)) { |
1055 | // Found bias for ConvTranspose. |
1056 | precision = quantConfig.precisionBias; |
1057 | calibration = Calibration::None; |
1058 | continue; |
1059 | } |
1060 | if ((user->getKind() == glow::Kinded::Kind::FullyConnectedNodeKind) && |
1061 | (user->getNthInput(FullyConnectedNode::BiasIdx) == nodeOutput)) { |
1062 | // Found bias for FullyConnectedNode. |
1063 | precision = quantConfig.precisionBias; |
1064 | calibration = Calibration::None; |
1065 | continue; |
1066 | } |
1067 | if ((user->getKind() == glow::Kinded::Kind::BatchedAddNodeKind) && |
1068 | (user->getNthInput(BatchedAddNode::SliceIdx) == nodeOutput)) { |
1069 | // Find out if this BatchAddNode was lowered from FullyConnectedNode. |
1070 | const auto *baN = llvm::cast<BatchedAddNode>(user); |
1071 | if (isBAFromLoweredFC(baN, loweredMap)) { |
1072 | // Found bias for lowered FullyConnectedNode. |
1073 | precision = quantConfig.precisionBias; |
1074 | calibration = Calibration::None; |
1075 | continue; |
1076 | } |
1077 | } |
1078 | } |
1079 | |
1080 | // Do not calibrate the quantization parameters for scalars. |
1081 | if (nodeOutput.getType()->size() == 1) { |
1082 | calibration = Calibration::None; |
1083 | } |
1084 | |
1085 | // Disable the quantization calibration for constant weights. |
1086 | if (!quantConfig.calibrateConstants && |
1087 | llvm::isa<Constant>(nodeOutput.getNode())) { |
1088 | calibration = Calibration::None; |
1089 | } |
1090 | |
1091 | // Compute the TensorQuantizationParams using the profiling information |
1092 | // and the target precision and calibration. |
1093 | TensorProfilingParams TPP = profilingInfo.tensorProfilingParams_; |
1094 | TensorQuantizationParams TQP = |
1095 | chooseQuantizationParams(TPP, schema, precision, calibration); |
1096 | quantizationInfos.emplace_back(nodeOutputName, TQP); |
1097 | } |
1098 | |
1099 | return quantizationInfos; |
1100 | } |
1101 | |
1102 | void quantizeFunction(Function *F, const QuantizationConfiguration &quantConfig, |
1103 | const Backend &B, const LoweredInfoMap &loweredMap, |
1104 | const KindSet &doNotQuantizeKinds) { |
1105 | DCHECK(quantConfig.precision == ElemKind::Int8QTy || |
1106 | quantConfig.precision == ElemKind::UInt8QTy || |
1107 | quantConfig.precision == ElemKind::Int16QTy) |
1108 | << "Only Int8, UInt8, and Int16 quantization supported" ; |
1109 | |
1110 | FunctionQuantizer quantizer(*F, B, quantConfig, doNotQuantizeKinds, |
1111 | loweredMap); |
1112 | quantizer.convert(); |
1113 | |
1114 | // Enable rowwise quantization for FullyConnected node. |
1115 | if (quantConfig.enableRowwise) { |
1116 | quantizer.enableRowwise(); |
1117 | } |
1118 | |
1119 | // Enable channelwise quantization for Convolution node. |
1120 | if (quantConfig.enableChannelwise) { |
1121 | quantizer.enableChannelwise(); |
1122 | } |
1123 | } |
1124 | |
1125 | } // namespace quantization |
1126 | } // namespace glow |
1127 | |