1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file canonicalize_cast.cc |
22 | * \brief Canonicalize cast expressions to make operator fusion more efficient. |
23 | */ |
24 | #include <tvm/relay/analysis.h> |
25 | #include <tvm/relay/attrs/nn.h> |
26 | #include <tvm/relay/expr_functor.h> |
27 | #include <tvm/relay/transform.h> |
28 | |
29 | #include "pass_utils.h" |
30 | #include "pattern_utils.h" |
31 | |
32 | namespace tvm { |
33 | namespace relay { |
34 | |
35 | // This pass finds upcast that is referred by multiple elemwise/broadcast operators, and creates a |
36 | // copy of it in each branch such that after fusion the previous function have output with fewer |
37 | // bits. |
38 | // |
39 | // Consider the following example: |
40 | // \code |
41 | // def @main(x: int8) { |
42 | // %1 = cast(%x, f32) |
43 | // %2 = exp(%1) |
44 | // %3 = log(%1) |
45 | // (%3, 4) |
46 | // } |
47 | // \endcode |
48 | // |
49 | // We would like to prevent sharing of the cast expression such that operator fusion can produce |
50 | // more efficient result as below. |
51 | // \code |
52 | // def @main(x: int8) { |
53 | // %1 = fn (%p1: i8) { |
54 | // exp(cast(%p1, f32) |
55 | // } |
56 | // %3 = %1(%x) |
57 | // %2 = fn (%p1: i8) { |
58 | // log(cast(%p1, f32) |
59 | // } |
60 | // %4 = %2(%x) |
61 | // (%3, 4) |
62 | // } |
63 | // \endcode |
64 | class CastCanonicalizer : public ExprMutator { |
65 | public: |
66 | CastCanonicalizer() : cast_op_(Op::Get("cast" )) {} |
67 | |
68 | Expr VisitExpr_(const CallNode* call) { |
69 | static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern" ); |
70 | |
71 | if (const OpNode* opnode = call->op.as<OpNode>()) { |
72 | auto pattern = fpattern[GetRef<Op>(opnode)]; |
73 | if (pattern <= kBroadcast) { |
74 | Array<Expr> call_args = call->args; |
75 | bool unchanged = true; |
76 | for (size_t i = 0; i < call_args.size(); ++i) { |
77 | Expr arg = call_args[i]; |
78 | Expr new_arg = GetNewCallArg(arg); |
79 | if (!arg.same_as(new_arg)) { |
80 | call_args.Set(i, new_arg); |
81 | unchanged = false; |
82 | } |
83 | } |
84 | if (unchanged) { |
85 | return GetRef<Expr>(call); |
86 | } |
87 | return Call(call->op, call_args, call->attrs, call->type_args); |
88 | } |
89 | } |
90 | |
91 | Expr new_expr = ExprMutator::VisitExpr_(call); |
92 | return new_expr; |
93 | } |
94 | |
95 | private: |
96 | std::unordered_map<const Object*, size_t> ref_counter_; |
97 | // cast op is frequently checked for equivalence. Therefore, we cache it to |
98 | // reduce lookup overhead. |
99 | const Op& cast_op_; |
100 | |
101 | Expr GetNewCallArg(const Expr& e) { |
102 | // if e is a upcast and ref count > 1, create an copy; otherwise call the default visitor |
103 | Expr new_expr = this->VisitExpr(e); |
104 | |
105 | if (const CallNode* call = e.as<CallNode>()) { |
106 | if (call->op == cast_op_) { |
107 | auto attrs = call->attrs.as<CastAttrs>(); |
108 | const auto* from_type = call->args[0]->type_as<TensorTypeNode>(); |
109 | ICHECK(from_type); |
110 | |
111 | if (from_type->dtype.bits() < attrs->dtype.bits()) { |
112 | if (++ref_counter_[call] > 1) { |
113 | const CallNode* new_call = new_expr.as<CallNode>(); |
114 | ICHECK(new_call); |
115 | ICHECK(new_call->op == cast_op_); |
116 | return Call(new_call->op, new_call->args, new_call->attrs, new_call->type_args); |
117 | } |
118 | } |
119 | } |
120 | } |
121 | return new_expr; |
122 | } |
123 | }; |
124 | |
125 | Expr CanonicalizeCast(const Expr& e) { return CastCanonicalizer().Mutate(e); } |
126 | |
127 | namespace transform { |
128 | |
129 | Pass CanonicalizeCast() { |
130 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
131 | [=](Function f, IRModule m, PassContext pc) { |
132 | return Downcast<Function>(CanonicalizeCast(f)); |
133 | }; |
134 | return CreateFunctionPass(pass_func, 3, "CanonicalizeCast" , {"InferType" }); |
135 | } |
136 | |
137 | TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeCast" ).set_body_typed(CanonicalizeCast); |
138 | |
139 | } // namespace transform |
140 | |
141 | } // namespace relay |
142 | } // namespace tvm |
143 | |