1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file touch_extractor.cc |
22 | * \brief Extract feature of touch pattern of axes in lowered IR |
23 | */ |
24 | |
25 | #include "touch_extractor.h" |
26 | |
27 | #include <algorithm> |
28 | #include <cmath> |
29 | #include <set> |
30 | #include <unordered_map> |
31 | |
32 | namespace tvm { |
33 | namespace autotvm { |
34 | |
35 | int ParallelLevel(AnnotationType ann) { |
36 | switch (ann) { |
37 | case kBlockX: |
38 | case kBlockY: |
39 | case kBlockZ: |
40 | return 2; |
41 | case kThreadX: |
42 | case kThreadY: |
43 | case kThreadZ: |
44 | case kParallel: |
45 | return 1; |
46 | default: |
47 | return 0; |
48 | } |
49 | } |
50 | |
51 | // get touch pattern from index expression |
52 | class IndexParser : public ExprVisitor { |
53 | public: |
54 | void Parse(PrimExpr expr) { |
55 | pattern_map.clear(); |
56 | this->VisitExpr(expr); |
57 | } |
58 | |
59 | void VisitExpr_(const VarNode* op) final { |
60 | // TODO(lmzheng): handle more index types (multiple occurrence) |
61 | if (pattern_map.count(op) == 0) { |
62 | pattern_map[op] = TouchPattern(); |
63 | pattern_map[op].stride = next_stride_; |
64 | next_stride_ = 1; |
65 | } |
66 | } |
67 | |
68 | void VisitExpr_(const MulNode* op) final { |
69 | if (op->a.as<VarNode>()) { |
70 | if (const auto stride = op->b.as<IntImmNode>()) { |
71 | next_stride_ = stride->value; |
72 | } |
73 | } |
74 | ExprVisitor::VisitExpr_(op); |
75 | } |
76 | |
77 | std::unordered_map<const VarNode*, TouchPattern> pattern_map; |
78 | |
79 | private: |
80 | int64_t next_stride_ = 1; |
81 | }; |
82 | |
83 | // extract iter vars and their touch pattern from ir |
84 | bool TouchExtractor::(Var var, int64_t length, AnnotationType ann_type) { |
85 | // do not insert duplicated occurrences of virtual thread |
86 | if (ann_type == kVirtualThread && itervar_map.count(var) != 0) { |
87 | skip_stack_size_.push_back(itervar_stack_.size()); |
88 | return true; |
89 | } else { |
90 | itervar_stack_.push_back(var); |
91 | topdown_product_ *= length; |
92 | |
93 | if (itervar_map.count(var) != 0) { |
94 | // find two duplicated axes |
95 | // these happens when we create tvm.thread_axis("threadIdx.x") once and |
96 | // bind it twice. Here we treat them as two axes |
97 | // so we create a snapshot for the old one and freeze it |
98 | Var old = Var(var.get()->name_hint); |
99 | itervar_map.insert({old, itervar_map[var]}); |
100 | itervar_map.erase(var); |
101 | } |
102 | |
103 | itervar_map.insert( |
104 | {var, ItervarFeature(var, length, static_cast<int>(itervar_stack_.size()), ann_type, |
105 | topdown_product_, static_cast<int>(itervar_counter_++))}); |
106 | } |
107 | |
108 | return true; |
109 | } |
110 | |
111 | void TouchExtractor::() { |
112 | if (!skip_stack_size_.empty() && skip_stack_size_.back() == itervar_stack_.size()) { |
113 | skip_stack_size_.pop_back(); |
114 | return; |
115 | } |
116 | Var var = itervar_stack_.back(); |
117 | |
118 | // update count and reuse ratio for upper iter vars (includes self) |
119 | for (auto kv : itervar_map[var].touch_feature) { |
120 | if (kv.second.stride != 0) { // multiply count |
121 | for (auto stack_var : itervar_stack_) { |
122 | auto touch_pattern = itervar_map[stack_var].touch_feature.find(kv.first); |
123 | ICHECK(touch_pattern != itervar_map[stack_var].touch_feature.end()); |
124 | touch_pattern->second.count *= itervar_map[var].length; |
125 | } |
126 | } else { // multiply reuse ratio |
127 | for (auto stack_var : itervar_stack_) { |
128 | auto touch_pattern = itervar_map[stack_var].touch_feature.find(kv.first); |
129 | ICHECK(touch_pattern != itervar_map[stack_var].touch_feature.end()); |
130 | touch_pattern->second.reuse *= itervar_map[var].length; |
131 | } |
132 | } |
133 | } |
134 | itervar_stack_.pop_back(); |
135 | |
136 | int64_t length = itervar_map[var].length; |
137 | if (length != 0) topdown_product_ /= length; |
138 | int64_t bottomup_product = -1; |
139 | for (auto kv : itervar_map[var].touch_feature) { |
140 | bottomup_product = std::max(bottomup_product, kv.second.count * kv.second.reuse); |
141 | } |
142 | |
143 | itervar_map[var].bottomup_product = bottomup_product; |
144 | |
145 | // push base to upper parallel axis |
146 | int para_level = ParallelLevel(itervar_map[var].ann); |
147 | // if is the separate line of parallel level, push the base to upper parallel level |
148 | if (!itervar_stack_.empty() && |
149 | ParallelLevel(itervar_map[itervar_stack_.back()].ann) == para_level + 1) { |
150 | for (auto kv : itervar_map[var].touch_feature) { |
151 | for (auto stack_var : itervar_stack_) { |
152 | if (ParallelLevel(itervar_map[stack_var].ann) == para_level + 1) { |
153 | auto touch_pattern = itervar_map[stack_var].touch_feature.find(kv.first); |
154 | ICHECK(touch_pattern != itervar_map[stack_var].touch_feature.end()); |
155 | touch_pattern->second.thread_reuse = -kv.second.reuse; |
156 | touch_pattern->second.thread_count = -kv.second.count; |
157 | // NOTE: use minus as a flag to denote it is a base, |
158 | // indicating it is not the final value |
159 | } |
160 | } |
161 | } |
162 | } |
163 | |
164 | for (auto kv : itervar_map[var].touch_feature) { |
165 | if (kv.second.thread_count < 0) { |
166 | itervar_map[var].touch_feature[kv.first].thread_count = |
167 | kv.second.count / (-kv.second.thread_count); |
168 | itervar_map[var].touch_feature[kv.first].thread_reuse = |
169 | kv.second.reuse / (-kv.second.thread_reuse); |
170 | } |
171 | } |
172 | } |
173 | |
174 | void TouchExtractor::(Var buffer_var, PrimExpr index) { |
175 | std::string name = buffer_var.get()->name_hint; |
176 | TouchedBuffer buf = name + "_" + std::to_string(buffer_counter_[name]++); |
177 | |
178 | // extract touch pattern from index |
179 | IndexParser parser; |
180 | parser.Parse(index); |
181 | |
182 | // push up mem access info |
183 | for (auto var : itervar_stack_) { |
184 | auto x = parser.pattern_map.find(var.get()); |
185 | if (x != parser.pattern_map.end()) { |
186 | itervar_map[var].touch_feature[buf] = x->second; |
187 | } else { |
188 | itervar_map[var].touch_feature[buf] = TouchPattern(); |
189 | } |
190 | } |
191 | } |
192 | |
193 | void TouchExtractor::() {} |
194 | |
195 | /*! |
196 | * \brief Get axis-based feature for all axes |
197 | * \param stmt The statement to be extracted |
198 | * \param bool Whether take log for numerical feature |
199 | * \param ret_feature The buffer where the return value is stored |
200 | * |
201 | * \note The format of return value is |
202 | * (( |
203 | * ('_itervar_', var), |
204 | * ('_attr_', length, nest_level, topdown, bottomup, one_hot_annotation), |
205 | * ('_arith_', add_ct, mul_ct, div_ct), |
206 | * ('data_vec_0', stride, mod, count, reuse, thread_count, thread_reuse), |
207 | * ('conv_0', stride, mod, count, reuse, thread_count, thread_reuse), |
208 | * ), |
209 | * ( |
210 | * ('_itervar_', var2), |
211 | * ('_attr_', length, nest_level, one_hot_annotation), |
212 | * ('_arith_', add_ct, mul_ct, div_ct), |
213 | * ('kernel_vec_0', stride, mod, count, reuse, thread_count, thread_reuse), |
214 | * ('conv_1', stride, mod, count, reuse, thread_count, thread_reuse), |
215 | * )) |
216 | * |
217 | * Itervars are sorted according to their first occurrence position in IR. |
218 | * Buffers touched by an itervar are sorted by their unique names. |
219 | * |
220 | * \note If you want to flatten these features as the input of your model, |
221 | * You can use the faster one GetItervarFeatureFlatten below. |
222 | */ |
223 | void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr>>>* ret_feature) { |
224 | // extract |
225 | TouchExtractor touch_analyzer; |
226 | touch_analyzer.Analyze(stmt); |
227 | |
228 | // sort according to order |
229 | std::vector<Var> vars; |
230 | for (auto kv : touch_analyzer.itervar_map) { |
231 | vars.push_back(kv.first); |
232 | } |
233 | std::sort(vars.begin(), vars.end(), [&](const Var& lhs, const Var& rhs) -> bool { |
234 | return touch_analyzer.itervar_map[lhs].order < touch_analyzer.itervar_map[rhs].order; |
235 | }); |
236 | |
237 | // whether take log for numerical feature |
238 | std::function<double(int64_t)> trans; |
239 | if (take_log) { |
240 | trans = [](int64_t x) { |
241 | if (x < 0) return -std::log(-x + 1) / std::log(2); |
242 | x = x + 1; |
243 | return std::log(x) / std::log(2); |
244 | }; |
245 | } else { |
246 | trans = [](int64_t x) { return x; }; |
247 | } |
248 | |
249 | // serialize for front end |
250 | for (auto var : vars) { |
251 | Array<Array<PrimExpr>> feature_row; |
252 | ItervarFeature& fea = touch_analyzer.itervar_map[var]; |
253 | feature_row.push_back(Array<PrimExpr>{tvm::tir::StringImm("_itervar_" ), var}); |
254 | |
255 | Array<PrimExpr> attr{ |
256 | tvm::tir::StringImm("_attr_" ), |
257 | FloatImm(DataType::Float(32), trans(fea.length)), |
258 | IntImm(DataType::Int(32), fea.nest_level), |
259 | FloatImm(DataType::Float(32), trans(fea.topdown_product)), |
260 | FloatImm(DataType::Float(32), trans(fea.bottomup_product)), |
261 | }; |
262 | // one hot annotation |
263 | for (int i = 0; i < kNum; i++) { |
264 | attr.push_back(i == fea.ann); |
265 | } |
266 | feature_row.push_back(attr); |
267 | |
268 | // arithmetic |
269 | feature_row.push_back(Array<PrimExpr>{ |
270 | tvm::tir::StringImm("_arith_" ), |
271 | FloatImm(DataType::Float(32), trans(fea.add_ct)), |
272 | FloatImm(DataType::Float(32), trans(fea.mul_ct)), |
273 | FloatImm(DataType::Float(32), trans(fea.div_ct)), |
274 | }); |
275 | |
276 | // touch map |
277 | std::vector<TouchedBuffer> bufs; |
278 | for (auto kv : fea.touch_feature) { |
279 | bufs.push_back(kv.first); |
280 | } |
281 | std::sort(bufs.begin(), bufs.end()); |
282 | for (auto k : bufs) { |
283 | TouchPattern& v = fea.touch_feature[k]; |
284 | feature_row.push_back(Array<PrimExpr>{ |
285 | tvm::tir::StringImm(k), |
286 | FloatImm(DataType::Float(32), trans(v.stride)), |
287 | FloatImm(DataType::Float(32), trans(v.mod)), |
288 | FloatImm(DataType::Float(32), trans(v.count)), |
289 | FloatImm(DataType::Float(32), trans(v.reuse)), |
290 | FloatImm(DataType::Float(32), trans(v.thread_count)), |
291 | FloatImm(DataType::Float(32), trans(v.thread_reuse)), |
292 | }); |
293 | } |
294 | |
295 | ret_feature->push_back(feature_row); |
296 | } |
297 | } |
298 | |
299 | /*! |
300 | * \brief Get axis-based feature for all axes and flatten them into a one-dimensional vector. |
301 | * \param stmt The statement to be extracted |
302 | * \param bool Whether take log for numerical feature |
303 | * \param ret_feature The buffer where the return value is stored |
304 | * |
305 | * \note See GetItervarFeature for more details about the return value. |
306 | * This is an optimized version of GetItervarFeature + Flatten. This runs much faster. |
307 | */ |
308 | void GetItervarFeatureFlatten(Stmt stmt, bool take_log, std::vector<float>* ret_feature) { |
309 | // extract touch feature |
310 | TouchExtractor touch_analyzer; |
311 | touch_analyzer.Analyze(stmt); |
312 | |
313 | // sort according to order |
314 | std::vector<Var> vars; |
315 | for (auto kv : touch_analyzer.itervar_map) { |
316 | vars.push_back(kv.first); |
317 | } |
318 | std::sort(vars.begin(), vars.end(), [&](const Var& lhs, const Var& rhs) -> bool { |
319 | return touch_analyzer.itervar_map[lhs].order < touch_analyzer.itervar_map[rhs].order; |
320 | }); |
321 | |
322 | // whether take log for numerical feature |
323 | std::function<float(int64_t)> trans; |
324 | if (take_log) { |
325 | trans = [](int64_t x) { |
326 | if (x < 0) return -std::log(-x + 1) / std::log(2); |
327 | x = x + 1; |
328 | return std::log(x) / std::log(2); |
329 | }; |
330 | } else { |
331 | trans = [](int64_t x) { return x; }; |
332 | } |
333 | |
334 | // serialize for front end |
335 | for (auto var : vars) { |
336 | ItervarFeature& fea = touch_analyzer.itervar_map[var]; |
337 | |
338 | ret_feature->push_back(trans(fea.length)); |
339 | ret_feature->push_back(fea.nest_level); |
340 | ret_feature->push_back(trans(fea.topdown_product)); |
341 | ret_feature->push_back(trans(fea.bottomup_product)); |
342 | |
343 | // one hot annotation |
344 | for (int i = 0; i < kNum; i++) { |
345 | ret_feature->push_back(i == fea.ann); |
346 | } |
347 | |
348 | // arithmetic |
349 | ret_feature->push_back(trans(fea.add_ct)); |
350 | ret_feature->push_back(trans(fea.mul_ct)); |
351 | ret_feature->push_back(trans(fea.div_ct)); |
352 | |
353 | // touch map |
354 | std::vector<TouchedBuffer> bufs; |
355 | for (auto kv : fea.touch_feature) { |
356 | bufs.push_back(kv.first); |
357 | } |
358 | std::sort(bufs.begin(), bufs.end()); |
359 | for (auto k : bufs) { |
360 | TouchPattern& v = fea.touch_feature[k]; |
361 | ret_feature->push_back(trans(v.stride)); |
362 | ret_feature->push_back(trans(v.mod)); |
363 | ret_feature->push_back(trans(v.count)); |
364 | ret_feature->push_back(trans(v.reuse)); |
365 | ret_feature->push_back(trans(v.thread_count)); |
366 | ret_feature->push_back(trans(v.thread_reuse)); |
367 | } |
368 | } |
369 | } |
370 | |
371 | /*! |
372 | * \brief Get curve sample feature (relation feature) and flatten them into a one-dimensional |
373 | * vector. \param stmt The statement to be extracted \param sample_n The number of points used for |
374 | * sampling a curve (along one dimension) \param ret_feature The buffer where the return value is |
375 | * stored |
376 | */ |
377 | void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector<float>* ret_feature) { |
378 | // extract touch feature |
379 | TouchExtractor touch_ext; |
380 | touch_ext.Analyze(stmt); |
381 | |
382 | // sort according to order |
383 | std::vector<Var> vars; |
384 | for (auto kv : touch_ext.itervar_map) { |
385 | vars.push_back(kv.first); |
386 | } |
387 | std::sort(vars.begin(), vars.end(), [&](const Var& lhs, const Var& rhs) -> bool { |
388 | return touch_ext.itervar_map[lhs].order < touch_ext.itervar_map[rhs].order; |
389 | }); |
390 | |
391 | int max_depth = 0; |
392 | std::map<TouchedBuffer, std::vector<double>> reuse_curve; |
393 | std::map<TouchedBuffer, std::vector<double>> count_curve; |
394 | std::map<TouchedBuffer, std::vector<double>> topdown_curve; |
395 | std::map<TouchedBuffer, std::vector<double>> bottomup_curve; |
396 | std::set<TouchedBuffer> innermost_buffers; |
397 | std::set<std::string> added; |
398 | |
399 | // find maximum depth of loop nest |
400 | for (auto var : vars) { |
401 | ItervarFeature& fea = touch_ext.itervar_map[var]; |
402 | max_depth = std::max(max_depth, fea.nest_level); |
403 | } |
404 | |
405 | // mark inner most buffer |
406 | for (auto iter = vars.rbegin(); iter != vars.rend(); iter++) { |
407 | auto var = *iter; |
408 | ItervarFeature& fea = touch_ext.itervar_map[var]; |
409 | if (fea.nest_level == max_depth) { |
410 | for (auto kv : fea.touch_feature) { |
411 | // delete buffer no (e.g. 'A_0' -> 'A', 'A_1' -> 'A') |
412 | std::string raw_name = kv.first.substr(0, kv.first.rfind("_" )); |
413 | |
414 | // delete memory scope (e.g. 'A.local' -> 'A', 'A.shared' -> 'A') |
415 | size_t pos = raw_name.find("." ); |
416 | if (pos < kv.first.size()) raw_name = raw_name.substr(0, pos); |
417 | |
418 | // If there are multiple innermost buffers that are derived from a same raw buffer |
419 | // We only record the last occurrence (note the `iter` is in reverse order) |
420 | // e.g. `A.local`, `A.shared` are derived from `A`, if they all occurred at the inner most |
421 | // level, we will only record the last occurrence, |
422 | if (added.find(raw_name) == added.end()) { |
423 | innermost_buffers.insert(kv.first); |
424 | added.insert(raw_name); |
425 | } |
426 | } |
427 | } |
428 | } |
429 | |
430 | // pad the first point (zero) for all curves |
431 | for (auto buf : innermost_buffers) { |
432 | reuse_curve[buf].push_back(0); |
433 | count_curve[buf].push_back(0); |
434 | topdown_curve[buf].push_back(0); |
435 | bottomup_curve[buf].push_back(0); |
436 | } |
437 | |
438 | // extract curves |
439 | for (auto var : vars) { |
440 | ItervarFeature& fea = touch_ext.itervar_map[var]; |
441 | for (auto kv : fea.touch_feature) { |
442 | if (innermost_buffers.find(kv.first) != innermost_buffers.end()) { |
443 | reuse_curve[kv.first].emplace_back(std::log(kv.second.reuse) / std::log(2)); |
444 | count_curve[kv.first].emplace_back(std::log(kv.second.count) / std::log(2)); |
445 | topdown_curve[kv.first].emplace_back(std::log(fea.topdown_product) / std::log(2)); |
446 | bottomup_curve[kv.first].emplace_back(std::log(fea.bottomup_product) / std::log(2)); |
447 | } |
448 | } |
449 | } |
450 | |
451 | // sample relation in the curve |
452 | auto sample_curve = [&](const std::vector<double>& x, const std::vector<double>& y, |
453 | double weight) { |
454 | for (int i = 0; i < sample_n; i++) { |
455 | double xx = i * weight; |
456 | for (int j = static_cast<int>(x.size()) - 1; j >= 0; j--) { |
457 | if (xx > x[j] - 1e-6) { |
458 | ret_feature->emplace_back(y[j]); |
459 | ret_feature->emplace_back(xx - x[j]); |
460 | break; |
461 | } |
462 | } |
463 | } |
464 | }; |
465 | |
466 | // serialize to frontend |
467 | for (auto k : innermost_buffers) { |
468 | std::vector<double>& count = count_curve[k]; |
469 | std::vector<double>& reuse = reuse_curve[k]; |
470 | std::vector<double>& top_down = topdown_curve[k]; |
471 | |
472 | std::sort(count.begin(), count.end()); |
473 | std::sort(reuse.begin(), reuse.end()); |
474 | std::sort(top_down.begin(), top_down.end()); |
475 | |
476 | sample_curve(count, reuse, 1); |
477 | sample_curve(reuse, count, 1); |
478 | sample_curve(count, top_down, 1); |
479 | sample_curve(top_down, count, 1); |
480 | } |
481 | } |
482 | |
483 | // register API for front end |
484 | TVM_REGISTER_GLOBAL("autotvm.feature.GetItervarFeature" ) |
485 | .set_body([](TVMArgs args, TVMRetValue* ret) { |
486 | Stmt stmt = args[0]; |
487 | bool take_log = args[1]; |
488 | Array<Array<Array<PrimExpr>>> ret_feature; |
489 | |
490 | GetItervarFeature(stmt, take_log, &ret_feature); |
491 | |
492 | *ret = ret_feature; |
493 | }); |
494 | |
495 | TVM_REGISTER_GLOBAL("autotvm.feature.GetItervarFeatureFlatten" ) |
496 | .set_body([](TVMArgs args, TVMRetValue* ret) { |
497 | Stmt stmt = args[0]; |
498 | bool take_log = args[1]; |
499 | std::vector<float> ret_feature; |
500 | |
501 | GetItervarFeatureFlatten(stmt, take_log, &ret_feature); |
502 | |
503 | TVMByteArray arr; |
504 | arr.size = sizeof(float) * ret_feature.size(); |
505 | arr.data = reinterpret_cast<char*>(ret_feature.data()); |
506 | *ret = arr; |
507 | }); |
508 | |
509 | TVM_REGISTER_GLOBAL("autotvm.feature.GetCurveSampleFeatureFlatten" ) |
510 | .set_body([](TVMArgs args, TVMRetValue* ret) { |
511 | Stmt stmt = args[0]; |
512 | int sample_n = args[1]; |
513 | std::vector<float> ret_feature; |
514 | |
515 | GetCurveSampleFeatureFlatten(stmt, sample_n, &ret_feature); |
516 | |
517 | TVMByteArray arr; |
518 | arr.size = sizeof(float) * ret_feature.size(); |
519 | arr.data = reinterpret_cast<char*>(ret_feature.data()); |
520 | *ret = arr; |
521 | }); |
522 | |
523 | } // namespace autotvm |
524 | } // namespace tvm |
525 | |