1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file data_type_rewriter.h |
22 | * \brief Rewrite the data type of expressions. |
23 | */ |
24 | #ifndef TVM_TIR_DATA_TYPE_REWRITER_H_ |
25 | #define TVM_TIR_DATA_TYPE_REWRITER_H_ |
26 | |
27 | #include <tvm/tir/stmt_functor.h> |
28 | |
29 | #include <unordered_map> |
30 | |
31 | namespace tvm { |
32 | namespace tir { |
33 | |
34 | /*! |
35 | * \brief Legalize the data types of expressions to make sure they are consistent with other |
36 | * parts of the program. |
37 | * |
38 | * It enforces the following rules: |
39 | * - The data type of the index variable in a loop must be consistent with the data type of the loop |
40 | * bounds. |
41 | * - The data type of the binary and ternary expressions must be consistent with the data types of |
42 | * each of their operands. |
43 | * - The data type of the bounds and binding values of block iter vars must be consistent with the |
44 | * data type of the block iter vars. |
45 | * |
46 | * Usually we enforce the consistency of data types when constructing the IR nodes. However, such |
47 | * inconsistency may happen as a result of IR mutation in some passes. This class can be used as |
48 | * base class of such passes to ensure the consistency of data types. |
49 | */ |
50 | class DataTypeLegalizer : public StmtExprMutator { |
51 | protected: |
52 | Stmt VisitStmt_(const ForNode* op) override; |
53 | Stmt VisitStmt_(const AttrStmtNode* op) override; |
54 | Stmt VisitStmt_(const BlockRealizeNode* op) override; |
55 | Stmt VisitStmt_(const BlockNode* op) override; |
56 | Stmt VisitStmt_(const LetStmtNode* op) override; |
57 | PrimExpr VisitExpr_(const VarNode* op) override; |
58 | PrimExpr VisitExpr_(const SelectNode* op) override; |
59 | PrimExpr VisitExpr_(const RampNode* op) override; |
60 | PrimExpr VisitExpr_(const AddNode* op) override; |
61 | PrimExpr VisitExpr_(const SubNode* op) override; |
62 | PrimExpr VisitExpr_(const MulNode* op) override; |
63 | PrimExpr VisitExpr_(const DivNode* op) override; |
64 | PrimExpr VisitExpr_(const ModNode* op) override; |
65 | PrimExpr VisitExpr_(const FloorDivNode* op) override; |
66 | PrimExpr VisitExpr_(const FloorModNode* op) override; |
67 | PrimExpr VisitExpr_(const MinNode* op) override; |
68 | PrimExpr VisitExpr_(const MaxNode* op) override; |
69 | PrimExpr VisitExpr_(const EQNode* op) override; |
70 | PrimExpr VisitExpr_(const NENode* op) override; |
71 | PrimExpr VisitExpr_(const LTNode* op) override; |
72 | PrimExpr VisitExpr_(const LENode* op) override; |
73 | PrimExpr VisitExpr_(const GTNode* op) override; |
74 | PrimExpr VisitExpr_(const GENode* op) override; |
75 | PrimExpr VisitExpr_(const CallNode* op) override; |
76 | PrimExpr VisitExpr_(const CastNode* op) override; |
77 | |
78 | using StmtExprMutator::VisitExpr_; |
79 | using StmtExprMutator::VisitStmt_; |
80 | |
81 | // a map from IterVar before rewrite to that after rewrite, |
82 | // ensures one old IterVar maps to exactly one new IterVar |
83 | std::unordered_map<const IterVarNode*, IterVar> ivmap_; |
84 | // a map from original vars to ones with new dtype |
85 | std::unordered_map<const VarNode*, Var> var_remap_; |
86 | }; |
87 | |
88 | /*! |
89 | * \brief Data type rewriter for buffer indices. |
90 | * |
91 | * Detect the components of buffer indices that should be considered for data type rewriting. |
92 | * This class doesn't perform actual rewriting of data types. During recursive visiting, the |
93 | * internal flags `is_enabled_` and `is_conditional_` are used to indicate whether the current |
94 | * expression is a buffer index or a conditional expression, which can be used in the sub-classes to |
95 | * implement different rewriting rules. |
96 | */ |
97 | class IndexDataTypeRewriter : public DataTypeLegalizer { |
98 | protected: |
99 | using Parent = DataTypeLegalizer; |
100 | using Parent::VisitExpr_; |
101 | using Parent::VisitStmt_; |
102 | |
103 | Stmt VisitStmt_(const BlockRealizeNode* op) override; |
104 | Stmt VisitStmt_(const BlockNode* op) override; |
105 | Stmt VisitStmt_(const BufferStoreNode* op) override; |
106 | PrimExpr VisitExpr_(const BufferLoadNode* op) override; |
107 | Array<PrimExpr> VisitIndices(Array<PrimExpr> indices); |
108 | Stmt VisitStmt_(const IfThenElseNode* op) override; |
109 | Stmt VisitStmt_(const DeclBufferNode* op) override; |
110 | Stmt VisitStmt_(const AllocateNode* op) override; |
111 | PrimExpr VisitExpr_(const EQNode* op) override; |
112 | PrimExpr VisitExpr_(const NENode* op) override; |
113 | PrimExpr VisitExpr_(const LTNode* op) override; |
114 | PrimExpr VisitExpr_(const LENode* op) override; |
115 | PrimExpr VisitExpr_(const GTNode* op) override; |
116 | PrimExpr VisitExpr_(const GENode* op) override; |
117 | PrimExpr VisitExpr_(const CallNode* op) override; |
118 | Stmt VisitStmt_(const ForNode* op) override; |
119 | |
120 | Buffer VisitBuffer(const Buffer& buffer); |
121 | Buffer GetRemappedBuffer(const Buffer& buffer); |
122 | Map<String, ObjectRef> VisitBlockAnnotations(const Map<String, ObjectRef>& annotations); |
123 | BufferRegion VisitBufferRegion(const BufferRegion& region); |
124 | IterVar VisitIterVar(const IterVar& iter_var); |
125 | // indicator of index expr to rewrite |
126 | bool is_enabled_{false}; |
127 | // indicator of condition |
128 | bool is_condition_{false}; |
129 | |
130 | Map<Buffer, Buffer> buffer_remap_; |
131 | }; |
132 | |
133 | /*! |
134 | * \brief Normalize the data types of buffer shapes and indices to the same data type. |
135 | * |
136 | * This pass rewrites the data types of buffer shapes and indices to the specified data type. It |
137 | * assumes the specified data type is large enough to hold the original ranges of buffer shapes and |
138 | * indices. |
139 | */ |
140 | class IndexDataTypeNormalizer : public IndexDataTypeRewriter { |
141 | public: |
142 | explicit IndexDataTypeNormalizer(DataType target_data_type); |
143 | PrimFunc Rewrite(PrimFunc func); |
144 | |
145 | protected: |
146 | using Parent = IndexDataTypeRewriter; |
147 | using Parent::VisitExpr_; |
148 | using Parent::VisitStmt_; |
149 | PrimExpr VisitExpr_(const IntImmNode* op) final; |
150 | PrimExpr VisitExpr_(const VarNode* op) final; |
151 | PrimExpr VisitExpr_(const CastNode* op) final; |
152 | |
153 | DataType target_data_type_ = DataType::Int(64); |
154 | }; |
155 | |
156 | } // namespace tir |
157 | } // namespace tvm |
158 | |
159 | #endif // TVM_TIR_DATA_TYPE_REWRITER_H_ |
160 | |