1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/tir/transform.h>
20
21#include <cmath>
22#include <memory>
23#include <numeric>
24#include <string>
25#include <unordered_map>
26#include <unordered_set>
27#include <vector>
28
29#include "../utils.h"
30
31namespace tvm {
32namespace tir {
33
34using support::NDIntSet;
35
36/*! \brief Type for multi-dimensional index */
37using MultiIndex = std::vector<PrimExpr>;
38/*! \brief Vector of int64_t */
39using IntVec = std::vector<int64_t>;
40/*! \brief Vector of for loops */
41using ForVec = std::vector<const ForNode*>;
42
43/*!
44 * \brief An unordered_map for (for, buffer) => V
45 * \tparam V The value type
46 */
47template <class V>
48using ForBufferMap = std::unordered_map<const ForNode*, std::unordered_map<const BufferNode*, V>>;
49
50/*! \brief Given x, compute log2(|x| + 1) */
51inline double slog(double x) { return x >= 0 ? std::log2(x + 1) : std::log2(-x + 1); }
52
53namespace utils {
54
55/*!
56 * \brief Get the shape of the buffer
57 * \param buffer The buffer
58 * \param analyzer The analyzer
59 * \return The shape of the buffer
60 */
61std::vector<int64_t> GetBufferShape(const Buffer& buffer, arith::Analyzer* analyzer) {
62 int ndim = buffer->shape.size();
63 std::vector<int64_t> result;
64 result.reserve(ndim);
65 for (const PrimExpr& i : buffer->shape) {
66 if (const IntImmNode* int_imm = i.as<IntImmNode>()) {
67 result.push_back(int_imm->value);
68 continue;
69 }
70 arith::ConstIntBound bound = analyzer->const_int_bound(i);
71 if (0 <= bound->max_value && bound->max_value < arith::ConstIntBound::kPosInf) {
72 result.push_back(bound->max_value);
73 } else {
74 result.push_back(1);
75 }
76 }
77 return result;
78}
79
80/*!
81 * \brief Given a loop, return its `pragma_auto_unroll_max_step` annotation if it exists
82 * \param loop The loop to be checked
83 * \return The value of `pragma_auto_unroll_max_step` if it exists, or -1 if it does not exist
84 */
85int64_t GetPragmaAutoUnroll(const ForNode* loop) {
86 if (Optional<IntImm> auto_unroll = GetAnn<IntImm>(loop, tir::attr::pragma_auto_unroll_max_step)) {
87 return auto_unroll.value()->value;
88 }
89 return -1;
90}
91
92/*!
93 * \brief Given a list of loops, return the extent of the first loop if the list is not empty,
94 * and the first loop has constant extent. Otherwise returns the default value given
95 * \param loops The list of loops to be checked
96 * \param default_value The default value to be returned if the list is empty or the first loop
97 * does not have constant extent
98 * \return The extent of the first loop if the list is not empty, or the first loop has constant
99 * extent. Otherwise returns the default value
100 */
101int64_t FirstLoopExtent(const ForVec& loops, int64_t default_value) {
102 if (!loops.empty()) {
103 if (const int64_t* extent = GetLoopIntExtent(loops[0])) {
104 return *extent;
105 }
106 }
107 return default_value;
108}
109
110/*!
111 * \brief Relax each of the multi-indexing pattern according to the domains bound in the analyzer,
112 * and then union them into a single region
113 * \param multi_index_pattern A list of multi-index pattern to be relaxed
114 * \param numel The size of the single region after union
115 * \param analyzer The analyzer that contains the domain information
116 * \return The relaxed and unioned region
117 */
118IntVec RelaxAndUnion(const std::vector<MultiIndex>& multi_indices, int64_t* numel,
119 arith::Analyzer* analyzer) {
120 *numel = 1;
121 if (multi_indices.empty()) {
122 return {};
123 }
124 int n_indices = multi_indices.size();
125 int ndim = multi_indices[0].size();
126 IntVec access_shape(ndim, 0);
127 for (int i = 0; i < ndim; ++i) {
128 int64_t minimum = arith::ConstIntBound::kPosInf;
129 int64_t maximum = arith::ConstIntBound::kNegInf;
130 for (int j = 0; j < n_indices; ++j) {
131 arith::ConstIntBound bound = analyzer->const_int_bound(multi_indices[j][i]);
132 minimum = std::min(minimum, bound->min_value);
133 maximum = std::max(maximum, bound->max_value);
134 }
135 *numel *= maximum - minimum + 1;
136 access_shape[i] = maximum - minimum + 1;
137 }
138 return access_shape;
139}
140
141/*!
142 * \brief Given a list of multi-index pattern, return the minimal stride of a variable on it
143 * \param multi_indices The list of multi-index pattern
144 * \param buffer_stride The stride of the buffer
145 * \param var The variable to be checked
146 * \return The minimal stride of the variable on the multi-index pattern
147 */
148int64_t GetVarStride(const std::vector<MultiIndex>& multi_indices, const IntVec& buffer_stride,
149 const Var& var) {
150 class CoefficientExtractor : private ExprVisitor {
151 public:
152 static int64_t Extract(const PrimExpr& expr, const Var& var) {
153 CoefficientExtractor extractor(var);
154 extractor.VisitExpr(expr);
155 return (extractor.visited_var && !extractor.visited_mul && !extractor.visited_add)
156 ? 1
157 : (extractor.visited_var ? extractor.stride : 0);
158 }
159
160 private:
161 explicit CoefficientExtractor(const Var& var)
162 : var(var), stride(0), visited_var(false), visited_add(false), visited_mul(false) {}
163
164 void VisitExpr_(const MulNode* node) override {
165 ExprVisitor::VisitExpr_(node);
166 if (visited_var && !visited_add) {
167 if (const auto* a = node->a.as<IntImmNode>()) {
168 visited_mul = true;
169 stride = a->value;
170 } else if (const auto* b = node->b.as<IntImmNode>()) {
171 visited_mul = true;
172 stride = b->value;
173 }
174 }
175 }
176
177 void VisitExpr_(const AddNode* node) override {
178 ExprVisitor::VisitExpr_(node);
179 if (visited_var && !visited_mul) {
180 visited_add = true;
181 stride = 1;
182 }
183 }
184
185 void VisitExpr_(const VarNode* node) override {
186 if (node == var.get()) {
187 visited_var = true;
188 stride = 2;
189 }
190 }
191
192 const Var& var;
193 int64_t stride;
194 bool visited_var;
195 bool visited_add;
196 bool visited_mul;
197 };
198
199 constexpr int64_t kNotFound = std::numeric_limits<int64_t>::max();
200 int ndim = buffer_stride.size();
201 // Calculate the min stride possible
202 int64_t result = kNotFound;
203 for (const MultiIndex& multi_index : multi_indices) {
204 ICHECK_EQ(multi_index.size(), buffer_stride.size());
205 // Find the rightest dimension that contains the given variable
206 for (int i = ndim - 1; i >= 0; --i) {
207 int64_t coef = CoefficientExtractor::Extract(multi_index[i], var);
208 if (coef != 0) {
209 result = std::min(result, std::abs(coef) * buffer_stride[i]);
210 break;
211 }
212 }
213 }
214 return (result == kNotFound) ? 0 : result;
215}
216
217/*!
218 * \brief Converts a 2-dimensional STL vector to a TVM NDArray
219 * \param src The source 2-dimensional STL vector
220 * \return The converted TVM NDArray
221 */
222runtime::NDArray AsNDArray(const std::vector<std::vector<double>>& src) {
223 ICHECK(!src.empty());
224 int n = src.size();
225 int m = src[0].size();
226 runtime::NDArray tgt = runtime::NDArray::Empty(
227 /*shape=*/{n, m},
228 /*dtype=*/DLDataType{kDLFloat, 64, 1},
229 /*ctx=*/DLDevice{kDLCPU, 0});
230 double* data = static_cast<double*>(tgt->data);
231 for (const std::vector<double>& row : src) {
232 for (double v : row) {
233 *data++ = v;
234 }
235 }
236 return tgt;
237}
238
239} // namespace utils
240
241namespace transform {
242
243/*!
244 * \brief Create a pass that simplifies the IR for feature extraction
245 * \return The pass created
246 */
247Pass SimplifyForFeatureExtraction() {
248 class Simplifier : private StmtExprMutator {
249 public:
250 static Stmt Run(Stmt stmt) { return Simplifier()(std::move(stmt)); }
251
252 private:
253 static bool HasBufferLoad(const PrimExpr& expr) {
254 bool found = false;
255 PostOrderVisit(expr, [&found](const ObjectRef& node) {
256 if (node->IsInstance<BufferLoadNode>()) {
257 found = true;
258 }
259 });
260 return found;
261 }
262
263 PrimExpr VisitExpr_(const SelectNode* node) final {
264 if (HasBufferLoad(node->true_value) || HasBufferLoad(node->false_value) ||
265 HasBufferLoad(node->condition)) {
266 return GetRef<Select>(node);
267 }
268 return make_const(node->dtype, 1.0);
269 }
270
271 PrimExpr VisitExpr_(const VarNode* var) final {
272 if (unit_vars_.count(GetRef<Var>(var))) {
273 return make_const(var->dtype, 0.0);
274 }
275 return GetRef<Var>(var);
276 }
277
278 Stmt VisitStmt_(const ForNode* loop) final {
279 if (is_zero(loop->min) && is_one(loop->extent) && loop->kind == ForKind::kSerial &&
280 loop->annotations.empty()) {
281 unit_vars_.insert(loop->loop_var);
282 return VisitStmt(loop->body);
283 } else {
284 return StmtExprMutator::VisitStmt_(loop);
285 }
286 }
287
288 std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> unit_vars_;
289 };
290 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
291 PrimFuncNode* n = f.CopyOnWrite();
292 n->body = Simplifier::Run(std::move(n->body));
293 return f;
294 };
295 return CreatePrimFuncPass(pass_func, 0, "tir.SimplifyForFeatureExtraction", {});
296}
297
298/*!
299 * \brief Create a list of passes that preprocesses the IR for feature extraction
300 * \return The list of passes created
301 */
302Sequential PassListForPerStoreFeature() {
303 return Sequential({
304 tir::transform::RemoveWeightLayoutRewriteBlock(/*skip_ndarray_rewrite*/ true),
305 tir::transform::SimplifyForFeatureExtraction(),
306 tir::transform::LowerCrossThreadReduction(),
307 tir::transform::LowerInitBlock(),
308 tir::transform::PlanAndUpdateBufferAllocationLocation(),
309 tir::transform::ConvertBlocksToOpaque(),
310 tir::transform::UnifyThreadBinding(),
311 tir::transform::CompactBufferAllocation(),
312 tir::transform::LowerMatchBuffer(),
313 tir::transform::Simplify(),
314 });
315}
316
317} // namespace transform
318
319/*! \brief A data structure managing loop nests */
320struct LoopNest {
321 int64_t prod = 1; // The product of the extents of all the loops
322 ForVec loops; // All the loops
323 IntVec auto_unroll; // The loops with auto unroll pragma
324 ForVec parallel; // The loops whose ForKind are kParallel
325 ForVec vectorize; // The loops whose ForKind are kVectorized
326 ForVec unroll; // The loops whose ForKind are kUnrolled
327 ForVec blockIdx_x; // The loops whose ForKind are kThreadBinding to blockIdx.x
328 ForVec blockIdx_y; // The loops whose ForKind are kThreadBinding to blockIdx.y
329 ForVec blockIdx_z; // The loops whose ForKind are kThreadBinding to blockIdx.z
330 ForVec threadIdx_x; // The loops whose ForKind are kThreadBinding to threadIdx.x
331 ForVec threadIdx_y; // The loops whose ForKind are kThreadBinding to threadIdx.y
332 ForVec threadIdx_z; // The loops whose ForKind are kThreadBinding to threadIdx.z
333 ForVec vthread; // The loops whose ForKind are kThreadBinding to vthread.*
334
335 /*!
336 * \brief Push a new loop into the loop nest
337 * \param loop The loop to be pushed
338 * \param auto_unroll_attr The auto unroll attribute of the loop
339 * \return A list of for loops that the loop is bound to
340 */
341 ForVec* Push(const ForNode* loop, int64_t* auto_unroll_attr) {
342 if (const int64_t* extent = GetLoopIntExtent(loop)) {
343 this->prod *= *extent;
344 }
345 this->loops.push_back(loop);
346 if ((*auto_unroll_attr = utils::GetPragmaAutoUnroll(loop)) > 0) {
347 this->auto_unroll.push_back(*auto_unroll_attr);
348 }
349 ForVec* ref_loops = nullptr;
350 if (loop->kind == ForKind::kParallel) {
351 ref_loops = &parallel;
352 } else if (loop->kind == ForKind::kVectorized) {
353 ref_loops = &vectorize;
354 } else if (loop->kind == ForKind::kUnrolled) {
355 ref_loops = &unroll;
356 } else if (loop->kind == ForKind::kThreadBinding) {
357 std::string thread_tag = loop->thread_binding.value()->thread_tag;
358 if (thread_tag == "blockIdx.x") {
359 ref_loops = &blockIdx_x;
360 } else if (thread_tag == "blockIdx.y") {
361 ref_loops = &blockIdx_y;
362 } else if (thread_tag == "blockIdx.z") {
363 ref_loops = &blockIdx_z;
364 } else if (thread_tag == "threadIdx.x") {
365 ref_loops = &threadIdx_x;
366 } else if (thread_tag == "threadIdx.y") {
367 ref_loops = &threadIdx_y;
368 } else if (thread_tag == "threadIdx.z") {
369 ref_loops = &threadIdx_z;
370 } else if (support::StartsWith(thread_tag, "vthread")) {
371 ref_loops = &vthread;
372 } else {
373 LOG(FATAL) << "ValueError: Unable to recognize thread tag: " << thread_tag;
374 }
375 }
376 if (ref_loops != nullptr) {
377 ref_loops->push_back(loop);
378 }
379 return ref_loops;
380 }
381
382 /*!
383 * \brief Pop the last loop from the loop nest
384 * \param loop The loop to be popped
385 * \param ref_loops The list of for loops that the loop is bound to
386 * \param auto_unroll_attr The auto unroll attribute of the loop
387 */
388 void Pop(const ForNode* loop, ForVec* ref_loops, int auto_unroll_attr) {
389 if (ref_loops) {
390 ref_loops->pop_back();
391 }
392 if (auto_unroll_attr > 0) {
393 this->auto_unroll.pop_back();
394 }
395 if (const int64_t* extent = GetLoopIntExtent(loop)) {
396 this->prod /= *extent;
397 }
398 this->loops.pop_back();
399 }
400};
401
402/****** Group 1: Computation related features ******/
403
404namespace group1 {
405
406/*! \brief Group 1 features */
407struct Feature {
408 /*! \brief Arithmetic features */
409 struct ArithOps {
410 // Float-point arithmetic features
411 int64_t float_mad = 0; // The number of float MAD (Multiply–add) ops
412 int64_t float_add_sub = 0; // The number of float add and sub ops
413 int64_t float_mul = 0; // The number of float multiply ops
414 int64_t float_div_mod = 0; // The number of float div and mod ops
415 int64_t float_cmp = 0; // The number of float comparison ops
416 int64_t float_math_func = 0; // The number of float math func calls
417 int64_t float_other_func = 0; // The number of other float func calls
418 // Integer arithmetic features
419 int64_t int_mad = 0; // The number of integer MAD (Multiply–add) ops
420 int64_t int_add_sub = 0; // The number of integer add and sub ops
421 int64_t int_mul = 0; // The number of integer multiply ops
422 int64_t int_div_mod = 0; // The number of integer div and mod ops
423 int64_t int_cmp = 0; // The number of integer comparison ops
424 int64_t int_math_func = 0; // The number of integer math func calls
425 int64_t int_other_func = 0; // The number of other integer func calls
426 // Other arithmetic features
427 int64_t bool_op = 0; // The number of bool ops
428 int64_t select_op = 0; // The number of select ops
429
430 static constexpr int64_t kCount = 16;
431
432 ArithOps() = default;
433 ArithOps(const BufferStoreNode* store, int64_t prod_loop_extent);
434
435 void Export(std::vector<double>* v) const {
436 double vs[] = {
437 slog(float_mad), slog(float_add_sub), slog(float_mul), slog(float_div_mod),
438 slog(float_cmp), slog(float_math_func), slog(float_other_func), //
439 slog(int_mad), slog(int_add_sub), slog(int_mul), slog(int_div_mod),
440 slog(int_cmp), slog(int_math_func), slog(int_other_func), //
441 slog(bool_op), slog(select_op),
442 };
443 v->insert(v->end(), std::begin(vs), std::end(vs));
444 }
445 };
446
447 /*! \brief Loop binding features */
448 struct ForKindFeature {
449 enum class Pos : int {
450 kPosNone = 0, // Does not have this kind of annotation
451 kPosInnerSpatial = 1, // The annotated iterator is the innermost spatial iterator
452 kPosMiddleSpatial = 2, // The annotated iterator is a middle spatial iterator
453 kPosOuterSpatial = 3, // The annotated iterator is the outermost spatial iterator
454 kPosInnerReduce = 4, // The annotated iterator is the innermost reduce iterator
455 kPosMiddleReduce = 5, // The annotated iterator is a middle reduce iterator
456 kPosOuterReduce = 6, // The annotated iterator is the outermost reduce iterator
457 kPosMixed = 7, // The annotated iterator is a mixed space and reduce iterator
458 };
459 int64_t num = 0; // The number of iterators with the annotation
460 int64_t prod = 0; // The product of the lengths of iterators with the annotation
461 int64_t len = 0; // The length of the innermost iterator with the annotation
462 Pos pos = Pos::kPosMixed; // The position of the iterators with the annotation
463
464 static constexpr int64_t kCount = 11;
465
466 explicit ForKindFeature(const ForVec& loops);
467
468 void Export(std::vector<double>* v) const {
469 double vs[] = {
470 slog(num),
471 slog(prod),
472 slog(len),
473 static_cast<double>(static_cast<int>(pos) == 0),
474 static_cast<double>(static_cast<int>(pos) == 1),
475 static_cast<double>(static_cast<int>(pos) == 2),
476 static_cast<double>(static_cast<int>(pos) == 3),
477 static_cast<double>(static_cast<int>(pos) == 4),
478 static_cast<double>(static_cast<int>(pos) == 5),
479 static_cast<double>(static_cast<int>(pos) == 6),
480 static_cast<double>(static_cast<int>(pos) == 7),
481 };
482 v->insert(v->end(), std::begin(vs), std::end(vs));
483 }
484 };
485
486 ArithOps arith_ops; // Arithmetic features
487 ForKindFeature vectorize; // Loop binding features: kVectorize
488 ForKindFeature unroll; // Loop binding features: kUnroll
489 ForKindFeature parallel; // Loop binding features: kParallel
490 bool is_gpu = false; // If the program is running on GPU
491 int64_t blockIdx_x_len = 1; // The length of blockIdx.x
492 int64_t blockIdx_y_len = 1; // The length of blockIdx.y
493 int64_t blockIdx_z_len = 1; // The length of blockIdx.z
494 int64_t threadIdx_x_len = 1; // The length of threadIdx.x
495 int64_t threadIdx_y_len = 1; // The length of threadIdx.y
496 int64_t threadIdx_z_len = 1; // The length of threadIdx.z
497 int64_t vthread_len = 1; // The length of virtual thread
498
499 static constexpr int64_t kCount = ArithOps::kCount + ForKindFeature::kCount * 3 + 8;
500
501 explicit Feature(const BufferStoreNode* store, const LoopNest& loop_nest, bool is_gpu)
502 : arith_ops(store, loop_nest.prod),
503 vectorize(loop_nest.vectorize),
504 unroll(loop_nest.unroll),
505 parallel(loop_nest.parallel) {
506 if (is_gpu) {
507 this->is_gpu = true;
508 this->blockIdx_x_len = utils::FirstLoopExtent(loop_nest.blockIdx_x, 1);
509 this->blockIdx_y_len = utils::FirstLoopExtent(loop_nest.blockIdx_y, 1);
510 this->blockIdx_z_len = utils::FirstLoopExtent(loop_nest.blockIdx_z, 1);
511 this->threadIdx_x_len = utils::FirstLoopExtent(loop_nest.threadIdx_x, 1);
512 this->threadIdx_y_len = utils::FirstLoopExtent(loop_nest.threadIdx_y, 1);
513 this->threadIdx_z_len = utils::FirstLoopExtent(loop_nest.threadIdx_z, 1);
514 this->vthread_len = utils::FirstLoopExtent(loop_nest.vthread, 1);
515 }
516 }
517
518 void Export(std::vector<double>* v) const {
519 this->arith_ops.Export(v);
520 this->vectorize.Export(v);
521 this->unroll.Export(v);
522 this->parallel.Export(v);
523 double vs[] = {
524 static_cast<double>(is_gpu), //
525 slog(blockIdx_x_len), slog(blockIdx_y_len), slog(blockIdx_z_len),
526 slog(threadIdx_x_len), slog(threadIdx_y_len), slog(threadIdx_z_len),
527 slog(vthread_len),
528 };
529 v->insert(v->end(), std::begin(vs), std::end(vs));
530 }
531};
532
533Feature::ArithOps::ArithOps(const BufferStoreNode* store, int64_t prod_loop_extent) {
534 class ArithOpCounter : public ExprVisitor {
535 public:
536#define TVM_FEATURE_SIMPLE(Type, Counter) \
537 void VisitExpr_(const Type* op) final { \
538 result_.Counter += this->prod_loop_extent_; \
539 ExprVisitor::VisitExpr_(op); \
540 }
541#define TVM_FEATURE_BINARY(Type, FloatCounter, IntCounter) \
542 void VisitExpr_(const Type* op) final { \
543 if (op->dtype.is_float()) { \
544 result_.FloatCounter += this->prod_loop_extent_; \
545 } else { \
546 result_.IntCounter += this->prod_loop_extent_; \
547 } \
548 ExprVisitor::VisitExpr_(op); \
549 }
550 TVM_FEATURE_SIMPLE(AndNode, bool_op);
551 TVM_FEATURE_SIMPLE(OrNode, bool_op);
552 TVM_FEATURE_SIMPLE(NotNode, bool_op);
553 TVM_FEATURE_SIMPLE(SelectNode, select_op);
554 TVM_FEATURE_BINARY(AddNode, float_add_sub, int_add_sub);
555 TVM_FEATURE_BINARY(SubNode, float_add_sub, int_add_sub);
556 TVM_FEATURE_BINARY(MulNode, float_mul, int_mul);
557 TVM_FEATURE_BINARY(DivNode, float_div_mod, int_div_mod);
558 TVM_FEATURE_BINARY(ModNode, float_div_mod, int_div_mod);
559 TVM_FEATURE_BINARY(FloorDivNode, float_div_mod, int_div_mod);
560 TVM_FEATURE_BINARY(FloorModNode, float_div_mod, int_div_mod);
561 TVM_FEATURE_BINARY(MaxNode, float_cmp, int_cmp);
562 TVM_FEATURE_BINARY(MinNode, float_cmp, int_cmp);
563 TVM_FEATURE_BINARY(EQNode, float_cmp, int_cmp);
564 TVM_FEATURE_BINARY(NENode, float_cmp, int_cmp);
565 TVM_FEATURE_BINARY(LTNode, float_cmp, int_cmp);
566 TVM_FEATURE_BINARY(LENode, float_cmp, int_cmp);
567 TVM_FEATURE_BINARY(GTNode, float_cmp, int_cmp);
568 TVM_FEATURE_BINARY(GENode, float_cmp, int_cmp);
569#undef TVM_FEATURE_BINARY
570#undef TVM_FEATURE_SIMPLE
571
572 void VisitExpr_(const CallNode* op) final {
573 static auto op_call_effect_ = Op::GetAttrMap<TCallEffectKind>("TCallEffectKind");
574 TCallEffectKind effect_kind = op_call_effect_[Downcast<Op>(op->op)];
575 bool is_pure =
576 effect_kind == CallEffectKind::kPure || effect_kind == CallEffectKind::kExprAnnotation;
577 if (is_pure) {
578 if (op->dtype.is_float()) {
579 result_.float_math_func += prod_loop_extent_;
580 } else {
581 result_.int_math_func += prod_loop_extent_;
582 }
583 } else {
584 if (op->dtype.is_float()) {
585 result_.float_other_func += prod_loop_extent_;
586 } else {
587 result_.int_other_func += prod_loop_extent_;
588 }
589 }
590 ExprVisitor::VisitExpr_(op);
591 }
592
593 int64_t prod_loop_extent_;
594 ArithOps result_;
595 };
596 ArithOpCounter counter;
597 counter.prod_loop_extent_ = prod_loop_extent;
598 counter(store->value);
599 *this = counter.result_;
600}
601
602Feature::ForKindFeature::ForKindFeature(const ForVec& loops) {
603 if (loops.empty()) {
604 this->num = 0;
605 this->prod = 0;
606 this->len = 0;
607 this->pos = ForKindFeature::Pos::kPosNone;
608 } else {
609 const int64_t* last_loop_extent = GetLoopIntExtent(loops.back());
610 this->num = loops.size();
611 this->len = last_loop_extent ? *last_loop_extent : 1;
612 this->pos = ForKindFeature::Pos::kPosMixed;
613 int64_t& prod = this->prod = 1;
614 for (const ForNode* loop : loops) {
615 if (const int64_t* extent = GetLoopIntExtent(loop)) {
616 prod *= *extent;
617 }
618 }
619 }
620}
621
622} // namespace group1
623
624namespace group2 {
625
626/*! \brief Group 2 features */
627struct Feature {
628 enum class AccessType : int {
629 /*! The buffer is read but not written */
630 kRead = 0,
631 /*! The buffer is written but not read */
632 kWrite = 1,
633 /*! The buffer is both read and written */
634 kReadWrite = 2,
635 /*! Unknown type */
636 kUnknownRW = 3,
637 };
638 enum class ReuseType : int {
639 /*! Buffer reuse because accessed on each iteration of a loop */
640 kLoopMultipleRead = 0,
641 /*! Buffer reuse because it is serially accessed */
642 kSerialMultipleReadWrite = 1,
643 /*! No buffer reuse */
644 kNoReuse = 2,
645 };
646
647 struct SubFeature {
648 /*! \brief The buffer this feature is for */
649 const BufferNode* buffer = nullptr;
650 /*! \brief The access type of the buffer */
651 AccessType access_type = AccessType::kUnknownRW;
652 /*! \brief A list of multi-dimensonal indices used to access the buffer */
653 std::vector<MultiIndex> multi_indices = {};
654 // Access information
655 /*! \brief loop_accessed_numel[i][...] means the number of elements accessed by loops[i] */
656 std::vector<std::unordered_map<const BufferNode*, int64_t>> loop_accessed_numel = {};
657 /*! \brief The shape of the data access */
658 IntVec access_shape;
659 /*! \brief The bytes that are continuously accessed */
660 int64_t num_continuous_bytes = 1;
661 // Stride information
662 /*! \brief The min stride of the access */
663 int64_t min_stride = 0;
664 /*! \brief The innermost stride */
665 int64_t innermost_stride = 0;
666 /*! \brief The product of the non-strided loops */
667 int64_t prod_non_strided_loop_extent = 0;
668 // Reuse information
669 /*! The type of data reuse */
670 ReuseType reuse_type = ReuseType::kNoReuse;
671 /*! The reuse distance in terms of number of iterations */
672 double reuse_dis_iter = 0.0;
673 /*! The reuse distance in terms of bytes */
674 double reuse_dis_bytes = 0.0;
675 /*! The reuse count */
676 int64_t reuse_ct = 0;
677 // Features
678 /*! The touched memory in bytes */
679 double bytes;
680 /*! The touched unique memory in bytes */
681 double unique_bytes;
682 /*! The number of touched cache lines */
683 double lines;
684 /*! The number touched unique cache lines */
685 double unique_lines;
686 /*! bytes / reuse_ct */
687 double bytes_d_reuse_ct;
688 /*! unique_bytes / reuse_ct */
689 double unique_bytes_d_reuse_ct;
690 /*! lines / reuse_ct */
691 double lines_d_reuse_ct;
692 /*! unique_lines / reuse_ct */
693 double unique_lines_d_reuse_ct;
694 /*! The stride in access */
695 double stride;
696
697 static constexpr int64_t kCount = 18;
698
699 void Export(std::vector<double>* v) const {
700 double vs[] = {
701 static_cast<double>(static_cast<int>(access_type) == 0),
702 static_cast<double>(static_cast<int>(access_type) == 1),
703 static_cast<double>(static_cast<int>(access_type) == 2),
704 // FeatureSet::BufferAccess::AccessType::kUnknownRW is ignored
705 slog(bytes),
706 slog(unique_bytes),
707 slog(lines),
708 slog(unique_lines),
709 static_cast<double>(static_cast<int>(reuse_type) == 0),
710 static_cast<double>(static_cast<int>(reuse_type) == 1),
711 static_cast<double>(static_cast<int>(reuse_type) == 2),
712 slog(reuse_dis_iter),
713 slog(reuse_dis_bytes),
714 slog(reuse_ct),
715 slog(bytes_d_reuse_ct),
716 slog(unique_bytes_d_reuse_ct),
717 slog(lines_d_reuse_ct),
718 slog(unique_lines_d_reuse_ct),
719 slog(stride),
720 };
721 v->insert(v->end(), std::begin(vs), std::end(vs));
722 }
723
724 static void Pad(std::vector<double>* v) { v->insert(v->end(), 18, 0.0); }
725
726 void SetStride(const LoopNest& loop_nest, arith::Analyzer* analyzer);
727
728 void SetReuse(const LoopNest& loop_nest, //
729 int64_t top_loop_touch_bytes, //
730 const ForBufferMap<IntVec>& buffer_touched_under_loop);
731
732 void SetFeature(const LoopNest& loop_nest, int64_t cache_line_bytes);
733
734 explicit SubFeature(const BufferNode* buffer, AccessType access_type,
735 std::vector<MultiIndex> multi_indices, int n_loops)
736 : buffer(buffer),
737 access_type(access_type),
738 multi_indices(multi_indices),
739 loop_accessed_numel(n_loops) {}
740 };
741
742 void Export(std::vector<double>* v, int buffers_per_store) const {
743 int n = sub_features.size();
744 for (int i = 0; i < buffers_per_store; ++i) {
745 if (i < n) {
746 sub_features[i].Export(v);
747 } else {
748 SubFeature::Pad(v);
749 }
750 }
751 }
752
753 explicit Feature(const BufferStoreNode* store, const LoopNest& loop_nest,
754 int64_t cache_line_bytes, IntVec* for_touched_bytes,
755 ForBufferMap<IntVec>* buffer_touched_under_loop, arith::Analyzer* analyzer);
756
757 void Init(const BufferStoreNode* store, int n_loops);
758
759 void SetRegion(const LoopNest& loop_nest, //
760 IntVec* for_touched_bytes, //
761 ForBufferMap<IntVec>* buffer_touched_under_loop, //
762 arith::Analyzer* analyzer);
763
764 std::vector<SubFeature> sub_features;
765};
766
767void Feature::Init(const BufferStoreNode* store, int n_loops) {
768 struct Info {
769 AccessType access_type = AccessType::kUnknownRW;
770 std::vector<MultiIndex> multi_indices;
771 };
772 std::unordered_map<const BufferNode*, Info> buffer_info;
773 {
774 Info& info = buffer_info[store->buffer.get()];
775 info.access_type = AccessType::kWrite;
776 info.multi_indices.push_back({store->indices.begin(), store->indices.end()});
777 }
778 PostOrderVisit(store->value, [&buffer_info](const ObjectRef& obj) -> void {
779 if (const BufferLoadNode* load = obj.as<BufferLoadNode>()) {
780 const BufferNode* buffer = load->buffer.get();
781 Info& info = buffer_info[buffer];
782 switch (info.access_type) {
783 case AccessType::kRead:
784 break;
785 case AccessType::kWrite:
786 info.access_type = AccessType::kReadWrite;
787 break;
788 case AccessType::kReadWrite:
789 break;
790 case AccessType::kUnknownRW:
791 default:
792 info.access_type = AccessType::kRead;
793 break;
794 }
795 if (info.access_type != AccessType::kReadWrite) {
796 info.multi_indices.push_back({load->indices.begin(), load->indices.end()});
797 }
798 }
799 });
800 this->sub_features.reserve(buffer_info.size());
801 for (const auto& kv : buffer_info) {
802 this->sub_features.emplace_back(kv.first, kv.second.access_type,
803 std::move(kv.second.multi_indices), n_loops);
804 }
805}
806
807void Feature::SetRegion(const LoopNest& loop_nest, IntVec* for_touched_bytes,
808 ForBufferMap<IntVec>* buffer_touched_under_loop,
809 arith::Analyzer* analyzer) {
810 int n_loops = loop_nest.loops.size();
811 const std::vector<const ForNode*>& loops = loop_nest.loops;
812 // Step 1. Initialize and bind all the loop variables to a constant
813 *for_touched_bytes = IntVec(n_loops, 0);
814 for (int i = 0; i < n_loops; ++i) {
815 const ForNode* loop = loops[i];
816 analyzer->Bind(loop->loop_var, loop->min, /*allow_override=*/true);
817 }
818 // Step 2. Corner case: no loops
819 if (n_loops == 0) {
820 // In this case, the `access_shape` is not calculated
821 for (SubFeature& feature : sub_features) {
822 feature.access_shape = IntVec(feature.buffer->shape.size(), 1);
823 }
824 return;
825 }
826 // Step 3. Gradually bind the loops from inner to outer,
827 // calculate the area the loops touch on each buffer
828 for (int i = n_loops - 1; i >= 0; --i) {
829 const ForNode* loop = loops[i];
830 analyzer->Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent),
831 /*allow_override=*/true);
832 int64_t& touched_bytes = (*for_touched_bytes)[i] = 0;
833 for (SubFeature& feature : sub_features) {
834 const BufferNode* buffer = feature.buffer;
835 // Note: `feature.access_shape` for `i == 0` is the only one preserved,
836 // while others are discarded
837 int64_t numel;
838 feature.access_shape = utils::RelaxAndUnion(feature.multi_indices, &numel, analyzer);
839 numel = std::max<int64_t>(0, numel);
840 feature.loop_accessed_numel[i][buffer] = numel;
841 touched_bytes += numel * buffer->dtype.bytes();
842 (*buffer_touched_under_loop)[loop][buffer].push_back(numel);
843 }
844 }
845}
846
847void Feature::SubFeature::SetStride(const LoopNest& loop_nest, arith::Analyzer* analyzer) {
848 int n_loops = loop_nest.loops.size();
849 const std::vector<const ForNode*>& loops = loop_nest.loops;
850 // For each buffer, we find the loop stride on it
851 const BufferNode* buffer = this->buffer;
852 int ndim = this->buffer->shape.size();
853 IntVec buffer_shape = utils::GetBufferShape(GetRef<Buffer>(buffer), analyzer);
854 // Calculate the buffer's stride from its shape
855 IntVec buffer_stride(ndim);
856 if (ndim >= 1) {
857 buffer_stride[ndim - 1] = 1;
858 for (int i = ndim - 2; i >= 0; --i) {
859 buffer_stride[i] = buffer_stride[i + 1] * buffer_shape[i + 1];
860 }
861 }
862 // Calculate `num_continuous_bytes`
863 {
864 int64_t& num_continuous_bytes = this->num_continuous_bytes = 1;
865 const IntVec& access_shape = this->access_shape;
866 ICHECK_EQ(access_shape.size(), buffer_shape.size());
867 for (int i = ndim - 1; i >= 0; --i) {
868 if (access_shape[i] == buffer_shape[i]) {
869 num_continuous_bytes = buffer_shape[i] * buffer->dtype.bytes();
870 break;
871 }
872 }
873 }
874 // Enumerate loops from inner to outer
875 int i = 0;
876 // Calculate this->min_stride
877 int64_t& stride = this->min_stride = 0;
878 for (i = n_loops - 1; i >= 0; --i) {
879 stride = utils::GetVarStride(this->multi_indices, buffer_stride, loops[i]->loop_var);
880 if (stride != 0) {
881 break;
882 }
883 }
884 // Calculate this->innermost_stride
885 this->innermost_stride = (i == n_loops - 1) ? stride : 0;
886 // Calculate this->prod
887 int64_t& prod = this->prod_non_strided_loop_extent = 1;
888 for (int j = n_loops - 1; j > i; --j) {
889 if (const int64_t* extent = GetLoopIntExtent(loops[j])) {
890 prod *= *extent;
891 }
892 }
893}
894
895void Feature::SubFeature::SetReuse(const LoopNest& loop_nest, int64_t top_loop_touch_bytes,
896 const ForBufferMap<IntVec>& buffer_touched_under_loop) {
897 const BufferNode* buffer = this->buffer;
898 // Step 3.1. Collect all `Var`s that appears in the buffer region
899 std::unordered_set<const VarNode*> region_vars;
900 for (const MultiIndex& multi_index : this->multi_indices) {
901 for (const PrimExpr& index : multi_index) {
902 PostOrderVisit(index, [&region_vars](const ObjectRef& obj) -> void {
903 if (const auto* var = obj.as<VarNode>()) {
904 region_vars.insert(var);
905 }
906 });
907 }
908 }
909 // Default case: no reuse
910 ReuseType& reuse_type = this->reuse_type = ReuseType::kNoReuse;
911 double& reuse_dis_iter = this->reuse_dis_iter = 0;
912 double& reuse_dis_bytes = this->reuse_dis_bytes = 0;
913 int64_t& reuse_ct = this->reuse_ct = 0;
914
915 // Step 3.2. Enumerate loops from inner to outer, find the first loop with reuse
916 int n_loops = loop_nest.loops.size();
917 const std::vector<const ForNode*>& loops = loop_nest.loops;
918 for (int i = n_loops - 1; i >= 0; --i) {
919 const ForNode* loop = loops[i];
920 // Case 1. Find an invariant loop, i.e. reuse with kLoopMultipleRead
921 if (!region_vars.count(loop->loop_var.get())) {
922 reuse_type = ReuseType::kLoopMultipleRead;
923 if (const int64_t* extent = GetLoopIntExtent(loop)) {
924 reuse_ct = *extent;
925 } else {
926 reuse_ct = 1;
927 }
928 reuse_dis_iter = 1;
929 for (int j = n_loops - 1; j > i; --j) {
930 if (const int64_t* extent = GetLoopIntExtent(loops[j])) {
931 reuse_dis_iter *= *extent;
932 }
933 }
934 reuse_dis_bytes = 0.0;
935 if (i == n_loops - 1) {
936 reuse_dis_bytes = top_loop_touch_bytes;
937 } else {
938 for (const auto& iter : buffer_touched_under_loop.at(loops[i + 1])) {
939 const BufferNode* buffer = iter.first;
940 const IntVec& numels = iter.second;
941 int64_t numel = std::accumulate(numels.begin(), numels.end(), int64_t(0));
942 reuse_dis_bytes += numel * buffer->dtype.bytes();
943 }
944 }
945 break;
946 }
947 // Case 2. Find serial reuse, i.e. reuse with kSerialMultipleReadWrite
948 const IntVec& touched = buffer_touched_under_loop.at(loop).at(buffer);
949 if (touched.size() >= 2) {
950 int64_t extent = 1;
951 if (const int64_t* ext = GetLoopIntExtent(loop)) {
952 extent = *ext;
953 }
954 reuse_type = ReuseType::kSerialMultipleReadWrite;
955 reuse_ct = touched.size() - 1;
956 reuse_dis_iter = *std::min_element(touched.begin(), touched.end());
957 reuse_dis_bytes = 0.0;
958 for (const auto& iter : buffer_touched_under_loop.at(loop)) {
959 const BufferNode* buffer = iter.first;
960 const IntVec& numels = iter.second;
961 int64_t numel = std::accumulate(numels.begin(), numels.end(), int64_t(0));
962 reuse_dis_bytes += numel * buffer->dtype.bytes();
963 }
964 reuse_dis_iter /= extent;
965 reuse_dis_bytes /= extent;
966 break;
967 }
968 }
969}
970
971void Feature::SubFeature::SetFeature(const LoopNest& loop_nest, int64_t cache_line_bytes) {
972 int64_t dtype_bytes = this->buffer->dtype.bytes();
973 this->stride = this->innermost_stride;
974 this->bytes = dtype_bytes * loop_nest.prod;
975 if (loop_nest.loops.empty()) {
976 this->unique_bytes = 1;
977 this->lines = 1;
978 this->unique_lines = 1;
979 } else {
980 this->unique_bytes =
981 static_cast<double>(this->loop_accessed_numel.front().at(buffer)) * dtype_bytes;
982 this->lines = static_cast<double>(loop_nest.prod) / this->prod_non_strided_loop_extent *
983 std::min(1.0, 1.0 * this->min_stride * dtype_bytes / cache_line_bytes);
984 this->lines = std::max(1.0, this->lines);
985 this->unique_lines = static_cast<double>(this->unique_bytes) /
986 std::min(cache_line_bytes, this->num_continuous_bytes);
987 this->unique_lines = std::max(1.0, this->unique_lines);
988 }
989 double proxy_reuse_ct = this->reuse_ct > 0 ? this->reuse_ct : 0.5;
990 this->bytes_d_reuse_ct = this->bytes / proxy_reuse_ct;
991 this->unique_bytes_d_reuse_ct = this->unique_bytes / proxy_reuse_ct;
992 this->lines_d_reuse_ct = this->lines / proxy_reuse_ct;
993 this->unique_lines_d_reuse_ct = this->unique_lines / proxy_reuse_ct;
994}
995
996Feature::Feature(const BufferStoreNode* store, const LoopNest& loop_nest, int64_t cache_line_bytes,
997 IntVec* for_touched_bytes, ForBufferMap<IntVec>* buffer_touched_under_loop,
998 arith::Analyzer* analyzer) {
999 int n_loops = loop_nest.loops.size();
1000 // Step 0. Initialize data structures
1001 this->Init(store, n_loops);
1002 // Step 1. Calculate region-related feature
1003 this->SetRegion(loop_nest, for_touched_bytes, buffer_touched_under_loop, analyzer);
1004 // Step 2. Calculate stride-related feature
1005 for (auto& feature : sub_features) {
1006 feature.SetStride(loop_nest, analyzer);
1007 }
1008 // Step 3. Calculate reuse-related feature
1009 int64_t top_loop_touch_bytes = 0.0;
1010 if (n_loops > 0) {
1011 for (const SubFeature& feature : sub_features) {
1012 int64_t bytes = feature.buffer->dtype.bytes();
1013 int64_t n_buffer = feature.loop_accessed_numel[0].size();
1014 top_loop_touch_bytes += bytes * n_buffer;
1015 }
1016 }
1017 for (auto& feature : sub_features) {
1018 feature.SetReuse(loop_nest, top_loop_touch_bytes, *buffer_touched_under_loop);
1019 }
1020 // Step 4. Calculate rest of the features
1021 for (auto& feature : sub_features) {
1022 feature.SetFeature(loop_nest, cache_line_bytes);
1023 }
1024 // Step 5. Sort the features
1025 std::sort(sub_features.begin(), sub_features.end(), [](const SubFeature& a, const SubFeature& b) {
1026 if (a.lines != b.lines) {
1027 return a.lines > b.lines;
1028 }
1029 if (a.bytes != b.bytes) {
1030 return a.bytes > b.bytes;
1031 }
1032 return a.buffer->name < b.buffer->name;
1033 });
1034}
1035
1036} // namespace group2
1037
1038namespace group3 {
1039
1040/*! \brief Group 3 feature */
1041struct Feature {
1042 /*!
1043 * \brief See the wiki page [1] for details
1044 *
1045 * Arithmetic intensity is FLOPs/unique bytes of memory touched. A value is computed
1046 * for each set of loop nests starting with just the innermost loop and
1047 * reaching to include all loops. There are a variable number of loops, so
1048 * n_samples are taken from the curve of arithmetic intensity vs flops. This
1049 * biases the values towards larger loops.
1050 *
1051 * Note that the denominator is unique bytes of memory touched. Repeated
1052 * access to the same byte of memory counts as only a single byte touched.
1053 *
1054 * Values are scaled by log2(x + 1).
1055 *
1056 * [1] https://en.wikipedia.org/wiki/Roofline_model
1057 */
1058 std::vector<double> arith_intensity_curve;
1059
1060 void Export(std::vector<double>* v) const {
1061 v->insert(v->end(), arith_intensity_curve.begin(), arith_intensity_curve.end());
1062 }
1063
1064 explicit Feature(int n_samples, const LoopNest& loop_nest, const IntVec& for_touched_bytes,
1065 const group1::Feature::ArithOps& arith_ops)
1066 : arith_intensity_curve(n_samples, 0.0) {
1067 const std::vector<const ForNode*>& loops = loop_nest.loops;
1068 ICHECK_EQ(loops.size(), for_touched_bytes.size());
1069 int n_loops = loops.size();
1070 // Calculate `memory_bytes`
1071 std::vector<double> memory_bytes;
1072 memory_bytes.resize(n_loops);
1073 for (int i = 0; i < n_loops; ++i) {
1074 memory_bytes[n_loops - 1 - i] = for_touched_bytes[i];
1075 }
1076 // Calculate `compute_ops` and `cur_compute_ops`
1077 std::vector<double> compute_ops;
1078 double total_compute_ops = arith_ops.float_mad + arith_ops.float_add_sub + arith_ops.float_mul +
1079 arith_ops.float_div_mod + arith_ops.float_cmp +
1080 arith_ops.float_math_func + arith_ops.float_other_func;
1081 total_compute_ops /= loop_nest.prod;
1082 for (int i = n_loops - 1; i >= 0; --i) {
1083 if (const int64_t* extent = GetLoopIntExtent(loops[i])) {
1084 total_compute_ops *= *extent;
1085 }
1086 compute_ops.push_back(total_compute_ops);
1087 }
1088 // Fill the feature set
1089 if (total_compute_ops <= 0 || compute_ops.empty()) {
1090 for (int i = 0; i < n_samples; ++i) {
1091 arith_intensity_curve[i] = 0.0;
1092 }
1093 return;
1094 }
1095 total_compute_ops = compute_ops.back();
1096 int p = 0;
1097 for (int i = 0; i < n_samples; ++i) {
1098 double& result = arith_intensity_curve[i];
1099 double cur_compute_ops = static_cast<double>(i + 1) / n_samples * total_compute_ops;
1100 // Find the first `p` that `compute[p] >= total * (i + 1) / N`
1101 for (; p < n_loops; ++p) {
1102 if (compute_ops[p] >= cur_compute_ops - 1e-4) {
1103 break;
1104 }
1105 }
1106 CHECK_LT(p, n_loops);
1107 if (p == 0) {
1108 result = slog(compute_ops[p] / memory_bytes[p]);
1109 } else {
1110 double base = compute_ops[p - 1] / memory_bytes[p - 1];
1111 double slope =
1112 (compute_ops[p] / memory_bytes[p] - compute_ops[p - 1] / memory_bytes[p - 1]) /
1113 (compute_ops[p] - compute_ops[p - 1]);
1114 result = slog(base + slope * (cur_compute_ops - compute_ops[p - 1]));
1115 }
1116 }
1117 }
1118};
1119
1120} // namespace group3
1121
1122namespace group4 {
1123
1124/*! \brief Group 4 feature */
1125struct Feature {
1126 int64_t alloc_size = 0; // The size of allocated buffer in bytes
1127 int64_t alloc_prod = 0; // alloc_outer_prod * alloc_inner_prod
1128 int64_t alloc_outer_prod = 1; // The product of lengths of loops outside the scope of the alloc
1129
1130 static constexpr int64_t kCount = 4;
1131
1132 void Export(std::vector<double>* v, int64_t outer_prod) const {
1133 double vs[] = {
1134 slog(alloc_size),
1135 slog(alloc_prod),
1136 slog(alloc_outer_prod),
1137 slog(static_cast<double>(outer_prod) / alloc_outer_prod),
1138 };
1139 v->insert(v->end(), std::begin(vs), std::end(vs));
1140 }
1141
1142 Feature() = default;
1143
1144 explicit Feature(const LoopNest& loop_nest, const Buffer& buffer, arith::Analyzer* analyzer) {
1145 std::vector<int64_t> shape = utils::GetBufferShape(buffer, analyzer);
1146 int64_t numel = 1;
1147 for (int64_t x : shape) {
1148 numel *= x;
1149 }
1150 alloc_size = numel * buffer->dtype.bytes();
1151 alloc_prod = numel * loop_nest.prod;
1152 alloc_outer_prod = loop_nest.prod;
1153 }
1154};
1155
1156} // namespace group4
1157
1158namespace group5 {
1159
1160/*! \brief Group 5 feature */
1161struct Feature {
1162 int64_t outer_prod; // The product of lengths of outer loops
1163 int num_loops; // The number of outer loops
1164 int auto_unroll_max_step; // The value of pragma "auto_unroll_max_step"
1165
1166 static constexpr int64_t kCount = 3;
1167
1168 void Export(std::vector<double>* v) const {
1169 double vs[] = {
1170 slog(outer_prod),
1171 slog(num_loops),
1172 slog(auto_unroll_max_step),
1173 };
1174 v->insert(v->end(), std::begin(vs), std::end(vs));
1175 }
1176
1177 explicit Feature(const LoopNest& loop_nest) {
1178 this->outer_prod = loop_nest.prod;
1179 this->num_loops = loop_nest.loops.size();
1180 this->auto_unroll_max_step = loop_nest.auto_unroll.empty() ? 0 : loop_nest.auto_unroll.back();
1181 }
1182};
1183
1184} // namespace group5
1185
1186namespace group6 {
1187
1188/*! \brief The auxiliary feature extractor for workloads */
1189class WorkloadEmbeddingExtractor : private StmtVisitor {
1190 public:
1191 static std::vector<double> Extract(const IRModule& mod) {
1192 WorkloadEmbeddingExtractor self;
1193 for (const auto& kv : mod->functions) {
1194 if (const PrimFuncNode* func = kv.second.as<PrimFuncNode>()) {
1195 self(func->body);
1196 }
1197 }
1198 return self.embedding;
1199 }
1200
1201 private:
1202 void VisitStmt_(const BlockNode* block) final {
1203 StmtVisitor::VisitStmt_(block);
1204 std::string name = block->name_hint;
1205 std::for_each(name.begin(), name.end(), [](char& c) { c = ::tolower(c); });
1206 if (name.find("softmax") != std::string::npos) {
1207 embedding[0] = 1.0;
1208 } else if ((name.find("max") != std::string::npos) || (name.find("min") != std::string::npos)) {
1209 embedding[1] = 1.0;
1210 } else if (name.find("add") != std::string::npos) {
1211 embedding[2] = 1.0;
1212 } else if (name.find("batch_matmul") != std::string::npos) {
1213 embedding[3] = 1.0;
1214 } else if (name.find("matmul") != std::string::npos) {
1215 embedding[4] = 1.0;
1216 } else if (name.find("depthwiseconv2d") != std::string::npos) {
1217 embedding[5] = 1.0;
1218 } else if (name.find("conv2d_winograd") != std::string::npos) {
1219 embedding[6] = 1.0;
1220 } else if (name.find("conv2d") != std::string::npos) {
1221 embedding[7] = 1.0;
1222 }
1223 }
1224
1225 std::vector<double> embedding = std::vector<double>(8, 0.0);
1226};
1227
1228/*! \brief Group 6 feature */
1229struct Feature {
1230 explicit Feature(const IRModule& mod) {
1231 this->feature = WorkloadEmbeddingExtractor::Extract(mod);
1232 }
1233
1234 void Export(std::vector<double>* v) const {
1235 v->insert(v->end(), std::begin(feature), std::end(feature));
1236 }
1237
1238 std::vector<double> feature; // The workload embedding
1239 static constexpr int64_t kCount = 8;
1240};
1241
1242} // namespace group6
1243
1244/*! \brief The feature extracted */
1245struct Feature {
1246 const BufferNode* buffer = nullptr;
1247 int buffer_order = -1;
1248 std::unique_ptr<group1::Feature> group1 = nullptr;
1249 std::unique_ptr<group2::Feature> group2 = nullptr;
1250 std::unique_ptr<group3::Feature> group3 = nullptr;
1251 std::unique_ptr<group4::Feature> group4 = nullptr;
1252 std::unique_ptr<group5::Feature> group5 = nullptr;
1253 std::shared_ptr<group6::Feature> group6 = nullptr;
1254
1255 bool operator<(const Feature& other) const { return buffer_order < other.buffer_order; }
1256};
1257
1258/*! \brief The main feature extractor */
1259class PerStoreFeatureCollector : private StmtVisitor {
1260 public:
1261 static std::vector<Feature> Collect(bool is_gpu, int64_t cache_line_bytes,
1262 int64_t arith_intensity_curve_num_samples,
1263 const IRModule& mod) {
1264 PerStoreFeatureCollector collector(is_gpu, cache_line_bytes, arith_intensity_curve_num_samples);
1265 for (const auto& kv : mod->functions) {
1266 if (const PrimFuncNode* func = kv.second.as<PrimFuncNode>()) {
1267 collector(func->body);
1268 for (const auto& it : func->buffer_map) {
1269 collector.HandleBufferAlloc(it.second);
1270 }
1271 }
1272 }
1273 std::vector<Feature> result;
1274 result.reserve(collector.buffer_features_.size());
1275 for (auto& it : collector.buffer_features_) {
1276 Feature& feature = it.second;
1277 if (feature.buffer != nullptr) {
1278 ICHECK(feature.group1);
1279 ICHECK(feature.group2);
1280 ICHECK(feature.group3);
1281 ICHECK(feature.group5);
1282 if (feature.group4 == nullptr) {
1283 feature.group4 = std::make_unique<group4::Feature>();
1284 }
1285 result.push_back(std::move(feature));
1286 }
1287 }
1288 std::sort(result.begin(), result.end());
1289 return result;
1290 }
1291
1292 private:
1293 void VisitStmt_(const ForNode* loop) final {
1294 int64_t auto_unroll;
1295 ForVec* for_vec = loop_nest_.Push(loop, &auto_unroll);
1296 StmtVisitor::VisitStmt_(loop);
1297 loop_nest_.Pop(loop, for_vec, auto_unroll);
1298 }
1299
1300 void VisitStmt_(const BufferStoreNode* store) final {
1301 if (store->value->IsInstance<IntImmNode>() || store->value->IsInstance<FloatImmNode>()) {
1302 return;
1303 }
1304 const BufferNode* buffer = store->buffer.get();
1305 Feature& feature = buffer_features_[buffer];
1306 if (feature.buffer == nullptr) {
1307 feature.buffer = buffer;
1308 feature.buffer_order = buffer_features_.size();
1309 }
1310 feature.group1 = std::make_unique<group1::Feature>(store, loop_nest_, is_gpu_);
1311 feature.group2 =
1312 std::make_unique<group2::Feature>(store, loop_nest_, cache_line_bytes_, &for_touched_bytes_,
1313 &buffer_touched_under_loop_, &analyzer_);
1314 feature.group3 =
1315 std::make_unique<group3::Feature>(arith_intensity_curve_num_samples_, loop_nest_,
1316 for_touched_bytes_, feature.group1->arith_ops);
1317 feature.group5 = std::make_unique<group5::Feature>(loop_nest_);
1318 }
1319
1320 void VisitStmt_(const BlockNode* block) final {
1321 StmtVisitor::VisitStmt_(block);
1322 for (const Buffer& buffer : block->alloc_buffers) {
1323 HandleBufferAlloc(buffer);
1324 }
1325 }
1326
1327 void HandleBufferAlloc(const Buffer& buffer) {
1328 Feature& feature = buffer_features_[buffer.get()];
1329 feature.group4 = std::make_unique<group4::Feature>(loop_nest_, buffer, &analyzer_);
1330 }
1331
1332 explicit PerStoreFeatureCollector(bool is_gpu, int64_t cache_line_bytes,
1333 int64_t arith_intensity_curve_num_samples)
1334 : is_gpu_(is_gpu),
1335 cache_line_bytes_(cache_line_bytes),
1336 arith_intensity_curve_num_samples_(arith_intensity_curve_num_samples) {}
1337
1338 bool is_gpu_;
1339 int64_t cache_line_bytes_;
1340 int64_t arith_intensity_curve_num_samples_;
1341 arith::Analyzer analyzer_;
1342 LoopNest loop_nest_ = {};
1343 IntVec for_touched_bytes_ = {};
1344 ForBufferMap<IntVec> buffer_touched_under_loop_ = {};
1345 std::unordered_map<const BufferNode*, Feature> buffer_features_ = {};
1346};
1347
1348} // namespace tir
1349} // namespace tvm
1350
1351namespace tvm {
1352namespace meta_schedule {
1353
1354class PerStoreFeatureNode : public FeatureExtractorNode {
1355 public:
1356 int buffers_per_store;
1357 int arith_intensity_curve_num_samples;
1358 int cache_line_bytes;
1359 bool extract_workload;
1360 int feature_vector_length;
1361
1362 void VisitAttrs(tvm::AttrVisitor* v) {
1363 v->Visit("buffers_per_store", &buffers_per_store);
1364 v->Visit("arith_intensity_curve_num_samples", &arith_intensity_curve_num_samples);
1365 v->Visit("cache_line_bytes", &cache_line_bytes);
1366 v->Visit("feature_vector_length", &feature_vector_length);
1367 }
1368
1369 void ExtractSingle(IRModule mod, bool is_gpu, std::vector<std::vector<double>>* results) {
1370 static transform::Sequential passes = tir::transform::PassListForPerStoreFeature();
1371 mod = passes(std::move(mod));
1372 std::vector<tir::Feature> features = tir::PerStoreFeatureCollector::Collect(
1373 is_gpu, this->cache_line_bytes, this->arith_intensity_curve_num_samples, mod);
1374 int n_features = features.size();
1375 results->resize(n_features);
1376 for (int i = 0; i < n_features; ++i) {
1377 const tir::Feature& feature = features[i];
1378 std::vector<double>& result = (*results)[i];
1379 result.reserve(feature_vector_length);
1380 feature.group1->Export(&result);
1381 feature.group2->Export(&result, this->buffers_per_store);
1382 feature.group3->Export(&result);
1383 feature.group4->Export(&result, feature.group5->outer_prod);
1384 feature.group5->Export(&result);
1385 }
1386 }
1387
1388 Array<runtime::NDArray> ExtractFrom(const TuneContext& tune_context,
1389 const Array<MeasureCandidate>& candidates) {
1390 bool is_gpu = tune_context->target.value()->kind->name == "cuda";
1391 std::vector<runtime::NDArray> results;
1392 results.resize(candidates.size());
1393 std::unique_ptr<tir::group6::Feature> feature_group6 = nullptr;
1394 if (extract_workload) {
1395 feature_group6 = std::make_unique<tir::group6::Feature>(tune_context->mod.value());
1396 }
1397 auto f = [this, is_gpu, &feature_group6, &candidates, &results](int, int task_id) -> void {
1398 const auto& candidate = candidates[task_id];
1399 std::vector<std::vector<double>> features;
1400 ExtractSingle(DeepCopyIRModule(candidate->sch->mod()), is_gpu, &features);
1401 if (extract_workload) {
1402 for (auto& feature : features) {
1403 feature_group6->Export(&feature);
1404 }
1405 }
1406 results[task_id] = tir::utils::AsNDArray(features);
1407 };
1408 support::parallel_for_dynamic(0, candidates.size(), tune_context->num_threads, f);
1409 return results;
1410 }
1411
1412 static constexpr const char* _type_key = "meta_schedule.PerStoreFeature";
1413 TVM_DECLARE_FINAL_OBJECT_INFO(PerStoreFeatureNode, FeatureExtractorNode);
1414};
1415
1416FeatureExtractor FeatureExtractor::PerStoreFeature(int buffers_per_store,
1417 int arith_intensity_curve_num_samples,
1418 int cache_line_bytes, bool extract_workload) {
1419 ObjectPtr<PerStoreFeatureNode> n = make_object<PerStoreFeatureNode>();
1420 n->buffers_per_store = buffers_per_store;
1421 n->arith_intensity_curve_num_samples = arith_intensity_curve_num_samples;
1422 n->cache_line_bytes = cache_line_bytes;
1423 n->extract_workload = extract_workload;
1424 n->feature_vector_length = tir::group1::Feature::kCount + //
1425 tir::group2::Feature::SubFeature::kCount * buffers_per_store + //
1426 arith_intensity_curve_num_samples + //
1427 tir::group4::Feature::kCount + //
1428 tir::group5::Feature::kCount;
1429 if (extract_workload) {
1430 n->feature_vector_length += tir::group6::Feature::kCount;
1431 }
1432 return FeatureExtractor(n);
1433}
1434
1435TVM_REGISTER_NODE_TYPE(PerStoreFeatureNode);
1436TVM_REGISTER_GLOBAL("meta_schedule.FeatureExtractorPerStoreFeature")
1437 .set_body_typed(FeatureExtractor::PerStoreFeature);
1438
1439} // namespace meta_schedule
1440} // namespace tvm
1441