1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #ifndef TVM_TIR_SCHEDULE_IR_COMPARATOR_H_ |
20 | #define TVM_TIR_SCHEDULE_IR_COMPARATOR_H_ |
21 | |
22 | #include <string> |
23 | #include <unordered_map> |
24 | #include <utility> |
25 | #include <vector> |
26 | |
27 | #include "./utils.h" |
28 | |
29 | namespace tvm { |
30 | namespace tir { |
31 | |
32 | using ExprComparator = ExprFunctor<bool(const PrimExpr& n, const PrimExpr& other)>; |
33 | using StmtComparator = StmtFunctor<bool(const Stmt& n, const Stmt& other)>; |
34 | |
35 | /*! \brief Deep comparison to check if two IR ASTs are equivalent for tensorization*/ |
36 | class TensorizeComparator : public ExprComparator, public StmtComparator { |
37 | public: |
38 | /*! |
39 | * \brief Constructor of TensorizeComparator |
40 | * \param assert_mode Whether to raise an error if the two IR ASTs do not match. |
41 | * \param lhs_mod The IRModule of the LHS. This is used for error reporting. |
42 | */ |
43 | explicit TensorizeComparator(IRModule lhs_mod, bool assert_mode = true) |
44 | : lhs_mod_(std::move(lhs_mod)), assert_mode_(assert_mode) {} |
45 | |
46 | bool VisitExpr(const PrimExpr& n, const PrimExpr& other) override; |
47 | bool VisitStmt(const Stmt& n, const Stmt& other) override; |
48 | |
49 | bool VisitStmt_(const ForNode* op, const Stmt& other) override; |
50 | bool VisitStmt_(const SeqStmtNode* op, const Stmt& other) override; |
51 | bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override; |
52 | bool VisitStmt_(const BlockRealizeNode* op, const Stmt& other) override; |
53 | bool VisitStmt_(const BlockNode* op, const Stmt& other) override; |
54 | |
55 | bool VisitExpr_(const AddNode* op, const PrimExpr& other) override; |
56 | bool VisitExpr_(const SubNode* op, const PrimExpr& other) override; |
57 | bool VisitExpr_(const MulNode* op, const PrimExpr& other) override; |
58 | bool VisitExpr_(const DivNode* op, const PrimExpr& other) override; |
59 | bool VisitExpr_(const ModNode* op, const PrimExpr& other) override; |
60 | bool VisitExpr_(const EQNode* op, const PrimExpr& other) override; |
61 | bool VisitExpr_(const NENode* op, const PrimExpr& other) override; |
62 | bool VisitExpr_(const LTNode* op, const PrimExpr& other) override; |
63 | bool VisitExpr_(const LENode* op, const PrimExpr& other) override; |
64 | bool VisitExpr_(const GTNode* op, const PrimExpr& other) override; |
65 | bool VisitExpr_(const GENode* op, const PrimExpr& other) override; |
66 | bool VisitExpr_(const AndNode* op, const PrimExpr& other) override; |
67 | bool VisitExpr_(const OrNode* op, const PrimExpr& other) override; |
68 | bool VisitExpr_(const MinNode* op, const PrimExpr& other) override; |
69 | bool VisitExpr_(const MaxNode* op, const PrimExpr& other) override; |
70 | bool VisitExpr_(const FloorDivNode* op, const PrimExpr& other) override; |
71 | bool VisitExpr_(const FloorModNode* op, const PrimExpr& other) override; |
72 | bool VisitExpr_(const IntImmNode* op, const PrimExpr& other) override; |
73 | bool VisitExpr_(const FloatImmNode* op, const PrimExpr& other) override; |
74 | bool VisitExpr_(const CastNode* op, const PrimExpr& other) override; |
75 | bool VisitExpr_(const VarNode* op, const PrimExpr& other) override; |
76 | bool VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) override; |
77 | bool VisitExpr_(const SelectNode* op, const PrimExpr& other) override; |
78 | |
79 | /*! \brief Map from RHS buffer to LHS buffer */ |
80 | std::unordered_map<Buffer, Buffer, ObjectHash, ObjectEqual> rhs_buffer_map_; |
81 | /*! \brief Base indices of the LHS buffer. */ |
82 | std::unordered_map<Buffer, std::vector<PrimExpr>, ObjectPtrHash, ObjectPtrEqual> buffer_indices_; |
83 | |
84 | protected: |
85 | bool DefEqual(const Var& lhs, const Var& rhs); |
86 | virtual bool CompareBuffer(const Buffer& lhs, const Buffer& rhs); |
87 | bool CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs); |
88 | bool CompareAnnotation(const std::pair<String, ObjectRef>& lhs, |
89 | const std::pair<String, ObjectRef>& rhs); |
90 | bool CompareAnnotationMap(const Map<String, ObjectRef>& lhs, const Map<String, ObjectRef>& rhs); |
91 | template <typename T> |
92 | bool CompareBufferAccess(const T* lhs, const T* rhs); |
93 | template <typename T, typename Self, typename F> |
94 | bool CompareArray(const Array<T>& lhs, const Array<T>& rhs, F Self::*cmp); |
95 | bool CompareRange(const Range& lhs, const Range& rhs); |
96 | bool CompareIterVar(const IterVar& lhs, const IterVar& rhs); |
97 | void EmitError(const std::string& error_message); |
98 | |
99 | /*! \brief IRModule of the LHS stmt. */ |
100 | IRModule lhs_mod_; |
101 | /*! \brief Whether assertion mode is enabled. */ |
102 | bool assert_mode_; |
103 | /*! \brief Whether it is visiting the scope block (the outermost block). */ |
104 | bool is_scope_block = true; |
105 | /*! \brief The arithmetic analyzer. */ |
106 | arith::Analyzer analyzer_; |
107 | /*! \brief Additional error messages. Only used when assert_mode is true. */ |
108 | std::vector<std::string> error_messages_; |
109 | // variable remap if any |
110 | std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> equal_map_; |
111 | }; |
112 | |
113 | /*! |
114 | * \brief IR comparator for auto tensorization. |
115 | * This comparator is used to extract correspondence between the IR of the workload (LHS) and the |
116 | * tensor intrin (RHS). Unlike `TensorizeComparator`, this comparator has relaxed requirements |
117 | * during comparison. It ignores the loop structure (number of loops and their extents) and buffer |
118 | * indices. It only requires the LHS and the RHS to have the same arithmetic operations and the same |
119 | * dtype. With such relaxed requirements, workloads that can only match the tensor intrin after |
120 | * certain transformations (e.g. im2col for conv2d) are allowed for auto tensorization. |
121 | */ |
122 | class AutoTensorizeComparator : public TensorizeComparator { |
123 | public: |
124 | explicit AutoTensorizeComparator(const IRModule& lhs_mod) |
125 | : TensorizeComparator(lhs_mod, /* assert_mode=*/false) {} |
126 | |
127 | private: |
128 | bool VisitExprDefault_(const Object* op, const PrimExpr& other) override; |
129 | bool VisitStmtDefault_(const Object* op, const Stmt& other) override; |
130 | |
131 | bool VisitStmt_(const BlockNode* op, const Stmt& other) override; |
132 | bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override; |
133 | |
134 | bool VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) override; |
135 | |
136 | bool CompareBuffer(const Buffer& lhs, const Buffer& rhs) override; |
137 | template <typename T> |
138 | bool CompareBufferAccess(const T* lhs, const T* rhs); |
139 | |
140 | public: |
141 | // Additional information extracted from LHS (the workload) and RHS (the tensor intrin). |
142 | |
143 | /*! \brief Block iters in the LHS stmt. */ |
144 | std::vector<IterVar> lhs_iters_; |
145 | /*! \brief Block iters in the RHS stmt. */ |
146 | std::vector<IterVar> rhs_iters_; |
147 | /*! \brief The buffer and its access indices in the LHS stmt. */ |
148 | std::unordered_map<Buffer, Array<PrimExpr>, ObjectPtrHash, ObjectPtrEqual> |
149 | lhs_buffer_indices_map_; |
150 | /*! \brief The buffer and its access indices in the RHS stmt. */ |
151 | std::unordered_map<Buffer, Array<PrimExpr>, ObjectPtrHash, ObjectPtrEqual> |
152 | rhs_buffer_indices_map_; |
153 | /*! \brief Map from LHS buffer to RHS buffer */ |
154 | std::unordered_map<Buffer, Buffer, ObjectHash, ObjectEqual> lhs_buffer_map_; |
155 | |
156 | private: |
157 | /*! \brief The domain of the inner block iters. */ |
158 | Map<Var, arith::IntSet> inner_iter_dom_map_; |
159 | }; |
160 | |
161 | } // namespace tir |
162 | } // namespace tvm |
163 | |
164 | #endif // TVM_TIR_SCHEDULE_IR_COMPARATOR_H_ |
165 | |