1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file feature.cc
22 * \brief Detect features used in Expr/Module
23 */
24#include <tvm/ir/module.h>
25#include <tvm/relay/analysis.h>
26#include <tvm/relay/expr.h>
27#include <tvm/relay/expr_functor.h>
28#include <tvm/relay/feature.h>
29
30#include "../transforms/pass_utils.h"
31
32namespace tvm {
33namespace relay {
34
35FeatureSet DetectFeature(const Expr& expr) {
36 if (!expr.defined()) {
37 return FeatureSet::No();
38 }
39 struct FeatureDetector : ExprVisitor {
40 std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> visited_;
41 FeatureSet fs = FeatureSet::No();
42
43 void VisitExpr(const Expr& expr) final {
44 if (visited_.count(expr) == 0) {
45 visited_.insert(expr);
46 ExprVisitor::VisitExpr(expr);
47 } else {
48 if (!IsAtomic(expr)) {
49 fs += fGraph;
50 }
51 }
52 }
53#define DETECT_CONSTRUCT(CONSTRUCT_NAME, STMT) \
54 void VisitExpr_(const CONSTRUCT_NAME##Node* op) final { STMT fs += f##CONSTRUCT_NAME; }
55#define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) \
56 DETECT_CONSTRUCT(CONSTRUCT_NAME, { ExprVisitor::VisitExpr_(op); })
57 DETECT_DEFAULT_CONSTRUCT(Var)
58 DETECT_DEFAULT_CONSTRUCT(GlobalVar)
59 DETECT_DEFAULT_CONSTRUCT(Constant)
60 DETECT_DEFAULT_CONSTRUCT(Tuple)
61 DETECT_DEFAULT_CONSTRUCT(TupleGetItem)
62 DETECT_CONSTRUCT(Function, {
63 if (!op->HasNonzeroAttr(attr::kPrimitive)) {
64 ExprVisitor::VisitExpr_(op);
65 }
66 })
67 DETECT_DEFAULT_CONSTRUCT(Op)
68 DETECT_DEFAULT_CONSTRUCT(Call)
69 DETECT_CONSTRUCT(Let, {
70 for (const Var& v : FreeVars(op->value)) {
71 if (op->var == v) {
72 fs += fLetRec;
73 }
74 }
75 ExprVisitor::VisitExpr_(op);
76 })
77 DETECT_DEFAULT_CONSTRUCT(If)
78 DETECT_DEFAULT_CONSTRUCT(RefCreate)
79 DETECT_DEFAULT_CONSTRUCT(RefRead)
80 DETECT_DEFAULT_CONSTRUCT(RefWrite)
81 DETECT_DEFAULT_CONSTRUCT(Constructor)
82 DETECT_DEFAULT_CONSTRUCT(Match)
83#undef DETECT_DEFAULT_CONSTRUCT
84 } fd;
85 fd(expr);
86 return fd.fs;
87}
88
89std::string FeatureSet::ToString() const {
90 std::string ret;
91 ret += "[";
92 size_t detected = 0;
93#define DETECT_FEATURE(FEATURE_NAME) \
94 ++detected; \
95 if (bs_[FEATURE_NAME]) { \
96 ret += #FEATURE_NAME; \
97 ret += ", "; \
98 }
99 DETECT_FEATURE(fVar);
100 DETECT_FEATURE(fGlobalVar);
101 DETECT_FEATURE(fConstant);
102 DETECT_FEATURE(fTuple);
103 DETECT_FEATURE(fTupleGetItem);
104 DETECT_FEATURE(fFunction);
105 DETECT_FEATURE(fOp);
106 DETECT_FEATURE(fCall);
107 DETECT_FEATURE(fLet);
108 DETECT_FEATURE(fIf);
109 DETECT_FEATURE(fRefCreate);
110 DETECT_FEATURE(fRefRead);
111 DETECT_FEATURE(fRefWrite);
112 DETECT_FEATURE(fConstructor);
113 DETECT_FEATURE(fMatch);
114 DETECT_FEATURE(fGraph);
115 DETECT_FEATURE(fLetRec);
116#undef DETECT_FEATURE
117 ICHECK(detected == feature_count) << "some feature not printed";
118 ret += "]";
119 return ret;
120}
121
122FeatureSet DetectFeature(const IRModule& mod) {
123 FeatureSet fs = FeatureSet::No();
124 for (const auto& f : mod->functions) {
125 fs += DetectFeature(f.second);
126 }
127 return fs;
128}
129
130Array<Integer> PyDetectFeature(const Expr& expr, const Optional<IRModule>& mod) {
131 FeatureSet fs = DetectFeature(expr);
132 if (mod.defined()) {
133 fs = fs + DetectFeature(mod.value());
134 }
135 return static_cast<Array<Integer>>(fs);
136}
137
138TVM_REGISTER_GLOBAL("relay.analysis.detect_feature").set_body_typed(PyDetectFeature);
139
140void CheckFeature(const Expr& expr, const FeatureSet& fs) {
141 auto dfs = DetectFeature(expr);
142 ICHECK(dfs.is_subset_of(fs)) << AsText(expr, false)
143 << "\nhas unsupported feature: " << (dfs - fs).ToString();
144}
145
146void CheckFeature(const IRModule& mod, const FeatureSet& fs) {
147 for (const auto& f : mod->functions) {
148 CheckFeature(f.second, fs);
149 }
150}
151
152} // namespace relay
153} // namespace tvm
154