1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file dynamic_to_static.cc
23 * \brief Rewrite Dynamic Operations to Static operations where possible
24 */
25#include <tvm/relay/attrs/algorithm.h>
26#include <tvm/relay/attrs/image.h>
27#include <tvm/relay/expr_functor.h>
28#include <tvm/relay/transform.h>
29
30#include "pattern_utils.h"
31
32namespace tvm {
33namespace relay {
34
35class DynamicToStaticMutator : public MixedModeMutator {
36 public:
37 DynamicToStaticMutator(IRModule mod, Function func) : mod_(mod), func_(func) {
38 op_map_ = {
39 {Op::Get("dyn.reshape"),
40 [this](const CallNode* call_node) {
41 auto args = PrepareArgs(call_node);
42 if (const ConstantNode* shape = args[1].as<ConstantNode>()) {
43 ICHECK_EQ(shape->data->ndim, 1);
44 return MakeReshape(call_node->args[0], ToVector(shape->data));
45 }
46 return Expr(nullptr);
47 }},
48 {Op::Get("dyn.squeeze"),
49 [this](const CallNode* call_node) {
50 auto args = PrepareArgs(call_node);
51 if (const ConstantNode* axis = args[1].as<ConstantNode>()) {
52 ICHECK_EQ(axis->data->ndim, 1);
53 return MakeSqueeze(call_node->args[0], ToVector(axis->data));
54 }
55 return Expr(nullptr);
56 }},
57 {Op::Get("dyn.tile"),
58 [this](const CallNode* call_node) {
59 auto args = PrepareArgs(call_node);
60 if (const ConstantNode* reps = args[1].as<ConstantNode>()) {
61 ICHECK_EQ(reps->data->ndim, 1);
62 return MakeTile(call_node->args[0], ToVector(reps->data));
63 }
64 return Expr(nullptr);
65 }},
66 {Op::Get("dyn.topk"),
67 [this](const CallNode* call_node) {
68 auto args = PrepareArgs(call_node);
69 if (const ConstantNode* k = args[1].as<ConstantNode>()) {
70 const TopKAttrs* param = call_node->attrs.as<TopKAttrs>();
71 ICHECK(param);
72 return MakeTopK(call_node->args[0], static_cast<int>(ToScalar(k->data, 0)),
73 param->axis, param->ret_type, param->is_ascend, param->dtype);
74 }
75 return Expr(nullptr);
76 }},
77 {Op::Get("dyn.broadcast_to"),
78 [this](const CallNode* call_node) {
79 auto args = PrepareArgs(call_node);
80 if (const ConstantNode* shape = args[1].as<ConstantNode>()) {
81 ICHECK_EQ(shape->data->ndim, 1);
82 return MakeBroadCastTo(call_node->args[0], ToVector(shape->data));
83 }
84 return Expr(nullptr);
85 }},
86 {Op::Get("dyn.zeros"),
87 [this](const CallNode* call_node) {
88 auto args = PrepareArgs(call_node);
89 if (const ConstantNode* shape = args[0].as<ConstantNode>()) {
90 const InitOpAttrs* param = call_node->attrs.as<InitOpAttrs>();
91 ICHECK(param);
92 return MakeZeros(ToVector(shape->data), param->dtype);
93 }
94 return Expr(nullptr);
95 }},
96 {Op::Get("dyn.ones"),
97 [this](const CallNode* call_node) {
98 auto args = PrepareArgs(call_node);
99 if (const ConstantNode* shape = args[0].as<ConstantNode>()) {
100 const InitOpAttrs* param = call_node->attrs.as<InitOpAttrs>();
101 ICHECK(param);
102 return MakeOnes(ToVector(shape->data), param->dtype);
103 }
104 return Expr(nullptr);
105 }},
106 {Op::Get("dyn.one_hot"),
107 [this](const CallNode* call_node) {
108 auto args = PrepareArgs(call_node);
109 if (const ConstantNode* depth = args[3].as<ConstantNode>()) {
110 const OneHotAttrs* param = call_node->attrs.as<OneHotAttrs>();
111 ICHECK(param);
112 return MakeOneHot(call_node->args[0], call_node->args[1], call_node->args[2],
113 static_cast<int>(ToScalar(depth->data, 0)), param->axis,
114 param->dtype);
115 }
116 return Expr(nullptr);
117 }},
118 {Op::Get("dyn.image.resize2d"),
119 [this](const CallNode* call_node) {
120 auto args = PrepareArgs(call_node);
121 if (const ConstantNode* size = args[1].as<ConstantNode>()) {
122 if (const ConstantNode* roi = args[2].as<ConstantNode>()) {
123 const Resize2DAttrs* param = call_node->attrs.as<Resize2DAttrs>();
124 ICHECK(param);
125 auto size_int = ToVector(size->data);
126 Array<PrimExpr> size_prim;
127 for (size_t i = 0; i < size_int.size(); ++i) {
128 size_prim.push_back(size_int[i]);
129 }
130 auto roi_vec = ToFloatVector(roi->data);
131 Array<FloatImm> roi_prim;
132 for (size_t i = 0; i < roi_vec.size(); ++i) {
133 roi_prim.push_back(roi_vec[i]);
134 }
135 return MakeResize2D(call_node->args[0], size_prim, roi_prim, param->layout,
136 param->method, param->coordinate_transformation_mode,
137 param->rounding_method, param->cubic_alpha, param->cubic_exclude,
138 param->extrapolation_value, param->out_dtype);
139 }
140 }
141 return Expr(nullptr);
142 }},
143 {Op::Get("dyn.full"),
144 [this](const CallNode* call_node) {
145 auto args = PrepareArgs(call_node);
146 if (const ConstantNode* shape = args[1].as<ConstantNode>()) {
147 ICHECK_EQ(shape->data->ndim, 1);
148 const InitOpAttrs* param = call_node->attrs.as<InitOpAttrs>();
149 ICHECK(param);
150 return MakeFull(call_node->args[0], ToVector(shape->data), param->dtype);
151 }
152 return Expr(nullptr);
153 }},
154 {Op::Get("dyn.nn.upsampling"),
155 [this](const CallNode* call_node) {
156 auto args = PrepareArgs(call_node);
157 const ConstantNode* scale_h = args[1].as<ConstantNode>();
158 const ConstantNode* scale_w = args[2].as<ConstantNode>();
159 if (scale_h && scale_w) {
160 ICHECK_EQ(scale_h->data->ndim, 0);
161 ICHECK_EQ(scale_w->data->ndim, 0);
162 const UpSamplingAttrs* param = call_node->attrs.as<UpSamplingAttrs>();
163 ICHECK(param);
164 return MakeUpSampling(call_node->args[0], ToScalar(scale_h->data),
165 ToScalar(scale_w->data), param->layout, param->method,
166 param->align_corners);
167 }
168 return Expr(nullptr);
169 }},
170 {Op::Get("dyn.nn.upsampling3d"),
171 [this](const CallNode* call_node) {
172 auto args = PrepareArgs(call_node);
173 const ConstantNode* scale_d = args[1].as<ConstantNode>();
174 const ConstantNode* scale_h = args[2].as<ConstantNode>();
175 const ConstantNode* scale_w = args[3].as<ConstantNode>();
176 if (scale_d && scale_h && scale_w) {
177 ICHECK_EQ(scale_d->data->ndim, 0);
178 ICHECK_EQ(scale_h->data->ndim, 0);
179 ICHECK_EQ(scale_w->data->ndim, 0);
180 const UpSampling3DAttrs* param = call_node->attrs.as<UpSampling3DAttrs>();
181 ICHECK(param);
182 return MakeUpSampling3D(call_node->args[0], ToScalar(scale_d->data),
183 ToScalar(scale_h->data), ToScalar(scale_w->data),
184 param->layout, param->method,
185 param->coordinate_transformation_mode);
186 }
187 return Expr(nullptr);
188 }},
189 {Op::Get("dyn.nn.pad"),
190 [this](const CallNode* call_node) {
191 auto args = PrepareArgs(call_node);
192 const ConstantNode* pad_width = args[1].as<ConstantNode>();
193 const ConstantNode* pad_fill = args[2].as<ConstantNode>();
194 if (pad_width && pad_fill) {
195 ICHECK_EQ(pad_fill->data->ndim, 0); // pad_val is 1d
196 ICHECK_EQ(pad_width->data->ndim, 2); // pad_width is 2d
197
198 const PadAttrs* param = call_node->attrs.as<PadAttrs>();
199 ICHECK(param);
200
201 Expr pad_value = args[2];
202 return MakePad(call_node->args[0], ToMatrix(pad_width->data), pad_value,
203 param->pad_mode);
204 }
205 return Expr(nullptr);
206 }},
207 {Op::Get("dyn.strided_slice"),
208 [this](const CallNode* call_node) {
209 auto args = PrepareArgs(call_node);
210 const ConstantNode* begin = args[1].as<ConstantNode>();
211 const ConstantNode* end = args[2].as<ConstantNode>();
212 const ConstantNode* stride = args[3].as<ConstantNode>();
213 if (begin && end && stride) {
214 ICHECK_EQ(begin->data->ndim, 1);
215 ICHECK_EQ(end->data->ndim, 1);
216 ICHECK_EQ(stride->data->ndim, 1);
217 const StridedSliceAttrs* param = call_node->attrs.as<StridedSliceAttrs>();
218 ICHECK(param);
219 return MakeStridedSlice(call_node->args[0], ToVector(begin->data), ToVector(end->data),
220 ToVector(stride->data), param->slice_mode);
221 }
222 return Expr(nullptr);
223 }},
224 {Op::Get("dyn.sparse_to_dense"),
225 [this](const CallNode* call_node) {
226 auto args = PrepareArgs(call_node);
227 const ConstantNode* output_shape = args[3].as<ConstantNode>();
228 if (output_shape) {
229 ICHECK_EQ(output_shape->data->ndim, 1);
230 return MakeSparseToDense(call_node->args[0], ToVector(output_shape->data),
231 call_node->args[1], call_node->args[2]);
232 }
233 return Expr(nullptr);
234 }},
235 };
236 Map<BaseFunc, GlobalVar> vars;
237 for (auto kv : mod_->functions) {
238 vars.Set(kv.second, kv.first);
239 }
240 gv_ = vars[func_];
241 }
242
243 Expr GetCurExpr(const Expr& original_expr) {
244 if (original_expr.as<FunctionNode>()) {
245 return mod_->Lookup(gv_);
246 } else {
247 return mod_->Lookup(gv_).as<FunctionNode>()->body;
248 }
249 }
250
251 Expr PrepareInput(const Expr& expr) {
252 BaseFunc func;
253 if (auto* func_node = expr.as<BaseFuncNode>()) {
254 func = GetRef<BaseFunc>(func_node);
255 } else {
256 func =
257 relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod_), {});
258 }
259 mod_->Update(gv_, func);
260
261 mod_ = transform::FoldConstant()(mod_);
262 transform::InferTypeLocal(GetCurExpr(expr));
263 mod_ = transform::FoldConstant()(mod_);
264 transform::InferTypeLocal(GetCurExpr(expr));
265
266 Expr out;
267 if (expr.as<FunctionNode>()) {
268 out = mod_->Lookup(gv_);
269 } else {
270 out = mod_->Lookup(gv_).as<FunctionNode>()->body;
271 }
272 return out;
273 }
274
275 std::vector<Expr> PrepareArgs(const CallNode* call_node) {
276 std::vector<Expr> args;
277 for (auto arg : call_node->args) {
278 if (arg.as<ConstantNode>()) {
279 args.emplace_back(arg);
280 } else {
281 args.emplace_back(PrepareInput(arg));
282 }
283 }
284 return args;
285 }
286
287 private:
288 Expr Rewrite_(const CallNode* pre, const Expr& post) override {
289 if (const CallNode* call_node = post.as<CallNode>()) {
290 if (op_map_.count(call_node->op)) {
291 auto out = op_map_[call_node->op](call_node);
292 if (out.defined()) {
293 return out;
294 }
295 }
296 }
297 return post;
298 }
299
300 Expr DispatchVisitExpr(const Expr& expr) override {
301 auto post = MixedModeMutator::DispatchVisitExpr(expr);
302 if (auto op = post.as<FunctionNode>()) {
303 return Function(op->params, op->body, NullValue<Type>(), op->type_params, op->attrs);
304 }
305 return post;
306 }
307
308 std::unordered_map<Expr, std::function<Expr(const CallNode*)>, ObjectPtrHash, ObjectPtrEqual>
309 op_map_;
310 IRModule mod_;
311 Function func_;
312 GlobalVar gv_;
313};
314
315Expr DynamicToStatic(Function f, IRModule m) {
316 DynamicToStaticMutator mutator(m, f);
317 Expr expr = mutator.Mutate(f);
318 Expr out = mutator.PrepareInput(expr);
319 return out;
320}
321
322namespace transform {
323
324Pass DynamicToStatic() {
325 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
326 [=](Function f, IRModule m, PassContext pc) {
327 return Downcast<Function>(DynamicToStatic(f, m));
328 };
329 return CreateFunctionPass(pass_func, 2, "DynamicToStatic", {});
330}
331
332TVM_REGISTER_GLOBAL("relay._transform.DynamicToStatic").set_body_typed([]() {
333 return DynamicToStatic();
334});
335
336} // namespace transform
337} // namespace relay
338} // namespace tvm
339