1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file realize.cc |
23 | * |
24 | * \brief Realizing the simulated graph into real low-precision |
25 | * graph. |
26 | */ |
27 | |
28 | #include "./realize.h" |
29 | |
30 | #include <tvm/relay/analysis.h> |
31 | #include <tvm/relay/attrs/annotation.h> |
32 | #include <tvm/relay/transform.h> |
33 | |
34 | #include "../op/annotation/annotation.h" |
35 | #include "../qnn/utils.h" |
36 | #include "../transforms/fold_constant.h" |
37 | #include "./quantize.h" |
38 | |
39 | namespace tvm { |
40 | namespace relay { |
41 | namespace quantize { |
42 | |
43 | using namespace relay::transform; |
44 | |
45 | Expr QRealizeIntExprNode::Realize() const { |
46 | Expr data = this->data; |
47 | // dequantize |
48 | data = Cast(data, DataType::Float(32)); |
49 | data = Multiply(data, this->dom_scale); |
50 | return data; |
51 | } |
52 | |
53 | QRealizeIntExpr::QRealizeIntExpr(Expr data, Expr dom_scale, DataType dtype) { |
54 | ObjectPtr<QRealizeIntExprNode> n = make_object<QRealizeIntExprNode>(); |
55 | n->data = std::move(data); |
56 | n->dom_scale = std::move(dom_scale); |
57 | n->dtype = std::move(dtype); |
58 | data_ = std::move(n); |
59 | } |
60 | |
61 | inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) { |
62 | return Call(ref_call->op, args, ref_call->attrs, ref_call->type_args); |
63 | } |
64 | |
65 | /* calculate `data * s1 / s2`, use shift if possible */ |
66 | inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype, |
67 | const Array<IndexExpr>& data_shape) { |
68 | const QConfig& cfg = QConfig::Current(); |
69 | // here we assume the dtype of data is dtype activation |
70 | if (s1 == s2) return data; |
71 | |
72 | float factor = s1 / s2; |
73 | float shift_factor = std::log2(factor); |
74 | ICHECK_GT(shift_factor, 0); |
75 | if (static_cast<int>(shift_factor) == shift_factor) { |
76 | return LeftShift(data, MakeConstantScalar(dtype, static_cast<int>(shift_factor))); |
77 | } else if (static_cast<int>(factor) == factor) { |
78 | return Multiply(data, MakeConstantScalar(dtype, factor)); |
79 | } else { |
80 | if (cfg->rounding == "UPWARD" ) { |
81 | auto [fixed_point_multiplier, shift] = qnn::GetFixedPointMultiplierShift(factor); |
82 | data = relay::FixedPointMultiply(data, fixed_point_multiplier, shift); |
83 | } else { |
84 | data = qnn::FixedPointMultiplyToNearest(data, factor, data_shape); |
85 | } |
86 | |
87 | return Cast(data, dtype); |
88 | } |
89 | } |
90 | |
91 | Expr QuantizeRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) { |
92 | const QConfig& cfg = QConfig::Current(); |
93 | // do not handle data type cast |
94 | const auto param = ref_call->attrs.as<SimulatedQuantizeAttrs>(); |
95 | ICHECK_EQ(param->rounding, "round" ); |
96 | |
97 | Expr dom_scale = new_args[1]; |
98 | Expr clip_min = new_args[2]; |
99 | Expr clip_max = new_args[3]; |
100 | |
101 | float dom_scale_imm = GetScalarFromConstant<float>(dom_scale); |
102 | float clip_min_imm = GetScalarFromConstant<float>(clip_min); |
103 | float clip_max_imm = GetScalarFromConstant<float>(clip_max); |
104 | |
105 | // x * idom_scale = y * odom_scale |
106 | // => y = x * idom_scale / odom_scale |
107 | if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) { |
108 | // int32->int8 |
109 | Expr data = n->data; |
110 | float idom_scale_imm = GetScalarFromConstant<float>(n->dom_scale); |
111 | float odom_scale_imm = GetScalarFromConstant<float>(dom_scale); |
112 | if (idom_scale_imm == odom_scale_imm) { |
113 | // same domain scale, only clip |
114 | data = Clip(data, clip_min_imm, clip_max_imm); |
115 | return QRealizeIntExpr(data, dom_scale, n->dtype); |
116 | } |
117 | |
118 | float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm); |
119 | ICHECK_NE(shift_nbit, 0); |
120 | if (static_cast<int>(shift_nbit) == shift_nbit) { |
121 | if (shift_nbit > 0) { |
122 | // use right shift |
123 | if (cfg->round_for_shift) { |
124 | float round_bias = std::pow(2.0, shift_nbit - 1); |
125 | data = Add(data, MakeConstantScalar(cfg->dtype_activation, static_cast<int>(round_bias))); |
126 | } |
127 | data = RightShift(data, |
128 | MakeConstantScalar(cfg->dtype_activation, static_cast<int>(shift_nbit))); |
129 | } else { |
130 | data = LeftShift(data, |
131 | MakeConstantScalar(cfg->dtype_activation, static_cast<int>(-shift_nbit))); |
132 | } |
133 | data = Clip(data, clip_min_imm, clip_max_imm); |
134 | return QRealizeIntExpr(data, dom_scale, n->dtype); |
135 | } else { |
136 | data = Cast(data, DataType::Int(64)); |
137 | if (cfg->rounding == "UPWARD" ) { |
138 | auto [fixed_point_multiplier, shift] = |
139 | qnn::GetFixedPointMultiplierShift(idom_scale_imm / odom_scale_imm); |
140 | data = relay::FixedPointMultiply(data, fixed_point_multiplier, shift); |
141 | } else { |
142 | data = qnn::FixedPointMultiplyToNearest(data, idom_scale_imm / odom_scale_imm, |
143 | ref_call->type_as<TensorTypeNode>()->shape); |
144 | } |
145 | data = Cast(Clip(data, clip_min_imm, clip_max_imm), n->dtype); |
146 | return QRealizeIntExpr(data, dom_scale, n->dtype); |
147 | } |
148 | } |
149 | |
150 | // quantize from real |
151 | ICHECK(!new_args[0]->IsInstance<TempExprNode>()); |
152 | Expr data = new_args[0]; |
153 | Expr scaled_data = Multiply(data, MakeConstantScalar(DataType::Float(32), 1 / dom_scale_imm)); |
154 | Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm); |
155 | return QRealizeIntExpr(round_data, dom_scale, DataType::Float(32)); |
156 | } |
157 | |
158 | RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize" ) |
159 | .set_attr<FForwardRewrite>("FQRealizeRewrite" , QuantizeRealize); |
160 | |
161 | Expr Conv2dRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) { |
162 | const QConfig& cfg = QConfig::Current(); |
163 | ICHECK_EQ(new_args.size(), 2); |
164 | if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) { |
165 | const auto* lhs = new_args[0].as<QRealizeIntExprNode>(); |
166 | const auto* rhs = new_args[1].as<QRealizeIntExprNode>(); |
167 | Expr ldata = lhs->data; |
168 | if (lhs->dtype != cfg->dtype_input) { |
169 | ldata = Cast(ldata, cfg->dtype_input); |
170 | } |
171 | Expr rdata = Cast(rhs->data, cfg->dtype_weight); |
172 | |
173 | const auto ref_attrs = ref_call->attrs.as<Conv2DAttrs>(); |
174 | auto attrs = make_object<Conv2DAttrs>(); |
175 | *attrs = *ref_attrs; |
176 | DataType out_dtype = cfg->dtype_activation; |
177 | attrs->out_dtype = out_dtype; |
178 | |
179 | Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args); |
180 | Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); |
181 | Expr dom_scale = FoldConstantExpr(mul); |
182 | return QRealizeIntExpr(ret, dom_scale, out_dtype); |
183 | } |
184 | ICHECK(!new_args[0]->IsInstance<TempExprNode>() || !new_args[1]->IsInstance<TempExprNode>()); |
185 | return Expr(nullptr); |
186 | } |
187 | |
188 | RELAY_REGISTER_OP("nn.conv2d" ).set_attr<FForwardRewrite>("FQRealizeRewrite" , Conv2dRealize); |
189 | |
190 | Expr Conv1dRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) { |
191 | const QConfig& cfg = QConfig::Current(); |
192 | CHECK_EQ(new_args.size(), 2); |
193 | if (!new_args[0]->IsInstance<TempExprNode>() && !new_args[1]->IsInstance<TempExprNode>()) { |
194 | return Expr(nullptr); |
195 | } |
196 | const auto* lhs = new_args[0].as<QRealizeIntExprNode>(); |
197 | CHECK(lhs); |
198 | const auto* rhs = new_args[1].as<QRealizeIntExprNode>(); |
199 | CHECK(rhs); |
200 | |
201 | Expr ldata = lhs->data; |
202 | if (lhs->dtype != cfg->dtype_input) { |
203 | ldata = Cast(ldata, cfg->dtype_input); |
204 | } |
205 | Expr rdata = Cast(rhs->data, cfg->dtype_weight); |
206 | |
207 | const auto ref_attrs = ref_call->attrs.as<Conv1DAttrs>(); |
208 | auto attrs = make_object<Conv1DAttrs>(); |
209 | *attrs = *ref_attrs; |
210 | DataType out_dtype = cfg->dtype_activation; |
211 | attrs->out_dtype = out_dtype; |
212 | |
213 | Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args); |
214 | Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); |
215 | Expr dom_scale = FoldConstantExpr(mul); |
216 | return QRealizeIntExpr(ret, dom_scale, out_dtype); |
217 | } |
218 | |
219 | RELAY_REGISTER_OP("nn.conv1d" ).set_attr<FForwardRewrite>("FQRealizeRewrite" , Conv1dRealize); |
220 | |
221 | Expr DenseRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) { |
222 | const QConfig& cfg = QConfig::Current(); |
223 | ICHECK_EQ(new_args.size(), 2); |
224 | if (!new_args[0]->IsInstance<TempExprNode>() || !new_args[1]->IsInstance<TempExprNode>()) { |
225 | return Expr(nullptr); |
226 | } |
227 | const auto* lhs = new_args[0].as<QRealizeIntExprNode>(); |
228 | const auto* rhs = new_args[1].as<QRealizeIntExprNode>(); |
229 | |
230 | Expr ldata = lhs->data; |
231 | if (lhs->dtype != cfg->dtype_input) { |
232 | ldata = Cast(ldata, cfg->dtype_input); |
233 | } |
234 | Expr rdata = Cast(rhs->data, cfg->dtype_weight); |
235 | |
236 | const auto ref_attrs = ref_call->attrs.as<DenseAttrs>(); |
237 | auto attrs = make_object<DenseAttrs>(); |
238 | *attrs = *ref_attrs; |
239 | DataType out_dtype = cfg->dtype_activation; |
240 | attrs->out_dtype = out_dtype; |
241 | |
242 | Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args); |
243 | Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); |
244 | Expr dom_scale = FoldConstantExpr(mul); |
245 | return QRealizeIntExpr(ret, dom_scale, out_dtype); |
246 | } |
247 | |
248 | RELAY_REGISTER_OP("nn.dense" ).set_attr<FForwardRewrite>("FQRealizeRewrite" , DenseRealize); |
249 | |
250 | Expr MulRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) { |
251 | const QConfig& cfg = QConfig::Current(); |
252 | ICHECK_EQ(new_args.size(), 2); |
253 | if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) { |
254 | // execute the operation with activation data type. |
255 | const auto* lhs = new_args[0].as<QRealizeIntExprNode>(); |
256 | const auto* rhs = new_args[1].as<QRealizeIntExprNode>(); |
257 | Expr ldata = lhs->data; |
258 | Expr rdata = rhs->data; |
259 | |
260 | DataType dtype = cfg->dtype_activation; |
261 | if (lhs->dtype != dtype) { |
262 | ldata = Cast(ldata, dtype); |
263 | } |
264 | if (rhs->dtype != dtype) { |
265 | rdata = Cast(rdata, dtype); |
266 | } |
267 | |
268 | Expr ret = ForwardOp(ref_call, {ldata, rdata}); |
269 | Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); |
270 | Expr dom_scale = FoldConstantExpr(mul); |
271 | return QRealizeIntExpr(ret, dom_scale, dtype); |
272 | } |
273 | ICHECK(!new_args[0]->IsInstance<TempExprNode>() || !new_args[1]->IsInstance<TempExprNode>()); |
274 | return Expr(nullptr); |
275 | } |
276 | |
277 | RELAY_REGISTER_OP("multiply" ).set_attr<FForwardRewrite>("FQRealizeRewrite" , MulRealize); |
278 | |
279 | float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) { |
280 | if (nptrs.size() == 2) { |
281 | // x = a * s1, y = b * s2 |
282 | // x + y = (a * s1 / s2 + b) * s2, if s1 > s2 |
283 | // = (a + b * s2 / s1) * s1, if s2 > s1 |
284 | float s1 = GetScalarFromConstant<float>(nptrs[0]->dom_scale); |
285 | float s2 = GetScalarFromConstant<float>(nptrs[1]->dom_scale); |
286 | return s1 > s2 ? s2 : s1; |
287 | } else { |
288 | const QConfig& cfg = QConfig::Current(); |
289 | float scale = cfg->global_scale; |
290 | return scale / std::pow(2.0, cfg->nbit_activation - 1); |
291 | } |
292 | } |
293 | |
294 | /* \brief Unify the dom scale of arguments */ |
295 | Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args, |
296 | DataType* dtype_ptr, Expr* scale_ptr, |
297 | DataType dtype = DataType::Void()) { |
298 | static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize" ); |
299 | const QConfig& cfg = QConfig::Current(); |
300 | |
301 | std::vector<const QRealizeIntExprNode*> nptrs; |
302 | Array<Expr> ret; |
303 | for (auto arg : args) { |
304 | const auto* nptr = arg.as<QRealizeIntExprNode>(); |
305 | ICHECK(nptr); |
306 | nptrs.push_back(nptr); |
307 | ret.push_back(nptr->data); |
308 | } |
309 | |
310 | // unify the data type |
311 | ICHECK_EQ(ref_args.size(), args.size()); |
312 | |
313 | if (dtype.is_void()) { |
314 | if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) { |
315 | dtype = cfg->dtype_input; |
316 | } else { |
317 | dtype = cfg->dtype_activation; |
318 | } |
319 | } |
320 | |
321 | for (size_t i = 0; i < ret.size(); ++i) { |
322 | auto ref_arg = ref_args[i].as<CallNode>(); |
323 | if (nptrs[i]->dtype != dtype) { |
324 | ret.Set(i, Cast(ret[i], dtype)); |
325 | } else if (ref_arg && ref_arg->op.same_as(simulated_quantize) && |
326 | ref_arg->attrs.as<SimulatedQuantizeAttrs>()->kind == kQInput) { |
327 | auto new_arg = Cast(ret[i], cfg->dtype_input); |
328 | new_arg = StopFusion(new_arg); |
329 | ret.Set(i, Cast(new_arg, dtype)); |
330 | } |
331 | } |
332 | |
333 | // unify the dom_scale |
334 | float s = ChooseDomScale(nptrs); |
335 | Expr dom_scale = MakeConstantScalar(DataType::Float(32), s); |
336 | for (size_t i = 0; i < ret.size(); ++i) { |
337 | float cur_s = GetScalarFromConstant<float>(nptrs[i]->dom_scale); |
338 | ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype, ref_args[i]->type_as<TensorTypeNode>()->shape)); |
339 | } |
340 | |
341 | *dtype_ptr = dtype; |
342 | *scale_ptr = dom_scale; |
343 | return ret; |
344 | } |
345 | |
346 | Expr AddRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) { |
347 | ICHECK_EQ(new_args.size(), 2); |
348 | if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) { |
349 | DataType dtype; |
350 | Expr dom_scale; |
351 | // execute the operation with activation data type. |
352 | const QConfig& cfg = QConfig::Current(); |
353 | Array<Expr> ret_args = |
354 | UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale, cfg->dtype_activation); |
355 | for (size_t i = 0; i < ret_args.size(); ++i) { |
356 | // do not fuse float32 arg |
357 | if (new_args[i].as<QRealizeIntExprNode>()->dtype == DataType::Float(32)) { |
358 | ret_args.Set(i, StopFusion(ret_args[i])); |
359 | } |
360 | } |
361 | Expr ret = ForwardOp(ref_call, ret_args); |
362 | return QRealizeIntExpr(ret, dom_scale, dtype); |
363 | } |
364 | |
365 | ICHECK(!new_args[0]->IsInstance<TempExprNode>() && !new_args[1]->IsInstance<TempExprNode>()); |
366 | return Expr(nullptr); |
367 | } |
368 | |
369 | RELAY_REGISTER_OP("add" ).set_attr<FForwardRewrite>("FQRealizeRewrite" , AddRealize); |
370 | |
371 | Expr ClipRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) { |
372 | ICHECK_EQ(new_args.size(), 1); |
373 | if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) { |
374 | const auto ref_attrs = ref_call->attrs.as<ClipAttrs>(); |
375 | auto attrs = make_object<ClipAttrs>(); |
376 | double dom_scale = GetScalarFromConstant<float>(n->dom_scale); |
377 | attrs->a_min = ref_attrs->a_min / dom_scale; |
378 | attrs->a_max = ref_attrs->a_max / dom_scale; |
379 | |
380 | Expr ret = Call(ref_call->op, {n->data}, Attrs(attrs), ref_call->type_args); |
381 | return QRealizeIntExpr(ret, n->dom_scale, n->dtype); |
382 | } |
383 | ICHECK(!new_args[0]->IsInstance<TempExprNode>()); |
384 | return Expr(nullptr); |
385 | } |
386 | |
387 | RELAY_REGISTER_OP("clip" ).set_attr<FForwardRewrite>("FQRealizeRewrite" , ClipRealize); |
388 | |
389 | Expr ConcatenateRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) { |
390 | ICHECK_EQ(new_args.size(), 1); |
391 | ICHECK_EQ(ref_call->args.size(), 1); |
392 | |
393 | const auto* tuple = new_args[0].as<TupleNode>(); |
394 | const auto* ref_tuple = ref_call->args[0].as<TupleNode>(); |
395 | ICHECK(tuple); |
396 | ICHECK(ref_tuple); |
397 | const Array<Expr>& arr = tuple->fields; |
398 | const Array<Expr>& ref_arr = ref_tuple->fields; |
399 | |
400 | if (arr[0].as<QRealizeIntExprNode>()) { |
401 | DataType dtype; |
402 | Expr dom_scale; |
403 | Array<Expr> ret_args = UnifyDTypeScale(ref_arr, arr, &dtype, &dom_scale); |
404 | Expr ret = ForwardOp(ref_call, {Tuple(ret_args)}); |
405 | return QRealizeIntExpr(ret, dom_scale, dtype); |
406 | } else { |
407 | for (auto arg : new_args) { |
408 | ICHECK(!arg->IsInstance<TempExprNode>()); |
409 | } |
410 | return Expr(nullptr); |
411 | } |
412 | } |
413 | |
414 | RELAY_REGISTER_OP("concatenate" ).set_attr<FForwardRewrite>("FQRealizeRewrite" , ConcatenateRealize); |
415 | |
416 | /* \brief forward the original operator */ |
417 | Expr IdentityRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) { |
418 | ICHECK_EQ(new_args.size(), 1); |
419 | if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) { |
420 | Expr ret = ForwardOp(ref_call, {n->data}); |
421 | return QRealizeIntExpr(ret, n->dom_scale, n->dtype); |
422 | } |
423 | ICHECK(!new_args[0]->IsInstance<TempExprNode>()); |
424 | return Expr(nullptr); |
425 | } |
426 | |
427 | RELAY_REGISTER_OP("nn.relu" ).set_attr<FForwardRewrite>("FQRealizeRewrite" , IdentityRealize); |
428 | |
429 | RELAY_REGISTER_OP("reshape" ).set_attr<FForwardRewrite>("FQRealizeRewrite" , IdentityRealize); |
430 | |
431 | RELAY_REGISTER_OP("strided_slice" ).set_attr<FForwardRewrite>("FQRealizeRewrite" , IdentityRealize); |
432 | |
433 | RELAY_REGISTER_OP("nn.batch_flatten" ) |
434 | .set_attr<FForwardRewrite>("FQRealizeRewrite" , IdentityRealize); |
435 | |
436 | RELAY_REGISTER_OP("transpose" ).set_attr<FForwardRewrite>("FQRealizeRewrite" , IdentityRealize); |
437 | |
438 | RELAY_REGISTER_OP("annotation.stop_fusion" ) |
439 | .set_attr<FForwardRewrite>("FQRealizeRewrite" , IdentityRealize); |
440 | |
441 | /* \brief for unary operators which requantize its input to dtype_nbit */ |
442 | Expr CastDtypeInputRealize(const Call& ref_call, const Array<Expr>& new_args, |
443 | const ObjectRef& ctx) { |
444 | const QConfig& cfg = QConfig::Current(); |
445 | ICHECK_EQ(new_args.size(), 1); |
446 | if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) { |
447 | Expr data = Cast(n->data, cfg->dtype_input); |
448 | Expr ret = ForwardOp(ref_call, {data}); |
449 | return QRealizeIntExpr(ret, n->dom_scale, cfg->dtype_input); |
450 | } |
451 | ICHECK(!new_args[0]->IsInstance<TempExprNode>()); |
452 | return Expr(nullptr); |
453 | } |
454 | |
455 | RELAY_REGISTER_OP("nn.max_pool2d" ) |
456 | .set_attr<FForwardRewrite>("FQRealizeRewrite" , CastDtypeInputRealize); |
457 | |
458 | RELAY_REGISTER_OP("nn.max_pool1d" ) |
459 | .set_attr<FForwardRewrite>("FQRealizeRewrite" , CastDtypeInputRealize); |
460 | |
461 | Expr AvgPoolRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) { |
462 | const QConfig& cfg = QConfig::Current(); |
463 | ICHECK_EQ(new_args.size(), 1); |
464 | if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) { |
465 | Expr data = n->data; |
466 | if (n->dtype != cfg->dtype_activation) { |
467 | data = Cast(n->data, cfg->dtype_activation); |
468 | } |
469 | Expr ret = ForwardOp(ref_call, {data}); |
470 | return QRealizeIntExpr(ret, n->dom_scale, cfg->dtype_activation); |
471 | } |
472 | ICHECK(!new_args[0]->IsInstance<TempExprNode>()); |
473 | return Expr(nullptr); |
474 | } |
475 | |
476 | RELAY_REGISTER_OP("nn.avg_pool2d" ).set_attr<FForwardRewrite>("FQRealizeRewrite" , AvgPoolRealize); |
477 | |
478 | RELAY_REGISTER_OP("nn.global_avg_pool2d" ) |
479 | .set_attr<FForwardRewrite>("FQRealizeRewrite" , AvgPoolRealize); |
480 | |
481 | Expr CastHintRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) { |
482 | const auto param = ref_call->attrs.as<CastHintAttrs>(); |
483 | ICHECK_EQ(new_args.size(), 1); |
484 | if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) { |
485 | Expr ret = Cast(n->data, param->dtype); |
486 | return QRealizeIntExpr(ret, n->dom_scale, param->dtype); |
487 | } |
488 | ICHECK(!new_args[0]->IsInstance<TempExprNode>()); |
489 | return Expr(nullptr); |
490 | } |
491 | |
492 | RELAY_REGISTER_OP("annotation.cast_hint" ) |
493 | .set_attr<FForwardRewrite>("FQRealizeRewrite" , CastHintRealize); |
494 | |
495 | Expr BatchMatmulRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectRef& ctx) { |
496 | const QConfig& cfg = QConfig::Current(); |
497 | ICHECK_EQ(new_args.size(), 2); |
498 | if (!new_args[0]->IsInstance<TempExprNode>() || !new_args[1]->IsInstance<TempExprNode>()) { |
499 | return Expr(nullptr); |
500 | } |
501 | const auto* lhs = new_args[0].as<QRealizeIntExprNode>(); |
502 | const auto* rhs = new_args[1].as<QRealizeIntExprNode>(); |
503 | |
504 | Expr ldata = lhs->data; |
505 | Expr rdata = rhs->data; |
506 | DataType dtype_input = cfg->dtype_input; |
507 | DataType dtype_weight = cfg->dtype_weight; |
508 | |
509 | if (lhs->dtype != dtype_input) { |
510 | ldata = Cast(ldata, dtype_input); |
511 | } |
512 | if (rhs->dtype != dtype_weight) { |
513 | rdata = Cast(rdata, dtype_weight); |
514 | } |
515 | |
516 | const auto ref_attrs = ref_call->attrs.as<BatchMatmulAttrs>(); |
517 | auto attrs = make_object<BatchMatmulAttrs>(); |
518 | *attrs = *ref_attrs; |
519 | DataType out_dtype = cfg->dtype_activation; |
520 | attrs->out_dtype = out_dtype; |
521 | |
522 | Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args); |
523 | Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); |
524 | Expr dom_scale = FoldConstantExpr(mul); |
525 | return QRealizeIntExpr(ret, dom_scale, out_dtype); |
526 | } |
527 | |
528 | RELAY_REGISTER_OP("nn.batch_matmul" ) |
529 | .set_attr<FForwardRewrite>("FQRealizeRewrite" , BatchMatmulRealize); |
530 | |
531 | Pass QuantizeRealizePass() { |
532 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
533 | [=](Function f, IRModule m, PassContext pc) { |
534 | return Downcast<Function>(ForwardRewrite(f, "FQRealizeRewrite" , nullptr, nullptr)); |
535 | }; |
536 | return CreateFunctionPass(pass_func, 1, "QuantizeRealize" , {}); |
537 | } |
538 | |
539 | TVM_REGISTER_GLOBAL("relay._quantize.QuantizeRealize" ).set_body_typed(QuantizeRealizePass); |
540 | |
541 | } // namespace quantize |
542 | } // namespace relay |
543 | } // namespace tvm |
544 | |