1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file dynamic_to_static.cc |
23 | * \brief Rewrite Dynamic Operations to Static operations where possible |
24 | */ |
25 | #include <tvm/relay/attrs/algorithm.h> |
26 | #include <tvm/relay/attrs/image.h> |
27 | #include <tvm/relay/expr_functor.h> |
28 | #include <tvm/relay/transform.h> |
29 | |
30 | #include "pattern_utils.h" |
31 | |
32 | namespace tvm { |
33 | namespace relay { |
34 | |
35 | class DynamicToStaticMutator : public MixedModeMutator { |
36 | public: |
37 | DynamicToStaticMutator(IRModule mod, Function func) : mod_(mod), func_(func) { |
38 | op_map_ = { |
39 | {Op::Get("dyn.reshape" ), |
40 | [this](const CallNode* call_node) { |
41 | auto args = PrepareArgs(call_node); |
42 | if (const ConstantNode* shape = args[1].as<ConstantNode>()) { |
43 | ICHECK_EQ(shape->data->ndim, 1); |
44 | return MakeReshape(call_node->args[0], ToVector(shape->data)); |
45 | } |
46 | return Expr(nullptr); |
47 | }}, |
48 | {Op::Get("dyn.squeeze" ), |
49 | [this](const CallNode* call_node) { |
50 | auto args = PrepareArgs(call_node); |
51 | if (const ConstantNode* axis = args[1].as<ConstantNode>()) { |
52 | ICHECK_EQ(axis->data->ndim, 1); |
53 | return MakeSqueeze(call_node->args[0], ToVector(axis->data)); |
54 | } |
55 | return Expr(nullptr); |
56 | }}, |
57 | {Op::Get("dyn.tile" ), |
58 | [this](const CallNode* call_node) { |
59 | auto args = PrepareArgs(call_node); |
60 | if (const ConstantNode* reps = args[1].as<ConstantNode>()) { |
61 | ICHECK_EQ(reps->data->ndim, 1); |
62 | return MakeTile(call_node->args[0], ToVector(reps->data)); |
63 | } |
64 | return Expr(nullptr); |
65 | }}, |
66 | {Op::Get("dyn.topk" ), |
67 | [this](const CallNode* call_node) { |
68 | auto args = PrepareArgs(call_node); |
69 | if (const ConstantNode* k = args[1].as<ConstantNode>()) { |
70 | const TopKAttrs* param = call_node->attrs.as<TopKAttrs>(); |
71 | ICHECK(param); |
72 | return MakeTopK(call_node->args[0], static_cast<int>(ToScalar(k->data, 0)), |
73 | param->axis, param->ret_type, param->is_ascend, param->dtype); |
74 | } |
75 | return Expr(nullptr); |
76 | }}, |
77 | {Op::Get("dyn.broadcast_to" ), |
78 | [this](const CallNode* call_node) { |
79 | auto args = PrepareArgs(call_node); |
80 | if (const ConstantNode* shape = args[1].as<ConstantNode>()) { |
81 | ICHECK_EQ(shape->data->ndim, 1); |
82 | return MakeBroadCastTo(call_node->args[0], ToVector(shape->data)); |
83 | } |
84 | return Expr(nullptr); |
85 | }}, |
86 | {Op::Get("dyn.zeros" ), |
87 | [this](const CallNode* call_node) { |
88 | auto args = PrepareArgs(call_node); |
89 | if (const ConstantNode* shape = args[0].as<ConstantNode>()) { |
90 | const InitOpAttrs* param = call_node->attrs.as<InitOpAttrs>(); |
91 | ICHECK(param); |
92 | return MakeZeros(ToVector(shape->data), param->dtype); |
93 | } |
94 | return Expr(nullptr); |
95 | }}, |
96 | {Op::Get("dyn.ones" ), |
97 | [this](const CallNode* call_node) { |
98 | auto args = PrepareArgs(call_node); |
99 | if (const ConstantNode* shape = args[0].as<ConstantNode>()) { |
100 | const InitOpAttrs* param = call_node->attrs.as<InitOpAttrs>(); |
101 | ICHECK(param); |
102 | return MakeOnes(ToVector(shape->data), param->dtype); |
103 | } |
104 | return Expr(nullptr); |
105 | }}, |
106 | {Op::Get("dyn.one_hot" ), |
107 | [this](const CallNode* call_node) { |
108 | auto args = PrepareArgs(call_node); |
109 | if (const ConstantNode* depth = args[3].as<ConstantNode>()) { |
110 | const OneHotAttrs* param = call_node->attrs.as<OneHotAttrs>(); |
111 | ICHECK(param); |
112 | return MakeOneHot(call_node->args[0], call_node->args[1], call_node->args[2], |
113 | static_cast<int>(ToScalar(depth->data, 0)), param->axis, |
114 | param->dtype); |
115 | } |
116 | return Expr(nullptr); |
117 | }}, |
118 | {Op::Get("dyn.image.resize2d" ), |
119 | [this](const CallNode* call_node) { |
120 | auto args = PrepareArgs(call_node); |
121 | if (const ConstantNode* size = args[1].as<ConstantNode>()) { |
122 | if (const ConstantNode* roi = args[2].as<ConstantNode>()) { |
123 | const Resize2DAttrs* param = call_node->attrs.as<Resize2DAttrs>(); |
124 | ICHECK(param); |
125 | auto size_int = ToVector(size->data); |
126 | Array<PrimExpr> size_prim; |
127 | for (size_t i = 0; i < size_int.size(); ++i) { |
128 | size_prim.push_back(size_int[i]); |
129 | } |
130 | auto roi_vec = ToFloatVector(roi->data); |
131 | Array<FloatImm> roi_prim; |
132 | for (size_t i = 0; i < roi_vec.size(); ++i) { |
133 | roi_prim.push_back(roi_vec[i]); |
134 | } |
135 | return MakeResize2D(call_node->args[0], size_prim, roi_prim, param->layout, |
136 | param->method, param->coordinate_transformation_mode, |
137 | param->rounding_method, param->cubic_alpha, param->cubic_exclude, |
138 | param->extrapolation_value, param->out_dtype); |
139 | } |
140 | } |
141 | return Expr(nullptr); |
142 | }}, |
143 | {Op::Get("dyn.full" ), |
144 | [this](const CallNode* call_node) { |
145 | auto args = PrepareArgs(call_node); |
146 | if (const ConstantNode* shape = args[1].as<ConstantNode>()) { |
147 | ICHECK_EQ(shape->data->ndim, 1); |
148 | const InitOpAttrs* param = call_node->attrs.as<InitOpAttrs>(); |
149 | ICHECK(param); |
150 | return MakeFull(call_node->args[0], ToVector(shape->data), param->dtype); |
151 | } |
152 | return Expr(nullptr); |
153 | }}, |
154 | {Op::Get("dyn.nn.upsampling" ), |
155 | [this](const CallNode* call_node) { |
156 | auto args = PrepareArgs(call_node); |
157 | const ConstantNode* scale_h = args[1].as<ConstantNode>(); |
158 | const ConstantNode* scale_w = args[2].as<ConstantNode>(); |
159 | if (scale_h && scale_w) { |
160 | ICHECK_EQ(scale_h->data->ndim, 0); |
161 | ICHECK_EQ(scale_w->data->ndim, 0); |
162 | const UpSamplingAttrs* param = call_node->attrs.as<UpSamplingAttrs>(); |
163 | ICHECK(param); |
164 | return MakeUpSampling(call_node->args[0], ToScalar(scale_h->data), |
165 | ToScalar(scale_w->data), param->layout, param->method, |
166 | param->align_corners); |
167 | } |
168 | return Expr(nullptr); |
169 | }}, |
170 | {Op::Get("dyn.nn.upsampling3d" ), |
171 | [this](const CallNode* call_node) { |
172 | auto args = PrepareArgs(call_node); |
173 | const ConstantNode* scale_d = args[1].as<ConstantNode>(); |
174 | const ConstantNode* scale_h = args[2].as<ConstantNode>(); |
175 | const ConstantNode* scale_w = args[3].as<ConstantNode>(); |
176 | if (scale_d && scale_h && scale_w) { |
177 | ICHECK_EQ(scale_d->data->ndim, 0); |
178 | ICHECK_EQ(scale_h->data->ndim, 0); |
179 | ICHECK_EQ(scale_w->data->ndim, 0); |
180 | const UpSampling3DAttrs* param = call_node->attrs.as<UpSampling3DAttrs>(); |
181 | ICHECK(param); |
182 | return MakeUpSampling3D(call_node->args[0], ToScalar(scale_d->data), |
183 | ToScalar(scale_h->data), ToScalar(scale_w->data), |
184 | param->layout, param->method, |
185 | param->coordinate_transformation_mode); |
186 | } |
187 | return Expr(nullptr); |
188 | }}, |
189 | {Op::Get("dyn.nn.pad" ), |
190 | [this](const CallNode* call_node) { |
191 | auto args = PrepareArgs(call_node); |
192 | const ConstantNode* pad_width = args[1].as<ConstantNode>(); |
193 | const ConstantNode* pad_fill = args[2].as<ConstantNode>(); |
194 | if (pad_width && pad_fill) { |
195 | ICHECK_EQ(pad_fill->data->ndim, 0); // pad_val is 1d |
196 | ICHECK_EQ(pad_width->data->ndim, 2); // pad_width is 2d |
197 | |
198 | const PadAttrs* param = call_node->attrs.as<PadAttrs>(); |
199 | ICHECK(param); |
200 | |
201 | Expr pad_value = args[2]; |
202 | return MakePad(call_node->args[0], ToMatrix(pad_width->data), pad_value, |
203 | param->pad_mode); |
204 | } |
205 | return Expr(nullptr); |
206 | }}, |
207 | {Op::Get("dyn.strided_slice" ), |
208 | [this](const CallNode* call_node) { |
209 | auto args = PrepareArgs(call_node); |
210 | const ConstantNode* begin = args[1].as<ConstantNode>(); |
211 | const ConstantNode* end = args[2].as<ConstantNode>(); |
212 | const ConstantNode* stride = args[3].as<ConstantNode>(); |
213 | if (begin && end && stride) { |
214 | ICHECK_EQ(begin->data->ndim, 1); |
215 | ICHECK_EQ(end->data->ndim, 1); |
216 | ICHECK_EQ(stride->data->ndim, 1); |
217 | const StridedSliceAttrs* param = call_node->attrs.as<StridedSliceAttrs>(); |
218 | ICHECK(param); |
219 | return MakeStridedSlice(call_node->args[0], ToVector(begin->data), ToVector(end->data), |
220 | ToVector(stride->data), param->slice_mode); |
221 | } |
222 | return Expr(nullptr); |
223 | }}, |
224 | {Op::Get("dyn.sparse_to_dense" ), |
225 | [this](const CallNode* call_node) { |
226 | auto args = PrepareArgs(call_node); |
227 | const ConstantNode* output_shape = args[3].as<ConstantNode>(); |
228 | if (output_shape) { |
229 | ICHECK_EQ(output_shape->data->ndim, 1); |
230 | return MakeSparseToDense(call_node->args[0], ToVector(output_shape->data), |
231 | call_node->args[1], call_node->args[2]); |
232 | } |
233 | return Expr(nullptr); |
234 | }}, |
235 | }; |
236 | Map<BaseFunc, GlobalVar> vars; |
237 | for (auto kv : mod_->functions) { |
238 | vars.Set(kv.second, kv.first); |
239 | } |
240 | gv_ = vars[func_]; |
241 | } |
242 | |
243 | Expr GetCurExpr(const Expr& original_expr) { |
244 | if (original_expr.as<FunctionNode>()) { |
245 | return mod_->Lookup(gv_); |
246 | } else { |
247 | return mod_->Lookup(gv_).as<FunctionNode>()->body; |
248 | } |
249 | } |
250 | |
251 | Expr PrepareInput(const Expr& expr) { |
252 | BaseFunc func; |
253 | if (auto* func_node = expr.as<BaseFuncNode>()) { |
254 | func = GetRef<BaseFunc>(func_node); |
255 | } else { |
256 | func = |
257 | relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod_), {}); |
258 | } |
259 | mod_->Update(gv_, func); |
260 | |
261 | mod_ = transform::FoldConstant()(mod_); |
262 | transform::InferTypeLocal(GetCurExpr(expr)); |
263 | mod_ = transform::FoldConstant()(mod_); |
264 | transform::InferTypeLocal(GetCurExpr(expr)); |
265 | |
266 | Expr out; |
267 | if (expr.as<FunctionNode>()) { |
268 | out = mod_->Lookup(gv_); |
269 | } else { |
270 | out = mod_->Lookup(gv_).as<FunctionNode>()->body; |
271 | } |
272 | return out; |
273 | } |
274 | |
275 | std::vector<Expr> PrepareArgs(const CallNode* call_node) { |
276 | std::vector<Expr> args; |
277 | for (auto arg : call_node->args) { |
278 | if (arg.as<ConstantNode>()) { |
279 | args.emplace_back(arg); |
280 | } else { |
281 | args.emplace_back(PrepareInput(arg)); |
282 | } |
283 | } |
284 | return args; |
285 | } |
286 | |
287 | private: |
288 | Expr Rewrite_(const CallNode* pre, const Expr& post) override { |
289 | if (const CallNode* call_node = post.as<CallNode>()) { |
290 | if (op_map_.count(call_node->op)) { |
291 | auto out = op_map_[call_node->op](call_node); |
292 | if (out.defined()) { |
293 | return out; |
294 | } |
295 | } |
296 | } |
297 | return post; |
298 | } |
299 | |
300 | Expr DispatchVisitExpr(const Expr& expr) override { |
301 | auto post = MixedModeMutator::DispatchVisitExpr(expr); |
302 | if (auto op = post.as<FunctionNode>()) { |
303 | return Function(op->params, op->body, NullValue<Type>(), op->type_params, op->attrs); |
304 | } |
305 | return post; |
306 | } |
307 | |
308 | std::unordered_map<Expr, std::function<Expr(const CallNode*)>, ObjectPtrHash, ObjectPtrEqual> |
309 | op_map_; |
310 | IRModule mod_; |
311 | Function func_; |
312 | GlobalVar gv_; |
313 | }; |
314 | |
315 | Expr DynamicToStatic(Function f, IRModule m) { |
316 | DynamicToStaticMutator mutator(m, f); |
317 | Expr expr = mutator.Mutate(f); |
318 | Expr out = mutator.PrepareInput(expr); |
319 | return out; |
320 | } |
321 | |
322 | namespace transform { |
323 | |
324 | Pass DynamicToStatic() { |
325 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
326 | [=](Function f, IRModule m, PassContext pc) { |
327 | return Downcast<Function>(DynamicToStatic(f, m)); |
328 | }; |
329 | return CreateFunctionPass(pass_func, 2, "DynamicToStatic" , {}); |
330 | } |
331 | |
332 | TVM_REGISTER_GLOBAL("relay._transform.DynamicToStatic" ).set_body_typed([]() { |
333 | return DynamicToStatic(); |
334 | }); |
335 | |
336 | } // namespace transform |
337 | } // namespace relay |
338 | } // namespace tvm |
339 | |