1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/tir/analysis.h> |
20 | #include <tvm/tir/stmt_functor.h> |
21 | |
22 | namespace tvm { |
23 | namespace tir { |
24 | |
25 | int32_t DataType2Int(const tvm::DataType& dtype) { |
26 | static_assert(sizeof(DLDataType) == sizeof(int32_t), "Incorrect size of DLDataType" ); |
27 | union { |
28 | DLDataType src; |
29 | int32_t dst; |
30 | } converter; |
31 | converter.src.code = dtype.code(); |
32 | converter.src.bits = dtype.bits(); |
33 | converter.src.lanes = dtype.lanes(); |
34 | return converter.dst; |
35 | } |
36 | |
37 | String Int2DataTypeStr(int32_t dtype) { |
38 | union { |
39 | DLDataType dst; |
40 | int32_t src; |
41 | } converter; |
42 | converter.src = dtype; |
43 | static std::string type_code_tab[] = {"int" , "uint" , "float" , "handle" , "bfloat" }; |
44 | std::ostringstream os; |
45 | os << type_code_tab[converter.dst.code]; |
46 | os << static_cast<int>(converter.dst.bits); |
47 | if (converter.dst.lanes != 1) { |
48 | os << "x" << static_cast<int>(converter.dst.lanes); |
49 | } |
50 | return os.str(); |
51 | } |
52 | |
53 | struct TResult { |
54 | TResult() = default; |
55 | |
56 | void Add(const tvm::DataType& dtype) { data_[DataType2Int(dtype)] += 1; } |
57 | |
58 | TResult operator+=(const TResult& rhs) { |
59 | for (const auto& kv : rhs.data_) { |
60 | data_[kv.first] += kv.second; |
61 | } |
62 | return *this; |
63 | } |
64 | |
65 | TResult operator*=(int64_t rhs) { |
66 | for (auto& kv : data_) { |
67 | kv.second *= rhs; |
68 | } |
69 | return *this; |
70 | } |
71 | |
72 | TResult MaxWith(const TResult& rhs) { |
73 | for (const auto& kv : rhs.data_) { |
74 | double& v = data_[kv.first]; |
75 | if (v < kv.second) { |
76 | v = kv.second; |
77 | } |
78 | } |
79 | return *this; |
80 | } |
81 | |
82 | std::unordered_map<int32_t, double> data_; |
83 | }; |
84 | |
85 | class FlopEstimator : private ExprFunctor<TResult(const PrimExpr& n)>, |
86 | private StmtFunctor<TResult(const Stmt& n)> { |
87 | public: |
88 | TResult VisitExpr(const PrimExpr& expr) override { return ExprFunctor::VisitExpr(expr); } |
89 | TResult VisitStmt(const Stmt& stmt) override { return StmtFunctor::VisitStmt(stmt); } |
90 | |
91 | #define TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(Node) \ |
92 | TResult VisitExpr_(const Node* op) final { \ |
93 | TResult result = VisitExpr(op->a); \ |
94 | result += VisitExpr(op->b); \ |
95 | result.Add(op->dtype); \ |
96 | return result; \ |
97 | } |
98 | TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(AddNode); |
99 | TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(SubNode); |
100 | TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(MulNode); |
101 | TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(DivNode); |
102 | TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(ModNode); |
103 | TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(FloorDivNode); |
104 | TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(FloorModNode); |
105 | TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(MinNode); |
106 | TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(MaxNode); |
107 | #undef TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY |
108 | TResult VisitExpr_(const EQNode* op) override { return TResult(); } |
109 | TResult VisitExpr_(const NENode* op) override { return TResult(); } |
110 | TResult VisitExpr_(const LTNode* op) override { return TResult(); } |
111 | TResult VisitExpr_(const LENode* op) override { return TResult(); } |
112 | TResult VisitExpr_(const GTNode* op) override { return TResult(); } |
113 | TResult VisitExpr_(const GENode* op) override { return TResult(); } |
114 | |
115 | TResult VisitExpr_(const NotNode* op) override { return VisitExpr(op->a); } |
116 | TResult VisitExpr_(const AndNode* op) final { |
117 | TResult result = VisitExpr(op->a); |
118 | result += VisitExpr(op->b); |
119 | return result; |
120 | } |
121 | TResult VisitExpr_(const OrNode* op) final { |
122 | TResult result = VisitExpr(op->a); |
123 | result += VisitExpr(op->b); |
124 | return result; |
125 | } |
126 | |
127 | TResult VisitExpr_(const BufferLoadNode* op) override { return TResult(); } |
128 | TResult VisitStmt_(const BufferStoreNode* store) override { return VisitExpr(store->value); } |
129 | TResult VisitStmt_(const BlockRealizeNode* block) override { |
130 | return VisitStmt(block->block->body); |
131 | } |
132 | TResult VisitStmt_(const BlockNode* block) override { |
133 | TResult result; |
134 | if (block->init.defined()) { |
135 | result += VisitStmt(block->init.value()); |
136 | } |
137 | result += VisitStmt(block->body); |
138 | return result; |
139 | } |
140 | TResult VisitStmt_(const ForNode* loop) override { |
141 | TResult result = VisitStmt(loop->body); |
142 | const auto* int_imm = loop->extent.as<IntImmNode>(); |
143 | ICHECK(int_imm) << "TypeError: Expect the extent of a loop to be IntImm, but gets: " |
144 | << loop->extent->GetTypeKey(); |
145 | result *= int_imm->value; |
146 | return result; |
147 | } |
148 | |
149 | TResult VisitStmt_(const IfThenElseNode* branch) override { |
150 | TResult cond = VisitExpr(branch->condition); |
151 | if (branch->else_case) { |
152 | cond += VisitStmt(branch->then_case).MaxWith(VisitStmt(branch->else_case.value())); |
153 | } else { |
154 | cond += VisitStmt(branch->then_case); |
155 | } |
156 | return cond; |
157 | } |
158 | |
159 | TResult VisitStmt_(const LetStmtNode* let) override { |
160 | TResult value = VisitExpr(let->value); |
161 | value += VisitStmt(let->body); |
162 | return value; |
163 | } |
164 | |
165 | TResult VisitExpr_(const SelectNode* op) override { |
166 | TResult cond = VisitExpr(op->condition); |
167 | cond += VisitExpr(op->true_value).MaxWith(VisitExpr(op->false_value)); |
168 | return cond; |
169 | } |
170 | |
171 | TResult VisitExpr_(const VarNode* op) override { return TResult(); } |
172 | TResult VisitExpr_(const SizeVarNode* op) override { return TResult(); } |
173 | TResult VisitExpr_(const IntImmNode* op) override { return TResult(); } |
174 | TResult VisitExpr_(const FloatImmNode* op) override { return TResult(); } |
175 | TResult VisitExpr_(const CastNode* op) override { return VisitExpr(op->value); } |
176 | TResult VisitStmt_(const AllocateConstNode* op) override { return VisitStmt(op->body); } |
177 | |
178 | TResult VisitStmt_(const SeqStmtNode* seq) override { |
179 | TResult result; |
180 | for (const Stmt& stmt : seq->seq) { |
181 | result += VisitStmt(stmt); |
182 | } |
183 | return result; |
184 | } |
185 | |
186 | TResult VisitExpr_(const CallNode* op) override { |
187 | TResult ret; |
188 | for (const auto& x : op->args) { |
189 | ret += VisitExpr(x); |
190 | } |
191 | return ret; |
192 | } |
193 | }; |
194 | |
195 | double PostprocessResults(const TResult& result) { |
196 | double cnt = 0.0; |
197 | for (const auto& kv : result.data_) { |
198 | cnt += kv.second; |
199 | } |
200 | return cnt; |
201 | } |
202 | |
203 | double EstimateTIRFlops(const Stmt& stmt) { |
204 | FlopEstimator counter; |
205 | return PostprocessResults(counter.VisitStmt(stmt)); |
206 | } |
207 | |
208 | double EstimateTIRFlops(const IRModule& mod) { |
209 | FlopEstimator counter; |
210 | TResult result; |
211 | VisitPrimFuncs(mod, [&result, &counter](const PrimFuncNode* f) { |
212 | result += counter.VisitStmt(f->body); // |
213 | }); |
214 | return PostprocessResults(result); |
215 | } |
216 | |
217 | TVM_REGISTER_GLOBAL("tir.analysis.EstimateTIRFlops" ).set_body_typed([](ObjectRef obj) -> double { |
218 | if (const auto* mod = obj.as<IRModuleNode>()) { |
219 | return EstimateTIRFlops(GetRef<IRModule>(mod)); |
220 | } else if (const auto* stmt = obj.as<StmtNode>()) { |
221 | return EstimateTIRFlops(GetRef<Stmt>(stmt)); |
222 | } else { |
223 | LOG(FATAL) << "TypeError: Expect the input to be either IRModule or Stmt, but gets: " |
224 | << obj->GetTypeKey(); |
225 | throw; |
226 | } |
227 | }); |
228 | |
229 | } // namespace tir |
230 | } // namespace tvm |
231 | |