1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file simplify_inference.cc |
22 | */ |
23 | #include <tvm/relay/analysis.h> |
24 | #include <tvm/relay/attrs/nn.h> |
25 | #include <tvm/relay/expr_functor.h> |
26 | #include <tvm/relay/op.h> |
27 | #include <tvm/relay/transform.h> |
28 | |
29 | #include "pattern_utils.h" |
30 | |
31 | namespace tvm { |
32 | namespace relay { |
33 | |
34 | Expr BatchNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Expr moving_mean, |
35 | Expr moving_var, Type tdata) { |
36 | auto ttype = tdata.as<TensorTypeNode>(); |
37 | ICHECK(ttype); |
38 | const auto param = attrs.as<BatchNormAttrs>(); |
39 | Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon)); |
40 | Expr var_add_eps = Add(moving_var, epsilon); |
41 | Expr sqrt_var = Sqrt(var_add_eps); |
42 | Expr scale = Divide(MakeConstantScalar(ttype->dtype, 1.0f), sqrt_var); |
43 | |
44 | if (param->scale) { |
45 | scale = Multiply(scale, gamma); |
46 | } |
47 | Expr neg_mean = Negative(moving_mean); |
48 | Expr shift = Multiply(neg_mean, scale); |
49 | if (param->center) { |
50 | shift = Add(shift, beta); |
51 | } |
52 | |
53 | auto ndim = ttype->shape.size(); |
54 | int axis = (param->axis < 0) ? param->axis + ndim : param->axis; |
55 | scale = ExpandBiasToMatchAxis(scale, ndim, {axis}); |
56 | shift = ExpandBiasToMatchAxis(shift, ndim, {axis}); |
57 | |
58 | Expr out = Multiply(data, scale); |
59 | out = Add(out, shift); |
60 | return out; |
61 | } |
62 | |
63 | Expr GroupNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) { |
64 | auto ttype = tdata.as<TensorTypeNode>(); |
65 | ICHECK(ttype); |
66 | const auto param = attrs.as<GroupNormAttrs>(); |
67 | ICHECK(param); |
68 | |
69 | int ndim = ttype->shape.size(); |
70 | int axis = (param->axis < 0) ? param->axis + ndim : param->axis; |
71 | Array<Integer> reduced_axes; |
72 | Array<Integer> new_shape; |
73 | Array<Integer> old_shape; |
74 | |
75 | int num_groups = param->num_groups; |
76 | int channel = ttype->shape[axis].as<IntImmNode>()->value; |
77 | |
78 | // old_shape = N, C, H, W |
79 | // new shape = N, num_groups, C/num_groups, H, W |
80 | // reduce_axes = axis of (C/num_groups, H, W) |
81 | for (int i = 0; i < ndim; ++i) { |
82 | auto val = ttype->shape[i].as<IntImmNode>()->value; |
83 | |
84 | // Save the old shape to reshape later |
85 | old_shape.push_back(val); |
86 | if (i == axis) { |
87 | new_shape.push_back(num_groups); |
88 | new_shape.push_back(channel / num_groups); |
89 | reduced_axes.push_back(i + 1); |
90 | continue; |
91 | } |
92 | if (i >= axis) { |
93 | reduced_axes.push_back(i + 1); |
94 | } |
95 | new_shape.push_back(val); |
96 | } |
97 | |
98 | data = Reshape(data, new_shape); |
99 | |
100 | Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon)); |
101 | Expr mean = Mean(data, {reduced_axes}, true, false); |
102 | Expr var = Variance(data, mean, {reduced_axes}, true, false); |
103 | Expr denom = Sqrt(Add(var, epsilon)); |
104 | Expr out = Divide(Subtract(data, mean), denom); |
105 | |
106 | out = Reshape(out, old_shape); |
107 | |
108 | if (param->scale) { |
109 | out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis})); |
110 | } |
111 | if (param->center) { |
112 | out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis})); |
113 | } |
114 | |
115 | return out; |
116 | } |
117 | |
118 | Expr LayerNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) { |
119 | auto ttype = tdata.as<TensorTypeNode>(); |
120 | ICHECK(ttype); |
121 | const auto param = attrs.as<LayerNormAttrs>(); |
122 | ICHECK(param); |
123 | |
124 | Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon)); |
125 | Expr mean = Mean(data, {param->axis}, true, false); |
126 | Expr var = Variance(data, mean, {param->axis}, true, false); |
127 | Expr denom = Sqrt(Add(var, epsilon)); |
128 | Expr out = Divide(Subtract(data, mean), denom); |
129 | |
130 | size_t ndim = ttype->shape.size(); |
131 | int axis = (param->axis < 0) ? param->axis + ndim : param->axis; |
132 | if (param->scale) { |
133 | out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis})); |
134 | } |
135 | if (param->center) { |
136 | out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis})); |
137 | } |
138 | return out; |
139 | } |
140 | |
141 | Expr InstanceNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) { |
142 | auto ttype = tdata.as<TensorTypeNode>(); |
143 | ICHECK(ttype); |
144 | const auto param = attrs.as<InstanceNormAttrs>(); |
145 | ICHECK(param); |
146 | |
147 | int ndim = ttype->shape.size(); |
148 | int axis = (param->axis < 0) ? param->axis + ndim : param->axis; |
149 | Array<Integer> reduced_axes; |
150 | for (int i = 1; i < ndim; ++i) { |
151 | if (i != axis) reduced_axes.push_back(i); |
152 | } |
153 | |
154 | Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon)); |
155 | Expr mean = Mean(data, reduced_axes, true, false); |
156 | Expr var = Variance(data, mean, reduced_axes, true, false); |
157 | Expr denom = Sqrt(Add(var, epsilon)); |
158 | Expr out = Divide(Subtract(data, mean), denom); |
159 | |
160 | if (param->scale) { |
161 | out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis})); |
162 | } |
163 | if (param->center) { |
164 | out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis})); |
165 | } |
166 | return out; |
167 | } |
168 | |
169 | Expr L2NormToInferUnpack(const Attrs attrs, Expr data) { |
170 | const auto param = attrs.as<L2NormalizeAttrs>(); |
171 | ICHECK(param); |
172 | |
173 | Expr epsilon = MakeConstantScalar(DataType::Float(32), static_cast<float>(param->eps)); |
174 | |
175 | Expr sqr = Multiply(data, data); |
176 | Expr sum = Maximum(Sum(sqr, param->axis, true, false), epsilon); |
177 | Expr sqrt = Sqrt(sum); |
178 | return Divide(data, sqrt); |
179 | } |
180 | |
181 | class InferenceSimplifier : public MixedModeMutator { |
182 | public: |
183 | InferenceSimplifier() |
184 | : batch_norm_op_(Op::Get("nn.batch_norm" )), |
185 | dropout_op_(Op::Get("nn.dropout" )), |
186 | instance_norm_op_(Op::Get("nn.instance_norm" )), |
187 | layer_norm_op_(Op::Get("nn.layer_norm" )), |
188 | group_norm_op_(Op::Get("nn.group_norm" )), |
189 | l2_norm_op_(Op::Get("nn.l2_normalize" )) {} |
190 | |
191 | Expr Rewrite_(const TupleGetItemNode* n, const Expr& new_e) final { |
192 | const auto* new_n = new_e.as<TupleGetItemNode>(); |
193 | if (new_n->index != 0) { |
194 | return new_e; |
195 | } |
196 | if (const auto* call = new_n->tuple.as<CallNode>()) { |
197 | if (call->op == batch_norm_op_) { |
198 | return BatchNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2], |
199 | call->args[3], call->args[4], ty_map_.at(call->args[0])); |
200 | } else if (call->op == dropout_op_) { |
201 | return call->args[0]; |
202 | } |
203 | } |
204 | return new_e; |
205 | } |
206 | |
207 | Expr Rewrite_(const CallNode* n, const Expr& new_n) { |
208 | if (n->op == batch_norm_op_) { |
209 | ty_map_[new_n.as<CallNode>()->args[0]] = n->args[0]->checked_type(); |
210 | } else if (n->op == layer_norm_op_) { |
211 | const auto* call = new_n.as<CallNode>(); |
212 | return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2], |
213 | n->args[0]->checked_type()); |
214 | } else if (n->op == group_norm_op_) { |
215 | const auto* call = new_n.as<CallNode>(); |
216 | return GroupNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2], |
217 | n->args[0]->checked_type()); |
218 | } else if (n->op == instance_norm_op_) { |
219 | const auto* call = new_n.as<CallNode>(); |
220 | return InstanceNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2], |
221 | n->args[0]->checked_type()); |
222 | } else if (n->op == l2_norm_op_) { |
223 | const auto* call = new_n.as<CallNode>(); |
224 | return L2NormToInferUnpack(call->attrs, call->args[0]); |
225 | } |
226 | return new_n; |
227 | } |
228 | |
229 | private: |
230 | // Cache the following ops. They will be used in the passes repeatedly for |
231 | // operator equivalence checking so that the registry lookup overhead can be |
232 | // reduced. |
233 | const Op& batch_norm_op_; |
234 | const Op& dropout_op_; |
235 | const Op& instance_norm_op_; |
236 | const Op& layer_norm_op_; |
237 | const Op& group_norm_op_; |
238 | const Op& l2_norm_op_; |
239 | std::unordered_map<Expr, Type, ObjectPtrHash, ObjectPtrEqual> ty_map_; |
240 | }; |
241 | |
242 | Expr SimplifyInference(const Expr& e) { return InferenceSimplifier().Mutate(e); } |
243 | |
244 | namespace transform { |
245 | |
246 | Pass SimplifyInference() { |
247 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
248 | [=](Function f, IRModule m, PassContext pc) { |
249 | return Downcast<Function>(SimplifyInference(f)); |
250 | }; |
251 | return CreateFunctionPass(pass_func, 0, "SimplifyInference" , {"InferType" }); |
252 | } |
253 | |
254 | TVM_REGISTER_GLOBAL("relay._transform.SimplifyInference" ).set_body_typed(SimplifyInference); |
255 | |
256 | } // namespace transform |
257 | |
258 | } // namespace relay |
259 | } // namespace tvm |
260 | |