1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/transforms/simplify_expr.cc |
22 | * \brief A pass for simplifying the Relay expression. |
23 | */ |
24 | |
25 | #include "simplify_expr.h" |
26 | |
27 | #include <tvm/relay/dataflow_matcher.h> |
28 | #include <tvm/relay/expr.h> |
29 | #include <tvm/relay/expr_functor.h> |
30 | #include <tvm/relay/transform.h> |
31 | #include <tvm/runtime/logging.h> |
32 | |
33 | #include <algorithm> |
34 | #include <limits> |
35 | #include <memory> |
36 | #include <string> |
37 | #include <utility> |
38 | |
39 | #include "../op/tensor/transform.h" |
40 | #include "fold_constant.h" |
41 | #include "pattern_utils.h" |
42 | |
43 | namespace tvm { |
44 | namespace relay { |
45 | |
46 | /*! |
47 | * \brief SimplifyReshape matches the pattern of consecutive reshape or reverse_reshape ops, |
48 | * and merges into one reshape op. |
49 | */ |
50 | class SimplifyReshape : public DFPatternRewrite { |
51 | public: |
52 | SimplifyReshape() { |
53 | x_ = IsWildcard(); |
54 | auto reshape1 = IsOp("reshape" ) || IsOp("contrib_reverse_reshape" ); |
55 | auto reshape2 = IsOp("reshape" ) || IsOp("contrib_reverse_reshape" ); |
56 | pattern_ = reshape1({reshape2({x_})}); |
57 | } |
58 | |
59 | Expr Callback(const Expr& pre, const Expr& post, |
60 | const Map<DFPattern, Array<Expr>>& node_map) const override { |
61 | auto x = node_map[x_][0]; |
62 | bool const_shape = true; |
63 | Array<Integer> newshape; |
64 | for (auto dim : Downcast<TensorType>(pre->checked_type())->shape) { |
65 | if (dim.as<IntImmNode>() == nullptr) { |
66 | const_shape = false; |
67 | break; |
68 | } |
69 | newshape.push_back(Downcast<Integer>(dim)); |
70 | } |
71 | if (const_shape) { |
72 | return MakeReshape(x, newshape); |
73 | } |
74 | return post; |
75 | } |
76 | |
77 | private: |
78 | /*! \brief Pattern input */ |
79 | DFPattern x_; |
80 | }; |
81 | |
82 | /*! |
83 | * \brief SimplifySameCast matches the pattern of cast data to the same dtype. |
84 | */ |
85 | class SimplifySameCast : public DFPatternRewrite { |
86 | public: |
87 | SimplifySameCast() { |
88 | data_pat_ = IsWildcard(); |
89 | like_pat_ = IsWildcard(); |
90 | pattern_ = IsOp("cast_like" )({data_pat_, like_pat_}) || IsOp("cast" )({data_pat_}); |
91 | } |
92 | |
93 | Expr Callback(const Expr& pre, const Expr& post, |
94 | const Map<DFPattern, Array<Expr>>& node_map) const override { |
95 | const CallNode* call = pre.as<CallNode>(); |
96 | const TensorTypeNode* data_ty = call->args[0]->checked_type().as<TensorTypeNode>(); |
97 | const TensorTypeNode* like_ty = pre->checked_type().as<TensorTypeNode>(); |
98 | if (like_ty->dtype == data_ty->dtype) { |
99 | return node_map[data_pat_][0]; |
100 | } |
101 | return post; |
102 | } |
103 | |
104 | protected: |
105 | DFPattern data_pat_; |
106 | DFPattern like_pat_; |
107 | }; |
108 | |
109 | /*! |
110 | * \brief SimplifyConsecutiveCast matches the pattern of consecutive cast/cast_like ops |
111 | */ |
112 | class SimplifyConsecutiveCast : public DFPatternRewrite { |
113 | public: |
114 | SimplifyConsecutiveCast() { |
115 | data_ = IsWildcard(); |
116 | cast1_ = IsOp("cast_like" )({data_, IsWildcard()}) || IsOp("cast" )({data_}); |
117 | pattern_ = IsOp("cast_like" )({cast1_, IsWildcard()}) || IsOp("cast" )({cast1_}); |
118 | } |
119 | |
120 | Expr Callback(const Expr& pre, const Expr& post, |
121 | const Map<DFPattern, Array<Expr>>& node_map) const override { |
122 | auto data = node_map[data_][0]; |
123 | auto cast1 = Downcast<Call>(node_map[cast1_][0]); |
124 | auto data_type = Downcast<TensorType>(data->checked_type()); |
125 | DataType cast1_dtype = Downcast<TensorType>(cast1->checked_type())->dtype; |
126 | |
127 | if (!IsWidenCast(data_type->dtype, cast1_dtype)) { |
128 | // Cannot remove the narrow cast |
129 | return post; |
130 | } |
131 | |
132 | const CallNode* cast2 = post.as<CallNode>(); |
133 | DataType cast2_dtype = Downcast<TensorType>(cast2->checked_type())->dtype; |
134 | auto expr = MakeCast(data, cast2_dtype); |
135 | |
136 | // We need to set the checked type as it may be needed in the next callback |
137 | expr->checked_type_ = TensorType(data_type->shape, cast2_dtype); |
138 | return expr; |
139 | } |
140 | |
141 | bool IsWidenCast(DataType origin, DataType cast) const { |
142 | /* Return whether casting from origin to cast results in more or the same precision.*/ |
143 | if (origin.code() == cast.code() && origin.bits() <= cast.bits()) { |
144 | return true; |
145 | } |
146 | if (origin.code() == DataType::kBFloat || cast.code() == DataType::kBFloat) { |
147 | // BFloat cast cannot be omitted |
148 | return false; |
149 | } |
150 | if (origin.code() < cast.code() && origin.bits() <= cast.bits()) { |
151 | // Loosely have a hiearchy to datatypes |
152 | // e.g. int --> uint --> float has increasing range of numbers they can represent |
153 | return true; |
154 | } |
155 | return false; |
156 | } |
157 | |
158 | protected: |
159 | DFPattern data_; |
160 | DFPattern cast1_; |
161 | }; |
162 | |
163 | bool CheckDataTypeMaxMinValue(DataType dtype, double min_value, double max_value) { |
164 | if (dtype.is_int() || dtype.is_uint()) { |
165 | double ubound = static_cast<double>(Downcast<IntImm>(tvm::max_value(dtype))->value); |
166 | double lbound = static_cast<double>(Downcast<IntImm>(tvm::min_value(dtype))->value); |
167 | return ubound == max_value && lbound == min_value; |
168 | } else if (dtype.is_float()) { |
169 | double ubound = Downcast<FloatImm>(tvm::max_value(dtype))->value; |
170 | double lbound = Downcast<FloatImm>(tvm::min_value(dtype))->value; |
171 | return ubound == max_value && lbound == min_value; |
172 | } |
173 | |
174 | return false; |
175 | } |
176 | |
177 | /*! |
178 | * \brief SimplifyClipAndConsecutiveCast matches the pattern clip->cast->cast and remove redundant |
179 | * casts. |
180 | * Analysis of "redundancy" is done based on clip min/max values and min/max values of casted data |
181 | * type. |
182 | */ |
183 | class SimplifyClipAndConsecutiveCast : public DFPatternRewrite { |
184 | public: |
185 | SimplifyClipAndConsecutiveCast() { |
186 | clip_ = IsOp("clip" )({IsWildcard()}); |
187 | cast1_ = IsOp("cast" )({clip_}); |
188 | pattern_ = IsOp("cast" )({cast1_}); |
189 | } |
190 | |
191 | Expr Callback(const Expr& pre, const Expr& post, |
192 | const Map<DFPattern, Array<Expr>>& node_map) const override { |
193 | auto clip = Downcast<Call>(node_map[clip_][0]); |
194 | const CallNode* clip_node = clip.as<CallNode>(); |
195 | const ClipAttrs* clip_attrs = clip_node->attrs.as<ClipAttrs>(); |
196 | DataType clip_dtype = Downcast<TensorType>(clip->checked_type())->dtype; |
197 | |
198 | auto cast1 = Downcast<Call>(node_map[cast1_][0]); |
199 | DataType cast1_dtype = Downcast<TensorType>(cast1->checked_type())->dtype; |
200 | |
201 | auto cast2 = Downcast<Call>(post); |
202 | DataType cast2_dtype = Downcast<TensorType>(cast2->checked_type())->dtype; |
203 | |
204 | if (clip_dtype == cast2_dtype && |
205 | CheckDataTypeMaxMinValue(cast1_dtype, clip_attrs->a_min, clip_attrs->a_max)) { |
206 | // Case 1: |
207 | // Data type of Clip == target data type of second Cast and min/max value of Clip == min/max |
208 | // value of first Clip target data type. In this case both Clip ops can be removed. |
209 | // Example: |
210 | // %0 == [type=int32] |
211 | // %1 = clip(%0, a_min=0f, a_max=255f) [type=int32] |
212 | // %2 = cast(%1, dtype="uint8") [type=uint8] |
213 | // %3 = cast(%2, dtype="int32") [type=int32] |
214 | // |
215 | // Optimized to (both casts can be removed): |
216 | // %1 = clip(%0, a_min=0f, a_max=255f) [type=int32] |
217 | return node_map[clip_][0]; |
218 | } |
219 | return post; |
220 | } |
221 | |
222 | protected: |
223 | DFPattern clip_, cast1_; |
224 | }; |
225 | |
226 | /*! |
227 | * \brief SimplifyCastClip matches the pattern cast->clip and remove redundant Cast based on Clip |
228 | * min/max values and min/max values of Cast target data type. |
229 | * |
230 | * Example: |
231 | * %1 = cast(%0, dtype="uint8") [type=uint8] |
232 | * %2 = clip(%1, a_min=0f, a_max=255f) [type=int8] |
233 | * |
234 | * Optimized to (remove Clip): |
235 | * %1 = cast(%0, dtype="uint8") [type=uint8] |
236 | */ |
237 | class SimplifyCastClip : public DFPatternRewrite { |
238 | public: |
239 | SimplifyCastClip() { |
240 | cast_ = IsOp("cast" )({IsWildcard()}); |
241 | pattern_ = IsOp("clip" )({cast_}); |
242 | } |
243 | |
244 | Expr Callback(const Expr& pre, const Expr& post, |
245 | const Map<DFPattern, Array<Expr>>& node_map) const override { |
246 | auto cast = Downcast<Call>(node_map[cast_][0]); |
247 | DataType cast_dtype = Downcast<TensorType>(cast->checked_type())->dtype; |
248 | |
249 | auto clip = Downcast<Call>(post); |
250 | const CallNode* clip_node = clip.as<CallNode>(); |
251 | const ClipAttrs* clip_attrs = clip_node->attrs.as<ClipAttrs>(); |
252 | |
253 | if (CheckDataTypeMaxMinValue(cast_dtype, clip_attrs->a_min, clip_attrs->a_max)) { |
254 | return node_map[cast_][0]; |
255 | } |
256 | return post; |
257 | } |
258 | |
259 | protected: |
260 | DFPattern clip_, cast_; |
261 | }; |
262 | |
263 | /*! |
264 | * \brief SimplifyTranspose matches the pattern of consecutive transpose op, |
265 | * and merges or cancels them. |
266 | */ |
267 | class SimplifyTranspose : public DFPatternRewrite { |
268 | public: |
269 | SimplifyTranspose() { |
270 | x_ = IsWildcard(); |
271 | auto trans1 = IsOp("transpose" ) || IsOp("layout_transform" ); |
272 | auto trans2 = IsOp("transpose" ) || IsOp("layout_transform" ); |
273 | pattern_ = trans1({trans2({x_})}); |
274 | } |
275 | |
276 | Expr Callback(const Expr& pre, const Expr& post, |
277 | const Map<DFPattern, Array<Expr>>& node_map) const override { |
278 | auto x = node_map[x_][0]; |
279 | |
280 | Call trans_call = Downcast<Call>(post); |
281 | |
282 | // Try to fuse any rank changing layout transformations |
283 | if (auto layout_trans = FoldRankChangingLayoutTrans(x, trans_call)) { |
284 | if (auto attr = layout_trans.value()->attrs.as<LayoutTransformAttrs>()) { |
285 | // Prune any trivial layout transformation |
286 | if (attr->src_layout == attr->dst_layout) { |
287 | return x; |
288 | } |
289 | } |
290 | return layout_trans.value(); |
291 | } |
292 | |
293 | // Initialize axes |
294 | int ndim = Downcast<TensorType>(pre->checked_type())->shape.size(); |
295 | Array<Integer> axes; |
296 | for (int i = 0; i < ndim; ++i) { |
297 | axes.push_back(i); |
298 | } |
299 | |
300 | // Collect axes changes from the matched pattern, including two consecutive transposes. |
301 | std::vector<std::vector<int>> interm_axes; |
302 | interm_axes.push_back(GetTransposeAxisOrder(trans_call, ndim)); |
303 | trans_call = Downcast<Call>(trans_call->args[0]); |
304 | interm_axes.push_back(GetTransposeAxisOrder(trans_call, ndim)); |
305 | |
306 | // Calculate the final axes in reverse order (from root to output) |
307 | auto it = interm_axes.rbegin(); |
308 | while (it != interm_axes.rend()) { |
309 | auto interm = *it; |
310 | |
311 | Array<Integer> new_axes; |
312 | for (int i = 0; i < ndim; ++i) { |
313 | new_axes.push_back(axes[interm[i]]); |
314 | } |
315 | axes = new_axes; |
316 | it++; |
317 | } |
318 | |
319 | // Check if the transpose is still required |
320 | bool need_transpose = false; |
321 | for (int i = 0; i < ndim; ++i) { |
322 | if (axes[i] != i) { |
323 | need_transpose = true; |
324 | break; |
325 | } |
326 | } |
327 | |
328 | if (need_transpose) { |
329 | return MakeTranspose(x, axes); |
330 | } |
331 | return x; |
332 | } |
333 | |
334 | String PermuteLayout(const String& layout, std::vector<int> axes_order) const { |
335 | std::string new_layout{}; |
336 | std::string old_layout{layout}; |
337 | ICHECK_EQ(axes_order.size(), layout.size()) |
338 | << "Number of axes must match the number of named axes in the layout to permute: length(" |
339 | << old_layout << ") != " << axes_order.size(); |
340 | std::stringstream order; |
341 | for (auto axis : axes_order) { |
342 | new_layout += old_layout[axis]; |
343 | order << axis << ", " ; |
344 | } |
345 | DLOG(INFO) << "Using transpose axes order {" << order.str() |
346 | << "} to permute layout: " << old_layout << " to " << new_layout; |
347 | return new_layout; |
348 | } |
349 | |
350 | struct RankChangingLayoutDescriptor { |
351 | Layout src_layout; |
352 | Layout dst_layout; |
353 | // Either a rank changing layout transform or a transpose |
354 | Call other_transform; |
355 | }; |
356 | |
357 | std::unique_ptr<RankChangingLayoutDescriptor> GetRankChangeDescriptor(const Call& call) const { |
358 | std::unique_ptr<RankChangingLayoutDescriptor> desc{nullptr}; |
359 | if (auto attr = call->attrs.as<LayoutTransformAttrs>()) { |
360 | if (attr->src_layout.length() != attr->dst_layout.length()) { |
361 | desc = std::make_unique<RankChangingLayoutDescriptor>(); |
362 | desc->src_layout = Layout(attr->src_layout); |
363 | desc->dst_layout = Layout(attr->dst_layout); |
364 | desc->other_transform = Downcast<Call>(call->args[0]); |
365 | } |
366 | } |
367 | if (auto attr = Downcast<Call>(call->args[0])->attrs.as<LayoutTransformAttrs>()) { |
368 | if (attr->src_layout.length() != attr->dst_layout.length()) { |
369 | if (!desc) { |
370 | desc = std::make_unique<RankChangingLayoutDescriptor>(); |
371 | desc->src_layout = Layout(attr->src_layout); |
372 | desc->dst_layout = Layout(attr->dst_layout); |
373 | desc->other_transform = call; |
374 | } else { |
375 | ICHECK(desc->src_layout->name == attr->dst_layout) |
376 | << "Back-to-back layout transforms must have the same intermediate layout: " |
377 | << desc->src_layout->name << " != " << attr->dst_layout; |
378 | desc->src_layout = Layout(attr->src_layout); |
379 | } |
380 | } |
381 | } |
382 | return desc; |
383 | } |
384 | |
385 | /* |
386 | * \brief Fuse call and it's argument into a single layout_transform operator |
387 | * when either call or it's argument is a rang changing layout_transform, e.g., |
388 | * |
389 | * Simplify |
390 | * |
391 | * [N, H, W, C] -> Transpose -> [N, C, H, W] -> LayoutTrans -> [N, C, H, W, 4c] |
392 | * |
393 | * to, |
394 | * |
395 | * [N, H, W, C] -> LayoutTrans -> [N, C, H, W, 4c]. |
396 | * |
397 | * \param The input expression to the matched pattern |
398 | * \param The pattern root; the second of two consecutive Transpose/LayoutTransform ops |
399 | */ |
400 | Optional<Call> FoldRankChangingLayoutTrans(const Expr& data, const Call& call) const { |
401 | // Check to see if either the first or second call in matched pattern |
402 | // is a rank changing layout transform. If so, return a descriptor containing |
403 | // the layouts and any additional transpose or layout transform op. |
404 | auto desc = GetRankChangeDescriptor(call); |
405 | if (desc == nullptr) { |
406 | // No rank changing layout transform |
407 | return Optional<Call>{nullptr}; |
408 | } |
409 | |
410 | Optional<Expr> output_layout_trans; |
411 | // Fuse a rank increasing layout transform and a preceeding transpose |
412 | if (desc->src_layout->axes.size() < desc->dst_layout->axes.size()) { |
413 | auto axes = GetTransposeAxisOrder(desc->other_transform, desc->src_layout->axes.size()); |
414 | // Calculate the reverse axis order and apply to the source layout |
415 | std::vector<int> inverse(axes.size()); |
416 | for (size_t i = 0; i < axes.size(); i++) { |
417 | inverse[axes[i]] = i; |
418 | } |
419 | String new_layout = PermuteLayout(desc->src_layout->name, inverse); |
420 | output_layout_trans = MakeLayoutTransform(data, new_layout, desc->dst_layout->name); |
421 | // Fuse a rank descreasing layout transform followed by a transpose |
422 | } else if (desc->src_layout->axes.size() > desc->dst_layout->axes.size()) { |
423 | auto axes = GetTransposeAxisOrder(desc->other_transform, desc->dst_layout->axes.size()); |
424 | String new_layout = PermuteLayout(desc->dst_layout->name, axes); |
425 | output_layout_trans = MakeLayoutTransform(data, desc->src_layout->name, new_layout); |
426 | // Fuse two back-to-back layout transformations which change rank |
427 | } else if (desc->other_transform->attrs.as<LayoutTransformAttrs>()) { |
428 | output_layout_trans = |
429 | MakeLayoutTransform(data, desc->src_layout->name, desc->dst_layout->name); |
430 | } |
431 | return Downcast<Call>(output_layout_trans); |
432 | } |
433 | |
434 | std::vector<int> GetTransposeAxisOrder(const Call& call, int ndim) const { |
435 | std::vector<int> attr_axes; |
436 | if (auto attr = call->attrs.as<TransposeAttrs>()) { |
437 | if (attr->axes.defined()) { |
438 | for (int i = 0; i < ndim; ++i) { |
439 | int64_t axis = attr->axes[i].IntValue(); |
440 | axis += (axis < 0) ? ndim : 0; |
441 | attr_axes.push_back(axis); |
442 | } |
443 | } else { |
444 | // Empty axes means reverse |
445 | for (int i = ndim - 1; i >= 0; --i) { |
446 | attr_axes.push_back(i); |
447 | } |
448 | } |
449 | } else if (auto attr = call->attrs.as<LayoutTransformAttrs>()) { |
450 | Layout src_layout(attr->src_layout); |
451 | Layout dst_layout(attr->dst_layout); |
452 | for (int i = 0; i < ndim; ++i) { |
453 | attr_axes.push_back(src_layout.IndexOf(dst_layout[i])); |
454 | } |
455 | } else { |
456 | CHECK(false) << "Expected transpose or layout_transform, but got " |
457 | << Downcast<Op>(call->op)->name; |
458 | } |
459 | return std::move(attr_axes); |
460 | } |
461 | |
462 | private: |
463 | /*! \brief Pattern input */ |
464 | DFPattern x_; |
465 | }; |
466 | |
467 | /*! |
468 | * \brief FullElementwise finds full like ops followed by broadcasting ops, and eliminates |
469 | * the full op by directly passing the fill value into the broadcasting op. |
470 | */ |
471 | class FullElementwise : public DFPatternRewrite { |
472 | public: |
473 | FullElementwise() { |
474 | x_ = IsWildcard(); |
475 | data_ = IsWildcard(); |
476 | value_ = IsConstant(); |
477 | |
478 | full_ = IsOp("full" )({value_}) || IsOp("full_like" )({data_, value_}); |
479 | ones_ = IsOp("ones" )({}) || IsOp("ones_like" )({data_}); |
480 | zeros_ = IsOp("zeros" )({}) || IsOp("zeros_like" )({data_}); |
481 | |
482 | Map<String, ObjectRef> attrs; |
483 | attrs.Set("TOpPattern" , Integer(static_cast<int>(kBroadcast))); |
484 | DFPattern op = IsWildcard().HasAttr(attrs); |
485 | DFPattern full = full_ || ones_ || zeros_; |
486 | pattern_ = op({full, x_}) || op({x_, full}); |
487 | } |
488 | |
489 | Expr Callback(const Expr& pre, const Expr& post, |
490 | const Map<DFPattern, Array<Expr>>& node_map) const override { |
491 | const CallNode* call = pre.as<CallNode>(); |
492 | ICHECK(call); |
493 | Type pre_type = pre->checked_type_; |
494 | ICHECK(pre_type.as<TensorTypeNode>()); |
495 | auto dtype = pre_type.as<TensorTypeNode>()->dtype; |
496 | auto x = node_map[x_][0]; |
497 | bool is_left = post.as<CallNode>()->args[1] == x; |
498 | Type x_type; |
499 | if (is_left) { |
500 | x_type = call->args[1]->checked_type_; |
501 | } else { |
502 | x_type = call->args[0]->checked_type_; |
503 | } |
504 | |
505 | if (StructuralEqual()(x_type, pre_type)) { |
506 | Expr value; |
507 | if (node_map.count(full_)) { |
508 | value = node_map[value_][0]; |
509 | ICHECK(IsConstScalar(value)); |
510 | } else if (node_map.count(ones_)) { |
511 | value = MakeConstantScalar(dtype, 1); |
512 | } else if (node_map.count(zeros_)) { |
513 | value = MakeConstantScalar(dtype, 0); |
514 | } else { |
515 | ICHECK(false) << "Didn't find a full op while matching full + elementwise" ; |
516 | } |
517 | if (is_left) { |
518 | return Call(call->op, {value, x}, call->attrs, call->type_args, call->span); |
519 | } else { |
520 | return Call(call->op, {x, value}, call->attrs, call->type_args, call->span); |
521 | } |
522 | } |
523 | return post; |
524 | } |
525 | |
526 | private: |
527 | /*! \brief binary argument */ |
528 | DFPattern x_; |
529 | /*! \brief data ops get shape from */ |
530 | DFPattern data_; |
531 | /*! \brief constant input */ |
532 | DFPattern value_; |
533 | /*! \brief full op */ |
534 | DFPattern full_; |
535 | /*! \brief ones op */ |
536 | DFPattern ones_; |
537 | /*! \brief zeros op */ |
538 | DFPattern zeros_; |
539 | }; |
540 | |
541 | /*! |
542 | * \brief Converts `*_like` operators to their explicit shape equivalent (e.g. `zeros_like(x, y)` to |
543 | * `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary dependencies |
544 | * and can enable more opportunities for operator fusion. |
545 | */ |
546 | class ConcretizeLikeRewrite : public DFPatternRewrite { |
547 | public: |
548 | explicit ConcretizeLikeRewrite(const Op& op) { |
549 | ICHECK(op->num_inputs == 1 || op->num_inputs == 2) |
550 | << "ConcretizeLike does not handle operators that aren't unary or binary, got: " << op; |
551 | like_pat_ = IsWildcard(); |
552 | data_pat_ = IsWildcard(); |
553 | if (op->num_inputs == 1) { |
554 | pattern_ = IsExpr(op)({like_pat_}); |
555 | } else { |
556 | pattern_ = IsExpr(op)({data_pat_, like_pat_}); |
557 | } |
558 | } |
559 | |
560 | virtual bool Check(const Expr& pre, const Expr& post, |
561 | const Map<DFPattern, Array<Expr>>& node_map) const { |
562 | const CallNode* call_node = pre.as<CallNode>(); |
563 | ICHECK(call_node); |
564 | |
565 | if (!call_node->checked_type().as<TensorTypeNode>()) { |
566 | return false; |
567 | } |
568 | |
569 | return true; |
570 | } |
571 | |
572 | virtual Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape, |
573 | DataType dtype) const = 0; |
574 | |
575 | Expr Callback(const Expr& pre, const Expr& post, |
576 | const Map<DFPattern, Array<Expr>>& node_map) const override { |
577 | if (!Check(pre, post, node_map)) { |
578 | return post; |
579 | } |
580 | |
581 | const TensorTypeNode* like_ty = pre->checked_type().as<TensorTypeNode>(); |
582 | Array<Integer> cshape; |
583 | for (const auto& dim : like_ty->shape) { |
584 | if (const auto* imm = dim.as<IntImmNode>()) { |
585 | cshape.push_back(Integer(GetRef<IntImm>(imm))); |
586 | } else { |
587 | // shape is not static, don't concretize |
588 | return post; |
589 | } |
590 | } |
591 | |
592 | return Concretize(node_map, cshape, like_ty->dtype); |
593 | } |
594 | |
595 | protected: |
596 | DFPattern data_pat_; |
597 | DFPattern like_pat_; |
598 | }; |
599 | |
600 | class ConcretizeZerosLikeRewrite : public ConcretizeLikeRewrite { |
601 | public: |
602 | ConcretizeZerosLikeRewrite() : ConcretizeLikeRewrite(Op::Get("zeros_like" )) {} |
603 | |
604 | Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape, |
605 | DataType dtype) const override { |
606 | return MakeZeros(shape, dtype); |
607 | } |
608 | }; |
609 | |
610 | class ConcretizeOnesLikeRewrite : public ConcretizeLikeRewrite { |
611 | public: |
612 | ConcretizeOnesLikeRewrite() : ConcretizeLikeRewrite(Op::Get("ones_like" )) {} |
613 | |
614 | Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape, |
615 | DataType dtype) const override { |
616 | return MakeOnes(shape, dtype); |
617 | } |
618 | }; |
619 | |
620 | class ConcretizeFullLikeRewrite : public ConcretizeLikeRewrite { |
621 | public: |
622 | ConcretizeFullLikeRewrite() : ConcretizeLikeRewrite(Op::Get("full_like" )) {} |
623 | |
624 | Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape, |
625 | DataType dtype) const override { |
626 | // `like_pat_` here is `fill_value` |
627 | return MakeFull(node_map[like_pat_][0], shape, dtype); |
628 | } |
629 | }; |
630 | |
631 | class ConcretizeReshapeLikeRewrite : public ConcretizeLikeRewrite { |
632 | public: |
633 | ConcretizeReshapeLikeRewrite() : ConcretizeLikeRewrite(Op::Get("reshape_like" )) {} |
634 | |
635 | Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape, |
636 | DataType dtype) const override { |
637 | return MakeReshape(node_map[data_pat_][0], shape); |
638 | } |
639 | }; |
640 | |
641 | class ConcretizeCollapseSumLikeRewrite : public ConcretizeLikeRewrite { |
642 | public: |
643 | ConcretizeCollapseSumLikeRewrite() : ConcretizeLikeRewrite(Op::Get("collapse_sum_like" )) {} |
644 | |
645 | Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape, |
646 | DataType dtype) const override { |
647 | ICHECK_LE(shape.size(), std::numeric_limits<int64_t>::max()); |
648 | static const Op& op = Op::Get("collapse_sum_to" ); |
649 | auto attrs = make_object<InitOpAttrs>(); |
650 | attrs->shape = shape; |
651 | std::vector<int64_t> s; |
652 | std::transform(shape.begin(), shape.end(), std::back_inserter(s), |
653 | [](Integer i) { return i.IntValue(); }); |
654 | auto cshape = MakeConstantTensor(DataType::Int(32), {static_cast<int64_t>(shape.size())}, s); |
655 | return Call(op, {node_map[data_pat_][0], cshape}, Attrs(attrs)); |
656 | } |
657 | }; |
658 | |
659 | class ConcretizeBroadcastToLikeRewrite : public ConcretizeLikeRewrite { |
660 | public: |
661 | ConcretizeBroadcastToLikeRewrite() : ConcretizeLikeRewrite(Op::Get("broadcast_to_like" )) {} |
662 | |
663 | Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape, |
664 | DataType dtype) const override { |
665 | return MakeBroadCastTo(node_map[data_pat_][0], shape); |
666 | } |
667 | }; |
668 | |
669 | /*! |
670 | * \brief Converts cast_like operator to cast. Not inheriting from ConcretizeLikeRewrite |
671 | * because even if shape is not static, still can concretize. |
672 | */ |
673 | class ConcretizeCastLikeRewrite : public DFPatternRewrite { |
674 | public: |
675 | ConcretizeCastLikeRewrite() { |
676 | data_pat_ = IsWildcard(); |
677 | like_pat_ = IsWildcard(); |
678 | pattern_ = IsOp("cast_like" )({data_pat_, like_pat_}); |
679 | } |
680 | |
681 | Expr Callback(const Expr& pre, const Expr& post, |
682 | const Map<DFPattern, Array<Expr>>& node_map) const override { |
683 | const CallNode* call_node = pre.as<CallNode>(); |
684 | ICHECK(call_node); |
685 | |
686 | if (!call_node->checked_type().as<TensorTypeNode>()) { |
687 | return post; |
688 | } |
689 | |
690 | const TensorTypeNode* like_ty = pre->checked_type().as<TensorTypeNode>(); |
691 | return MakeCast(node_map[data_pat_][0], like_ty->dtype); |
692 | } |
693 | |
694 | protected: |
695 | DFPattern data_pat_; |
696 | DFPattern like_pat_; |
697 | }; |
698 | |
699 | /*! \brief Eliminates expressions that are equivalent to identity. */ |
700 | class EliminateIdentityRewrite : public DFPatternRewrite { |
701 | public: |
702 | EliminateIdentityRewrite() { |
703 | x_ = IsWildcard(); |
704 | const_ = IsConstant(); |
705 | |
706 | DFPattern add_op = IsOp("add" ); |
707 | DFPattern mul_op = IsOp("multiply" ); |
708 | DFPattern zeros_expr = IsOp("zeros" )({}) || IsOp("zeros_like" )({IsWildcard()}) || const_; |
709 | DFPattern ones_expr = IsOp("ones" )({}) || IsOp("ones_like" )({IsWildcard()}) || const_; |
710 | |
711 | // add and multiply are commutative so we don't need another pattern for reversed args |
712 | DFPattern add_id = add_op({x_, zeros_expr}); |
713 | DFPattern mul_id = mul_op({x_, ones_expr}); |
714 | |
715 | DFPattern sub_id = IsOp("subtract" )({x_, zeros_expr}); |
716 | DFPattern div_id = IsOp("divide" )({x_, ones_expr}); |
717 | |
718 | pattern_ = add_id || mul_id || sub_id || div_id; |
719 | } |
720 | |
721 | bool CheckConstant(const OpNode* op, const ConstantNode* constant) const { |
722 | if (!IsScalar(GetRef<Expr>(constant))) { |
723 | return false; |
724 | } |
725 | auto value = TryToScalar(constant->data, 0); |
726 | if (!value) { |
727 | // unsupported dtype |
728 | return false; |
729 | } |
730 | if (op->name == "add" || op->name == "subtract" ) { |
731 | return value.value() == 0.0; |
732 | } else if (op->name == "multiply" || op->name == "divide" ) { |
733 | return value.value() == 1.0; |
734 | } |
735 | return false; |
736 | } |
737 | |
738 | Expr Callback(const Expr& pre, const Expr& post, |
739 | const Map<DFPattern, Array<Expr>>& node_map) const override { |
740 | const CallNode* call = pre.as<CallNode>(); |
741 | ICHECK(call); |
742 | Type pre_type = pre->checked_type_; |
743 | ICHECK(pre_type.as<TensorTypeNode>()); |
744 | auto x = node_map[x_][0]; |
745 | bool is_left = post.as<CallNode>()->args[1] == x; |
746 | Type x_type; |
747 | if (is_left) { |
748 | x_type = call->args[1]->checked_type_; |
749 | } else { |
750 | x_type = call->args[0]->checked_type_; |
751 | } |
752 | |
753 | if (node_map.count(const_)) { |
754 | // the other argument is a Constant in this case |
755 | const ConstantNode* constant = node_map[const_][0].as<ConstantNode>(); |
756 | const OpNode* op = call->op.as<OpNode>(); |
757 | ICHECK(constant); |
758 | ICHECK(op); |
759 | if (!CheckConstant(op, constant)) { |
760 | return post; |
761 | } |
762 | } |
763 | |
764 | if (StructuralEqual()(x_type, pre_type)) { |
765 | return x; |
766 | } |
767 | |
768 | return post; |
769 | } |
770 | |
771 | private: |
772 | DFPattern x_; |
773 | DFPattern const_; |
774 | }; |
775 | |
776 | /*! \brief Switch adjacent add-mul with constants to mul-add. |
777 | * As mul-add pattern is more friendly to FoldScaleAxis. |
778 | */ |
779 | class SwitchAddMultiply : public DFPatternRewrite { |
780 | public: |
781 | SwitchAddMultiply() { |
782 | x_ = IsWildcard(); |
783 | c1_ = IsConstant(); |
784 | c2_ = IsConstant(); |
785 | pattern_ = (x_ + c1_) * c2_; |
786 | } |
787 | |
788 | Expr Callback(const Expr& pre, const Expr& post, |
789 | const Map<DFPattern, Array<Expr>>& node_map) const override { |
790 | auto x = node_map[x_][0]; |
791 | auto c1 = node_map[c1_][0]; |
792 | auto c2 = node_map[c2_][0]; |
793 | |
794 | if (x.as<ConstantNode>()) { |
795 | return post; |
796 | } |
797 | |
798 | Expr const_expr = Call(Op::Get("multiply" ), {c1, c2}); |
799 | Expr const_val = transform::FoldConstantExpr(const_expr); |
800 | |
801 | return Call(Op::Get("add" ), {Call(Op::Get("multiply" ), {x, c2}), const_val}); |
802 | } |
803 | |
804 | private: |
805 | DFPattern x_; |
806 | DFPattern c1_; |
807 | DFPattern c2_; |
808 | }; |
809 | |
810 | /*! \brief Simplify two adjacent multiply or add with constants for further constant folding. |
811 | * The pattern matching supports commutative property. |
812 | */ |
813 | class SimplifyAdjacentMultiplyOrAdd : public DFPatternRewrite { |
814 | public: |
815 | SimplifyAdjacentMultiplyOrAdd() { |
816 | x_ = IsWildcard(); |
817 | c1_ = IsConstant(); |
818 | c2_ = IsConstant(); |
819 | pattern_ = (x_ * c1_ * c2_) || (x_ + c1_ + c2_); |
820 | } |
821 | |
822 | Expr Callback(const Expr& pre, const Expr& post, |
823 | const Map<DFPattern, Array<Expr>>& node_map) const override { |
824 | const CallNode* call = pre.as<CallNode>(); |
825 | auto x = node_map[x_][0]; |
826 | auto c1 = node_map[c1_][0]; |
827 | auto c2 = node_map[c2_][0]; |
828 | |
829 | if (x.as<ConstantNode>()) { |
830 | return post; |
831 | } |
832 | |
833 | Expr const_expr = Call(call->op, {c1, c2}); |
834 | Expr const_val = transform::FoldConstantExpr(const_expr); |
835 | |
836 | return Call(call->op, {x, const_val}); |
837 | } |
838 | |
839 | private: |
840 | DFPattern x_; |
841 | DFPattern c1_; |
842 | DFPattern c2_; |
843 | }; |
844 | |
845 | /*! \brief Simplifying x+x to x*2 */ |
846 | class SimplifyAdd : public DFPatternRewrite { |
847 | public: |
848 | SimplifyAdd() { |
849 | x_ = IsWildcard(); |
850 | y_ = IsWildcard(); |
851 | pattern_ = IsOp("add" )({x_, y_}); |
852 | } |
853 | |
854 | Expr Callback(const Expr& pre, const Expr& post, |
855 | const Map<DFPattern, Array<Expr>>& node_map) const override { |
856 | Type pre_type = pre->checked_type_; |
857 | auto dtype = pre_type.as<TensorTypeNode>()->dtype; |
858 | auto x = node_map[x_][0]; |
859 | auto y = node_map[y_][0]; |
860 | auto data_type = Downcast<TensorType>(x->checked_type()); |
861 | |
862 | if (x == y) { |
863 | Expr value; |
864 | value = MakeConstantScalar(dtype, 2); |
865 | return InferType(Call(Op::Get("multiply" ), {x, value})); |
866 | } |
867 | return post; |
868 | } |
869 | |
870 | private: |
871 | /*! \brief Pattern input */ |
872 | DFPattern x_; |
873 | DFPattern y_; |
874 | }; |
875 | |
876 | /*! \brief Simplifying x/sqrt to x*sqrt */ |
877 | class SimplifyRSqrt : public DFPatternRewrite { |
878 | public: |
879 | SimplifyRSqrt() { |
880 | x_ = IsWildcard(); |
881 | numerator_ = IsWildcard(); |
882 | auto sqrt = IsOp("sqrt" ); |
883 | pattern_ = IsOp("divide" )({numerator_, sqrt({x_})}); |
884 | } |
885 | |
886 | Expr Callback(const Expr& pre, const Expr& post, |
887 | const Map<DFPattern, Array<Expr>>& node_map) const override { |
888 | static const Op& op = Op::Get("rsqrt" ); |
889 | auto x = node_map[x_][0]; |
890 | auto numerator = node_map[numerator_][0]; |
891 | return Call(Op::Get("multiply" ), {numerator, Call(op, {x})}); |
892 | } |
893 | |
894 | private: |
895 | /*! \brief Pattern input */ |
896 | DFPattern x_; |
897 | DFPattern numerator_; |
898 | }; |
899 | |
900 | /*! \brief Base class for simplifying dequantize followed by arg ops */ |
901 | class SimplifyDQArgFunc : public DFPatternRewrite { |
902 | public: |
903 | explicit SimplifyDQArgFunc(std::string op) : op_(op) { |
904 | x_ = IsWildcard(); |
905 | dq_ = IsOp("qnn.dequantize" )({x_, IsWildcard(), IsWildcard()}); |
906 | pattern_ = IsOp(op_)({dq_}); |
907 | } |
908 | |
909 | Expr Callback(const Expr& pre, const Expr& post, |
910 | const Map<DFPattern, Array<Expr>>& node_map) const override { |
911 | const CallNode* call = pre.as<CallNode>(); |
912 | ICHECK(call); |
913 | auto x = node_map[x_][0]; |
914 | return Call(Op::Get(op_), {x}, call->attrs); |
915 | } |
916 | |
917 | protected: |
918 | /*! \brief Pattern input */ |
919 | DFPattern x_; |
920 | /*! \brief dequantize op */ |
921 | DFPattern dq_; |
922 | /*! \brief Name of op to simplify */ |
923 | String op_; |
924 | }; |
925 | |
926 | /*! \brief Simplify dequantize follwed by argmax */ |
927 | class SimplifyDQArgMax : public SimplifyDQArgFunc { |
928 | public: |
929 | SimplifyDQArgMax() : SimplifyDQArgFunc("argmax" ) {} |
930 | }; |
931 | |
932 | /*! \brief Simplify dequantize follwed by argmin */ |
933 | class SimplifyDQArgMin : public SimplifyDQArgFunc { |
934 | public: |
935 | SimplifyDQArgMin() : SimplifyDQArgFunc("argmin" ) {} |
936 | }; |
937 | |
938 | /*! \brief Simplify dequantize follwed by argsort */ |
939 | class SimplifyDQArgSort : public SimplifyDQArgFunc { |
940 | public: |
941 | SimplifyDQArgSort() : SimplifyDQArgFunc("argsort" ) {} |
942 | }; |
943 | |
944 | Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { |
945 | // the rewrites will be applied in the given order, and repeated until fixed point |
946 | DFPatternRewriteComposer composer; |
947 | composer.AddRewrite<ConcretizeZerosLikeRewrite>(); |
948 | composer.AddRewrite<ConcretizeOnesLikeRewrite>(); |
949 | composer.AddRewrite<ConcretizeFullLikeRewrite>(); |
950 | composer.AddRewrite<ConcretizeReshapeLikeRewrite>(); |
951 | composer.AddRewrite<ConcretizeCollapseSumLikeRewrite>(); |
952 | composer.AddRewrite<ConcretizeBroadcastToLikeRewrite>(); |
953 | composer.AddRewrite<ConcretizeCastLikeRewrite>(); |
954 | composer.AddRewrite<SimplifyAdd>(); |
955 | composer.AddRewrite<SimplifyRSqrt>(); |
956 | composer.AddRewrite<EliminateIdentityRewrite>(); |
957 | composer.AddRewrite<SimplifyReshape>(); |
958 | composer.AddRewrite<SimplifyTranspose>(); |
959 | composer.AddRewrite<SimplifySameCast>(); |
960 | composer.AddRewrite<SimplifyConsecutiveCast>(); |
961 | composer.AddRewrite<FullElementwise>(); |
962 | composer.AddRewrite<SwitchAddMultiply>(); |
963 | composer.AddRewrite<SimplifyAdjacentMultiplyOrAdd>(); |
964 | composer.AddRewrite<SimplifyDQArgMax>(); |
965 | composer.AddRewrite<SimplifyDQArgMin>(); |
966 | composer.AddRewrite<SimplifyDQArgSort>(); |
967 | composer.AddRewrite<SimplifyClipAndConsecutiveCast>(); |
968 | composer.AddRewrite<SimplifyCastClip>(); |
969 | return RewritePatterns(composer.MakeCallbacks(), expr, mod); |
970 | } |
971 | |
972 | namespace transform { |
973 | |
974 | Pass SimplifyExpr() { |
975 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
976 | [=](Function f, IRModule m, PassContext pc) { |
977 | return Downcast<Function>(SimplifyExpr(f, m)); |
978 | }; |
979 | return CreateFunctionPass(pass_func, 0, "SimplifyExpr" , {"InferType" }); |
980 | } |
981 | |
982 | TVM_REGISTER_GLOBAL("relay._transform.SimplifyExpr" ).set_body_typed(SimplifyExpr); |
983 | |
984 | } // namespace transform |
985 | |
986 | } // namespace relay |
987 | } // namespace tvm |
988 | |