1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#ifndef TVM_TIR_SCHEDULE_IR_COMPARATOR_H_
20#define TVM_TIR_SCHEDULE_IR_COMPARATOR_H_
21
22#include <string>
23#include <unordered_map>
24#include <utility>
25#include <vector>
26
27#include "./utils.h"
28
29namespace tvm {
30namespace tir {
31
32using ExprComparator = ExprFunctor<bool(const PrimExpr& n, const PrimExpr& other)>;
33using StmtComparator = StmtFunctor<bool(const Stmt& n, const Stmt& other)>;
34
35/*! \brief Deep comparison to check if two IR ASTs are equivalent for tensorization*/
36class TensorizeComparator : public ExprComparator, public StmtComparator {
37 public:
38 /*!
39 * \brief Constructor of TensorizeComparator
40 * \param assert_mode Whether to raise an error if the two IR ASTs do not match.
41 * \param lhs_mod The IRModule of the LHS. This is used for error reporting.
42 */
43 explicit TensorizeComparator(IRModule lhs_mod, bool assert_mode = true)
44 : lhs_mod_(std::move(lhs_mod)), assert_mode_(assert_mode) {}
45
46 bool VisitExpr(const PrimExpr& n, const PrimExpr& other) override;
47 bool VisitStmt(const Stmt& n, const Stmt& other) override;
48
49 bool VisitStmt_(const ForNode* op, const Stmt& other) override;
50 bool VisitStmt_(const SeqStmtNode* op, const Stmt& other) override;
51 bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override;
52 bool VisitStmt_(const BlockRealizeNode* op, const Stmt& other) override;
53 bool VisitStmt_(const BlockNode* op, const Stmt& other) override;
54
55 bool VisitExpr_(const AddNode* op, const PrimExpr& other) override;
56 bool VisitExpr_(const SubNode* op, const PrimExpr& other) override;
57 bool VisitExpr_(const MulNode* op, const PrimExpr& other) override;
58 bool VisitExpr_(const DivNode* op, const PrimExpr& other) override;
59 bool VisitExpr_(const ModNode* op, const PrimExpr& other) override;
60 bool VisitExpr_(const EQNode* op, const PrimExpr& other) override;
61 bool VisitExpr_(const NENode* op, const PrimExpr& other) override;
62 bool VisitExpr_(const LTNode* op, const PrimExpr& other) override;
63 bool VisitExpr_(const LENode* op, const PrimExpr& other) override;
64 bool VisitExpr_(const GTNode* op, const PrimExpr& other) override;
65 bool VisitExpr_(const GENode* op, const PrimExpr& other) override;
66 bool VisitExpr_(const AndNode* op, const PrimExpr& other) override;
67 bool VisitExpr_(const OrNode* op, const PrimExpr& other) override;
68 bool VisitExpr_(const MinNode* op, const PrimExpr& other) override;
69 bool VisitExpr_(const MaxNode* op, const PrimExpr& other) override;
70 bool VisitExpr_(const FloorDivNode* op, const PrimExpr& other) override;
71 bool VisitExpr_(const FloorModNode* op, const PrimExpr& other) override;
72 bool VisitExpr_(const IntImmNode* op, const PrimExpr& other) override;
73 bool VisitExpr_(const FloatImmNode* op, const PrimExpr& other) override;
74 bool VisitExpr_(const CastNode* op, const PrimExpr& other) override;
75 bool VisitExpr_(const VarNode* op, const PrimExpr& other) override;
76 bool VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) override;
77 bool VisitExpr_(const SelectNode* op, const PrimExpr& other) override;
78
79 /*! \brief Map from RHS buffer to LHS buffer */
80 std::unordered_map<Buffer, Buffer, ObjectHash, ObjectEqual> rhs_buffer_map_;
81 /*! \brief Base indices of the LHS buffer. */
82 std::unordered_map<Buffer, std::vector<PrimExpr>, ObjectPtrHash, ObjectPtrEqual> buffer_indices_;
83
84 protected:
85 bool DefEqual(const Var& lhs, const Var& rhs);
86 virtual bool CompareBuffer(const Buffer& lhs, const Buffer& rhs);
87 bool CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs);
88 bool CompareAnnotation(const std::pair<String, ObjectRef>& lhs,
89 const std::pair<String, ObjectRef>& rhs);
90 bool CompareAnnotationMap(const Map<String, ObjectRef>& lhs, const Map<String, ObjectRef>& rhs);
91 template <typename T>
92 bool CompareBufferAccess(const T* lhs, const T* rhs);
93 template <typename T, typename Self, typename F>
94 bool CompareArray(const Array<T>& lhs, const Array<T>& rhs, F Self::*cmp);
95 bool CompareRange(const Range& lhs, const Range& rhs);
96 bool CompareIterVar(const IterVar& lhs, const IterVar& rhs);
97 void EmitError(const std::string& error_message);
98
99 /*! \brief IRModule of the LHS stmt. */
100 IRModule lhs_mod_;
101 /*! \brief Whether assertion mode is enabled. */
102 bool assert_mode_;
103 /*! \brief Whether it is visiting the scope block (the outermost block). */
104 bool is_scope_block = true;
105 /*! \brief The arithmetic analyzer. */
106 arith::Analyzer analyzer_;
107 /*! \brief Additional error messages. Only used when assert_mode is true. */
108 std::vector<std::string> error_messages_;
109 // variable remap if any
110 std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> equal_map_;
111};
112
113/*!
114 * \brief IR comparator for auto tensorization.
115 * This comparator is used to extract correspondence between the IR of the workload (LHS) and the
116 * tensor intrin (RHS). Unlike `TensorizeComparator`, this comparator has relaxed requirements
117 * during comparison. It ignores the loop structure (number of loops and their extents) and buffer
118 * indices. It only requires the LHS and the RHS to have the same arithmetic operations and the same
119 * dtype. With such relaxed requirements, workloads that can only match the tensor intrin after
120 * certain transformations (e.g. im2col for conv2d) are allowed for auto tensorization.
121 */
122class AutoTensorizeComparator : public TensorizeComparator {
123 public:
124 explicit AutoTensorizeComparator(const IRModule& lhs_mod)
125 : TensorizeComparator(lhs_mod, /* assert_mode=*/false) {}
126
127 private:
128 bool VisitExprDefault_(const Object* op, const PrimExpr& other) override;
129 bool VisitStmtDefault_(const Object* op, const Stmt& other) override;
130
131 bool VisitStmt_(const BlockNode* op, const Stmt& other) override;
132 bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override;
133
134 bool VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) override;
135
136 bool CompareBuffer(const Buffer& lhs, const Buffer& rhs) override;
137 template <typename T>
138 bool CompareBufferAccess(const T* lhs, const T* rhs);
139
140 public:
141 // Additional information extracted from LHS (the workload) and RHS (the tensor intrin).
142
143 /*! \brief Block iters in the LHS stmt. */
144 std::vector<IterVar> lhs_iters_;
145 /*! \brief Block iters in the RHS stmt. */
146 std::vector<IterVar> rhs_iters_;
147 /*! \brief The buffer and its access indices in the LHS stmt. */
148 std::unordered_map<Buffer, Array<PrimExpr>, ObjectPtrHash, ObjectPtrEqual>
149 lhs_buffer_indices_map_;
150 /*! \brief The buffer and its access indices in the RHS stmt. */
151 std::unordered_map<Buffer, Array<PrimExpr>, ObjectPtrHash, ObjectPtrEqual>
152 rhs_buffer_indices_map_;
153 /*! \brief Map from LHS buffer to RHS buffer */
154 std::unordered_map<Buffer, Buffer, ObjectHash, ObjectEqual> lhs_buffer_map_;
155
156 private:
157 /*! \brief The domain of the inner block iters. */
158 Map<Var, arith::IntSet> inner_iter_dom_map_;
159};
160
161} // namespace tir
162} // namespace tvm
163
164#endif // TVM_TIR_SCHEDULE_IR_COMPARATOR_H_
165