1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \brief Replace certain copy with copy intrinsics. |
22 | * \file copy_intrin_rewrite.cc |
23 | */ |
24 | #include <tvm/arith/analyzer.h> |
25 | #include <tvm/arith/pattern.h> |
26 | #include <tvm/runtime/registry.h> |
27 | #include <tvm/tir/expr.h> |
28 | #include <tvm/tir/stmt_functor.h> |
29 | #include <tvm/tir/transform.h> |
30 | |
31 | #include "../../arith/pattern_match.h" |
32 | #include "ir_utils.h" |
33 | |
34 | namespace tvm { |
35 | namespace tir { |
36 | |
37 | using runtime::PackedFunc; |
38 | |
39 | class CopyIntrinInjector : public StmtMutator { |
40 | public: |
41 | CopyIntrinInjector(const std::string& pragma_key, const PackedFunc& flower_copy_fromto) |
42 | : pragma_key_(attr::pragma_scope_prefix + pragma_key), |
43 | flower_copy_fromto_(flower_copy_fromto) {} |
44 | |
45 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
46 | if (op->attr_key == pragma_key_) { |
47 | Stmt ret; |
48 | std::string error_info; |
49 | ICHECK(MatchCopyPattern(op->body, &ret, &error_info)) |
50 | << "Cannot match copy pattern. The error is " << error_info << " The body is " |
51 | << op->body; |
52 | return ret; |
53 | } |
54 | return StmtMutator::VisitStmt_(op); |
55 | } |
56 | |
57 | private: |
58 | bool MatchCopyPattern(Stmt stmt, Stmt* out, std::string* error_info) { |
59 | using namespace arith; |
60 | Stmt body = stmt; |
61 | |
62 | // strip the loops |
63 | std::vector<const ForNode*> loops; |
64 | while (const ForNode* op = body.as<ForNode>()) { |
65 | if (!is_zero(op->min)) { |
66 | *error_info = "the 'min' value of body 'Fonode' is 0." ; |
67 | return false; |
68 | } |
69 | loops.push_back(op); |
70 | body = op->body; |
71 | } |
72 | auto store = body.as<BufferStoreNode>(); |
73 | if (store == nullptr) { |
74 | *error_info = "the body is not a 'BufferStoreNode'" ; |
75 | return false; |
76 | } |
77 | // Expr sel_cond, sel_true_value, sel_false_value; |
78 | // match select or if |
79 | PVar<PrimExpr> sel_cond, sel_true_value, sel_false_value; |
80 | bool has_cond = if_then_else(sel_cond, sel_true_value, sel_false_value).Match(store->value) || |
81 | select(sel_cond, sel_true_value, sel_false_value).Match(store->value); |
82 | |
83 | const CastNode* cast = store->value.as<CastNode>(); |
84 | auto load = store->value.as<BufferLoadNode>(); |
85 | if (0 == loops.size()) { |
86 | ICHECK(!has_cond); |
87 | } |
88 | // for now only support true condition matching |
89 | if (has_cond) { |
90 | load = sel_true_value.Eval().as<BufferLoadNode>(); |
91 | } |
92 | // cast can be part of the pattern |
93 | if (cast != nullptr) { |
94 | load = cast->value.as<BufferLoadNode>(); |
95 | } |
96 | if (load == nullptr) { |
97 | *error_info = "the 'LoadNode' of body is a nullptr." ; |
98 | return false; |
99 | } |
100 | if (load->dtype.lanes() != 1) return false; |
101 | Array<Var> loop_vars; |
102 | for (const ForNode* op : loops) { |
103 | loop_vars.push_back(op->loop_var); |
104 | } |
105 | // TODO(Lunderberg): Move this pass to be before |
106 | // StorageFlatten/FlattenBuffer. That will simplify the |
107 | // implementation, since the pre-flattened indices/strides can be |
108 | // used directly. |
109 | ICHECK((store->indices.size() == 1) && (load->indices.size() == 1)) |
110 | << "InjectDoubleBuffer expects flat 1-d buffers. " |
111 | << "Has StorageFlatten (TE-based schedules) or " |
112 | << "FlattenBuffer (TIR-based schedules) been run?" ; |
113 | |
114 | Array<PrimExpr> store_strides = arith::DetectLinearEquation(store->indices[0], loop_vars); |
115 | Array<PrimExpr> load_strides = arith::DetectLinearEquation(load->indices[0], loop_vars); |
116 | if (load_strides.size() == 0 || store_strides.size() == 0) return false; |
117 | Array<PrimExpr> dst_shape; |
118 | const size_t loop_var_size = loop_vars.size(); |
119 | if (loop_var_size == 0) { |
120 | dst_shape.push_back(make_const(DataType::Int(32), 1)); |
121 | } else { |
122 | for (const ForNode* op : loops) { |
123 | dst_shape.push_back(op->extent); |
124 | } |
125 | } |
126 | Array<PrimExpr> src_shape = dst_shape; |
127 | Array<PrimExpr> pad_before, pad_after; |
128 | PrimExpr pad_value; |
129 | PrimExpr src_elem_offset = load_strides[loop_var_size]; |
130 | if (has_cond) { |
131 | Array<PrimExpr> clip_bound = arith::DetectClipBound(sel_cond.Eval(), loop_vars); |
132 | pad_value = sel_false_value.Eval(); |
133 | if (clip_bound.size() == 0) { |
134 | *error_info = "the size of clip bound is 0." ; |
135 | return false; |
136 | } |
137 | ICHECK_EQ(src_shape.size(), loop_vars.size()); |
138 | ICHECK_EQ(clip_bound.size(), loop_vars.size() * 2); |
139 | for (size_t i = 0; i < src_shape.size(); ++i) { |
140 | PrimExpr min_value = clip_bound[2 * i]; |
141 | PrimExpr max_value = clip_bound[2 * i + 1]; |
142 | DataType t = loop_vars[i].dtype(); |
143 | PrimExpr svalue = src_shape[i]; |
144 | if (min_value.defined()) { |
145 | PrimExpr pbefore = analyzer_.Simplify(Max(min_value, make_zero(t))); |
146 | src_elem_offset = src_elem_offset + pbefore * load_strides[i]; |
147 | svalue = svalue - pbefore; |
148 | pad_before.push_back(pbefore); |
149 | } else { |
150 | pad_before.push_back(make_zero(t)); |
151 | } |
152 | if (max_value.defined()) { |
153 | PrimExpr pafter = analyzer_.Simplify( |
154 | max(loops[i]->extent - max_value - make_const(t, 1), make_zero(t))); |
155 | svalue = svalue - pafter; |
156 | pad_after.push_back(pafter); |
157 | } else { |
158 | pad_after.push_back(make_zero(t)); |
159 | } |
160 | src_shape.Set(i, analyzer_.Simplify(svalue)); |
161 | } |
162 | src_elem_offset = analyzer_.Simplify(src_elem_offset); |
163 | } |
164 | ICHECK_EQ(load_strides.size(), store_strides.size()); |
165 | ICHECK_EQ(load_strides.size(), loop_var_size + 1); |
166 | Array<PrimExpr> src_strides(load_strides.begin(), load_strides.begin() + loop_var_size); |
167 | Array<PrimExpr> dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size); |
168 | if (loop_var_size == 0) { |
169 | src_strides.push_back(make_const(DataType::Int(32), 1)); |
170 | dst_strides.push_back(make_const(DataType::Int(32), 1)); |
171 | } |
172 | Buffer dst = store->buffer; |
173 | { |
174 | auto writer = dst.CopyOnWrite(); |
175 | writer->shape = dst_shape; |
176 | writer->strides = dst_strides; |
177 | writer->elem_offset = store_strides[loop_var_size]; |
178 | } |
179 | |
180 | Buffer src = load->buffer; |
181 | { |
182 | auto writer = src.CopyOnWrite(); |
183 | writer->shape = src_shape; |
184 | writer->strides = src_strides; |
185 | writer->elem_offset = src_elem_offset; |
186 | } |
187 | *out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value); |
188 | if (!out->defined()) { |
189 | *error_info = "flower function did not return correct stmt" ; |
190 | return false; |
191 | } |
192 | return true; |
193 | } |
194 | |
195 | // pragma key |
196 | std::string pragma_key_; |
197 | // function to lower copy intrinsics. |
198 | const PackedFunc& flower_copy_fromto_; |
199 | // arith analyzer |
200 | arith::Analyzer analyzer_; |
201 | }; |
202 | |
203 | Stmt InjectCopyIntrin(Stmt stmt, const std::string& pragma_key, |
204 | const PackedFunc& flower_copy_fromto) { |
205 | return CopyIntrinInjector(pragma_key, flower_copy_fromto)(std::move(stmt)); |
206 | } |
207 | |
208 | namespace transform { |
209 | |
210 | Pass InjectCopyIntrin(String pragma_key, PackedFunc flower_copy_fromto) { |
211 | auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
212 | auto* n = f.CopyOnWrite(); |
213 | n->body = CopyIntrinInjector(pragma_key, flower_copy_fromto)(std::move(n->body)); |
214 | return f; |
215 | }; |
216 | return CreatePrimFuncPass(pass_func, 0, "tir.InjectCopyIntrin" , {}); |
217 | } |
218 | |
219 | TVM_REGISTER_GLOBAL("tir.transform.InjectCopyIntrin" ).set_body_typed(InjectCopyIntrin); |
220 | |
221 | } // namespace transform |
222 | |
223 | } // namespace tir |
224 | } // namespace tvm |
225 | |