1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file canonicalize_cast.cc
22 * \brief Canonicalize cast expressions to make operator fusion more efficient.
23 */
24#include <tvm/relay/analysis.h>
25#include <tvm/relay/attrs/nn.h>
26#include <tvm/relay/expr_functor.h>
27#include <tvm/relay/transform.h>
28
29#include "pass_utils.h"
30#include "pattern_utils.h"
31
32namespace tvm {
33namespace relay {
34
35// This pass finds upcast that is referred by multiple elemwise/broadcast operators, and creates a
36// copy of it in each branch such that after fusion the previous function have output with fewer
37// bits.
38//
39// Consider the following example:
40// \code
41// def @main(x: int8) {
42// %1 = cast(%x, f32)
43// %2 = exp(%1)
44// %3 = log(%1)
45// (%3, 4)
46// }
47// \endcode
48//
49// We would like to prevent sharing of the cast expression such that operator fusion can produce
50// more efficient result as below.
51// \code
52// def @main(x: int8) {
53// %1 = fn (%p1: i8) {
54// exp(cast(%p1, f32)
55// }
56// %3 = %1(%x)
57// %2 = fn (%p1: i8) {
58// log(cast(%p1, f32)
59// }
60// %4 = %2(%x)
61// (%3, 4)
62// }
63// \endcode
64class CastCanonicalizer : public ExprMutator {
65 public:
66 CastCanonicalizer() : cast_op_(Op::Get("cast")) {}
67
68 Expr VisitExpr_(const CallNode* call) {
69 static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
70
71 if (const OpNode* opnode = call->op.as<OpNode>()) {
72 auto pattern = fpattern[GetRef<Op>(opnode)];
73 if (pattern <= kBroadcast) {
74 Array<Expr> call_args = call->args;
75 bool unchanged = true;
76 for (size_t i = 0; i < call_args.size(); ++i) {
77 Expr arg = call_args[i];
78 Expr new_arg = GetNewCallArg(arg);
79 if (!arg.same_as(new_arg)) {
80 call_args.Set(i, new_arg);
81 unchanged = false;
82 }
83 }
84 if (unchanged) {
85 return GetRef<Expr>(call);
86 }
87 return Call(call->op, call_args, call->attrs, call->type_args);
88 }
89 }
90
91 Expr new_expr = ExprMutator::VisitExpr_(call);
92 return new_expr;
93 }
94
95 private:
96 std::unordered_map<const Object*, size_t> ref_counter_;
97 // cast op is frequently checked for equivalence. Therefore, we cache it to
98 // reduce lookup overhead.
99 const Op& cast_op_;
100
101 Expr GetNewCallArg(const Expr& e) {
102 // if e is a upcast and ref count > 1, create an copy; otherwise call the default visitor
103 Expr new_expr = this->VisitExpr(e);
104
105 if (const CallNode* call = e.as<CallNode>()) {
106 if (call->op == cast_op_) {
107 auto attrs = call->attrs.as<CastAttrs>();
108 const auto* from_type = call->args[0]->type_as<TensorTypeNode>();
109 ICHECK(from_type);
110
111 if (from_type->dtype.bits() < attrs->dtype.bits()) {
112 if (++ref_counter_[call] > 1) {
113 const CallNode* new_call = new_expr.as<CallNode>();
114 ICHECK(new_call);
115 ICHECK(new_call->op == cast_op_);
116 return Call(new_call->op, new_call->args, new_call->attrs, new_call->type_args);
117 }
118 }
119 }
120 }
121 return new_expr;
122 }
123};
124
125Expr CanonicalizeCast(const Expr& e) { return CastCanonicalizer().Mutate(e); }
126
127namespace transform {
128
129Pass CanonicalizeCast() {
130 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
131 [=](Function f, IRModule m, PassContext pc) {
132 return Downcast<Function>(CanonicalizeCast(f));
133 };
134 return CreateFunctionPass(pass_func, 3, "CanonicalizeCast", {"InferType"});
135}
136
137TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeCast").set_body_typed(CanonicalizeCast);
138
139} // namespace transform
140
141} // namespace relay
142} // namespace tvm
143