1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file data_type_rewriter.h
22 * \brief Rewrite the data type of expressions.
23 */
24#ifndef TVM_TIR_DATA_TYPE_REWRITER_H_
25#define TVM_TIR_DATA_TYPE_REWRITER_H_
26
27#include <tvm/tir/stmt_functor.h>
28
29#include <unordered_map>
30
31namespace tvm {
32namespace tir {
33
34/*!
35 * \brief Legalize the data types of expressions to make sure they are consistent with other
36 * parts of the program.
37 *
38 * It enforces the following rules:
39 * - The data type of the index variable in a loop must be consistent with the data type of the loop
40 * bounds.
41 * - The data type of the binary and ternary expressions must be consistent with the data types of
42 * each of their operands.
43 * - The data type of the bounds and binding values of block iter vars must be consistent with the
44 * data type of the block iter vars.
45 *
46 * Usually we enforce the consistency of data types when constructing the IR nodes. However, such
47 * inconsistency may happen as a result of IR mutation in some passes. This class can be used as
48 * base class of such passes to ensure the consistency of data types.
49 */
50class DataTypeLegalizer : public StmtExprMutator {
51 protected:
52 Stmt VisitStmt_(const ForNode* op) override;
53 Stmt VisitStmt_(const AttrStmtNode* op) override;
54 Stmt VisitStmt_(const BlockRealizeNode* op) override;
55 Stmt VisitStmt_(const BlockNode* op) override;
56 Stmt VisitStmt_(const LetStmtNode* op) override;
57 PrimExpr VisitExpr_(const VarNode* op) override;
58 PrimExpr VisitExpr_(const SelectNode* op) override;
59 PrimExpr VisitExpr_(const RampNode* op) override;
60 PrimExpr VisitExpr_(const AddNode* op) override;
61 PrimExpr VisitExpr_(const SubNode* op) override;
62 PrimExpr VisitExpr_(const MulNode* op) override;
63 PrimExpr VisitExpr_(const DivNode* op) override;
64 PrimExpr VisitExpr_(const ModNode* op) override;
65 PrimExpr VisitExpr_(const FloorDivNode* op) override;
66 PrimExpr VisitExpr_(const FloorModNode* op) override;
67 PrimExpr VisitExpr_(const MinNode* op) override;
68 PrimExpr VisitExpr_(const MaxNode* op) override;
69 PrimExpr VisitExpr_(const EQNode* op) override;
70 PrimExpr VisitExpr_(const NENode* op) override;
71 PrimExpr VisitExpr_(const LTNode* op) override;
72 PrimExpr VisitExpr_(const LENode* op) override;
73 PrimExpr VisitExpr_(const GTNode* op) override;
74 PrimExpr VisitExpr_(const GENode* op) override;
75 PrimExpr VisitExpr_(const CallNode* op) override;
76 PrimExpr VisitExpr_(const CastNode* op) override;
77
78 using StmtExprMutator::VisitExpr_;
79 using StmtExprMutator::VisitStmt_;
80
81 // a map from IterVar before rewrite to that after rewrite,
82 // ensures one old IterVar maps to exactly one new IterVar
83 std::unordered_map<const IterVarNode*, IterVar> ivmap_;
84 // a map from original vars to ones with new dtype
85 std::unordered_map<const VarNode*, Var> var_remap_;
86};
87
88/*!
89 * \brief Data type rewriter for buffer indices.
90 *
91 * Detect the components of buffer indices that should be considered for data type rewriting.
92 * This class doesn't perform actual rewriting of data types. During recursive visiting, the
93 * internal flags `is_enabled_` and `is_conditional_` are used to indicate whether the current
94 * expression is a buffer index or a conditional expression, which can be used in the sub-classes to
95 * implement different rewriting rules.
96 */
97class IndexDataTypeRewriter : public DataTypeLegalizer {
98 protected:
99 using Parent = DataTypeLegalizer;
100 using Parent::VisitExpr_;
101 using Parent::VisitStmt_;
102
103 Stmt VisitStmt_(const BlockRealizeNode* op) override;
104 Stmt VisitStmt_(const BlockNode* op) override;
105 Stmt VisitStmt_(const BufferStoreNode* op) override;
106 PrimExpr VisitExpr_(const BufferLoadNode* op) override;
107 Array<PrimExpr> VisitIndices(Array<PrimExpr> indices);
108 Stmt VisitStmt_(const IfThenElseNode* op) override;
109 Stmt VisitStmt_(const DeclBufferNode* op) override;
110 Stmt VisitStmt_(const AllocateNode* op) override;
111 PrimExpr VisitExpr_(const EQNode* op) override;
112 PrimExpr VisitExpr_(const NENode* op) override;
113 PrimExpr VisitExpr_(const LTNode* op) override;
114 PrimExpr VisitExpr_(const LENode* op) override;
115 PrimExpr VisitExpr_(const GTNode* op) override;
116 PrimExpr VisitExpr_(const GENode* op) override;
117 PrimExpr VisitExpr_(const CallNode* op) override;
118 Stmt VisitStmt_(const ForNode* op) override;
119
120 Buffer VisitBuffer(const Buffer& buffer);
121 Buffer GetRemappedBuffer(const Buffer& buffer);
122 Map<String, ObjectRef> VisitBlockAnnotations(const Map<String, ObjectRef>& annotations);
123 BufferRegion VisitBufferRegion(const BufferRegion& region);
124 IterVar VisitIterVar(const IterVar& iter_var);
125 // indicator of index expr to rewrite
126 bool is_enabled_{false};
127 // indicator of condition
128 bool is_condition_{false};
129
130 Map<Buffer, Buffer> buffer_remap_;
131};
132
133/*!
134 * \brief Normalize the data types of buffer shapes and indices to the same data type.
135 *
136 * This pass rewrites the data types of buffer shapes and indices to the specified data type. It
137 * assumes the specified data type is large enough to hold the original ranges of buffer shapes and
138 * indices.
139 */
140class IndexDataTypeNormalizer : public IndexDataTypeRewriter {
141 public:
142 explicit IndexDataTypeNormalizer(DataType target_data_type);
143 PrimFunc Rewrite(PrimFunc func);
144
145 protected:
146 using Parent = IndexDataTypeRewriter;
147 using Parent::VisitExpr_;
148 using Parent::VisitStmt_;
149 PrimExpr VisitExpr_(const IntImmNode* op) final;
150 PrimExpr VisitExpr_(const VarNode* op) final;
151 PrimExpr VisitExpr_(const CastNode* op) final;
152
153 DataType target_data_type_ = DataType::Int(64);
154};
155
156} // namespace tir
157} // namespace tvm
158
159#endif // TVM_TIR_DATA_TYPE_REWRITER_H_
160