1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file realize.cc
23 *
24 * \brief Realizing the simulated graph into real low-precision
25 * graph.
26 */
27
28#include "./realize.h"
29
30#include <tvm/relay/analysis.h>
31#include <tvm/relay/attrs/annotation.h>
32#include <tvm/relay/transform.h>
33
34#include "../op/annotation/annotation.h"
35#include "../qnn/utils.h"
36#include "../transforms/fold_constant.h"
37#include "./quantize.h"
38
39namespace tvm {
40namespace relay {
41namespace quantize {
42
43using namespace relay::transform;
44
45Expr QRealizeIntExprNode::Realize() const {
46 Expr data = this->data;
47 // dequantize
48 data = Cast(data, DataType::Float(32));
49 data = Multiply(data, this->dom_scale);
50 return data;
51}
52
53QRealizeIntExpr::QRealizeIntExpr(Expr data, Expr dom_scale, DataType dtype) {
54 ObjectPtr<QRealizeIntExprNode> n = make_object<QRealizeIntExprNode>();
55 n->data = std::move(data);
56 n->dom_scale = std::move(dom_scale);
57 n->dtype = std::move(dtype);
58 data_ = std::move(n);
59}
60
61inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) {
62 return Call(ref_call->op, args, ref_call->attrs, ref_call->type_args);
63}
64
65/* calculate `data * s1 / s2`, use shift if possible */
66inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype,
67 const Array<IndexExpr>& data_shape) {
68 const QConfig& cfg = QConfig::Current();
69 // here we assume the dtype of data is dtype activation
70 if (s1 == s2) return data;
71
72 float factor = s1 / s2;
73 float shift_factor = std::log2(factor);
74 ICHECK_GT(shift_factor, 0);
75 if (static_cast<int>(shift_factor) == shift_factor) {
76 return LeftShift(data, MakeConstantScalar(dtype, static_cast<int>(shift_factor)));
77 } else if (static_cast<int>(factor) == factor) {
78 return Multiply(data, MakeConstantScalar(dtype, factor));
79 } else {
80 if (cfg->rounding == "UPWARD") {
81 auto [fixed_point_multiplier, shift] = qnn::GetFixedPointMultiplierShift(factor);
82 data = relay::FixedPointMultiply(data, fixed_point_multiplier, shift);
83 } else {
84 data = qnn::FixedPointMultiplyToNearest(data, factor, data_shape);
85 }
86
87 return Cast(data, dtype);
88 }
89}
90
91Expr QuantizeRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
92 const QConfig& cfg = QConfig::Current();
93 // do not handle data type cast
94 const auto param = ref_call->attrs.as<SimulatedQuantizeAttrs>();
95 ICHECK_EQ(param->rounding, "round");
96
97 Expr dom_scale = new_args[1];
98 Expr clip_min = new_args[2];
99 Expr clip_max = new_args[3];
100
101 float dom_scale_imm = GetScalarFromConstant<float>(dom_scale);
102 float clip_min_imm = GetScalarFromConstant<float>(clip_min);
103 float clip_max_imm = GetScalarFromConstant<float>(clip_max);
104
105 // x * idom_scale = y * odom_scale
106 // => y = x * idom_scale / odom_scale
107 if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
108 // int32->int8
109 Expr data = n->data;
110 float idom_scale_imm = GetScalarFromConstant<float>(n->dom_scale);
111 float odom_scale_imm = GetScalarFromConstant<float>(dom_scale);
112 if (idom_scale_imm == odom_scale_imm) {
113 // same domain scale, only clip
114 data = Clip(data, clip_min_imm, clip_max_imm);
115 return QRealizeIntExpr(data, dom_scale, n->dtype);
116 }
117
118 float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm);
119 ICHECK_NE(shift_nbit, 0);
120 if (static_cast<int>(shift_nbit) == shift_nbit) {
121 if (shift_nbit > 0) {
122 // use right shift
123 if (cfg->round_for_shift) {
124 float round_bias = std::pow(2.0, shift_nbit - 1);
125 data = Add(data, MakeConstantScalar(cfg->dtype_activation, static_cast<int>(round_bias)));
126 }
127 data = RightShift(data,
128 MakeConstantScalar(cfg->dtype_activation, static_cast<int>(shift_nbit)));
129 } else {
130 data = LeftShift(data,
131 MakeConstantScalar(cfg->dtype_activation, static_cast<int>(-shift_nbit)));
132 }
133 data = Clip(data, clip_min_imm, clip_max_imm);
134 return QRealizeIntExpr(data, dom_scale, n->dtype);
135 } else {
136 data = Cast(data, DataType::Int(64));
137 if (cfg->rounding == "UPWARD") {
138 auto [fixed_point_multiplier, shift] =
139 qnn::GetFixedPointMultiplierShift(idom_scale_imm / odom_scale_imm);
140 data = relay::FixedPointMultiply(data, fixed_point_multiplier, shift);
141 } else {
142 data = qnn::FixedPointMultiplyToNearest(data, idom_scale_imm / odom_scale_imm,
143 ref_call->type_as<TensorTypeNode>()->shape);
144 }
145 data = Cast(Clip(data, clip_min_imm, clip_max_imm), n->dtype);
146 return QRealizeIntExpr(data, dom_scale, n->dtype);
147 }
148 }
149
150 // quantize from real
151 ICHECK(!new_args[0]->IsInstance<TempExprNode>());
152 Expr data = new_args[0];
153 Expr scaled_data = Multiply(data, MakeConstantScalar(DataType::Float(32), 1 / dom_scale_imm));
154 Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm);
155 return QRealizeIntExpr(round_data, dom_scale, DataType::Float(32));
156}
157
158RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize")
159 .set_attr<FForwardRewrite>("FQRealizeRewrite", QuantizeRealize);
160
161Expr Conv2dRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
162 const QConfig& cfg = QConfig::Current();
163 ICHECK_EQ(new_args.size(), 2);
164 if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
165 const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
166 const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
167 Expr ldata = lhs->data;
168 if (lhs->dtype != cfg->dtype_input) {
169 ldata = Cast(ldata, cfg->dtype_input);
170 }
171 Expr rdata = Cast(rhs->data, cfg->dtype_weight);
172
173 const auto ref_attrs = ref_call->attrs.as<Conv2DAttrs>();
174 auto attrs = make_object<Conv2DAttrs>();
175 *attrs = *ref_attrs;
176 DataType out_dtype = cfg->dtype_activation;
177 attrs->out_dtype = out_dtype;
178
179 Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args);
180 Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
181 Expr dom_scale = FoldConstantExpr(mul);
182 return QRealizeIntExpr(ret, dom_scale, out_dtype);
183 }
184 ICHECK(!new_args[0]->IsInstance<TempExprNode>() || !new_args[1]->IsInstance<TempExprNode>());
185 return Expr(nullptr);
186}
187
188RELAY_REGISTER_OP("nn.conv2d").set_attr<FForwardRewrite>("FQRealizeRewrite", Conv2dRealize);
189
190Expr Conv1dRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
191 const QConfig& cfg = QConfig::Current();
192 CHECK_EQ(new_args.size(), 2);
193 if (!new_args[0]->IsInstance<TempExprNode>() && !new_args[1]->IsInstance<TempExprNode>()) {
194 return Expr(nullptr);
195 }
196 const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
197 CHECK(lhs);
198 const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
199 CHECK(rhs);
200
201 Expr ldata = lhs->data;
202 if (lhs->dtype != cfg->dtype_input) {
203 ldata = Cast(ldata, cfg->dtype_input);
204 }
205 Expr rdata = Cast(rhs->data, cfg->dtype_weight);
206
207 const auto ref_attrs = ref_call->attrs.as<Conv1DAttrs>();
208 auto attrs = make_object<Conv1DAttrs>();
209 *attrs = *ref_attrs;
210 DataType out_dtype = cfg->dtype_activation;
211 attrs->out_dtype = out_dtype;
212
213 Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args);
214 Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
215 Expr dom_scale = FoldConstantExpr(mul);
216 return QRealizeIntExpr(ret, dom_scale, out_dtype);
217}
218
219RELAY_REGISTER_OP("nn.conv1d").set_attr<FForwardRewrite>("FQRealizeRewrite", Conv1dRealize);
220
221Expr DenseRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
222 const QConfig& cfg = QConfig::Current();
223 ICHECK_EQ(new_args.size(), 2);
224 if (!new_args[0]->IsInstance<TempExprNode>() || !new_args[1]->IsInstance<TempExprNode>()) {
225 return Expr(nullptr);
226 }
227 const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
228 const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
229
230 Expr ldata = lhs->data;
231 if (lhs->dtype != cfg->dtype_input) {
232 ldata = Cast(ldata, cfg->dtype_input);
233 }
234 Expr rdata = Cast(rhs->data, cfg->dtype_weight);
235
236 const auto ref_attrs = ref_call->attrs.as<DenseAttrs>();
237 auto attrs = make_object<DenseAttrs>();
238 *attrs = *ref_attrs;
239 DataType out_dtype = cfg->dtype_activation;
240 attrs->out_dtype = out_dtype;
241
242 Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args);
243 Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
244 Expr dom_scale = FoldConstantExpr(mul);
245 return QRealizeIntExpr(ret, dom_scale, out_dtype);
246}
247
248RELAY_REGISTER_OP("nn.dense").set_attr<FForwardRewrite>("FQRealizeRewrite", DenseRealize);
249
250Expr MulRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
251 const QConfig& cfg = QConfig::Current();
252 ICHECK_EQ(new_args.size(), 2);
253 if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
254 // execute the operation with activation data type.
255 const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
256 const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
257 Expr ldata = lhs->data;
258 Expr rdata = rhs->data;
259
260 DataType dtype = cfg->dtype_activation;
261 if (lhs->dtype != dtype) {
262 ldata = Cast(ldata, dtype);
263 }
264 if (rhs->dtype != dtype) {
265 rdata = Cast(rdata, dtype);
266 }
267
268 Expr ret = ForwardOp(ref_call, {ldata, rdata});
269 Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
270 Expr dom_scale = FoldConstantExpr(mul);
271 return QRealizeIntExpr(ret, dom_scale, dtype);
272 }
273 ICHECK(!new_args[0]->IsInstance<TempExprNode>() || !new_args[1]->IsInstance<TempExprNode>());
274 return Expr(nullptr);
275}
276
277RELAY_REGISTER_OP("multiply").set_attr<FForwardRewrite>("FQRealizeRewrite", MulRealize);
278
279float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) {
280 if (nptrs.size() == 2) {
281 // x = a * s1, y = b * s2
282 // x + y = (a * s1 / s2 + b) * s2, if s1 > s2
283 // = (a + b * s2 / s1) * s1, if s2 > s1
284 float s1 = GetScalarFromConstant<float>(nptrs[0]->dom_scale);
285 float s2 = GetScalarFromConstant<float>(nptrs[1]->dom_scale);
286 return s1 > s2 ? s2 : s1;
287 } else {
288 const QConfig& cfg = QConfig::Current();
289 float scale = cfg->global_scale;
290 return scale / std::pow(2.0, cfg->nbit_activation - 1);
291 }
292}
293
294/* \brief Unify the dom scale of arguments */
295Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args,
296 DataType* dtype_ptr, Expr* scale_ptr,
297 DataType dtype = DataType::Void()) {
298 static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
299 const QConfig& cfg = QConfig::Current();
300
301 std::vector<const QRealizeIntExprNode*> nptrs;
302 Array<Expr> ret;
303 for (auto arg : args) {
304 const auto* nptr = arg.as<QRealizeIntExprNode>();
305 ICHECK(nptr);
306 nptrs.push_back(nptr);
307 ret.push_back(nptr->data);
308 }
309
310 // unify the data type
311 ICHECK_EQ(ref_args.size(), args.size());
312
313 if (dtype.is_void()) {
314 if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) {
315 dtype = cfg->dtype_input;
316 } else {
317 dtype = cfg->dtype_activation;
318 }
319 }
320
321 for (size_t i = 0; i < ret.size(); ++i) {
322 auto ref_arg = ref_args[i].as<CallNode>();
323 if (nptrs[i]->dtype != dtype) {
324 ret.Set(i, Cast(ret[i], dtype));
325 } else if (ref_arg && ref_arg->op.same_as(simulated_quantize) &&
326 ref_arg->attrs.as<SimulatedQuantizeAttrs>()->kind == kQInput) {
327 auto new_arg = Cast(ret[i], cfg->dtype_input);
328 new_arg = StopFusion(new_arg);
329 ret.Set(i, Cast(new_arg, dtype));
330 }
331 }
332
333 // unify the dom_scale
334 float s = ChooseDomScale(nptrs);
335 Expr dom_scale = MakeConstantScalar(DataType::Float(32), s);
336 for (size_t i = 0; i < ret.size(); ++i) {
337 float cur_s = GetScalarFromConstant<float>(nptrs[i]->dom_scale);
338 ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype, ref_args[i]->type_as<TensorTypeNode>()->shape));
339 }
340
341 *dtype_ptr = dtype;
342 *scale_ptr = dom_scale;
343 return ret;
344}
345
346Expr AddRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
347 ICHECK_EQ(new_args.size(), 2);
348 if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
349 DataType dtype;
350 Expr dom_scale;
351 // execute the operation with activation data type.
352 const QConfig& cfg = QConfig::Current();
353 Array<Expr> ret_args =
354 UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale, cfg->dtype_activation);
355 for (size_t i = 0; i < ret_args.size(); ++i) {
356 // do not fuse float32 arg
357 if (new_args[i].as<QRealizeIntExprNode>()->dtype == DataType::Float(32)) {
358 ret_args.Set(i, StopFusion(ret_args[i]));
359 }
360 }
361 Expr ret = ForwardOp(ref_call, ret_args);
362 return QRealizeIntExpr(ret, dom_scale, dtype);
363 }
364
365 ICHECK(!new_args[0]->IsInstance<TempExprNode>() && !new_args[1]->IsInstance<TempExprNode>());
366 return Expr(nullptr);
367}
368
369RELAY_REGISTER_OP("add").set_attr<FForwardRewrite>("FQRealizeRewrite", AddRealize);
370
371Expr ClipRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
372 ICHECK_EQ(new_args.size(), 1);
373 if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
374 const auto ref_attrs = ref_call->attrs.as<ClipAttrs>();
375 auto attrs = make_object<ClipAttrs>();
376 double dom_scale = GetScalarFromConstant<float>(n->dom_scale);
377 attrs->a_min = ref_attrs->a_min / dom_scale;
378 attrs->a_max = ref_attrs->a_max / dom_scale;
379
380 Expr ret = Call(ref_call->op, {n->data}, Attrs(attrs), ref_call->type_args);
381 return QRealizeIntExpr(ret, n->dom_scale, n->dtype);
382 }
383 ICHECK(!new_args[0]->IsInstance<TempExprNode>());
384 return Expr(nullptr);
385}
386
387RELAY_REGISTER_OP("clip").set_attr<FForwardRewrite>("FQRealizeRewrite", ClipRealize);
388
389Expr ConcatenateRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
390 ICHECK_EQ(new_args.size(), 1);
391 ICHECK_EQ(ref_call->args.size(), 1);
392
393 const auto* tuple = new_args[0].as<TupleNode>();
394 const auto* ref_tuple = ref_call->args[0].as<TupleNode>();
395 ICHECK(tuple);
396 ICHECK(ref_tuple);
397 const Array<Expr>& arr = tuple->fields;
398 const Array<Expr>& ref_arr = ref_tuple->fields;
399
400 if (arr[0].as<QRealizeIntExprNode>()) {
401 DataType dtype;
402 Expr dom_scale;
403 Array<Expr> ret_args = UnifyDTypeScale(ref_arr, arr, &dtype, &dom_scale);
404 Expr ret = ForwardOp(ref_call, {Tuple(ret_args)});
405 return QRealizeIntExpr(ret, dom_scale, dtype);
406 } else {
407 for (auto arg : new_args) {
408 ICHECK(!arg->IsInstance<TempExprNode>());
409 }
410 return Expr(nullptr);
411 }
412}
413
414RELAY_REGISTER_OP("concatenate").set_attr<FForwardRewrite>("FQRealizeRewrite", ConcatenateRealize);
415
416/* \brief forward the original operator */
417Expr IdentityRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
418 ICHECK_EQ(new_args.size(), 1);
419 if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
420 Expr ret = ForwardOp(ref_call, {n->data});
421 return QRealizeIntExpr(ret, n->dom_scale, n->dtype);
422 }
423 ICHECK(!new_args[0]->IsInstance<TempExprNode>());
424 return Expr(nullptr);
425}
426
427RELAY_REGISTER_OP("nn.relu").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
428
429RELAY_REGISTER_OP("reshape").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
430
431RELAY_REGISTER_OP("strided_slice").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
432
433RELAY_REGISTER_OP("nn.batch_flatten")
434 .set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
435
436RELAY_REGISTER_OP("transpose").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
437
438RELAY_REGISTER_OP("annotation.stop_fusion")
439 .set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
440
441/* \brief for unary operators which requantize its input to dtype_nbit */
442Expr CastDtypeInputRealize(const Call& ref_call, const Array<Expr>& new_args,
443 const ObjectRef& ctx) {
444 const QConfig& cfg = QConfig::Current();
445 ICHECK_EQ(new_args.size(), 1);
446 if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
447 Expr data = Cast(n->data, cfg->dtype_input);
448 Expr ret = ForwardOp(ref_call, {data});
449 return QRealizeIntExpr(ret, n->dom_scale, cfg->dtype_input);
450 }
451 ICHECK(!new_args[0]->IsInstance<TempExprNode>());
452 return Expr(nullptr);
453}
454
455RELAY_REGISTER_OP("nn.max_pool2d")
456 .set_attr<FForwardRewrite>("FQRealizeRewrite", CastDtypeInputRealize);
457
458RELAY_REGISTER_OP("nn.max_pool1d")
459 .set_attr<FForwardRewrite>("FQRealizeRewrite", CastDtypeInputRealize);
460
461Expr AvgPoolRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
462 const QConfig& cfg = QConfig::Current();
463 ICHECK_EQ(new_args.size(), 1);
464 if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
465 Expr data = n->data;
466 if (n->dtype != cfg->dtype_activation) {
467 data = Cast(n->data, cfg->dtype_activation);
468 }
469 Expr ret = ForwardOp(ref_call, {data});
470 return QRealizeIntExpr(ret, n->dom_scale, cfg->dtype_activation);
471 }
472 ICHECK(!new_args[0]->IsInstance<TempExprNode>());
473 return Expr(nullptr);
474}
475
476RELAY_REGISTER_OP("nn.avg_pool2d").set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
477
478RELAY_REGISTER_OP("nn.global_avg_pool2d")
479 .set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
480
481Expr CastHintRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
482 const auto param = ref_call->attrs.as<CastHintAttrs>();
483 ICHECK_EQ(new_args.size(), 1);
484 if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
485 Expr ret = Cast(n->data, param->dtype);
486 return QRealizeIntExpr(ret, n->dom_scale, param->dtype);
487 }
488 ICHECK(!new_args[0]->IsInstance<TempExprNode>());
489 return Expr(nullptr);
490}
491
492RELAY_REGISTER_OP("annotation.cast_hint")
493 .set_attr<FForwardRewrite>("FQRealizeRewrite", CastHintRealize);
494
495Expr BatchMatmulRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) {
496 const QConfig& cfg = QConfig::Current();
497 ICHECK_EQ(new_args.size(), 2);
498 if (!new_args[0]->IsInstance<TempExprNode>() || !new_args[1]->IsInstance<TempExprNode>()) {
499 return Expr(nullptr);
500 }
501 const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
502 const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
503
504 Expr ldata = lhs->data;
505 Expr rdata = rhs->data;
506 DataType dtype_input = cfg->dtype_input;
507 DataType dtype_weight = cfg->dtype_weight;
508
509 if (lhs->dtype != dtype_input) {
510 ldata = Cast(ldata, dtype_input);
511 }
512 if (rhs->dtype != dtype_weight) {
513 rdata = Cast(rdata, dtype_weight);
514 }
515
516 const auto ref_attrs = ref_call->attrs.as<BatchMatmulAttrs>();
517 auto attrs = make_object<BatchMatmulAttrs>();
518 *attrs = *ref_attrs;
519 DataType out_dtype = cfg->dtype_activation;
520 attrs->out_dtype = out_dtype;
521
522 Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args);
523 Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
524 Expr dom_scale = FoldConstantExpr(mul);
525 return QRealizeIntExpr(ret, dom_scale, out_dtype);
526}
527
528RELAY_REGISTER_OP("nn.batch_matmul")
529 .set_attr<FForwardRewrite>("FQRealizeRewrite", BatchMatmulRealize);
530
531Pass QuantizeRealizePass() {
532 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
533 [=](Function f, IRModule m, PassContext pc) {
534 return Downcast<Function>(ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr));
535 };
536 return CreateFunctionPass(pass_func, 1, "QuantizeRealize", {});
537}
538
539TVM_REGISTER_GLOBAL("relay._quantize.QuantizeRealize").set_body_typed(QuantizeRealizePass);
540
541} // namespace quantize
542} // namespace relay
543} // namespace tvm
544