1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/transforms/quantize_fake_quantization.cc
22 * \brief A pass for taking fake quantized graphs and converting them
23 * to actual integer operations.
24 */
25
26#include "fake_quantization_to_integer.h"
27
28#include <tvm/ir/affine_type.h>
29#include <tvm/relay/attrs/nn.h>
30#include <tvm/relay/dataflow_matcher.h>
31#include <tvm/relay/expr.h>
32#include <tvm/relay/expr_functor.h>
33#include <tvm/relay/qnn/attrs.h>
34#include <tvm/relay/transform.h>
35
36#include <unordered_map>
37
38#include "../qnn/utils.h"
39
40namespace tvm {
41namespace relay {
42
43/* Description of FakeQuantizationToInteger
44 *
45 * The purpose of this pass is to find regions of the graph that follow
46 * the general pattern:
47 *
48 * x w
49 * | |
50 * dq dq
51 * \ /
52 * op1
53 * |
54 * op2
55 * |
56 * q
57 *
58 * and convert them into subgraphs with actual integer operations on x and w
59 *
60 * The pass does this via a multi-pass approach:
61 *
62 * The main pass is a MixedModeMutator that traverses the full graph searching for
63 * quantize operations
64 *
65 * The second pass is an ExprVisitor that recursively searches for subgraphs leading to the
66 * quantize for subtraphs bounded by dequantize operations. This pass extracts the affine
67 * types of the inputs for later processing, where affine denotes the transformation
68 * x_real = (x_affine - zero_point) * scale
69 *
70 * The third pass is an ExprMutator that recursively rewrites the subgraphs using packed funcs
71 * registered with the FTVMFakeQuantizationToInteger attribute. These packed funcs rewrite
72 * the ops based on the affine types of their inputs and then return the affine types of the
73 * new rewriten ops to pass that information down the stack during rewrite.
74 *
75 * After the second and third passes run, the first pass replaces the quantize with the
76 * rewritten subgraph and the processing continues
77 *
78 *
79 * After that an additional QAT pass can be enabled by use_qat flag. The goal of the pass is to find
80 * operations in those regions(which were not successfully converted by the main pass) that can
81 * still be converted into quantized form. The idea is to find and transform operations with
82 * dequantized inputs one by one individually. Only operations for which all parameters can be
83 * explicitly calculated are allowed. For example, if on the above general pattern op2 is not
84 * registered with the FTVMFakeQuantizationToInteger attribute, op1 operation can still be
85 * converted. Converted pattern below:
86 *
87 * x w
88 * | |
89 * \ /
90 * op1
91 * |
92 * dq
93 * |
94 * op2
95 * |
96 * q
97 *
98 * This pass works in the same multi-pass approach.
99 */
100
101using ExprSet = std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual>;
102using ExprMap = std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual>;
103using AffineTypeMap = Map<Expr, AffineType>;
104
105using FTVMFakeQuantizationToInteger =
106 runtime::TypedPackedFunc<Array<ObjectRef>(const Expr& expr, const AffineTypeMap& map)>;
107
108const ExprSet SubgraphExtractor::GetSubgraph(const Expr& expr) {
109 VisitExpr(expr);
110 ExprSet subgraph;
111 if (is_fake_quantized_) {
112 for (auto kv : this->visit_counter_) {
113 if (auto call_node = GetRef<ObjectRef>(kv.first).as<CallNode>()) {
114 if (call_node->op != quantize_op_) {
115 subgraph.insert(Downcast<Expr>(GetRef<ObjectRef>(kv.first)));
116 }
117 }
118 }
119 }
120 return subgraph;
121}
122const AffineTypeMap SubgraphExtractor::GetAffineTypes() { return affine_types_; }
123void SubgraphExtractor::VisitExpr(const Expr& expr) {
124 // When looking for fake quantized subgraphs, we only support data-flow regions of the graph,
125 // i.e. call nodes/tuples/constants/etc. If we see anything else (like control flow) we
126 // abort the rewrite.
127 if (expr.as<CallNode>() == nullptr && expr.as<OpNode>() == nullptr &&
128 expr.as<TupleNode>() == nullptr && expr.as<TupleGetItemNode>() == nullptr &&
129 expr.as<ConstantNode>() == nullptr) {
130 DLOG(INFO) << "FakeQuantizationToInteger found a non-dataflow op inside"
131 << " a fake quantize region, aborting this rewrite";
132 is_fake_quantized_ = false;
133 } else {
134 ExprVisitor::VisitExpr(expr);
135 }
136}
137
138void SubgraphExtractor::VisitExpr_(const CallNode* call_node) {
139 const Op test_op = Downcast<Op>(call_node->op);
140 if (call_node->op == quantize_op_) {
141 const auto* attrs = call_node->attrs.as<qnn::QuantizeAttrs>();
142 ICHECK(attrs != nullptr);
143 // Only look at arg0 for quantize
144 VisitExpr(call_node->args[0]);
145 // Collect type of quantize ops
146 affine_types_.Set(
147 GetRef<Expr>(call_node),
148 TensorAffineType(call_node->args[1], call_node->args[2], attrs->out_dtype, attrs->axis));
149 } else if (call_node->op == dequantize_op_) {
150 const auto* attrs = call_node->attrs.as<qnn::DequantizeAttrs>();
151 ICHECK(attrs != nullptr);
152 // Collect type of dequantize ops
153 affine_types_.Set(
154 GetRef<Expr>(call_node),
155 TensorAffineType(call_node->args[1], call_node->args[2],
156 call_node->args[0]->checked_type().as<TensorTypeNode>()->dtype,
157 attrs->axis));
158 } else {
159 // run normally on everything else.
160 ExprVisitor::VisitExpr_(call_node);
161 }
162}
163
164class SubgraphMutator : public ExprMutator {
165 public:
166 SubgraphMutator(ExprSet subgraph, AffineTypeMap affine_types, bool hard_fail)
167 : subgraph_(subgraph), affine_types_(affine_types), hard_fail_(hard_fail) {}
168
169 Expr MutateSubgraph(const Expr& expr) {
170 if (subgraph_.size() == 0) {
171 return expr;
172 }
173 const CallNode* quantize_node = expr.as<CallNode>();
174 ICHECK(quantize_node);
175 ICHECK(quantize_node->op == quantize_op_);
176 out_type_ = affine_types_[expr];
177 static auto fqfq =
178 Op::GetAttrMap<FTVMFakeQuantizationToInteger>("FTVMFakeQuantizationToInteger");
179 for (auto node : subgraph_) {
180 const Op op = Downcast<Op>(node.as<CallNode>()->op);
181 if (!fqfq.count(Downcast<Op>(op))) {
182 // Only modify the subgraph if we have translation
183 // rules for every op
184 if (hard_fail_) {
185 LOG(FATAL) << "Found no rewrite rule for " << AsText(op, false) << std::endl;
186 } else {
187 DLOG(INFO) << "Found no rewrite rule for " << AsText(op, false) << std::endl;
188 return expr;
189 }
190 }
191 }
192 try {
193 return Mutate(expr);
194 } catch (std::exception& e) {
195 if (hard_fail_) {
196 LOG(FATAL) << e.what();
197 } else {
198 DLOG(INFO) << "Ran into an error rewriting a subgraph, skipping" << expr << std::endl;
199 return expr;
200 }
201 }
202 }
203
204 protected:
205 Expr VisitExpr_(const CallNode* call_node) {
206 Expr out;
207
208 static auto fqfq =
209 Op::GetAttrMap<FTVMFakeQuantizationToInteger>("FTVMFakeQuantizationToInteger");
210 Op op = Downcast<Op>(call_node->op);
211 if (fqfq.count(op)) {
212 Expr expr;
213 if (op == dequantize_op_) {
214 expr = GetRef<Expr>(call_node);
215 } else {
216 expr = ExprMutator::VisitExpr_(call_node);
217 // Set the current op to the output type, useful if we can't deduce output parameters
218 // from input parameters
219 affine_types_.Set(expr, out_type_);
220 }
221 // Call the rewrite
222 Array<ObjectRef> vals = fqfq[op](expr, affine_types_);
223 // Save the outputs of the rewrite
224 ICHECK(vals.size() == 2)
225 << "got the wrong number of returned arguments from FTVMFakeQuantizationToInteger for "
226 << AsText(op, false);
227 out = Downcast<Expr>(vals[0]);
228 affine_types_.Set(out, Downcast<AffineType>(vals[1]));
229 } else {
230 ICHECK(false) << "When rewriting a fake quantized graph, found an invalid node "
231 << AsText(GetRef<Expr>(call_node), false);
232 }
233 return out;
234 }
235
236 Expr VisitExpr_(const TupleNode* node) {
237 Expr expr = ExprMutator::VisitExpr_(node);
238 auto new_node = expr.as<TupleNode>();
239 Array<TensorAffineType> types;
240 for (Expr field : new_node->fields) {
241 ICHECK(affine_types_[field].as<TensorAffineTypeNode>());
242 types.push_back(Downcast<TensorAffineType>(affine_types_[field]));
243 }
244 affine_types_.Set(expr, TupleAffineType(types));
245 return expr;
246 }
247
248 Expr VisitExpr_(const TupleGetItemNode* node) {
249 Expr expr = ExprMutator::VisitExpr_(node);
250 auto tuple_type = affine_types_[expr.as<TupleGetItemNode>()->tuple].as<TupleAffineTypeNode>();
251 affine_types_.Set(expr, tuple_type->types[node->index]);
252 return expr;
253 }
254
255 ExprSet subgraph_;
256 AffineTypeMap affine_types_;
257 AffineType out_type_;
258 const bool hard_fail_;
259 const Op quantize_op_ = Op::Get("qnn.quantize");
260 const Op dequantize_op_ = Op::Get("qnn.dequantize");
261};
262
263class FakeQuantizationRewriter : public MixedModeMutator {
264 public:
265 explicit FakeQuantizationRewriter(bool hard_fail) : hard_fail_(hard_fail) {}
266
267 protected:
268 Expr Rewrite_(const CallNode* pre, const Expr& post) override {
269 if (const CallNode* call_node = post.as<CallNode>()) {
270 if (call_node->op == quantize_op_) {
271 SubgraphExtractor extractor;
272 ExprSet subgraph = extractor.GetSubgraph(GetRef<Expr>(pre));
273 AffineTypeMap affine_types = extractor.GetAffineTypes();
274
275 ExprSet post_subgraph;
276 AffineTypeMap post_affine_types;
277
278 for (auto kv : affine_types) {
279 if (pre == kv.first.as<CallNode>()) {
280 // we havent memoized the current op yet
281 post_affine_types.Set(post, kv.second);
282 } else {
283 post_affine_types.Set(memo_.at(kv.first), kv.second);
284 }
285 }
286 for (auto expr : subgraph) {
287 post_subgraph.insert(memo_[expr]);
288 }
289 Expr out =
290 SubgraphMutator(post_subgraph, post_affine_types, hard_fail_).MutateSubgraph(post);
291 return out;
292 }
293 }
294 return post;
295 }
296 const Op quantize_op_ = Op::Get("qnn.quantize");
297 const bool hard_fail_;
298};
299
300/* Checks if the operation to convert QAT pass is enabled.
301 * The following conditions must be satisfied:
302 * 1. operations registered for FTVMFakeQuantizationToInteger;
303 * 2. Unary operators or operators with the TensorAffineType calculated during
304 * FTVMFakeQuantizationToInteger conversion;
305 * 3. Not one of the "key" operations: requantize,quantize and dequantize(they are at the boundaries
306 * of regions defined to be quantized).
307 */
308bool is_op_enabled_for_optional_fq2i(const CallNode* call_node) {
309 const Op op = Downcast<Op>(call_node->op);
310 static auto fqfq = Op::GetAttrMap<FTVMFakeQuantizationToInteger>("FTVMFakeQuantizationToInteger");
311 static std::unordered_set<relay::Expr, tvm::ObjectHash, tvm::ObjectEqual> ops = {
312 Op::Get("broadcast_to"),
313 Op::Get("clip"),
314 Op::Get("expand_dims"),
315 Op::Get("max"),
316 Op::Get("maximum"),
317 Op::Get("min"),
318 Op::Get("minimum"),
319 Op::Get("nn.avg_pool2d"),
320 Op::Get("nn.batch_flatten"),
321 Op::Get("nn.batch_matmul"),
322 Op::Get("nn.bias_add"),
323 Op::Get("nn.conv2d"),
324 Op::Get("nn.conv2d_transpose"),
325 Op::Get("nn.dense"),
326 Op::Get("nn.depth_to_space"),
327 Op::Get("nn.global_avg_pool2d"),
328 Op::Get("nn.max_pool2d"),
329 Op::Get("nn.pad"),
330 Op::Get("nn.relu"),
331 Op::Get("reshape"),
332 Op::Get("split"),
333 Op::Get("squeeze"),
334 Op::Get("strided_slice"),
335 Op::Get("transpose")};
336
337 return ops.find(call_node->op) != ops.end() && fqfq.count(Downcast<Op>(op));
338}
339
340class QATSubgraphExtractor : public ExprVisitor {
341 public:
342 const ExprSet GetSubgraph(const Expr& expr) {
343 expr_call_node_ = expr.as<CallNode>();
344 ICHECK(expr_call_node_ != nullptr);
345 ICHECK(is_op_enabled_for_optional_fq2i(expr_call_node_));
346
347 VisitExpr(expr);
348
349 ExprSet subgraph;
350 if (is_fake_quantized_) {
351 for (auto kv : this->visit_counter_) {
352 if (auto call_node = GetRef<ObjectRef>(kv.first).as<CallNode>()) {
353 if (call_node != expr_call_node_) {
354 subgraph.insert(Downcast<Expr>(GetRef<ObjectRef>(kv.first)));
355 }
356 }
357 }
358 }
359 return subgraph;
360 }
361 const AffineTypeMap GetAffineTypes() { return affine_types_; }
362 void VisitExpr(const Expr& expr) override {
363 // When looking for fake quantized subgraphs, we only support data-flow regions of the graph,
364 // i.e. call nodes/tuples/constants/etc. If we see anything else (like control flow) we
365 // abort the rewrite.
366 if (expr.as<CallNode>() == nullptr && expr.as<OpNode>() == nullptr &&
367 expr.as<TupleNode>() == nullptr && expr.as<TupleGetItemNode>() == nullptr &&
368 expr.as<ConstantNode>() == nullptr) {
369 DLOG(INFO) << "FakeQuantizationToInteger found a non - dataflow op inside a fake quantize "
370 "region, aborting this rewrite";
371 is_fake_quantized_ = false;
372 } else {
373 ExprVisitor::VisitExpr(expr);
374 }
375 }
376
377 protected:
378 void VisitExpr_(const CallNode* call_node) override {
379 if (call_node->op == dequantize_op_) {
380 const auto* attrs = call_node->attrs.as<qnn::DequantizeAttrs>();
381 ICHECK(attrs != nullptr);
382
383 affine_types_.Set(
384 GetRef<Expr>(call_node),
385 TensorAffineType(
386 call_node->args[1], call_node->args[2],
387 tvm::relay::transform::InferTypeLocal(call_node->args[0]).as<TensorTypeNode>()->dtype,
388 attrs->axis));
389 } else if (call_node == expr_call_node_) {
390 for (auto arg : call_node->args) {
391 VisitExpr(arg);
392 }
393 } else {
394 // run normally on everything else.
395 ExprVisitor::VisitExpr_(call_node);
396 }
397 }
398
399 const Op dequantize_op_ = Op::Get("qnn.dequantize");
400 bool is_fake_quantized_ = true;
401 AffineTypeMap affine_types_;
402 const CallNode* expr_call_node_ = nullptr;
403};
404
405class QATSubgraphMutator : public ExprMutator {
406 public:
407 QATSubgraphMutator(ExprSet subgraph, AffineTypeMap affine_types, bool hard_fail)
408 : subgraph_(subgraph), affine_types_(affine_types), hard_fail_(hard_fail) {}
409
410 Expr MutateSubgraph(const Expr& expr) {
411 if (subgraph_.size() == 0) {
412 return expr;
413 }
414
415 quantize_node_ = expr.as<CallNode>();
416 ICHECK(quantize_node_);
417 ICHECK(is_op_enabled_for_optional_fq2i(quantize_node_));
418
419 for (auto node : subgraph_) {
420 const Op op = Downcast<Op>(node.as<CallNode>()->op);
421
422 if (node.as<CallNode>()->op != dequantize_op_) {
423 if (hard_fail_) {
424 LOG(FATAL) << "Not dequantization was found in the input arguments for"
425 << AsText(op, false) << std::endl;
426 } else {
427 DLOG(INFO) << "Not dequantization was found in the input arguments for "
428 << AsText(op, false) << std::endl;
429 return expr;
430 }
431 }
432 }
433 try {
434 return Mutate(expr);
435 } catch (std::exception& e) {
436 if (hard_fail_) {
437 throw e;
438 } else {
439 DLOG(INFO) << "Ran into an error rewriting a subgraph, skipping" << expr << std::endl;
440 return expr;
441 }
442 }
443 }
444
445 protected:
446 Expr VisitExpr_(const CallNode* call_node) {
447 Expr out;
448 static auto fqfq =
449 Op::GetAttrMap<FTVMFakeQuantizationToInteger>("FTVMFakeQuantizationToInteger");
450
451 Op op = Downcast<Op>(call_node->op);
452 if (fqfq.count(op)) {
453 Expr expr;
454 if (op == dequantize_op_) {
455 expr = GetRef<Expr>(call_node);
456 } else {
457 expr = ExprMutator::VisitExpr_(call_node);
458 }
459 // Call the rewrite
460 Array<ObjectRef> vals = fqfq[op](expr, affine_types_);
461 // Save the outputs of the rewrite
462 ICHECK(vals.size() == 2)
463 << "got the wrong number of returned arguments from FTVMFakeQuantizationToInteger for "
464 << AsText(op, false);
465 out = Downcast<Expr>(vals[0]);
466
467 affine_types_.Set(out, Downcast<AffineType>(vals[1]));
468
469 if (call_node == quantize_node_) {
470 out = qnn::MakeDequantize(out, vals[1].as<TensorAffineTypeNode>()->scale,
471 vals[1].as<TensorAffineTypeNode>()->zero_point,
472 vals[1].as<TensorAffineTypeNode>()->axis);
473 }
474 } else {
475 ICHECK(false) << "When rewriting a fake quantized graph, found an invalid node "
476 << AsText(GetRef<Expr>(call_node), false);
477 }
478 return out;
479 }
480
481 Expr VisitExpr_(const TupleNode* node) {
482 Expr expr = ExprMutator::VisitExpr_(node);
483 auto new_node = expr.as<TupleNode>();
484 Array<TensorAffineType> types;
485 for (Expr field : new_node->fields) {
486 ICHECK(affine_types_[field].as<TensorAffineTypeNode>());
487 types.push_back(Downcast<TensorAffineType>(affine_types_[field]));
488 }
489 affine_types_.Set(expr, TupleAffineType(types));
490 return expr;
491 }
492
493 Expr VisitExpr_(const TupleGetItemNode* node) {
494 Expr expr = ExprMutator::VisitExpr_(node);
495 auto tuple_type = affine_types_[expr.as<TupleGetItemNode>()->tuple].as<TupleAffineTypeNode>();
496 affine_types_.Set(expr, tuple_type->types[node->index]);
497 return expr;
498 }
499
500 ExprSet subgraph_;
501 AffineTypeMap affine_types_;
502 const bool hard_fail_;
503 const Op dequantize_op_ = Op::Get("qnn.dequantize");
504 const CallNode* quantize_node_ = nullptr;
505};
506
507class QATRewriter : public MixedModeMutator {
508 public:
509 explicit QATRewriter(bool hard_fail) : hard_fail_(hard_fail) {}
510
511 protected:
512 Expr Rewrite_(const CallNode* pre, const Expr& post) override {
513 if (const CallNode* call_node = post.as<CallNode>()) {
514 const Op op = Downcast<Op>(call_node->op);
515 if (is_op_enabled_for_optional_fq2i(call_node)) {
516 QATSubgraphExtractor extractor;
517 ExprSet subgraph = extractor.GetSubgraph(post);
518 AffineTypeMap affine_types = extractor.GetAffineTypes();
519 Expr out = QATSubgraphMutator(subgraph, affine_types, hard_fail_).MutateSubgraph(post);
520 return out;
521 }
522 }
523 return post;
524 }
525 const bool hard_fail_;
526};
527
528Expr FakeQuantizationToInteger(const Expr& expr, const IRModule& mod, bool hard_fail,
529 bool use_qat) {
530 auto fq_expr = FakeQuantizationRewriter(hard_fail).Mutate(expr);
531 if (use_qat) {
532 fq_expr = tvm::relay::InferType(fq_expr);
533 fq_expr = QATRewriter(hard_fail).Mutate(fq_expr);
534 }
535 return fq_expr;
536}
537
538namespace transform {
539
540Pass FakeQuantizationToInteger(bool hard_fail, bool use_qat) {
541 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
542 [=](Function f, IRModule m, PassContext pc) {
543 return Downcast<Function>(FakeQuantizationToInteger(f, m, hard_fail, use_qat));
544 };
545 return CreateFunctionPass(pass_func, 0, "FakeQuantizationToInteger", {"InferType", "DivToMul"});
546}
547
548TVM_REGISTER_GLOBAL("relay._transform.FakeQuantizationToInteger")
549 .set_body_typed(FakeQuantizationToInteger);
550
551} // namespace transform
552
553} // namespace relay
554} // namespace tvm
555