1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file feature.cc |
22 | * \brief Detect features used in Expr/Module |
23 | */ |
24 | #include <tvm/ir/module.h> |
25 | #include <tvm/relay/analysis.h> |
26 | #include <tvm/relay/expr.h> |
27 | #include <tvm/relay/expr_functor.h> |
28 | #include <tvm/relay/feature.h> |
29 | |
30 | #include "../transforms/pass_utils.h" |
31 | |
32 | namespace tvm { |
33 | namespace relay { |
34 | |
35 | FeatureSet DetectFeature(const Expr& expr) { |
36 | if (!expr.defined()) { |
37 | return FeatureSet::No(); |
38 | } |
39 | struct FeatureDetector : ExprVisitor { |
40 | std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> visited_; |
41 | FeatureSet fs = FeatureSet::No(); |
42 | |
43 | void VisitExpr(const Expr& expr) final { |
44 | if (visited_.count(expr) == 0) { |
45 | visited_.insert(expr); |
46 | ExprVisitor::VisitExpr(expr); |
47 | } else { |
48 | if (!IsAtomic(expr)) { |
49 | fs += fGraph; |
50 | } |
51 | } |
52 | } |
53 | #define DETECT_CONSTRUCT(CONSTRUCT_NAME, STMT) \ |
54 | void VisitExpr_(const CONSTRUCT_NAME##Node* op) final { STMT fs += f##CONSTRUCT_NAME; } |
55 | #define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) \ |
56 | DETECT_CONSTRUCT(CONSTRUCT_NAME, { ExprVisitor::VisitExpr_(op); }) |
57 | DETECT_DEFAULT_CONSTRUCT(Var) |
58 | DETECT_DEFAULT_CONSTRUCT(GlobalVar) |
59 | DETECT_DEFAULT_CONSTRUCT(Constant) |
60 | DETECT_DEFAULT_CONSTRUCT(Tuple) |
61 | DETECT_DEFAULT_CONSTRUCT(TupleGetItem) |
62 | DETECT_CONSTRUCT(Function, { |
63 | if (!op->HasNonzeroAttr(attr::kPrimitive)) { |
64 | ExprVisitor::VisitExpr_(op); |
65 | } |
66 | }) |
67 | DETECT_DEFAULT_CONSTRUCT(Op) |
68 | DETECT_DEFAULT_CONSTRUCT(Call) |
69 | DETECT_CONSTRUCT(Let, { |
70 | for (const Var& v : FreeVars(op->value)) { |
71 | if (op->var == v) { |
72 | fs += fLetRec; |
73 | } |
74 | } |
75 | ExprVisitor::VisitExpr_(op); |
76 | }) |
77 | DETECT_DEFAULT_CONSTRUCT(If) |
78 | DETECT_DEFAULT_CONSTRUCT(RefCreate) |
79 | DETECT_DEFAULT_CONSTRUCT(RefRead) |
80 | DETECT_DEFAULT_CONSTRUCT(RefWrite) |
81 | DETECT_DEFAULT_CONSTRUCT(Constructor) |
82 | DETECT_DEFAULT_CONSTRUCT(Match) |
83 | #undef DETECT_DEFAULT_CONSTRUCT |
84 | } fd; |
85 | fd(expr); |
86 | return fd.fs; |
87 | } |
88 | |
89 | std::string FeatureSet::ToString() const { |
90 | std::string ret; |
91 | ret += "[" ; |
92 | size_t detected = 0; |
93 | #define DETECT_FEATURE(FEATURE_NAME) \ |
94 | ++detected; \ |
95 | if (bs_[FEATURE_NAME]) { \ |
96 | ret += #FEATURE_NAME; \ |
97 | ret += ", "; \ |
98 | } |
99 | DETECT_FEATURE(fVar); |
100 | DETECT_FEATURE(fGlobalVar); |
101 | DETECT_FEATURE(fConstant); |
102 | DETECT_FEATURE(fTuple); |
103 | DETECT_FEATURE(fTupleGetItem); |
104 | DETECT_FEATURE(fFunction); |
105 | DETECT_FEATURE(fOp); |
106 | DETECT_FEATURE(fCall); |
107 | DETECT_FEATURE(fLet); |
108 | DETECT_FEATURE(fIf); |
109 | DETECT_FEATURE(fRefCreate); |
110 | DETECT_FEATURE(fRefRead); |
111 | DETECT_FEATURE(fRefWrite); |
112 | DETECT_FEATURE(fConstructor); |
113 | DETECT_FEATURE(fMatch); |
114 | DETECT_FEATURE(fGraph); |
115 | DETECT_FEATURE(fLetRec); |
116 | #undef DETECT_FEATURE |
117 | ICHECK(detected == feature_count) << "some feature not printed" ; |
118 | ret += "]" ; |
119 | return ret; |
120 | } |
121 | |
122 | FeatureSet DetectFeature(const IRModule& mod) { |
123 | FeatureSet fs = FeatureSet::No(); |
124 | for (const auto& f : mod->functions) { |
125 | fs += DetectFeature(f.second); |
126 | } |
127 | return fs; |
128 | } |
129 | |
130 | Array<Integer> PyDetectFeature(const Expr& expr, const Optional<IRModule>& mod) { |
131 | FeatureSet fs = DetectFeature(expr); |
132 | if (mod.defined()) { |
133 | fs = fs + DetectFeature(mod.value()); |
134 | } |
135 | return static_cast<Array<Integer>>(fs); |
136 | } |
137 | |
138 | TVM_REGISTER_GLOBAL("relay.analysis.detect_feature" ).set_body_typed(PyDetectFeature); |
139 | |
140 | void CheckFeature(const Expr& expr, const FeatureSet& fs) { |
141 | auto dfs = DetectFeature(expr); |
142 | ICHECK(dfs.is_subset_of(fs)) << AsText(expr, false) |
143 | << "\nhas unsupported feature: " << (dfs - fs).ToString(); |
144 | } |
145 | |
146 | void CheckFeature(const IRModule& mod, const FeatureSet& fs) { |
147 | for (const auto& f : mod->functions) { |
148 | CheckFeature(f.second, fs); |
149 | } |
150 | } |
151 | |
152 | } // namespace relay |
153 | } // namespace tvm |
154 | |