1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file feature_visitor.h
22 * \brief Base class for feature extractor.
23 * These features are used for machine learning cost model
24 */
25
26#ifndef TVM_AUTOTVM_FEATURE_VISITOR_H_
27#define TVM_AUTOTVM_FEATURE_VISITOR_H_
28
29#include <tvm/tir/expr.h>
30#include <tvm/tir/stmt.h>
31#include <tvm/tir/stmt_functor.h>
32
33#include <string>
34
35namespace tvm {
36namespace autotvm {
37
38using namespace tvm::tir;
39
40/*!
41 * \brief Type of for loop, used as one-hot encoding in features
42 */
43enum AnnotationType {
44 kBlockX,
45 kBlockY,
46 kBlockZ,
47 kThreadX,
48 kThreadY,
49 kThreadZ,
50 kUnrolled,
51 kVectorized,
52 kParallel,
53 kSerial,
54 kVirtualThread,
55 kNum,
56};
57
58/*!
59 * \brief A base class for feature extractor, used for processing
60 * for loop and memory access in the IR
61 */
62class FeatureVisitor : public StmtExprVisitor {
63 public:
64 // for loop
65 void VisitStmt_(const ForNode* op) final;
66 void VisitStmt_(const AttrStmtNode* op) final;
67
68 // memory access
69 void VisitExpr_(const BufferLoadNode* op) final;
70 void VisitStmt_(const BufferStoreNode* op) final;
71
72 using StmtExprVisitor::VisitExpr_;
73 using StmtExprVisitor::VisitStmt_;
74
75 protected:
76 /*!
77 * \brief Enter a for loop node
78 * \param var The expression to be printed.
79 * \param length The output stream
80 * \param ann_type The type for the for loop
81 * \return skip Whether skip this node
82 */
83 virtual bool EnterItervar_(tir::Var var, int64_t length, AnnotationType ann_type) = 0;
84 /*! \brief Exit a for loop subtree */
85 virtual void ExitItervar_() = 0;
86 /*!
87 * \brief Enter a memory access node
88 * \param buffer_var The buffer to access.
89 * \param index Index expression
90 */
91 virtual void EnterMem_(tir::Var buffer_var, tvm::PrimExpr index) = 0;
92 /*! \brief Exit a memory access node */
93 virtual void ExitMem_() = 0;
94};
95
96} // namespace autotvm
97} // namespace tvm
98
99#endif // TVM_AUTOTVM_FEATURE_VISITOR_H_
100