1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/transforms/quantize_fake_quantization.cc |
22 | * \brief A pass for taking fake quantized graphs and converting them |
23 | * to actual integer operations. |
24 | */ |
25 | |
26 | #include "fake_quantization_to_integer.h" |
27 | |
28 | #include <tvm/ir/affine_type.h> |
29 | #include <tvm/relay/attrs/nn.h> |
30 | #include <tvm/relay/dataflow_matcher.h> |
31 | #include <tvm/relay/expr.h> |
32 | #include <tvm/relay/expr_functor.h> |
33 | #include <tvm/relay/qnn/attrs.h> |
34 | #include <tvm/relay/transform.h> |
35 | |
36 | #include <unordered_map> |
37 | |
38 | #include "../qnn/utils.h" |
39 | |
40 | namespace tvm { |
41 | namespace relay { |
42 | |
43 | /* Description of FakeQuantizationToInteger |
44 | * |
45 | * The purpose of this pass is to find regions of the graph that follow |
46 | * the general pattern: |
47 | * |
48 | * x w |
49 | * | | |
50 | * dq dq |
51 | * \ / |
52 | * op1 |
53 | * | |
54 | * op2 |
55 | * | |
56 | * q |
57 | * |
58 | * and convert them into subgraphs with actual integer operations on x and w |
59 | * |
60 | * The pass does this via a multi-pass approach: |
61 | * |
62 | * The main pass is a MixedModeMutator that traverses the full graph searching for |
63 | * quantize operations |
64 | * |
65 | * The second pass is an ExprVisitor that recursively searches for subgraphs leading to the |
66 | * quantize for subtraphs bounded by dequantize operations. This pass extracts the affine |
67 | * types of the inputs for later processing, where affine denotes the transformation |
68 | * x_real = (x_affine - zero_point) * scale |
69 | * |
70 | * The third pass is an ExprMutator that recursively rewrites the subgraphs using packed funcs |
71 | * registered with the FTVMFakeQuantizationToInteger attribute. These packed funcs rewrite |
72 | * the ops based on the affine types of their inputs and then return the affine types of the |
73 | * new rewriten ops to pass that information down the stack during rewrite. |
74 | * |
75 | * After the second and third passes run, the first pass replaces the quantize with the |
76 | * rewritten subgraph and the processing continues |
77 | * |
78 | * |
79 | * After that an additional QAT pass can be enabled by use_qat flag. The goal of the pass is to find |
80 | * operations in those regions(which were not successfully converted by the main pass) that can |
81 | * still be converted into quantized form. The idea is to find and transform operations with |
82 | * dequantized inputs one by one individually. Only operations for which all parameters can be |
83 | * explicitly calculated are allowed. For example, if on the above general pattern op2 is not |
84 | * registered with the FTVMFakeQuantizationToInteger attribute, op1 operation can still be |
85 | * converted. Converted pattern below: |
86 | * |
87 | * x w |
88 | * | | |
89 | * \ / |
90 | * op1 |
91 | * | |
92 | * dq |
93 | * | |
94 | * op2 |
95 | * | |
96 | * q |
97 | * |
98 | * This pass works in the same multi-pass approach. |
99 | */ |
100 | |
101 | using ExprSet = std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual>; |
102 | using ExprMap = std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual>; |
103 | using AffineTypeMap = Map<Expr, AffineType>; |
104 | |
105 | using FTVMFakeQuantizationToInteger = |
106 | runtime::TypedPackedFunc<Array<ObjectRef>(const Expr& expr, const AffineTypeMap& map)>; |
107 | |
108 | const ExprSet SubgraphExtractor::(const Expr& expr) { |
109 | VisitExpr(expr); |
110 | ExprSet subgraph; |
111 | if (is_fake_quantized_) { |
112 | for (auto kv : this->visit_counter_) { |
113 | if (auto call_node = GetRef<ObjectRef>(kv.first).as<CallNode>()) { |
114 | if (call_node->op != quantize_op_) { |
115 | subgraph.insert(Downcast<Expr>(GetRef<ObjectRef>(kv.first))); |
116 | } |
117 | } |
118 | } |
119 | } |
120 | return subgraph; |
121 | } |
122 | const AffineTypeMap SubgraphExtractor::() { return affine_types_; } |
123 | void SubgraphExtractor::(const Expr& expr) { |
124 | // When looking for fake quantized subgraphs, we only support data-flow regions of the graph, |
125 | // i.e. call nodes/tuples/constants/etc. If we see anything else (like control flow) we |
126 | // abort the rewrite. |
127 | if (expr.as<CallNode>() == nullptr && expr.as<OpNode>() == nullptr && |
128 | expr.as<TupleNode>() == nullptr && expr.as<TupleGetItemNode>() == nullptr && |
129 | expr.as<ConstantNode>() == nullptr) { |
130 | DLOG(INFO) << "FakeQuantizationToInteger found a non-dataflow op inside" |
131 | << " a fake quantize region, aborting this rewrite" ; |
132 | is_fake_quantized_ = false; |
133 | } else { |
134 | ExprVisitor::VisitExpr(expr); |
135 | } |
136 | } |
137 | |
138 | void SubgraphExtractor::(const CallNode* call_node) { |
139 | const Op test_op = Downcast<Op>(call_node->op); |
140 | if (call_node->op == quantize_op_) { |
141 | const auto* attrs = call_node->attrs.as<qnn::QuantizeAttrs>(); |
142 | ICHECK(attrs != nullptr); |
143 | // Only look at arg0 for quantize |
144 | VisitExpr(call_node->args[0]); |
145 | // Collect type of quantize ops |
146 | affine_types_.Set( |
147 | GetRef<Expr>(call_node), |
148 | TensorAffineType(call_node->args[1], call_node->args[2], attrs->out_dtype, attrs->axis)); |
149 | } else if (call_node->op == dequantize_op_) { |
150 | const auto* attrs = call_node->attrs.as<qnn::DequantizeAttrs>(); |
151 | ICHECK(attrs != nullptr); |
152 | // Collect type of dequantize ops |
153 | affine_types_.Set( |
154 | GetRef<Expr>(call_node), |
155 | TensorAffineType(call_node->args[1], call_node->args[2], |
156 | call_node->args[0]->checked_type().as<TensorTypeNode>()->dtype, |
157 | attrs->axis)); |
158 | } else { |
159 | // run normally on everything else. |
160 | ExprVisitor::VisitExpr_(call_node); |
161 | } |
162 | } |
163 | |
164 | class SubgraphMutator : public ExprMutator { |
165 | public: |
166 | SubgraphMutator(ExprSet subgraph, AffineTypeMap affine_types, bool hard_fail) |
167 | : subgraph_(subgraph), affine_types_(affine_types), hard_fail_(hard_fail) {} |
168 | |
169 | Expr MutateSubgraph(const Expr& expr) { |
170 | if (subgraph_.size() == 0) { |
171 | return expr; |
172 | } |
173 | const CallNode* quantize_node = expr.as<CallNode>(); |
174 | ICHECK(quantize_node); |
175 | ICHECK(quantize_node->op == quantize_op_); |
176 | out_type_ = affine_types_[expr]; |
177 | static auto fqfq = |
178 | Op::GetAttrMap<FTVMFakeQuantizationToInteger>("FTVMFakeQuantizationToInteger" ); |
179 | for (auto node : subgraph_) { |
180 | const Op op = Downcast<Op>(node.as<CallNode>()->op); |
181 | if (!fqfq.count(Downcast<Op>(op))) { |
182 | // Only modify the subgraph if we have translation |
183 | // rules for every op |
184 | if (hard_fail_) { |
185 | LOG(FATAL) << "Found no rewrite rule for " << AsText(op, false) << std::endl; |
186 | } else { |
187 | DLOG(INFO) << "Found no rewrite rule for " << AsText(op, false) << std::endl; |
188 | return expr; |
189 | } |
190 | } |
191 | } |
192 | try { |
193 | return Mutate(expr); |
194 | } catch (std::exception& e) { |
195 | if (hard_fail_) { |
196 | LOG(FATAL) << e.what(); |
197 | } else { |
198 | DLOG(INFO) << "Ran into an error rewriting a subgraph, skipping" << expr << std::endl; |
199 | return expr; |
200 | } |
201 | } |
202 | } |
203 | |
204 | protected: |
205 | Expr VisitExpr_(const CallNode* call_node) { |
206 | Expr out; |
207 | |
208 | static auto fqfq = |
209 | Op::GetAttrMap<FTVMFakeQuantizationToInteger>("FTVMFakeQuantizationToInteger" ); |
210 | Op op = Downcast<Op>(call_node->op); |
211 | if (fqfq.count(op)) { |
212 | Expr expr; |
213 | if (op == dequantize_op_) { |
214 | expr = GetRef<Expr>(call_node); |
215 | } else { |
216 | expr = ExprMutator::VisitExpr_(call_node); |
217 | // Set the current op to the output type, useful if we can't deduce output parameters |
218 | // from input parameters |
219 | affine_types_.Set(expr, out_type_); |
220 | } |
221 | // Call the rewrite |
222 | Array<ObjectRef> vals = fqfq[op](expr, affine_types_); |
223 | // Save the outputs of the rewrite |
224 | ICHECK(vals.size() == 2) |
225 | << "got the wrong number of returned arguments from FTVMFakeQuantizationToInteger for " |
226 | << AsText(op, false); |
227 | out = Downcast<Expr>(vals[0]); |
228 | affine_types_.Set(out, Downcast<AffineType>(vals[1])); |
229 | } else { |
230 | ICHECK(false) << "When rewriting a fake quantized graph, found an invalid node " |
231 | << AsText(GetRef<Expr>(call_node), false); |
232 | } |
233 | return out; |
234 | } |
235 | |
236 | Expr VisitExpr_(const TupleNode* node) { |
237 | Expr expr = ExprMutator::VisitExpr_(node); |
238 | auto new_node = expr.as<TupleNode>(); |
239 | Array<TensorAffineType> types; |
240 | for (Expr field : new_node->fields) { |
241 | ICHECK(affine_types_[field].as<TensorAffineTypeNode>()); |
242 | types.push_back(Downcast<TensorAffineType>(affine_types_[field])); |
243 | } |
244 | affine_types_.Set(expr, TupleAffineType(types)); |
245 | return expr; |
246 | } |
247 | |
248 | Expr VisitExpr_(const TupleGetItemNode* node) { |
249 | Expr expr = ExprMutator::VisitExpr_(node); |
250 | auto tuple_type = affine_types_[expr.as<TupleGetItemNode>()->tuple].as<TupleAffineTypeNode>(); |
251 | affine_types_.Set(expr, tuple_type->types[node->index]); |
252 | return expr; |
253 | } |
254 | |
255 | ExprSet subgraph_; |
256 | AffineTypeMap affine_types_; |
257 | AffineType out_type_; |
258 | const bool hard_fail_; |
259 | const Op quantize_op_ = Op::Get("qnn.quantize" ); |
260 | const Op dequantize_op_ = Op::Get("qnn.dequantize" ); |
261 | }; |
262 | |
263 | class FakeQuantizationRewriter : public MixedModeMutator { |
264 | public: |
265 | explicit FakeQuantizationRewriter(bool hard_fail) : hard_fail_(hard_fail) {} |
266 | |
267 | protected: |
268 | Expr Rewrite_(const CallNode* pre, const Expr& post) override { |
269 | if (const CallNode* call_node = post.as<CallNode>()) { |
270 | if (call_node->op == quantize_op_) { |
271 | SubgraphExtractor ; |
272 | ExprSet subgraph = extractor.GetSubgraph(GetRef<Expr>(pre)); |
273 | AffineTypeMap affine_types = extractor.GetAffineTypes(); |
274 | |
275 | ExprSet post_subgraph; |
276 | AffineTypeMap post_affine_types; |
277 | |
278 | for (auto kv : affine_types) { |
279 | if (pre == kv.first.as<CallNode>()) { |
280 | // we havent memoized the current op yet |
281 | post_affine_types.Set(post, kv.second); |
282 | } else { |
283 | post_affine_types.Set(memo_.at(kv.first), kv.second); |
284 | } |
285 | } |
286 | for (auto expr : subgraph) { |
287 | post_subgraph.insert(memo_[expr]); |
288 | } |
289 | Expr out = |
290 | SubgraphMutator(post_subgraph, post_affine_types, hard_fail_).MutateSubgraph(post); |
291 | return out; |
292 | } |
293 | } |
294 | return post; |
295 | } |
296 | const Op quantize_op_ = Op::Get("qnn.quantize" ); |
297 | const bool hard_fail_; |
298 | }; |
299 | |
300 | /* Checks if the operation to convert QAT pass is enabled. |
301 | * The following conditions must be satisfied: |
302 | * 1. operations registered for FTVMFakeQuantizationToInteger; |
303 | * 2. Unary operators or operators with the TensorAffineType calculated during |
304 | * FTVMFakeQuantizationToInteger conversion; |
305 | * 3. Not one of the "key" operations: requantize,quantize and dequantize(they are at the boundaries |
306 | * of regions defined to be quantized). |
307 | */ |
308 | bool is_op_enabled_for_optional_fq2i(const CallNode* call_node) { |
309 | const Op op = Downcast<Op>(call_node->op); |
310 | static auto fqfq = Op::GetAttrMap<FTVMFakeQuantizationToInteger>("FTVMFakeQuantizationToInteger" ); |
311 | static std::unordered_set<relay::Expr, tvm::ObjectHash, tvm::ObjectEqual> ops = { |
312 | Op::Get("broadcast_to" ), |
313 | Op::Get("clip" ), |
314 | Op::Get("expand_dims" ), |
315 | Op::Get("max" ), |
316 | Op::Get("maximum" ), |
317 | Op::Get("min" ), |
318 | Op::Get("minimum" ), |
319 | Op::Get("nn.avg_pool2d" ), |
320 | Op::Get("nn.batch_flatten" ), |
321 | Op::Get("nn.batch_matmul" ), |
322 | Op::Get("nn.bias_add" ), |
323 | Op::Get("nn.conv2d" ), |
324 | Op::Get("nn.conv2d_transpose" ), |
325 | Op::Get("nn.dense" ), |
326 | Op::Get("nn.depth_to_space" ), |
327 | Op::Get("nn.global_avg_pool2d" ), |
328 | Op::Get("nn.max_pool2d" ), |
329 | Op::Get("nn.pad" ), |
330 | Op::Get("nn.relu" ), |
331 | Op::Get("reshape" ), |
332 | Op::Get("split" ), |
333 | Op::Get("squeeze" ), |
334 | Op::Get("strided_slice" ), |
335 | Op::Get("transpose" )}; |
336 | |
337 | return ops.find(call_node->op) != ops.end() && fqfq.count(Downcast<Op>(op)); |
338 | } |
339 | |
340 | class : public ExprVisitor { |
341 | public: |
342 | const ExprSet (const Expr& expr) { |
343 | expr_call_node_ = expr.as<CallNode>(); |
344 | ICHECK(expr_call_node_ != nullptr); |
345 | ICHECK(is_op_enabled_for_optional_fq2i(expr_call_node_)); |
346 | |
347 | VisitExpr(expr); |
348 | |
349 | ExprSet subgraph; |
350 | if (is_fake_quantized_) { |
351 | for (auto kv : this->visit_counter_) { |
352 | if (auto call_node = GetRef<ObjectRef>(kv.first).as<CallNode>()) { |
353 | if (call_node != expr_call_node_) { |
354 | subgraph.insert(Downcast<Expr>(GetRef<ObjectRef>(kv.first))); |
355 | } |
356 | } |
357 | } |
358 | } |
359 | return subgraph; |
360 | } |
361 | const AffineTypeMap () { return affine_types_; } |
362 | void (const Expr& expr) override { |
363 | // When looking for fake quantized subgraphs, we only support data-flow regions of the graph, |
364 | // i.e. call nodes/tuples/constants/etc. If we see anything else (like control flow) we |
365 | // abort the rewrite. |
366 | if (expr.as<CallNode>() == nullptr && expr.as<OpNode>() == nullptr && |
367 | expr.as<TupleNode>() == nullptr && expr.as<TupleGetItemNode>() == nullptr && |
368 | expr.as<ConstantNode>() == nullptr) { |
369 | DLOG(INFO) << "FakeQuantizationToInteger found a non - dataflow op inside a fake quantize " |
370 | "region, aborting this rewrite" ; |
371 | is_fake_quantized_ = false; |
372 | } else { |
373 | ExprVisitor::VisitExpr(expr); |
374 | } |
375 | } |
376 | |
377 | protected: |
378 | void (const CallNode* call_node) override { |
379 | if (call_node->op == dequantize_op_) { |
380 | const auto* attrs = call_node->attrs.as<qnn::DequantizeAttrs>(); |
381 | ICHECK(attrs != nullptr); |
382 | |
383 | affine_types_.Set( |
384 | GetRef<Expr>(call_node), |
385 | TensorAffineType( |
386 | call_node->args[1], call_node->args[2], |
387 | tvm::relay::transform::InferTypeLocal(call_node->args[0]).as<TensorTypeNode>()->dtype, |
388 | attrs->axis)); |
389 | } else if (call_node == expr_call_node_) { |
390 | for (auto arg : call_node->args) { |
391 | VisitExpr(arg); |
392 | } |
393 | } else { |
394 | // run normally on everything else. |
395 | ExprVisitor::VisitExpr_(call_node); |
396 | } |
397 | } |
398 | |
399 | const Op = Op::Get("qnn.dequantize" ); |
400 | bool = true; |
401 | AffineTypeMap ; |
402 | const CallNode* = nullptr; |
403 | }; |
404 | |
405 | class QATSubgraphMutator : public ExprMutator { |
406 | public: |
407 | QATSubgraphMutator(ExprSet subgraph, AffineTypeMap affine_types, bool hard_fail) |
408 | : subgraph_(subgraph), affine_types_(affine_types), hard_fail_(hard_fail) {} |
409 | |
410 | Expr MutateSubgraph(const Expr& expr) { |
411 | if (subgraph_.size() == 0) { |
412 | return expr; |
413 | } |
414 | |
415 | quantize_node_ = expr.as<CallNode>(); |
416 | ICHECK(quantize_node_); |
417 | ICHECK(is_op_enabled_for_optional_fq2i(quantize_node_)); |
418 | |
419 | for (auto node : subgraph_) { |
420 | const Op op = Downcast<Op>(node.as<CallNode>()->op); |
421 | |
422 | if (node.as<CallNode>()->op != dequantize_op_) { |
423 | if (hard_fail_) { |
424 | LOG(FATAL) << "Not dequantization was found in the input arguments for" |
425 | << AsText(op, false) << std::endl; |
426 | } else { |
427 | DLOG(INFO) << "Not dequantization was found in the input arguments for " |
428 | << AsText(op, false) << std::endl; |
429 | return expr; |
430 | } |
431 | } |
432 | } |
433 | try { |
434 | return Mutate(expr); |
435 | } catch (std::exception& e) { |
436 | if (hard_fail_) { |
437 | throw e; |
438 | } else { |
439 | DLOG(INFO) << "Ran into an error rewriting a subgraph, skipping" << expr << std::endl; |
440 | return expr; |
441 | } |
442 | } |
443 | } |
444 | |
445 | protected: |
446 | Expr VisitExpr_(const CallNode* call_node) { |
447 | Expr out; |
448 | static auto fqfq = |
449 | Op::GetAttrMap<FTVMFakeQuantizationToInteger>("FTVMFakeQuantizationToInteger" ); |
450 | |
451 | Op op = Downcast<Op>(call_node->op); |
452 | if (fqfq.count(op)) { |
453 | Expr expr; |
454 | if (op == dequantize_op_) { |
455 | expr = GetRef<Expr>(call_node); |
456 | } else { |
457 | expr = ExprMutator::VisitExpr_(call_node); |
458 | } |
459 | // Call the rewrite |
460 | Array<ObjectRef> vals = fqfq[op](expr, affine_types_); |
461 | // Save the outputs of the rewrite |
462 | ICHECK(vals.size() == 2) |
463 | << "got the wrong number of returned arguments from FTVMFakeQuantizationToInteger for " |
464 | << AsText(op, false); |
465 | out = Downcast<Expr>(vals[0]); |
466 | |
467 | affine_types_.Set(out, Downcast<AffineType>(vals[1])); |
468 | |
469 | if (call_node == quantize_node_) { |
470 | out = qnn::MakeDequantize(out, vals[1].as<TensorAffineTypeNode>()->scale, |
471 | vals[1].as<TensorAffineTypeNode>()->zero_point, |
472 | vals[1].as<TensorAffineTypeNode>()->axis); |
473 | } |
474 | } else { |
475 | ICHECK(false) << "When rewriting a fake quantized graph, found an invalid node " |
476 | << AsText(GetRef<Expr>(call_node), false); |
477 | } |
478 | return out; |
479 | } |
480 | |
481 | Expr VisitExpr_(const TupleNode* node) { |
482 | Expr expr = ExprMutator::VisitExpr_(node); |
483 | auto new_node = expr.as<TupleNode>(); |
484 | Array<TensorAffineType> types; |
485 | for (Expr field : new_node->fields) { |
486 | ICHECK(affine_types_[field].as<TensorAffineTypeNode>()); |
487 | types.push_back(Downcast<TensorAffineType>(affine_types_[field])); |
488 | } |
489 | affine_types_.Set(expr, TupleAffineType(types)); |
490 | return expr; |
491 | } |
492 | |
493 | Expr VisitExpr_(const TupleGetItemNode* node) { |
494 | Expr expr = ExprMutator::VisitExpr_(node); |
495 | auto tuple_type = affine_types_[expr.as<TupleGetItemNode>()->tuple].as<TupleAffineTypeNode>(); |
496 | affine_types_.Set(expr, tuple_type->types[node->index]); |
497 | return expr; |
498 | } |
499 | |
500 | ExprSet subgraph_; |
501 | AffineTypeMap affine_types_; |
502 | const bool hard_fail_; |
503 | const Op dequantize_op_ = Op::Get("qnn.dequantize" ); |
504 | const CallNode* quantize_node_ = nullptr; |
505 | }; |
506 | |
507 | class QATRewriter : public MixedModeMutator { |
508 | public: |
509 | explicit QATRewriter(bool hard_fail) : hard_fail_(hard_fail) {} |
510 | |
511 | protected: |
512 | Expr Rewrite_(const CallNode* pre, const Expr& post) override { |
513 | if (const CallNode* call_node = post.as<CallNode>()) { |
514 | const Op op = Downcast<Op>(call_node->op); |
515 | if (is_op_enabled_for_optional_fq2i(call_node)) { |
516 | QATSubgraphExtractor ; |
517 | ExprSet subgraph = extractor.GetSubgraph(post); |
518 | AffineTypeMap affine_types = extractor.GetAffineTypes(); |
519 | Expr out = QATSubgraphMutator(subgraph, affine_types, hard_fail_).MutateSubgraph(post); |
520 | return out; |
521 | } |
522 | } |
523 | return post; |
524 | } |
525 | const bool hard_fail_; |
526 | }; |
527 | |
528 | Expr FakeQuantizationToInteger(const Expr& expr, const IRModule& mod, bool hard_fail, |
529 | bool use_qat) { |
530 | auto fq_expr = FakeQuantizationRewriter(hard_fail).Mutate(expr); |
531 | if (use_qat) { |
532 | fq_expr = tvm::relay::InferType(fq_expr); |
533 | fq_expr = QATRewriter(hard_fail).Mutate(fq_expr); |
534 | } |
535 | return fq_expr; |
536 | } |
537 | |
538 | namespace transform { |
539 | |
540 | Pass FakeQuantizationToInteger(bool hard_fail, bool use_qat) { |
541 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
542 | [=](Function f, IRModule m, PassContext pc) { |
543 | return Downcast<Function>(FakeQuantizationToInteger(f, m, hard_fail, use_qat)); |
544 | }; |
545 | return CreateFunctionPass(pass_func, 0, "FakeQuantizationToInteger" , {"InferType" , "DivToMul" }); |
546 | } |
547 | |
548 | TVM_REGISTER_GLOBAL("relay._transform.FakeQuantizationToInteger" ) |
549 | .set_body_typed(FakeQuantizationToInteger); |
550 | |
551 | } // namespace transform |
552 | |
553 | } // namespace relay |
554 | } // namespace tvm |
555 | |