1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file touch_extractor.cc
22 * \brief Extract feature of touch pattern of axes in lowered IR
23 */
24
25#include "touch_extractor.h"
26
27#include <algorithm>
28#include <cmath>
29#include <set>
30#include <unordered_map>
31
32namespace tvm {
33namespace autotvm {
34
35int ParallelLevel(AnnotationType ann) {
36 switch (ann) {
37 case kBlockX:
38 case kBlockY:
39 case kBlockZ:
40 return 2;
41 case kThreadX:
42 case kThreadY:
43 case kThreadZ:
44 case kParallel:
45 return 1;
46 default:
47 return 0;
48 }
49}
50
51// get touch pattern from index expression
52class IndexParser : public ExprVisitor {
53 public:
54 void Parse(PrimExpr expr) {
55 pattern_map.clear();
56 this->VisitExpr(expr);
57 }
58
59 void VisitExpr_(const VarNode* op) final {
60 // TODO(lmzheng): handle more index types (multiple occurrence)
61 if (pattern_map.count(op) == 0) {
62 pattern_map[op] = TouchPattern();
63 pattern_map[op].stride = next_stride_;
64 next_stride_ = 1;
65 }
66 }
67
68 void VisitExpr_(const MulNode* op) final {
69 if (op->a.as<VarNode>()) {
70 if (const auto stride = op->b.as<IntImmNode>()) {
71 next_stride_ = stride->value;
72 }
73 }
74 ExprVisitor::VisitExpr_(op);
75 }
76
77 std::unordered_map<const VarNode*, TouchPattern> pattern_map;
78
79 private:
80 int64_t next_stride_ = 1;
81};
82
83// extract iter vars and their touch pattern from ir
84bool TouchExtractor::EnterItervar_(Var var, int64_t length, AnnotationType ann_type) {
85 // do not insert duplicated occurrences of virtual thread
86 if (ann_type == kVirtualThread && itervar_map.count(var) != 0) {
87 skip_stack_size_.push_back(itervar_stack_.size());
88 return true;
89 } else {
90 itervar_stack_.push_back(var);
91 topdown_product_ *= length;
92
93 if (itervar_map.count(var) != 0) {
94 // find two duplicated axes
95 // these happens when we create tvm.thread_axis("threadIdx.x") once and
96 // bind it twice. Here we treat them as two axes
97 // so we create a snapshot for the old one and freeze it
98 Var old = Var(var.get()->name_hint);
99 itervar_map.insert({old, itervar_map[var]});
100 itervar_map.erase(var);
101 }
102
103 itervar_map.insert(
104 {var, ItervarFeature(var, length, static_cast<int>(itervar_stack_.size()), ann_type,
105 topdown_product_, static_cast<int>(itervar_counter_++))});
106 }
107
108 return true;
109}
110
111void TouchExtractor::ExitItervar_() {
112 if (!skip_stack_size_.empty() && skip_stack_size_.back() == itervar_stack_.size()) {
113 skip_stack_size_.pop_back();
114 return;
115 }
116 Var var = itervar_stack_.back();
117
118 // update count and reuse ratio for upper iter vars (includes self)
119 for (auto kv : itervar_map[var].touch_feature) {
120 if (kv.second.stride != 0) { // multiply count
121 for (auto stack_var : itervar_stack_) {
122 auto touch_pattern = itervar_map[stack_var].touch_feature.find(kv.first);
123 ICHECK(touch_pattern != itervar_map[stack_var].touch_feature.end());
124 touch_pattern->second.count *= itervar_map[var].length;
125 }
126 } else { // multiply reuse ratio
127 for (auto stack_var : itervar_stack_) {
128 auto touch_pattern = itervar_map[stack_var].touch_feature.find(kv.first);
129 ICHECK(touch_pattern != itervar_map[stack_var].touch_feature.end());
130 touch_pattern->second.reuse *= itervar_map[var].length;
131 }
132 }
133 }
134 itervar_stack_.pop_back();
135
136 int64_t length = itervar_map[var].length;
137 if (length != 0) topdown_product_ /= length;
138 int64_t bottomup_product = -1;
139 for (auto kv : itervar_map[var].touch_feature) {
140 bottomup_product = std::max(bottomup_product, kv.second.count * kv.second.reuse);
141 }
142
143 itervar_map[var].bottomup_product = bottomup_product;
144
145 // push base to upper parallel axis
146 int para_level = ParallelLevel(itervar_map[var].ann);
147 // if is the separate line of parallel level, push the base to upper parallel level
148 if (!itervar_stack_.empty() &&
149 ParallelLevel(itervar_map[itervar_stack_.back()].ann) == para_level + 1) {
150 for (auto kv : itervar_map[var].touch_feature) {
151 for (auto stack_var : itervar_stack_) {
152 if (ParallelLevel(itervar_map[stack_var].ann) == para_level + 1) {
153 auto touch_pattern = itervar_map[stack_var].touch_feature.find(kv.first);
154 ICHECK(touch_pattern != itervar_map[stack_var].touch_feature.end());
155 touch_pattern->second.thread_reuse = -kv.second.reuse;
156 touch_pattern->second.thread_count = -kv.second.count;
157 // NOTE: use minus as a flag to denote it is a base,
158 // indicating it is not the final value
159 }
160 }
161 }
162 }
163
164 for (auto kv : itervar_map[var].touch_feature) {
165 if (kv.second.thread_count < 0) {
166 itervar_map[var].touch_feature[kv.first].thread_count =
167 kv.second.count / (-kv.second.thread_count);
168 itervar_map[var].touch_feature[kv.first].thread_reuse =
169 kv.second.reuse / (-kv.second.thread_reuse);
170 }
171 }
172}
173
174void TouchExtractor::EnterMem_(Var buffer_var, PrimExpr index) {
175 std::string name = buffer_var.get()->name_hint;
176 TouchedBuffer buf = name + "_" + std::to_string(buffer_counter_[name]++);
177
178 // extract touch pattern from index
179 IndexParser parser;
180 parser.Parse(index);
181
182 // push up mem access info
183 for (auto var : itervar_stack_) {
184 auto x = parser.pattern_map.find(var.get());
185 if (x != parser.pattern_map.end()) {
186 itervar_map[var].touch_feature[buf] = x->second;
187 } else {
188 itervar_map[var].touch_feature[buf] = TouchPattern();
189 }
190 }
191}
192
193void TouchExtractor::ExitMem_() {}
194
195/*!
196 * \brief Get axis-based feature for all axes
197 * \param stmt The statement to be extracted
198 * \param bool Whether take log for numerical feature
199 * \param ret_feature The buffer where the return value is stored
200 *
201 * \note The format of return value is
202 * ((
203 * ('_itervar_', var),
204 * ('_attr_', length, nest_level, topdown, bottomup, one_hot_annotation),
205 * ('_arith_', add_ct, mul_ct, div_ct),
206 * ('data_vec_0', stride, mod, count, reuse, thread_count, thread_reuse),
207 * ('conv_0', stride, mod, count, reuse, thread_count, thread_reuse),
208 * ),
209 * (
210 * ('_itervar_', var2),
211 * ('_attr_', length, nest_level, one_hot_annotation),
212 * ('_arith_', add_ct, mul_ct, div_ct),
213 * ('kernel_vec_0', stride, mod, count, reuse, thread_count, thread_reuse),
214 * ('conv_1', stride, mod, count, reuse, thread_count, thread_reuse),
215 * ))
216 *
217 * Itervars are sorted according to their first occurrence position in IR.
218 * Buffers touched by an itervar are sorted by their unique names.
219 *
220 * \note If you want to flatten these features as the input of your model,
221 * You can use the faster one GetItervarFeatureFlatten below.
222 */
223void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr>>>* ret_feature) {
224 // extract
225 TouchExtractor touch_analyzer;
226 touch_analyzer.Analyze(stmt);
227
228 // sort according to order
229 std::vector<Var> vars;
230 for (auto kv : touch_analyzer.itervar_map) {
231 vars.push_back(kv.first);
232 }
233 std::sort(vars.begin(), vars.end(), [&](const Var& lhs, const Var& rhs) -> bool {
234 return touch_analyzer.itervar_map[lhs].order < touch_analyzer.itervar_map[rhs].order;
235 });
236
237 // whether take log for numerical feature
238 std::function<double(int64_t)> trans;
239 if (take_log) {
240 trans = [](int64_t x) {
241 if (x < 0) return -std::log(-x + 1) / std::log(2);
242 x = x + 1;
243 return std::log(x) / std::log(2);
244 };
245 } else {
246 trans = [](int64_t x) { return x; };
247 }
248
249 // serialize for front end
250 for (auto var : vars) {
251 Array<Array<PrimExpr>> feature_row;
252 ItervarFeature& fea = touch_analyzer.itervar_map[var];
253 feature_row.push_back(Array<PrimExpr>{tvm::tir::StringImm("_itervar_"), var});
254
255 Array<PrimExpr> attr{
256 tvm::tir::StringImm("_attr_"),
257 FloatImm(DataType::Float(32), trans(fea.length)),
258 IntImm(DataType::Int(32), fea.nest_level),
259 FloatImm(DataType::Float(32), trans(fea.topdown_product)),
260 FloatImm(DataType::Float(32), trans(fea.bottomup_product)),
261 };
262 // one hot annotation
263 for (int i = 0; i < kNum; i++) {
264 attr.push_back(i == fea.ann);
265 }
266 feature_row.push_back(attr);
267
268 // arithmetic
269 feature_row.push_back(Array<PrimExpr>{
270 tvm::tir::StringImm("_arith_"),
271 FloatImm(DataType::Float(32), trans(fea.add_ct)),
272 FloatImm(DataType::Float(32), trans(fea.mul_ct)),
273 FloatImm(DataType::Float(32), trans(fea.div_ct)),
274 });
275
276 // touch map
277 std::vector<TouchedBuffer> bufs;
278 for (auto kv : fea.touch_feature) {
279 bufs.push_back(kv.first);
280 }
281 std::sort(bufs.begin(), bufs.end());
282 for (auto k : bufs) {
283 TouchPattern& v = fea.touch_feature[k];
284 feature_row.push_back(Array<PrimExpr>{
285 tvm::tir::StringImm(k),
286 FloatImm(DataType::Float(32), trans(v.stride)),
287 FloatImm(DataType::Float(32), trans(v.mod)),
288 FloatImm(DataType::Float(32), trans(v.count)),
289 FloatImm(DataType::Float(32), trans(v.reuse)),
290 FloatImm(DataType::Float(32), trans(v.thread_count)),
291 FloatImm(DataType::Float(32), trans(v.thread_reuse)),
292 });
293 }
294
295 ret_feature->push_back(feature_row);
296 }
297}
298
299/*!
300 * \brief Get axis-based feature for all axes and flatten them into a one-dimensional vector.
301 * \param stmt The statement to be extracted
302 * \param bool Whether take log for numerical feature
303 * \param ret_feature The buffer where the return value is stored
304 *
305 * \note See GetItervarFeature for more details about the return value.
306 * This is an optimized version of GetItervarFeature + Flatten. This runs much faster.
307 */
308void GetItervarFeatureFlatten(Stmt stmt, bool take_log, std::vector<float>* ret_feature) {
309 // extract touch feature
310 TouchExtractor touch_analyzer;
311 touch_analyzer.Analyze(stmt);
312
313 // sort according to order
314 std::vector<Var> vars;
315 for (auto kv : touch_analyzer.itervar_map) {
316 vars.push_back(kv.first);
317 }
318 std::sort(vars.begin(), vars.end(), [&](const Var& lhs, const Var& rhs) -> bool {
319 return touch_analyzer.itervar_map[lhs].order < touch_analyzer.itervar_map[rhs].order;
320 });
321
322 // whether take log for numerical feature
323 std::function<float(int64_t)> trans;
324 if (take_log) {
325 trans = [](int64_t x) {
326 if (x < 0) return -std::log(-x + 1) / std::log(2);
327 x = x + 1;
328 return std::log(x) / std::log(2);
329 };
330 } else {
331 trans = [](int64_t x) { return x; };
332 }
333
334 // serialize for front end
335 for (auto var : vars) {
336 ItervarFeature& fea = touch_analyzer.itervar_map[var];
337
338 ret_feature->push_back(trans(fea.length));
339 ret_feature->push_back(fea.nest_level);
340 ret_feature->push_back(trans(fea.topdown_product));
341 ret_feature->push_back(trans(fea.bottomup_product));
342
343 // one hot annotation
344 for (int i = 0; i < kNum; i++) {
345 ret_feature->push_back(i == fea.ann);
346 }
347
348 // arithmetic
349 ret_feature->push_back(trans(fea.add_ct));
350 ret_feature->push_back(trans(fea.mul_ct));
351 ret_feature->push_back(trans(fea.div_ct));
352
353 // touch map
354 std::vector<TouchedBuffer> bufs;
355 for (auto kv : fea.touch_feature) {
356 bufs.push_back(kv.first);
357 }
358 std::sort(bufs.begin(), bufs.end());
359 for (auto k : bufs) {
360 TouchPattern& v = fea.touch_feature[k];
361 ret_feature->push_back(trans(v.stride));
362 ret_feature->push_back(trans(v.mod));
363 ret_feature->push_back(trans(v.count));
364 ret_feature->push_back(trans(v.reuse));
365 ret_feature->push_back(trans(v.thread_count));
366 ret_feature->push_back(trans(v.thread_reuse));
367 }
368 }
369}
370
371/*!
372 * \brief Get curve sample feature (relation feature) and flatten them into a one-dimensional
373 * vector. \param stmt The statement to be extracted \param sample_n The number of points used for
374 * sampling a curve (along one dimension) \param ret_feature The buffer where the return value is
375 * stored
376 */
377void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector<float>* ret_feature) {
378 // extract touch feature
379 TouchExtractor touch_ext;
380 touch_ext.Analyze(stmt);
381
382 // sort according to order
383 std::vector<Var> vars;
384 for (auto kv : touch_ext.itervar_map) {
385 vars.push_back(kv.first);
386 }
387 std::sort(vars.begin(), vars.end(), [&](const Var& lhs, const Var& rhs) -> bool {
388 return touch_ext.itervar_map[lhs].order < touch_ext.itervar_map[rhs].order;
389 });
390
391 int max_depth = 0;
392 std::map<TouchedBuffer, std::vector<double>> reuse_curve;
393 std::map<TouchedBuffer, std::vector<double>> count_curve;
394 std::map<TouchedBuffer, std::vector<double>> topdown_curve;
395 std::map<TouchedBuffer, std::vector<double>> bottomup_curve;
396 std::set<TouchedBuffer> innermost_buffers;
397 std::set<std::string> added;
398
399 // find maximum depth of loop nest
400 for (auto var : vars) {
401 ItervarFeature& fea = touch_ext.itervar_map[var];
402 max_depth = std::max(max_depth, fea.nest_level);
403 }
404
405 // mark inner most buffer
406 for (auto iter = vars.rbegin(); iter != vars.rend(); iter++) {
407 auto var = *iter;
408 ItervarFeature& fea = touch_ext.itervar_map[var];
409 if (fea.nest_level == max_depth) {
410 for (auto kv : fea.touch_feature) {
411 // delete buffer no (e.g. 'A_0' -> 'A', 'A_1' -> 'A')
412 std::string raw_name = kv.first.substr(0, kv.first.rfind("_"));
413
414 // delete memory scope (e.g. 'A.local' -> 'A', 'A.shared' -> 'A')
415 size_t pos = raw_name.find(".");
416 if (pos < kv.first.size()) raw_name = raw_name.substr(0, pos);
417
418 // If there are multiple innermost buffers that are derived from a same raw buffer
419 // We only record the last occurrence (note the `iter` is in reverse order)
420 // e.g. `A.local`, `A.shared` are derived from `A`, if they all occurred at the inner most
421 // level, we will only record the last occurrence,
422 if (added.find(raw_name) == added.end()) {
423 innermost_buffers.insert(kv.first);
424 added.insert(raw_name);
425 }
426 }
427 }
428 }
429
430 // pad the first point (zero) for all curves
431 for (auto buf : innermost_buffers) {
432 reuse_curve[buf].push_back(0);
433 count_curve[buf].push_back(0);
434 topdown_curve[buf].push_back(0);
435 bottomup_curve[buf].push_back(0);
436 }
437
438 // extract curves
439 for (auto var : vars) {
440 ItervarFeature& fea = touch_ext.itervar_map[var];
441 for (auto kv : fea.touch_feature) {
442 if (innermost_buffers.find(kv.first) != innermost_buffers.end()) {
443 reuse_curve[kv.first].emplace_back(std::log(kv.second.reuse) / std::log(2));
444 count_curve[kv.first].emplace_back(std::log(kv.second.count) / std::log(2));
445 topdown_curve[kv.first].emplace_back(std::log(fea.topdown_product) / std::log(2));
446 bottomup_curve[kv.first].emplace_back(std::log(fea.bottomup_product) / std::log(2));
447 }
448 }
449 }
450
451 // sample relation in the curve
452 auto sample_curve = [&](const std::vector<double>& x, const std::vector<double>& y,
453 double weight) {
454 for (int i = 0; i < sample_n; i++) {
455 double xx = i * weight;
456 for (int j = static_cast<int>(x.size()) - 1; j >= 0; j--) {
457 if (xx > x[j] - 1e-6) {
458 ret_feature->emplace_back(y[j]);
459 ret_feature->emplace_back(xx - x[j]);
460 break;
461 }
462 }
463 }
464 };
465
466 // serialize to frontend
467 for (auto k : innermost_buffers) {
468 std::vector<double>& count = count_curve[k];
469 std::vector<double>& reuse = reuse_curve[k];
470 std::vector<double>& top_down = topdown_curve[k];
471
472 std::sort(count.begin(), count.end());
473 std::sort(reuse.begin(), reuse.end());
474 std::sort(top_down.begin(), top_down.end());
475
476 sample_curve(count, reuse, 1);
477 sample_curve(reuse, count, 1);
478 sample_curve(count, top_down, 1);
479 sample_curve(top_down, count, 1);
480 }
481}
482
483// register API for front end
484TVM_REGISTER_GLOBAL("autotvm.feature.GetItervarFeature")
485 .set_body([](TVMArgs args, TVMRetValue* ret) {
486 Stmt stmt = args[0];
487 bool take_log = args[1];
488 Array<Array<Array<PrimExpr>>> ret_feature;
489
490 GetItervarFeature(stmt, take_log, &ret_feature);
491
492 *ret = ret_feature;
493 });
494
495TVM_REGISTER_GLOBAL("autotvm.feature.GetItervarFeatureFlatten")
496 .set_body([](TVMArgs args, TVMRetValue* ret) {
497 Stmt stmt = args[0];
498 bool take_log = args[1];
499 std::vector<float> ret_feature;
500
501 GetItervarFeatureFlatten(stmt, take_log, &ret_feature);
502
503 TVMByteArray arr;
504 arr.size = sizeof(float) * ret_feature.size();
505 arr.data = reinterpret_cast<char*>(ret_feature.data());
506 *ret = arr;
507 });
508
509TVM_REGISTER_GLOBAL("autotvm.feature.GetCurveSampleFeatureFlatten")
510 .set_body([](TVMArgs args, TVMRetValue* ret) {
511 Stmt stmt = args[0];
512 int sample_n = args[1];
513 std::vector<float> ret_feature;
514
515 GetCurveSampleFeatureFlatten(stmt, sample_n, &ret_feature);
516
517 TVMByteArray arr;
518 arr.size = sizeof(float) * ret_feature.size();
519 arr.data = reinterpret_cast<char*>(ret_feature.data());
520 *ret = arr;
521 });
522
523} // namespace autotvm
524} // namespace tvm
525