1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file auto_scheduler/feature.cc
22 * \brief Feature extraction for the cost model
23 */
24
25#include <tvm/arith/analyzer.h>
26#include <tvm/auto_scheduler/feature.h>
27#include <tvm/auto_scheduler/measure.h>
28#include <tvm/auto_scheduler/measure_record.h>
29#include <tvm/driver/driver_api.h>
30#include <tvm/ir/global_var_supply.h>
31#include <tvm/runtime/registry.h>
32#include <tvm/support/parallel_for.h>
33#include <tvm/te/operation.h>
34#include <tvm/te/schedule_pass.h>
35#include <tvm/tir/analysis.h>
36#include <tvm/tir/op_attr_types.h>
37#include <tvm/tir/stmt_functor.h>
38#include <tvm/tir/transform.h>
39
40#include <algorithm>
41#include <cassert>
42#include <cmath>
43#include <numeric>
44#include <unordered_map>
45#include <vector>
46
47#include "search_policy/utils.h"
48#include "utils.h"
49
50namespace tvm {
51namespace auto_scheduler {
52
53using namespace tvm::tir;
54using arith::Analyzer;
55using arith::ConstIntBound;
56
57template <class T>
58using BufferMap = std::unordered_map<Var, T, ObjectHash, ObjectEqual>;
59
60// The number of samples to extract for arithmetic intensity curves
61static const int ARITH_INTENSITY_CURVE_SAMPLE_N = 10;
62
63// Annotation position encoding
64enum class AnnotationPosType : int {
65 kPosNone = 0, // Does not have this kind of annotation
66 kPosInnerSpatial = 1, // The annotated iterator is the innermost spatial iterator
67 kPosMiddleSpatial = 2, // The annotated iterator is a middle spatial iterator
68 kPosOuterSpatial = 3, // The annotated iterator is the outermost spatial iterator
69 kPosInnerReduce = 4, // The annotated iterator is the innermost reduce iterator
70 kPosMiddleReduce = 5, // The annotated iterator is a middle reduce iterator
71 kPosOuterReduce = 6, // The annotated iterator is the outermost reduce iterator
72 kPosMixed = 7 // The annotated iterator is a mixed space and reduce iterator
73};
74
75// Buffer access type
76enum class BufferAccessType : int { kRead = 0, kWrite = 1, kReadWrite = 2, kUnknownRW = 3 };
77
78// Accesses to a buffer
79struct BufferAccess {
80 // data reuse type
81 BufferAccessType acc_type{BufferAccessType::kUnknownRW};
82 // Use a two-dimensional array to store multiple multi-dimensional accesses.
83 // The innermost vector stores the multi-dimensional indices of one access.
84 std::vector<std::vector<PrimExpr>> indices;
85};
86
87// Data reuse type
88enum class ReuseType : int { kLoopMultipleRead = 0, kSerialMultipleReadWrite = 1, kNoReuse = 2 };
89
90// Feature for an access of a buffer
91struct BufferAccessFeature {
92 std::string buffer_name; // The name of the buffer
93 BufferAccessType acc_type; // The type of the access
94 float bytes; // The touched memory in bytes
95 float unique_bytes; // The touched unique memory in bytes
96 float lines; // The number of touched cache lines
97 float unique_lines; // The number touched unique cache lines
98 ReuseType reuse_type; // Tye type of data reuse
99 float reuse_dis_iter; // The reuse distance in iterator number
100 float reuse_dis_bytes; // The reuse distance in total touched bytes
101 float reuse_ct; // The reuse ratio
102 float bytes_d_reuse_ct; // bytes / reuse_ct
103 float unique_bytes_d_reuse_ct; // unique_bytes / reuse_ct
104 float lines_d_reuse_ct; // lines / reuse_ct
105 float unique_lines_d_reuse_ct; // unique_lines / reuse_ct
106 float stride; // The stride in access
107};
108
109// Feature set of a BufferStore statement
110struct FeatureSet {
111 // Group 1: Computation related features
112 float float_mad; // The number of float MAD (Multiply–add) ops
113 float float_addsub; // The number of float add and sub ops
114 float float_mul; // The number of float multiply ops
115 float float_divmod; // The number of float div and mod ops
116 float float_cmp; // The number of float comparison ops
117 float float_math_func; // The number of float math func calls
118 float float_other_func; // The number of other float func calls
119 float int_mad; // The number of integer MAD (Multiply–add) ops
120 float int_addsub; // The number of integer add and sub ops
121 float int_mul; // The number of float multiply ops
122 float int_divmod; // The number of float div and mod ops
123 float int_cmp; // The number of float comparison ops
124 float int_math_func; // The number of float math func calls
125 float int_other_func; // The number of other float func calls
126 float bool_op; // The number of bool ops
127 float select_op; // The number of select ops
128 float vec_num; // The number of vectorized iterators
129 float vec_prod; // The product of the lengths of vectorized iterators
130 float vec_len; // The length of the innermost vectorized iterator
131 AnnotationPosType vec_type; // The type of vectorization position
132 float unroll_num; // The number of unrolled iterators
133 float unroll_prod; // The product of the lengths of vectorized iterators
134 float unroll_len; // The length of the innermost unrolled iterator
135 AnnotationPosType unroll_type; // The type of unroll position
136 float parallel_num; // The number of paralleled iterators
137 float parallel_prod; // The product of the lengths of paralleled iterators
138 float parallel_len; // The length of the innermost paralleled iterators
139 AnnotationPosType parallel_type; // The type of parallel position
140 float is_gpu; // Whether it is a GPU task
141 float blockIdx_x_len; // The length of blockIdx.x
142 float blockIdx_y_len; // The length of blockIdx.y
143 float blockIdx_z_len; // The length of blockIdx.z
144 float threadIdx_x_len; // The length of threadIdx.x
145 float threadIdx_y_len; // The length of threadIdx.y
146 float threadIdx_z_len; // The length of threadIdx.z
147 float vthread_len; // The length of virtual thread
148
149 // Group 2: Buffer access related features (per buffer)
150 std::vector<BufferAccessFeature> access_feas;
151
152 // Group 3: Arithmetic intensity related features
153 float arith_intensity_curve[ARITH_INTENSITY_CURVE_SAMPLE_N]; // points sampled from the
154 // arithmetic intensity curve
155
156 // Group 4: Allocation related features
157 float alloc_size; // The size of allocated buffer in bytes
158 float alloc_outer_prod; // The product of lengths of loops outside the scope of the allocation
159 float alloc_inner_prod; // The product of lengths of loops inside the score of the allocation
160 float alloc_prod; // alloc_outer_prod * alloc_inner_prod
161
162 // Group 5: Outer scope related features
163 float outer_prod; // The product of lengths of outer loops
164 float num_loops; // The number of outer loops
165 float auto_unroll_max_step; // The value of pragma "auto_unroll_max_step"
166};
167
168// Return whether a var is in an expr
169bool VarInExpr(const Var& var, const PrimExpr& expr) {
170 bool find = false;
171
172 PostOrderVisit(expr, [&find, &var](const ObjectRef& node) {
173 if (find) {
174 return;
175 }
176
177 if (const VarNode* op = node.as<VarNode>()) {
178 if (op == var.get()) {
179 find = true;
180 }
181 }
182 });
183
184 return find;
185}
186
187// Get position encoding for annotation
188AnnotationPosType GetAnnotationPosEncoding(const Var& var, const Array<PrimExpr>& spatial_args,
189 const Array<IterVar>& axis,
190 const Array<IterVar>& reduce_axis) {
191 // Try to match spatial args first
192 size_t find_i = 0;
193 size_t find_ct = 0;
194 for (size_t i = 0; i < spatial_args.size(); ++i) {
195 if (VarInExpr(var, spatial_args[i])) {
196 find_i = i;
197 find_ct += 1;
198 }
199 }
200
201 if (find_ct == 0) {
202 // If it is not found in spacial args, then it is a reduce iterator.
203 // Use name to match
204 const std::string& var_name = var->name_hint;
205 for (size_t i = 0; i < reduce_axis.size(); ++i) {
206 if (var_name.find(reduce_axis[i]->var->name_hint) != std::string::npos) {
207 find_i = i;
208 find_ct++;
209 }
210 }
211 if (find_ct >= 1) {
212 if (find_i == 0) {
213 return AnnotationPosType::kPosInnerReduce;
214 } else if (find_i == reduce_axis.size() - 1) {
215 return AnnotationPosType::kPosOuterReduce;
216 } else {
217 return AnnotationPosType::kPosMiddleReduce;
218 }
219 } else {
220 // If the axis is not found in both spatial args and reduce axis,
221 // then this stage must compute_at somewhere under this axis and this axis is simplified out
222 // We assume it is an outer spatial
223 return AnnotationPosType::kPosOuterSpatial;
224 }
225 } else if (find_ct == 1) {
226 if (find_i == spatial_args.size() - 1) {
227 return AnnotationPosType::kPosInnerSpatial;
228 } else if (find_i == 0) {
229 return AnnotationPosType::kPosOuterSpatial;
230 } else {
231 return AnnotationPosType::kPosMiddleSpatial;
232 }
233 } else {
234 return AnnotationPosType::kPosMixed;
235 }
236}
237
238// Return the maximum extent of a for loop
239int64_t GetLoopExtent(const ForNode* node, const Analyzer& ana) {
240 int64_t bound = ana.const_int_bound(node->extent)->max_value;
241 if (bound == ConstIntBound::kPosInf) {
242 return 1; // Analyzer could not determine a valid bound, use 1 instead.
243 } else {
244 return bound;
245 }
246}
247
248// Count math ops in an expr
249class MathOpCounter : public StmtExprVisitor {
250 public:
251#define VisitBinary(Type, float_ct, int_ct) \
252 void VisitExpr_(const Type* op) final { \
253 if (op->a.dtype().is_float() || op->a.dtype().is_bfloat16()) { \
254 float_ct += op->a.dtype().lanes(); \
255 } else { \
256 int_ct += op->a.dtype().lanes(); \
257 } \
258 StmtExprVisitor::VisitExpr_(op); \
259 }
260
261 VisitBinary(AddNode, float_addsub, int_addsub);
262 VisitBinary(SubNode, float_addsub, int_addsub);
263 VisitBinary(MulNode, float_mul, int_mul);
264 VisitBinary(DivNode, float_divmod, int_divmod);
265 VisitBinary(ModNode, float_divmod, int_divmod);
266 VisitBinary(FloorDivNode, float_divmod, int_divmod);
267 VisitBinary(FloorModNode, float_divmod, int_divmod);
268 VisitBinary(MaxNode, float_cmp, int_cmp);
269 VisitBinary(MinNode, float_cmp, int_cmp);
270 VisitBinary(EQNode, float_cmp, int_cmp);
271 VisitBinary(NENode, float_cmp, int_cmp);
272 VisitBinary(LTNode, float_cmp, int_cmp);
273 VisitBinary(LENode, float_cmp, int_cmp);
274 VisitBinary(GTNode, float_cmp, int_cmp);
275 VisitBinary(GENode, float_cmp, int_cmp);
276
277#undef VisitBinary
278
279 void VisitExpr_(const AndNode* op) final {
280 bool_op++;
281 StmtExprVisitor::VisitExpr_(op);
282 }
283 void VisitExpr_(const OrNode* op) final {
284 bool_op++;
285 StmtExprVisitor::VisitExpr_(op);
286 }
287 void VisitExpr_(const NotNode* op) final {
288 bool_op++;
289 StmtExprVisitor::VisitExpr_(op);
290 }
291 void VisitExpr_(const SelectNode* op) final {
292 select_op++;
293 StmtExprVisitor::VisitExpr_(op);
294 }
295
296 void VisitExpr_(const CallNode* op) final {
297 auto* pop = op->op.as<OpNode>();
298 ICHECK(pop != nullptr);
299 auto effect_kind = op_call_effect_[GetRef<Op>(pop)];
300 bool is_pure =
301 effect_kind == CallEffectKind::kPure || effect_kind == CallEffectKind::kExprAnnotation;
302
303 if (is_pure) {
304 if (op->dtype.is_float() || op->dtype.is_bfloat16()) {
305 float_math_func++;
306 } else {
307 int_math_func++;
308 }
309 } else {
310 if (op->dtype.is_float() || op->dtype.is_bfloat16()) {
311 float_other_func++;
312 } else {
313 int_other_func++;
314 }
315 }
316 StmtExprVisitor::VisitExpr_(op);
317 }
318
319 // todo(merrymercy): Detect MAD (Multiply–add)
320 size_t float_mad{0}; // The number of float MAD (Multiply–add) ops
321 size_t float_addsub{0}; // The number of float add and sub ops
322 size_t float_mul{0}; // The number of float multiply ops
323 size_t float_divmod{0}; // The number of float div and mod ops
324 size_t float_cmp{0}; // The number of float comparison ops
325 size_t float_math_func{0}; // The number of float math func calls
326 size_t float_other_func{0}; // The number of other float func calls
327 size_t int_mad{0}; // The number of integer MAD (Multiply–add) ops
328 size_t int_addsub{0}; // The number of integer add and sub ops
329 size_t int_mul{0}; // The number of float multiply ops
330 size_t int_divmod{0}; // The number of float div and mod ops
331 size_t int_cmp{0}; // The number of float comparison ops
332 size_t int_math_func{0}; // The number of float math func calls
333 size_t int_other_func{0}; // The number of other float func calls
334 size_t bool_op{0}; // The number of bool ops
335 size_t select_op{0}; // The number of select ops
336
337 OpAttrMap<TCallEffectKind> op_call_effect_ = Op::GetAttrMap<TCallEffectKind>("TCallEffectKind");
338};
339
340// Extract all buffer accesses in an expr
341class BufferAccessExtractor : public StmtExprVisitor {
342 public:
343 void ExtractReads(const PrimExpr& expr) { this->VisitExpr(expr); }
344
345 void InsertAccess(const Var& buf, BufferAccessType acc_type, const Array<PrimExpr>& indices) {
346 BufferAccess& acc = buf_accesses[buf];
347 acc.acc_type = acc_type;
348 acc.indices.push_back(std::vector<PrimExpr>(indices.begin(), indices.end()));
349 }
350
351 void VisitExpr_(const BufferLoadNode* op) final {
352 AddAccess(op->buffer->data, op->indices);
353 StmtExprVisitor::VisitExpr_(op);
354 }
355
356 void AddAccess(const Var& buffer, const Array<PrimExpr>& indices) {
357 BufferAccess& acc = buf_accesses[buffer];
358 switch (acc.acc_type) {
359 case BufferAccessType::kRead:
360 break;
361 case BufferAccessType::kWrite:
362 acc.acc_type = BufferAccessType::kReadWrite;
363 break;
364 case BufferAccessType::kReadWrite:
365 break;
366 case BufferAccessType::kUnknownRW:
367 default:
368 acc.acc_type = BufferAccessType::kRead;
369 break;
370 }
371
372 if (acc.acc_type != BufferAccessType::kReadWrite) {
373 // If a buffer is both read and written, in the tvm DSL, it must be a update,
374 // so the indices should be the same. Then we can skip appending indices for it.
375 // Otherwise we do the following.
376 buf_accesses[buffer].indices.push_back(std::vector<PrimExpr>(indices.begin(), indices.end()));
377 }
378 }
379
380 BufferMap<BufferAccess> buf_accesses;
381};
382
383// Compute the coefficient for an loop iterator in an expression
384// Note: we use an approximation strategy to find coefficient.
385// Hopefully, it is faster than DetectLinearEquation and can handle more cases (non-linear)
386class CoefficientExtractor : public StmtExprVisitor {
387 public:
388 void VisitExpr_(const MulNode* node) final {
389 StmtExprVisitor::VisitExpr_(node);
390 if (visited_var) {
391 if (!visited_add) {
392 if (auto a = node->a.as<IntImmNode>()) {
393 visited_mul = true;
394 stride = a->value;
395 } else if (auto b = node->b.as<IntImmNode>()) {
396 visited_mul = true;
397 stride = b->value;
398 }
399 }
400 }
401 }
402
403 void VisitExpr_(const AddNode* node) final {
404 StmtExprVisitor::VisitExpr_(node);
405 if (visited_var) {
406 if (!visited_mul) {
407 visited_add = true;
408 stride = 1;
409 }
410 }
411 }
412
413 void VisitExpr_(const VarNode* node) final {
414 if (node == var_) {
415 visited_var = true;
416 // This is a magic default stride in case our approximation strategy fails
417 stride = 2;
418 }
419 }
420
421 int ExtractCoefficient(const PrimExpr& expr, const VarNode* var) {
422 visited_var = visited_mul = visited_add = false;
423 var_ = var;
424
425 this->VisitExpr(expr);
426
427 if (visited_var && !visited_mul && !visited_add) {
428 return 1;
429 } else {
430 return stride;
431 }
432 }
433
434 bool visited_var{false};
435 bool visited_mul{false};
436 bool visited_add{false};
437 int stride{0};
438
439 private:
440 const VarNode* var_{nullptr};
441};
442
443// Compute stride for the accesses to a buffer
444int64_t ComputeStride(const std::vector<std::vector<PrimExpr>>& indices,
445 const std::vector<int>& shape, const VarNode* stride_var) {
446 // Use stride of 1 for 0-dimensional buffers. 0-dim buffers has a single
447 // index access, so we have to check here.
448 if (shape.size() == 0) {
449 return 1;
450 }
451 int64_t min_stride = std::numeric_limits<int64_t>::max();
452 bool find = false;
453 CoefficientExtractor extractor;
454
455 for (const auto& index : indices) {
456 int64_t shape_stride = 1;
457 for (int i = static_cast<int>(index.size()) - 1; i >= 0; i--) {
458 int coefficient = extractor.ExtractCoefficient(index[i], stride_var);
459 if (extractor.visited_var) {
460 find = true;
461 min_stride = std::min(min_stride, std::abs(coefficient) * shape_stride);
462 break;
463 }
464 shape_stride *= shape[i];
465 }
466 }
467
468 return find ? min_stride : 0;
469}
470
471// Compute touched bytes and cache lines for accesses to a buffer
472void ComputeRegion(const std::vector<std::vector<PrimExpr>>& indices, arith::Analyzer* ana,
473 std::vector<int>* region) {
474 region->clear();
475
476 if (indices.empty()) {
477 return;
478 }
479
480 region->reserve(indices[0].size());
481
482 if (indices.size() == 1) {
483 for (const auto& index : indices[0]) {
484 ConstIntBound bound = ana->const_int_bound(index);
485 region->push_back(bound->max_value - bound->min_value + 1);
486 }
487 } else {
488 // future(lmzheng): implement a more accurate IntSet?
489 for (size_t i = 0; i < indices[0].size(); ++i) {
490 int64_t minimum = ConstIntBound::kPosInf, maximum = ConstIntBound::kNegInf;
491 for (size_t j = 0; j < indices.size(); ++j) {
492 ConstIntBound bound = ana->const_int_bound(indices[j][i]);
493
494 minimum = std::min(minimum, bound->min_value);
495 maximum = std::max(maximum, bound->max_value);
496 }
497 region->push_back(maximum - minimum + 1);
498 }
499 }
500}
501
502// Compute reuse distance and reuse ratio for accesses to a buffer
503// return values: reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct
504std::tuple<ReuseType, float, float, float> ComputeReuse(
505 const Var& buf, const std::vector<std::vector<PrimExpr>>& indices,
506 const std::vector<const ForNode*>& for_loop_stack,
507 const std::unordered_map<const ForNode*,
508 BufferMap<std::vector<std::tuple<BufferAccessType, int64_t, int>>>>&
509 for_touch_regions,
510 const Analyzer& ana) {
511 float reuse_dis_iter = 1.0f;
512 float reuse_dis_bytes = -1.0f;
513
514 for (int i = static_cast<int>(for_loop_stack.size()) - 1; i >= 0; --i) {
515 const ForNode* cur_for = for_loop_stack[i];
516 bool find = false;
517
518 for (size_t j = 0; j < indices.size(); j++) {
519 for (size_t k = 0; k < indices[j].size(); k++) {
520 if (VarInExpr(cur_for->loop_var, indices[j][k])) {
521 find = true;
522 break;
523 }
524 }
525 if (find) {
526 break;
527 }
528 }
529
530 int64_t extent = GetLoopExtent(for_loop_stack[i], ana);
531 if (find) {
532 // accumulate/update reuse distance
533 reuse_dis_iter *= extent;
534 reuse_dis_bytes = 0.0f;
535 for (const auto& iter : for_touch_regions.at(cur_for)) {
536 for (const auto& access : iter.second) {
537 reuse_dis_bytes += std::get<1>(access) * std::get<2>(access);
538 }
539 }
540 } else {
541 // Have LoopMultipleRead reuse
542 if (reuse_dis_bytes < 0) {
543 // For the reuse in the innermost axis, the above code won't be executed.
544 // So we compute bytes here
545 reuse_dis_bytes = 0.0f;
546 for (const auto& iter : for_touch_regions.at(cur_for)) {
547 for (const auto& access : iter.second) {
548 reuse_dis_bytes += 1 * std::get<2>(access);
549 }
550 }
551 }
552 return std::make_tuple(ReuseType::kLoopMultipleRead, reuse_dis_iter, reuse_dis_bytes, extent);
553 }
554
555 const BufferMap<std::vector<std::tuple<BufferAccessType, int64_t, int>>>& buffer_map =
556 for_touch_regions.at(cur_for);
557
558 int serial_reuse = static_cast<int>(buffer_map.at(buf).size()) - 1;
559 if (serial_reuse > 0) {
560 int64_t extent = GetLoopExtent(cur_for, ana);
561
562 // Have SerialMultipleReadWrite reuse
563 reuse_dis_iter = std::numeric_limits<float>::max();
564 for (const auto& acc_info : buffer_map.at(buf)) {
565 reuse_dis_iter = std::min(reuse_dis_iter, static_cast<float>(std::get<1>(acc_info)));
566 }
567
568 reuse_dis_bytes = 0.0f;
569 for (const auto& iter : for_touch_regions.at(cur_for)) {
570 for (const auto& access : iter.second) {
571 reuse_dis_bytes += std::get<1>(access) * std::get<2>(access);
572 }
573 }
574
575 return std::make_tuple(ReuseType::kSerialMultipleReadWrite, reuse_dis_iter / extent,
576 reuse_dis_bytes / extent, serial_reuse);
577 }
578 }
579
580 return std::make_tuple(ReuseType::kNoReuse, 0, 0, 0);
581}
582
583// Extract features for every BufferStore statement
584//
585// This visitor assumes that loop bounds do no depend on data or on parent loop
586// bounds. For example, `for i in .. { for j in range(i, ..) }` would result in
587// inaccurate features. This visitor also does not take conditionals into
588// consideration when creating features. Each branch of the conditional is
589// taken at the same time.
590class PerStoreFeatureExtractor : public StmtExprVisitor {
591 public:
592 explicit PerStoreFeatureExtractor(int cache_line_size, const Map<Var, Buffer>& existing_buffers)
593 : cache_line_size_(cache_line_size) {
594 for (const auto& buffer : existing_buffers) {
595 buffer_shapes[buffer.first] = buffer.second->shape;
596 buffer_dtypes[buffer.first] = buffer.second->dtype;
597 // Also need to add a reference from the buffers internal variable. This
598 // is usually how buffers are referenced within the body of a PrimFunc
599 buffer_shapes[buffer.second->data] = buffer.second->shape;
600 buffer_dtypes[buffer.second->data] = buffer.second->dtype;
601 }
602 }
603
604 void VisitStmt_(const AttrStmtNode* node) final {
605 if (node->attr_key == tir::attr::thread_extent || node->attr_key == tir::attr::virtual_thread) {
606 const Var& var = node->node.as<IterVarNode>()->var;
607 int extent = GetIntImm(node->value);
608
609 int* plen = nullptr;
610
611 const std::string& name = var.get()->name_hint;
612 if (node->attr_key == tir::attr::thread_extent) {
613 if (name == "blockIdx.x") {
614 plen = &blockIdx_x_len_;
615 } else if (name == "blockIdx.y") {
616 plen = &block_idx_y_len_;
617 } else if (name == "blockIdx.z") {
618 plen = &block_idx_z_len_;
619 } else if (name == "threadIdx.x") {
620 plen = &threadIdx_x_len_;
621 } else if (name == "threadIdx.y") {
622 plen = &thread_idx_y_len_;
623 } else if (name == "threadIdx.z") {
624 plen = &thread_idx_z_len_;
625 } else {
626 LOG(FATAL) << "invalid thread itervar " + name;
627 }
628 } else {
629 plen = &vthread_len_;
630 }
631
632 int extent_before = *plen;
633 if (node->attr_key == tir::attr::thread_extent) {
634 *plen = extent;
635 } else {
636 *plen *= extent;
637 }
638
639 is_gpu_ = true;
640
641 // make a fake for node for blockIdx.x or threadIdx.x
642 Stmt fake_for_node = For(var, 0, extent, ForKind::kParallel, node->body);
643
644 outer_loop_prod_ *= extent;
645 for_loop_stack_.push_back(fake_for_node.as<ForNode>());
646 variable_definition_stack_.push_back({});
647 StmtExprVisitor::VisitStmt_(node);
648 variable_definition_stack_.pop_back();
649 for_loop_stack_.pop_back();
650 outer_loop_prod_ /= extent;
651
652 *plen = extent_before;
653 } else if (node->attr_key == "pragma_auto_unroll_max_step") {
654 int value = GetIntImm(node->value);
655
656 int16_t old_value = cur_auto_unroll_max_step_;
657 cur_auto_unroll_max_step_ = value;
658 StmtExprVisitor::VisitStmt_(node);
659 cur_auto_unroll_max_step_ = old_value;
660 } else {
661 StmtExprVisitor::VisitStmt_(node);
662 }
663 }
664
665 void VisitStmt_(const ForNode* node) final {
666 ana_.Bind(node->loop_var, Range::FromMinExtent(node->min, node->extent));
667 int64_t loop_extent = GetLoopExtent(node, ana_);
668
669 if (node->kind == ForKind::kVectorized) {
670 vec_for_stack_.push_back(node);
671 } else if (node->kind == ForKind::kUnrolled) {
672 unroll_for_stack_.push_back(node);
673 } else if (node->kind == ForKind::kParallel) {
674 parallel_for_stack_.push_back(node);
675 }
676
677 outer_loop_prod_ *= loop_extent;
678 for_loop_stack_.push_back(node);
679 variable_definition_stack_.push_back({});
680 StmtExprVisitor::VisitStmt_(node);
681 variable_definition_stack_.pop_back();
682 for_loop_stack_.pop_back();
683 outer_loop_prod_ /= loop_extent;
684
685 if (node->kind == ForKind::kVectorized) {
686 vec_for_stack_.pop_back();
687 } else if (node->kind == ForKind::kUnrolled) {
688 unroll_for_stack_.pop_back();
689 } else if (node->kind == ForKind::kParallel) {
690 parallel_for_stack_.pop_back();
691 }
692 }
693
694 void VisitExpr_(const BufferLoadNode* node) final {
695 // Store buffer shape/dtype. It may already be stored.
696 buffer_shapes[node->buffer->data] = node->buffer->shape;
697 buffer_dtypes[node->buffer->data] = node->buffer->dtype;
698 StmtExprVisitor::VisitExpr_(node);
699 }
700
701 void VisitStmt_(const BufferStoreNode* node) final {
702 // Store buffer shape/dtype. It may already be stored.
703 buffer_shapes[node->buffer->data] = node->buffer->shape;
704 buffer_dtypes[node->buffer->data] = node->buffer->dtype;
705
706 MathOpCounter math_op_counter;
707 math_op_counter(node->value);
708 std::vector<float> mem_bytes_list;
709 std::vector<float> compute_ops_list;
710 double cur_compute_ops;
711
712 // Group 1: Computation related features
713 ExtractComputationFeature(node->buffer->data, node->indices, math_op_counter);
714
715 // Group 2: Buffer access related features (per buffer)
716 ExtractBufferAccessFeature(node->buffer->data, node->indices, node->value, math_op_counter,
717 &cur_compute_ops, &compute_ops_list, &mem_bytes_list);
718
719 // Group 3: Arithmetic intensity related features
720 ExtractArithmeticIntensityFeature(node->buffer->data, cur_compute_ops, compute_ops_list,
721 mem_bytes_list);
722
723 // Group 4: Allocation related features
724 ExtractOuterScopeFeature(node->buffer->data);
725 }
726
727 void VisitStmt_(const BufferRealizeNode* node) final {
728 // Store buffer shape/dtype. It may already be stored.
729 buffer_shapes[node->buffer->data] = node->buffer->shape;
730 buffer_dtypes[node->buffer->data] = node->buffer->dtype;
731 StmtExprVisitor::VisitStmt_(node);
732
733 // Group 5: Outer scope related features
734 ExtractAllocationFeature(node);
735 }
736
737 void VisitStmt_(const AllocateNode* node) final {
738 buffer_dtypes[node->buffer_var] = node->dtype;
739 buffer_shapes[node->buffer_var] = node->extents;
740 StmtExprVisitor::VisitStmt_(node);
741
742 // Group 5: Outer scope related features
743 ExtractAllocationFeature(node);
744 }
745
746 void VisitStmt_(const LetStmtNode* node) final {
747 // TODO(tkonolige): add arithmetic counts from this statement to counts of inner stores.
748 ana_.Bind(node->var, node->value);
749 ICHECK(variable_definition_stack_.size() > 0)
750 << "Variable definition outside of a for loop is not handled by feature extraction";
751 variable_definition_stack_.back().push_back(std::make_tuple(node->var, node->value));
752 StmtExprVisitor::VisitStmt_(node);
753 }
754
755 // Extract computation related features (group 1)
756 void ExtractComputationFeature(const Var& buffer, const Array<PrimExpr>& indices,
757 const MathOpCounter& math_op_counter) {
758 FeatureSet& fea = buffer_features[buffer];
759
760 // Computation related features
761 fea.float_mad += outer_loop_prod_ * math_op_counter.float_mad;
762 fea.float_addsub += outer_loop_prod_ * math_op_counter.float_addsub;
763 fea.float_mul += outer_loop_prod_ * math_op_counter.float_mul;
764 fea.float_divmod += outer_loop_prod_ * math_op_counter.float_divmod;
765 fea.float_cmp += outer_loop_prod_ * math_op_counter.float_cmp;
766 fea.float_math_func += outer_loop_prod_ * math_op_counter.float_math_func;
767 fea.float_other_func += outer_loop_prod_ * math_op_counter.float_other_func;
768 fea.int_mad += outer_loop_prod_ * math_op_counter.int_mad;
769 fea.int_addsub += outer_loop_prod_ * math_op_counter.int_addsub;
770 fea.int_mul += outer_loop_prod_ * math_op_counter.int_mul;
771 fea.int_divmod += outer_loop_prod_ * math_op_counter.int_divmod;
772 fea.int_math_func += outer_loop_prod_ * math_op_counter.int_math_func;
773 fea.int_cmp += outer_loop_prod_ * math_op_counter.int_cmp;
774 fea.int_other_func += outer_loop_prod_ * math_op_counter.int_other_func;
775 fea.bool_op += outer_loop_prod_ * math_op_counter.bool_op;
776 fea.select_op += outer_loop_prod_ * math_op_counter.select_op;
777
778 fea.vec_len = fea.unroll_len = fea.parallel_len = 0.0f;
779 fea.vec_type = fea.unroll_type = fea.parallel_type = AnnotationPosType::kPosNone;
780
781 fea.vec_num = vec_for_stack_.size();
782 if (!vec_for_stack_.empty()) {
783 fea.vec_len = GetLoopExtent(vec_for_stack_.back(), ana_);
784 fea.vec_prod = 1.0;
785 for (const ForNode* pfor : vec_for_stack_) {
786 fea.vec_prod *= GetLoopExtent(pfor, ana_);
787 }
788 fea.vec_type = AnnotationPosType::kPosMixed;
789 // todo(merrymercy): this feature requires operation (tvm.compute) information
790 // GetAnnotationPosEncoding(vec_for_stack_.back()->loop_var,
791 // node->args, pcompute->axis, pcompute->reduce_axis);
792 }
793
794 fea.unroll_num = unroll_for_stack_.size();
795 if (!unroll_for_stack_.empty()) {
796 fea.unroll_len = GetLoopExtent(unroll_for_stack_.back(), ana_);
797 fea.unroll_prod = 1.0;
798 for (const ForNode* pfor : unroll_for_stack_) {
799 fea.unroll_prod *= GetLoopExtent(pfor, ana_);
800 }
801 fea.unroll_type = AnnotationPosType::kPosMixed;
802 // GetAnnotationPosEncoding(unroll_for_stack_.back()->loop_var,
803 // node->args, pcompute->axis, pcompute->reduce_axis);
804 }
805
806 fea.parallel_num = parallel_for_stack_.size();
807 if (!parallel_for_stack_.empty()) {
808 fea.parallel_len = GetLoopExtent(parallel_for_stack_.back(), ana_);
809 fea.parallel_prod = 1.0;
810 for (const ForNode* pfor : parallel_for_stack_) {
811 fea.parallel_prod *= GetLoopExtent(pfor, ana_);
812 }
813 fea.parallel_type = AnnotationPosType::kPosMixed;
814 // GetAnnotationPosEncoding(parallel_for_stack_.back()->loop_var,
815 // node->args, pcompute->axis, pcompute->reduce_axis);
816 }
817
818 // GPU threads
819 fea.is_gpu = is_gpu_;
820 fea.blockIdx_x_len = blockIdx_x_len_;
821 fea.blockIdx_y_len = block_idx_y_len_;
822 fea.blockIdx_z_len = block_idx_z_len_;
823 fea.threadIdx_x_len = threadIdx_x_len_;
824 fea.threadIdx_y_len = thread_idx_y_len_;
825 fea.threadIdx_z_len = thread_idx_z_len_;
826 fea.vthread_len = vthread_len_;
827 }
828
829 // Extract buffer access related features (group 2)
830 void ExtractBufferAccessFeature(const Var& buffer, const Array<PrimExpr>& indices,
831 const PrimExpr& value, const MathOpCounter& math_op_counter,
832 double* cur_compute_ops, std::vector<float>* compute_ops_list,
833 std::vector<float>* mem_bytes_list) {
834 FeatureSet& fea = buffer_features[buffer];
835
836 // Extract all buffer accesses
837 std::vector<BufferAccessFeature> acc_feas;
838 BufferAccessExtractor buf_extractor;
839 buf_extractor.InsertAccess(buffer, BufferAccessType::kWrite, indices);
840 buf_extractor.ExtractReads(value);
841
842 mem_bytes_list->reserve(for_loop_stack_.size());
843 compute_ops_list->reserve(for_loop_stack_.size());
844
845 *cur_compute_ops = math_op_counter.float_mad + math_op_counter.float_addsub +
846 math_op_counter.float_mul + math_op_counter.float_divmod +
847 math_op_counter.float_cmp + math_op_counter.float_math_func +
848 math_op_counter.float_other_func;
849
850 ICHECK_EQ(for_loop_stack_.size(), variable_definition_stack_.size())
851 << "variable_definition_stack_ should mirror for_loop_stack_ in size";
852 std::vector<int> tmp_region;
853 for (int i = static_cast<int>(for_loop_stack_.size()) - 1; i >= 0; i--) {
854 const ForNode* p_for = for_loop_stack_[i];
855
856 // Construct a local analyzer context which contains definitions (for and
857 // let) from innermost loops up to and including `i`. For loop variable
858 // definitions in loops more outer than `i` are set to 1 so that we can
859 // get per-loop-iteration features. Note that we add these definitions
860 // from outermost to innermost because inner definitions may depend on
861 // outer ones.
862 Analyzer local_analyzer;
863 for (int j = 0; j < i; j++) {
864 local_analyzer.Bind(for_loop_stack_.at(j)->loop_var,
865 Range::FromMinExtent(for_loop_stack_.at(j)->min, 1));
866 }
867 for (int j = i; j < static_cast<int>(for_loop_stack_.size()); j++) {
868 local_analyzer.Bind(
869 for_loop_stack_.at(j)->loop_var,
870 Range::FromMinExtent(for_loop_stack_.at(j)->min, for_loop_stack_.at(j)->extent));
871 for (auto definition : variable_definition_stack_.at(j)) {
872 local_analyzer.Bind(std::get<0>(definition), std::get<1>(definition));
873 }
874 }
875
876 // Note, here we do overwrite.
877 // So if there are multiple BufferStoreNode, the last one will overwrite the first few.
878 // e.g. The update part in gemm will overwrite the init part.
879 BufferMap<std::vector<std::tuple<BufferAccessType, int64_t, int>>>& buffer_regions_map =
880 for_touch_regions_[p_for];
881
882 int64_t mem_bytes = 0;
883 for (const auto& x : buf_extractor.buf_accesses) {
884 const Var& t = x.first;
885 const BufferAccess& acc = x.second;
886
887 ComputeRegion(acc.indices, &local_analyzer, &tmp_region);
888 int64_t touched_size = ElementProduct(tmp_region);
889 touched_size = std::max<int64_t>(0, touched_size);
890 buffer_regions_map[t].push_back(
891 std::make_tuple(acc.acc_type, touched_size, buffer_dtypes.at(t).bytes()));
892 mem_bytes += touched_size * buffer_dtypes.at(t).bytes();
893 }
894
895 mem_bytes_list->push_back(mem_bytes);
896 *cur_compute_ops *= GetLoopExtent(for_loop_stack_[i], local_analyzer);
897 compute_ops_list->push_back(*cur_compute_ops);
898 }
899
900 // Buffer access related features (per buffer)
901 for (const auto& x : buf_extractor.buf_accesses) {
902 const Var& t = x.first;
903 const BufferAccess& acc = x.second;
904
905 std::vector<int> int_shape;
906 for (const auto& dim : buffer_shapes.at(t)) {
907 int_shape.push_back(GetIntImm(dim));
908 }
909
910 size_t ele_bytes = buffer_dtypes.at(t).bytes();
911
912 // calculate bytes
913 float bytes = outer_loop_prod_ * ele_bytes;
914 float unique_bytes;
915
916 // calculate cache lines
917 int64_t stride;
918 float lines;
919 float unique_lines;
920
921 if (for_loop_stack_.empty()) {
922 unique_bytes = ele_bytes;
923 stride = 0;
924 lines = 1.0f;
925 unique_lines = 1.0f;
926 } else {
927 unique_bytes = static_cast<float>(
928 std::get<1>(for_touch_regions_[for_loop_stack_.front()][t].front())) *
929 ele_bytes;
930
931 stride = 0;
932 int64_t reduce_ratio = 1;
933
934 int i;
935 for (i = static_cast<int>(for_loop_stack_.size()) - 1; i >= 0; i--) {
936 stride = ComputeStride(acc.indices, int_shape, for_loop_stack_[i]->loop_var.get());
937 if (stride != 0) {
938 break;
939 }
940 reduce_ratio *= GetLoopExtent(for_loop_stack_.back(), ana_);
941 }
942
943 lines = outer_loop_prod_ / reduce_ratio *
944 std::min(1.0f, 1.0f * stride * ele_bytes / cache_line_size_);
945 lines = std::max(lines, 1.0f);
946
947 // convert `stride` back to the stride of the innermost iterator
948 stride = (i == static_cast<int>(for_loop_stack_.size()) - 1 ? stride : 0);
949
950 float n_continuous = ele_bytes;
951 for (int i = std::min(static_cast<int>(tmp_region.size()) - 1,
952 static_cast<int>(int_shape.size()) - 1);
953 i >= 0; i--) {
954 if (tmp_region[i] == int_shape[i]) {
955 n_continuous *= tmp_region[i];
956 break;
957 }
958 }
959 unique_lines = unique_bytes / std::min(n_continuous, static_cast<float>(cache_line_size_));
960 unique_lines = std::max(unique_lines, 1.0f);
961 }
962
963 auto [reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct] =
964 ComputeReuse(t, acc.indices, for_loop_stack_, for_touch_regions_, ana_);
965
966 acc_feas.emplace_back();
967 BufferAccessFeature& acc_fea = acc_feas.back();
968
969 // TODO(tkonolige): save buffer names and use those instead?
970 acc_fea.buffer_name = t->name_hint;
971 acc_fea.acc_type = acc.acc_type;
972 acc_fea.stride = stride;
973 acc_fea.bytes = bytes;
974 acc_fea.unique_bytes = unique_bytes;
975 acc_fea.lines = lines;
976 acc_fea.unique_lines = unique_lines;
977 acc_fea.reuse_type = reuse_type;
978 acc_fea.reuse_dis_iter = reuse_dis_iter;
979 acc_fea.reuse_dis_bytes = reuse_dis_bytes;
980 acc_fea.reuse_ct = reuse_ct;
981 if (acc_fea.reuse_ct > 0.5) {
982 acc_fea.bytes_d_reuse_ct = bytes / reuse_ct;
983 acc_fea.unique_bytes_d_reuse_ct = unique_bytes / reuse_ct;
984 acc_fea.lines_d_reuse_ct = lines / reuse_ct;
985 acc_fea.unique_lines_d_reuse_ct = unique_lines / reuse_ct;
986 } else {
987 // no reuse, multiply by a magic number '2'
988 acc_fea.bytes_d_reuse_ct = bytes * 2;
989 acc_fea.unique_bytes_d_reuse_ct = unique_bytes * 2;
990 acc_fea.lines_d_reuse_ct = lines * 2;
991 acc_fea.unique_lines_d_reuse_ct = unique_lines * 2;
992 }
993 }
994
995 fea.access_feas = acc_feas;
996 }
997
998 // Extract arithmetic intensity related feature (group 3)
999 void ExtractArithmeticIntensityFeature(const Var& buffer, double cur_compute_ops,
1000 const std::vector<float>& compute_ops_list,
1001 const std::vector<float>& mem_bytes_list) {
1002 FeatureSet& fea = buffer_features[buffer];
1003
1004 // Compute arithmetic intensity curve (y axis : arithmetic intensity, x axis : flops).
1005 // We use piecewise linear interpolation to fit this curve.
1006 int pt = 0;
1007 if (cur_compute_ops <= 0 || compute_ops_list.empty()) {
1008 std::fill(fea.arith_intensity_curve,
1009 fea.arith_intensity_curve + ARITH_INTENSITY_CURVE_SAMPLE_N, 0.0);
1010 } else {
1011 for (size_t i = 0; i < ARITH_INTENSITY_CURVE_SAMPLE_N; ++i) {
1012 float cur_compute_ops = compute_ops_list.back() * (i + 1) / ARITH_INTENSITY_CURVE_SAMPLE_N;
1013 while (compute_ops_list[pt] < cur_compute_ops - 1e-4) {
1014 pt++;
1015 }
1016 ICHECK_LT(pt, compute_ops_list.size());
1017
1018 float value;
1019 if (pt == 0) {
1020 value = compute_ops_list[pt] / mem_bytes_list[pt];
1021 } else {
1022 float base = compute_ops_list[pt - 1] / mem_bytes_list[pt - 1];
1023 float slope = (compute_ops_list[pt] / mem_bytes_list[pt] -
1024 compute_ops_list[pt - 1] / mem_bytes_list[pt - 1]) /
1025 (compute_ops_list[pt] - compute_ops_list[pt - 1]);
1026 value = base + slope * (cur_compute_ops - compute_ops_list[pt - 1]);
1027 }
1028 fea.arith_intensity_curve[i] = value;
1029 }
1030 }
1031 }
1032
1033 // Extract allocation related features (group 4)
1034 void ExtractAllocationFeature(const BufferRealizeNode* node) {
1035 FeatureSet& fea = buffer_features[node->buffer->data];
1036
1037 float allocation_size = 1.0f;
1038 for (const auto& x : node->bounds) {
1039 allocation_size *= GetIntImm(x->extent);
1040 }
1041 // allocation feature
1042 fea.alloc_size = allocation_size * node->buffer->dtype.bytes();
1043 fea.alloc_prod = allocation_size * outer_loop_prod_;
1044 fea.alloc_outer_prod = outer_loop_prod_;
1045 fea.alloc_inner_prod = fea.outer_prod / outer_loop_prod_;
1046 }
1047
1048 void ExtractAllocationFeature(const AllocateNode* node) {
1049 FeatureSet& fea = buffer_features[node->buffer_var];
1050
1051 float allocation_size = 1.0f;
1052 for (const auto& x : node->extents) {
1053 // TODO(tkonolige): will not handle dynamic shape
1054 allocation_size *= GetIntImm(x);
1055 }
1056 // allocation feature
1057 fea.alloc_size = allocation_size * node->dtype.bytes();
1058 fea.alloc_prod = allocation_size * outer_loop_prod_;
1059 fea.alloc_outer_prod = outer_loop_prod_;
1060 fea.alloc_inner_prod = fea.outer_prod / outer_loop_prod_;
1061 }
1062
1063 // Extract outer scope related features (group 5)
1064 void ExtractOuterScopeFeature(const Var& buffer) {
1065 FeatureSet& fea = buffer_features[buffer];
1066
1067 fea.outer_prod = outer_loop_prod_;
1068 fea.num_loops = for_loop_stack_.size();
1069 fea.auto_unroll_max_step = cur_auto_unroll_max_step_;
1070 }
1071
1072 // Stores FeatureSet for every buffer
1073 BufferMap<FeatureSet> buffer_features;
1074
1075 private:
1076 // The shared arithmetic analyzer
1077 Analyzer ana_;
1078
1079 // The product of outer loop
1080 float outer_loop_prod_ = 1.0f;
1081
1082 // The stacks to store parent loops during DFS
1083 std::vector<const ForNode*> for_loop_stack_;
1084 std::vector<const ForNode*> parallel_for_stack_;
1085 std::vector<const ForNode*> vec_for_stack_;
1086 std::vector<const ForNode*> unroll_for_stack_;
1087 std::vector<std::vector<std::tuple<Var, PrimExpr>>> variable_definition_stack_;
1088
1089 // GPU-related features
1090 bool is_gpu_{false};
1091 int blockIdx_x_len_{1};
1092 int block_idx_y_len_{1};
1093 int block_idx_z_len_{1};
1094 int threadIdx_x_len_{1};
1095 int thread_idx_y_len_{1};
1096 int thread_idx_z_len_{1};
1097 int vthread_len_{1};
1098 int16_t cur_auto_unroll_max_step_{0};
1099
1100 // Store touch region information for all for loops. The format of this nested map:
1101 // For a loop, for all its touched buffers, for all different accesses to the buffers,
1102 // its (access type, number of touched elements, number of bytes of single element)
1103 std::unordered_map<const ForNode*,
1104 BufferMap<std::vector<std::tuple<BufferAccessType, int64_t, int>>>>
1105 for_touch_regions_;
1106
1107 // The default cache line size in bytes
1108 const int cache_line_size_ = 64;
1109
1110 // Storage of buffer shape and dtype information. Needed because Load/Store
1111 // nodes only do not contain this information.
1112 BufferMap<Array<PrimExpr>> buffer_shapes;
1113 BufferMap<DataType> buffer_dtypes;
1114};
1115
1116// shifted log to incorporate the property that log2p(0) = 0
1117inline float log2p(float x) { return x < 0 ? -std::log2(-x + 1) : std::log2(x + 1); }
1118
1119void GetPerStoreFeature(const PrimFunc& func, int cache_line_size, int max_n_bufs,
1120 std::vector<float>* ret, bool log_scale) {
1121 PerStoreFeatureExtractor extractor(cache_line_size, func->buffer_map);
1122 extractor(func->body);
1123
1124 auto slog = log_scale ? log2p : [](float x) { return x; };
1125
1126 ret->push_back(extractor.buffer_features.size());
1127
1128 for (const auto& x : extractor.buffer_features) {
1129 const FeatureSet& fea_set = x.second;
1130
1131 /***** Group 1: Computation related features *****/
1132 ret->push_back(slog(fea_set.float_mad));
1133 ret->push_back(slog(fea_set.float_addsub));
1134 ret->push_back(slog(fea_set.float_mul));
1135 ret->push_back(slog(fea_set.float_divmod));
1136 ret->push_back(slog(fea_set.float_cmp));
1137 ret->push_back(slog(fea_set.float_math_func));
1138 ret->push_back(slog(fea_set.float_other_func));
1139 ret->push_back(slog(fea_set.int_mad));
1140 ret->push_back(slog(fea_set.int_addsub));
1141 ret->push_back(slog(fea_set.int_mul));
1142 ret->push_back(slog(fea_set.int_divmod));
1143 ret->push_back(slog(fea_set.int_cmp));
1144 ret->push_back(slog(fea_set.int_math_func));
1145 ret->push_back(slog(fea_set.int_other_func));
1146 ret->push_back(slog(fea_set.bool_op));
1147 ret->push_back(slog(fea_set.select_op));
1148
1149 ret->push_back(slog(fea_set.vec_num));
1150 ret->push_back(slog(fea_set.vec_prod));
1151 ret->push_back(slog(fea_set.vec_len));
1152 for (int i = 0; i <= static_cast<int>(AnnotationPosType::kPosMixed); i++) {
1153 ret->push_back(i == static_cast<int>(fea_set.vec_type));
1154 }
1155
1156 ret->push_back(slog(fea_set.unroll_num));
1157 ret->push_back(slog(fea_set.unroll_prod));
1158 ret->push_back(slog(fea_set.unroll_len));
1159 for (int i = 0; i <= static_cast<int>(AnnotationPosType::kPosMixed); i++) {
1160 ret->push_back(i == static_cast<int>(fea_set.unroll_type));
1161 }
1162
1163 ret->push_back(slog(fea_set.parallel_num));
1164 ret->push_back(slog(fea_set.parallel_prod));
1165 ret->push_back(slog(fea_set.parallel_len));
1166 for (int i = 0; i <= static_cast<int>(AnnotationPosType::kPosMixed); i++) {
1167 ret->push_back(i == static_cast<int>(fea_set.parallel_type));
1168 }
1169
1170 ret->push_back(fea_set.is_gpu);
1171 ret->push_back(slog(fea_set.blockIdx_x_len));
1172 ret->push_back(slog(fea_set.blockIdx_y_len));
1173 ret->push_back(slog(fea_set.blockIdx_z_len));
1174 ret->push_back(slog(fea_set.threadIdx_x_len));
1175 ret->push_back(slog(fea_set.threadIdx_y_len));
1176 ret->push_back(slog(fea_set.threadIdx_z_len));
1177 ret->push_back(slog(fea_set.vthread_len));
1178
1179 /***** Group 2: Buffer access related features *****/
1180 // sort according to pair (lines, bytes)
1181 std::vector<std::pair<float, float>> buf_order_key;
1182 for (const auto& acc_fea : fea_set.access_feas) {
1183 buf_order_key.emplace_back(acc_fea.lines, acc_fea.bytes);
1184 }
1185 std::vector<int> buf_order(buf_order_key.size());
1186 std::iota(buf_order.begin(), buf_order.end(), 0);
1187
1188 auto cmp = [&buf_order_key](int l, int r) {
1189 return buf_order_key[l].first > buf_order_key[r].first ||
1190 (buf_order_key[l].first == buf_order_key[r].first &&
1191 buf_order_key[l].second > buf_order_key[r].second);
1192 };
1193 std::sort(buf_order.begin(), buf_order.end(), cmp);
1194 int n_bufs = std::min(max_n_bufs, static_cast<int>(buf_order.size()));
1195 buf_order.resize(n_bufs);
1196
1197 for (int idx : buf_order) {
1198 const auto& acc_fea = fea_set.access_feas[idx];
1199 for (int j = 0; j <= static_cast<int>(BufferAccessType::kReadWrite); ++j) {
1200 ret->push_back(j == static_cast<int>(acc_fea.acc_type));
1201 }
1202 ret->push_back(slog(acc_fea.bytes));
1203 ret->push_back(slog(acc_fea.unique_bytes));
1204 ret->push_back(slog(acc_fea.lines));
1205 ret->push_back(slog(acc_fea.unique_lines));
1206 for (int j = 0; j <= static_cast<int>(ReuseType::kNoReuse); ++j) {
1207 ret->push_back(j == static_cast<int>(acc_fea.reuse_type));
1208 }
1209 ret->push_back(slog(acc_fea.reuse_dis_iter));
1210 ret->push_back(slog(acc_fea.reuse_dis_bytes));
1211 ret->push_back(slog(acc_fea.reuse_ct));
1212 ret->push_back(slog(acc_fea.bytes_d_reuse_ct));
1213 ret->push_back(slog(acc_fea.unique_bytes_d_reuse_ct));
1214 ret->push_back(slog(acc_fea.lines_d_reuse_ct));
1215 ret->push_back(slog(acc_fea.unique_lines_d_reuse_ct));
1216 ret->push_back(slog(acc_fea.stride));
1217 }
1218 // - fill padding
1219 for (int i = 0; i < max_n_bufs - n_bufs; ++i) {
1220 for (int j = 0; j <= static_cast<int>(BufferAccessType::kReadWrite); ++j) { // 3
1221 ret->push_back(0.0f);
1222 }
1223 ret->push_back(0.0f);
1224 ret->push_back(0.0f);
1225 ret->push_back(0.0f);
1226 ret->push_back(0.0f);
1227 for (int j = 0; j <= static_cast<int>(ReuseType::kNoReuse); ++j) { // 3
1228 ret->push_back(0.0f);
1229 }
1230 ret->push_back(0.0f);
1231 ret->push_back(0.0f);
1232 ret->push_back(0.0f);
1233 ret->push_back(0.0f);
1234 ret->push_back(0.0f);
1235 ret->push_back(0.0f);
1236 ret->push_back(0.0f);
1237 ret->push_back(0.0f);
1238 }
1239
1240 /***** Group 3: Arithmetic intensity related features *****/
1241 for (size_t i = 0; i < ARITH_INTENSITY_CURVE_SAMPLE_N; ++i) {
1242 ret->push_back(slog(fea_set.arith_intensity_curve[i]));
1243 }
1244
1245 /***** Group 4: Allocation related features *****/
1246 ret->push_back(slog(fea_set.alloc_size));
1247 ret->push_back(slog(fea_set.alloc_prod));
1248 ret->push_back(slog(fea_set.alloc_outer_prod));
1249 ret->push_back(slog(fea_set.alloc_inner_prod));
1250
1251 /***** Group 5: Outer scope related features *****/
1252 ret->push_back(slog(fea_set.outer_prod));
1253 ret->push_back(slog(fea_set.num_loops));
1254 ret->push_back(slog(fea_set.auto_unroll_max_step));
1255 }
1256}
1257
1258void GetPerStoreFeatureName(int max_n_bufs, std::vector<std::string>* ret) {
1259 /***** Group 1: Computation related features *****/
1260 ret->push_back(("float_mad"));
1261 ret->push_back(("float_addsub"));
1262 ret->push_back(("float_mul"));
1263 ret->push_back(("float_divmod"));
1264 ret->push_back(("float_cmp"));
1265 ret->push_back(("float_mathfunc"));
1266 ret->push_back(("float_otherfunc"));
1267 ret->push_back(("int_mad"));
1268 ret->push_back(("int_addsub"));
1269 ret->push_back(("int_mul"));
1270 ret->push_back(("int_divmod"));
1271 ret->push_back(("int_cmp"));
1272 ret->push_back(("int_mathfunc"));
1273 ret->push_back(("int_otherfunc"));
1274 ret->push_back(("bool_op"));
1275 ret->push_back(("select_op"));
1276 ret->push_back(("vec_num"));
1277 ret->push_back(("vec_prod"));
1278 ret->push_back(("vec_len"));
1279 ret->push_back(("vec_type.kPosNone"));
1280 ret->push_back(("vec_type.kPosInnerSpatial"));
1281 ret->push_back(("vec_type.kPosMiddleSpatial"));
1282 ret->push_back(("vec_type.kPosOuterSpatial"));
1283 ret->push_back(("vec_type.kPosInnerReduce"));
1284 ret->push_back(("vec_type.kPosMiddleReduce"));
1285 ret->push_back(("vec_type.kPosOuterReduce"));
1286 ret->push_back(("vec_type.kPosMixed"));
1287 ret->push_back(("unroll_num"));
1288 ret->push_back(("unroll_prod"));
1289 ret->push_back(("unroll_len"));
1290 ret->push_back(("unroll_type.kPosNone"));
1291 ret->push_back(("unroll_type.kPosInnerSpatial"));
1292 ret->push_back(("unroll_type.kPosMiddleSpatial"));
1293 ret->push_back(("unroll_type.kPosOuterSpatial"));
1294 ret->push_back(("unroll_type.kPosInnerReduce"));
1295 ret->push_back(("unroll_type.kPosMiddleReduce"));
1296 ret->push_back(("unroll_type.kPosOuterReduce"));
1297 ret->push_back(("unroll_type.kPosMixed"));
1298 ret->push_back(("parallel_num"));
1299 ret->push_back(("parallel_prod"));
1300 ret->push_back(("parallel_len"));
1301 ret->push_back(("parallel_type.kPosNone"));
1302 ret->push_back(("parallel_type.kPosInnerSpatial"));
1303 ret->push_back(("parallel_type.kPosMiddleSpatial"));
1304 ret->push_back(("parallel_type.kPosOuterSpatial"));
1305 ret->push_back(("parallel_type.kPosInnerReduce"));
1306 ret->push_back(("parallel_type.kPosMiddleReduce"));
1307 ret->push_back(("parallel_type.kPosOuterReduce"));
1308 ret->push_back(("parallel_type.kPosMixed"));
1309 ret->push_back(("is_gpu"));
1310 ret->push_back(("blockIdx_x_len"));
1311 ret->push_back(("blockIdx_y_len"));
1312 ret->push_back(("blockIdx_z_len"));
1313 ret->push_back(("threadIdx_x_len"));
1314 ret->push_back(("threadIdx_y_len"));
1315 ret->push_back(("threadIdx_z_len"));
1316 ret->push_back(("vthread_len"));
1317 // section total: 57
1318
1319 /***** Group 2: Buffer access related features *****/
1320 for (size_t i = 0; i < static_cast<size_t>(max_n_bufs); ++i) {
1321 std::string prefix = "B" + std::to_string(i) + ".";
1322 ret->push_back((prefix + "acc_type.kRead"));
1323 ret->push_back((prefix + "acc_type.kWrite"));
1324 ret->push_back((prefix + "acc_type.kReadWrite"));
1325 ret->push_back((prefix + "bytes"));
1326 ret->push_back((prefix + "unique_bytes"));
1327 ret->push_back((prefix + "lines"));
1328 ret->push_back((prefix + "unique_lines"));
1329 ret->push_back((prefix + "reuse_type.kLoopMultipleRead"));
1330 ret->push_back((prefix + "reuse_type.kSerialMultipleReadWrite"));
1331 ret->push_back((prefix + "reuse_type.kNoReuse"));
1332 ret->push_back((prefix + "reuse_dis_iter"));
1333 ret->push_back((prefix + "reuse_dis_bytes"));
1334 ret->push_back((prefix + "reuse_ct"));
1335 ret->push_back((prefix + "bytes_d_reuse_ct"));
1336 ret->push_back((prefix + "unique_bytes_d_reuse_ct"));
1337 ret->push_back((prefix + "lines_d_reuse_ct"));
1338 ret->push_back((prefix + "unique_lines_d_reuse_ct"));
1339 ret->push_back((prefix + "stride"));
1340 }
1341 // section total : max_n_bufs * 18
1342
1343 /***** Group 3: Arithmetic intensity related features *****/
1344 for (size_t i = 0; i < ARITH_INTENSITY_CURVE_SAMPLE_N; ++i) {
1345 ret->push_back(("arith_intensity_curve_" + std::to_string(i)));
1346 }
1347 // section total: ARITH_INTENSITY_CURVE_SAMPLE_N = 10
1348
1349 /***** Group 4: Allocation related features *****/
1350 ret->push_back(("alloc_size"));
1351 ret->push_back(("alloc_prod"));
1352 ret->push_back(("alloc_outer_prod"));
1353 ret->push_back(("alloc_inner_prod"));
1354 // section total : 4
1355
1356 /***** Group 5: Outer scope related features *****/
1357 ret->push_back(("outer_prod"));
1358 ret->push_back(("num_loops"));
1359 ret->push_back(("auto_unroll_max_step"));
1360 // section total : 3
1361}
1362
1363void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, int max_n_bufs,
1364 std::vector<float>* feature, std::atomic<int>* error_ct) {
1365 auto [sch, tensors] = task->compute_dag.ApplySteps(state->transform_steps);
1366
1367 // When inlining, replace const matrices with const values.
1368 // Produces wrong IR, but good enough for feature extraction, and
1369 // can improve the speed of feature extraction/search. Must be
1370 // called before ScheduleToModule to have an effect.
1371 sch = sch.normalize_for_feature_extraction();
1372
1373 try {
1374 const std::string& name = "main";
1375 auto pass_ctx = tvm::transform::PassContext::Current();
1376
1377 auto mod = ScheduleToModule(sch, Array<ObjectRef>{tensors.begin(), tensors.end()}, name,
1378 std::unordered_map<te::Tensor, te::Buffer>(),
1379 GlobalVarSupply(NameSupply("")));
1380
1381 bool disable_vectorize =
1382 pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
1383 bool instrument_bound_checkers =
1384 pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
1385
1386 if (IsGPUTask(task)) {
1387 auto pass_list = Array<tvm::transform::Pass>();
1388 // Phase 0
1389 pass_list.push_back(tir::transform::InjectPrefetch());
1390 pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
1391 // Phase 1
1392 pass_list.push_back(tir::transform::NarrowDataType(32));
1393 pass_list.push_back(tir::transform::Simplify());
1394 pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
1395 pass_list.push_back(tir::transform::InjectVirtualThread());
1396 pass_list.push_back(tir::transform::StorageRewrite());
1397 pass_list.push_back(tir::transform::Simplify());
1398 tvm::Map<String, tvm::PrimExpr> gpu_params{
1399 {"max_shared_memory_per_block", task->hardware_params->max_shared_memory_per_block},
1400 {"max_local_memory_per_block", task->hardware_params->max_local_memory_per_block},
1401 {"max_threads_per_block", task->hardware_params->max_threads_per_block},
1402 {"max_vector_bytes", task->hardware_params->vector_unit_bytes},
1403 {"max_vthread", task->hardware_params->max_vthread_extent},
1404 };
1405 pass_list.push_back(tir::transform::VerifyGPUCode(gpu_params));
1406 const auto& optimize = tir::transform::Sequential(pass_list);
1407 optimize(mod);
1408 }
1409 if (IsHexagonTask(task)) {
1410 Target target = task->target;
1411 const auto vtcm_capacity = target->GetAttr<Integer>("vtcm-capacity").value().IntValue();
1412 const auto& optimize =
1413 tir::transform::Sequential({tir::transform::VerifyVTCMLimit(vtcm_capacity)});
1414 optimize(mod);
1415 }
1416 const auto& optimize =
1417 tir::transform::Sequential(Array<tvm::transform::Pass>{tir::transform::Simplify()});
1418 mod = optimize(std::move(mod));
1419 PrimFunc prim_func = Downcast<PrimFunc>(mod->Lookup(name));
1420 GetPerStoreFeature(prim_func, task->hardware_params->cache_line_bytes, max_n_bufs, feature);
1421 } catch (Error& e) {
1422 (*error_ct)++;
1423 }
1424}
1425
1426void GetPerStoreFeaturesFromStates(const Array<State>& states, const SearchTask& task,
1427 int skip_first_n_feature_extraction, int max_n_bufs,
1428 std::vector<std::vector<float>>* features) {
1429 // extract features
1430 features->assign(states.size(), std::vector<float>());
1431
1432 std::atomic<int> error_ct(0);
1433
1434 support::parallel_for(skip_first_n_feature_extraction, states.size(),
1435 [&task, &states, &max_n_bufs, &features, &error_ct](int i) {
1436 GetPerStoreFeaturesWorkerFunc(task, states[i], max_n_bufs,
1437 &(*features)[i], &error_ct);
1438 });
1439}
1440
1441void GetPerStoreFeaturesFromStates(const Array<State>& states, const std::vector<SearchTask>& tasks,
1442 int skip_first_n_feature_extraction, int max_n_bufs,
1443 std::vector<std::vector<float>>* features) {
1444 // extract features
1445 features->assign(states.size(), std::vector<float>());
1446
1447 std::atomic<int> error_ct(0);
1448
1449 support::parallel_for(skip_first_n_feature_extraction, states.size(),
1450 [&tasks, &states, &max_n_bufs, &features, &error_ct](int i) {
1451 GetPerStoreFeaturesWorkerFunc(tasks[i], states[i], max_n_bufs,
1452 &(*features)[i], &error_ct);
1453 });
1454}
1455
1456void GetPerStoreFeaturesFromFile(const std::string& filename, int max_lines, int max_n_bufs,
1457 std::vector<std::vector<float>>* features,
1458 std::vector<float>* normalized_throughputs,
1459 std::vector<int>* task_ids) {
1460 Array<State> states;
1461 std::vector<SearchTask> tasks;
1462
1463 normalized_throughputs->clear();
1464 task_ids->clear();
1465
1466 // (workload_key, target) -> (search_task, task_id)
1467 std::unordered_map<std::pair<std::string, std::string>, std::pair<SearchTask, size_t>> task_cache;
1468 // task_id -> min_cost
1469 std::vector<float> min_costs;
1470
1471 const auto* workload_key_to_tensors =
1472 tvm::runtime::Registry::Get("auto_scheduler.workload_key_to_tensors");
1473 ICHECK(workload_key_to_tensors != nullptr);
1474
1475 // read from file
1476 RecordReader reader(filename);
1477 auto cur_inp = make_object<MeasureInputNode>();
1478 auto cur_res = make_object<MeasureResultNode>();
1479 while (reader->ReadNext(cur_inp.get(), cur_res.get())) {
1480 float cost = static_cast<float>(FloatArrayMean(cur_res->costs));
1481 const std::string& workload_key = cur_inp->task->workload_key;
1482
1483 SearchTask task;
1484 size_t task_id;
1485 std::pair<std::string, std::string> key(workload_key, cur_inp->task->target->str());
1486 auto find_res = task_cache.find(key);
1487 if (find_res == task_cache.end()) {
1488 // rebuild task
1489 Array<te::Tensor> tensors = (*workload_key_to_tensors)(workload_key);
1490 Target target = cur_inp->task->target;
1491 Target target_host = cur_inp->task->target_host;
1492 CheckAndUpdateHostConsistency(&target, &target_host);
1493 task = SearchTask(ComputeDAG(tensors), workload_key, target, target_host,
1494 cur_inp->task->hardware_params, cur_inp->task->layout_rewrite_option,
1495 cur_inp->task->task_input_names);
1496 task_id = task_cache.size();
1497
1498 // compute min cost for each task
1499 task_cache.insert(std::make_pair(key, std::make_pair(task, task_id)));
1500 min_costs.push_back(cost);
1501 } else {
1502 std::tie(task, task_id) = find_res->second;
1503 min_costs[task_id] = std::min(min_costs[task_id], cost);
1504 }
1505
1506 tasks.push_back(std::move(task));
1507 task_ids->push_back(task_id);
1508 states.push_back(cur_inp->state);
1509 normalized_throughputs->push_back(cost);
1510
1511 if (max_lines > 0 && static_cast<int>(states.size()) >= max_lines) {
1512 break;
1513 }
1514 }
1515
1516 for (size_t i = 0; i < normalized_throughputs->size(); ++i) {
1517 (*normalized_throughputs)[i] = min_costs[(*task_ids)[i]] / (*normalized_throughputs)[i];
1518 }
1519
1520 GetPerStoreFeaturesFromStates(states, tasks, 0, max_n_bufs, features);
1521}
1522
1523void GetPerStoreFeaturesFromMeasurePairs(const Array<MeasureInput>& inputs,
1524 const Array<MeasureResult>& results,
1525 int skip_first_n_feature_extraction, int max_n_bufs,
1526 std::vector<std::vector<float>>* features,
1527 std::vector<float>* normalized_throughputs,
1528 std::vector<int>* task_ids) {
1529 Array<State> states;
1530 std::vector<SearchTask> tasks;
1531
1532 normalized_throughputs->clear();
1533 task_ids->clear();
1534
1535 // (workload_key, target) -> (search_task, task_id)
1536 std::unordered_map<std::pair<std::string, std::string>, std::pair<SearchTask, size_t>> task_cache;
1537 // task_id -> min_cost
1538 std::vector<float> min_costs;
1539
1540 const auto* workload_key_to_tensors =
1541 tvm::runtime::Registry::Get("auto_scheduler.workload_key_to_tensors");
1542 ICHECK(workload_key_to_tensors != nullptr);
1543
1544 tasks.reserve(inputs.size());
1545 normalized_throughputs->reserve(inputs.size());
1546 task_ids->reserve(inputs.size());
1547 for (size_t i = 0; i < inputs.size(); ++i) {
1548 float cost = static_cast<float>(FloatArrayMean(results[i]->costs));
1549 const std::string& workload_key = inputs[i]->task->workload_key;
1550 SearchTask task;
1551
1552 size_t task_id;
1553 std::pair<std::string, std::string> key(workload_key, inputs[i]->task->target->str());
1554 auto find_res = task_cache.find(key);
1555 if (find_res == task_cache.end()) {
1556 if (inputs[i]->task->compute_dag.defined()) { // the measure input is complete
1557 task = inputs[i]->task;
1558 } else {
1559 // The measure input is incomplete, rebuild task for incomplete measure pairs read from file
1560 try {
1561 Array<te::Tensor> tensors = (*workload_key_to_tensors)(workload_key);
1562 Target target = inputs[i]->task->target;
1563 Target target_host = inputs[i]->task->target_host;
1564 CheckAndUpdateHostConsistency(&target, &target_host);
1565 task =
1566 SearchTask(ComputeDAG(tensors), workload_key, target, target_host,
1567 inputs[i]->task->hardware_params, inputs[i]->task->layout_rewrite_option,
1568 inputs[i]->task->task_input_names);
1569 } catch (std::exception& e) {
1570 // Cannot build ComputeDAG from workload key, the task may have not been registered in
1571 // this search round
1572 continue;
1573 }
1574 }
1575 task_id = task_cache.size();
1576
1577 // compute min cost for each task
1578 task_cache.insert(std::make_pair(key, std::make_pair(task, task_id)));
1579 min_costs.push_back(cost);
1580 } else {
1581 std::tie(task, task_id) = find_res->second;
1582 min_costs[task_id] = std::min(min_costs[task_id], cost);
1583 }
1584
1585 tasks.push_back(std::move(task));
1586 task_ids->push_back(task_id);
1587 states.push_back(inputs[i]->state);
1588 normalized_throughputs->push_back(cost);
1589 }
1590
1591 for (size_t i = 0; i < normalized_throughputs->size(); ++i) {
1592 (*normalized_throughputs)[i] = min_costs[(*task_ids)[i]] / (*normalized_throughputs)[i];
1593 }
1594
1595 GetPerStoreFeaturesFromStates(states, tasks, skip_first_n_feature_extraction, max_n_bufs,
1596 features);
1597}
1598
1599/*
1600 * \brief Serialize a two-dimensional variable-size feature vector with normalized throughputs
1601 * and task ids to a one-dimensional flatten byte array.
1602 *
1603 * For faster data copy between c++ and python, the c++ part returns features in a single
1604 * flatten array using a packed format. The python part then unpacks the flatten array.
1605 *
1606 * The packed format for n records is:
1607 * {
1608 * int n;
1609 * int sizes[n+2]; // The sizes for the following arrays
1610 *
1611 * float features_0[size[0]]; // The features for record 0
1612 * float features_1[size[1]]; // The features for record 1
1613 * ...
1614 * float features_i[size[i]]; // The features for record i
1615 * ... // until i == n - 1
1616 *
1617 * float throughputs[sizes[n]]; // The normalized throughputs for n records
1618 * int task_ids[size[n+1]]; // The task ids for n records
1619 *
1620 * }
1621 * To implement this format, we also store int as float, so we can store all numbers
1622 * into a single float array.
1623 */
1624TVMByteArray SerializeFeatures(std::vector<std::vector<float>>&& features,
1625 std::vector<float>&& normalized_throughputs,
1626 std::vector<int>&& task_ids, std::vector<char>* out_data) {
1627 size_t total_bytes = 0;
1628 std::vector<int> size_vector;
1629
1630 int n = features.size();
1631
1632 // serialize sizes
1633 size_t size_vector_size = 1 + n + 2;
1634 total_bytes += size_vector_size * sizeof(int);
1635
1636 size_vector.reserve(size_vector_size);
1637 size_vector.push_back(features.size());
1638 for (const auto& x : features) {
1639 size_vector.push_back(static_cast<int>(x.size()));
1640 total_bytes += sizeof(float) * x.size();
1641 }
1642 size_vector.push_back(static_cast<int>(normalized_throughputs.size()));
1643 total_bytes += sizeof(float) * normalized_throughputs.size();
1644 size_vector.push_back(static_cast<int>(task_ids.size()));
1645 total_bytes += sizeof(int) * task_ids.size();
1646
1647 ICHECK_EQ(size_vector.size(), size_vector_size);
1648
1649 // allocate memory
1650 out_data->reserve(total_bytes);
1651 char* ptr = out_data->data();
1652
1653 // serialize size_vector
1654 memmove(ptr, reinterpret_cast<char*>(size_vector.data()), size_vector.size() * sizeof(int));
1655 ptr += size_vector.size() * sizeof(int);
1656
1657 // serialize features
1658 for (auto& x : features) {
1659 memmove(ptr, x.data(), sizeof(float) * x.size());
1660 ptr += sizeof(float) * x.size();
1661 x.clear();
1662 }
1663
1664 // serialize normalized_throughputs
1665 memmove(ptr, reinterpret_cast<char*>(normalized_throughputs.data()),
1666 normalized_throughputs.size() * sizeof(int));
1667 ptr += normalized_throughputs.size() * sizeof(int);
1668
1669 // serialize task_ids
1670 memmove(ptr, reinterpret_cast<char*>(task_ids.data()), task_ids.size() * sizeof(int));
1671 ptr += task_ids.size() * sizeof(int);
1672
1673 ICHECK_EQ(ptr - out_data->data(), total_bytes);
1674
1675 return TVMByteArray{out_data->data(), total_bytes};
1676}
1677
1678TVM_REGISTER_GLOBAL("auto_scheduler.GetPerStoreFeaturesFromFile")
1679 .set_body([](TVMArgs args, TVMRetValue* ret) {
1680 std::string filename = args[0];
1681 int max_lines = args[1];
1682 int max_n_bufs = args[2];
1683
1684 std::vector<std::vector<float>> features;
1685 std::vector<float> normalized_throughputs;
1686 std::vector<int> task_ids;
1687
1688 GetPerStoreFeaturesFromFile(filename, max_lines, max_n_bufs, &features,
1689 &normalized_throughputs, &task_ids);
1690
1691 std::vector<char> byte_data;
1692 *ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs),
1693 std::move(task_ids), &byte_data);
1694 });
1695
1696TVM_REGISTER_GLOBAL("auto_scheduler.GetPerStoreFeaturesFromMeasurePairs")
1697 .set_body([](TVMArgs args, TVMRetValue* ret) {
1698 Array<MeasureInput> inputs = args[0];
1699 Array<MeasureResult> results = args[1];
1700 int skip_first_n_feature_extraction = args[2];
1701 int max_n_bufs = args[3];
1702
1703 std::vector<std::vector<float>> features;
1704 std::vector<float> normalized_throughputs;
1705 std::vector<int> task_ids;
1706
1707 GetPerStoreFeaturesFromMeasurePairs(inputs, results, skip_first_n_feature_extraction,
1708 max_n_bufs, &features, &normalized_throughputs,
1709 &task_ids);
1710
1711 std::vector<char> byte_data;
1712 *ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs),
1713 std::move(task_ids), &byte_data);
1714 });
1715
1716TVM_REGISTER_GLOBAL("auto_scheduler.GetPerStoreFeaturesFromStates")
1717 .set_body([](TVMArgs args, TVMRetValue* ret) {
1718 Array<State> states = args[0];
1719 SearchTask task = args[1];
1720 int max_n_bufs = args[2];
1721
1722 std::vector<std::vector<float>> features;
1723 std::vector<float> normalized_throughputs;
1724 std::vector<int> task_ids;
1725
1726 GetPerStoreFeaturesFromStates(states, task, 0, max_n_bufs, &features);
1727
1728 std::vector<char> byte_data;
1729 *ret = SerializeFeatures(std::move(features), std::move(normalized_throughputs),
1730 std::move(task_ids), &byte_data);
1731 });
1732
1733TVM_REGISTER_GLOBAL("auto_scheduler.GetPerStoreFeatureNames")
1734 .set_body([](TVMArgs args, TVMRetValue* ret) {
1735 int max_n_bufs = args[0];
1736 std::vector<std::string> names;
1737
1738 GetPerStoreFeatureName(max_n_bufs, &names);
1739
1740 Array<String> arr;
1741 for (const auto& x : names) {
1742 arr.push_back(x);
1743 }
1744 *ret = arr;
1745 });
1746
1747TVM_REGISTER_GLOBAL("auto_scheduler.FeaturesFromPrimFunc")
1748 .set_body_typed([](const PrimFunc& func, int cache_line_size, int max_n_bufs, bool log_scale) {
1749 std::vector<float> vec;
1750 GetPerStoreFeature(func, cache_line_size, max_n_bufs, &vec, log_scale);
1751 int64_t num_feature_rows = vec[0]; // first element is number of rows
1752 int64_t row_length = 0;
1753 if (num_feature_rows != 0) {
1754 row_length = (vec.size() - 1) / num_feature_rows;
1755 }
1756 auto ary =
1757 runtime::NDArray::Empty({num_feature_rows, row_length}, {kDLFloat, 32, 1}, {kDLCPU, 0});
1758 // NDArray is row major by default
1759 ary.CopyFromBytes(vec.data() + 1, sizeof(float) * num_feature_rows * row_length);
1760 return ary;
1761 });
1762
1763} // namespace auto_scheduler
1764} // namespace tvm
1765