1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Quantization/Quantization.h"
18
19#include "glow/Backend/Backend.h"
20#include "glow/Converter/FunctionConverter.h"
21
22#include <cmath>
23#include <unordered_set>
24#include <vector>
25
26using llvm::cast;
27
28namespace {
29
30using namespace glow;
31using namespace glow::quantization;
32
33/// \returns whether BatchedAddNode \p baN was originally lowered from a
34/// FullyConnectedNode based on the given \p loweredMap.
35static bool isBAFromLoweredFC(const BatchedAddNode *baN,
36 const LoweredInfoMap &loweredMap) {
37 // Look for the set of NodeNameAndKinds corresponding to the
38 // BatchedAdd. If one exists, this means it was lowered.
39 auto it = loweredMap.find(baN->getResult().generateNodeOutputName());
40 if (it == loweredMap.end()) {
41 return false;
42 }
43
44 // Look through the set looking to see if the BatchedAdd was lowered
45 // from a FullyConnectedNode.
46 auto &set = it->getValue();
47 for (auto i = set.begin(), e = set.end(); i != e; ++i) {
48 if (i->getKind() == glow::Kinded::Kind::FullyConnectedNodeKind) {
49 return true;
50 }
51 }
52 return false;
53}
54
55/// This class produces a quantized function based on a provided profile.
56class FunctionQuantizer : public FunctionConverter {
57protected:
58 /// Get the type that \p out should have at the end of the conversion
59 /// process regardless of what is its current type.
60 /// This is similar to ::getTargetTypeForOutput except that \p out
61 /// doesn't need to be a floating point type for this method to
62 /// return the target type.
63 /// The reason we need this method is because we may morph the type
64 /// of \p out to match some IR constraints, but the final type still
65 /// needs to be known to insert rescale nodes.
66 TypeRef getTargetTypeForOutputImpl(const NodeValue &out) const {
67 auto outTQPIt = nodeToTQP_.find(out.generateNodeOutputName());
68 assert(outTQPIt != nodeToTQP_.end() &&
69 "Missing quantization params for a node");
70
71 const TensorQuantizationParams &TQP = outTQPIt->second;
72 return mod_.uniqueType(quantizationPrecision_, out.dims(), TQP.scale,
73 TQP.offset);
74 }
75
76 /// \see FunctionConverter::getTargetTypeForOutput.
77 /// \returns quantized type for \p out if any; if not quantizable, then
78 /// \returns the original type.
79 TypeRef getTargetTypeForOutput(const NodeValue &out) const override {
80 if (out.getElementType() != ElemKind::FloatTy) {
81 return out.getType();
82 }
83 return getTargetTypeForOutputImpl(out);
84 }
85
86 /// \see FunctionConverter::getTargetTypeForInput.
87 /// \returns the quantized type for the \p idx-th argument of \p use, if any;
88 /// if not quantizable, then \returns the original type.
89 TypeRef getTargetTypeForInput(const Node &use, unsigned idx) const override {
90 NodeValue val = use.getNthInput(idx);
91
92 // Do not quantize non floating point type, e.g., Index type.
93 if (val.getElementType() != ElemKind::FloatTy) {
94 return val.getType();
95 }
96
97 auto valTQPIt = nodeToTQP_.find(val.generateNodeOutputName());
98 assert(valTQPIt != nodeToTQP_.end() &&
99 "Missing quantization params for a node");
100
101 const TensorQuantizationParams &TQP = valTQPIt->second;
102
103 // Local lambda to specialize the bias quantization parameters.
104 auto getBiasType = [&](TypeRef inputTy, TypeRef weightsTy) -> TypeRef {
105 TensorQuantizationParams inputTQP = {inputTy->getScale(),
106 inputTy->getOffset()};
107 TensorQuantizationParams weightsTQP = {weightsTy->getScale(),
108 weightsTy->getOffset()};
109 auto biasTQP = specializeBiasQuantizationParams(
110 TQP, inputTQP, weightsTQP, schema_, quantizationPrecisionBias_);
111 return mod_.uniqueType(quantizationPrecisionBias_, val.dims(),
112 biasTQP.scale, biasTQP.offset);
113 };
114
115 // NOTE: For every node for which the bias is specialized add the similar
116 // logic in the 'generateNodeQuantizationInfos' function.
117 if (use.getKind() == glow::Kinded::Kind::ConvolutionNodeKind &&
118 idx == ConvolutionNode::BiasIdx) {
119 // Get the input and weights types. This ensures the types will be
120 // quantized. This is often the case when calling into this function from
121 // canConvert(), as we have not yet converted the inputs.
122 return getBiasType(
123 getTargetTypeForInput(use, ConvolutionNode::InputIdx),
124 getTargetTypeForInput(use, ConvolutionNode::FilterIdx));
125 } else if (use.getKind() == glow::Kinded::Kind::Convolution3DNodeKind &&
126 idx == Convolution3DNode::BiasIdx) {
127 // Get the input and weights types. This ensures the types will be
128 // quantized. This is often the case when calling into this function from
129 // canConvert(), as we have not yet converted the inputs.
130 return getBiasType(
131 getTargetTypeForInput(use, Convolution3DNode::InputIdx),
132 getTargetTypeForInput(use, Convolution3DNode::FilterIdx));
133 } else if (use.getKind() == glow::Kinded::Kind::ConvTransposeNodeKind &&
134 idx == ConvTransposeNode::BiasIdx) {
135 // Get the input and weights types. This ensures the types will be
136 // quantized. This is often the case when calling into this function from
137 // canConvert(), as we have not yet converted the inputs.
138 return getBiasType(
139 getTargetTypeForInput(use, ConvTransposeNode::InputIdx),
140 getTargetTypeForInput(use, ConvTransposeNode::FilterIdx));
141 } else if (use.getKind() == glow::Kinded::Kind::FullyConnectedNodeKind &&
142 idx == FullyConnectedNode::BiasIdx) {
143 // Return the original type if we don't want to convert FC biases.
144 if (skipQuantizeFCBias_) {
145 return val.getType();
146 }
147 // Get the input and weights types. This ensures the types will be
148 // quantized. This is often the case when calling into this function from
149 // canConvert(), as we have not yet converted the inputs.
150 return getBiasType(
151 getTargetTypeForInput(use, FullyConnectedNode::InputIdx),
152 getTargetTypeForInput(use, FullyConnectedNode::WeightsIdx));
153 } else if (use.getKind() == glow::Kinded::Kind::BatchedAddNodeKind &&
154 idx == BatchedAddNode::SliceIdx) {
155 // Check if this BatchedAdd was lowered from a FullyConnectedNode.
156 const auto *baN = llvm::cast<BatchedAddNode>(&use);
157 if (isBAFromLoweredFC(baN, loweredMap_)) {
158 // If it came from a FullyConnected node then we need to backtrack to
159 // the matrix multiplication to calculate the new scale for the batched
160 // add slice. Slice must be a MatMul if this was lowered from a
161 // FullyConnected. Batch may have already been quantized.
162 NodeValue batch = baN->getBatch();
163 assert(
164 (llvm::isa<MatMulNode>(batch) || llvm::isa<QuantizeNode>(batch)) &&
165 "Batch must be either a MatMul or a Quantize.");
166 MatMulNode *MM = llvm::dyn_cast<MatMulNode>(batch);
167 if (!MM) {
168 QuantizeNode *QN = llvm::cast<QuantizeNode>(batch);
169 assert(llvm::isa<MatMulNode>(QN->getInput()) &&
170 "MM must be input of BA if lowered from FC.");
171 MM = llvm::cast<MatMulNode>(QN->getInput());
172 }
173 return getBiasType(getTargetTypeForOutput(MM->getLHS()),
174 getTargetTypeForOutput(MM->getRHS()));
175 }
176 }
177 return mod_.uniqueType(quantizationPrecision_, val.dims(), TQP.scale,
178 TQP.offset);
179 }
180
181 /// Macro to be put in a switch for all nodes that may need to be replaced by
182 /// a LookupTable if the backend doesn't support the quantized node directly.
183#define CASES_FOR_INT_LOOKUP_TABLE_REPLACEMENT \
184 case Kinded::Kind::LogNodeKind: \
185 case Kinded::Kind::TanhNodeKind: \
186 case Kinded::Kind::SigmoidNodeKind
187
188 /// \see FunctionConverter::canConvert.
189 /// Only convert nodes that use floating point types and that
190 /// weren't specifically marked as to-ignore with doNotQuantizeKinds_.
191 bool canConvert(const Node &node) const override {
192 // Check if the node is one that we never want to convert, e.g. SaveNode.
193 if (!FunctionConverter::canConvert(node)) {
194 return false;
195 }
196
197 // Check if the node kind should not be converted based on supplied kinds
198 // informed to the converter.
199 if (doNotQuantizeKinds_.count(node.getKind())) {
200 return false;
201 }
202
203 // Gather the input and output types that we will have once we quantize the
204 // node, and check if the backend supports such a node. Note that if a node
205 // has float inputs or outputs then we must have quantization parameters for
206 // them. For inputs and outputs without quantization parameters, we keep
207 // their original element type.
208 bool needsQuantization = false;
209 std::vector<TypeRef> inputTypes, outputTypes;
210 for (unsigned idx = 0, end = node.getNumInputs(); idx != end; ++idx) {
211 NodeValue val = node.getNthInput(idx);
212 if (val.getElementType() == ElemKind::FloatTy) {
213 needsQuantization = true;
214 if (!quantizationParamsExist(val)) {
215 CHECK(!assertAllNodesQuantized_)
216 << "Quantization parameters did not exist for an input NodeValue "
217 "that should have been quantized; input number "
218 << idx << " of node:\n"
219 << node.getDebugDesc();
220 return false;
221 }
222 }
223 TypeRef targetTy = getTargetTypeForInput(node, idx);
224 inputTypes.push_back(targetTy);
225 }
226 for (unsigned idx = 0, end = node.getNumResults(); idx != end; ++idx) {
227 NodeValue val = node.getNthResult(idx);
228 if (val.getElementType() == ElemKind::FloatTy) {
229 needsQuantization = true;
230 if (!quantizationParamsExist(val)) {
231 CHECK(!assertAllNodesQuantized_)
232 << "Quantization parameters did not exist for a result of a Node "
233 "that should have been quantized; result number "
234 << idx << " of node:\n"
235 << node.getDebugDesc();
236 return false;
237 }
238 }
239 TypeRef targetTy = getTargetTypeForOutput(val);
240 outputTypes.push_back(targetTy);
241 }
242
243 // If none of the inputs were FPType then there's no quantization to
244 // perform, so we return that we cannot convert this node.
245 if (!needsQuantization) {
246 return false;
247 }
248
249 // Only convert the node if the backend supports the newly converted node.
250 bool isOpSupported =
251 B_.isOpSupported(NodeInfo(node.getKind(), inputTypes, outputTypes));
252
253 // Some nodes are only supported as quantized via lookup tables. Here we
254 // check if such nodes are supported without lookup tables; if so, then we
255 // convert them. Otherwise, return whether we can support them as lookup
256 // tables instead, and they will be quantized as lookup tables.
257 switch (node.getKind()) {
258 CASES_FOR_INT_LOOKUP_TABLE_REPLACEMENT:
259 if (!isOpSupported) {
260 isOpSupported = B_.isOpSupported(NodeInfo(
261 Kinded::Kind::IntLookupTableNodeKind, inputTypes, outputTypes));
262 }
263 break;
264 default:
265 break;
266 }
267
268 // Quantizer may be set up to die if a node is only skipped during
269 // quantization because the backend does not support it as quantized.
270 if (assertAllNodesQuantized_) {
271 CHECK(isOpSupported) << B_.getBackendName()
272 << " Backend does not support node as quantized in "
273 << Type::getElementName(quantizationPrecision_).str()
274 << ":\n"
275 << node.getDebugDesc();
276 }
277
278 return isOpSupported;
279 }
280
281 /// Helper that \returns whether quantization parameters exist
282 /// in \ref nodeToTQP_ given the name and result number of \p val.
283 bool quantizationParamsExist(const NodeValue &val) const {
284 auto valTQPIt = nodeToTQP_.find(val.generateNodeOutputName());
285 return valTQPIt != nodeToTQP_.end();
286 }
287
288 /// Create either a QuantizeNode or DequantizeNode in \p function based on the
289 /// \p destTy and the type of \p val.
290 /// Basically, if \p val's type is floating point, this creates a
291 /// QuantizeNode of \p val.
292 /// If \p val's type is a quantized type, this creates a
293 /// DequantizeNode of \p val.
294 ///
295 /// \pre One of t\p val's type and \p destTy must be FloatTy and
296 /// the other must be a quantized type.
297 Node *createConversion(Function &function, const Node &node, NodeValue &val,
298 TypeRef destTy, bool /* isInput */) override {
299 assert((&function == &function_) &&
300 "Trying to add quantize/dequantize conversion to a function other "
301 "than the function being quantized.");
302 std::string nodeName = node.getName().str();
303 if (destTy->isQuantizedType()) {
304 return function.createQuantize(nodeName + "_quantize", val, destTy);
305 }
306 return function.createDequantize(nodeName + "_dequantize", val, destTy);
307 }
308
309 /// All IRConstraint cases below assume that the input and output index that
310 /// they are looking for the type is at idx 0. We statically assert that here
311 /// along with the case.
312 static constexpr unsigned SingleMatchingInOutTypeInputIdx = 0;
313 static constexpr unsigned SingleMatchingInOutTypeResultIdx = 0;
314#define CASE_SINGLE_MATCHING_INOUT_TYPE(NODE_NAME_, INPUT_NAME_, OUTPUT_NAME_) \
315 static_assert((NODE_NAME_##Node::INPUT_NAME_##Idx == \
316 SingleMatchingInOutTypeInputIdx && \
317 NODE_NAME_##Node::OUTPUT_NAME_##Idx == \
318 SingleMatchingInOutTypeResultIdx), \
319 #NODE_NAME_ "Node format is unexpected."); \
320 case Kinded::Kind::NODE_NAME_##NodeKind
321
322 /// Macro to be put in a switch for all the nodes that have a constraint
323 /// where the input and output type must be equals.
324 /// Note: The last case of the macro doesn't have ':' so we can put it
325 /// where the macro is inserted to keep the nice code formatting.
326 // clang-format off
327#define CASES_FOR_SINGLE_MATCHING_IN_OUT_TYPE \
328 CASE_SINGLE_MATCHING_INOUT_TYPE(LocalResponseNormalization, Input, Result): \
329 CASE_SINGLE_MATCHING_INOUT_TYPE(Slice, Input, Result): \
330 CASE_SINGLE_MATCHING_INOUT_TYPE(Reshape, Input, Result): \
331 CASE_SINGLE_MATCHING_INOUT_TYPE(Transpose, Input, Result): \
332 CASE_SINGLE_MATCHING_INOUT_TYPE(TopK, Input, Values): \
333 CASE_SINGLE_MATCHING_INOUT_TYPE(Gather, Data, Result): \
334 CASE_SINGLE_MATCHING_INOUT_TYPE(MaxPool, Input, Result): \
335 CASE_SINGLE_MATCHING_INOUT_TYPE(ResizeNearest, Input, Result): \
336 CASE_SINGLE_MATCHING_INOUT_TYPE(ResizeBilinear, Input, Result): \
337 CASE_SINGLE_MATCHING_INOUT_TYPE(SpaceToDepth, Input, Result)
338 // clang-format on
339
340 /// \see FunctionConverter::morphNode.
341 /// This method does the final adjustment to the output types
342 /// when the profile and the IR constraints do not agree.
343 /// E.g., the profile of LocalResponseNormalizationNode may
344 /// give a range that is different from the range its input.
345 /// However, the IR expects that both the input and output of
346 /// this node have the same type.
347 Node &morphNode(Node &node) override {
348 // FIXME: Right now, the TensorQuantizationParams only tracks one
349 // type per NodeValue, whereas we would need one type for the output and
350 // one for each user of that value. E.g., for
351 // val = node
352 // = use1 val
353 // = use2 val
354 //
355
356 // We would want to track independently the types for the result of node,
357 // the use of val in use1, and the use of val in use2, in respectively
358 // outTy, inTy1, and inTy2, so that we can generate: outTy = node inTy1 =
359 // cast outTy = use1 inTy1 inTy2 = cast outTy = use2 inTy2
360 //
361 //
362 // But instead what we track only one type like this:
363 // outTy = node
364 // = use1 outTy
365 // = use2 outTy
366 //
367 // However, sometimes outTy is not suitable for the input (we fix those in
368 // postProcessing) and sometimes, the outTy is not suitable for the output
369 // itself!
370 // What this means basically is outTy encodes the inTy constraints whereas
371 // they may disagree.
372 // The following switch fixes outTy for the few nodes where inTy and outTy
373 // can disagree.
374 // This happens for cases where for instance, the quantized parameters for
375 // the output have a different scale than the input, whereas the operation
376 // itself doesn't allow that.
377 //
378 // E.g., the constraints for `outTy = op inTy` are `outTy == inTy`, but the
379 // quantization profiling gave different types to outTy and inTy.
380 switch (node.getKind()) {
381 // Those cases need to be in sync with postProcessing, so we generate them
382 // using macros.
383 CASES_FOR_SINGLE_MATCHING_IN_OUT_TYPE : {
384 // The constraints on the IR says that the input type must
385 // be the same as the output type.
386 TypeRef inTy =
387 node.getNthInput(SingleMatchingInOutTypeInputIdx).getType();
388 TypeRef fixedTy = mod_.uniqueType(
389 quantizationPrecision_,
390 node.getNthResult(SingleMatchingInOutTypeResultIdx).dims(),
391 inTy->getScale(), inTy->getOffset());
392
393 node.setType(SingleMatchingInOutTypeResultIdx, fixedTy);
394 assert(!lastMorphedNodeWithTypeChanges &&
395 "Missed one node to rescale in postprocessing");
396 lastMorphedNodeWithTypeChanges = &node;
397 return node;
398 }
399 default:
400 return node;
401 }
402 }
403
404 /// Perform post processing for \p node. Handles special cases, e.g.
405 /// requirements for input/output quantization parameters, converting to
406 /// lookup tables, etc. Also updates nodeToTQP_ with the added dequantization
407 /// nodes added for \p node.
408 void postProcessing(Node &node) override {
409 Node *quantizedNode = &node;
410 switch (node.getKind()) {
411 default:
412 break;
413
414 // Cases for nodes where all inputs should use the same scale/offset as
415 // the output.
416#define CASE_ALL_INS_MATCH_SINGLE_OUT(NODE_KIND_) \
417 case Kinded::Kind::NODE_KIND_##NodeKind: { \
418 auto *N = cast<NODE_KIND_##Node>(&node); \
419 TypeRef outputTy = N->getResult().getType(); \
420 assert(outputTy->isQuantizedType() && "Node hasn't been quantized yet?!"); \
421 unsigned idx = 0; \
422 for (size_t i = 0, e = N->getNumInputs(); i < e; ++i) { \
423 NodeValue input = N->getNthInput(i); \
424 auto argOutTy = \
425 mod_.uniqueType(quantizationPrecision_, input.dims(), \
426 outputTy->getScale(), outputTy->getOffset()); \
427 auto *rescale = function_.createRescaleQuantized( \
428 input.getNode()->getName(), input, argOutTy); \
429 function_.getLogContext()->logNodeInputChange(*N, N->getNthInput(idx), \
430 rescale); \
431 N->setNthInput(idx++, rescale); \
432 } \
433 break; \
434 }
435 CASE_ALL_INS_MATCH_SINGLE_OUT(Concat);
436 CASE_ALL_INS_MATCH_SINGLE_OUT(InsertTensor);
437#undef CASE_ALL_INS_MATCH_SINGLE_OUT
438
439 CASES_FOR_SINGLE_MATCHING_IN_OUT_TYPE : {
440 // Check that the main loop hands us the node in the order we expect:
441 // morph then postprocessing.
442 // If the assert breaks, that means that morphNode and postprocessing
443 // are out of sync (we probably miss a case in the switch).
444 assert(lastMorphedNodeWithTypeChanges == &node &&
445 "Mismatching last node changed");
446 lastMorphedNodeWithTypeChanges = nullptr;
447 // These nodes do not change {S,O} of the output, they use the same
448 // {S,O} as the input. Make sure that rescale is applied to comply with
449 // the taken profile from the node.
450 TypeRef outputTy = getTargetTypeForOutputImpl(
451 NodeValue(&node, SingleMatchingInOutTypeResultIdx));
452 assert(outputTy->isQuantizedType() && "Node hasn't been quantized yet?!");
453 auto outTy = mod_.uniqueType(quantizationPrecision_, outputTy->dims(),
454 outputTy->getScale(), outputTy->getOffset());
455 NodeValue val = node.getNthResult(SingleMatchingInOutTypeResultIdx);
456 // "val" may not have any users if the output goes unused, e.g. if we are
457 // quantizing a TopKNode and only indices is used.
458 if (val.getNumUsers() == 0) {
459 break;
460 }
461 // "node" should have only one use, the dequantize node.
462 // Update this use.
463 assert(
464 val.hasOneUse() &&
465 llvm::dyn_cast<DequantizeNode>((*val.getUsers().begin()).getUser()) &&
466 "This node should only be used by the dequantize node");
467 auto *dequantize =
468 llvm::dyn_cast<DequantizeNode>((*val.getUsers().begin()).getUser());
469 auto *rescale =
470 function_.createRescaleQuantized(node.getName(), val, outTy);
471 quantizedNode = rescale;
472
473 function_.getLogContext()->logNodeInputChange(
474 *dequantize, dequantize->getNthInput(DequantizeNode::InputIdx),
475 rescale);
476 dequantize->setNthInput(DequantizeNode::InputIdx, rescale);
477 break;
478 }
479
480 CASES_FOR_INT_LOOKUP_TABLE_REPLACEMENT : {
481 // If these nodes aren't supported then we convert them to a lookup table.
482 NodeInfo NI(node);
483 if (B_.isOpSupported(NI)) {
484 break;
485 }
486 assert(B_.isOpSupported(NodeInfo(Kinded::Kind::IntLookupTableNodeKind,
487 NI.getInTypes(), NI.getOutTypes())) &&
488 "Backend should support IntLookupTable at this point.");
489 switch (node.getKind()) {
490 case Kinded::Kind::LogNodeKind:
491 quantizedNode = replaceQuantizedLogWithLookupTable(
492 function_, llvm::cast<LogNode>(node), schema_);
493 break;
494 case Kinded::Kind::ExpNodeKind:
495 quantizedNode = replaceQuantizedExpWithLookupTable(
496 function_, llvm::cast<ExpNode>(node), schema_);
497 break;
498 case Kinded::Kind::TanhNodeKind:
499 quantizedNode = replaceQuantizedTanhWithLookupTable(
500 function_, llvm::cast<TanhNode>(node), schema_);
501 break;
502 case Kinded::Kind::SigmoidNodeKind:
503 quantizedNode = replaceQuantizedSigmoidWithLookupTable(
504 function_, llvm::cast<SigmoidNode>(node), schema_);
505 break;
506 default:
507 llvm_unreachable("Unsupported case for converting to lookup table.");
508 }
509 }
510 }
511 assert(!lastMorphedNodeWithTypeChanges && "Type not fixed");
512
513 // Update nodeToTQP_ since we've added in dequantized nodes to the output of
514 // now-quantized nodes. This is necessary because later we try to quantize
515 // nodes only if we have a quantized type for its operands (i.e. a profile
516 // in nodeToTQP_). However its inputs may already have been quantized, which
517 // means its inputs are replaced by a dequantize node, and no profile would
518 // be found in nodeToTQP_ for the dequantize node_. Thus we add TQPs for
519 // the dequantize node given the scale/offset it is dequantizing from.
520 for (unsigned outNum = 0, e = quantizedNode->getNumResults(); outNum != e;
521 ++outNum) {
522 NodeValue val = quantizedNode->getNthResult(outNum);
523 if (!val.getType()->isQuantizedType()) {
524 continue;
525 }
526 // Not all float outputs will have a dequantize added to its quantized
527 // output, as we may just be using some outputs of a quantized node
528 // (e.g. when quantizing a TopK but only using the Indices output, no
529 // dequantize node is added to Values).
530 if (val.getNumUsers() == 0) {
531 continue;
532 }
533 assert(
534 val.hasOneUse() &&
535 llvm::dyn_cast<DequantizeNode>((*val.getUsers().begin()).getUser()) &&
536 "This node should only be used by the dequantize node");
537 auto *dequantize =
538 llvm::dyn_cast<DequantizeNode>((*val.getUsers().begin()).getUser());
539 TypeRef outTy = val.getType();
540 auto name = dequantize->getResult().generateNodeOutputName();
541 nodeToTQP_[name] = {outTy->getScale(), outTy->getOffset()};
542 }
543 } // namespace
544
545 void convertTensor(Tensor &tensor, TypeRef destTy) override {
546 assert(tensor.getElementType() == ElemKind::FloatTy &&
547 (destTy->getElementType() == ElemKind::Int8QTy ||
548 destTy->getElementType() == ElemKind::Int16QTy) &&
549 "Dequantization not implemented");
550
551 tensor = quantizeTensor(tensor, {destTy->getScale(), destTy->getOffset()},
552 destTy->getElementType());
553 }
554
555private:
556 /// Shortcut to the module of function_.
557 Module &mod_;
558 /// Backend used to check is a quantized operator is supported.
559 const Backend &B_;
560 /// Quantization schema.
561 quantization::Schema schema_;
562 /// Quantization precision.
563 const ElemKind quantizationPrecision_;
564 /// Set of node kinds that should not be quantized.
565 const KindSet &doNotQuantizeKinds_;
566 /// Map the (name of a node, idx) to its quantization parameters.
567 std::unordered_map<std::string, TensorQuantizationParams> nodeToTQP_;
568 /// For debug, keep track of the last node that we changed because of IR
569 /// constraints.
570 Node *lastMorphedNodeWithTypeChanges;
571 /// A map between quantization profiling names of NodeValues that were lowered
572 /// from each other. Maps to a set of names of NodeValues and their NodeKinds
573 /// that were replaced by the NodeValue (whose output name is the key) that
574 /// replaced them.
575 const LoweredInfoMap &loweredMap_;
576 /// Used for debugging if we expect all nodes to be quantized by the
577 /// quantizer.
578 bool assertAllNodesQuantized_;
579 /// Precision used for bias quantization for Convolution and FullyConnected.
580 /// This allows specializing the bias quantization.
581 const ElemKind quantizationPrecisionBias_;
582
583 // If true, don't apply quantization to FC bias inputs.
584 const bool skipQuantizeFCBias_;
585
586public:
587 /// Creates a function quantizer for function \p F using the quantization
588 /// configuration \p quantConfig. This method quantizes as many nodes as
589 /// permitted by the backend \p B. The map \p loweredMap contains info about
590 /// what nodes were lowered from what, to be used during quantization.
591 /// \p doNotQuantizeKinds lists kinds to not quantize, even if a profile was
592 /// gathered for them and the backend supports the quantized operation.
593 FunctionQuantizer(Function &F, const Backend &B,
594 const QuantizationConfiguration &quantConfig,
595 const KindSet &doNotQuantizeKinds,
596 const LoweredInfoMap &loweredMap)
597 : FunctionConverter(F), mod_(*F.getParent()), B_(B),
598 schema_(quantConfig.schema),
599 quantizationPrecision_(quantConfig.precision),
600 doNotQuantizeKinds_(doNotQuantizeKinds), loweredMap_(loweredMap),
601 assertAllNodesQuantized_(quantConfig.assertAllNodesQuantized),
602 quantizationPrecisionBias_(quantConfig.precisionBias),
603 skipQuantizeFCBias_(quantConfig.skipQuantizeFCBias) {
604
605 // Compute the TensorQuantizationParams using the profiling infos.
606 auto quantizationInfos =
607 generateNodeQuantizationInfos(&F, quantConfig, loweredMap);
608
609 // Build a mapping between node name and TensorQuantizatonParams.
610 for (const auto &quantizationInfo : quantizationInfos) {
611 nodeToTQP_.emplace(quantizationInfo.nodeOutputName_,
612 quantizationInfo.tensorQuantizationParams_);
613 }
614
615 // Use for debug purposes.
616 lastMorphedNodeWithTypeChanges = nullptr;
617 (void)assertAllNodesQuantized_;
618 }
619
620 /// Traverse all nodes to find applicable quantized nodes, and convert them
621 /// to RowwiseQuantized versions if required inputs are Constant.
622 void enableRowwise() {
623 auto nodeIt = function_.getNodes().end();
624 auto stopIt = function_.getNodes().begin();
625 do {
626 --nodeIt;
627 Node &node = *nodeIt;
628 auto *Q = llvm::dyn_cast<DequantizeNode>(&node);
629 if (!Q) {
630 continue;
631 }
632
633 // ----------------------------------------------------------------------
634 // After function "convert()" is called, one FullyConnectedNode is
635 // converted into:
636 // [fp32 input] [fp32 weights] [fp32 bias]
637 // | | |
638 // [QuantizeNode] [QuantizeNode] [QuantizeNode (optional)]
639 // \ | /
640 // [ FullyConnectedNode ]
641 // |
642 // [DequantizeNode]
643 // We need to find the above pattern and convert it to:
644 // [fp32 input] [fp32 weights] [fp32 bias]
645 // | / | \ |
646 // | [int8 weights] [scales] [offsets] |
647 // [QuantizeNode] | | | [QuantizeNode]
648 // \ | | | /
649 // [ RowwiseQuantizedFullyConnectedNode ]
650 // |
651 // [DequantizeNode]
652 // ----------------------------------------------------------------------
653 bool foundFC = false;
654 NodeValue input, weights, bias, result;
655 if (auto *fcN = llvm::dyn_cast<FullyConnectedNode>(Q->getInput())) {
656 foundFC = true;
657 input = fcN->getInput();
658 weights = fcN->getWeights();
659 bias = fcN->getBias();
660 result = fcN->getResult();
661 } else if (const auto *baN =
662 llvm::dyn_cast<BatchedAddNode>(Q->getInput())) {
663 if (isBAFromLoweredFC(baN, loweredMap_)) {
664 foundFC = true;
665 NodeValue batch = baN->getBatch();
666
667 // All quantization has occurred at this point, but optimizations
668 // haven't eliminated extra quantize/dequantize nodes. Look
669 // backwards through them to find the MatMul of the FC.
670 assert(llvm::isa<QuantizeNode>(batch));
671 QuantizeNode *QN = llvm::cast<QuantizeNode>(batch);
672 assert(llvm::isa<DequantizeNode>(QN->getInput()));
673 DequantizeNode *DQN = llvm::cast<DequantizeNode>(QN->getInput());
674 assert(llvm::isa<MatMulNode>(DQN->getInput()));
675 MatMulNode *MM = llvm::cast<MatMulNode>(DQN->getInput());
676
677 input = MM->getLHS();
678 weights = MM->getRHS();
679 bias = baN->getSlice();
680 result = baN->getResult();
681 }
682 }
683 if (foundFC) {
684 // Only convert quantized FullyConnected Node (or its equivalent lowered
685 // representation in MatMul + BatchedAdd form).
686 if (input.getType()->isQuantizedType() &&
687 llvm::isa<QuantizeNode>(weights.getNode()) &&
688 result.getType()->isQuantizedType()) {
689 auto *wq = llvm::dyn_cast<QuantizeNode>(weights.getNode());
690 // For RowwiseQuantizedFullyConnected, the weights need to be
691 // constant.
692 if (Constant *wc = llvm::dyn_cast<Constant>(wq->getInput())) {
693 auto *fcq = function_.createRowwiseQuantizedFullyConnected(
694 "rowwiseqfc", input, wc, bias, result.getType(), schema_,
695 /* transposeWeight */ true);
696 // Replace usages of quantized FC node (or its equivalent lowered
697 // representation MM + BA) to RowwiseQuantizedFullyConnectedNode.
698 result.replaceAllUsesOfWith(fcq->getResult());
699 }
700 }
701 }
702
703 // Convert SLWS from normal version to fused rowwise-quantized version if
704 // applicable. Data must be Constant for this to occur. We also will not
705 // quantize the weights as we do for the default normal quantized SLWS, as
706 // the rowwise version uses float weights.
707 if (auto *SLWS =
708 llvm::dyn_cast<SparseLengthsWeightedSumNode>(Q->getInput())) {
709 NodeValue data = SLWS->getData();
710
711 // It's possible we skipped quantizing this node due to
712 // doNotQuantizeKinds, and so may not need to process it.
713 auto *dataQN = llvm::dyn_cast<QuantizeNode>(data.getNode());
714 if (!dataQN) {
715 continue;
716 }
717
718 // Can only convert to rowwise-quantized version if the data input is
719 // Constant.
720 auto *dataC = llvm::dyn_cast<Constant>(dataQN->getInput());
721 if (!dataC) {
722 continue;
723 }
724
725 // Right now we quantize the weights input for SLWS. However, the
726 // rowwise-quantized version does not, so we will skip the QN. At this
727 // point we know the SLWS was quantized, so the weights input must be a
728 // quantize node.
729 auto *weightsQN = llvm::dyn_cast<QuantizeNode>(SLWS->getWeights());
730 assert(weightsQN && "Weights should have been quantized");
731 NodeValue weightsF = weightsQN->getInput();
732
733 auto *FRWQSLWS =
734 function_.createFusedRowwiseQuantizedSparseLengthsWeightedSum(
735 SLWS->getName(), dataC->getPayloadMutable(), weightsF,
736 SLWS->getIndices(), SLWS->getLengths(),
737 /* fusedElemKind */ ElemKind::UInt8FusedQTy,
738 /* useFP16Accumulation */ false, SLWS->getLengthsMode(),
739 SLWS->getAvgLength());
740
741 // Fused RWQSLWS stores the fused scales and offsets in trailing
742 // columns. If the input was single dimensional then it adds extra
743 // dimensions to both input and output. Therefore reshape back to the
744 // expected output shape in case the input to the SLWS did not have a
745 // second dimension but the fused version added one to insert columns.
746 auto *RN = function_.createReshape("reshape", FRWQSLWS,
747 SLWS->getResult().dims());
748
749 // Replace the dequantize node of the original SLWS with the FRWQSLWS,
750 // as its output is already in float.
751 Q->getResult().replaceAllUsesOfWith(RN->getResult());
752 }
753
754 } while (nodeIt != stopIt);
755
756 cleanUp();
757 assert(function_.verify() && "Conversion led to invalid function");
758 }
759
760 /// Traverse all nodes to find applicable quantized nodes, and convert them
761 /// to ChannelwiseQuantized versions if required inputs are Constant.
762 void enableChannelwise() {
763 auto nodeIt = function_.getNodes().end();
764 auto stopIt = function_.getNodes().begin();
765 do {
766 --nodeIt;
767 Node &node = *nodeIt;
768 auto *Q = llvm::dyn_cast<DequantizeNode>(&node);
769 if (!Q) {
770 continue;
771 }
772
773 // ----------------------------------------------------------------------
774 // After function "convert()" is called, one ConvolutionNode is
775 // converted into:
776 // [fp32 input] [fp32 filter] [fp32 bias]
777 // | | |
778 // [QuantizeNode] [QuantizeNode] [QuantizeNode]
779 // \ | /
780 // [ ConvolutionNode ]
781 // |
782 // [DequantizeNode]
783 // We need to find the above pattern and convert it to:
784 // [fp32 input] [fp32 filter] [fp32 bias]
785 // | / | \ / | \
786 // | [int8 filter|scales|offsets] [int8 bias|scales|offsets]
787 // [QuantizeNode] | | | | | |
788 // \ | | | / / /
789 // [ ChannelwiseQuantizedConvolutionNode ]
790 // |
791 // [DequantizeNode]
792 // ----------------------------------------------------------------------
793
794 // Replace ConvolutionNode with ChannelwiseQuantizedConvolutionNode
795 // if the filter and bias operands are constant. The node creation
796 // function will be provided with the floating-point filter and
797 // bias constants and will perform channel wise quantization.
798 if (auto *convNode = llvm::dyn_cast<ConvolutionNode>(Q->getInput())) {
799
800 NodeValue input = convNode->getInput();
801 NodeValue filter = convNode->getFilter();
802 NodeValue bias = convNode->getBias();
803 NodeValue result = convNode->getResult();
804
805 if (input.getType()->isQuantizedType() &&
806 llvm::isa<QuantizeNode>(filter.getNode()) &&
807 llvm::isa<QuantizeNode>(bias.getNode()) &&
808 result.getType()->isQuantizedType()) {
809
810 auto *filterQ = llvm::dyn_cast<QuantizeNode>(filter.getNode());
811 Constant *filterC = llvm::dyn_cast<Constant>(filterQ->getInput());
812 auto *biasQ = llvm::dyn_cast<QuantizeNode>(bias.getNode());
813 Constant *biasC = llvm::dyn_cast<Constant>(biasQ->getInput());
814
815 if (filterC && biasC) {
816 // When the overall requested quantization schema is asymmetric
817 // we use symmetric quantization schema for the channelwise filter
818 // and bias in order to be closer to the TFLite quantization specs:
819 // https://www.tensorflow.org/lite/performance/quantization_spec
820 quantization::Schema quantSchema = schema_;
821 if (quantSchema == quantization::Schema::Asymmetric) {
822 quantSchema = quantization::Schema::Symmetric;
823 }
824 // Create per channel quantized Convolution.
825 auto *convNodeCWQ = function_.createChannelwiseQuantizedConv(
826 "ChannelwiseQuantizedConv", input, filterC, biasC,
827 /* filterScales */ nullptr, /* filterOffsets */ nullptr,
828 /* biasScales */ nullptr, /* biasOffsets */ nullptr,
829 result.getType(), convNode->getKernels(),
830 convNode->getStrides(), convNode->getPads(),
831 convNode->getGroup(), convNode->getDilation(),
832 /* quantizeFilter */ true, /* quantizeBias */ true, quantSchema,
833 quantizationPrecision_, quantizationPrecisionBias_);
834 convNodeCWQ->setFusedActivation(convNode->getFusedActivation());
835 convNodeCWQ->setFusedActivationArgs(
836 convNode->getFusedActivationArgs());
837 result.replaceAllUsesOfWith(convNodeCWQ->getResult());
838 }
839 }
840 }
841 } while (nodeIt != stopIt);
842 cleanUp();
843 assert(function_.verify() && "Conversion led to invalid function");
844 }
845}; // namespace
846
847} // namespace
848
849namespace glow {
850namespace quantization {
851
852Node *replaceQuantizedLogWithLookupTable(Function &F, const LogNode &LN,
853 Schema schema) {
854 IntLookupTableNode *ILT = F.createIntLog(
855 LN.getName().str() + ".log", LN.getInput(), LN.getResult().getType());
856 LN.getResult().replaceAllUsesOfWith(ILT);
857 return ILT;
858}
859
860Node *replaceQuantizedExpWithLookupTable(Function &F, const ExpNode &EN,
861 Schema schema) {
862 IntLookupTableNode *ELT = F.createIntExp(
863 EN.getName().str() + ".exp", EN.getInput(), EN.getResult().getType());
864 EN.getResult().replaceAllUsesOfWith(ELT);
865 return ELT;
866}
867
868Node *replaceQuantizedTanhWithLookupTable(Function &F, const TanhNode &TN,
869 Schema schema) {
870 // Quantized tanh operator expects input to be in a certain floating point
871 // range. This operator works based on the precomputed table and has to
872 // process input in a range of [-3.0, 3.0]. Tanh asymptotically approaches
873 // +/-1.0 and is already +/-.995 at +/-3.0.
874 // The output quantization parameters are chosen to represent the floating
875 // point range of [-1.0, 1.0].
876 TypeRef inpTy = TN.getInput().getType();
877 TypeRef outTy = TN.getResult().getType();
878 auto inputQuantizationParams = glow::quantization::chooseQuantizationParams(
879 {-3.0, 3.0}, schema, inpTy->getElementType());
880 auto tanhInTy = F.getParent()->uniqueType(
881 inpTy->getElementType(), TN.getResult().dims(),
882 inputQuantizationParams.scale, inputQuantizationParams.offset);
883
884 // Make sure input is clipped in [-3.0, 3.0] floating point range.
885 auto *rescaleInputNode =
886 F.createRescaleQuantized(TN.getName(), TN.getInput(), tanhInTy);
887
888 // Make sure output is clipped in [-1.0, 1.0] floating point range.
889 auto outputQuantizationParams = glow::quantization::chooseQuantizationParams(
890 {-1.0, 1.0}, schema, outTy->getElementType());
891 auto resultOutTy = F.getParent()->uniqueType(
892 outTy->getElementType(), rescaleInputNode->getResult().dims(),
893 outputQuantizationParams.scale, outputQuantizationParams.offset);
894
895 // Note: The actual lookup table is created inside this call.
896 auto *quantizedNode =
897 F.createIntTanh(TN.getName(), rescaleInputNode, resultOutTy);
898
899 auto *rescaleOutputNode = F.createRescaleQuantized(
900 TN.getName(), quantizedNode, TN.getResult().getType());
901
902 TN.getResult().replaceAllUsesOfWith(rescaleOutputNode);
903 return rescaleOutputNode;
904}
905
906Node *replaceQuantizedSigmoidWithLookupTable(Function &F, const SigmoidNode &SN,
907 Schema schema) {
908 // Quantized sigmoid operator expects input to be in a certain floating
909 // point range. This operator works based on the precomputed table and has
910 // to process input in a range of [-6.0, 6.0]. Sigmoid asymptotically
911 // approaches 0 at -inf and 1 at +inf. It has values of 0.00247262 and
912 // 0.997527 at -6.0 and 6.0 correspondingly. The output quantization
913 // parameters are chosen to represent the floating point range of [0, 1.0].
914 TypeRef inpTy = SN.getInput().getType();
915 TypeRef outTy = SN.getResult().getType();
916 auto inputQuantizationParams = glow::quantization::chooseQuantizationParams(
917 {-6.0, 6.0}, schema, inpTy->getElementType());
918 auto sigmoidInTy = F.getParent()->uniqueType(
919 inpTy->getElementType(), SN.getResult().dims(),
920 inputQuantizationParams.scale, inputQuantizationParams.offset);
921
922 // Make sure input is clipped in [-6.0, 6.0] floating point range.
923 auto *rescaleInputNode =
924 F.createRescaleQuantized(SN.getName(), SN.getInput(), sigmoidInTy);
925
926 // Make sure output is clipped in [0.0, 1.0] floating point range.
927 auto outputQuantizationParams = glow::quantization::chooseQuantizationParams(
928 {0.0, 1.0}, schema, outTy->getElementType());
929 auto resultOutTy = F.getParent()->uniqueType(
930 outTy->getElementType(), rescaleInputNode->getResult().dims(),
931 outputQuantizationParams.scale, outputQuantizationParams.offset);
932
933 // Note: The actual lookup table is created inside this call.
934 auto *quantizedNode =
935 F.createIntSigmoid(SN.getName(), rescaleInputNode, resultOutTy);
936
937 auto *rescaleOutputNode = F.createRescaleQuantized(
938 SN.getName(), quantizedNode, SN.getResult().getType());
939
940 SN.getResult().replaceAllUsesOfWith(rescaleOutputNode);
941
942 return rescaleOutputNode->getResult();
943}
944
945/// Helper which, given the output name \p currName of some node, looks for
946/// corresponding names in \p loweredMap which represent any names that this
947/// node was lowered from. If any are found then they are inserted into \p
948/// profilingInfos along with \p TPP.
949static void
950findAndInsertLoweredInfos(llvm::StringRef currName,
951 const LoweredInfoMap &loweredMap,
952 std::vector<NodeProfilingInfo> &profilingInfos,
953 const TensorProfilingParams &TPP) {
954 auto currSetIt = loweredMap.find(currName);
955 if (currSetIt == loweredMap.end()) {
956 return;
957 }
958
959 // Get the set of names corresponding to currName. All names in the set are
960 // names that were originally lowered into currName.
961 auto &currSet = currSetIt->getValue();
962
963 // For each of the names (currOrigName), insert them into profilingInfos,
964 // and then recursively find and insert other names in case currOrigName was
965 // also lowered from a previous node.
966 for (auto i = currSet.begin(), e = currSet.end(); i != e; ++i) {
967 llvm::StringRef currOrigName = i->getName();
968 profilingInfos.emplace_back(currOrigName.str(), TPP);
969 findAndInsertLoweredInfos(currOrigName, loweredMap, profilingInfos, TPP);
970 }
971}
972
973std::vector<NodeProfilingInfo>
974generateNodeProfilingInfos(PlaceholderBindings &bindings, const Function *F,
975 const LoweredInfoMap &loweredMap) {
976 std::vector<NodeProfilingInfo> profilingInfos;
977 for (auto &node : F->getNodes()) {
978 auto *QPN = llvm::dyn_cast<QuantizationProfileNode>(&node);
979 if (QPN) {
980
981 // Extract the profiling information from the placeholders after running
982 // the network in profiling mode.
983 auto compInfoH = bindings.get(QPN->getComputationInfoPlaceholder())
984 ->getHandle<float>();
985 auto *histogramT = bindings.get(QPN->getHistogramPlaceholder());
986 float min = compInfoH.raw(0);
987 float max = compInfoH.raw(1);
988
989 // Generate a name to be used as profiling information identifier.
990 std::string fullOutputName = NodeValue::generateNodeOutputName(
991 QPN->getProfiledNodeName(), QPN->getProfiledOutputNumber());
992
993 // Set TensorProfilingParams for this node output.
994 TensorProfilingParams TPP(min, max, *histogramT);
995 profilingInfos.emplace_back(fullOutputName, TPP);
996
997 // If the NodeValue represented by fullOutputName was created via lowering
998 // of another original NodeValue, then generate node profiling info for
999 // the original NodeValue using the same profiling parameters.
1000 findAndInsertLoweredInfos(fullOutputName, loweredMap, profilingInfos,
1001 TPP);
1002 }
1003 }
1004 return profilingInfos;
1005}
1006
1007std::vector<NodeQuantizationInfo>
1008generateNodeQuantizationInfos(Function *F,
1009 const QuantizationConfiguration &quantConfig,
1010 const LoweredInfoMap &loweredMap) {
1011 std::vector<NodeQuantizationInfo> quantizationInfos;
1012 for (const auto &profilingInfo : quantConfig.infos) {
1013 // Get node value from node output name.
1014 std::string nodeOutputName = profilingInfo.nodeOutputName_;
1015 NodeValue nodeOutput = F->getNodeValueByName(nodeOutputName);
1016
1017 // Skip if the node is not part of the graph.
1018 if (!nodeOutput.getNode()) {
1019 continue;
1020 }
1021
1022 // Default quantization schema.
1023 Schema schema = quantConfig.schema;
1024
1025 // Default target precision.
1026 ElemKind precision = quantConfig.precision;
1027
1028 // Default calibration mode.
1029 Calibration calibration = quantConfig.calibration;
1030
1031 // The TensorQuantizationParams must be computed using the target
1032 // precision used during the actual quantization. The code below
1033 // reflects the same logic as the one used in the function
1034 // FunctionQuantizer::getTargetTypeForInput for specializing the bias
1035 // quantization precision. Since bias quantization is sensitive we will
1036 // choose to use no calibration.
1037 for (const auto &use : nodeOutput.getUsers()) {
1038 const auto *user = use.getUser();
1039 if ((user->getKind() == glow::Kinded::Kind::ConvolutionNodeKind) &&
1040 (user->getNthInput(ConvolutionNode::BiasIdx) == nodeOutput)) {
1041 // Found bias for ConvolutionNode.
1042 precision = quantConfig.precisionBias;
1043 calibration = Calibration::None;
1044 continue;
1045 }
1046 if ((user->getKind() == glow::Kinded::Kind::Convolution3DNodeKind) &&
1047 (user->getNthInput(Convolution3DNode::BiasIdx) == nodeOutput)) {
1048 // Found bias for Convolution3DNode.
1049 precision = quantConfig.precisionBias;
1050 calibration = Calibration::None;
1051 continue;
1052 }
1053 if ((user->getKind() == glow::Kinded::Kind::ConvTransposeNodeKind) &&
1054 (user->getNthInput(ConvTransposeNode::BiasIdx) == nodeOutput)) {
1055 // Found bias for ConvTranspose.
1056 precision = quantConfig.precisionBias;
1057 calibration = Calibration::None;
1058 continue;
1059 }
1060 if ((user->getKind() == glow::Kinded::Kind::FullyConnectedNodeKind) &&
1061 (user->getNthInput(FullyConnectedNode::BiasIdx) == nodeOutput)) {
1062 // Found bias for FullyConnectedNode.
1063 precision = quantConfig.precisionBias;
1064 calibration = Calibration::None;
1065 continue;
1066 }
1067 if ((user->getKind() == glow::Kinded::Kind::BatchedAddNodeKind) &&
1068 (user->getNthInput(BatchedAddNode::SliceIdx) == nodeOutput)) {
1069 // Find out if this BatchAddNode was lowered from FullyConnectedNode.
1070 const auto *baN = llvm::cast<BatchedAddNode>(user);
1071 if (isBAFromLoweredFC(baN, loweredMap)) {
1072 // Found bias for lowered FullyConnectedNode.
1073 precision = quantConfig.precisionBias;
1074 calibration = Calibration::None;
1075 continue;
1076 }
1077 }
1078 }
1079
1080 // Do not calibrate the quantization parameters for scalars.
1081 if (nodeOutput.getType()->size() == 1) {
1082 calibration = Calibration::None;
1083 }
1084
1085 // Disable the quantization calibration for constant weights.
1086 if (!quantConfig.calibrateConstants &&
1087 llvm::isa<Constant>(nodeOutput.getNode())) {
1088 calibration = Calibration::None;
1089 }
1090
1091 // Compute the TensorQuantizationParams using the profiling information
1092 // and the target precision and calibration.
1093 TensorProfilingParams TPP = profilingInfo.tensorProfilingParams_;
1094 TensorQuantizationParams TQP =
1095 chooseQuantizationParams(TPP, schema, precision, calibration);
1096 quantizationInfos.emplace_back(nodeOutputName, TQP);
1097 }
1098
1099 return quantizationInfos;
1100}
1101
1102void quantizeFunction(Function *F, const QuantizationConfiguration &quantConfig,
1103 const Backend &B, const LoweredInfoMap &loweredMap,
1104 const KindSet &doNotQuantizeKinds) {
1105 DCHECK(quantConfig.precision == ElemKind::Int8QTy ||
1106 quantConfig.precision == ElemKind::UInt8QTy ||
1107 quantConfig.precision == ElemKind::Int16QTy)
1108 << "Only Int8, UInt8, and Int16 quantization supported";
1109
1110 FunctionQuantizer quantizer(*F, B, quantConfig, doNotQuantizeKinds,
1111 loweredMap);
1112 quantizer.convert();
1113
1114 // Enable rowwise quantization for FullyConnected node.
1115 if (quantConfig.enableRowwise) {
1116 quantizer.enableRowwise();
1117 }
1118
1119 // Enable channelwise quantization for Convolution node.
1120 if (quantConfig.enableChannelwise) {
1121 quantizer.enableChannelwise();
1122 }
1123}
1124
1125} // namespace quantization
1126} // namespace glow
1127