1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file util.cc
23 *
24 * \brief Utility functions for Relay.
25 */
26#include <tvm/ir/type_functor.h>
27#include <tvm/relay/analysis.h>
28#include <tvm/relay/attrs/algorithm.h>
29#include <tvm/relay/expr_functor.h>
30#include <tvm/relay/op.h>
31#include <tvm/relay/op_attr_types.h>
32#include <tvm/relay/pattern_functor.h>
33
34#include "../transforms/pass_utils.h"
35
36namespace tvm {
37namespace relay {
38
39template <typename T>
40struct InsertionSet {
41 std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual> set;
42 std::vector<T> data;
43 void Insert(const T& t) {
44 if (set.count(t) == 0) {
45 set.insert(t);
46 data.push_back(t);
47 }
48 }
49};
50
51class TypeVarTVisitor : public TypeVisitor {
52 public:
53 TypeVarTVisitor(InsertionSet<TypeVar>* type_vars, InsertionSet<TypeVar>* bound_type_vars)
54 : type_vars_(type_vars), bound_type_vars_(bound_type_vars) {}
55
56 void VisitType_(const TypeVarNode* tp) final {
57 TypeVar var = GetRef<TypeVar>(tp);
58 type_vars_->Insert(var);
59 }
60
61 void VisitType_(const FuncTypeNode* f) final {
62 for (auto type_param : f->type_params) {
63 type_vars_->Insert(type_param);
64 bound_type_vars_->Insert(type_param);
65 }
66 TypeVisitor::VisitType_(f);
67 }
68
69 private:
70 InsertionSet<TypeVar>* type_vars_;
71 InsertionSet<TypeVar>* bound_type_vars_;
72};
73
74class TypeVarEVisitor : private MixedModeVisitor {
75 public:
76 explicit TypeVarEVisitor(const IRModule& mod) : mod_(mod) {}
77
78 Array<TypeVar> CollectFree() {
79 Array<TypeVar> ret;
80 for (const auto& v : type_vars_.data) {
81 if (bound_type_vars_.set.count(v) == 0) {
82 ret.push_back(v);
83 }
84 }
85 return ret;
86 }
87
88 Array<TypeVar> CollectBound() {
89 Array<TypeVar> ret;
90 for (const auto& v : bound_type_vars_.data) {
91 ret.push_back(v);
92 }
93 return ret;
94 }
95
96 Array<TypeVar> CollectAll() {
97 Array<TypeVar> ret;
98 for (const auto& v : type_vars_.data) {
99 ret.push_back(v);
100 }
101 return ret;
102 }
103
104 Array<TypeVar> Free(const Expr& expr) {
105 VisitExpr(expr);
106 return CollectFree();
107 }
108
109 Array<TypeVar> Free(const Type& type) {
110 VisitType(type);
111 return CollectFree();
112 }
113
114 Array<TypeVar> Bound(const Expr& expr) {
115 VisitExpr(expr);
116 return CollectBound();
117 }
118
119 Array<TypeVar> Bound(const Type& type) {
120 VisitType(type);
121 return CollectBound();
122 }
123
124 Array<TypeVar> All(const Expr& expr) {
125 VisitExpr(expr);
126 return CollectAll();
127 }
128
129 Array<TypeVar> All(const Type& type) {
130 VisitType(type);
131 return CollectAll();
132 }
133
134 using MixedModeVisitor::VisitExpr_;
135
136 void VisitExpr_(const FunctionNode* f) final {
137 for (const auto& tp : f->type_params) {
138 type_vars_.Insert(tp);
139 bound_type_vars_.Insert(tp);
140 }
141 ExprVisitor::VisitExpr_(f);
142 }
143
144 void VisitExpr_(const LetNode* op) final {
145 auto pre_visit = [this](const LetNode* op) {
146 this->VisitExpr(op->var);
147 this->VisitExpr(op->value);
148 };
149 auto post_visit = [this](const LetNode* op) {
150 this->VisitExpr(op->body);
151 this->visit_counter_[op] += 1;
152 };
153 ExpandANormalForm(op, pre_visit, post_visit);
154 }
155
156 void VisitExpr_(const ConstructorNode* cn) final {
157 // for constructors, type vars will be bound in the module
158 auto data = mod_->LookupTypeDef(cn->belong_to);
159 for (const auto& tv : data->type_vars) {
160 type_vars_.Insert(tv);
161 bound_type_vars_.Insert(tv);
162 }
163 ExprVisitor::VisitExpr_(cn);
164 }
165
166 void VisitType(const Type& t) final {
167 TypeVarTVisitor(&type_vars_, &bound_type_vars_).VisitType(t);
168 }
169
170 private:
171 InsertionSet<TypeVar> type_vars_;
172 InsertionSet<TypeVar> bound_type_vars_;
173 const IRModule& mod_;
174};
175
176class VarVisitor : protected MixedModeVisitor, protected PatternVisitor {
177 public:
178 Array<Var> Free(const Expr& expr) {
179 this->VisitExpr(expr);
180 Array<Var> ret;
181 for (const auto& v : vars_.data) {
182 if (bound_vars_.set.count(v) == 0) {
183 ret.push_back(v);
184 }
185 }
186 return ret;
187 }
188
189 Array<Var> Collect() {
190 Array<Var> ret;
191 for (const auto& v : bound_vars_.data) {
192 ret.push_back(v);
193 }
194 return ret;
195 }
196
197 Array<Var> Bound(const Expr& expr) {
198 this->VisitExpr(expr);
199 return Collect();
200 }
201
202 Array<Var> Bound(const Pattern& pat) {
203 this->VisitPattern(pat);
204 return Collect();
205 }
206
207 Array<Var> All(const Expr& expr) {
208 this->VisitExpr(expr);
209 Array<Var> ret;
210 for (const auto& v : vars_.data) {
211 ret.push_back(v);
212 }
213 return ret;
214 }
215
216 void MarkBounded(const Var& v) {
217 bound_vars_.Insert(v);
218 vars_.Insert(v);
219 }
220
221 using MixedModeVisitor::VisitExpr_;
222
223 void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef<Var>(var)); }
224
225 void VisitExpr_(const FunctionNode* op) final {
226 for (const auto& param : op->params) {
227 MarkBounded(param);
228 }
229 VisitExpr(op->body);
230 }
231
232 void VisitExpr_(const LetNode* op) final {
233 Expr let = GetRef<Let>(op);
234 while (auto let_node = let.as<LetNode>()) {
235 MarkBounded(let_node->var);
236 VisitExpr(let_node->value);
237 let = let_node->body;
238 }
239 VisitExpr(let);
240 }
241
242 void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); }
243
244 void VisitPattern_(const PatternVarNode* op) final { MarkBounded(op->var); }
245
246 private:
247 InsertionSet<Var> vars_;
248 InsertionSet<Var> bound_vars_;
249};
250
251tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const IRModule& mod) {
252 return TypeVarEVisitor(mod).Free(expr);
253}
254
255tvm::Array<TypeVar> FreeTypeVars(const Type& type, const IRModule& mod) {
256 return TypeVarEVisitor(mod).Free(type);
257}
258
259tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const IRModule& mod) {
260 return TypeVarEVisitor(mod).Bound(expr);
261}
262
263tvm::Array<TypeVar> BoundTypeVars(const Type& type, const IRModule& mod) {
264 return TypeVarEVisitor(mod).Bound(type);
265}
266
267tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const IRModule& mod) {
268 return TypeVarEVisitor(mod).All(expr);
269}
270
271tvm::Array<TypeVar> AllTypeVars(const Type& type, const IRModule& mod) {
272 return TypeVarEVisitor(mod).All(type);
273}
274
275tvm::Array<Var> FreeVars(const Expr& expr) { return VarVisitor().Free(expr); }
276
277tvm::Array<Var> BoundVars(const Expr& expr) { return VarVisitor().Bound(expr); }
278
279tvm::Array<Var> BoundVars(const Pattern& pat) { return VarVisitor().Bound(pat); }
280
281tvm::Array<Var> AllVars(const Expr& expr) { return VarVisitor().All(expr); }
282
283TVM_REGISTER_GLOBAL("relay.analysis.free_vars").set_body_typed(FreeVars);
284
285TVM_REGISTER_GLOBAL("relay.analysis.bound_vars").set_body([](TVMArgs args, TVMRetValue* ret) {
286 ObjectRef x = args[0];
287 if (x.as<ExprNode>()) {
288 *ret = BoundVars(Downcast<Expr>(x));
289 } else {
290 *ret = BoundVars(Downcast<Pattern>(x));
291 }
292});
293
294TVM_REGISTER_GLOBAL("relay.analysis.all_vars").set_body_typed(AllVars);
295
296TVM_REGISTER_GLOBAL("relay.analysis.free_type_vars").set_body([](TVMArgs args, TVMRetValue* ret) {
297 ObjectRef x = args[0];
298 IRModule mod = args[1];
299 if (x.as<TypeNode>()) {
300 *ret = FreeTypeVars(Downcast<Type>(x), mod);
301 } else {
302 *ret = FreeTypeVars(Downcast<Expr>(x), mod);
303 }
304});
305
306TVM_REGISTER_GLOBAL("relay.analysis.bound_type_vars").set_body([](TVMArgs args, TVMRetValue* ret) {
307 ObjectRef x = args[0];
308 IRModule mod = args[1];
309 if (x.as<TypeNode>()) {
310 *ret = BoundTypeVars(Downcast<Type>(x), mod);
311 } else {
312 *ret = BoundTypeVars(Downcast<Expr>(x), mod);
313 }
314});
315
316TVM_REGISTER_GLOBAL("relay.analysis.all_type_vars").set_body([](TVMArgs args, TVMRetValue* ret) {
317 ObjectRef x = args[0];
318 IRModule mod = args[1];
319 if (x.as<TypeNode>()) {
320 *ret = AllTypeVars(Downcast<Type>(x), mod);
321 } else {
322 *ret = AllTypeVars(Downcast<Expr>(x), mod);
323 }
324});
325
326class DtypeCollector : protected ExprVisitor, protected TypeVisitor {
327 public:
328 void VisitExpr(const Expr& expr) final {
329 if (expr->checked_type_.defined()) {
330 TypeVisitor::VisitType(expr->checked_type());
331 }
332 ExprVisitor::VisitExpr(expr);
333 }
334
335 void VisitType_(const TensorTypeNode* op) final { dtypes_.insert(DLDataType2String(op->dtype)); }
336
337 Array<String> All(const Expr& expr) {
338 VisitExpr(expr);
339
340 Array<String> res;
341 for (const auto& dtype : dtypes_) {
342 res.push_back(String(dtype));
343 }
344 return res;
345 }
346
347 private:
348 std::unordered_set<std::string> dtypes_;
349};
350
351tvm::Array<String> AllDtypes(const Expr& expr) { return DtypeCollector().All(expr); }
352
353TVM_REGISTER_GLOBAL("relay.analysis.all_dtypes").set_body_typed(AllDtypes);
354
355/*!
356 * \brief Get reference counter of each internal ExprNode in body.
357 * \param body The body expression.
358 * \return The reference count mapping.
359 */
360std::unordered_map<const Object*, size_t> GetExprRefCount(const Expr& body) {
361 class ExprRefCounter : private MixedModeVisitor {
362 public:
363 std::unordered_map<const Object*, size_t> Get(const Expr& body) {
364 this->VisitExpr(body);
365 return std::move(this->visit_counter_);
366 }
367 };
368 return ExprRefCounter().Get(body);
369}
370
371template <typename T>
372bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) {
373 ICHECK_EQ(tensor->device.device_type, kDLCPU);
374 ICHECK(tensor->strides == nullptr);
375 ICHECK_EQ(tensor->byte_offset, 0);
376 const T* data = static_cast<const T*>(tensor->data);
377 int64_t num_elems = 1;
378 for (int i = 0; i < tensor->ndim; ++i) {
379 num_elems *= tensor->shape[i];
380 }
381
382 for (int64_t i = 0; i < num_elems; i++) {
383 if (*data < value) {
384 return false;
385 }
386 data++;
387 }
388 return true;
389}
390
391bool IsAllPositiveConstant(const Expr& expr) {
392 // Cache the operators that are checked recursively to reduce lookup overhead.
393 static const auto& expand_dims_op = Op::Get("expand_dims");
394 static const auto& reshape_op = Op::Get("reshape");
395 static const auto& transpose_op = Op::Get("transpose");
396 static const auto& squeeze_op = Op::Get("squeeze");
397 static const auto& repeat_op = Op::Get("repeat");
398
399 // peel through a few common transform ops.
400 if (const auto* constant = expr.as<ConstantNode>()) {
401 const auto& tensor = constant->data;
402 const auto& dtype = tensor->dtype;
403 if (dtype.lanes != 1) {
404 return false;
405 } else if (dtype.code == kDLFloat && dtype.bits == 32) {
406 return IsNDArrayAllGreaterEqual<float>(tensor, 0);
407 } else if (dtype.code == kDLFloat && dtype.bits == 64) {
408 return IsNDArrayAllGreaterEqual<double>(tensor, 0);
409 } else if (dtype.code == kDLInt && dtype.bits == 8) {
410 return IsNDArrayAllGreaterEqual<int8_t>(tensor, 0);
411 } else if (dtype.code == kDLInt && dtype.bits == 32) {
412 return IsNDArrayAllGreaterEqual<int32_t>(tensor, 0);
413 } else if (dtype.code == kDLUInt && dtype.bits == 8) {
414 return IsNDArrayAllGreaterEqual<uint8_t>(tensor, 0);
415 } else if (dtype.code == kDLUInt && dtype.bits == 32) {
416 return IsNDArrayAllGreaterEqual<uint32_t>(tensor, 0);
417 } else {
418 return false;
419 }
420 } else if (const auto* op = expr.as<CallNode>()) {
421 // tail recursion.
422 if (op->op == expand_dims_op || op->op == reshape_op || op->op == transpose_op ||
423 op->op == squeeze_op || op->op == repeat_op) {
424 return IsAllPositiveConstant(op->args[0]);
425 } else {
426 return false;
427 }
428 } else {
429 return false;
430 }
431}
432
433Type TypeSubst(const Type& type, const TypeVar& tvar, const Type& subst) {
434 return TypeSubst(type, tvm::Map<TypeVar, Type>({{tvar, subst}}));
435}
436
437Expr TypeSubst(const Expr& expr, const TypeVar& tvar, const Type& subst) {
438 return TypeSubst(expr, tvm::Map<TypeVar, Type>({{tvar, subst}}));
439}
440
441Type TypeSubst(const Type& type, const tvm::Map<TypeVar, Type>& subst_map) {
442 return Bind(type, subst_map);
443}
444
445Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map) {
446 class TypeSubstMutator : public ExprMutator, public PatternMutator {
447 public:
448 explicit TypeSubstMutator(const tvm::Map<TypeVar, Type>& subst_map) : subst_map_(subst_map) {}
449 Type VisitType(const Type& t) final { return TypeSubst(t, subst_map_); }
450 Var VisitVar(const Var& v) final { return Downcast<Var>(VisitExpr(v)); }
451
452 Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); }
453
454 Clause VisitClause(const Clause& c) final {
455 Pattern pat = VisitPattern(c->lhs);
456 return Clause(pat, VisitExpr(c->rhs));
457 }
458
459 private:
460 const tvm::Map<TypeVar, Type>& subst_map_;
461 };
462 ICHECK(WellFormed(expr));
463 auto ret = TypeSubstMutator(subst_map).VisitExpr(expr);
464 ICHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size());
465 ICHECK(WellFormed(ret));
466 return ret;
467}
468
469struct IsDynamicVisitor : public TypeVisitor {
470 bool is_dyn{false};
471 void VisitType_(const TensorTypeNode* tt) {
472 for (auto dim : tt->shape) {
473 if (dim.as<tir::IntImmNode>() == nullptr) {
474 is_dyn = true;
475 break;
476 }
477 }
478 }
479};
480
481bool IsDynamic(const Type& ty) {
482 IsDynamicVisitor v;
483 v.VisitType(ty);
484 return v.is_dyn;
485}
486
487TVM_REGISTER_GLOBAL("relay.ir.IsDynamic").set_body_typed(IsDynamic);
488
489bool IsDataDependent(const CallNode* call) {
490 static auto tshape_data_dependent = Op::GetAttrMap<TShapeDataDependent>("TShapeDataDependent");
491 Op op = Downcast<Op>(call->op);
492
493 if (!tshape_data_dependent.count(op)) {
494 return false;
495 }
496
497 if (op->name == "strided_slice") {
498 if (const auto* attrs = call->attrs.as<StridedSliceAttrs>()) {
499 if (attrs->begin && attrs->end && attrs->strides) {
500 // not data dependent if begin, end and strides exist
501 return false;
502 }
503 }
504 }
505
506 for (auto req : tshape_data_dependent[op]) {
507 if (req->value != 0) return true;
508 }
509 return false;
510}
511} // namespace relay
512} // namespace tvm
513