1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file util.cc |
23 | * |
24 | * \brief Utility functions for Relay. |
25 | */ |
26 | #include <tvm/ir/type_functor.h> |
27 | #include <tvm/relay/analysis.h> |
28 | #include <tvm/relay/attrs/algorithm.h> |
29 | #include <tvm/relay/expr_functor.h> |
30 | #include <tvm/relay/op.h> |
31 | #include <tvm/relay/op_attr_types.h> |
32 | #include <tvm/relay/pattern_functor.h> |
33 | |
34 | #include "../transforms/pass_utils.h" |
35 | |
36 | namespace tvm { |
37 | namespace relay { |
38 | |
39 | template <typename T> |
40 | struct InsertionSet { |
41 | std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual> set; |
42 | std::vector<T> data; |
43 | void Insert(const T& t) { |
44 | if (set.count(t) == 0) { |
45 | set.insert(t); |
46 | data.push_back(t); |
47 | } |
48 | } |
49 | }; |
50 | |
51 | class TypeVarTVisitor : public TypeVisitor { |
52 | public: |
53 | TypeVarTVisitor(InsertionSet<TypeVar>* type_vars, InsertionSet<TypeVar>* bound_type_vars) |
54 | : type_vars_(type_vars), bound_type_vars_(bound_type_vars) {} |
55 | |
56 | void VisitType_(const TypeVarNode* tp) final { |
57 | TypeVar var = GetRef<TypeVar>(tp); |
58 | type_vars_->Insert(var); |
59 | } |
60 | |
61 | void VisitType_(const FuncTypeNode* f) final { |
62 | for (auto type_param : f->type_params) { |
63 | type_vars_->Insert(type_param); |
64 | bound_type_vars_->Insert(type_param); |
65 | } |
66 | TypeVisitor::VisitType_(f); |
67 | } |
68 | |
69 | private: |
70 | InsertionSet<TypeVar>* type_vars_; |
71 | InsertionSet<TypeVar>* bound_type_vars_; |
72 | }; |
73 | |
74 | class TypeVarEVisitor : private MixedModeVisitor { |
75 | public: |
76 | explicit TypeVarEVisitor(const IRModule& mod) : mod_(mod) {} |
77 | |
78 | Array<TypeVar> CollectFree() { |
79 | Array<TypeVar> ret; |
80 | for (const auto& v : type_vars_.data) { |
81 | if (bound_type_vars_.set.count(v) == 0) { |
82 | ret.push_back(v); |
83 | } |
84 | } |
85 | return ret; |
86 | } |
87 | |
88 | Array<TypeVar> CollectBound() { |
89 | Array<TypeVar> ret; |
90 | for (const auto& v : bound_type_vars_.data) { |
91 | ret.push_back(v); |
92 | } |
93 | return ret; |
94 | } |
95 | |
96 | Array<TypeVar> CollectAll() { |
97 | Array<TypeVar> ret; |
98 | for (const auto& v : type_vars_.data) { |
99 | ret.push_back(v); |
100 | } |
101 | return ret; |
102 | } |
103 | |
104 | Array<TypeVar> Free(const Expr& expr) { |
105 | VisitExpr(expr); |
106 | return CollectFree(); |
107 | } |
108 | |
109 | Array<TypeVar> Free(const Type& type) { |
110 | VisitType(type); |
111 | return CollectFree(); |
112 | } |
113 | |
114 | Array<TypeVar> Bound(const Expr& expr) { |
115 | VisitExpr(expr); |
116 | return CollectBound(); |
117 | } |
118 | |
119 | Array<TypeVar> Bound(const Type& type) { |
120 | VisitType(type); |
121 | return CollectBound(); |
122 | } |
123 | |
124 | Array<TypeVar> All(const Expr& expr) { |
125 | VisitExpr(expr); |
126 | return CollectAll(); |
127 | } |
128 | |
129 | Array<TypeVar> All(const Type& type) { |
130 | VisitType(type); |
131 | return CollectAll(); |
132 | } |
133 | |
134 | using MixedModeVisitor::VisitExpr_; |
135 | |
136 | void VisitExpr_(const FunctionNode* f) final { |
137 | for (const auto& tp : f->type_params) { |
138 | type_vars_.Insert(tp); |
139 | bound_type_vars_.Insert(tp); |
140 | } |
141 | ExprVisitor::VisitExpr_(f); |
142 | } |
143 | |
144 | void VisitExpr_(const LetNode* op) final { |
145 | auto pre_visit = [this](const LetNode* op) { |
146 | this->VisitExpr(op->var); |
147 | this->VisitExpr(op->value); |
148 | }; |
149 | auto post_visit = [this](const LetNode* op) { |
150 | this->VisitExpr(op->body); |
151 | this->visit_counter_[op] += 1; |
152 | }; |
153 | ExpandANormalForm(op, pre_visit, post_visit); |
154 | } |
155 | |
156 | void VisitExpr_(const ConstructorNode* cn) final { |
157 | // for constructors, type vars will be bound in the module |
158 | auto data = mod_->LookupTypeDef(cn->belong_to); |
159 | for (const auto& tv : data->type_vars) { |
160 | type_vars_.Insert(tv); |
161 | bound_type_vars_.Insert(tv); |
162 | } |
163 | ExprVisitor::VisitExpr_(cn); |
164 | } |
165 | |
166 | void VisitType(const Type& t) final { |
167 | TypeVarTVisitor(&type_vars_, &bound_type_vars_).VisitType(t); |
168 | } |
169 | |
170 | private: |
171 | InsertionSet<TypeVar> type_vars_; |
172 | InsertionSet<TypeVar> bound_type_vars_; |
173 | const IRModule& mod_; |
174 | }; |
175 | |
176 | class VarVisitor : protected MixedModeVisitor, protected PatternVisitor { |
177 | public: |
178 | Array<Var> Free(const Expr& expr) { |
179 | this->VisitExpr(expr); |
180 | Array<Var> ret; |
181 | for (const auto& v : vars_.data) { |
182 | if (bound_vars_.set.count(v) == 0) { |
183 | ret.push_back(v); |
184 | } |
185 | } |
186 | return ret; |
187 | } |
188 | |
189 | Array<Var> Collect() { |
190 | Array<Var> ret; |
191 | for (const auto& v : bound_vars_.data) { |
192 | ret.push_back(v); |
193 | } |
194 | return ret; |
195 | } |
196 | |
197 | Array<Var> Bound(const Expr& expr) { |
198 | this->VisitExpr(expr); |
199 | return Collect(); |
200 | } |
201 | |
202 | Array<Var> Bound(const Pattern& pat) { |
203 | this->VisitPattern(pat); |
204 | return Collect(); |
205 | } |
206 | |
207 | Array<Var> All(const Expr& expr) { |
208 | this->VisitExpr(expr); |
209 | Array<Var> ret; |
210 | for (const auto& v : vars_.data) { |
211 | ret.push_back(v); |
212 | } |
213 | return ret; |
214 | } |
215 | |
216 | void MarkBounded(const Var& v) { |
217 | bound_vars_.Insert(v); |
218 | vars_.Insert(v); |
219 | } |
220 | |
221 | using MixedModeVisitor::VisitExpr_; |
222 | |
223 | void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef<Var>(var)); } |
224 | |
225 | void VisitExpr_(const FunctionNode* op) final { |
226 | for (const auto& param : op->params) { |
227 | MarkBounded(param); |
228 | } |
229 | VisitExpr(op->body); |
230 | } |
231 | |
232 | void VisitExpr_(const LetNode* op) final { |
233 | Expr let = GetRef<Let>(op); |
234 | while (auto let_node = let.as<LetNode>()) { |
235 | MarkBounded(let_node->var); |
236 | VisitExpr(let_node->value); |
237 | let = let_node->body; |
238 | } |
239 | VisitExpr(let); |
240 | } |
241 | |
242 | void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); } |
243 | |
244 | void VisitPattern_(const PatternVarNode* op) final { MarkBounded(op->var); } |
245 | |
246 | private: |
247 | InsertionSet<Var> vars_; |
248 | InsertionSet<Var> bound_vars_; |
249 | }; |
250 | |
251 | tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const IRModule& mod) { |
252 | return TypeVarEVisitor(mod).Free(expr); |
253 | } |
254 | |
255 | tvm::Array<TypeVar> FreeTypeVars(const Type& type, const IRModule& mod) { |
256 | return TypeVarEVisitor(mod).Free(type); |
257 | } |
258 | |
259 | tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const IRModule& mod) { |
260 | return TypeVarEVisitor(mod).Bound(expr); |
261 | } |
262 | |
263 | tvm::Array<TypeVar> BoundTypeVars(const Type& type, const IRModule& mod) { |
264 | return TypeVarEVisitor(mod).Bound(type); |
265 | } |
266 | |
267 | tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const IRModule& mod) { |
268 | return TypeVarEVisitor(mod).All(expr); |
269 | } |
270 | |
271 | tvm::Array<TypeVar> AllTypeVars(const Type& type, const IRModule& mod) { |
272 | return TypeVarEVisitor(mod).All(type); |
273 | } |
274 | |
275 | tvm::Array<Var> FreeVars(const Expr& expr) { return VarVisitor().Free(expr); } |
276 | |
277 | tvm::Array<Var> BoundVars(const Expr& expr) { return VarVisitor().Bound(expr); } |
278 | |
279 | tvm::Array<Var> BoundVars(const Pattern& pat) { return VarVisitor().Bound(pat); } |
280 | |
281 | tvm::Array<Var> AllVars(const Expr& expr) { return VarVisitor().All(expr); } |
282 | |
283 | TVM_REGISTER_GLOBAL("relay.analysis.free_vars" ).set_body_typed(FreeVars); |
284 | |
285 | TVM_REGISTER_GLOBAL("relay.analysis.bound_vars" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
286 | ObjectRef x = args[0]; |
287 | if (x.as<ExprNode>()) { |
288 | *ret = BoundVars(Downcast<Expr>(x)); |
289 | } else { |
290 | *ret = BoundVars(Downcast<Pattern>(x)); |
291 | } |
292 | }); |
293 | |
294 | TVM_REGISTER_GLOBAL("relay.analysis.all_vars" ).set_body_typed(AllVars); |
295 | |
296 | TVM_REGISTER_GLOBAL("relay.analysis.free_type_vars" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
297 | ObjectRef x = args[0]; |
298 | IRModule mod = args[1]; |
299 | if (x.as<TypeNode>()) { |
300 | *ret = FreeTypeVars(Downcast<Type>(x), mod); |
301 | } else { |
302 | *ret = FreeTypeVars(Downcast<Expr>(x), mod); |
303 | } |
304 | }); |
305 | |
306 | TVM_REGISTER_GLOBAL("relay.analysis.bound_type_vars" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
307 | ObjectRef x = args[0]; |
308 | IRModule mod = args[1]; |
309 | if (x.as<TypeNode>()) { |
310 | *ret = BoundTypeVars(Downcast<Type>(x), mod); |
311 | } else { |
312 | *ret = BoundTypeVars(Downcast<Expr>(x), mod); |
313 | } |
314 | }); |
315 | |
316 | TVM_REGISTER_GLOBAL("relay.analysis.all_type_vars" ).set_body([](TVMArgs args, TVMRetValue* ret) { |
317 | ObjectRef x = args[0]; |
318 | IRModule mod = args[1]; |
319 | if (x.as<TypeNode>()) { |
320 | *ret = AllTypeVars(Downcast<Type>(x), mod); |
321 | } else { |
322 | *ret = AllTypeVars(Downcast<Expr>(x), mod); |
323 | } |
324 | }); |
325 | |
326 | class DtypeCollector : protected ExprVisitor, protected TypeVisitor { |
327 | public: |
328 | void VisitExpr(const Expr& expr) final { |
329 | if (expr->checked_type_.defined()) { |
330 | TypeVisitor::VisitType(expr->checked_type()); |
331 | } |
332 | ExprVisitor::VisitExpr(expr); |
333 | } |
334 | |
335 | void VisitType_(const TensorTypeNode* op) final { dtypes_.insert(DLDataType2String(op->dtype)); } |
336 | |
337 | Array<String> All(const Expr& expr) { |
338 | VisitExpr(expr); |
339 | |
340 | Array<String> res; |
341 | for (const auto& dtype : dtypes_) { |
342 | res.push_back(String(dtype)); |
343 | } |
344 | return res; |
345 | } |
346 | |
347 | private: |
348 | std::unordered_set<std::string> dtypes_; |
349 | }; |
350 | |
351 | tvm::Array<String> AllDtypes(const Expr& expr) { return DtypeCollector().All(expr); } |
352 | |
353 | TVM_REGISTER_GLOBAL("relay.analysis.all_dtypes" ).set_body_typed(AllDtypes); |
354 | |
355 | /*! |
356 | * \brief Get reference counter of each internal ExprNode in body. |
357 | * \param body The body expression. |
358 | * \return The reference count mapping. |
359 | */ |
360 | std::unordered_map<const Object*, size_t> GetExprRefCount(const Expr& body) { |
361 | class ExprRefCounter : private MixedModeVisitor { |
362 | public: |
363 | std::unordered_map<const Object*, size_t> Get(const Expr& body) { |
364 | this->VisitExpr(body); |
365 | return std::move(this->visit_counter_); |
366 | } |
367 | }; |
368 | return ExprRefCounter().Get(body); |
369 | } |
370 | |
371 | template <typename T> |
372 | bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) { |
373 | ICHECK_EQ(tensor->device.device_type, kDLCPU); |
374 | ICHECK(tensor->strides == nullptr); |
375 | ICHECK_EQ(tensor->byte_offset, 0); |
376 | const T* data = static_cast<const T*>(tensor->data); |
377 | int64_t num_elems = 1; |
378 | for (int i = 0; i < tensor->ndim; ++i) { |
379 | num_elems *= tensor->shape[i]; |
380 | } |
381 | |
382 | for (int64_t i = 0; i < num_elems; i++) { |
383 | if (*data < value) { |
384 | return false; |
385 | } |
386 | data++; |
387 | } |
388 | return true; |
389 | } |
390 | |
391 | bool IsAllPositiveConstant(const Expr& expr) { |
392 | // Cache the operators that are checked recursively to reduce lookup overhead. |
393 | static const auto& expand_dims_op = Op::Get("expand_dims" ); |
394 | static const auto& reshape_op = Op::Get("reshape" ); |
395 | static const auto& transpose_op = Op::Get("transpose" ); |
396 | static const auto& squeeze_op = Op::Get("squeeze" ); |
397 | static const auto& repeat_op = Op::Get("repeat" ); |
398 | |
399 | // peel through a few common transform ops. |
400 | if (const auto* constant = expr.as<ConstantNode>()) { |
401 | const auto& tensor = constant->data; |
402 | const auto& dtype = tensor->dtype; |
403 | if (dtype.lanes != 1) { |
404 | return false; |
405 | } else if (dtype.code == kDLFloat && dtype.bits == 32) { |
406 | return IsNDArrayAllGreaterEqual<float>(tensor, 0); |
407 | } else if (dtype.code == kDLFloat && dtype.bits == 64) { |
408 | return IsNDArrayAllGreaterEqual<double>(tensor, 0); |
409 | } else if (dtype.code == kDLInt && dtype.bits == 8) { |
410 | return IsNDArrayAllGreaterEqual<int8_t>(tensor, 0); |
411 | } else if (dtype.code == kDLInt && dtype.bits == 32) { |
412 | return IsNDArrayAllGreaterEqual<int32_t>(tensor, 0); |
413 | } else if (dtype.code == kDLUInt && dtype.bits == 8) { |
414 | return IsNDArrayAllGreaterEqual<uint8_t>(tensor, 0); |
415 | } else if (dtype.code == kDLUInt && dtype.bits == 32) { |
416 | return IsNDArrayAllGreaterEqual<uint32_t>(tensor, 0); |
417 | } else { |
418 | return false; |
419 | } |
420 | } else if (const auto* op = expr.as<CallNode>()) { |
421 | // tail recursion. |
422 | if (op->op == expand_dims_op || op->op == reshape_op || op->op == transpose_op || |
423 | op->op == squeeze_op || op->op == repeat_op) { |
424 | return IsAllPositiveConstant(op->args[0]); |
425 | } else { |
426 | return false; |
427 | } |
428 | } else { |
429 | return false; |
430 | } |
431 | } |
432 | |
433 | Type TypeSubst(const Type& type, const TypeVar& tvar, const Type& subst) { |
434 | return TypeSubst(type, tvm::Map<TypeVar, Type>({{tvar, subst}})); |
435 | } |
436 | |
437 | Expr TypeSubst(const Expr& expr, const TypeVar& tvar, const Type& subst) { |
438 | return TypeSubst(expr, tvm::Map<TypeVar, Type>({{tvar, subst}})); |
439 | } |
440 | |
441 | Type TypeSubst(const Type& type, const tvm::Map<TypeVar, Type>& subst_map) { |
442 | return Bind(type, subst_map); |
443 | } |
444 | |
445 | Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map) { |
446 | class TypeSubstMutator : public ExprMutator, public PatternMutator { |
447 | public: |
448 | explicit TypeSubstMutator(const tvm::Map<TypeVar, Type>& subst_map) : subst_map_(subst_map) {} |
449 | Type VisitType(const Type& t) final { return TypeSubst(t, subst_map_); } |
450 | Var VisitVar(const Var& v) final { return Downcast<Var>(VisitExpr(v)); } |
451 | |
452 | Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); } |
453 | |
454 | Clause VisitClause(const Clause& c) final { |
455 | Pattern pat = VisitPattern(c->lhs); |
456 | return Clause(pat, VisitExpr(c->rhs)); |
457 | } |
458 | |
459 | private: |
460 | const tvm::Map<TypeVar, Type>& subst_map_; |
461 | }; |
462 | ICHECK(WellFormed(expr)); |
463 | auto ret = TypeSubstMutator(subst_map).VisitExpr(expr); |
464 | ICHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size()); |
465 | ICHECK(WellFormed(ret)); |
466 | return ret; |
467 | } |
468 | |
469 | struct IsDynamicVisitor : public TypeVisitor { |
470 | bool is_dyn{false}; |
471 | void VisitType_(const TensorTypeNode* tt) { |
472 | for (auto dim : tt->shape) { |
473 | if (dim.as<tir::IntImmNode>() == nullptr) { |
474 | is_dyn = true; |
475 | break; |
476 | } |
477 | } |
478 | } |
479 | }; |
480 | |
481 | bool IsDynamic(const Type& ty) { |
482 | IsDynamicVisitor v; |
483 | v.VisitType(ty); |
484 | return v.is_dyn; |
485 | } |
486 | |
487 | TVM_REGISTER_GLOBAL("relay.ir.IsDynamic" ).set_body_typed(IsDynamic); |
488 | |
489 | bool IsDataDependent(const CallNode* call) { |
490 | static auto tshape_data_dependent = Op::GetAttrMap<TShapeDataDependent>("TShapeDataDependent" ); |
491 | Op op = Downcast<Op>(call->op); |
492 | |
493 | if (!tshape_data_dependent.count(op)) { |
494 | return false; |
495 | } |
496 | |
497 | if (op->name == "strided_slice" ) { |
498 | if (const auto* attrs = call->attrs.as<StridedSliceAttrs>()) { |
499 | if (attrs->begin && attrs->end && attrs->strides) { |
500 | // not data dependent if begin, end and strides exist |
501 | return false; |
502 | } |
503 | } |
504 | } |
505 | |
506 | for (auto req : tshape_data_dependent[op]) { |
507 | if (req->value != 0) return true; |
508 | } |
509 | return false; |
510 | } |
511 | } // namespace relay |
512 | } // namespace tvm |
513 | |