1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file touch_extractor.h
22 * \brief Extract feature of touch pattern of axes in lowered IR
23 */
24
25#ifndef TVM_AUTOTVM_TOUCH_EXTRACTOR_H_
26#define TVM_AUTOTVM_TOUCH_EXTRACTOR_H_
27
28#include <tvm/runtime/registry.h>
29#include <tvm/tir/expr.h>
30#include <tvm/tir/expr_functor.h>
31
32#include <deque>
33#include <map>
34#include <stack>
35#include <string>
36#include <unordered_map>
37#include <vector>
38
39#include "feature_visitor.h"
40
41namespace tvm {
42namespace autotvm {
43
44using TouchedBuffer = std::string;
45
46// touch pattern buf[(stride * var) % mod) + other]
47struct TouchPattern {
48 int64_t stride{0};
49 int64_t mod{-1}; // -1 for +inf
50
51 int64_t count{1};
52 int64_t reuse{1};
53 int64_t thread_count{0}; // count when move thread axis into innermost
54 int64_t thread_reuse{0}; // reuse ratio move thread axis into innermost
55};
56
57// all the feature of an iter var
58struct ItervarFeature {
59 ItervarFeature(Var var, int64_t extent, int nest, AnnotationType ann_type, int64_t topdown,
60 int counter)
61 : length(extent), nest_level(nest), ann(ann_type), topdown_product(topdown), order(counter) {}
62 ItervarFeature() {}
63
64 // Axis Attributes
65 int64_t length;
66 int nest_level;
67 AnnotationType ann; // one-hot axis type
68 int64_t topdown_product; // accumulative product of axis length, in top-down order
69 int64_t bottomup_product; // accumulative product of axis length, in bottom-up order
70 // bottomup_product = reuse * count for any touched buffer
71
72 int order; // used for soring axis
73
74 // Arithmetic feature
75 int add_ct{0};
76 int mul_ct{0};
77 int div_ct{0};
78
79 // Memory Touch Feature
80 std::unordered_map<TouchedBuffer, TouchPattern> touch_feature;
81};
82
83// extract iter vars and their touch pattern from ir
84class TouchExtractor : public FeatureVisitor {
85 public:
86 void Analyze(const Stmt& stmt) { operator()(stmt); }
87
88 // arithmetic stats
89 void VisitExpr_(const AddNode* op) final {
90 if (op->dtype.is_float() || op->dtype.is_bfloat16()) {
91 itervar_map[itervar_stack_.back()].add_ct++;
92 }
93 FeatureVisitor::VisitExpr_(op);
94 }
95
96 void VisitExpr_(const SubNode* op) final {
97 if (op->dtype.is_float() || op->dtype.is_bfloat16()) {
98 itervar_map[itervar_stack_.back()].add_ct++;
99 }
100 FeatureVisitor::VisitExpr_(op);
101 }
102
103 void VisitExpr_(const MulNode* op) final {
104 if (op->dtype.is_float() || op->dtype.is_bfloat16()) {
105 itervar_map[itervar_stack_.back()].mul_ct++;
106 }
107 FeatureVisitor::VisitExpr_(op);
108 }
109
110 void VisitExpr_(const DivNode* op) final {
111 if (op->dtype.is_float() || op->dtype.is_bfloat16()) {
112 itervar_map[itervar_stack_.back()].div_ct++;
113 }
114 FeatureVisitor::VisitExpr_(op);
115 }
116
117 void VisitExpr_(const ModNode* op) final {
118 if (op->dtype.is_float() || op->dtype.is_bfloat16()) {
119 itervar_map[itervar_stack_.back()].div_ct++;
120 }
121 FeatureVisitor::VisitExpr_(op);
122 }
123
124 std::unordered_map<Var, ItervarFeature, tvm::ObjectPtrHash, tvm::ObjectPtrEqual> itervar_map;
125
126 private:
127 bool EnterItervar_(Var var, int64_t length, AnnotationType ann_type);
128 void ExitItervar_();
129 void EnterMem_(Var buffer_var, PrimExpr index);
130 void ExitMem_();
131
132 int64_t topdown_product_{1};
133 std::map<std::string, size_t> buffer_counter_;
134 size_t itervar_counter_{0};
135 std::deque<Var> itervar_stack_; // use deque instead of stack for indexing
136 std::deque<size_t> skip_stack_size_;
137
138 using FeatureVisitor::VisitExpr_;
139};
140
141} // namespace autotvm
142} // namespace tvm
143
144#endif // TVM_AUTOTVM_TOUCH_EXTRACTOR_H_
145