1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file simplify_inference.cc
22 */
23#include <tvm/relay/analysis.h>
24#include <tvm/relay/attrs/nn.h>
25#include <tvm/relay/expr_functor.h>
26#include <tvm/relay/op.h>
27#include <tvm/relay/transform.h>
28
29#include "pattern_utils.h"
30
31namespace tvm {
32namespace relay {
33
34Expr BatchNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Expr moving_mean,
35 Expr moving_var, Type tdata) {
36 auto ttype = tdata.as<TensorTypeNode>();
37 ICHECK(ttype);
38 const auto param = attrs.as<BatchNormAttrs>();
39 Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon));
40 Expr var_add_eps = Add(moving_var, epsilon);
41 Expr sqrt_var = Sqrt(var_add_eps);
42 Expr scale = Divide(MakeConstantScalar(ttype->dtype, 1.0f), sqrt_var);
43
44 if (param->scale) {
45 scale = Multiply(scale, gamma);
46 }
47 Expr neg_mean = Negative(moving_mean);
48 Expr shift = Multiply(neg_mean, scale);
49 if (param->center) {
50 shift = Add(shift, beta);
51 }
52
53 auto ndim = ttype->shape.size();
54 int axis = (param->axis < 0) ? param->axis + ndim : param->axis;
55 scale = ExpandBiasToMatchAxis(scale, ndim, {axis});
56 shift = ExpandBiasToMatchAxis(shift, ndim, {axis});
57
58 Expr out = Multiply(data, scale);
59 out = Add(out, shift);
60 return out;
61}
62
63Expr GroupNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) {
64 auto ttype = tdata.as<TensorTypeNode>();
65 ICHECK(ttype);
66 const auto param = attrs.as<GroupNormAttrs>();
67 ICHECK(param);
68
69 int ndim = ttype->shape.size();
70 int axis = (param->axis < 0) ? param->axis + ndim : param->axis;
71 Array<Integer> reduced_axes;
72 Array<Integer> new_shape;
73 Array<Integer> old_shape;
74
75 int num_groups = param->num_groups;
76 int channel = ttype->shape[axis].as<IntImmNode>()->value;
77
78 // old_shape = N, C, H, W
79 // new shape = N, num_groups, C/num_groups, H, W
80 // reduce_axes = axis of (C/num_groups, H, W)
81 for (int i = 0; i < ndim; ++i) {
82 auto val = ttype->shape[i].as<IntImmNode>()->value;
83
84 // Save the old shape to reshape later
85 old_shape.push_back(val);
86 if (i == axis) {
87 new_shape.push_back(num_groups);
88 new_shape.push_back(channel / num_groups);
89 reduced_axes.push_back(i + 1);
90 continue;
91 }
92 if (i >= axis) {
93 reduced_axes.push_back(i + 1);
94 }
95 new_shape.push_back(val);
96 }
97
98 data = Reshape(data, new_shape);
99
100 Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon));
101 Expr mean = Mean(data, {reduced_axes}, true, false);
102 Expr var = Variance(data, mean, {reduced_axes}, true, false);
103 Expr denom = Sqrt(Add(var, epsilon));
104 Expr out = Divide(Subtract(data, mean), denom);
105
106 out = Reshape(out, old_shape);
107
108 if (param->scale) {
109 out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis}));
110 }
111 if (param->center) {
112 out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis}));
113 }
114
115 return out;
116}
117
118Expr LayerNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) {
119 auto ttype = tdata.as<TensorTypeNode>();
120 ICHECK(ttype);
121 const auto param = attrs.as<LayerNormAttrs>();
122 ICHECK(param);
123
124 Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon));
125 Expr mean = Mean(data, {param->axis}, true, false);
126 Expr var = Variance(data, mean, {param->axis}, true, false);
127 Expr denom = Sqrt(Add(var, epsilon));
128 Expr out = Divide(Subtract(data, mean), denom);
129
130 size_t ndim = ttype->shape.size();
131 int axis = (param->axis < 0) ? param->axis + ndim : param->axis;
132 if (param->scale) {
133 out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis}));
134 }
135 if (param->center) {
136 out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis}));
137 }
138 return out;
139}
140
141Expr InstanceNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) {
142 auto ttype = tdata.as<TensorTypeNode>();
143 ICHECK(ttype);
144 const auto param = attrs.as<InstanceNormAttrs>();
145 ICHECK(param);
146
147 int ndim = ttype->shape.size();
148 int axis = (param->axis < 0) ? param->axis + ndim : param->axis;
149 Array<Integer> reduced_axes;
150 for (int i = 1; i < ndim; ++i) {
151 if (i != axis) reduced_axes.push_back(i);
152 }
153
154 Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast<float>(param->epsilon));
155 Expr mean = Mean(data, reduced_axes, true, false);
156 Expr var = Variance(data, mean, reduced_axes, true, false);
157 Expr denom = Sqrt(Add(var, epsilon));
158 Expr out = Divide(Subtract(data, mean), denom);
159
160 if (param->scale) {
161 out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis}));
162 }
163 if (param->center) {
164 out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis}));
165 }
166 return out;
167}
168
169Expr L2NormToInferUnpack(const Attrs attrs, Expr data) {
170 const auto param = attrs.as<L2NormalizeAttrs>();
171 ICHECK(param);
172
173 Expr epsilon = MakeConstantScalar(DataType::Float(32), static_cast<float>(param->eps));
174
175 Expr sqr = Multiply(data, data);
176 Expr sum = Maximum(Sum(sqr, param->axis, true, false), epsilon);
177 Expr sqrt = Sqrt(sum);
178 return Divide(data, sqrt);
179}
180
181class InferenceSimplifier : public MixedModeMutator {
182 public:
183 InferenceSimplifier()
184 : batch_norm_op_(Op::Get("nn.batch_norm")),
185 dropout_op_(Op::Get("nn.dropout")),
186 instance_norm_op_(Op::Get("nn.instance_norm")),
187 layer_norm_op_(Op::Get("nn.layer_norm")),
188 group_norm_op_(Op::Get("nn.group_norm")),
189 l2_norm_op_(Op::Get("nn.l2_normalize")) {}
190
191 Expr Rewrite_(const TupleGetItemNode* n, const Expr& new_e) final {
192 const auto* new_n = new_e.as<TupleGetItemNode>();
193 if (new_n->index != 0) {
194 return new_e;
195 }
196 if (const auto* call = new_n->tuple.as<CallNode>()) {
197 if (call->op == batch_norm_op_) {
198 return BatchNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
199 call->args[3], call->args[4], ty_map_.at(call->args[0]));
200 } else if (call->op == dropout_op_) {
201 return call->args[0];
202 }
203 }
204 return new_e;
205 }
206
207 Expr Rewrite_(const CallNode* n, const Expr& new_n) {
208 if (n->op == batch_norm_op_) {
209 ty_map_[new_n.as<CallNode>()->args[0]] = n->args[0]->checked_type();
210 } else if (n->op == layer_norm_op_) {
211 const auto* call = new_n.as<CallNode>();
212 return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
213 n->args[0]->checked_type());
214 } else if (n->op == group_norm_op_) {
215 const auto* call = new_n.as<CallNode>();
216 return GroupNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
217 n->args[0]->checked_type());
218 } else if (n->op == instance_norm_op_) {
219 const auto* call = new_n.as<CallNode>();
220 return InstanceNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
221 n->args[0]->checked_type());
222 } else if (n->op == l2_norm_op_) {
223 const auto* call = new_n.as<CallNode>();
224 return L2NormToInferUnpack(call->attrs, call->args[0]);
225 }
226 return new_n;
227 }
228
229 private:
230 // Cache the following ops. They will be used in the passes repeatedly for
231 // operator equivalence checking so that the registry lookup overhead can be
232 // reduced.
233 const Op& batch_norm_op_;
234 const Op& dropout_op_;
235 const Op& instance_norm_op_;
236 const Op& layer_norm_op_;
237 const Op& group_norm_op_;
238 const Op& l2_norm_op_;
239 std::unordered_map<Expr, Type, ObjectPtrHash, ObjectPtrEqual> ty_map_;
240};
241
242Expr SimplifyInference(const Expr& e) { return InferenceSimplifier().Mutate(e); }
243
244namespace transform {
245
246Pass SimplifyInference() {
247 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
248 [=](Function f, IRModule m, PassContext pc) {
249 return Downcast<Function>(SimplifyInference(f));
250 };
251 return CreateFunctionPass(pass_func, 0, "SimplifyInference", {"InferType"});
252}
253
254TVM_REGISTER_GLOBAL("relay._transform.SimplifyInference").set_body_typed(SimplifyInference);
255
256} // namespace transform
257
258} // namespace relay
259} // namespace tvm
260