1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief Replace certain copy with copy intrinsics.
22 * \file copy_intrin_rewrite.cc
23 */
24#include <tvm/arith/analyzer.h>
25#include <tvm/arith/pattern.h>
26#include <tvm/runtime/registry.h>
27#include <tvm/tir/expr.h>
28#include <tvm/tir/stmt_functor.h>
29#include <tvm/tir/transform.h>
30
31#include "../../arith/pattern_match.h"
32#include "ir_utils.h"
33
34namespace tvm {
35namespace tir {
36
37using runtime::PackedFunc;
38
39class CopyIntrinInjector : public StmtMutator {
40 public:
41 CopyIntrinInjector(const std::string& pragma_key, const PackedFunc& flower_copy_fromto)
42 : pragma_key_(attr::pragma_scope_prefix + pragma_key),
43 flower_copy_fromto_(flower_copy_fromto) {}
44
45 Stmt VisitStmt_(const AttrStmtNode* op) final {
46 if (op->attr_key == pragma_key_) {
47 Stmt ret;
48 std::string error_info;
49 ICHECK(MatchCopyPattern(op->body, &ret, &error_info))
50 << "Cannot match copy pattern. The error is " << error_info << " The body is "
51 << op->body;
52 return ret;
53 }
54 return StmtMutator::VisitStmt_(op);
55 }
56
57 private:
58 bool MatchCopyPattern(Stmt stmt, Stmt* out, std::string* error_info) {
59 using namespace arith;
60 Stmt body = stmt;
61
62 // strip the loops
63 std::vector<const ForNode*> loops;
64 while (const ForNode* op = body.as<ForNode>()) {
65 if (!is_zero(op->min)) {
66 *error_info = "the 'min' value of body 'Fonode' is 0.";
67 return false;
68 }
69 loops.push_back(op);
70 body = op->body;
71 }
72 auto store = body.as<BufferStoreNode>();
73 if (store == nullptr) {
74 *error_info = "the body is not a 'BufferStoreNode'";
75 return false;
76 }
77 // Expr sel_cond, sel_true_value, sel_false_value;
78 // match select or if
79 PVar<PrimExpr> sel_cond, sel_true_value, sel_false_value;
80 bool has_cond = if_then_else(sel_cond, sel_true_value, sel_false_value).Match(store->value) ||
81 select(sel_cond, sel_true_value, sel_false_value).Match(store->value);
82
83 const CastNode* cast = store->value.as<CastNode>();
84 auto load = store->value.as<BufferLoadNode>();
85 if (0 == loops.size()) {
86 ICHECK(!has_cond);
87 }
88 // for now only support true condition matching
89 if (has_cond) {
90 load = sel_true_value.Eval().as<BufferLoadNode>();
91 }
92 // cast can be part of the pattern
93 if (cast != nullptr) {
94 load = cast->value.as<BufferLoadNode>();
95 }
96 if (load == nullptr) {
97 *error_info = "the 'LoadNode' of body is a nullptr.";
98 return false;
99 }
100 if (load->dtype.lanes() != 1) return false;
101 Array<Var> loop_vars;
102 for (const ForNode* op : loops) {
103 loop_vars.push_back(op->loop_var);
104 }
105 // TODO(Lunderberg): Move this pass to be before
106 // StorageFlatten/FlattenBuffer. That will simplify the
107 // implementation, since the pre-flattened indices/strides can be
108 // used directly.
109 ICHECK((store->indices.size() == 1) && (load->indices.size() == 1))
110 << "InjectDoubleBuffer expects flat 1-d buffers. "
111 << "Has StorageFlatten (TE-based schedules) or "
112 << "FlattenBuffer (TIR-based schedules) been run?";
113
114 Array<PrimExpr> store_strides = arith::DetectLinearEquation(store->indices[0], loop_vars);
115 Array<PrimExpr> load_strides = arith::DetectLinearEquation(load->indices[0], loop_vars);
116 if (load_strides.size() == 0 || store_strides.size() == 0) return false;
117 Array<PrimExpr> dst_shape;
118 const size_t loop_var_size = loop_vars.size();
119 if (loop_var_size == 0) {
120 dst_shape.push_back(make_const(DataType::Int(32), 1));
121 } else {
122 for (const ForNode* op : loops) {
123 dst_shape.push_back(op->extent);
124 }
125 }
126 Array<PrimExpr> src_shape = dst_shape;
127 Array<PrimExpr> pad_before, pad_after;
128 PrimExpr pad_value;
129 PrimExpr src_elem_offset = load_strides[loop_var_size];
130 if (has_cond) {
131 Array<PrimExpr> clip_bound = arith::DetectClipBound(sel_cond.Eval(), loop_vars);
132 pad_value = sel_false_value.Eval();
133 if (clip_bound.size() == 0) {
134 *error_info = "the size of clip bound is 0.";
135 return false;
136 }
137 ICHECK_EQ(src_shape.size(), loop_vars.size());
138 ICHECK_EQ(clip_bound.size(), loop_vars.size() * 2);
139 for (size_t i = 0; i < src_shape.size(); ++i) {
140 PrimExpr min_value = clip_bound[2 * i];
141 PrimExpr max_value = clip_bound[2 * i + 1];
142 DataType t = loop_vars[i].dtype();
143 PrimExpr svalue = src_shape[i];
144 if (min_value.defined()) {
145 PrimExpr pbefore = analyzer_.Simplify(Max(min_value, make_zero(t)));
146 src_elem_offset = src_elem_offset + pbefore * load_strides[i];
147 svalue = svalue - pbefore;
148 pad_before.push_back(pbefore);
149 } else {
150 pad_before.push_back(make_zero(t));
151 }
152 if (max_value.defined()) {
153 PrimExpr pafter = analyzer_.Simplify(
154 max(loops[i]->extent - max_value - make_const(t, 1), make_zero(t)));
155 svalue = svalue - pafter;
156 pad_after.push_back(pafter);
157 } else {
158 pad_after.push_back(make_zero(t));
159 }
160 src_shape.Set(i, analyzer_.Simplify(svalue));
161 }
162 src_elem_offset = analyzer_.Simplify(src_elem_offset);
163 }
164 ICHECK_EQ(load_strides.size(), store_strides.size());
165 ICHECK_EQ(load_strides.size(), loop_var_size + 1);
166 Array<PrimExpr> src_strides(load_strides.begin(), load_strides.begin() + loop_var_size);
167 Array<PrimExpr> dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size);
168 if (loop_var_size == 0) {
169 src_strides.push_back(make_const(DataType::Int(32), 1));
170 dst_strides.push_back(make_const(DataType::Int(32), 1));
171 }
172 Buffer dst = store->buffer;
173 {
174 auto writer = dst.CopyOnWrite();
175 writer->shape = dst_shape;
176 writer->strides = dst_strides;
177 writer->elem_offset = store_strides[loop_var_size];
178 }
179
180 Buffer src = load->buffer;
181 {
182 auto writer = src.CopyOnWrite();
183 writer->shape = src_shape;
184 writer->strides = src_strides;
185 writer->elem_offset = src_elem_offset;
186 }
187 *out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value);
188 if (!out->defined()) {
189 *error_info = "flower function did not return correct stmt";
190 return false;
191 }
192 return true;
193 }
194
195 // pragma key
196 std::string pragma_key_;
197 // function to lower copy intrinsics.
198 const PackedFunc& flower_copy_fromto_;
199 // arith analyzer
200 arith::Analyzer analyzer_;
201};
202
203Stmt InjectCopyIntrin(Stmt stmt, const std::string& pragma_key,
204 const PackedFunc& flower_copy_fromto) {
205 return CopyIntrinInjector(pragma_key, flower_copy_fromto)(std::move(stmt));
206}
207
208namespace transform {
209
210Pass InjectCopyIntrin(String pragma_key, PackedFunc flower_copy_fromto) {
211 auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
212 auto* n = f.CopyOnWrite();
213 n->body = CopyIntrinInjector(pragma_key, flower_copy_fromto)(std::move(n->body));
214 return f;
215 };
216 return CreatePrimFuncPass(pass_func, 0, "tir.InjectCopyIntrin", {});
217}
218
219TVM_REGISTER_GLOBAL("tir.transform.InjectCopyIntrin").set_body_typed(InjectCopyIntrin);
220
221} // namespace transform
222
223} // namespace tir
224} // namespace tvm
225