1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file touch_extractor.h |
22 | * \brief Extract feature of touch pattern of axes in lowered IR |
23 | */ |
24 | |
25 | #ifndef TVM_AUTOTVM_TOUCH_EXTRACTOR_H_ |
26 | #define |
27 | |
28 | #include <tvm/runtime/registry.h> |
29 | #include <tvm/tir/expr.h> |
30 | #include <tvm/tir/expr_functor.h> |
31 | |
32 | #include <deque> |
33 | #include <map> |
34 | #include <stack> |
35 | #include <string> |
36 | #include <unordered_map> |
37 | #include <vector> |
38 | |
39 | #include "feature_visitor.h" |
40 | |
41 | namespace tvm { |
42 | namespace autotvm { |
43 | |
44 | using TouchedBuffer = std::string; |
45 | |
46 | // touch pattern buf[(stride * var) % mod) + other] |
47 | struct TouchPattern { |
48 | int64_t stride{0}; |
49 | int64_t mod{-1}; // -1 for +inf |
50 | |
51 | int64_t count{1}; |
52 | int64_t reuse{1}; |
53 | int64_t thread_count{0}; // count when move thread axis into innermost |
54 | int64_t thread_reuse{0}; // reuse ratio move thread axis into innermost |
55 | }; |
56 | |
57 | // all the feature of an iter var |
58 | struct ItervarFeature { |
59 | ItervarFeature(Var var, int64_t extent, int nest, AnnotationType ann_type, int64_t topdown, |
60 | int counter) |
61 | : length(extent), nest_level(nest), ann(ann_type), topdown_product(topdown), order(counter) {} |
62 | ItervarFeature() {} |
63 | |
64 | // Axis Attributes |
65 | int64_t length; |
66 | int nest_level; |
67 | AnnotationType ann; // one-hot axis type |
68 | int64_t topdown_product; // accumulative product of axis length, in top-down order |
69 | int64_t bottomup_product; // accumulative product of axis length, in bottom-up order |
70 | // bottomup_product = reuse * count for any touched buffer |
71 | |
72 | int order; // used for soring axis |
73 | |
74 | // Arithmetic feature |
75 | int add_ct{0}; |
76 | int mul_ct{0}; |
77 | int div_ct{0}; |
78 | |
79 | // Memory Touch Feature |
80 | std::unordered_map<TouchedBuffer, TouchPattern> touch_feature; |
81 | }; |
82 | |
83 | // extract iter vars and their touch pattern from ir |
84 | class : public FeatureVisitor { |
85 | public: |
86 | void (const Stmt& stmt) { operator()(stmt); } |
87 | |
88 | // arithmetic stats |
89 | void (const AddNode* op) final { |
90 | if (op->dtype.is_float() || op->dtype.is_bfloat16()) { |
91 | itervar_map[itervar_stack_.back()].add_ct++; |
92 | } |
93 | FeatureVisitor::VisitExpr_(op); |
94 | } |
95 | |
96 | void (const SubNode* op) final { |
97 | if (op->dtype.is_float() || op->dtype.is_bfloat16()) { |
98 | itervar_map[itervar_stack_.back()].add_ct++; |
99 | } |
100 | FeatureVisitor::VisitExpr_(op); |
101 | } |
102 | |
103 | void (const MulNode* op) final { |
104 | if (op->dtype.is_float() || op->dtype.is_bfloat16()) { |
105 | itervar_map[itervar_stack_.back()].mul_ct++; |
106 | } |
107 | FeatureVisitor::VisitExpr_(op); |
108 | } |
109 | |
110 | void (const DivNode* op) final { |
111 | if (op->dtype.is_float() || op->dtype.is_bfloat16()) { |
112 | itervar_map[itervar_stack_.back()].div_ct++; |
113 | } |
114 | FeatureVisitor::VisitExpr_(op); |
115 | } |
116 | |
117 | void (const ModNode* op) final { |
118 | if (op->dtype.is_float() || op->dtype.is_bfloat16()) { |
119 | itervar_map[itervar_stack_.back()].div_ct++; |
120 | } |
121 | FeatureVisitor::VisitExpr_(op); |
122 | } |
123 | |
124 | std::unordered_map<Var, ItervarFeature, tvm::ObjectPtrHash, tvm::ObjectPtrEqual> ; |
125 | |
126 | private: |
127 | bool (Var var, int64_t length, AnnotationType ann_type); |
128 | void (); |
129 | void (Var buffer_var, PrimExpr index); |
130 | void (); |
131 | |
132 | int64_t {1}; |
133 | std::map<std::string, size_t> ; |
134 | size_t {0}; |
135 | std::deque<Var> ; // use deque instead of stack for indexing |
136 | std::deque<size_t> ; |
137 | |
138 | using FeatureVisitor::VisitExpr_; |
139 | }; |
140 | |
141 | } // namespace autotvm |
142 | } // namespace tvm |
143 | |
144 | #endif // TVM_AUTOTVM_TOUCH_EXTRACTOR_H_ |
145 | |