1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file feature_visitor.h |
22 | * \brief Base class for feature extractor. |
23 | * These features are used for machine learning cost model |
24 | */ |
25 | |
26 | #ifndef TVM_AUTOTVM_FEATURE_VISITOR_H_ |
27 | #define TVM_AUTOTVM_FEATURE_VISITOR_H_ |
28 | |
29 | #include <tvm/tir/expr.h> |
30 | #include <tvm/tir/stmt.h> |
31 | #include <tvm/tir/stmt_functor.h> |
32 | |
33 | #include <string> |
34 | |
35 | namespace tvm { |
36 | namespace autotvm { |
37 | |
38 | using namespace tvm::tir; |
39 | |
40 | /*! |
41 | * \brief Type of for loop, used as one-hot encoding in features |
42 | */ |
43 | enum AnnotationType { |
44 | kBlockX, |
45 | kBlockY, |
46 | kBlockZ, |
47 | kThreadX, |
48 | kThreadY, |
49 | kThreadZ, |
50 | kUnrolled, |
51 | kVectorized, |
52 | kParallel, |
53 | kSerial, |
54 | kVirtualThread, |
55 | kNum, |
56 | }; |
57 | |
58 | /*! |
59 | * \brief A base class for feature extractor, used for processing |
60 | * for loop and memory access in the IR |
61 | */ |
62 | class FeatureVisitor : public StmtExprVisitor { |
63 | public: |
64 | // for loop |
65 | void VisitStmt_(const ForNode* op) final; |
66 | void VisitStmt_(const AttrStmtNode* op) final; |
67 | |
68 | // memory access |
69 | void VisitExpr_(const BufferLoadNode* op) final; |
70 | void VisitStmt_(const BufferStoreNode* op) final; |
71 | |
72 | using StmtExprVisitor::VisitExpr_; |
73 | using StmtExprVisitor::VisitStmt_; |
74 | |
75 | protected: |
76 | /*! |
77 | * \brief Enter a for loop node |
78 | * \param var The expression to be printed. |
79 | * \param length The output stream |
80 | * \param ann_type The type for the for loop |
81 | * \return skip Whether skip this node |
82 | */ |
83 | virtual bool EnterItervar_(tir::Var var, int64_t length, AnnotationType ann_type) = 0; |
84 | /*! \brief Exit a for loop subtree */ |
85 | virtual void ExitItervar_() = 0; |
86 | /*! |
87 | * \brief Enter a memory access node |
88 | * \param buffer_var The buffer to access. |
89 | * \param index Index expression |
90 | */ |
91 | virtual void EnterMem_(tir::Var buffer_var, tvm::PrimExpr index) = 0; |
92 | /*! \brief Exit a memory access node */ |
93 | virtual void ExitMem_() = 0; |
94 | }; |
95 | |
96 | } // namespace autotvm |
97 | } // namespace tvm |
98 | |
99 | #endif // TVM_AUTOTVM_FEATURE_VISITOR_H_ |
100 | |