1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/transforms/simplify_expr.cc
22 * \brief A pass for simplifying the Relay expression.
23 */
24
25#include "simplify_expr.h"
26
27#include <tvm/relay/dataflow_matcher.h>
28#include <tvm/relay/expr.h>
29#include <tvm/relay/expr_functor.h>
30#include <tvm/relay/transform.h>
31#include <tvm/runtime/logging.h>
32
33#include <algorithm>
34#include <limits>
35#include <memory>
36#include <string>
37#include <utility>
38
39#include "../op/tensor/transform.h"
40#include "fold_constant.h"
41#include "pattern_utils.h"
42
43namespace tvm {
44namespace relay {
45
46/*!
47 * \brief SimplifyReshape matches the pattern of consecutive reshape or reverse_reshape ops,
48 * and merges into one reshape op.
49 */
50class SimplifyReshape : public DFPatternRewrite {
51 public:
52 SimplifyReshape() {
53 x_ = IsWildcard();
54 auto reshape1 = IsOp("reshape") || IsOp("contrib_reverse_reshape");
55 auto reshape2 = IsOp("reshape") || IsOp("contrib_reverse_reshape");
56 pattern_ = reshape1({reshape2({x_})});
57 }
58
59 Expr Callback(const Expr& pre, const Expr& post,
60 const Map<DFPattern, Array<Expr>>& node_map) const override {
61 auto x = node_map[x_][0];
62 bool const_shape = true;
63 Array<Integer> newshape;
64 for (auto dim : Downcast<TensorType>(pre->checked_type())->shape) {
65 if (dim.as<IntImmNode>() == nullptr) {
66 const_shape = false;
67 break;
68 }
69 newshape.push_back(Downcast<Integer>(dim));
70 }
71 if (const_shape) {
72 return MakeReshape(x, newshape);
73 }
74 return post;
75 }
76
77 private:
78 /*! \brief Pattern input */
79 DFPattern x_;
80};
81
82/*!
83 * \brief SimplifySameCast matches the pattern of cast data to the same dtype.
84 */
85class SimplifySameCast : public DFPatternRewrite {
86 public:
87 SimplifySameCast() {
88 data_pat_ = IsWildcard();
89 like_pat_ = IsWildcard();
90 pattern_ = IsOp("cast_like")({data_pat_, like_pat_}) || IsOp("cast")({data_pat_});
91 }
92
93 Expr Callback(const Expr& pre, const Expr& post,
94 const Map<DFPattern, Array<Expr>>& node_map) const override {
95 const CallNode* call = pre.as<CallNode>();
96 const TensorTypeNode* data_ty = call->args[0]->checked_type().as<TensorTypeNode>();
97 const TensorTypeNode* like_ty = pre->checked_type().as<TensorTypeNode>();
98 if (like_ty->dtype == data_ty->dtype) {
99 return node_map[data_pat_][0];
100 }
101 return post;
102 }
103
104 protected:
105 DFPattern data_pat_;
106 DFPattern like_pat_;
107};
108
109/*!
110 * \brief SimplifyConsecutiveCast matches the pattern of consecutive cast/cast_like ops
111 */
112class SimplifyConsecutiveCast : public DFPatternRewrite {
113 public:
114 SimplifyConsecutiveCast() {
115 data_ = IsWildcard();
116 cast1_ = IsOp("cast_like")({data_, IsWildcard()}) || IsOp("cast")({data_});
117 pattern_ = IsOp("cast_like")({cast1_, IsWildcard()}) || IsOp("cast")({cast1_});
118 }
119
120 Expr Callback(const Expr& pre, const Expr& post,
121 const Map<DFPattern, Array<Expr>>& node_map) const override {
122 auto data = node_map[data_][0];
123 auto cast1 = Downcast<Call>(node_map[cast1_][0]);
124 auto data_type = Downcast<TensorType>(data->checked_type());
125 DataType cast1_dtype = Downcast<TensorType>(cast1->checked_type())->dtype;
126
127 if (!IsWidenCast(data_type->dtype, cast1_dtype)) {
128 // Cannot remove the narrow cast
129 return post;
130 }
131
132 const CallNode* cast2 = post.as<CallNode>();
133 DataType cast2_dtype = Downcast<TensorType>(cast2->checked_type())->dtype;
134 auto expr = MakeCast(data, cast2_dtype);
135
136 // We need to set the checked type as it may be needed in the next callback
137 expr->checked_type_ = TensorType(data_type->shape, cast2_dtype);
138 return expr;
139 }
140
141 bool IsWidenCast(DataType origin, DataType cast) const {
142 /* Return whether casting from origin to cast results in more or the same precision.*/
143 if (origin.code() == cast.code() && origin.bits() <= cast.bits()) {
144 return true;
145 }
146 if (origin.code() == DataType::kBFloat || cast.code() == DataType::kBFloat) {
147 // BFloat cast cannot be omitted
148 return false;
149 }
150 if (origin.code() < cast.code() && origin.bits() <= cast.bits()) {
151 // Loosely have a hiearchy to datatypes
152 // e.g. int --> uint --> float has increasing range of numbers they can represent
153 return true;
154 }
155 return false;
156 }
157
158 protected:
159 DFPattern data_;
160 DFPattern cast1_;
161};
162
163bool CheckDataTypeMaxMinValue(DataType dtype, double min_value, double max_value) {
164 if (dtype.is_int() || dtype.is_uint()) {
165 double ubound = static_cast<double>(Downcast<IntImm>(tvm::max_value(dtype))->value);
166 double lbound = static_cast<double>(Downcast<IntImm>(tvm::min_value(dtype))->value);
167 return ubound == max_value && lbound == min_value;
168 } else if (dtype.is_float()) {
169 double ubound = Downcast<FloatImm>(tvm::max_value(dtype))->value;
170 double lbound = Downcast<FloatImm>(tvm::min_value(dtype))->value;
171 return ubound == max_value && lbound == min_value;
172 }
173
174 return false;
175}
176
177/*!
178 * \brief SimplifyClipAndConsecutiveCast matches the pattern clip->cast->cast and remove redundant
179 * casts.
180 * Analysis of "redundancy" is done based on clip min/max values and min/max values of casted data
181 * type.
182 */
183class SimplifyClipAndConsecutiveCast : public DFPatternRewrite {
184 public:
185 SimplifyClipAndConsecutiveCast() {
186 clip_ = IsOp("clip")({IsWildcard()});
187 cast1_ = IsOp("cast")({clip_});
188 pattern_ = IsOp("cast")({cast1_});
189 }
190
191 Expr Callback(const Expr& pre, const Expr& post,
192 const Map<DFPattern, Array<Expr>>& node_map) const override {
193 auto clip = Downcast<Call>(node_map[clip_][0]);
194 const CallNode* clip_node = clip.as<CallNode>();
195 const ClipAttrs* clip_attrs = clip_node->attrs.as<ClipAttrs>();
196 DataType clip_dtype = Downcast<TensorType>(clip->checked_type())->dtype;
197
198 auto cast1 = Downcast<Call>(node_map[cast1_][0]);
199 DataType cast1_dtype = Downcast<TensorType>(cast1->checked_type())->dtype;
200
201 auto cast2 = Downcast<Call>(post);
202 DataType cast2_dtype = Downcast<TensorType>(cast2->checked_type())->dtype;
203
204 if (clip_dtype == cast2_dtype &&
205 CheckDataTypeMaxMinValue(cast1_dtype, clip_attrs->a_min, clip_attrs->a_max)) {
206 // Case 1:
207 // Data type of Clip == target data type of second Cast and min/max value of Clip == min/max
208 // value of first Clip target data type. In this case both Clip ops can be removed.
209 // Example:
210 // %0 == [type=int32]
211 // %1 = clip(%0, a_min=0f, a_max=255f) [type=int32]
212 // %2 = cast(%1, dtype="uint8") [type=uint8]
213 // %3 = cast(%2, dtype="int32") [type=int32]
214 //
215 // Optimized to (both casts can be removed):
216 // %1 = clip(%0, a_min=0f, a_max=255f) [type=int32]
217 return node_map[clip_][0];
218 }
219 return post;
220 }
221
222 protected:
223 DFPattern clip_, cast1_;
224};
225
226/*!
227 * \brief SimplifyCastClip matches the pattern cast->clip and remove redundant Cast based on Clip
228 * min/max values and min/max values of Cast target data type.
229 *
230 * Example:
231 * %1 = cast(%0, dtype="uint8") [type=uint8]
232 * %2 = clip(%1, a_min=0f, a_max=255f) [type=int8]
233 *
234 * Optimized to (remove Clip):
235 * %1 = cast(%0, dtype="uint8") [type=uint8]
236 */
237class SimplifyCastClip : public DFPatternRewrite {
238 public:
239 SimplifyCastClip() {
240 cast_ = IsOp("cast")({IsWildcard()});
241 pattern_ = IsOp("clip")({cast_});
242 }
243
244 Expr Callback(const Expr& pre, const Expr& post,
245 const Map<DFPattern, Array<Expr>>& node_map) const override {
246 auto cast = Downcast<Call>(node_map[cast_][0]);
247 DataType cast_dtype = Downcast<TensorType>(cast->checked_type())->dtype;
248
249 auto clip = Downcast<Call>(post);
250 const CallNode* clip_node = clip.as<CallNode>();
251 const ClipAttrs* clip_attrs = clip_node->attrs.as<ClipAttrs>();
252
253 if (CheckDataTypeMaxMinValue(cast_dtype, clip_attrs->a_min, clip_attrs->a_max)) {
254 return node_map[cast_][0];
255 }
256 return post;
257 }
258
259 protected:
260 DFPattern clip_, cast_;
261};
262
263/*!
264 * \brief SimplifyTranspose matches the pattern of consecutive transpose op,
265 * and merges or cancels them.
266 */
267class SimplifyTranspose : public DFPatternRewrite {
268 public:
269 SimplifyTranspose() {
270 x_ = IsWildcard();
271 auto trans1 = IsOp("transpose") || IsOp("layout_transform");
272 auto trans2 = IsOp("transpose") || IsOp("layout_transform");
273 pattern_ = trans1({trans2({x_})});
274 }
275
276 Expr Callback(const Expr& pre, const Expr& post,
277 const Map<DFPattern, Array<Expr>>& node_map) const override {
278 auto x = node_map[x_][0];
279
280 Call trans_call = Downcast<Call>(post);
281
282 // Try to fuse any rank changing layout transformations
283 if (auto layout_trans = FoldRankChangingLayoutTrans(x, trans_call)) {
284 if (auto attr = layout_trans.value()->attrs.as<LayoutTransformAttrs>()) {
285 // Prune any trivial layout transformation
286 if (attr->src_layout == attr->dst_layout) {
287 return x;
288 }
289 }
290 return layout_trans.value();
291 }
292
293 // Initialize axes
294 int ndim = Downcast<TensorType>(pre->checked_type())->shape.size();
295 Array<Integer> axes;
296 for (int i = 0; i < ndim; ++i) {
297 axes.push_back(i);
298 }
299
300 // Collect axes changes from the matched pattern, including two consecutive transposes.
301 std::vector<std::vector<int>> interm_axes;
302 interm_axes.push_back(GetTransposeAxisOrder(trans_call, ndim));
303 trans_call = Downcast<Call>(trans_call->args[0]);
304 interm_axes.push_back(GetTransposeAxisOrder(trans_call, ndim));
305
306 // Calculate the final axes in reverse order (from root to output)
307 auto it = interm_axes.rbegin();
308 while (it != interm_axes.rend()) {
309 auto interm = *it;
310
311 Array<Integer> new_axes;
312 for (int i = 0; i < ndim; ++i) {
313 new_axes.push_back(axes[interm[i]]);
314 }
315 axes = new_axes;
316 it++;
317 }
318
319 // Check if the transpose is still required
320 bool need_transpose = false;
321 for (int i = 0; i < ndim; ++i) {
322 if (axes[i] != i) {
323 need_transpose = true;
324 break;
325 }
326 }
327
328 if (need_transpose) {
329 return MakeTranspose(x, axes);
330 }
331 return x;
332 }
333
334 String PermuteLayout(const String& layout, std::vector<int> axes_order) const {
335 std::string new_layout{};
336 std::string old_layout{layout};
337 ICHECK_EQ(axes_order.size(), layout.size())
338 << "Number of axes must match the number of named axes in the layout to permute: length("
339 << old_layout << ") != " << axes_order.size();
340 std::stringstream order;
341 for (auto axis : axes_order) {
342 new_layout += old_layout[axis];
343 order << axis << ", ";
344 }
345 DLOG(INFO) << "Using transpose axes order {" << order.str()
346 << "} to permute layout: " << old_layout << " to " << new_layout;
347 return new_layout;
348 }
349
350 struct RankChangingLayoutDescriptor {
351 Layout src_layout;
352 Layout dst_layout;
353 // Either a rank changing layout transform or a transpose
354 Call other_transform;
355 };
356
357 std::unique_ptr<RankChangingLayoutDescriptor> GetRankChangeDescriptor(const Call& call) const {
358 std::unique_ptr<RankChangingLayoutDescriptor> desc{nullptr};
359 if (auto attr = call->attrs.as<LayoutTransformAttrs>()) {
360 if (attr->src_layout.length() != attr->dst_layout.length()) {
361 desc = std::make_unique<RankChangingLayoutDescriptor>();
362 desc->src_layout = Layout(attr->src_layout);
363 desc->dst_layout = Layout(attr->dst_layout);
364 desc->other_transform = Downcast<Call>(call->args[0]);
365 }
366 }
367 if (auto attr = Downcast<Call>(call->args[0])->attrs.as<LayoutTransformAttrs>()) {
368 if (attr->src_layout.length() != attr->dst_layout.length()) {
369 if (!desc) {
370 desc = std::make_unique<RankChangingLayoutDescriptor>();
371 desc->src_layout = Layout(attr->src_layout);
372 desc->dst_layout = Layout(attr->dst_layout);
373 desc->other_transform = call;
374 } else {
375 ICHECK(desc->src_layout->name == attr->dst_layout)
376 << "Back-to-back layout transforms must have the same intermediate layout: "
377 << desc->src_layout->name << " != " << attr->dst_layout;
378 desc->src_layout = Layout(attr->src_layout);
379 }
380 }
381 }
382 return desc;
383 }
384
385 /*
386 * \brief Fuse call and it's argument into a single layout_transform operator
387 * when either call or it's argument is a rang changing layout_transform, e.g.,
388 *
389 * Simplify
390 *
391 * [N, H, W, C] -> Transpose -> [N, C, H, W] -> LayoutTrans -> [N, C, H, W, 4c]
392 *
393 * to,
394 *
395 * [N, H, W, C] -> LayoutTrans -> [N, C, H, W, 4c].
396 *
397 * \param The input expression to the matched pattern
398 * \param The pattern root; the second of two consecutive Transpose/LayoutTransform ops
399 */
400 Optional<Call> FoldRankChangingLayoutTrans(const Expr& data, const Call& call) const {
401 // Check to see if either the first or second call in matched pattern
402 // is a rank changing layout transform. If so, return a descriptor containing
403 // the layouts and any additional transpose or layout transform op.
404 auto desc = GetRankChangeDescriptor(call);
405 if (desc == nullptr) {
406 // No rank changing layout transform
407 return Optional<Call>{nullptr};
408 }
409
410 Optional<Expr> output_layout_trans;
411 // Fuse a rank increasing layout transform and a preceeding transpose
412 if (desc->src_layout->axes.size() < desc->dst_layout->axes.size()) {
413 auto axes = GetTransposeAxisOrder(desc->other_transform, desc->src_layout->axes.size());
414 // Calculate the reverse axis order and apply to the source layout
415 std::vector<int> inverse(axes.size());
416 for (size_t i = 0; i < axes.size(); i++) {
417 inverse[axes[i]] = i;
418 }
419 String new_layout = PermuteLayout(desc->src_layout->name, inverse);
420 output_layout_trans = MakeLayoutTransform(data, new_layout, desc->dst_layout->name);
421 // Fuse a rank descreasing layout transform followed by a transpose
422 } else if (desc->src_layout->axes.size() > desc->dst_layout->axes.size()) {
423 auto axes = GetTransposeAxisOrder(desc->other_transform, desc->dst_layout->axes.size());
424 String new_layout = PermuteLayout(desc->dst_layout->name, axes);
425 output_layout_trans = MakeLayoutTransform(data, desc->src_layout->name, new_layout);
426 // Fuse two back-to-back layout transformations which change rank
427 } else if (desc->other_transform->attrs.as<LayoutTransformAttrs>()) {
428 output_layout_trans =
429 MakeLayoutTransform(data, desc->src_layout->name, desc->dst_layout->name);
430 }
431 return Downcast<Call>(output_layout_trans);
432 }
433
434 std::vector<int> GetTransposeAxisOrder(const Call& call, int ndim) const {
435 std::vector<int> attr_axes;
436 if (auto attr = call->attrs.as<TransposeAttrs>()) {
437 if (attr->axes.defined()) {
438 for (int i = 0; i < ndim; ++i) {
439 int64_t axis = attr->axes[i].IntValue();
440 axis += (axis < 0) ? ndim : 0;
441 attr_axes.push_back(axis);
442 }
443 } else {
444 // Empty axes means reverse
445 for (int i = ndim - 1; i >= 0; --i) {
446 attr_axes.push_back(i);
447 }
448 }
449 } else if (auto attr = call->attrs.as<LayoutTransformAttrs>()) {
450 Layout src_layout(attr->src_layout);
451 Layout dst_layout(attr->dst_layout);
452 for (int i = 0; i < ndim; ++i) {
453 attr_axes.push_back(src_layout.IndexOf(dst_layout[i]));
454 }
455 } else {
456 CHECK(false) << "Expected transpose or layout_transform, but got "
457 << Downcast<Op>(call->op)->name;
458 }
459 return std::move(attr_axes);
460 }
461
462 private:
463 /*! \brief Pattern input */
464 DFPattern x_;
465};
466
467/*!
468 * \brief FullElementwise finds full like ops followed by broadcasting ops, and eliminates
469 * the full op by directly passing the fill value into the broadcasting op.
470 */
471class FullElementwise : public DFPatternRewrite {
472 public:
473 FullElementwise() {
474 x_ = IsWildcard();
475 data_ = IsWildcard();
476 value_ = IsConstant();
477
478 full_ = IsOp("full")({value_}) || IsOp("full_like")({data_, value_});
479 ones_ = IsOp("ones")({}) || IsOp("ones_like")({data_});
480 zeros_ = IsOp("zeros")({}) || IsOp("zeros_like")({data_});
481
482 Map<String, ObjectRef> attrs;
483 attrs.Set("TOpPattern", Integer(static_cast<int>(kBroadcast)));
484 DFPattern op = IsWildcard().HasAttr(attrs);
485 DFPattern full = full_ || ones_ || zeros_;
486 pattern_ = op({full, x_}) || op({x_, full});
487 }
488
489 Expr Callback(const Expr& pre, const Expr& post,
490 const Map<DFPattern, Array<Expr>>& node_map) const override {
491 const CallNode* call = pre.as<CallNode>();
492 ICHECK(call);
493 Type pre_type = pre->checked_type_;
494 ICHECK(pre_type.as<TensorTypeNode>());
495 auto dtype = pre_type.as<TensorTypeNode>()->dtype;
496 auto x = node_map[x_][0];
497 bool is_left = post.as<CallNode>()->args[1] == x;
498 Type x_type;
499 if (is_left) {
500 x_type = call->args[1]->checked_type_;
501 } else {
502 x_type = call->args[0]->checked_type_;
503 }
504
505 if (StructuralEqual()(x_type, pre_type)) {
506 Expr value;
507 if (node_map.count(full_)) {
508 value = node_map[value_][0];
509 ICHECK(IsConstScalar(value));
510 } else if (node_map.count(ones_)) {
511 value = MakeConstantScalar(dtype, 1);
512 } else if (node_map.count(zeros_)) {
513 value = MakeConstantScalar(dtype, 0);
514 } else {
515 ICHECK(false) << "Didn't find a full op while matching full + elementwise";
516 }
517 if (is_left) {
518 return Call(call->op, {value, x}, call->attrs, call->type_args, call->span);
519 } else {
520 return Call(call->op, {x, value}, call->attrs, call->type_args, call->span);
521 }
522 }
523 return post;
524 }
525
526 private:
527 /*! \brief binary argument */
528 DFPattern x_;
529 /*! \brief data ops get shape from */
530 DFPattern data_;
531 /*! \brief constant input */
532 DFPattern value_;
533 /*! \brief full op */
534 DFPattern full_;
535 /*! \brief ones op */
536 DFPattern ones_;
537 /*! \brief zeros op */
538 DFPattern zeros_;
539};
540
541/*!
542 * \brief Converts `*_like` operators to their explicit shape equivalent (e.g. `zeros_like(x, y)` to
543 * `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary dependencies
544 * and can enable more opportunities for operator fusion.
545 */
546class ConcretizeLikeRewrite : public DFPatternRewrite {
547 public:
548 explicit ConcretizeLikeRewrite(const Op& op) {
549 ICHECK(op->num_inputs == 1 || op->num_inputs == 2)
550 << "ConcretizeLike does not handle operators that aren't unary or binary, got: " << op;
551 like_pat_ = IsWildcard();
552 data_pat_ = IsWildcard();
553 if (op->num_inputs == 1) {
554 pattern_ = IsExpr(op)({like_pat_});
555 } else {
556 pattern_ = IsExpr(op)({data_pat_, like_pat_});
557 }
558 }
559
560 virtual bool Check(const Expr& pre, const Expr& post,
561 const Map<DFPattern, Array<Expr>>& node_map) const {
562 const CallNode* call_node = pre.as<CallNode>();
563 ICHECK(call_node);
564
565 if (!call_node->checked_type().as<TensorTypeNode>()) {
566 return false;
567 }
568
569 return true;
570 }
571
572 virtual Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
573 DataType dtype) const = 0;
574
575 Expr Callback(const Expr& pre, const Expr& post,
576 const Map<DFPattern, Array<Expr>>& node_map) const override {
577 if (!Check(pre, post, node_map)) {
578 return post;
579 }
580
581 const TensorTypeNode* like_ty = pre->checked_type().as<TensorTypeNode>();
582 Array<Integer> cshape;
583 for (const auto& dim : like_ty->shape) {
584 if (const auto* imm = dim.as<IntImmNode>()) {
585 cshape.push_back(Integer(GetRef<IntImm>(imm)));
586 } else {
587 // shape is not static, don't concretize
588 return post;
589 }
590 }
591
592 return Concretize(node_map, cshape, like_ty->dtype);
593 }
594
595 protected:
596 DFPattern data_pat_;
597 DFPattern like_pat_;
598};
599
600class ConcretizeZerosLikeRewrite : public ConcretizeLikeRewrite {
601 public:
602 ConcretizeZerosLikeRewrite() : ConcretizeLikeRewrite(Op::Get("zeros_like")) {}
603
604 Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
605 DataType dtype) const override {
606 return MakeZeros(shape, dtype);
607 }
608};
609
610class ConcretizeOnesLikeRewrite : public ConcretizeLikeRewrite {
611 public:
612 ConcretizeOnesLikeRewrite() : ConcretizeLikeRewrite(Op::Get("ones_like")) {}
613
614 Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
615 DataType dtype) const override {
616 return MakeOnes(shape, dtype);
617 }
618};
619
620class ConcretizeFullLikeRewrite : public ConcretizeLikeRewrite {
621 public:
622 ConcretizeFullLikeRewrite() : ConcretizeLikeRewrite(Op::Get("full_like")) {}
623
624 Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
625 DataType dtype) const override {
626 // `like_pat_` here is `fill_value`
627 return MakeFull(node_map[like_pat_][0], shape, dtype);
628 }
629};
630
631class ConcretizeReshapeLikeRewrite : public ConcretizeLikeRewrite {
632 public:
633 ConcretizeReshapeLikeRewrite() : ConcretizeLikeRewrite(Op::Get("reshape_like")) {}
634
635 Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
636 DataType dtype) const override {
637 return MakeReshape(node_map[data_pat_][0], shape);
638 }
639};
640
641class ConcretizeCollapseSumLikeRewrite : public ConcretizeLikeRewrite {
642 public:
643 ConcretizeCollapseSumLikeRewrite() : ConcretizeLikeRewrite(Op::Get("collapse_sum_like")) {}
644
645 Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
646 DataType dtype) const override {
647 ICHECK_LE(shape.size(), std::numeric_limits<int64_t>::max());
648 static const Op& op = Op::Get("collapse_sum_to");
649 auto attrs = make_object<InitOpAttrs>();
650 attrs->shape = shape;
651 std::vector<int64_t> s;
652 std::transform(shape.begin(), shape.end(), std::back_inserter(s),
653 [](Integer i) { return i.IntValue(); });
654 auto cshape = MakeConstantTensor(DataType::Int(32), {static_cast<int64_t>(shape.size())}, s);
655 return Call(op, {node_map[data_pat_][0], cshape}, Attrs(attrs));
656 }
657};
658
659class ConcretizeBroadcastToLikeRewrite : public ConcretizeLikeRewrite {
660 public:
661 ConcretizeBroadcastToLikeRewrite() : ConcretizeLikeRewrite(Op::Get("broadcast_to_like")) {}
662
663 Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
664 DataType dtype) const override {
665 return MakeBroadCastTo(node_map[data_pat_][0], shape);
666 }
667};
668
669/*!
670 * \brief Converts cast_like operator to cast. Not inheriting from ConcretizeLikeRewrite
671 * because even if shape is not static, still can concretize.
672 */
673class ConcretizeCastLikeRewrite : public DFPatternRewrite {
674 public:
675 ConcretizeCastLikeRewrite() {
676 data_pat_ = IsWildcard();
677 like_pat_ = IsWildcard();
678 pattern_ = IsOp("cast_like")({data_pat_, like_pat_});
679 }
680
681 Expr Callback(const Expr& pre, const Expr& post,
682 const Map<DFPattern, Array<Expr>>& node_map) const override {
683 const CallNode* call_node = pre.as<CallNode>();
684 ICHECK(call_node);
685
686 if (!call_node->checked_type().as<TensorTypeNode>()) {
687 return post;
688 }
689
690 const TensorTypeNode* like_ty = pre->checked_type().as<TensorTypeNode>();
691 return MakeCast(node_map[data_pat_][0], like_ty->dtype);
692 }
693
694 protected:
695 DFPattern data_pat_;
696 DFPattern like_pat_;
697};
698
699/*! \brief Eliminates expressions that are equivalent to identity. */
700class EliminateIdentityRewrite : public DFPatternRewrite {
701 public:
702 EliminateIdentityRewrite() {
703 x_ = IsWildcard();
704 const_ = IsConstant();
705
706 DFPattern add_op = IsOp("add");
707 DFPattern mul_op = IsOp("multiply");
708 DFPattern zeros_expr = IsOp("zeros")({}) || IsOp("zeros_like")({IsWildcard()}) || const_;
709 DFPattern ones_expr = IsOp("ones")({}) || IsOp("ones_like")({IsWildcard()}) || const_;
710
711 // add and multiply are commutative so we don't need another pattern for reversed args
712 DFPattern add_id = add_op({x_, zeros_expr});
713 DFPattern mul_id = mul_op({x_, ones_expr});
714
715 DFPattern sub_id = IsOp("subtract")({x_, zeros_expr});
716 DFPattern div_id = IsOp("divide")({x_, ones_expr});
717
718 pattern_ = add_id || mul_id || sub_id || div_id;
719 }
720
721 bool CheckConstant(const OpNode* op, const ConstantNode* constant) const {
722 if (!IsScalar(GetRef<Expr>(constant))) {
723 return false;
724 }
725 auto value = TryToScalar(constant->data, 0);
726 if (!value) {
727 // unsupported dtype
728 return false;
729 }
730 if (op->name == "add" || op->name == "subtract") {
731 return value.value() == 0.0;
732 } else if (op->name == "multiply" || op->name == "divide") {
733 return value.value() == 1.0;
734 }
735 return false;
736 }
737
738 Expr Callback(const Expr& pre, const Expr& post,
739 const Map<DFPattern, Array<Expr>>& node_map) const override {
740 const CallNode* call = pre.as<CallNode>();
741 ICHECK(call);
742 Type pre_type = pre->checked_type_;
743 ICHECK(pre_type.as<TensorTypeNode>());
744 auto x = node_map[x_][0];
745 bool is_left = post.as<CallNode>()->args[1] == x;
746 Type x_type;
747 if (is_left) {
748 x_type = call->args[1]->checked_type_;
749 } else {
750 x_type = call->args[0]->checked_type_;
751 }
752
753 if (node_map.count(const_)) {
754 // the other argument is a Constant in this case
755 const ConstantNode* constant = node_map[const_][0].as<ConstantNode>();
756 const OpNode* op = call->op.as<OpNode>();
757 ICHECK(constant);
758 ICHECK(op);
759 if (!CheckConstant(op, constant)) {
760 return post;
761 }
762 }
763
764 if (StructuralEqual()(x_type, pre_type)) {
765 return x;
766 }
767
768 return post;
769 }
770
771 private:
772 DFPattern x_;
773 DFPattern const_;
774};
775
776/*! \brief Switch adjacent add-mul with constants to mul-add.
777 * As mul-add pattern is more friendly to FoldScaleAxis.
778 */
779class SwitchAddMultiply : public DFPatternRewrite {
780 public:
781 SwitchAddMultiply() {
782 x_ = IsWildcard();
783 c1_ = IsConstant();
784 c2_ = IsConstant();
785 pattern_ = (x_ + c1_) * c2_;
786 }
787
788 Expr Callback(const Expr& pre, const Expr& post,
789 const Map<DFPattern, Array<Expr>>& node_map) const override {
790 auto x = node_map[x_][0];
791 auto c1 = node_map[c1_][0];
792 auto c2 = node_map[c2_][0];
793
794 if (x.as<ConstantNode>()) {
795 return post;
796 }
797
798 Expr const_expr = Call(Op::Get("multiply"), {c1, c2});
799 Expr const_val = transform::FoldConstantExpr(const_expr);
800
801 return Call(Op::Get("add"), {Call(Op::Get("multiply"), {x, c2}), const_val});
802 }
803
804 private:
805 DFPattern x_;
806 DFPattern c1_;
807 DFPattern c2_;
808};
809
810/*! \brief Simplify two adjacent multiply or add with constants for further constant folding.
811 * The pattern matching supports commutative property.
812 */
813class SimplifyAdjacentMultiplyOrAdd : public DFPatternRewrite {
814 public:
815 SimplifyAdjacentMultiplyOrAdd() {
816 x_ = IsWildcard();
817 c1_ = IsConstant();
818 c2_ = IsConstant();
819 pattern_ = (x_ * c1_ * c2_) || (x_ + c1_ + c2_);
820 }
821
822 Expr Callback(const Expr& pre, const Expr& post,
823 const Map<DFPattern, Array<Expr>>& node_map) const override {
824 const CallNode* call = pre.as<CallNode>();
825 auto x = node_map[x_][0];
826 auto c1 = node_map[c1_][0];
827 auto c2 = node_map[c2_][0];
828
829 if (x.as<ConstantNode>()) {
830 return post;
831 }
832
833 Expr const_expr = Call(call->op, {c1, c2});
834 Expr const_val = transform::FoldConstantExpr(const_expr);
835
836 return Call(call->op, {x, const_val});
837 }
838
839 private:
840 DFPattern x_;
841 DFPattern c1_;
842 DFPattern c2_;
843};
844
845/*! \brief Simplifying x+x to x*2 */
846class SimplifyAdd : public DFPatternRewrite {
847 public:
848 SimplifyAdd() {
849 x_ = IsWildcard();
850 y_ = IsWildcard();
851 pattern_ = IsOp("add")({x_, y_});
852 }
853
854 Expr Callback(const Expr& pre, const Expr& post,
855 const Map<DFPattern, Array<Expr>>& node_map) const override {
856 Type pre_type = pre->checked_type_;
857 auto dtype = pre_type.as<TensorTypeNode>()->dtype;
858 auto x = node_map[x_][0];
859 auto y = node_map[y_][0];
860 auto data_type = Downcast<TensorType>(x->checked_type());
861
862 if (x == y) {
863 Expr value;
864 value = MakeConstantScalar(dtype, 2);
865 return InferType(Call(Op::Get("multiply"), {x, value}));
866 }
867 return post;
868 }
869
870 private:
871 /*! \brief Pattern input */
872 DFPattern x_;
873 DFPattern y_;
874};
875
876/*! \brief Simplifying x/sqrt to x*sqrt */
877class SimplifyRSqrt : public DFPatternRewrite {
878 public:
879 SimplifyRSqrt() {
880 x_ = IsWildcard();
881 numerator_ = IsWildcard();
882 auto sqrt = IsOp("sqrt");
883 pattern_ = IsOp("divide")({numerator_, sqrt({x_})});
884 }
885
886 Expr Callback(const Expr& pre, const Expr& post,
887 const Map<DFPattern, Array<Expr>>& node_map) const override {
888 static const Op& op = Op::Get("rsqrt");
889 auto x = node_map[x_][0];
890 auto numerator = node_map[numerator_][0];
891 return Call(Op::Get("multiply"), {numerator, Call(op, {x})});
892 }
893
894 private:
895 /*! \brief Pattern input */
896 DFPattern x_;
897 DFPattern numerator_;
898};
899
900/*! \brief Base class for simplifying dequantize followed by arg ops */
901class SimplifyDQArgFunc : public DFPatternRewrite {
902 public:
903 explicit SimplifyDQArgFunc(std::string op) : op_(op) {
904 x_ = IsWildcard();
905 dq_ = IsOp("qnn.dequantize")({x_, IsWildcard(), IsWildcard()});
906 pattern_ = IsOp(op_)({dq_});
907 }
908
909 Expr Callback(const Expr& pre, const Expr& post,
910 const Map<DFPattern, Array<Expr>>& node_map) const override {
911 const CallNode* call = pre.as<CallNode>();
912 ICHECK(call);
913 auto x = node_map[x_][0];
914 return Call(Op::Get(op_), {x}, call->attrs);
915 }
916
917 protected:
918 /*! \brief Pattern input */
919 DFPattern x_;
920 /*! \brief dequantize op */
921 DFPattern dq_;
922 /*! \brief Name of op to simplify */
923 String op_;
924};
925
926/*! \brief Simplify dequantize follwed by argmax */
927class SimplifyDQArgMax : public SimplifyDQArgFunc {
928 public:
929 SimplifyDQArgMax() : SimplifyDQArgFunc("argmax") {}
930};
931
932/*! \brief Simplify dequantize follwed by argmin */
933class SimplifyDQArgMin : public SimplifyDQArgFunc {
934 public:
935 SimplifyDQArgMin() : SimplifyDQArgFunc("argmin") {}
936};
937
938/*! \brief Simplify dequantize follwed by argsort */
939class SimplifyDQArgSort : public SimplifyDQArgFunc {
940 public:
941 SimplifyDQArgSort() : SimplifyDQArgFunc("argsort") {}
942};
943
944Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
945 // the rewrites will be applied in the given order, and repeated until fixed point
946 DFPatternRewriteComposer composer;
947 composer.AddRewrite<ConcretizeZerosLikeRewrite>();
948 composer.AddRewrite<ConcretizeOnesLikeRewrite>();
949 composer.AddRewrite<ConcretizeFullLikeRewrite>();
950 composer.AddRewrite<ConcretizeReshapeLikeRewrite>();
951 composer.AddRewrite<ConcretizeCollapseSumLikeRewrite>();
952 composer.AddRewrite<ConcretizeBroadcastToLikeRewrite>();
953 composer.AddRewrite<ConcretizeCastLikeRewrite>();
954 composer.AddRewrite<SimplifyAdd>();
955 composer.AddRewrite<SimplifyRSqrt>();
956 composer.AddRewrite<EliminateIdentityRewrite>();
957 composer.AddRewrite<SimplifyReshape>();
958 composer.AddRewrite<SimplifyTranspose>();
959 composer.AddRewrite<SimplifySameCast>();
960 composer.AddRewrite<SimplifyConsecutiveCast>();
961 composer.AddRewrite<FullElementwise>();
962 composer.AddRewrite<SwitchAddMultiply>();
963 composer.AddRewrite<SimplifyAdjacentMultiplyOrAdd>();
964 composer.AddRewrite<SimplifyDQArgMax>();
965 composer.AddRewrite<SimplifyDQArgMin>();
966 composer.AddRewrite<SimplifyDQArgSort>();
967 composer.AddRewrite<SimplifyClipAndConsecutiveCast>();
968 composer.AddRewrite<SimplifyCastClip>();
969 return RewritePatterns(composer.MakeCallbacks(), expr, mod);
970}
971
972namespace transform {
973
974Pass SimplifyExpr() {
975 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
976 [=](Function f, IRModule m, PassContext pc) {
977 return Downcast<Function>(SimplifyExpr(f, m));
978 };
979 return CreateFunctionPass(pass_func, 0, "SimplifyExpr", {"InferType"});
980}
981
982TVM_REGISTER_GLOBAL("relay._transform.SimplifyExpr").set_body_typed(SimplifyExpr);
983
984} // namespace transform
985
986} // namespace relay
987} // namespace tvm
988