1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/tir/analysis.h>
20#include <tvm/tir/stmt_functor.h>
21
22namespace tvm {
23namespace tir {
24
25int32_t DataType2Int(const tvm::DataType& dtype) {
26 static_assert(sizeof(DLDataType) == sizeof(int32_t), "Incorrect size of DLDataType");
27 union {
28 DLDataType src;
29 int32_t dst;
30 } converter;
31 converter.src.code = dtype.code();
32 converter.src.bits = dtype.bits();
33 converter.src.lanes = dtype.lanes();
34 return converter.dst;
35}
36
37String Int2DataTypeStr(int32_t dtype) {
38 union {
39 DLDataType dst;
40 int32_t src;
41 } converter;
42 converter.src = dtype;
43 static std::string type_code_tab[] = {"int", "uint", "float", "handle", "bfloat"};
44 std::ostringstream os;
45 os << type_code_tab[converter.dst.code];
46 os << static_cast<int>(converter.dst.bits);
47 if (converter.dst.lanes != 1) {
48 os << "x" << static_cast<int>(converter.dst.lanes);
49 }
50 return os.str();
51}
52
53struct TResult {
54 TResult() = default;
55
56 void Add(const tvm::DataType& dtype) { data_[DataType2Int(dtype)] += 1; }
57
58 TResult operator+=(const TResult& rhs) {
59 for (const auto& kv : rhs.data_) {
60 data_[kv.first] += kv.second;
61 }
62 return *this;
63 }
64
65 TResult operator*=(int64_t rhs) {
66 for (auto& kv : data_) {
67 kv.second *= rhs;
68 }
69 return *this;
70 }
71
72 TResult MaxWith(const TResult& rhs) {
73 for (const auto& kv : rhs.data_) {
74 double& v = data_[kv.first];
75 if (v < kv.second) {
76 v = kv.second;
77 }
78 }
79 return *this;
80 }
81
82 std::unordered_map<int32_t, double> data_;
83};
84
85class FlopEstimator : private ExprFunctor<TResult(const PrimExpr& n)>,
86 private StmtFunctor<TResult(const Stmt& n)> {
87 public:
88 TResult VisitExpr(const PrimExpr& expr) override { return ExprFunctor::VisitExpr(expr); }
89 TResult VisitStmt(const Stmt& stmt) override { return StmtFunctor::VisitStmt(stmt); }
90
91#define TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(Node) \
92 TResult VisitExpr_(const Node* op) final { \
93 TResult result = VisitExpr(op->a); \
94 result += VisitExpr(op->b); \
95 result.Add(op->dtype); \
96 return result; \
97 }
98 TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(AddNode);
99 TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(SubNode);
100 TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(MulNode);
101 TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(DivNode);
102 TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(ModNode);
103 TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(FloorDivNode);
104 TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(FloorModNode);
105 TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(MinNode);
106 TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(MaxNode);
107#undef TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY
108 TResult VisitExpr_(const EQNode* op) override { return TResult(); }
109 TResult VisitExpr_(const NENode* op) override { return TResult(); }
110 TResult VisitExpr_(const LTNode* op) override { return TResult(); }
111 TResult VisitExpr_(const LENode* op) override { return TResult(); }
112 TResult VisitExpr_(const GTNode* op) override { return TResult(); }
113 TResult VisitExpr_(const GENode* op) override { return TResult(); }
114
115 TResult VisitExpr_(const NotNode* op) override { return VisitExpr(op->a); }
116 TResult VisitExpr_(const AndNode* op) final {
117 TResult result = VisitExpr(op->a);
118 result += VisitExpr(op->b);
119 return result;
120 }
121 TResult VisitExpr_(const OrNode* op) final {
122 TResult result = VisitExpr(op->a);
123 result += VisitExpr(op->b);
124 return result;
125 }
126
127 TResult VisitExpr_(const BufferLoadNode* op) override { return TResult(); }
128 TResult VisitStmt_(const BufferStoreNode* store) override { return VisitExpr(store->value); }
129 TResult VisitStmt_(const BlockRealizeNode* block) override {
130 return VisitStmt(block->block->body);
131 }
132 TResult VisitStmt_(const BlockNode* block) override {
133 TResult result;
134 if (block->init.defined()) {
135 result += VisitStmt(block->init.value());
136 }
137 result += VisitStmt(block->body);
138 return result;
139 }
140 TResult VisitStmt_(const ForNode* loop) override {
141 TResult result = VisitStmt(loop->body);
142 const auto* int_imm = loop->extent.as<IntImmNode>();
143 ICHECK(int_imm) << "TypeError: Expect the extent of a loop to be IntImm, but gets: "
144 << loop->extent->GetTypeKey();
145 result *= int_imm->value;
146 return result;
147 }
148
149 TResult VisitStmt_(const IfThenElseNode* branch) override {
150 TResult cond = VisitExpr(branch->condition);
151 if (branch->else_case) {
152 cond += VisitStmt(branch->then_case).MaxWith(VisitStmt(branch->else_case.value()));
153 } else {
154 cond += VisitStmt(branch->then_case);
155 }
156 return cond;
157 }
158
159 TResult VisitStmt_(const LetStmtNode* let) override {
160 TResult value = VisitExpr(let->value);
161 value += VisitStmt(let->body);
162 return value;
163 }
164
165 TResult VisitExpr_(const SelectNode* op) override {
166 TResult cond = VisitExpr(op->condition);
167 cond += VisitExpr(op->true_value).MaxWith(VisitExpr(op->false_value));
168 return cond;
169 }
170
171 TResult VisitExpr_(const VarNode* op) override { return TResult(); }
172 TResult VisitExpr_(const SizeVarNode* op) override { return TResult(); }
173 TResult VisitExpr_(const IntImmNode* op) override { return TResult(); }
174 TResult VisitExpr_(const FloatImmNode* op) override { return TResult(); }
175 TResult VisitExpr_(const CastNode* op) override { return VisitExpr(op->value); }
176 TResult VisitStmt_(const AllocateConstNode* op) override { return VisitStmt(op->body); }
177
178 TResult VisitStmt_(const SeqStmtNode* seq) override {
179 TResult result;
180 for (const Stmt& stmt : seq->seq) {
181 result += VisitStmt(stmt);
182 }
183 return result;
184 }
185
186 TResult VisitExpr_(const CallNode* op) override {
187 TResult ret;
188 for (const auto& x : op->args) {
189 ret += VisitExpr(x);
190 }
191 return ret;
192 }
193};
194
195double PostprocessResults(const TResult& result) {
196 double cnt = 0.0;
197 for (const auto& kv : result.data_) {
198 cnt += kv.second;
199 }
200 return cnt;
201}
202
203double EstimateTIRFlops(const Stmt& stmt) {
204 FlopEstimator counter;
205 return PostprocessResults(counter.VisitStmt(stmt));
206}
207
208double EstimateTIRFlops(const IRModule& mod) {
209 FlopEstimator counter;
210 TResult result;
211 VisitPrimFuncs(mod, [&result, &counter](const PrimFuncNode* f) {
212 result += counter.VisitStmt(f->body); //
213 });
214 return PostprocessResults(result);
215}
216
217TVM_REGISTER_GLOBAL("tir.analysis.EstimateTIRFlops").set_body_typed([](ObjectRef obj) -> double {
218 if (const auto* mod = obj.as<IRModuleNode>()) {
219 return EstimateTIRFlops(GetRef<IRModule>(mod));
220 } else if (const auto* stmt = obj.as<StmtNode>()) {
221 return EstimateTIRFlops(GetRef<Stmt>(stmt));
222 } else {
223 LOG(FATAL) << "TypeError: Expect the input to be either IRModule or Stmt, but gets: "
224 << obj->GetTypeKey();
225 throw;
226 }
227});
228
229} // namespace tir
230} // namespace tvm
231